1 /* 2 * Copyright (c) Facebook, Inc. and its affiliates. 3 * Copyright (c) 2021-2024 NVIDIA Corporation 4 * 5 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://llvm.org/LICENSE.txt 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 #pragma once 18 19 // The original idea is taken from libunifex and adapted to stdexec. 20 21 #include "../stdexec/execution.hpp" 22 #include "any_sender_of.hpp" 23 #include "inline_scheduler.hpp" 24 25 #include <exception> 26 #include <type_traits> 27 28 namespace exec 29 { 30 namespace __at_coro_exit 31 { 32 using namespace stdexec; 33 34 using __any_scheduler = // 35 any_receiver_ref< // 36 completion_signatures<set_error_t(std::exception_ptr), 37 set_stopped_t()>> // 38 ::any_sender<>::any_scheduler<>; 39 40 struct __die_on_stop_t 41 { 42 template <class _Receiver> 43 struct __receiver_id 44 { 45 struct __t 46 { 47 using receiver_concept = stdexec::receiver_t; 48 using __id = __receiver_id; 49 _Receiver __receiver_; 50 51 template <class... _Args> 52 requires __callable<set_value_t, _Receiver, _Args...> set_valueexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t53 void set_value(_Args&&... __args) noexcept 54 { 55 stdexec::set_value(static_cast<_Receiver&&>(__receiver_), 56 static_cast<_Args&&>(__args)...); 57 } 58 59 template <class _Error> 60 requires __callable<set_error_t, _Receiver, _Error> set_errorexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t61 void set_error(_Error&& __err) noexcept 62 { 63 stdexec::set_error(static_cast<_Receiver&&>(__receiver_), 64 static_cast<_Error&&>(__err)); 65 } 66 set_stoppedexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t67 [[noreturn]] void set_stopped() noexcept 68 { 69 std::terminate(); 70 } 71 get_envexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t72 auto get_env() const noexcept -> env_of_t<_Receiver> 73 { 74 return stdexec::get_env(__receiver_); 75 } 76 }; 77 }; 78 79 template <class _Rec> 80 using __receiver = __t<__receiver_id<_Rec>>; 81 82 template <class _Sender> 83 struct __sender_id 84 { 85 template <class... _Env> 86 using __completion_signatures = // 87 __mapply<__mremove<set_stopped_t(), __q<completion_signatures>>, 88 __completion_signatures_of_t<_Sender, _Env...>>; 89 90 struct __t 91 { 92 using __id = __sender_id; 93 using sender_concept = stdexec::sender_t; 94 95 _Sender __sender_; 96 97 template <receiver _Receiver> 98 requires sender_to<_Sender, __receiver<_Receiver>> connectexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t99 auto connect(_Receiver __rcvr) && noexcept 100 -> connect_result_t<_Sender, __receiver<_Receiver>> 101 { 102 return stdexec::connect( 103 static_cast<_Sender&&>(__sender_), 104 __receiver<_Receiver>{static_cast<_Receiver&&>(__rcvr)}); 105 } 106 107 template <class... _Env> get_completion_signaturesexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t108 auto get_completion_signatures(_Env&&...) 109 -> __completion_signatures<_Env...> 110 { 111 return {}; 112 } 113 get_envexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t114 auto get_env() const noexcept -> env_of_t<_Sender> 115 { 116 return stdexec::get_env(__sender_); 117 } 118 }; 119 }; 120 template <class _Sender> 121 using __sender = __t<__sender_id<__decay_t<_Sender>>>; 122 123 template <sender _Sender> operator ()exec::__at_coro_exit::__die_on_stop_t124 auto operator()(_Sender&& __sndr) const 125 noexcept(__nothrow_decay_copyable<_Sender>) -> __sender<_Sender> 126 { 127 return __sender<_Sender>{static_cast<_Sender&&>(__sndr)}; 128 } 129 130 template <class _Value> operator ()exec::__at_coro_exit::__die_on_stop_t131 auto operator()(_Value&& __value) const noexcept -> _Value&& 132 { 133 return static_cast<_Value&&>(__value); 134 } 135 }; 136 137 inline constexpr __die_on_stop_t __die_on_stop; 138 139 template <class _Promise> 140 concept __has_continuation = // 141 requires(_Promise& __promise, __continuation_handle<> __c) { 142 { __promise.continuation() } -> convertible_to<__continuation_handle<>>; 143 { __promise.set_continuation(__c) }; 144 }; 145 146 template <class... _Ts> 147 class [[nodiscard]] __task 148 { 149 struct __promise; 150 151 public: 152 using promise_type = __promise; 153 154 #if STDEXEC_EDG() __task(__coro::coroutine_handle<__promise> __coro)155 __task(__coro::coroutine_handle<__promise> __coro) noexcept : 156 __coro_(__coro) 157 {} 158 #else __task(__coro::coroutine_handle<__promise> __coro)159 explicit __task(__coro::coroutine_handle<__promise> __coro) noexcept : 160 __coro_(__coro) 161 {} 162 #endif 163 __task(__task && __that)164 __task(__task&& __that) noexcept : 165 __coro_(std::exchange(__that.__coro_, {})) 166 {} 167 await_ready() const168 [[nodiscard]] auto await_ready() const noexcept -> bool 169 { 170 return false; 171 } 172 173 template <__has_continuation _Promise> await_suspend(__coro::coroutine_handle<_Promise> __parent)174 auto await_suspend(__coro::coroutine_handle<_Promise> __parent) noexcept 175 -> bool 176 { 177 __coro_.promise().__scheduler_ = 178 get_scheduler(get_env(__parent.promise())); 179 __coro_.promise().set_continuation(__parent.promise().continuation()); 180 __parent.promise().set_continuation(__coro_); 181 return false; 182 } 183 await_resume()184 auto await_resume() noexcept -> std::tuple<_Ts&...> 185 { 186 return std::exchange(__coro_, {}).promise().__args_; 187 } 188 189 private: 190 struct __final_awaitable 191 { await_readyexec::__at_coro_exit::__task::__final_awaitable192 static constexpr auto await_ready() noexcept -> bool 193 { 194 return false; 195 } 196 await_suspendexec::__at_coro_exit::__task::__final_awaitable197 static auto await_suspend( 198 __coro::coroutine_handle<__promise> __h) noexcept 199 -> __coro::coroutine_handle<> 200 { 201 __promise& __p = __h.promise(); 202 auto __coro = __p.__is_unhandled_stopped_ 203 ? __p.continuation().unhandled_stopped() 204 : __p.continuation().handle(); 205 return STDEXEC_DESTROY_AND_CONTINUE(__h, __coro); 206 } 207 await_resumeexec::__at_coro_exit::__task::__final_awaitable208 void await_resume() const noexcept {} 209 }; 210 211 struct __env 212 { 213 const __promise& __promise_; 214 queryexec::__at_coro_exit::__task::__env215 auto query(get_scheduler_t) const noexcept -> __any_scheduler 216 { 217 return __promise_.__scheduler_; 218 } 219 }; 220 221 struct __promise : with_awaitable_senders<__promise> 222 { 223 #if STDEXEC_EDG() 224 template <class _Action> __promiseexec::__at_coro_exit::__task::__promise225 __promise(_Action&&, _Ts&&... __ts) noexcept : __args_{__ts...} 226 {} 227 #else 228 template <class _Action> 229 explicit __promise(_Action&&, _Ts&... __ts) noexcept : __args_{__ts...} 230 {} 231 #endif 232 initial_suspendexec::__at_coro_exit::__task::__promise233 auto initial_suspend() noexcept -> __coro::suspend_always 234 { 235 return {}; 236 } 237 final_suspendexec::__at_coro_exit::__task::__promise238 auto final_suspend() noexcept -> __final_awaitable 239 { 240 return {}; 241 } 242 return_voidexec::__at_coro_exit::__task::__promise243 void return_void() noexcept {} 244 unhandled_exceptionexec::__at_coro_exit::__task::__promise245 [[noreturn]] void unhandled_exception() noexcept 246 { 247 std::terminate(); 248 } 249 unhandled_stoppedexec::__at_coro_exit::__task::__promise250 auto unhandled_stopped() noexcept -> __coro::coroutine_handle<__promise> 251 { 252 __is_unhandled_stopped_ = true; 253 return __coro::coroutine_handle<__promise>::from_promise(*this); 254 } 255 get_return_objectexec::__at_coro_exit::__task::__promise256 auto get_return_object() noexcept -> __task 257 { 258 return __task( 259 __coro::coroutine_handle<__promise>::from_promise(*this)); 260 } 261 262 template <class _Awaitable> await_transformexec::__at_coro_exit::__task::__promise263 auto await_transform(_Awaitable&& __awaitable) noexcept 264 -> decltype(auto) 265 { 266 return as_awaitable( 267 __die_on_stop(static_cast<_Awaitable&&>(__awaitable)), *this); 268 } 269 get_envexec::__at_coro_exit::__task::__promise270 auto get_env() const noexcept -> __env 271 { 272 return {*this}; 273 } 274 275 bool __is_unhandled_stopped_{false}; 276 std::tuple<_Ts&...> __args_{}; 277 __any_scheduler __scheduler_{inline_scheduler{}}; 278 }; 279 280 __coro::coroutine_handle<__promise> __coro_; 281 }; 282 283 struct __at_coro_exit_t 284 { 285 private: 286 template <class _Action, class... _Ts> __implexec::__at_coro_exit::__at_coro_exit_t287 static auto __impl(_Action __action, _Ts... __ts) -> __task<_Ts...> 288 { 289 co_await static_cast<_Action&&>(__action)(static_cast<_Ts&&>(__ts)...); 290 } 291 292 public: 293 template <class _Action, class... _Ts> 294 requires __callable<__decay_t<_Action>, __decay_t<_Ts>...> operator ()exec::__at_coro_exit::__at_coro_exit_t295 auto operator()(_Action&& __action, _Ts&&... __ts) const -> __task<_Ts...> 296 { 297 return __impl(static_cast<_Action&&>(__action), 298 static_cast<_Ts&&>(__ts)...); 299 } 300 }; 301 } // namespace __at_coro_exit 302 303 inline constexpr __at_coro_exit::__at_coro_exit_t at_coroutine_exit{}; 304 } // namespace exec 305