1 /* 2 * Copyright (c) Facebook, Inc. and its affiliates. 3 * Copyright (c) 2021-2024 NVIDIA Corporation 4 * 5 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://llvm.org/LICENSE.txt 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 #pragma once 18 19 // The original idea is taken from libunifex and adapted to stdexec. 20 21 #include "../stdexec/execution.hpp" 22 #include "any_sender_of.hpp" 23 #include "inline_scheduler.hpp" 24 25 #include <exception> 26 #include <type_traits> 27 28 namespace exec 29 { 30 namespace __at_coro_exit 31 { 32 using namespace stdexec; 33 34 using __any_scheduler = // 35 any_receiver_ref< // 36 completion_signatures<set_error_t(std::exception_ptr), 37 set_stopped_t()>> // 38 ::any_sender<>::any_scheduler<>; 39 40 struct __die_on_stop_t 41 { 42 template <class _Receiver> 43 struct __receiver_id 44 { 45 struct __t 46 { 47 using receiver_concept = stdexec::receiver_t; 48 using __id = __receiver_id; 49 _Receiver __receiver_; 50 51 template <class... _Args> 52 requires __callable<set_value_t, _Receiver, _Args...> set_valueexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t53 void set_value(_Args&&... __args) noexcept 54 { 55 stdexec::set_value(static_cast<_Receiver&&>(__receiver_), 56 static_cast<_Args&&>(__args)...); 57 } 58 59 template <class _Error> 60 requires __callable<set_error_t, _Receiver, _Error> set_errorexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t61 void set_error(_Error&& __err) noexcept 62 { 63 stdexec::set_error(static_cast<_Receiver&&>(__receiver_), 64 static_cast<_Error&&>(__err)); 65 } 66 set_stoppedexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t67 [[noreturn]] void set_stopped() noexcept 68 { 69 std::terminate(); 70 } 71 get_envexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t72 auto get_env() const noexcept -> env_of_t<_Receiver> 73 { 74 return stdexec::get_env(__receiver_); 75 } 76 }; 77 }; 78 79 template <class _Rec> 80 using __receiver = __t<__receiver_id<_Rec>>; 81 82 template <class _Sender> 83 struct __sender_id 84 { 85 template <class... _Env> 86 using __completion_signatures = // 87 __mapply<__mremove<set_stopped_t(), __q<completion_signatures>>, 88 __completion_signatures_of_t<_Sender, _Env...>>; 89 90 struct __t 91 { 92 using __id = __sender_id; 93 using sender_concept = stdexec::sender_t; 94 95 _Sender __sender_; 96 97 template <receiver _Receiver> 98 requires sender_to<_Sender, __receiver<_Receiver>> connectexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t99 auto connect(_Receiver __rcvr) && noexcept 100 -> connect_result_t<_Sender, __receiver<_Receiver>> 101 { 102 return stdexec::connect( 103 static_cast<_Sender&&>(__sender_), 104 __receiver<_Receiver>{static_cast<_Receiver&&>(__rcvr)}); 105 } 106 107 template <class... _Env> get_completion_signaturesexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t108 auto get_completion_signatures(_Env&&...) 109 -> __completion_signatures<_Env...> 110 { 111 return {}; 112 } 113 get_envexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t114 auto get_env() const noexcept -> env_of_t<_Sender> 115 { 116 return stdexec::get_env(__sender_); 117 } 118 }; 119 }; 120 template <class _Sender> 121 using __sender = __t<__sender_id<__decay_t<_Sender>>>; 122 123 template <sender _Sender> operator ()exec::__at_coro_exit::__die_on_stop_t124 auto operator()(_Sender&& __sndr) const 125 noexcept(__nothrow_decay_copyable<_Sender>) -> __sender<_Sender> 126 { 127 return __sender<_Sender>{static_cast<_Sender&&>(__sndr)}; 128 } 129 130 template <class _Value> operator ()exec::__at_coro_exit::__die_on_stop_t131 auto operator()(_Value&& __value) const noexcept -> _Value&& 132 { 133 return static_cast<_Value&&>(__value); 134 } 135 }; 136 137 inline constexpr __die_on_stop_t __die_on_stop; 138 139 template <class _Promise> 140 concept __has_continuation = // 141 requires(_Promise& __promise, __continuation_handle<> __c) { 142 { __promise.continuation() } -> convertible_to<__continuation_handle<>>; 143 { __promise.set_continuation(__c) }; 144 }; 145 146 template <class... _Ts> 147 class [[nodiscard]] __task 148 { 149 struct __promise; 150 151 public: 152 using promise_type = __promise; 153 __task(__coro::coroutine_handle<__promise> __coro)154 explicit __task(__coro::coroutine_handle<__promise> __coro) noexcept : 155 __coro_(__coro) 156 {} 157 __task(__task && __that)158 __task(__task&& __that) noexcept : 159 __coro_(std::exchange(__that.__coro_, {})) 160 {} 161 await_ready() const162 [[nodiscard]] auto await_ready() const noexcept -> bool 163 { 164 return false; 165 } 166 167 template <__has_continuation _Promise> await_suspend(__coro::coroutine_handle<_Promise> __parent)168 auto await_suspend(__coro::coroutine_handle<_Promise> __parent) noexcept 169 -> bool 170 { 171 __coro_.promise().__scheduler_ = 172 get_scheduler(get_env(__parent.promise())); 173 __coro_.promise().set_continuation(__parent.promise().continuation()); 174 __parent.promise().set_continuation(__coro_); 175 return false; 176 } 177 await_resume()178 auto await_resume() noexcept -> std::tuple<_Ts&...> 179 { 180 return std::exchange(__coro_, {}).promise().__args_; 181 } 182 183 private: 184 struct __final_awaitable 185 { await_readyexec::__at_coro_exit::__task::__final_awaitable186 static constexpr auto await_ready() noexcept -> bool 187 { 188 return false; 189 } 190 191 static auto await_suspendexec::__at_coro_exit::__task::__final_awaitable192 await_suspend(__coro::coroutine_handle<__promise> __h) noexcept 193 -> __coro::coroutine_handle<> 194 { 195 __promise& __p = __h.promise(); 196 auto __coro = __p.__is_unhandled_stopped_ 197 ? __p.continuation().unhandled_stopped() 198 : __p.continuation().handle(); 199 return STDEXEC_DESTROY_AND_CONTINUE(__h, __coro); 200 } 201 await_resumeexec::__at_coro_exit::__task::__final_awaitable202 void await_resume() const noexcept {} 203 }; 204 205 struct __env 206 { 207 const __promise& __promise_; 208 queryexec::__at_coro_exit::__task::__env209 auto query(get_scheduler_t) const noexcept -> __any_scheduler 210 { 211 return __promise_.__scheduler_; 212 } 213 }; 214 215 struct __promise : with_awaitable_senders<__promise> 216 { 217 template <class _Action> __promiseexec::__at_coro_exit::__task::__promise218 explicit __promise(_Action&&, _Ts&... __ts) noexcept : __args_{__ts...} 219 {} 220 initial_suspendexec::__at_coro_exit::__task::__promise221 auto initial_suspend() noexcept -> __coro::suspend_always 222 { 223 return {}; 224 } 225 final_suspendexec::__at_coro_exit::__task::__promise226 auto final_suspend() noexcept -> __final_awaitable 227 { 228 return {}; 229 } 230 return_voidexec::__at_coro_exit::__task::__promise231 void return_void() noexcept {} 232 unhandled_exceptionexec::__at_coro_exit::__task::__promise233 [[noreturn]] void unhandled_exception() noexcept 234 { 235 std::terminate(); 236 } 237 unhandled_stoppedexec::__at_coro_exit::__task::__promise238 auto unhandled_stopped() noexcept -> __coro::coroutine_handle<__promise> 239 { 240 __is_unhandled_stopped_ = true; 241 return __coro::coroutine_handle<__promise>::from_promise(*this); 242 } 243 get_return_objectexec::__at_coro_exit::__task::__promise244 auto get_return_object() noexcept -> __task 245 { 246 return __task( 247 __coro::coroutine_handle<__promise>::from_promise(*this)); 248 } 249 250 template <class _Awaitable> 251 auto await_transformexec::__at_coro_exit::__task::__promise252 await_transform(_Awaitable&& __awaitable) noexcept -> decltype(auto) 253 { 254 return as_awaitable( 255 __die_on_stop(static_cast<_Awaitable&&>(__awaitable)), *this); 256 } 257 get_envexec::__at_coro_exit::__task::__promise258 auto get_env() const noexcept -> __env 259 { 260 return {*this}; 261 } 262 263 bool __is_unhandled_stopped_{false}; 264 std::tuple<_Ts&...> __args_{}; 265 __any_scheduler __scheduler_{inline_scheduler{}}; 266 }; 267 268 __coro::coroutine_handle<__promise> __coro_; 269 }; 270 271 struct __at_coro_exit_t 272 { 273 private: 274 template <class _Action, class... _Ts> __implexec::__at_coro_exit::__at_coro_exit_t275 static auto __impl(_Action __action, _Ts... __ts) -> __task<_Ts...> 276 { 277 co_await static_cast<_Action&&>(__action)(static_cast<_Ts&&>(__ts)...); 278 } 279 280 public: 281 template <class _Action, class... _Ts> 282 requires __callable<__decay_t<_Action>, __decay_t<_Ts>...> operator ()exec::__at_coro_exit::__at_coro_exit_t283 auto operator()(_Action&& __action, _Ts&&... __ts) const -> __task<_Ts...> 284 { 285 return __impl(static_cast<_Action&&>(__action), 286 static_cast<_Ts&&>(__ts)...); 287 } 288 }; 289 } // namespace __at_coro_exit 290 291 inline constexpr __at_coro_exit::__at_coro_exit_t at_coroutine_exit{}; 292 } // namespace exec 293