1 /* 2 * Copyright (c) 2021-2024 NVIDIA Corporation 3 * 4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 5 * (the "License"); you may not use this file except in compliance with 6 * the License. You may obtain a copy of the License at 7 * 8 * https://llvm.org/LICENSE.txt 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include "__execution_fwd.hpp" 19 20 // include these after __execution_fwd.hpp 21 #include "__basic_sender.hpp" 22 #include "__concepts.hpp" 23 #include "__continues_on.hpp" 24 #include "__diagnostics.hpp" 25 #include "__domain.hpp" 26 #include "__env.hpp" 27 #include "__into_variant.hpp" 28 #include "__meta.hpp" 29 #include "__optional.hpp" 30 #include "__schedulers.hpp" 31 #include "__senders.hpp" 32 #include "__transform_completion_signatures.hpp" 33 #include "__transform_sender.hpp" 34 #include "__tuple.hpp" 35 #include "__type_traits.hpp" 36 #include "__utility.hpp" 37 #include "__variant.hpp" 38 39 #include "../stop_token.hpp" 40 41 #include <atomic> 42 #include <exception> 43 44 namespace stdexec { 45 ///////////////////////////////////////////////////////////////////////////// 46 // [execution.senders.adaptors.when_all] 47 // [execution.senders.adaptors.when_all_with_variant] 48 namespace __when_all { 49 enum __state_t { 50 __started, 51 __error, 52 __stopped 53 }; 54 55 struct __on_stop_request { 56 inplace_stop_source& __stop_source_; 57 operator ()stdexec::__when_all::__on_stop_request58 void operator()() noexcept { 59 __stop_source_.request_stop(); 60 } 61 }; 62 63 template <class _Env> __mkenv(_Env && __env,const inplace_stop_source & __stop_source)64 auto __mkenv(_Env&& __env, const inplace_stop_source& __stop_source) noexcept { 65 return __env::__join( 66 prop{get_stop_token, __stop_source.get_token()}, static_cast<_Env&&>(__env)); 67 } 68 69 template <class _Env> 70 using __env_t = 71 decltype(__when_all::__mkenv(__declval<_Env>(), __declval<inplace_stop_source&>())); 72 73 template <class _Sender, class _Env> 74 concept __max1_sender = 75 sender_in<_Sender, _Env> 76 && __mvalid<__value_types_of_t, _Sender, _Env, __mconst<int>, __msingle_or<void>>; 77 78 template < 79 __mstring _Context = "In stdexec::when_all()..."_mstr, 80 __mstring _Diagnostic = 81 "The given sender can complete successfully in more that one way. " 82 "Use stdexec::when_all_with_variant() instead."_mstr 83 > 84 struct _INVALID_WHEN_ALL_ARGUMENT_; 85 86 template <class _Sender, class... _Env> 87 using __too_many_value_completions_error = __mexception< 88 _INVALID_WHEN_ALL_ARGUMENT_<>, 89 _WITH_SENDER_<_Sender>, 90 _WITH_ENVIRONMENT_<_Env>... 91 >; 92 93 template <class... _Args> 94 using __all_nothrow_decay_copyable = __mbool<(__nothrow_decay_copyable<_Args> && ...)>; 95 96 template <class _Error> 97 using __set_error_t = completion_signatures<set_error_t(__decay_t<_Error>)>; 98 99 template <class _Sender, class... _Env> 100 using __nothrow_decay_copyable_results = __for_each_completion_signature< 101 __completion_signatures_of_t<_Sender, _Env...>, 102 __all_nothrow_decay_copyable, 103 __mand_t 104 >; 105 106 template <class... _Env> 107 struct __completions_t { 108 template <class... _Senders> 109 using __all_nothrow_decay_copyable_results = 110 __mand<__nothrow_decay_copyable_results<_Senders, _Env...>...>; 111 112 template <class _Sender, class _ValueTuple, class... _Rest> 113 using __value_tuple_t = __minvoke< 114 __if_c< 115 (0 == sizeof...(_Rest)), 116 __mconst<_ValueTuple>, 117 __q<__too_many_value_completions_error> 118 >, 119 _Sender, 120 _Env... 121 >; 122 123 template <class _Sender> 124 using __single_values_of_t = __value_types_t< 125 __completion_signatures_of_t<_Sender, _Env...>, 126 __mtransform<__q<__decay_t>, __q<__types>>, 127 __mbind_front_q<__value_tuple_t, _Sender> 128 >; 129 130 template <class... _Senders> 131 using __set_values_sig_t = __meval< 132 completion_signatures, 133 __minvoke<__mconcat<__qf<set_value_t>>, __single_values_of_t<_Senders>...> 134 >; 135 136 template <class... _Senders> 137 using __f = __meval< 138 __concat_completion_signatures, 139 __meval<__eptr_completion_if_t, __all_nothrow_decay_copyable_results<_Senders...>>, 140 __minvoke<__with_default<__qq<__set_values_sig_t>, completion_signatures<>>, _Senders...>, 141 __transform_completion_signatures< 142 __completion_signatures_of_t<_Senders, _Env...>, 143 __mconst<completion_signatures<>>::__f, 144 __set_error_t, 145 completion_signatures<set_stopped_t()>, 146 __concat_completion_signatures 147 >... 148 >; 149 }; 150 151 template <class _Receiver, class _ValuesTuple> __set_values(_Receiver & __rcvr,_ValuesTuple & __values)152 void __set_values(_Receiver& __rcvr, _ValuesTuple& __values) noexcept { 153 __values.apply( 154 [&]<class... OptTuples>(OptTuples&&... __opt_vals) noexcept -> void { 155 __tup::__cat_apply( 156 __mk_completion_fn(set_value, __rcvr), *static_cast<OptTuples&&>(__opt_vals)...); 157 }, 158 static_cast<_ValuesTuple&&>(__values)); 159 } 160 161 template <class _Env, class _Sender> 162 using __values_opt_tuple_t = 163 value_types_of_t<_Sender, __env_t<_Env>, __decayed_tuple, __optional>; 164 165 template <class _Env, __max1_sender<__env_t<_Env>>... _Senders> 166 struct __traits { 167 // tuple<optional<tuple<Vs1...>>, optional<tuple<Vs2...>>, ...> 168 using __values_tuple = __minvoke< 169 __with_default< 170 __mtransform<__mbind_front_q<__values_opt_tuple_t, _Env>, __q<__tuple_for>>, 171 __ignore 172 >, 173 _Senders... 174 >; 175 176 using __collect_errors = __mbind_front_q<__mset_insert, __mset<>>; 177 178 using __errors_list = __minvoke< 179 __mconcat<>, 180 __if< 181 __mand<__nothrow_decay_copyable_results<_Senders, _Env>...>, 182 __types<>, 183 __types<std::exception_ptr> 184 >, 185 __error_types_of_t<_Senders, __env_t<_Env>, __q<__types>>... 186 >; 187 188 using __errors_variant = __mapply<__q<__uniqued_variant_for>, __errors_list>; 189 }; 190 191 struct _INVALID_ARGUMENTS_TO_WHEN_ALL_ { }; 192 193 template <class _ErrorsVariant, class _ValuesTuple, class _StopToken, bool _SendsStopped> 194 struct __when_all_state { 195 using __stop_callback_t = stop_callback_for_t<_StopToken, __on_stop_request>; 196 197 template <class _Receiver> __arrivestdexec::__when_all::__when_all_state198 void __arrive(_Receiver& __rcvr) noexcept { 199 if (1 == __count_.fetch_sub(1)) { 200 __complete(__rcvr); 201 } 202 } 203 204 template <class _Receiver> __completestdexec::__when_all::__when_all_state205 void __complete(_Receiver& __rcvr) noexcept { 206 // Stop callback is no longer needed. Destroy it. 207 __on_stop_.reset(); 208 // All child operations have completed and arrived at the barrier. 209 switch (__state_.load(std::memory_order_relaxed)) { 210 case __started: 211 if constexpr (!same_as<_ValuesTuple, __ignore>) { 212 // All child operations completed successfully: 213 __when_all::__set_values(__rcvr, __values_); 214 } 215 break; 216 case __error: 217 if constexpr (!__same_as<_ErrorsVariant, __variant_for<>>) { 218 // One or more child operations completed with an error: 219 __errors_.visit( 220 __mk_completion_fn(set_error, __rcvr), static_cast<_ErrorsVariant&&>(__errors_)); 221 } 222 break; 223 case __stopped: 224 if constexpr (_SendsStopped) { 225 stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr)); 226 } else { 227 STDEXEC_UNREACHABLE(); 228 } 229 break; 230 default:; 231 } 232 } 233 234 std::atomic<std::size_t> __count_; 235 inplace_stop_source __stop_source_{}; 236 // Could be non-atomic here and atomic_ref everywhere except __completion_fn 237 std::atomic<__state_t> __state_{__started}; 238 _ErrorsVariant __errors_{}; STDEXEC_ATTRIBUTEstdexec::__when_all::__when_all_state239 STDEXEC_ATTRIBUTE(no_unique_address) _ValuesTuple __values_ { }; 240 __optional<__stop_callback_t> __on_stop_{}; 241 }; 242 243 template <class _Env> __mk_state_fn(const _Env &)244 static auto __mk_state_fn(const _Env&) noexcept { 245 return []<__max1_sender<__env_t<_Env>>... _Child>(__ignore, __ignore, _Child&&...) { 246 using _Traits = __traits<_Env, _Child...>; 247 using _ErrorsVariant = _Traits::__errors_variant; 248 using _ValuesTuple = _Traits::__values_tuple; 249 using _State = __when_all_state< 250 _ErrorsVariant, 251 _ValuesTuple, 252 stop_token_of_t<_Env>, 253 (sends_stopped<_Child, _Env> || ...)>; 254 return _State{sizeof...(_Child)}; 255 }; 256 } 257 258 template <class _Env> 259 using __mk_state_fn_t = decltype(__when_all::__mk_state_fn(__declval<_Env>())); 260 261 struct when_all_t { 262 template <sender... _Senders> 263 requires __has_common_domain<_Senders...> operator ()stdexec::__when_all::when_all_t264 auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto { 265 auto __domain = __common_domain_t<_Senders...>(); 266 return stdexec::transform_sender( 267 __domain, __make_sexpr<when_all_t>(__(), static_cast<_Senders&&>(__sndrs)...)); 268 } 269 }; 270 271 struct __when_all_impl : __sexpr_defaults { 272 template <class _Self, class _Env> 273 using __error_t = __mexception< 274 _INVALID_ARGUMENTS_TO_WHEN_ALL_, 275 __children_of<_Self, __q<_WITH_SENDERS_>>, 276 _WITH_ENVIRONMENT_<_Env> 277 >; 278 279 template <class _Self, class... _Env> 280 using __completions = __children_of<_Self, __completions_t<__env_t<_Env>...>>; 281 282 static constexpr auto get_attrs = []<class... _Child>(__ignore, const _Child&...) noexcept { 283 using _Domain = __common_domain_t<_Child...>; 284 if constexpr (__same_as<_Domain, default_domain>) { 285 return env(); 286 } else { 287 return prop{get_domain, _Domain()}; 288 } 289 }; 290 291 static constexpr auto get_completion_signatures = 292 []<class _Self, class... _Env>(_Self&&, _Env&&...) noexcept { 293 static_assert(sender_expr_for<_Self, when_all_t>); 294 return __minvoke<__mtry_catch<__q<__completions>, __q<__error_t>>, _Self, _Env...>(); 295 }; 296 297 static constexpr auto get_env = 298 []<class _State, class _Receiver>( 299 __ignore, 300 _State& __state, 301 const _Receiver& __rcvr) noexcept -> __env_t<env_of_t<const _Receiver&>> { 302 return __mkenv(stdexec::get_env(__rcvr), __state.__stop_source_); 303 }; 304 305 static constexpr auto get_state = 306 []<class _Self, class _Receiver>(_Self&& __self, _Receiver& __rcvr) 307 -> __sexpr_apply_result_t<_Self, __mk_state_fn_t<env_of_t<_Receiver>>> { 308 return __sexpr_apply( 309 static_cast<_Self&&>(__self), __when_all::__mk_state_fn(stdexec::get_env(__rcvr))); 310 }; 311 312 static constexpr auto start = []<class _State, class _Receiver, class... _Operations>( 313 _State& __state, 314 _Receiver& __rcvr, 315 _Operations&... __child_ops) noexcept -> void { 316 // register stop callback: 317 __state.__on_stop_.emplace( 318 get_stop_token(stdexec::get_env(__rcvr)), __on_stop_request{__state.__stop_source_}); 319 (stdexec::start(__child_ops), ...); 320 if constexpr (sizeof...(__child_ops) == 0) { 321 __state.__complete(__rcvr); 322 } 323 }; 324 325 template <class _State, class _Receiver, class _Error> __set_errorstdexec::__when_all::__when_all_impl326 static void __set_error(_State& __state, _Receiver&, _Error&& __err) noexcept { 327 // Transition to the "error" state and switch on the prior state. 328 // TODO: What memory orderings are actually needed here? 329 switch (__state.__state_.exchange(__error)) { 330 case __started: 331 // We must request stop. When the previous state is __error or __stopped, then stop has 332 // already been requested. 333 __state.__stop_source_.request_stop(); 334 [[fallthrough]]; 335 case __stopped: 336 // We are the first child to complete with an error, so we must save the error. (Any 337 // subsequent errors are ignored.) 338 if constexpr (__nothrow_decay_copyable<_Error>) { 339 __state.__errors_.template emplace<__decay_t<_Error>>(static_cast<_Error&&>(__err)); 340 } else { 341 STDEXEC_TRY { 342 __state.__errors_.template emplace<__decay_t<_Error>>(static_cast<_Error&&>(__err)); 343 } 344 STDEXEC_CATCH_ALL { 345 __state.__errors_.template emplace<std::exception_ptr>(std::current_exception()); 346 } 347 } 348 break; 349 case __error:; // We're already in the "error" state. Ignore the error. 350 } 351 } 352 353 static constexpr auto complete = 354 []<class _Index, class _State, class _Receiver, class _Set, class... _Args>( 355 _Index, 356 _State& __state, 357 _Receiver& __rcvr, 358 _Set, 359 _Args&&... __args) noexcept -> void { 360 using _ValuesTuple = decltype(_State::__values_); 361 if constexpr (__same_as<_Set, set_error_t>) { 362 __set_error(__state, __rcvr, static_cast<_Args&&>(__args)...); 363 } else if constexpr (__same_as<_Set, set_stopped_t>) { 364 __state_t __expected = __started; 365 // Transition to the "stopped" state if and only if we're in the 366 // "started" state. (If this fails, it's because we're in an 367 // error state, which trumps cancellation.) 368 if (__state.__state_.compare_exchange_strong(__expected, __stopped)) { 369 __state.__stop_source_.request_stop(); 370 } 371 } else if constexpr (!__same_as<_ValuesTuple, __ignore>) { 372 // We only need to bother recording the completion values 373 // if we're not already in the "error" or "stopped" state. 374 if (__state.__state_.load() == __started) { 375 auto& __opt_values = _ValuesTuple::template __get<__v<_Index>>(__state.__values_); 376 using _Tuple = __decayed_tuple<_Args...>; 377 static_assert( 378 __same_as<decltype(*__opt_values), _Tuple&>, 379 "One of the senders in this when_all() is fibbing about what types it sends"); 380 if constexpr ((__nothrow_decay_copyable<_Args> && ...)) { 381 __opt_values.emplace(_Tuple{static_cast<_Args&&>(__args)...}); 382 } else { 383 STDEXEC_TRY { 384 __opt_values.emplace(_Tuple{static_cast<_Args&&>(__args)...}); 385 } 386 STDEXEC_CATCH_ALL { 387 __set_error(__state, __rcvr, std::current_exception()); 388 } 389 } 390 } 391 } 392 393 __state.__arrive(__rcvr); 394 }; 395 }; 396 397 struct when_all_with_variant_t { 398 template <sender... _Senders> 399 requires __has_common_domain<_Senders...> operator ()stdexec::__when_all::when_all_with_variant_t400 auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto { 401 auto __domain = __common_domain_t<_Senders...>(); 402 return stdexec::transform_sender( 403 __domain, 404 __make_sexpr<when_all_with_variant_t>(__(), static_cast<_Senders&&>(__sndrs)...)); 405 } 406 407 template <class _Sender, class _Env> transform_senderstdexec::__when_all::when_all_with_variant_t408 static auto transform_sender(_Sender&& __sndr, const _Env&) { 409 // transform the when_all_with_variant into a regular when_all (looking for 410 // early when_all customizations), then transform it again to look for 411 // late customizations. 412 return __sexpr_apply( 413 static_cast<_Sender&&>(__sndr), 414 [&]<class... _Child>(__ignore, __ignore, _Child&&... __child) { 415 return when_all_t()(into_variant(static_cast<_Child&&>(__child))...); 416 }); 417 } 418 }; 419 420 struct __when_all_with_variant_impl : __sexpr_defaults { 421 static constexpr auto get_attrs = []<class... _Child>(__ignore, const _Child&...) noexcept { 422 using _Domain = __common_domain_t<_Child...>; 423 if constexpr (same_as<_Domain, default_domain>) { 424 return env(); 425 } else { 426 return prop{get_domain, _Domain()}; 427 } 428 }; 429 430 static constexpr auto get_completion_signatures = []<class _Sender>(_Sender&&) noexcept 431 -> __completion_signatures_of_t<transform_sender_result_t<default_domain, _Sender, env<>>> { 432 return {}; 433 }; 434 }; 435 436 struct transfer_when_all_t { 437 template <scheduler _Scheduler, sender... _Senders> 438 requires __has_common_domain<_Senders...> 439 auto operator ()stdexec::__when_all::transfer_when_all_t440 operator()(_Scheduler __sched, _Senders&&... __sndrs) const -> __well_formed_sender auto { 441 auto __domain = query_or(get_domain, __sched, default_domain()); 442 return stdexec::transform_sender( 443 __domain, 444 __make_sexpr<transfer_when_all_t>( 445 static_cast<_Scheduler&&>(__sched), static_cast<_Senders&&>(__sndrs)...)); 446 } 447 448 template <class _Sender, class _Env> transform_senderstdexec::__when_all::transfer_when_all_t449 static auto transform_sender(_Sender&& __sndr, const _Env&) { 450 // transform the transfer_when_all into a regular transform | when_all 451 // (looking for early customizations), then transform it again to look for 452 // late customizations. 453 return __sexpr_apply( 454 static_cast<_Sender&&>(__sndr), 455 [&]<class _Data, class... _Child>(__ignore, _Data&& __data, _Child&&... __child) { 456 return continues_on( 457 when_all_t()(static_cast<_Child&&>(__child)...), static_cast<_Data&&>(__data)); 458 }); 459 } 460 }; 461 462 struct __transfer_when_all_impl : __sexpr_defaults { 463 static constexpr auto get_attrs = []<class _Scheduler, class... _Child>( 464 const _Scheduler& __sched, 465 const _Child&...) noexcept { 466 using __sndr_t = __call_result_t<when_all_t, _Child...>; 467 using __domain_t = __detail::__early_domain_of_t<__sndr_t, __none_such>; 468 return __sched_attrs{std::cref(__sched), __domain_t{}}; 469 }; 470 471 static constexpr auto get_completion_signatures = []<class _Sender>(_Sender&&) noexcept 472 -> __completion_signatures_of_t<transform_sender_result_t<default_domain, _Sender, env<>>> { 473 return {}; 474 }; 475 }; 476 477 struct transfer_when_all_with_variant_t { 478 template <scheduler _Scheduler, sender... _Senders> 479 requires __has_common_domain<_Senders...> 480 auto operator ()stdexec::__when_all::transfer_when_all_with_variant_t481 operator()(_Scheduler&& __sched, _Senders&&... __sndrs) const -> __well_formed_sender auto { 482 auto __domain = query_or(get_domain, __sched, default_domain()); 483 return stdexec::transform_sender( 484 __domain, 485 __make_sexpr<transfer_when_all_with_variant_t>( 486 static_cast<_Scheduler&&>(__sched), static_cast<_Senders&&>(__sndrs)...)); 487 } 488 489 template <class _Sender, class _Env> transform_senderstdexec::__when_all::transfer_when_all_with_variant_t490 static auto transform_sender(_Sender&& __sndr, const _Env&) { 491 // transform the transfer_when_all_with_variant into regular transform_when_all 492 // and into_variant calls/ (looking for early customizations), then transform it 493 // again to look for late customizations. 494 return __sexpr_apply( 495 static_cast<_Sender&&>(__sndr), 496 [&]<class _Data, class... _Child>(__ignore, _Data&& __data, _Child&&... __child) { 497 return transfer_when_all_t()( 498 static_cast<_Data&&>(__data), into_variant(static_cast<_Child&&>(__child))...); 499 }); 500 } 501 }; 502 503 struct __transfer_when_all_with_variant_impl : __sexpr_defaults { 504 static constexpr auto get_attrs = []<class _Scheduler, class... _Child>( 505 const _Scheduler& __sched, 506 const _Child&...) noexcept { 507 using __sndr_t = __call_result_t<when_all_with_variant_t, _Child...>; 508 using __domain_t = __detail::__early_domain_of_t<__sndr_t, __none_such>; 509 return __sched_attrs{std::cref(__sched), __domain_t{}}; 510 }; 511 512 static constexpr auto get_completion_signatures = []<class _Sender>(_Sender&&) noexcept 513 -> __completion_signatures_of_t<transform_sender_result_t<default_domain, _Sender, env<>>> { 514 return {}; 515 }; 516 }; 517 } // namespace __when_all 518 519 using __when_all::when_all_t; 520 inline constexpr when_all_t when_all{}; 521 522 using __when_all::when_all_with_variant_t; 523 inline constexpr when_all_with_variant_t when_all_with_variant{}; 524 525 using __when_all::transfer_when_all_t; 526 inline constexpr transfer_when_all_t transfer_when_all{}; 527 528 using __when_all::transfer_when_all_with_variant_t; 529 inline constexpr transfer_when_all_with_variant_t transfer_when_all_with_variant{}; 530 531 template <> 532 struct __sexpr_impl<when_all_t> : __when_all::__when_all_impl { }; 533 534 template <> 535 struct __sexpr_impl<when_all_with_variant_t> : __when_all::__when_all_with_variant_impl { }; 536 537 template <> 538 struct __sexpr_impl<transfer_when_all_t> : __when_all::__transfer_when_all_impl { }; 539 540 template <> 541 struct __sexpr_impl<transfer_when_all_with_variant_t> 542 : __when_all::__transfer_when_all_with_variant_impl { }; 543 } // namespace stdexec 544