1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__execution_fwd.hpp"
19 
20 // include these after __execution_fwd.hpp
21 #include "../stop_token.hpp"
22 #include "__basic_sender.hpp"
23 #include "__concepts.hpp"
24 #include "__continue_on.hpp"
25 #include "__diagnostics.hpp"
26 #include "__domain.hpp"
27 #include "__env.hpp"
28 #include "__into_variant.hpp"
29 #include "__meta.hpp"
30 #include "__optional.hpp"
31 #include "__schedulers.hpp"
32 #include "__senders.hpp"
33 #include "__transform_completion_signatures.hpp"
34 #include "__transform_sender.hpp"
35 #include "__tuple.hpp"
36 #include "__type_traits.hpp"
37 #include "__utility.hpp"
38 #include "__variant.hpp"
39 
40 #include <atomic>
41 #include <exception>
42 
43 namespace stdexec
44 {
45 /////////////////////////////////////////////////////////////////////////////
46 // [execution.senders.adaptors.when_all]
47 // [execution.senders.adaptors.when_all_with_variant]
48 namespace __when_all
49 {
50 enum __state_t
51 {
52     __started,
53     __error,
54     __stopped
55 };
56 
57 struct __on_stop_request
58 {
59     inplace_stop_source& __stop_source_;
60 
operator ()stdexec::__when_all::__on_stop_request61     void operator()() noexcept
62     {
63         __stop_source_.request_stop();
64     }
65 };
66 
67 template <class _Env>
__mkenv(_Env && __env,const inplace_stop_source & __stop_source)68 auto __mkenv(_Env&& __env, const inplace_stop_source& __stop_source) noexcept
69 {
70     return __env::__join(prop{get_stop_token, __stop_source.get_token()},
71                          static_cast<_Env&&>(__env));
72 }
73 
74 template <class _Env>
75 using __env_t = //
76     decltype(__when_all::__mkenv(__declval<_Env>(),
77                                  __declval<inplace_stop_source&>()));
78 
79 template <class _Sender, class _Env>
80 concept __max1_sender =
81     sender_in<_Sender, _Env> && __mvalid<__value_types_of_t, _Sender, _Env,
82                                          __mconst<int>, __msingle_or<void>>;
83 
84 template <
85     __mstring _Context = "In stdexec::when_all()..."_mstr,
86     __mstring _Diagnostic =
87         "The given sender can complete successfully in more that one way. "
88         "Use stdexec::when_all_with_variant() instead."_mstr>
89 struct _INVALID_WHEN_ALL_ARGUMENT_;
90 
91 template <class _Sender, class... _Env>
92 using __too_many_value_completions_error =
93     __mexception<_INVALID_WHEN_ALL_ARGUMENT_<>, _WITH_SENDER_<_Sender>,
94                  _WITH_ENVIRONMENT_<_Env>...>;
95 
96 template <class... _Args>
97 using __all_nothrow_decay_copyable =
98     __mbool<(__nothrow_decay_copyable<_Args> && ...)>;
99 
100 template <class _Error>
101 using __set_error_t = completion_signatures<set_error_t(__decay_t<_Error>)>;
102 
103 template <class _Sender, class... _Env>
104 using __nothrow_decay_copyable_results = //
105     __for_each_completion_signature<
106         __completion_signatures_of_t<_Sender, _Env...>,
107         __all_nothrow_decay_copyable, __mand_t>;
108 
109 template <class... _Env>
110 struct __completions_t
111 {
112     template <class... _Senders>
113     using __all_nothrow_decay_copyable_results = //
114         __mand<__nothrow_decay_copyable_results<_Senders, _Env...>...>;
115 
116     template <class _Sender, class _ValueTuple, class... _Rest>
117     using __value_tuple_t =
118         __minvoke<__if_c<(0 == sizeof...(_Rest)), __mconst<_ValueTuple>,
119                          __q<__too_many_value_completions_error>>,
120                   _Sender, _Env...>;
121 
122     template <class _Sender>
123     using __single_values_of_t = //
124         __value_types_t<__completion_signatures_of_t<_Sender, _Env...>,
125                         __mtransform<__q<__decay_t>, __q<__types>>,
126                         __mbind_front_q<__value_tuple_t, _Sender>>;
127 
128     template <class... _Senders>
129     using __set_values_sig_t = //
130         __meval<completion_signatures,
131                 __minvoke<__mconcat<__qf<set_value_t>>,
132                           __single_values_of_t<_Senders>...>>;
133 
134     template <class... _Senders>
135     using __f =  //
136         __meval< //
137             __concat_completion_signatures,
138             __meval<__eptr_completion_if_t,
139                     __all_nothrow_decay_copyable_results<_Senders...>>,
140             completion_signatures<set_stopped_t()>,
141             __minvoke<__with_default<__qq<__set_values_sig_t>,
142                                      completion_signatures<>>,
143                       _Senders...>,
144             __transform_completion_signatures<
145                 __completion_signatures_of_t<_Senders, _Env...>,
146                 __mconst<completion_signatures<>>::__f, __set_error_t,
147                 completion_signatures<>, __concat_completion_signatures>...>;
148 };
149 
150 template <class _Tag, class _Receiver>
__complete_fn(_Tag,_Receiver & __rcvr)151 auto __complete_fn(_Tag, _Receiver& __rcvr) noexcept
152 {
153     return [&]<class... _Ts>(_Ts&... __ts) noexcept {
154         _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Ts&&>(__ts)...);
155     };
156 }
157 
158 template <class _Receiver, class _ValuesTuple>
__set_values(_Receiver & __rcvr,_ValuesTuple & __values)159 void __set_values(_Receiver& __rcvr, _ValuesTuple& __values) noexcept
160 {
161     __values.apply(
162         [&](auto&... __opt_vals) noexcept -> void {
163             __tup::__cat_apply(__when_all::__complete_fn(set_value, __rcvr),
164                                *__opt_vals...);
165         },
166         __values);
167 }
168 
169 template <class _Env, class _Sender>
170 using __values_opt_tuple_t = //
171     value_types_of_t<_Sender, __env_t<_Env>, __decayed_tuple, __optional>;
172 
173 template <class _Env, __max1_sender<__env_t<_Env>>... _Senders>
174 struct __traits
175 {
176     // tuple<optional<tuple<Vs1...>>, optional<tuple<Vs2...>>, ...>
177     using __values_tuple = //
178         __minvoke<__with_default<
179                       __mtransform<__mbind_front_q<__values_opt_tuple_t, _Env>,
180                                    __q<__tuple_for>>,
181                       __ignore>,
182                   _Senders...>;
183 
184     using __collect_errors = __mbind_front_q<__mset_insert, __mset<>>;
185 
186     using __errors_list = //
187         __minvoke<
188             __mconcat<>,
189             __if<__mand<__nothrow_decay_copyable_results<_Senders, _Env>...>,
190                  __types<>, __types<std::exception_ptr>>,
191             __error_types_of_t<_Senders, __env_t<_Env>, __q<__types>>...>;
192 
193     using __errors_variant =
194         __mapply<__q<__uniqued_variant_for>, __errors_list>;
195 };
196 
197 struct _INVALID_ARGUMENTS_TO_WHEN_ALL_
198 {};
199 
200 template <class _ErrorsVariant, class _ValuesTuple, class _StopToken>
201 struct __when_all_state
202 {
203     using __stop_callback_t =
204         stop_callback_for_t<_StopToken, __on_stop_request>;
205 
206     template <class _Receiver>
__arrivestdexec::__when_all::__when_all_state207     void __arrive(_Receiver& __rcvr) noexcept
208     {
209         if (0 == --__count_)
210         {
211             __complete(__rcvr);
212         }
213     }
214 
215     template <class _Receiver>
__completestdexec::__when_all::__when_all_state216     void __complete(_Receiver& __rcvr) noexcept
217     {
218         // Stop callback is no longer needed. Destroy it.
219         __on_stop_.reset();
220         // All child operations have completed and arrived at the barrier.
221         switch (__state_.load(std::memory_order_relaxed))
222         {
223             case __started:
224                 if constexpr (!same_as<_ValuesTuple, __ignore>)
225                 {
226                     // All child operations completed successfully:
227                     __when_all::__set_values(__rcvr, __values_);
228                 }
229                 break;
230             case __error:
231                 if constexpr (!__same_as<_ErrorsVariant, __variant_for<>>)
232                 {
233                     // One or more child operations completed with an error:
234                     __errors_.visit(__complete_fn(set_error, __rcvr),
235                                     __errors_);
236                 }
237                 break;
238             case __stopped:
239                 stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr));
240                 break;
241             default:;
242         }
243     }
244 
245     std::atomic<std::size_t> __count_;
246     inplace_stop_source __stop_source_{};
247     // Could be non-atomic here and atomic_ref everywhere except __completion_fn
248     std::atomic<__state_t> __state_{__started};
249     _ErrorsVariant __errors_{};
250     STDEXEC_ATTRIBUTE((no_unique_address))
251     _ValuesTuple __values_{};
252     __optional<__stop_callback_t> __on_stop_{};
253 };
254 
255 template <class _Env>
__mk_state_fn(const _Env &)256 static auto __mk_state_fn(const _Env&) noexcept
257 {
258     return []<__max1_sender<__env_t<_Env>>... _Child>(__ignore, __ignore,
259                                                       _Child&&...) {
260         using _Traits = __traits<_Env, _Child...>;
261         using _ErrorsVariant = typename _Traits::__errors_variant;
262         using _ValuesTuple = typename _Traits::__values_tuple;
263         using _State = __when_all_state<_ErrorsVariant, _ValuesTuple,
264                                         stop_token_of_t<_Env>>;
265         return _State{sizeof...(_Child)};
266     };
267 }
268 
269 template <class _Env>
270 using __mk_state_fn_t = decltype(__when_all::__mk_state_fn(__declval<_Env>()));
271 
272 struct when_all_t
273 {
274     // Used by the default_domain to find legacy customizations:
275     using _Sender = __1;
276     using __legacy_customizations_t = //
277         __types<tag_invoke_t(when_all_t, _Sender...)>;
278 
279     template <sender... _Senders>
280         requires __domain::__has_common_domain<_Senders...>
operator ()stdexec::__when_all::when_all_t281     auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto
282     {
283         auto __domain = __domain::__common_domain_t<_Senders...>();
284         return stdexec::transform_sender(
285             __domain, __make_sexpr<when_all_t>(
286                           __(), static_cast<_Senders&&>(__sndrs)...));
287     }
288 };
289 
290 struct __when_all_impl : __sexpr_defaults
291 {
292     template <class _Self, class _Env>
293     using __error_t = __mexception<_INVALID_ARGUMENTS_TO_WHEN_ALL_,
294                                    __children_of<_Self, __q<_WITH_SENDERS_>>,
295                                    _WITH_ENVIRONMENT_<_Env>>;
296 
297     template <class _Self, class... _Env>
298     using __completions =
299         __children_of<_Self, __completions_t<__env_t<_Env>...>>;
300 
301     static constexpr auto get_attrs = //
302         []<class... _Child>(__ignore, const _Child&...) noexcept {
303             using _Domain = __domain::__common_domain_t<_Child...>;
304             if constexpr (__same_as<_Domain, default_domain>)
305             {
306                 return env();
307             }
308             else
309             {
310                 return prop{get_domain, _Domain()};
311             }
312         };
313 
314     static constexpr auto get_completion_signatures = //
315         []<class _Self, class... _Env>(_Self&&, _Env&&...) noexcept {
316             static_assert(sender_expr_for<_Self, when_all_t>);
317             return __minvoke<__mtry_catch<__q<__completions>, __q<__error_t>>,
318                              _Self, _Env...>();
319         };
320 
321     static constexpr auto get_env =                                         //
322         []<class _State, class _Receiver>(__ignore, _State& __state,
323                                           const _Receiver& __rcvr) noexcept //
324         -> __env_t<env_of_t<const _Receiver&>> {
325         return __mkenv(stdexec::get_env(__rcvr), __state.__stop_source_);
326     };
327 
328     static constexpr auto get_state = //
329         []<class _Self, class _Receiver>(_Self&& __self, _Receiver& __rcvr)
330         -> __sexpr_apply_result_t<_Self, __mk_state_fn_t<env_of_t<_Receiver>>> {
331         return __sexpr_apply(
332             static_cast<_Self&&>(__self),
333             __when_all::__mk_state_fn(stdexec::get_env(__rcvr)));
334     };
335 
336     static constexpr auto start = //
337         []<class _State, class _Receiver, class... _Operations>(
338             _State& __state, _Receiver& __rcvr,
339             _Operations&... __child_ops) noexcept -> void {
340         // register stop callback:
341         __state.__on_stop_.emplace(get_stop_token(stdexec::get_env(__rcvr)),
342                                    __on_stop_request{__state.__stop_source_});
343         if (__state.__stop_source_.stop_requested())
344         {
345             // Stop has already been requested. Don't bother starting
346             // the child operations.
347             stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr));
348         }
349         else
350         {
351             (stdexec::start(__child_ops), ...);
352             if constexpr (sizeof...(__child_ops) == 0)
353             {
354                 __state.__complete(__rcvr);
355             }
356         }
357     };
358 
359     template <class _State, class _Receiver, class _Error>
__set_errorstdexec::__when_all::__when_all_impl360     static void __set_error(_State& __state, _Receiver&,
361                             _Error&& __err) noexcept
362     {
363         // TODO: What memory orderings are actually needed here?
364         if (__error != __state.__state_.exchange(__error))
365         {
366             __state.__stop_source_.request_stop();
367             // We won the race, free to write the error into the operation
368             // state without worry.
369             if constexpr (__nothrow_decay_copyable<_Error>)
370             {
371                 __state.__errors_.template emplace<__decay_t<_Error>>(
372                     static_cast<_Error&&>(__err));
373             }
374             else
375             {
376                 try
377                 {
378                     __state.__errors_.template emplace<__decay_t<_Error>>(
379                         static_cast<_Error&&>(__err));
380                 }
381                 catch (...)
382                 {
383                     __state.__errors_.template emplace<std::exception_ptr>(
384                         std::current_exception());
385                 }
386             }
387         }
388     }
389 
390     static constexpr auto complete = //
391         []<class _Index, class _State, class _Receiver, class _Set,
392            class... _Args>(_Index, _State& __state, _Receiver& __rcvr, _Set,
393                            _Args&&... __args) noexcept -> void {
394         if constexpr (__same_as<_Set, set_error_t>)
395         {
396             __set_error(__state, __rcvr, static_cast<_Args&&>(__args)...);
397         }
398         else if constexpr (__same_as<_Set, set_stopped_t>)
399         {
400             __state_t __expected = __started;
401             // Transition to the "stopped" state if and only if we're in the
402             // "started" state. (If this fails, it's because we're in an
403             // error state, which trumps cancellation.)
404             if (__state.__state_.compare_exchange_strong(__expected, __stopped))
405             {
406                 __state.__stop_source_.request_stop();
407             }
408         }
409         else if constexpr (!__same_as<decltype(_State::__values_), __ignore>)
410         {
411             // We only need to bother recording the completion values
412             // if we're not already in the "error" or "stopped" state.
413             if (__state.__state_ == __started)
414             {
415                 auto& __opt_values = __tup::get<__v<_Index>>(__state.__values_);
416                 using _Tuple = __decayed_tuple<_Args...>;
417                 static_assert(
418                     __same_as<decltype(*__opt_values), _Tuple&>,
419                     "One of the senders in this when_all() is fibbing about what types it sends");
420                 if constexpr ((__nothrow_decay_copyable<_Args> && ...))
421                 {
422                     __opt_values.emplace(
423                         _Tuple{{static_cast<_Args&&>(__args)}...});
424                 }
425                 else
426                 {
427                     try
428                     {
429                         __opt_values.emplace(
430                             _Tuple{{static_cast<_Args&&>(__args)}...});
431                     }
432                     catch (...)
433                     {
434                         __set_error(__state, __rcvr, std::current_exception());
435                     }
436                 }
437             }
438         }
439 
440         __state.__arrive(__rcvr);
441     };
442 };
443 
444 struct when_all_with_variant_t
445 {
446     using _Sender = __1;
447     using __legacy_customizations_t = //
448         __types<tag_invoke_t(when_all_with_variant_t, _Sender...)>;
449 
450     template <sender... _Senders>
451         requires __domain::__has_common_domain<_Senders...>
operator ()stdexec::__when_all::when_all_with_variant_t452     auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto
453     {
454         auto __domain = __domain::__common_domain_t<_Senders...>();
455         return stdexec::transform_sender(
456             __domain, __make_sexpr<when_all_with_variant_t>(
457                           __(), static_cast<_Senders&&>(__sndrs)...));
458     }
459 
460     template <class _Sender, class _Env>
transform_senderstdexec::__when_all::when_all_with_variant_t461     static auto transform_sender(_Sender&& __sndr, const _Env&)
462     {
463         // transform the when_all_with_variant into a regular when_all (looking
464         // for early when_all customizations), then transform it again to look
465         // for late customizations.
466         return __sexpr_apply(
467             static_cast<_Sender&&>(__sndr),
468             [&]<class... _Child>(__ignore, __ignore, _Child&&... __child) {
469                 return when_all_t()(
470                     into_variant(static_cast<_Child&&>(__child))...);
471             });
472     }
473 };
474 
475 struct __when_all_with_variant_impl : __sexpr_defaults
476 {
477     static constexpr auto get_attrs = //
478         []<class... _Child>(__ignore, const _Child&...) noexcept {
479             using _Domain = __domain::__common_domain_t<_Child...>;
480             if constexpr (same_as<_Domain, default_domain>)
481             {
482                 return env();
483             }
484             else
485             {
486                 return prop{get_domain, _Domain()};
487             }
488         };
489 
490     static constexpr auto get_completion_signatures = //
491         []<class _Sender>(_Sender&&) noexcept         //
492         -> __completion_signatures_of_t<              //
493             transform_sender_result_t<default_domain, _Sender, empty_env>> {
494         return {};
495     };
496 };
497 
498 struct transfer_when_all_t
499 {
500     using _Env = __0;
501     using _Sender = __1;
502     using __legacy_customizations_t = //
503         __types<tag_invoke_t(
504             transfer_when_all_t,
505             get_completion_scheduler_t<set_value_t>(const _Env&), _Sender...)>;
506 
507     template <scheduler _Scheduler, sender... _Senders>
508         requires __domain::__has_common_domain<_Senders...>
operator ()stdexec::__when_all::transfer_when_all_t509     auto operator()(_Scheduler&& __sched,
510                     _Senders&&... __sndrs) const -> __well_formed_sender auto
511     {
512         using _Env = __t<__schfr::__environ<__id<__decay_t<_Scheduler>>>>;
513         auto __domain = query_or(get_domain, __sched, default_domain());
514         return stdexec::transform_sender(
515             __domain, __make_sexpr<transfer_when_all_t>(
516                           _Env{static_cast<_Scheduler&&>(__sched)},
517                           static_cast<_Senders&&>(__sndrs)...));
518     }
519 
520     template <class _Sender, class _Env>
transform_senderstdexec::__when_all::transfer_when_all_t521     static auto transform_sender(_Sender&& __sndr, const _Env&)
522     {
523         // transform the transfer_when_all into a regular transform | when_all
524         // (looking for early customizations), then transform it again to look
525         // for late customizations.
526         return __sexpr_apply(
527             static_cast<_Sender&&>(__sndr),
528             [&]<class _Data, class... _Child>(__ignore, _Data&& __data,
529                                               _Child&&... __child) {
530                 return continue_on(
531                     when_all_t()(static_cast<_Child&&>(__child)...),
532                     get_completion_scheduler<set_value_t>(__data));
533             });
534     }
535 };
536 
537 struct __transfer_when_all_impl : __sexpr_defaults
538 {
539     static constexpr auto get_attrs = //
540         []<class _Data>(const _Data& __data,
541                         const auto&...) noexcept -> const _Data& {
542         return __data;
543     };
544 
545     static constexpr auto get_completion_signatures = //
546         []<class _Sender>(_Sender&&) noexcept         //
547         -> __completion_signatures_of_t<              //
548             transform_sender_result_t<default_domain, _Sender, empty_env>> {
549         return {};
550     };
551 };
552 
553 struct transfer_when_all_with_variant_t
554 {
555     using _Env = __0;
556     using _Sender = __1;
557     using __legacy_customizations_t = //
558         __types<tag_invoke_t(
559             transfer_when_all_with_variant_t,
560             get_completion_scheduler_t<set_value_t>(const _Env&), _Sender...)>;
561 
562     template <scheduler _Scheduler, sender... _Senders>
563         requires __domain::__has_common_domain<_Senders...>
operator ()stdexec::__when_all::transfer_when_all_with_variant_t564     auto operator()(_Scheduler&& __sched,
565                     _Senders&&... __sndrs) const -> __well_formed_sender auto
566     {
567         using _Env = __t<__schfr::__environ<__id<__decay_t<_Scheduler>>>>;
568         auto __domain = query_or(get_domain, __sched, default_domain());
569         return stdexec::transform_sender(
570             __domain, __make_sexpr<transfer_when_all_with_variant_t>(
571                           _Env{{static_cast<_Scheduler&&>(__sched)}},
572                           static_cast<_Senders&&>(__sndrs)...));
573     }
574 
575     template <class _Sender, class _Env>
transform_senderstdexec::__when_all::transfer_when_all_with_variant_t576     static auto transform_sender(_Sender&& __sndr, const _Env&)
577     {
578         // transform the transfer_when_all_with_variant into regular
579         // transform_when_all and into_variant calls/ (looking for early
580         // customizations), then transform it again to look for late
581         // customizations.
582         return __sexpr_apply(
583             static_cast<_Sender&&>(__sndr),
584             [&]<class _Data, class... _Child>(__ignore, _Data&& __data,
585                                               _Child&&... __child) {
586                 return transfer_when_all_t()(
587                     get_completion_scheduler<set_value_t>(
588                         static_cast<_Data&&>(__data)),
589                     into_variant(static_cast<_Child&&>(__child))...);
590             });
591     }
592 };
593 
594 struct __transfer_when_all_with_variant_impl : __sexpr_defaults
595 {
596     static constexpr auto get_attrs = //
597         []<class _Data>(const _Data& __data,
598                         const auto&...) noexcept -> const _Data& {
599         return __data;
600     };
601 
602     static constexpr auto get_completion_signatures = //
603         []<class _Sender>(_Sender&&) noexcept         //
604         -> __completion_signatures_of_t<              //
605             transform_sender_result_t<default_domain, _Sender, empty_env>> {
606         return {};
607     };
608 };
609 } // namespace __when_all
610 
611 using __when_all::when_all_t;
612 inline constexpr when_all_t when_all{};
613 
614 using __when_all::when_all_with_variant_t;
615 inline constexpr when_all_with_variant_t when_all_with_variant{};
616 
617 using __when_all::transfer_when_all_t;
618 inline constexpr transfer_when_all_t transfer_when_all{};
619 
620 using __when_all::transfer_when_all_with_variant_t;
621 inline constexpr transfer_when_all_with_variant_t
622     transfer_when_all_with_variant{};
623 
624 template <>
625 struct __sexpr_impl<when_all_t> : __when_all::__when_all_impl
626 {};
627 
628 template <>
629 struct __sexpr_impl<when_all_with_variant_t> :
630     __when_all::__when_all_with_variant_impl
631 {};
632 
633 template <>
634 struct __sexpr_impl<transfer_when_all_t> : __when_all::__transfer_when_all_impl
635 {};
636 
637 template <>
638 struct __sexpr_impl<transfer_when_all_with_variant_t> :
639     __when_all::__transfer_when_all_with_variant_impl
640 {};
641 } // namespace stdexec
642