1 /* 2 * Copyright (c) 2021-2024 NVIDIA Corporation 3 * 4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 5 * (the "License"); you may not use this file except in compliance with 6 * the License. You may obtain a copy of the License at 7 * 8 * https://llvm.org/LICENSE.txt 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include "__execution_fwd.hpp" 19 20 // include these after __execution_fwd.hpp 21 #include "../stop_token.hpp" 22 #include "__basic_sender.hpp" 23 #include "__concepts.hpp" 24 #include "__continue_on.hpp" 25 #include "__diagnostics.hpp" 26 #include "__domain.hpp" 27 #include "__env.hpp" 28 #include "__into_variant.hpp" 29 #include "__meta.hpp" 30 #include "__optional.hpp" 31 #include "__schedulers.hpp" 32 #include "__senders.hpp" 33 #include "__transform_completion_signatures.hpp" 34 #include "__transform_sender.hpp" 35 #include "__tuple.hpp" 36 #include "__type_traits.hpp" 37 #include "__utility.hpp" 38 #include "__variant.hpp" 39 40 #include <atomic> 41 #include <exception> 42 43 namespace stdexec 44 { 45 ///////////////////////////////////////////////////////////////////////////// 46 // [execution.senders.adaptors.when_all] 47 // [execution.senders.adaptors.when_all_with_variant] 48 namespace __when_all 49 { 50 enum __state_t 51 { 52 __started, 53 __error, 54 __stopped 55 }; 56 57 struct __on_stop_request 58 { 59 inplace_stop_source& __stop_source_; 60 61 void operator()() noexcept 62 { 63 __stop_source_.request_stop(); 64 } 65 }; 66 67 template <class _Env> 68 auto __mkenv(_Env&& __env, const inplace_stop_source& __stop_source) noexcept 69 { 70 return __env::__join(prop{get_stop_token, __stop_source.get_token()}, 71 static_cast<_Env&&>(__env)); 72 } 73 74 template <class _Env> 75 using __env_t = // 76 decltype(__when_all::__mkenv(__declval<_Env>(), 77 __declval<inplace_stop_source&>())); 78 79 template <class _Sender, class _Env> 80 concept __max1_sender = 81 sender_in<_Sender, _Env> && __mvalid<__value_types_of_t, _Sender, _Env, 82 __mconst<int>, __msingle_or<void>>; 83 84 template < 85 __mstring _Context = "In stdexec::when_all()..."_mstr, 86 __mstring _Diagnostic = 87 "The given sender can complete successfully in more that one way. " 88 "Use stdexec::when_all_with_variant() instead."_mstr> 89 struct _INVALID_WHEN_ALL_ARGUMENT_; 90 91 template <class _Sender, class... _Env> 92 using __too_many_value_completions_error = 93 __mexception<_INVALID_WHEN_ALL_ARGUMENT_<>, _WITH_SENDER_<_Sender>, 94 _WITH_ENVIRONMENT_<_Env>...>; 95 96 template <class... _Args> 97 using __all_nothrow_decay_copyable = 98 __mbool<(__nothrow_decay_copyable<_Args> && ...)>; 99 100 template <class _Error> 101 using __set_error_t = completion_signatures<set_error_t(__decay_t<_Error>)>; 102 103 template <class _Sender, class... _Env> 104 using __nothrow_decay_copyable_results = // 105 __for_each_completion_signature< 106 __completion_signatures_of_t<_Sender, _Env...>, 107 __all_nothrow_decay_copyable, __mand_t>; 108 109 template <class... _Env> 110 struct __completions_t 111 { 112 template <class... _Senders> 113 using __all_nothrow_decay_copyable_results = // 114 __mand<__nothrow_decay_copyable_results<_Senders, _Env...>...>; 115 116 template <class _Sender, class _ValueTuple, class... _Rest> 117 using __value_tuple_t = 118 __minvoke<__if_c<(0 == sizeof...(_Rest)), __mconst<_ValueTuple>, 119 __q<__too_many_value_completions_error>>, 120 _Sender, _Env...>; 121 122 template <class _Sender> 123 using __single_values_of_t = // 124 __value_types_t<__completion_signatures_of_t<_Sender, _Env...>, 125 __mtransform<__q<__decay_t>, __q<__types>>, 126 __mbind_front_q<__value_tuple_t, _Sender>>; 127 128 template <class... _Senders> 129 using __set_values_sig_t = // 130 __meval<completion_signatures, 131 __minvoke<__mconcat<__qf<set_value_t>>, 132 __single_values_of_t<_Senders>...>>; 133 134 template <class... _Senders> 135 using __f = // 136 __meval< // 137 __concat_completion_signatures, 138 __meval<__eptr_completion_if_t, 139 __all_nothrow_decay_copyable_results<_Senders...>>, 140 completion_signatures<set_stopped_t()>, 141 __minvoke<__with_default<__qq<__set_values_sig_t>, 142 completion_signatures<>>, 143 _Senders...>, 144 __transform_completion_signatures< 145 __completion_signatures_of_t<_Senders, _Env...>, 146 __mconst<completion_signatures<>>::__f, __set_error_t, 147 completion_signatures<>, __concat_completion_signatures>...>; 148 }; 149 150 template <class _Tag, class _Receiver> 151 auto __complete_fn(_Tag, _Receiver& __rcvr) noexcept 152 { 153 return [&]<class... _Ts>(_Ts&... __ts) noexcept { 154 _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Ts&&>(__ts)...); 155 }; 156 } 157 158 template <class _Receiver, class _ValuesTuple> 159 void __set_values(_Receiver& __rcvr, _ValuesTuple& __values) noexcept 160 { 161 __values.apply( 162 [&](auto&... __opt_vals) noexcept -> void { 163 __tup::__cat_apply(__when_all::__complete_fn(set_value, __rcvr), 164 *__opt_vals...); 165 }, 166 __values); 167 } 168 169 template <class _Env, class _Sender> 170 using __values_opt_tuple_t = // 171 value_types_of_t<_Sender, __env_t<_Env>, __decayed_tuple, __optional>; 172 173 template <class _Env, __max1_sender<__env_t<_Env>>... _Senders> 174 struct __traits 175 { 176 // tuple<optional<tuple<Vs1...>>, optional<tuple<Vs2...>>, ...> 177 using __values_tuple = // 178 __minvoke<__with_default< 179 __mtransform<__mbind_front_q<__values_opt_tuple_t, _Env>, 180 __q<__tuple_for>>, 181 __ignore>, 182 _Senders...>; 183 184 using __collect_errors = __mbind_front_q<__mset_insert, __mset<>>; 185 186 using __errors_list = // 187 __minvoke< 188 __mconcat<>, 189 __if<__mand<__nothrow_decay_copyable_results<_Senders, _Env>...>, 190 __types<>, __types<std::exception_ptr>>, 191 __error_types_of_t<_Senders, __env_t<_Env>, __q<__types>>...>; 192 193 using __errors_variant = 194 __mapply<__q<__uniqued_variant_for>, __errors_list>; 195 }; 196 197 struct _INVALID_ARGUMENTS_TO_WHEN_ALL_ 198 {}; 199 200 template <class _ErrorsVariant, class _ValuesTuple, class _StopToken> 201 struct __when_all_state 202 { 203 using __stop_callback_t = 204 stop_callback_for_t<_StopToken, __on_stop_request>; 205 206 template <class _Receiver> 207 void __arrive(_Receiver& __rcvr) noexcept 208 { 209 if (0 == --__count_) 210 { 211 __complete(__rcvr); 212 } 213 } 214 215 template <class _Receiver> 216 void __complete(_Receiver& __rcvr) noexcept 217 { 218 // Stop callback is no longer needed. Destroy it. 219 __on_stop_.reset(); 220 // All child operations have completed and arrived at the barrier. 221 switch (__state_.load(std::memory_order_relaxed)) 222 { 223 case __started: 224 if constexpr (!same_as<_ValuesTuple, __ignore>) 225 { 226 // All child operations completed successfully: 227 __when_all::__set_values(__rcvr, __values_); 228 } 229 break; 230 case __error: 231 if constexpr (!__same_as<_ErrorsVariant, __variant_for<>>) 232 { 233 // One or more child operations completed with an error: 234 __errors_.visit(__complete_fn(set_error, __rcvr), 235 __errors_); 236 } 237 break; 238 case __stopped: 239 stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr)); 240 break; 241 default:; 242 } 243 } 244 245 std::atomic<std::size_t> __count_; 246 inplace_stop_source __stop_source_{}; 247 // Could be non-atomic here and atomic_ref everywhere except __completion_fn 248 std::atomic<__state_t> __state_{__started}; 249 _ErrorsVariant __errors_{}; 250 STDEXEC_ATTRIBUTE((no_unique_address)) 251 _ValuesTuple __values_{}; 252 __optional<__stop_callback_t> __on_stop_{}; 253 }; 254 255 template <class _Env> 256 static auto __mk_state_fn(const _Env&) noexcept 257 { 258 return []<__max1_sender<__env_t<_Env>>... _Child>(__ignore, __ignore, 259 _Child&&...) { 260 using _Traits = __traits<_Env, _Child...>; 261 using _ErrorsVariant = typename _Traits::__errors_variant; 262 using _ValuesTuple = typename _Traits::__values_tuple; 263 using _State = __when_all_state<_ErrorsVariant, _ValuesTuple, 264 stop_token_of_t<_Env>>; 265 return _State{sizeof...(_Child)}; 266 }; 267 } 268 269 template <class _Env> 270 using __mk_state_fn_t = decltype(__when_all::__mk_state_fn(__declval<_Env>())); 271 272 struct when_all_t 273 { 274 // Used by the default_domain to find legacy customizations: 275 using _Sender = __1; 276 using __legacy_customizations_t = // 277 __types<tag_invoke_t(when_all_t, _Sender...)>; 278 279 template <sender... _Senders> 280 requires __domain::__has_common_domain<_Senders...> 281 auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto 282 { 283 auto __domain = __domain::__common_domain_t<_Senders...>(); 284 return stdexec::transform_sender( 285 __domain, __make_sexpr<when_all_t>( 286 __(), static_cast<_Senders&&>(__sndrs)...)); 287 } 288 }; 289 290 struct __when_all_impl : __sexpr_defaults 291 { 292 template <class _Self, class _Env> 293 using __error_t = __mexception<_INVALID_ARGUMENTS_TO_WHEN_ALL_, 294 __children_of<_Self, __q<_WITH_SENDERS_>>, 295 _WITH_ENVIRONMENT_<_Env>>; 296 297 template <class _Self, class... _Env> 298 using __completions = 299 __children_of<_Self, __completions_t<__env_t<_Env>...>>; 300 301 static constexpr auto get_attrs = // 302 []<class... _Child>(__ignore, const _Child&...) noexcept { 303 using _Domain = __domain::__common_domain_t<_Child...>; 304 if constexpr (__same_as<_Domain, default_domain>) 305 { 306 return env(); 307 } 308 else 309 { 310 return prop{get_domain, _Domain()}; 311 } 312 }; 313 314 static constexpr auto get_completion_signatures = // 315 []<class _Self, class... _Env>(_Self&&, _Env&&...) noexcept { 316 static_assert(sender_expr_for<_Self, when_all_t>); 317 return __minvoke<__mtry_catch<__q<__completions>, __q<__error_t>>, 318 _Self, _Env...>(); 319 }; 320 321 static constexpr auto get_env = // 322 []<class _State, class _Receiver>(__ignore, _State& __state, 323 const _Receiver& __rcvr) noexcept // 324 -> __env_t<env_of_t<const _Receiver&>> { 325 return __mkenv(stdexec::get_env(__rcvr), __state.__stop_source_); 326 }; 327 328 static constexpr auto get_state = // 329 []<class _Self, class _Receiver>(_Self&& __self, _Receiver& __rcvr) 330 -> __sexpr_apply_result_t<_Self, __mk_state_fn_t<env_of_t<_Receiver>>> { 331 return __sexpr_apply( 332 static_cast<_Self&&>(__self), 333 __when_all::__mk_state_fn(stdexec::get_env(__rcvr))); 334 }; 335 336 static constexpr auto start = // 337 []<class _State, class _Receiver, class... _Operations>( 338 _State& __state, _Receiver& __rcvr, 339 _Operations&... __child_ops) noexcept -> void { 340 // register stop callback: 341 __state.__on_stop_.emplace(get_stop_token(stdexec::get_env(__rcvr)), 342 __on_stop_request{__state.__stop_source_}); 343 if (__state.__stop_source_.stop_requested()) 344 { 345 // Stop has already been requested. Don't bother starting 346 // the child operations. 347 stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr)); 348 } 349 else 350 { 351 (stdexec::start(__child_ops), ...); 352 if constexpr (sizeof...(__child_ops) == 0) 353 { 354 __state.__complete(__rcvr); 355 } 356 } 357 }; 358 359 template <class _State, class _Receiver, class _Error> 360 static void __set_error(_State& __state, _Receiver&, 361 _Error&& __err) noexcept 362 { 363 // TODO: What memory orderings are actually needed here? 364 if (__error != __state.__state_.exchange(__error)) 365 { 366 __state.__stop_source_.request_stop(); 367 // We won the race, free to write the error into the operation 368 // state without worry. 369 if constexpr (__nothrow_decay_copyable<_Error>) 370 { 371 __state.__errors_.template emplace<__decay_t<_Error>>( 372 static_cast<_Error&&>(__err)); 373 } 374 else 375 { 376 try 377 { 378 __state.__errors_.template emplace<__decay_t<_Error>>( 379 static_cast<_Error&&>(__err)); 380 } 381 catch (...) 382 { 383 __state.__errors_.template emplace<std::exception_ptr>( 384 std::current_exception()); 385 } 386 } 387 } 388 } 389 390 static constexpr auto complete = // 391 []<class _Index, class _State, class _Receiver, class _Set, 392 class... _Args>(_Index, _State& __state, _Receiver& __rcvr, _Set, 393 _Args&&... __args) noexcept -> void { 394 if constexpr (__same_as<_Set, set_error_t>) 395 { 396 __set_error(__state, __rcvr, static_cast<_Args&&>(__args)...); 397 } 398 else if constexpr (__same_as<_Set, set_stopped_t>) 399 { 400 __state_t __expected = __started; 401 // Transition to the "stopped" state if and only if we're in the 402 // "started" state. (If this fails, it's because we're in an 403 // error state, which trumps cancellation.) 404 if (__state.__state_.compare_exchange_strong(__expected, __stopped)) 405 { 406 __state.__stop_source_.request_stop(); 407 } 408 } 409 else if constexpr (!__same_as<decltype(_State::__values_), __ignore>) 410 { 411 // We only need to bother recording the completion values 412 // if we're not already in the "error" or "stopped" state. 413 if (__state.__state_ == __started) 414 { 415 auto& __opt_values = __tup::get<__v<_Index>>(__state.__values_); 416 using _Tuple = __decayed_tuple<_Args...>; 417 static_assert( 418 __same_as<decltype(*__opt_values), _Tuple&>, 419 "One of the senders in this when_all() is fibbing about what types it sends"); 420 if constexpr ((__nothrow_decay_copyable<_Args> && ...)) 421 { 422 __opt_values.emplace( 423 _Tuple{{static_cast<_Args&&>(__args)}...}); 424 } 425 else 426 { 427 try 428 { 429 __opt_values.emplace( 430 _Tuple{{static_cast<_Args&&>(__args)}...}); 431 } 432 catch (...) 433 { 434 __set_error(__state, __rcvr, std::current_exception()); 435 } 436 } 437 } 438 } 439 440 __state.__arrive(__rcvr); 441 }; 442 }; 443 444 struct when_all_with_variant_t 445 { 446 using _Sender = __1; 447 using __legacy_customizations_t = // 448 __types<tag_invoke_t(when_all_with_variant_t, _Sender...)>; 449 450 template <sender... _Senders> 451 requires __domain::__has_common_domain<_Senders...> 452 auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto 453 { 454 auto __domain = __domain::__common_domain_t<_Senders...>(); 455 return stdexec::transform_sender( 456 __domain, __make_sexpr<when_all_with_variant_t>( 457 __(), static_cast<_Senders&&>(__sndrs)...)); 458 } 459 460 template <class _Sender, class _Env> 461 static auto transform_sender(_Sender&& __sndr, const _Env&) 462 { 463 // transform the when_all_with_variant into a regular when_all (looking 464 // for early when_all customizations), then transform it again to look 465 // for late customizations. 466 return __sexpr_apply( 467 static_cast<_Sender&&>(__sndr), 468 [&]<class... _Child>(__ignore, __ignore, _Child&&... __child) { 469 return when_all_t()( 470 into_variant(static_cast<_Child&&>(__child))...); 471 }); 472 } 473 }; 474 475 struct __when_all_with_variant_impl : __sexpr_defaults 476 { 477 static constexpr auto get_attrs = // 478 []<class... _Child>(__ignore, const _Child&...) noexcept { 479 using _Domain = __domain::__common_domain_t<_Child...>; 480 if constexpr (same_as<_Domain, default_domain>) 481 { 482 return env(); 483 } 484 else 485 { 486 return prop{get_domain, _Domain()}; 487 } 488 }; 489 490 static constexpr auto get_completion_signatures = // 491 []<class _Sender>(_Sender&&) noexcept // 492 -> __completion_signatures_of_t< // 493 transform_sender_result_t<default_domain, _Sender, empty_env>> { 494 return {}; 495 }; 496 }; 497 498 struct transfer_when_all_t 499 { 500 using _Env = __0; 501 using _Sender = __1; 502 using __legacy_customizations_t = // 503 __types<tag_invoke_t( 504 transfer_when_all_t, 505 get_completion_scheduler_t<set_value_t>(const _Env&), _Sender...)>; 506 507 template <scheduler _Scheduler, sender... _Senders> 508 requires __domain::__has_common_domain<_Senders...> 509 auto operator()(_Scheduler&& __sched, 510 _Senders&&... __sndrs) const -> __well_formed_sender auto 511 { 512 using _Env = __t<__schfr::__environ<__id<__decay_t<_Scheduler>>>>; 513 auto __domain = query_or(get_domain, __sched, default_domain()); 514 return stdexec::transform_sender( 515 __domain, __make_sexpr<transfer_when_all_t>( 516 _Env{static_cast<_Scheduler&&>(__sched)}, 517 static_cast<_Senders&&>(__sndrs)...)); 518 } 519 520 template <class _Sender, class _Env> 521 static auto transform_sender(_Sender&& __sndr, const _Env&) 522 { 523 // transform the transfer_when_all into a regular transform | when_all 524 // (looking for early customizations), then transform it again to look 525 // for late customizations. 526 return __sexpr_apply( 527 static_cast<_Sender&&>(__sndr), 528 [&]<class _Data, class... _Child>(__ignore, _Data&& __data, 529 _Child&&... __child) { 530 return continue_on( 531 when_all_t()(static_cast<_Child&&>(__child)...), 532 get_completion_scheduler<set_value_t>(__data)); 533 }); 534 } 535 }; 536 537 struct __transfer_when_all_impl : __sexpr_defaults 538 { 539 static constexpr auto get_attrs = // 540 []<class _Data>(const _Data& __data, 541 const auto&...) noexcept -> const _Data& { 542 return __data; 543 }; 544 545 static constexpr auto get_completion_signatures = // 546 []<class _Sender>(_Sender&&) noexcept // 547 -> __completion_signatures_of_t< // 548 transform_sender_result_t<default_domain, _Sender, empty_env>> { 549 return {}; 550 }; 551 }; 552 553 struct transfer_when_all_with_variant_t 554 { 555 using _Env = __0; 556 using _Sender = __1; 557 using __legacy_customizations_t = // 558 __types<tag_invoke_t( 559 transfer_when_all_with_variant_t, 560 get_completion_scheduler_t<set_value_t>(const _Env&), _Sender...)>; 561 562 template <scheduler _Scheduler, sender... _Senders> 563 requires __domain::__has_common_domain<_Senders...> 564 auto operator()(_Scheduler&& __sched, 565 _Senders&&... __sndrs) const -> __well_formed_sender auto 566 { 567 using _Env = __t<__schfr::__environ<__id<__decay_t<_Scheduler>>>>; 568 auto __domain = query_or(get_domain, __sched, default_domain()); 569 return stdexec::transform_sender( 570 __domain, __make_sexpr<transfer_when_all_with_variant_t>( 571 _Env{{static_cast<_Scheduler&&>(__sched)}}, 572 static_cast<_Senders&&>(__sndrs)...)); 573 } 574 575 template <class _Sender, class _Env> 576 static auto transform_sender(_Sender&& __sndr, const _Env&) 577 { 578 // transform the transfer_when_all_with_variant into regular 579 // transform_when_all and into_variant calls/ (looking for early 580 // customizations), then transform it again to look for late 581 // customizations. 582 return __sexpr_apply( 583 static_cast<_Sender&&>(__sndr), 584 [&]<class _Data, class... _Child>(__ignore, _Data&& __data, 585 _Child&&... __child) { 586 return transfer_when_all_t()( 587 get_completion_scheduler<set_value_t>( 588 static_cast<_Data&&>(__data)), 589 into_variant(static_cast<_Child&&>(__child))...); 590 }); 591 } 592 }; 593 594 struct __transfer_when_all_with_variant_impl : __sexpr_defaults 595 { 596 static constexpr auto get_attrs = // 597 []<class _Data>(const _Data& __data, 598 const auto&...) noexcept -> const _Data& { 599 return __data; 600 }; 601 602 static constexpr auto get_completion_signatures = // 603 []<class _Sender>(_Sender&&) noexcept // 604 -> __completion_signatures_of_t< // 605 transform_sender_result_t<default_domain, _Sender, empty_env>> { 606 return {}; 607 }; 608 }; 609 } // namespace __when_all 610 611 using __when_all::when_all_t; 612 inline constexpr when_all_t when_all{}; 613 614 using __when_all::when_all_with_variant_t; 615 inline constexpr when_all_with_variant_t when_all_with_variant{}; 616 617 using __when_all::transfer_when_all_t; 618 inline constexpr transfer_when_all_t transfer_when_all{}; 619 620 using __when_all::transfer_when_all_with_variant_t; 621 inline constexpr transfer_when_all_with_variant_t 622 transfer_when_all_with_variant{}; 623 624 template <> 625 struct __sexpr_impl<when_all_t> : __when_all::__when_all_impl 626 {}; 627 628 template <> 629 struct __sexpr_impl<when_all_with_variant_t> : 630 __when_all::__when_all_with_variant_impl 631 {}; 632 633 template <> 634 struct __sexpr_impl<transfer_when_all_t> : __when_all::__transfer_when_all_impl 635 {}; 636 637 template <> 638 struct __sexpr_impl<transfer_when_all_with_variant_t> : 639 __when_all::__transfer_when_all_with_variant_impl 640 {}; 641 } // namespace stdexec 642