1 /* 2 * Copyright (c) Facebook, Inc. and its affiliates. 3 * Copyright (c) 2021-2024 NVIDIA Corporation 4 * 5 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://llvm.org/LICENSE.txt 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 #pragma once 18 19 // The original idea is taken from libunifex and adapted to stdexec. 20 21 #include <exception> 22 23 #include "../stdexec/execution.hpp" 24 25 #include "inline_scheduler.hpp" 26 #include "any_sender_of.hpp" 27 28 namespace exec { 29 namespace __at_coro_exit { 30 using namespace stdexec; 31 32 using __any_scheduler = any_receiver_ref< 33 completion_signatures<set_error_t(std::exception_ptr), set_stopped_t()> 34 >::any_sender<>::any_scheduler<>; 35 36 struct __die_on_stop_t { 37 template <class _Receiver> 38 struct __receiver_id { 39 struct __t { 40 using receiver_concept = stdexec::receiver_t; 41 using __id = __receiver_id; 42 _Receiver __receiver_; 43 44 template <class... _Args> 45 requires __callable<set_value_t, _Receiver, _Args...> set_valueexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t46 void set_value(_Args&&... __args) noexcept { 47 stdexec::set_value( 48 static_cast<_Receiver&&>(__receiver_), static_cast<_Args&&>(__args)...); 49 } 50 51 template <class _Error> 52 requires __callable<set_error_t, _Receiver, _Error> set_errorexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t53 void set_error(_Error&& __err) noexcept { 54 stdexec::set_error(static_cast<_Receiver&&>(__receiver_), static_cast<_Error&&>(__err)); 55 } 56 57 [[noreturn]] set_stoppedexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t58 void set_stopped() noexcept { 59 std::terminate(); 60 } 61 get_envexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t62 auto get_env() const noexcept -> env_of_t<_Receiver> { 63 return stdexec::get_env(__receiver_); 64 } 65 }; 66 }; 67 68 template <class _Rec> 69 using __receiver = __t<__receiver_id<_Rec>>; 70 71 template <class _Sender> 72 struct __sender_id { 73 template <class... _Env> 74 using __completion_signatures = __mapply< 75 __mremove<set_stopped_t(), __q<completion_signatures>>, 76 __completion_signatures_of_t<_Sender, _Env...> 77 >; 78 79 struct __t { 80 using __id = __sender_id; 81 using sender_concept = stdexec::sender_t; 82 83 _Sender __sender_; 84 85 template <receiver _Receiver> 86 requires sender_to<_Sender, __receiver<_Receiver>> connectexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t87 auto connect(_Receiver __rcvr) && noexcept 88 -> connect_result_t<_Sender, __receiver<_Receiver>> { 89 return stdexec::connect( 90 static_cast<_Sender&&>(__sender_), 91 __receiver<_Receiver>{static_cast<_Receiver&&>(__rcvr)}); 92 } 93 94 template <class... _Env> get_completion_signaturesexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t95 auto get_completion_signatures(_Env&&...) -> __completion_signatures<_Env...> { 96 return {}; 97 } 98 get_envexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t99 auto get_env() const noexcept -> env_of_t<_Sender> { 100 return stdexec::get_env(__sender_); 101 } 102 }; 103 }; 104 template <class _Sender> 105 using __sender = __t<__sender_id<__decay_t<_Sender>>>; 106 107 template <sender _Sender> operator ()exec::__at_coro_exit::__die_on_stop_t108 auto operator()(_Sender&& __sndr) const noexcept(__nothrow_decay_copyable<_Sender>) 109 -> __sender<_Sender> { 110 return __sender<_Sender>{static_cast<_Sender&&>(__sndr)}; 111 } 112 113 template <class _Value> operator ()exec::__at_coro_exit::__die_on_stop_t114 auto operator()(_Value&& __value) const noexcept -> _Value&& { 115 return static_cast<_Value&&>(__value); 116 } 117 }; 118 119 inline constexpr __die_on_stop_t __die_on_stop; 120 121 template <class _Promise> 122 concept __has_continuation = requires(_Promise& __promise, __coroutine_handle<> __c) { 123 { __promise.continuation() } -> convertible_to<__coroutine_handle<>>; 124 { __promise.set_continuation(__c) }; 125 }; 126 127 template <class... _Ts> 128 class [[nodiscard]] __task { 129 struct __promise; 130 public: 131 using promise_type = __promise; 132 133 #if STDEXEC_EDG() __task(__coro::coroutine_handle<__promise> __coro)134 __task(__coro::coroutine_handle<__promise> __coro) noexcept 135 : __coro_(__coro) { 136 } 137 #else __task(__coro::coroutine_handle<__promise> __coro)138 explicit __task(__coro::coroutine_handle<__promise> __coro) noexcept 139 : __coro_(__coro) { 140 } 141 #endif 142 __task(__task && __that)143 __task(__task&& __that) noexcept 144 : __coro_(std::exchange(__that.__coro_, {})) { 145 } 146 147 [[nodiscard]] await_ready() const148 auto await_ready() const noexcept -> bool { 149 return false; 150 } 151 152 template <__has_continuation _Promise> await_suspend(__coro::coroutine_handle<_Promise> __parent)153 auto await_suspend(__coro::coroutine_handle<_Promise> __parent) noexcept -> bool { 154 __coro_.promise().__scheduler_ = get_scheduler(get_env(__parent.promise())); 155 __coro_.promise().set_continuation(__parent.promise().continuation()); 156 __parent.promise().set_continuation(__coro_); 157 return false; 158 } 159 await_resume()160 auto await_resume() noexcept -> std::tuple<_Ts&...> { 161 return std::exchange(__coro_, {}).promise().__args_; 162 } 163 164 private: 165 struct __final_awaitable { await_readyexec::__at_coro_exit::__task::__final_awaitable166 static constexpr auto await_ready() noexcept -> bool { 167 return false; 168 } 169 await_suspendexec::__at_coro_exit::__task::__final_awaitable170 static auto await_suspend(__coro::coroutine_handle<__promise> __h) noexcept 171 -> __coro::coroutine_handle<> { 172 __promise& __p = __h.promise(); 173 auto __coro = __p.__is_unhandled_stopped_ ? __p.continuation().unhandled_stopped() 174 : __p.continuation().handle(); 175 return STDEXEC_DESTROY_AND_CONTINUE(__h, __coro); 176 } 177 await_resumeexec::__at_coro_exit::__task::__final_awaitable178 void await_resume() const noexcept { 179 } 180 }; 181 182 struct __env { 183 const __promise& __promise_; 184 185 [[nodiscard]] queryexec::__at_coro_exit::__task::__env186 auto query(get_scheduler_t) const noexcept -> __any_scheduler { 187 return __promise_.__scheduler_; 188 } 189 }; 190 191 struct __promise : with_awaitable_senders<__promise> { 192 #if STDEXEC_EDG() 193 template <class _Action> __promiseexec::__at_coro_exit::__task::__promise194 __promise(_Action&&, _Ts&&... __ts) noexcept 195 : __args_{__ts...} { 196 } 197 #else 198 template <class _Action> 199 explicit __promise(_Action&&, _Ts&... __ts) noexcept 200 : __args_{__ts...} { 201 } 202 #endif 203 initial_suspendexec::__at_coro_exit::__task::__promise204 auto initial_suspend() noexcept -> __coro::suspend_always { 205 return {}; 206 } 207 final_suspendexec::__at_coro_exit::__task::__promise208 auto final_suspend() noexcept -> __final_awaitable { 209 return {}; 210 } 211 return_voidexec::__at_coro_exit::__task::__promise212 void return_void() noexcept { 213 } 214 215 [[noreturn]] unhandled_exceptionexec::__at_coro_exit::__task::__promise216 void unhandled_exception() noexcept { 217 std::terminate(); 218 } 219 unhandled_stoppedexec::__at_coro_exit::__task::__promise220 auto unhandled_stopped() noexcept -> __coro::coroutine_handle<__promise> { 221 __is_unhandled_stopped_ = true; 222 return __coro::coroutine_handle<__promise>::from_promise(*this); 223 } 224 get_return_objectexec::__at_coro_exit::__task::__promise225 auto get_return_object() noexcept -> __task { 226 return __task(__coro::coroutine_handle<__promise>::from_promise(*this)); 227 } 228 229 template <class _Awaitable> await_transformexec::__at_coro_exit::__task::__promise230 auto await_transform(_Awaitable&& __awaitable) noexcept -> decltype(auto) { 231 return as_awaitable(__die_on_stop(static_cast<_Awaitable&&>(__awaitable)), *this); 232 } 233 get_envexec::__at_coro_exit::__task::__promise234 auto get_env() const noexcept -> __env { 235 return {*this}; 236 } 237 238 bool __is_unhandled_stopped_{false}; 239 std::tuple<_Ts&...> __args_{}; 240 __any_scheduler __scheduler_{stdexec::inline_scheduler{}}; 241 }; 242 243 __coro::coroutine_handle<__promise> __coro_; 244 }; 245 246 struct __at_coro_exit_t { 247 private: 248 template <class _Action, class... _Ts> __implexec::__at_coro_exit::__at_coro_exit_t249 static auto __impl(_Action __action, _Ts... __ts) -> __task<_Ts...> { 250 co_await static_cast<_Action&&>(__action)(static_cast<_Ts&&>(__ts)...); 251 } 252 253 public: 254 template <class _Action, class... _Ts> 255 requires __callable<__decay_t<_Action>, __decay_t<_Ts>...> operator ()exec::__at_coro_exit::__at_coro_exit_t256 auto operator()(_Action&& __action, _Ts&&... __ts) const -> __task<_Ts...> { 257 return __impl(static_cast<_Action&&>(__action), static_cast<_Ts&&>(__ts)...); 258 } 259 }; 260 } // namespace __at_coro_exit 261 262 inline constexpr __at_coro_exit::__at_coro_exit_t at_coroutine_exit{}; 263 } // namespace exec 264