1 /*
2 * Copyright (c) 2021-2024 NVIDIA Corporation
3 *
4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5 * (the "License"); you may not use this file except in compliance with
6 * the License. You may obtain a copy of the License at
7 *
8 * https://llvm.org/LICENSE.txt
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #pragma once
17
18 #include "__execution_fwd.hpp"
19
20 // include these after __execution_fwd.hpp
21 #include "../stop_token.hpp"
22 #include "__basic_sender.hpp"
23 #include "__concepts.hpp"
24 #include "__continue_on.hpp"
25 #include "__diagnostics.hpp"
26 #include "__domain.hpp"
27 #include "__env.hpp"
28 #include "__into_variant.hpp"
29 #include "__meta.hpp"
30 #include "__optional.hpp"
31 #include "__schedulers.hpp"
32 #include "__senders.hpp"
33 #include "__transform_completion_signatures.hpp"
34 #include "__transform_sender.hpp"
35 #include "__tuple.hpp"
36 #include "__type_traits.hpp"
37 #include "__utility.hpp"
38 #include "__variant.hpp"
39
40 #include <atomic>
41 #include <exception>
42
43 namespace stdexec
44 {
45 /////////////////////////////////////////////////////////////////////////////
46 // [execution.senders.adaptors.when_all]
47 // [execution.senders.adaptors.when_all_with_variant]
48 namespace __when_all
49 {
50 enum __state_t
51 {
52 __started,
53 __error,
54 __stopped
55 };
56
57 struct __on_stop_request
58 {
59 inplace_stop_source& __stop_source_;
60
operator ()stdexec::__when_all::__on_stop_request61 void operator()() noexcept
62 {
63 __stop_source_.request_stop();
64 }
65 };
66
67 template <class _Env>
__mkenv(_Env && __env,const inplace_stop_source & __stop_source)68 auto __mkenv(_Env&& __env, const inplace_stop_source& __stop_source) noexcept
69 {
70 return __env::__join(prop{get_stop_token, __stop_source.get_token()},
71 static_cast<_Env&&>(__env));
72 }
73
74 template <class _Env>
75 using __env_t = //
76 decltype(__when_all::__mkenv(__declval<_Env>(),
77 __declval<inplace_stop_source&>()));
78
79 template <class _Sender, class _Env>
80 concept __max1_sender =
81 sender_in<_Sender, _Env> && __mvalid<__value_types_of_t, _Sender, _Env,
82 __mconst<int>, __msingle_or<void>>;
83
84 template <
85 __mstring _Context = "In stdexec::when_all()..."_mstr,
86 __mstring _Diagnostic =
87 "The given sender can complete successfully in more that one way. "
88 "Use stdexec::when_all_with_variant() instead."_mstr>
89 struct _INVALID_WHEN_ALL_ARGUMENT_;
90
91 template <class _Sender, class... _Env>
92 using __too_many_value_completions_error =
93 __mexception<_INVALID_WHEN_ALL_ARGUMENT_<>, _WITH_SENDER_<_Sender>,
94 _WITH_ENVIRONMENT_<_Env>...>;
95
96 template <class... _Args>
97 using __all_nothrow_decay_copyable =
98 __mbool<(__nothrow_decay_copyable<_Args> && ...)>;
99
100 template <class _Error>
101 using __set_error_t = completion_signatures<set_error_t(__decay_t<_Error>)>;
102
103 template <class _Sender, class... _Env>
104 using __nothrow_decay_copyable_results = //
105 __for_each_completion_signature<
106 __completion_signatures_of_t<_Sender, _Env...>,
107 __all_nothrow_decay_copyable, __mand_t>;
108
109 template <class... _Env>
110 struct __completions_t
111 {
112 template <class... _Senders>
113 using __all_nothrow_decay_copyable_results = //
114 __mand<__nothrow_decay_copyable_results<_Senders, _Env...>...>;
115
116 template <class _Sender, class _ValueTuple, class... _Rest>
117 using __value_tuple_t =
118 __minvoke<__if_c<(0 == sizeof...(_Rest)), __mconst<_ValueTuple>,
119 __q<__too_many_value_completions_error>>,
120 _Sender, _Env...>;
121
122 template <class _Sender>
123 using __single_values_of_t = //
124 __value_types_t<__completion_signatures_of_t<_Sender, _Env...>,
125 __mtransform<__q<__decay_t>, __q<__types>>,
126 __mbind_front_q<__value_tuple_t, _Sender>>;
127
128 template <class... _Senders>
129 using __set_values_sig_t = //
130 __meval<completion_signatures,
131 __minvoke<__mconcat<__qf<set_value_t>>,
132 __single_values_of_t<_Senders>...>>;
133
134 template <class... _Senders>
135 using __f = //
136 __meval< //
137 __concat_completion_signatures,
138 __meval<__eptr_completion_if_t,
139 __all_nothrow_decay_copyable_results<_Senders...>>,
140 completion_signatures<set_stopped_t()>,
141 __minvoke<__with_default<__qq<__set_values_sig_t>,
142 completion_signatures<>>,
143 _Senders...>,
144 __transform_completion_signatures<
145 __completion_signatures_of_t<_Senders, _Env...>,
146 __mconst<completion_signatures<>>::__f, __set_error_t,
147 completion_signatures<>, __concat_completion_signatures>...>;
148 };
149
150 template <class _Tag, class _Receiver>
__complete_fn(_Tag,_Receiver & __rcvr)151 auto __complete_fn(_Tag, _Receiver& __rcvr) noexcept
152 {
153 return [&]<class... _Ts>(_Ts&... __ts) noexcept {
154 _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Ts&&>(__ts)...);
155 };
156 }
157
158 template <class _Receiver, class _ValuesTuple>
__set_values(_Receiver & __rcvr,_ValuesTuple & __values)159 void __set_values(_Receiver& __rcvr, _ValuesTuple& __values) noexcept
160 {
161 __values.apply(
162 [&](auto&... __opt_vals) noexcept -> void {
163 __tup::__cat_apply(__when_all::__complete_fn(set_value, __rcvr),
164 *__opt_vals...);
165 },
166 __values);
167 }
168
169 template <class _Env, class _Sender>
170 using __values_opt_tuple_t = //
171 value_types_of_t<_Sender, __env_t<_Env>, __decayed_tuple, __optional>;
172
173 template <class _Env, __max1_sender<__env_t<_Env>>... _Senders>
174 struct __traits
175 {
176 // tuple<optional<tuple<Vs1...>>, optional<tuple<Vs2...>>, ...>
177 using __values_tuple = //
178 __minvoke<__with_default<
179 __mtransform<__mbind_front_q<__values_opt_tuple_t, _Env>,
180 __q<__tuple_for>>,
181 __ignore>,
182 _Senders...>;
183
184 using __collect_errors = __mbind_front_q<__mset_insert, __mset<>>;
185
186 using __errors_list = //
187 __minvoke<
188 __mconcat<>,
189 __if<__mand<__nothrow_decay_copyable_results<_Senders, _Env>...>,
190 __types<>, __types<std::exception_ptr>>,
191 __error_types_of_t<_Senders, __env_t<_Env>, __q<__types>>...>;
192
193 using __errors_variant =
194 __mapply<__q<__uniqued_variant_for>, __errors_list>;
195 };
196
197 struct _INVALID_ARGUMENTS_TO_WHEN_ALL_
198 {};
199
200 template <class _ErrorsVariant, class _ValuesTuple, class _StopToken>
201 struct __when_all_state
202 {
203 using __stop_callback_t =
204 stop_callback_for_t<_StopToken, __on_stop_request>;
205
206 template <class _Receiver>
__arrivestdexec::__when_all::__when_all_state207 void __arrive(_Receiver& __rcvr) noexcept
208 {
209 if (0 == --__count_)
210 {
211 __complete(__rcvr);
212 }
213 }
214
215 template <class _Receiver>
__completestdexec::__when_all::__when_all_state216 void __complete(_Receiver& __rcvr) noexcept
217 {
218 // Stop callback is no longer needed. Destroy it.
219 __on_stop_.reset();
220 // All child operations have completed and arrived at the barrier.
221 switch (__state_.load(std::memory_order_relaxed))
222 {
223 case __started:
224 if constexpr (!same_as<_ValuesTuple, __ignore>)
225 {
226 // All child operations completed successfully:
227 __when_all::__set_values(__rcvr, __values_);
228 }
229 break;
230 case __error:
231 if constexpr (!__same_as<_ErrorsVariant, __variant_for<>>)
232 {
233 // One or more child operations completed with an error:
234 __errors_.visit(__complete_fn(set_error, __rcvr),
235 __errors_);
236 }
237 break;
238 case __stopped:
239 stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr));
240 break;
241 default:;
242 }
243 }
244
245 std::atomic<std::size_t> __count_;
246 inplace_stop_source __stop_source_{};
247 // Could be non-atomic here and atomic_ref everywhere except __completion_fn
248 std::atomic<__state_t> __state_{__started};
249 _ErrorsVariant __errors_{};
250 STDEXEC_ATTRIBUTE((no_unique_address))
251 _ValuesTuple __values_{};
252 __optional<__stop_callback_t> __on_stop_{};
253 };
254
255 template <class _Env>
__mk_state_fn(const _Env &)256 static auto __mk_state_fn(const _Env&) noexcept
257 {
258 return []<__max1_sender<__env_t<_Env>>... _Child>(__ignore, __ignore,
259 _Child&&...) {
260 using _Traits = __traits<_Env, _Child...>;
261 using _ErrorsVariant = typename _Traits::__errors_variant;
262 using _ValuesTuple = typename _Traits::__values_tuple;
263 using _State = __when_all_state<_ErrorsVariant, _ValuesTuple,
264 stop_token_of_t<_Env>>;
265 return _State{sizeof...(_Child)};
266 };
267 }
268
269 template <class _Env>
270 using __mk_state_fn_t = decltype(__when_all::__mk_state_fn(__declval<_Env>()));
271
272 struct when_all_t
273 {
274 // Used by the default_domain to find legacy customizations:
275 using _Sender = __1;
276 using __legacy_customizations_t = //
277 __types<tag_invoke_t(when_all_t, _Sender...)>;
278
279 template <sender... _Senders>
280 requires __domain::__has_common_domain<_Senders...>
operator ()stdexec::__when_all::when_all_t281 auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto
282 {
283 auto __domain = __domain::__common_domain_t<_Senders...>();
284 return stdexec::transform_sender(
285 __domain, __make_sexpr<when_all_t>(
286 __(), static_cast<_Senders&&>(__sndrs)...));
287 }
288 };
289
290 struct __when_all_impl : __sexpr_defaults
291 {
292 template <class _Self, class _Env>
293 using __error_t = __mexception<_INVALID_ARGUMENTS_TO_WHEN_ALL_,
294 __children_of<_Self, __q<_WITH_SENDERS_>>,
295 _WITH_ENVIRONMENT_<_Env>>;
296
297 template <class _Self, class... _Env>
298 using __completions =
299 __children_of<_Self, __completions_t<__env_t<_Env>...>>;
300
301 static constexpr auto get_attrs = //
302 []<class... _Child>(__ignore, const _Child&...) noexcept {
303 using _Domain = __domain::__common_domain_t<_Child...>;
304 if constexpr (__same_as<_Domain, default_domain>)
305 {
306 return env();
307 }
308 else
309 {
310 return prop{get_domain, _Domain()};
311 }
312 };
313
314 static constexpr auto get_completion_signatures = //
315 []<class _Self, class... _Env>(_Self&&, _Env&&...) noexcept {
316 static_assert(sender_expr_for<_Self, when_all_t>);
317 return __minvoke<__mtry_catch<__q<__completions>, __q<__error_t>>,
318 _Self, _Env...>();
319 };
320
321 static constexpr auto get_env = //
322 []<class _State, class _Receiver>(__ignore, _State& __state,
323 const _Receiver& __rcvr) noexcept //
324 -> __env_t<env_of_t<const _Receiver&>> {
325 return __mkenv(stdexec::get_env(__rcvr), __state.__stop_source_);
326 };
327
328 static constexpr auto get_state = //
329 []<class _Self, class _Receiver>(_Self&& __self, _Receiver& __rcvr)
330 -> __sexpr_apply_result_t<_Self, __mk_state_fn_t<env_of_t<_Receiver>>> {
331 return __sexpr_apply(
332 static_cast<_Self&&>(__self),
333 __when_all::__mk_state_fn(stdexec::get_env(__rcvr)));
334 };
335
336 static constexpr auto start = //
337 []<class _State, class _Receiver, class... _Operations>(
338 _State& __state, _Receiver& __rcvr,
339 _Operations&... __child_ops) noexcept -> void {
340 // register stop callback:
341 __state.__on_stop_.emplace(get_stop_token(stdexec::get_env(__rcvr)),
342 __on_stop_request{__state.__stop_source_});
343 if (__state.__stop_source_.stop_requested())
344 {
345 // Stop has already been requested. Don't bother starting
346 // the child operations.
347 stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr));
348 }
349 else
350 {
351 (stdexec::start(__child_ops), ...);
352 if constexpr (sizeof...(__child_ops) == 0)
353 {
354 __state.__complete(__rcvr);
355 }
356 }
357 };
358
359 template <class _State, class _Receiver, class _Error>
__set_errorstdexec::__when_all::__when_all_impl360 static void __set_error(_State& __state, _Receiver&,
361 _Error&& __err) noexcept
362 {
363 // TODO: What memory orderings are actually needed here?
364 if (__error != __state.__state_.exchange(__error))
365 {
366 __state.__stop_source_.request_stop();
367 // We won the race, free to write the error into the operation
368 // state without worry.
369 if constexpr (__nothrow_decay_copyable<_Error>)
370 {
371 __state.__errors_.template emplace<__decay_t<_Error>>(
372 static_cast<_Error&&>(__err));
373 }
374 else
375 {
376 try
377 {
378 __state.__errors_.template emplace<__decay_t<_Error>>(
379 static_cast<_Error&&>(__err));
380 }
381 catch (...)
382 {
383 __state.__errors_.template emplace<std::exception_ptr>(
384 std::current_exception());
385 }
386 }
387 }
388 }
389
390 static constexpr auto complete = //
391 []<class _Index, class _State, class _Receiver, class _Set,
392 class... _Args>(_Index, _State& __state, _Receiver& __rcvr, _Set,
393 _Args&&... __args) noexcept -> void {
394 if constexpr (__same_as<_Set, set_error_t>)
395 {
396 __set_error(__state, __rcvr, static_cast<_Args&&>(__args)...);
397 }
398 else if constexpr (__same_as<_Set, set_stopped_t>)
399 {
400 __state_t __expected = __started;
401 // Transition to the "stopped" state if and only if we're in the
402 // "started" state. (If this fails, it's because we're in an
403 // error state, which trumps cancellation.)
404 if (__state.__state_.compare_exchange_strong(__expected, __stopped))
405 {
406 __state.__stop_source_.request_stop();
407 }
408 }
409 else if constexpr (!__same_as<decltype(_State::__values_), __ignore>)
410 {
411 // We only need to bother recording the completion values
412 // if we're not already in the "error" or "stopped" state.
413 if (__state.__state_ == __started)
414 {
415 auto& __opt_values = __tup::get<__v<_Index>>(__state.__values_);
416 using _Tuple = __decayed_tuple<_Args...>;
417 static_assert(
418 __same_as<decltype(*__opt_values), _Tuple&>,
419 "One of the senders in this when_all() is fibbing about what types it sends");
420 if constexpr ((__nothrow_decay_copyable<_Args> && ...))
421 {
422 __opt_values.emplace(
423 _Tuple{{static_cast<_Args&&>(__args)}...});
424 }
425 else
426 {
427 try
428 {
429 __opt_values.emplace(
430 _Tuple{{static_cast<_Args&&>(__args)}...});
431 }
432 catch (...)
433 {
434 __set_error(__state, __rcvr, std::current_exception());
435 }
436 }
437 }
438 }
439
440 __state.__arrive(__rcvr);
441 };
442 };
443
444 struct when_all_with_variant_t
445 {
446 using _Sender = __1;
447 using __legacy_customizations_t = //
448 __types<tag_invoke_t(when_all_with_variant_t, _Sender...)>;
449
450 template <sender... _Senders>
451 requires __domain::__has_common_domain<_Senders...>
operator ()stdexec::__when_all::when_all_with_variant_t452 auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto
453 {
454 auto __domain = __domain::__common_domain_t<_Senders...>();
455 return stdexec::transform_sender(
456 __domain, __make_sexpr<when_all_with_variant_t>(
457 __(), static_cast<_Senders&&>(__sndrs)...));
458 }
459
460 template <class _Sender, class _Env>
transform_senderstdexec::__when_all::when_all_with_variant_t461 static auto transform_sender(_Sender&& __sndr, const _Env&)
462 {
463 // transform the when_all_with_variant into a regular when_all (looking
464 // for early when_all customizations), then transform it again to look
465 // for late customizations.
466 return __sexpr_apply(
467 static_cast<_Sender&&>(__sndr),
468 [&]<class... _Child>(__ignore, __ignore, _Child&&... __child) {
469 return when_all_t()(
470 into_variant(static_cast<_Child&&>(__child))...);
471 });
472 }
473 };
474
475 struct __when_all_with_variant_impl : __sexpr_defaults
476 {
477 static constexpr auto get_attrs = //
478 []<class... _Child>(__ignore, const _Child&...) noexcept {
479 using _Domain = __domain::__common_domain_t<_Child...>;
480 if constexpr (same_as<_Domain, default_domain>)
481 {
482 return env();
483 }
484 else
485 {
486 return prop{get_domain, _Domain()};
487 }
488 };
489
490 static constexpr auto get_completion_signatures = //
491 []<class _Sender>(_Sender&&) noexcept //
492 -> __completion_signatures_of_t< //
493 transform_sender_result_t<default_domain, _Sender, empty_env>> {
494 return {};
495 };
496 };
497
498 struct transfer_when_all_t
499 {
500 using _Env = __0;
501 using _Sender = __1;
502 using __legacy_customizations_t = //
503 __types<tag_invoke_t(
504 transfer_when_all_t,
505 get_completion_scheduler_t<set_value_t>(const _Env&), _Sender...)>;
506
507 template <scheduler _Scheduler, sender... _Senders>
508 requires __domain::__has_common_domain<_Senders...>
operator ()stdexec::__when_all::transfer_when_all_t509 auto operator()(_Scheduler&& __sched,
510 _Senders&&... __sndrs) const -> __well_formed_sender auto
511 {
512 using _Env = __t<__schfr::__environ<__id<__decay_t<_Scheduler>>>>;
513 auto __domain = query_or(get_domain, __sched, default_domain());
514 return stdexec::transform_sender(
515 __domain, __make_sexpr<transfer_when_all_t>(
516 _Env{static_cast<_Scheduler&&>(__sched)},
517 static_cast<_Senders&&>(__sndrs)...));
518 }
519
520 template <class _Sender, class _Env>
transform_senderstdexec::__when_all::transfer_when_all_t521 static auto transform_sender(_Sender&& __sndr, const _Env&)
522 {
523 // transform the transfer_when_all into a regular transform | when_all
524 // (looking for early customizations), then transform it again to look
525 // for late customizations.
526 return __sexpr_apply(
527 static_cast<_Sender&&>(__sndr),
528 [&]<class _Data, class... _Child>(__ignore, _Data&& __data,
529 _Child&&... __child) {
530 return continue_on(
531 when_all_t()(static_cast<_Child&&>(__child)...),
532 get_completion_scheduler<set_value_t>(__data));
533 });
534 }
535 };
536
537 struct __transfer_when_all_impl : __sexpr_defaults
538 {
539 static constexpr auto get_attrs = //
540 []<class _Data>(const _Data& __data,
541 const auto&...) noexcept -> const _Data& {
542 return __data;
543 };
544
545 static constexpr auto get_completion_signatures = //
546 []<class _Sender>(_Sender&&) noexcept //
547 -> __completion_signatures_of_t< //
548 transform_sender_result_t<default_domain, _Sender, empty_env>> {
549 return {};
550 };
551 };
552
553 struct transfer_when_all_with_variant_t
554 {
555 using _Env = __0;
556 using _Sender = __1;
557 using __legacy_customizations_t = //
558 __types<tag_invoke_t(
559 transfer_when_all_with_variant_t,
560 get_completion_scheduler_t<set_value_t>(const _Env&), _Sender...)>;
561
562 template <scheduler _Scheduler, sender... _Senders>
563 requires __domain::__has_common_domain<_Senders...>
operator ()stdexec::__when_all::transfer_when_all_with_variant_t564 auto operator()(_Scheduler&& __sched,
565 _Senders&&... __sndrs) const -> __well_formed_sender auto
566 {
567 using _Env = __t<__schfr::__environ<__id<__decay_t<_Scheduler>>>>;
568 auto __domain = query_or(get_domain, __sched, default_domain());
569 return stdexec::transform_sender(
570 __domain, __make_sexpr<transfer_when_all_with_variant_t>(
571 _Env{{static_cast<_Scheduler&&>(__sched)}},
572 static_cast<_Senders&&>(__sndrs)...));
573 }
574
575 template <class _Sender, class _Env>
transform_senderstdexec::__when_all::transfer_when_all_with_variant_t576 static auto transform_sender(_Sender&& __sndr, const _Env&)
577 {
578 // transform the transfer_when_all_with_variant into regular
579 // transform_when_all and into_variant calls/ (looking for early
580 // customizations), then transform it again to look for late
581 // customizations.
582 return __sexpr_apply(
583 static_cast<_Sender&&>(__sndr),
584 [&]<class _Data, class... _Child>(__ignore, _Data&& __data,
585 _Child&&... __child) {
586 return transfer_when_all_t()(
587 get_completion_scheduler<set_value_t>(
588 static_cast<_Data&&>(__data)),
589 into_variant(static_cast<_Child&&>(__child))...);
590 });
591 }
592 };
593
594 struct __transfer_when_all_with_variant_impl : __sexpr_defaults
595 {
596 static constexpr auto get_attrs = //
597 []<class _Data>(const _Data& __data,
598 const auto&...) noexcept -> const _Data& {
599 return __data;
600 };
601
602 static constexpr auto get_completion_signatures = //
603 []<class _Sender>(_Sender&&) noexcept //
604 -> __completion_signatures_of_t< //
605 transform_sender_result_t<default_domain, _Sender, empty_env>> {
606 return {};
607 };
608 };
609 } // namespace __when_all
610
611 using __when_all::when_all_t;
612 inline constexpr when_all_t when_all{};
613
614 using __when_all::when_all_with_variant_t;
615 inline constexpr when_all_with_variant_t when_all_with_variant{};
616
617 using __when_all::transfer_when_all_t;
618 inline constexpr transfer_when_all_t transfer_when_all{};
619
620 using __when_all::transfer_when_all_with_variant_t;
621 inline constexpr transfer_when_all_with_variant_t
622 transfer_when_all_with_variant{};
623
624 template <>
625 struct __sexpr_impl<when_all_t> : __when_all::__when_all_impl
626 {};
627
628 template <>
629 struct __sexpr_impl<when_all_with_variant_t> :
630 __when_all::__when_all_with_variant_impl
631 {};
632
633 template <>
634 struct __sexpr_impl<transfer_when_all_t> : __when_all::__transfer_when_all_impl
635 {};
636
637 template <>
638 struct __sexpr_impl<transfer_when_all_with_variant_t> :
639 __when_all::__transfer_when_all_with_variant_impl
640 {};
641 } // namespace stdexec
642