xref: /openbmc/sdbusplus/include/sdbusplus/async/stdexec/__detail/__when_all.hpp (revision 10d0b4b7d1498cfd5c3d37edea271a54d1984e41)
1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__execution_fwd.hpp"
19 
20 // include these after __execution_fwd.hpp
21 #include "__basic_sender.hpp"
22 #include "__concepts.hpp"
23 #include "__continues_on.hpp"
24 #include "__diagnostics.hpp"
25 #include "__domain.hpp"
26 #include "__env.hpp"
27 #include "__into_variant.hpp"
28 #include "__meta.hpp"
29 #include "__optional.hpp"
30 #include "__schedulers.hpp"
31 #include "__senders.hpp"
32 #include "__transform_completion_signatures.hpp"
33 #include "__transform_sender.hpp"
34 #include "__tuple.hpp"
35 #include "__type_traits.hpp"
36 #include "__utility.hpp"
37 #include "__variant.hpp"
38 
39 #include "../stop_token.hpp"
40 
41 #include <atomic>
42 #include <exception>
43 
44 namespace stdexec {
45   /////////////////////////////////////////////////////////////////////////////
46   // [execution.senders.adaptors.when_all]
47   // [execution.senders.adaptors.when_all_with_variant]
48   namespace __when_all {
49     enum __state_t {
50       __started,
51       __error,
52       __stopped
53     };
54 
55     struct __on_stop_request {
56       inplace_stop_source& __stop_source_;
57 
operator ()stdexec::__when_all::__on_stop_request58       void operator()() noexcept {
59         __stop_source_.request_stop();
60       }
61     };
62 
63     template <class _Env>
__mkenv(_Env && __env,const inplace_stop_source & __stop_source)64     auto __mkenv(_Env&& __env, const inplace_stop_source& __stop_source) noexcept {
65       return __env::__join(
66         prop{get_stop_token, __stop_source.get_token()}, static_cast<_Env&&>(__env));
67     }
68 
69     template <class _Env>
70     using __env_t =
71       decltype(__when_all::__mkenv(__declval<_Env>(), __declval<inplace_stop_source&>()));
72 
73     template <class _Sender, class _Env>
74     concept __max1_sender =
75       sender_in<_Sender, _Env>
76       && __mvalid<__value_types_of_t, _Sender, _Env, __mconst<int>, __msingle_or<void>>;
77 
78     template <
79       __mstring _Context = "In stdexec::when_all()..."_mstr,
80       __mstring _Diagnostic =
81         "The given sender can complete successfully in more that one way. "
82         "Use stdexec::when_all_with_variant() instead."_mstr
83     >
84     struct _INVALID_WHEN_ALL_ARGUMENT_;
85 
86     template <class _Sender, class... _Env>
87     using __too_many_value_completions_error = __mexception<
88       _INVALID_WHEN_ALL_ARGUMENT_<>,
89       _WITH_SENDER_<_Sender>,
90       _WITH_ENVIRONMENT_<_Env>...
91     >;
92 
93     template <class... _Args>
94     using __all_nothrow_decay_copyable = __mbool<(__nothrow_decay_copyable<_Args> && ...)>;
95 
96     template <class _Error>
97     using __set_error_t = completion_signatures<set_error_t(__decay_t<_Error>)>;
98 
99     template <class _Sender, class... _Env>
100     using __nothrow_decay_copyable_results = __for_each_completion_signature<
101       __completion_signatures_of_t<_Sender, _Env...>,
102       __all_nothrow_decay_copyable,
103       __mand_t
104     >;
105 
106     template <class... _Env>
107     struct __completions_t {
108       template <class... _Senders>
109       using __all_nothrow_decay_copyable_results =
110         __mand<__nothrow_decay_copyable_results<_Senders, _Env...>...>;
111 
112       template <class _Sender, class _ValueTuple, class... _Rest>
113       using __value_tuple_t = __minvoke<
114         __if_c<
115           (0 == sizeof...(_Rest)),
116           __mconst<_ValueTuple>,
117           __q<__too_many_value_completions_error>
118         >,
119         _Sender,
120         _Env...
121       >;
122 
123       template <class _Sender>
124       using __single_values_of_t = __value_types_t<
125         __completion_signatures_of_t<_Sender, _Env...>,
126         __mtransform<__q<__decay_t>, __q<__types>>,
127         __mbind_front_q<__value_tuple_t, _Sender>
128       >;
129 
130       template <class... _Senders>
131       using __set_values_sig_t = __meval<
132         completion_signatures,
133         __minvoke<__mconcat<__qf<set_value_t>>, __single_values_of_t<_Senders>...>
134       >;
135 
136       template <class... _Senders>
137       using __f = __meval<
138         __concat_completion_signatures,
139         __meval<__eptr_completion_if_t, __all_nothrow_decay_copyable_results<_Senders...>>,
140         __minvoke<__with_default<__qq<__set_values_sig_t>, completion_signatures<>>, _Senders...>,
141         __transform_completion_signatures<
142           __completion_signatures_of_t<_Senders, _Env...>,
143           __mconst<completion_signatures<>>::__f,
144           __set_error_t,
145           completion_signatures<set_stopped_t()>,
146           __concat_completion_signatures
147         >...
148       >;
149     };
150 
151     template <class _Receiver, class _ValuesTuple>
__set_values(_Receiver & __rcvr,_ValuesTuple & __values)152     void __set_values(_Receiver& __rcvr, _ValuesTuple& __values) noexcept {
153       __values.apply(
154         [&]<class... OptTuples>(OptTuples&&... __opt_vals) noexcept -> void {
155           __tup::__cat_apply(
156             __mk_completion_fn(set_value, __rcvr), *static_cast<OptTuples&&>(__opt_vals)...);
157         },
158         static_cast<_ValuesTuple&&>(__values));
159     }
160 
161     template <class _Env, class _Sender>
162     using __values_opt_tuple_t =
163       value_types_of_t<_Sender, __env_t<_Env>, __decayed_tuple, __optional>;
164 
165     template <class _Env, __max1_sender<__env_t<_Env>>... _Senders>
166     struct __traits {
167       // tuple<optional<tuple<Vs1...>>, optional<tuple<Vs2...>>, ...>
168       using __values_tuple = __minvoke<
169         __with_default<
170           __mtransform<__mbind_front_q<__values_opt_tuple_t, _Env>, __q<__tuple_for>>,
171           __ignore
172         >,
173         _Senders...
174       >;
175 
176       using __collect_errors = __mbind_front_q<__mset_insert, __mset<>>;
177 
178       using __errors_list = __minvoke<
179         __mconcat<>,
180         __if<
181           __mand<__nothrow_decay_copyable_results<_Senders, _Env>...>,
182           __types<>,
183           __types<std::exception_ptr>
184         >,
185         __error_types_of_t<_Senders, __env_t<_Env>, __q<__types>>...
186       >;
187 
188       using __errors_variant = __mapply<__q<__uniqued_variant_for>, __errors_list>;
189     };
190 
191     struct _INVALID_ARGUMENTS_TO_WHEN_ALL_ { };
192 
193     template <class _ErrorsVariant, class _ValuesTuple, class _StopToken, bool _SendsStopped>
194     struct __when_all_state {
195       using __stop_callback_t = stop_callback_for_t<_StopToken, __on_stop_request>;
196 
197       template <class _Receiver>
__arrivestdexec::__when_all::__when_all_state198       void __arrive(_Receiver& __rcvr) noexcept {
199         if (1 == __count_.fetch_sub(1)) {
200           __complete(__rcvr);
201         }
202       }
203 
204       template <class _Receiver>
__completestdexec::__when_all::__when_all_state205       void __complete(_Receiver& __rcvr) noexcept {
206         // Stop callback is no longer needed. Destroy it.
207         __on_stop_.reset();
208         // All child operations have completed and arrived at the barrier.
209         switch (__state_.load(std::memory_order_relaxed)) {
210         case __started:
211           if constexpr (!same_as<_ValuesTuple, __ignore>) {
212             // All child operations completed successfully:
213             __when_all::__set_values(__rcvr, __values_);
214           }
215           break;
216         case __error:
217           if constexpr (!__same_as<_ErrorsVariant, __variant_for<>>) {
218             // One or more child operations completed with an error:
219             __errors_.visit(
220               __mk_completion_fn(set_error, __rcvr), static_cast<_ErrorsVariant&&>(__errors_));
221           }
222           break;
223         case __stopped:
224           if constexpr (_SendsStopped) {
225             stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr));
226           } else {
227             STDEXEC_UNREACHABLE();
228           }
229           break;
230         default:;
231         }
232       }
233 
234       std::atomic<std::size_t> __count_;
235       inplace_stop_source __stop_source_{};
236       // Could be non-atomic here and atomic_ref everywhere except __completion_fn
237       std::atomic<__state_t> __state_{__started};
238       _ErrorsVariant __errors_{};
STDEXEC_ATTRIBUTEstdexec::__when_all::__when_all_state239       STDEXEC_ATTRIBUTE(no_unique_address) _ValuesTuple __values_ { };
240       __optional<__stop_callback_t> __on_stop_{};
241     };
242 
243     template <class _Env>
__mk_state_fn(const _Env &)244     static auto __mk_state_fn(const _Env&) noexcept {
245       return []<__max1_sender<__env_t<_Env>>... _Child>(__ignore, __ignore, _Child&&...) {
246         using _Traits = __traits<_Env, _Child...>;
247         using _ErrorsVariant = _Traits::__errors_variant;
248         using _ValuesTuple = _Traits::__values_tuple;
249         using _State = __when_all_state<
250           _ErrorsVariant,
251           _ValuesTuple,
252           stop_token_of_t<_Env>,
253           (sends_stopped<_Child, _Env> || ...)>;
254         return _State{sizeof...(_Child)};
255       };
256     }
257 
258     template <class _Env>
259     using __mk_state_fn_t = decltype(__when_all::__mk_state_fn(__declval<_Env>()));
260 
261     struct when_all_t {
262       template <sender... _Senders>
263         requires __has_common_domain<_Senders...>
operator ()stdexec::__when_all::when_all_t264       auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto {
265         auto __domain = __common_domain_t<_Senders...>();
266         return stdexec::transform_sender(
267           __domain, __make_sexpr<when_all_t>(__(), static_cast<_Senders&&>(__sndrs)...));
268       }
269     };
270 
271     struct __when_all_impl : __sexpr_defaults {
272       template <class _Self, class _Env>
273       using __error_t = __mexception<
274         _INVALID_ARGUMENTS_TO_WHEN_ALL_,
275         __children_of<_Self, __q<_WITH_SENDERS_>>,
276         _WITH_ENVIRONMENT_<_Env>
277       >;
278 
279       template <class _Self, class... _Env>
280       using __completions = __children_of<_Self, __completions_t<__env_t<_Env>...>>;
281 
282       static constexpr auto get_attrs = []<class... _Child>(__ignore, const _Child&...) noexcept {
283         using _Domain = __common_domain_t<_Child...>;
284         if constexpr (__same_as<_Domain, default_domain>) {
285           return env();
286         } else {
287           return prop{get_domain, _Domain()};
288         }
289       };
290 
291       static constexpr auto get_completion_signatures =
292         []<class _Self, class... _Env>(_Self&&, _Env&&...) noexcept {
293           static_assert(sender_expr_for<_Self, when_all_t>);
294           return __minvoke<__mtry_catch<__q<__completions>, __q<__error_t>>, _Self, _Env...>();
295         };
296 
297       static constexpr auto get_env =
298         []<class _State, class _Receiver>(
299           __ignore,
300           _State& __state,
301           const _Receiver& __rcvr) noexcept -> __env_t<env_of_t<const _Receiver&>> {
302         return __mkenv(stdexec::get_env(__rcvr), __state.__stop_source_);
303       };
304 
305       static constexpr auto get_state =
306         []<class _Self, class _Receiver>(_Self&& __self, _Receiver& __rcvr)
307         -> __sexpr_apply_result_t<_Self, __mk_state_fn_t<env_of_t<_Receiver>>> {
308         return __sexpr_apply(
309           static_cast<_Self&&>(__self), __when_all::__mk_state_fn(stdexec::get_env(__rcvr)));
310       };
311 
312       static constexpr auto start = []<class _State, class _Receiver, class... _Operations>(
313                                       _State& __state,
314                                       _Receiver& __rcvr,
315                                       _Operations&... __child_ops) noexcept -> void {
316         // register stop callback:
317         __state.__on_stop_.emplace(
318           get_stop_token(stdexec::get_env(__rcvr)), __on_stop_request{__state.__stop_source_});
319         (stdexec::start(__child_ops), ...);
320         if constexpr (sizeof...(__child_ops) == 0) {
321           __state.__complete(__rcvr);
322         }
323       };
324 
325       template <class _State, class _Receiver, class _Error>
__set_errorstdexec::__when_all::__when_all_impl326       static void __set_error(_State& __state, _Receiver&, _Error&& __err) noexcept {
327         // Transition to the "error" state and switch on the prior state.
328         // TODO: What memory orderings are actually needed here?
329         switch (__state.__state_.exchange(__error)) {
330         case __started:
331           // We must request stop. When the previous state is __error or __stopped, then stop has
332           // already been requested.
333           __state.__stop_source_.request_stop();
334           [[fallthrough]];
335         case __stopped:
336           // We are the first child to complete with an error, so we must save the error. (Any
337           // subsequent errors are ignored.)
338           if constexpr (__nothrow_decay_copyable<_Error>) {
339             __state.__errors_.template emplace<__decay_t<_Error>>(static_cast<_Error&&>(__err));
340           } else {
341             STDEXEC_TRY {
342               __state.__errors_.template emplace<__decay_t<_Error>>(static_cast<_Error&&>(__err));
343             }
344             STDEXEC_CATCH_ALL {
345               __state.__errors_.template emplace<std::exception_ptr>(std::current_exception());
346             }
347           }
348           break;
349         case __error:; // We're already in the "error" state. Ignore the error.
350         }
351       }
352 
353       static constexpr auto complete =
354         []<class _Index, class _State, class _Receiver, class _Set, class... _Args>(
355           _Index,
356           _State& __state,
357           _Receiver& __rcvr,
358           _Set,
359           _Args&&... __args) noexcept -> void {
360         using _ValuesTuple = decltype(_State::__values_);
361         if constexpr (__same_as<_Set, set_error_t>) {
362           __set_error(__state, __rcvr, static_cast<_Args&&>(__args)...);
363         } else if constexpr (__same_as<_Set, set_stopped_t>) {
364           __state_t __expected = __started;
365           // Transition to the "stopped" state if and only if we're in the
366           // "started" state. (If this fails, it's because we're in an
367           // error state, which trumps cancellation.)
368           if (__state.__state_.compare_exchange_strong(__expected, __stopped)) {
369             __state.__stop_source_.request_stop();
370           }
371         } else if constexpr (!__same_as<_ValuesTuple, __ignore>) {
372           // We only need to bother recording the completion values
373           // if we're not already in the "error" or "stopped" state.
374           if (__state.__state_.load() == __started) {
375             auto& __opt_values = _ValuesTuple::template __get<__v<_Index>>(__state.__values_);
376             using _Tuple = __decayed_tuple<_Args...>;
377             static_assert(
378               __same_as<decltype(*__opt_values), _Tuple&>,
379               "One of the senders in this when_all() is fibbing about what types it sends");
380             if constexpr ((__nothrow_decay_copyable<_Args> && ...)) {
381               __opt_values.emplace(_Tuple{static_cast<_Args&&>(__args)...});
382             } else {
383               STDEXEC_TRY {
384                 __opt_values.emplace(_Tuple{static_cast<_Args&&>(__args)...});
385               }
386               STDEXEC_CATCH_ALL {
387                 __set_error(__state, __rcvr, std::current_exception());
388               }
389             }
390           }
391         }
392 
393         __state.__arrive(__rcvr);
394       };
395     };
396 
397     struct when_all_with_variant_t {
398       template <sender... _Senders>
399         requires __has_common_domain<_Senders...>
operator ()stdexec::__when_all::when_all_with_variant_t400       auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto {
401         auto __domain = __common_domain_t<_Senders...>();
402         return stdexec::transform_sender(
403           __domain,
404           __make_sexpr<when_all_with_variant_t>(__(), static_cast<_Senders&&>(__sndrs)...));
405       }
406 
407       template <class _Sender, class _Env>
transform_senderstdexec::__when_all::when_all_with_variant_t408       static auto transform_sender(_Sender&& __sndr, const _Env&) {
409         // transform the when_all_with_variant into a regular when_all (looking for
410         // early when_all customizations), then transform it again to look for
411         // late customizations.
412         return __sexpr_apply(
413           static_cast<_Sender&&>(__sndr),
414           [&]<class... _Child>(__ignore, __ignore, _Child&&... __child) {
415             return when_all_t()(into_variant(static_cast<_Child&&>(__child))...);
416           });
417       }
418     };
419 
420     struct __when_all_with_variant_impl : __sexpr_defaults {
421       static constexpr auto get_attrs = []<class... _Child>(__ignore, const _Child&...) noexcept {
422         using _Domain = __common_domain_t<_Child...>;
423         if constexpr (same_as<_Domain, default_domain>) {
424           return env();
425         } else {
426           return prop{get_domain, _Domain()};
427         }
428       };
429 
430       static constexpr auto get_completion_signatures = []<class _Sender>(_Sender&&) noexcept
431         -> __completion_signatures_of_t<transform_sender_result_t<default_domain, _Sender, env<>>> {
432         return {};
433       };
434     };
435 
436     struct transfer_when_all_t {
437       template <scheduler _Scheduler, sender... _Senders>
438         requires __has_common_domain<_Senders...>
439       auto
operator ()stdexec::__when_all::transfer_when_all_t440         operator()(_Scheduler __sched, _Senders&&... __sndrs) const -> __well_formed_sender auto {
441         auto __domain = query_or(get_domain, __sched, default_domain());
442         return stdexec::transform_sender(
443           __domain,
444           __make_sexpr<transfer_when_all_t>(
445             static_cast<_Scheduler&&>(__sched), static_cast<_Senders&&>(__sndrs)...));
446       }
447 
448       template <class _Sender, class _Env>
transform_senderstdexec::__when_all::transfer_when_all_t449       static auto transform_sender(_Sender&& __sndr, const _Env&) {
450         // transform the transfer_when_all into a regular transform | when_all
451         // (looking for early customizations), then transform it again to look for
452         // late customizations.
453         return __sexpr_apply(
454           static_cast<_Sender&&>(__sndr),
455           [&]<class _Data, class... _Child>(__ignore, _Data&& __data, _Child&&... __child) {
456             return continues_on(
457               when_all_t()(static_cast<_Child&&>(__child)...), static_cast<_Data&&>(__data));
458           });
459       }
460     };
461 
462     struct __transfer_when_all_impl : __sexpr_defaults {
463       static constexpr auto get_attrs = []<class _Scheduler, class... _Child>(
464                                           const _Scheduler& __sched,
465                                           const _Child&...) noexcept {
466         using __sndr_t = __call_result_t<when_all_t, _Child...>;
467         using __domain_t = __detail::__early_domain_of_t<__sndr_t, __none_such>;
468         return __sched_attrs{std::cref(__sched), __domain_t{}};
469       };
470 
471       static constexpr auto get_completion_signatures = []<class _Sender>(_Sender&&) noexcept
472         -> __completion_signatures_of_t<transform_sender_result_t<default_domain, _Sender, env<>>> {
473         return {};
474       };
475     };
476 
477     struct transfer_when_all_with_variant_t {
478       template <scheduler _Scheduler, sender... _Senders>
479         requires __has_common_domain<_Senders...>
480       auto
operator ()stdexec::__when_all::transfer_when_all_with_variant_t481         operator()(_Scheduler&& __sched, _Senders&&... __sndrs) const -> __well_formed_sender auto {
482         auto __domain = query_or(get_domain, __sched, default_domain());
483         return stdexec::transform_sender(
484           __domain,
485           __make_sexpr<transfer_when_all_with_variant_t>(
486             static_cast<_Scheduler&&>(__sched), static_cast<_Senders&&>(__sndrs)...));
487       }
488 
489       template <class _Sender, class _Env>
transform_senderstdexec::__when_all::transfer_when_all_with_variant_t490       static auto transform_sender(_Sender&& __sndr, const _Env&) {
491         // transform the transfer_when_all_with_variant into regular transform_when_all
492         // and into_variant calls/ (looking for early customizations), then transform it
493         // again to look for late customizations.
494         return __sexpr_apply(
495           static_cast<_Sender&&>(__sndr),
496           [&]<class _Data, class... _Child>(__ignore, _Data&& __data, _Child&&... __child) {
497             return transfer_when_all_t()(
498               static_cast<_Data&&>(__data), into_variant(static_cast<_Child&&>(__child))...);
499           });
500       }
501     };
502 
503     struct __transfer_when_all_with_variant_impl : __sexpr_defaults {
504       static constexpr auto get_attrs = []<class _Scheduler, class... _Child>(
505                                           const _Scheduler& __sched,
506                                           const _Child&...) noexcept {
507         using __sndr_t = __call_result_t<when_all_with_variant_t, _Child...>;
508         using __domain_t = __detail::__early_domain_of_t<__sndr_t, __none_such>;
509         return __sched_attrs{std::cref(__sched), __domain_t{}};
510       };
511 
512       static constexpr auto get_completion_signatures = []<class _Sender>(_Sender&&) noexcept
513         -> __completion_signatures_of_t<transform_sender_result_t<default_domain, _Sender, env<>>> {
514         return {};
515       };
516     };
517   } // namespace __when_all
518 
519   using __when_all::when_all_t;
520   inline constexpr when_all_t when_all{};
521 
522   using __when_all::when_all_with_variant_t;
523   inline constexpr when_all_with_variant_t when_all_with_variant{};
524 
525   using __when_all::transfer_when_all_t;
526   inline constexpr transfer_when_all_t transfer_when_all{};
527 
528   using __when_all::transfer_when_all_with_variant_t;
529   inline constexpr transfer_when_all_with_variant_t transfer_when_all_with_variant{};
530 
531   template <>
532   struct __sexpr_impl<when_all_t> : __when_all::__when_all_impl { };
533 
534   template <>
535   struct __sexpr_impl<when_all_with_variant_t> : __when_all::__when_all_with_variant_impl { };
536 
537   template <>
538   struct __sexpr_impl<transfer_when_all_t> : __when_all::__transfer_when_all_impl { };
539 
540   template <>
541   struct __sexpr_impl<transfer_when_all_with_variant_t>
542     : __when_all::__transfer_when_all_with_variant_impl { };
543 } // namespace stdexec
544