1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * Copyright (c) 2021-2024 NVIDIA Corporation
4  *
5  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
6  * (the "License"); you may not use this file except in compliance with
7  * the License. You may obtain a copy of the License at
8  *
9  *   https://llvm.org/LICENSE.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 #pragma once
18 
19 // The original idea is taken from libunifex and adapted to stdexec.
20 
21 #include "../stdexec/execution.hpp"
22 #include "any_sender_of.hpp"
23 #include "inline_scheduler.hpp"
24 
25 #include <exception>
26 #include <type_traits>
27 
28 namespace exec
29 {
30 namespace __at_coro_exit
31 {
32 using namespace stdexec;
33 
34 using __any_scheduler =                         //
35     any_receiver_ref<                           //
36         completion_signatures<set_error_t(std::exception_ptr),
37                               set_stopped_t()>> //
38     ::any_sender<>::any_scheduler<>;
39 
40 struct __die_on_stop_t
41 {
42     template <class _Receiver>
43     struct __receiver_id
44     {
45         struct __t
46         {
47             using receiver_concept = stdexec::receiver_t;
48             using __id = __receiver_id;
49             _Receiver __receiver_;
50 
51             template <class... _Args>
52                 requires __callable<set_value_t, _Receiver, _Args...>
set_valueexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t53             void set_value(_Args&&... __args) noexcept
54             {
55                 stdexec::set_value(static_cast<_Receiver&&>(__receiver_),
56                                    static_cast<_Args&&>(__args)...);
57             }
58 
59             template <class _Error>
60                 requires __callable<set_error_t, _Receiver, _Error>
set_errorexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t61             void set_error(_Error&& __err) noexcept
62             {
63                 stdexec::set_error(static_cast<_Receiver&&>(__receiver_),
64                                    static_cast<_Error&&>(__err));
65             }
66 
set_stoppedexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t67             [[noreturn]] void set_stopped() noexcept
68             {
69                 std::terminate();
70             }
71 
get_envexec::__at_coro_exit::__die_on_stop_t::__receiver_id::__t72             auto get_env() const noexcept -> env_of_t<_Receiver>
73             {
74                 return stdexec::get_env(__receiver_);
75             }
76         };
77     };
78 
79     template <class _Rec>
80     using __receiver = __t<__receiver_id<_Rec>>;
81 
82     template <class _Sender>
83     struct __sender_id
84     {
85         template <class... _Env>
86         using __completion_signatures = //
87             __mapply<__mremove<set_stopped_t(), __q<completion_signatures>>,
88                      __completion_signatures_of_t<_Sender, _Env...>>;
89 
90         struct __t
91         {
92             using __id = __sender_id;
93             using sender_concept = stdexec::sender_t;
94 
95             _Sender __sender_;
96 
97             template <receiver _Receiver>
98                 requires sender_to<_Sender, __receiver<_Receiver>>
connectexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t99             auto connect(_Receiver __rcvr) && noexcept
100                 -> connect_result_t<_Sender, __receiver<_Receiver>>
101             {
102                 return stdexec::connect(
103                     static_cast<_Sender&&>(__sender_),
104                     __receiver<_Receiver>{static_cast<_Receiver&&>(__rcvr)});
105             }
106 
107             template <class... _Env>
get_completion_signaturesexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t108             auto get_completion_signatures(_Env&&...)
109                 -> __completion_signatures<_Env...>
110             {
111                 return {};
112             }
113 
get_envexec::__at_coro_exit::__die_on_stop_t::__sender_id::__t114             auto get_env() const noexcept -> env_of_t<_Sender>
115             {
116                 return stdexec::get_env(__sender_);
117             }
118         };
119     };
120     template <class _Sender>
121     using __sender = __t<__sender_id<__decay_t<_Sender>>>;
122 
123     template <sender _Sender>
operator ()exec::__at_coro_exit::__die_on_stop_t124     auto operator()(_Sender&& __sndr) const
125         noexcept(__nothrow_decay_copyable<_Sender>) -> __sender<_Sender>
126     {
127         return __sender<_Sender>{static_cast<_Sender&&>(__sndr)};
128     }
129 
130     template <class _Value>
operator ()exec::__at_coro_exit::__die_on_stop_t131     auto operator()(_Value&& __value) const noexcept -> _Value&&
132     {
133         return static_cast<_Value&&>(__value);
134     }
135 };
136 
137 inline constexpr __die_on_stop_t __die_on_stop;
138 
139 template <class _Promise>
140 concept __has_continuation = //
141     requires(_Promise& __promise, __continuation_handle<> __c) {
142         { __promise.continuation() } -> convertible_to<__continuation_handle<>>;
143         { __promise.set_continuation(__c) };
144     };
145 
146 template <class... _Ts>
147 class [[nodiscard]] __task
148 {
149     struct __promise;
150 
151   public:
152     using promise_type = __promise;
153 
__task(__coro::coroutine_handle<__promise> __coro)154     explicit __task(__coro::coroutine_handle<__promise> __coro) noexcept :
155         __coro_(__coro)
156     {}
157 
__task(__task && __that)158     __task(__task&& __that) noexcept :
159         __coro_(std::exchange(__that.__coro_, {}))
160     {}
161 
await_ready() const162     [[nodiscard]] auto await_ready() const noexcept -> bool
163     {
164         return false;
165     }
166 
167     template <__has_continuation _Promise>
await_suspend(__coro::coroutine_handle<_Promise> __parent)168     auto await_suspend(__coro::coroutine_handle<_Promise> __parent) noexcept
169         -> bool
170     {
171         __coro_.promise().__scheduler_ =
172             get_scheduler(get_env(__parent.promise()));
173         __coro_.promise().set_continuation(__parent.promise().continuation());
174         __parent.promise().set_continuation(__coro_);
175         return false;
176     }
177 
await_resume()178     auto await_resume() noexcept -> std::tuple<_Ts&...>
179     {
180         return std::exchange(__coro_, {}).promise().__args_;
181     }
182 
183   private:
184     struct __final_awaitable
185     {
await_readyexec::__at_coro_exit::__task::__final_awaitable186         static constexpr auto await_ready() noexcept -> bool
187         {
188             return false;
189         }
190 
191         static auto
await_suspendexec::__at_coro_exit::__task::__final_awaitable192             await_suspend(__coro::coroutine_handle<__promise> __h) noexcept
193             -> __coro::coroutine_handle<>
194         {
195             __promise& __p = __h.promise();
196             auto __coro = __p.__is_unhandled_stopped_
197                               ? __p.continuation().unhandled_stopped()
198                               : __p.continuation().handle();
199             return STDEXEC_DESTROY_AND_CONTINUE(__h, __coro);
200         }
201 
await_resumeexec::__at_coro_exit::__task::__final_awaitable202         void await_resume() const noexcept {}
203     };
204 
205     struct __env
206     {
207         const __promise& __promise_;
208 
queryexec::__at_coro_exit::__task::__env209         auto query(get_scheduler_t) const noexcept -> __any_scheduler
210         {
211             return __promise_.__scheduler_;
212         }
213     };
214 
215     struct __promise : with_awaitable_senders<__promise>
216     {
217         template <class _Action>
__promiseexec::__at_coro_exit::__task::__promise218         explicit __promise(_Action&&, _Ts&... __ts) noexcept : __args_{__ts...}
219         {}
220 
initial_suspendexec::__at_coro_exit::__task::__promise221         auto initial_suspend() noexcept -> __coro::suspend_always
222         {
223             return {};
224         }
225 
final_suspendexec::__at_coro_exit::__task::__promise226         auto final_suspend() noexcept -> __final_awaitable
227         {
228             return {};
229         }
230 
return_voidexec::__at_coro_exit::__task::__promise231         void return_void() noexcept {}
232 
unhandled_exceptionexec::__at_coro_exit::__task::__promise233         [[noreturn]] void unhandled_exception() noexcept
234         {
235             std::terminate();
236         }
237 
unhandled_stoppedexec::__at_coro_exit::__task::__promise238         auto unhandled_stopped() noexcept -> __coro::coroutine_handle<__promise>
239         {
240             __is_unhandled_stopped_ = true;
241             return __coro::coroutine_handle<__promise>::from_promise(*this);
242         }
243 
get_return_objectexec::__at_coro_exit::__task::__promise244         auto get_return_object() noexcept -> __task
245         {
246             return __task(
247                 __coro::coroutine_handle<__promise>::from_promise(*this));
248         }
249 
250         template <class _Awaitable>
251         auto
await_transformexec::__at_coro_exit::__task::__promise252             await_transform(_Awaitable&& __awaitable) noexcept -> decltype(auto)
253         {
254             return as_awaitable(
255                 __die_on_stop(static_cast<_Awaitable&&>(__awaitable)), *this);
256         }
257 
get_envexec::__at_coro_exit::__task::__promise258         auto get_env() const noexcept -> __env
259         {
260             return {*this};
261         }
262 
263         bool __is_unhandled_stopped_{false};
264         std::tuple<_Ts&...> __args_{};
265         __any_scheduler __scheduler_{inline_scheduler{}};
266     };
267 
268     __coro::coroutine_handle<__promise> __coro_;
269 };
270 
271 struct __at_coro_exit_t
272 {
273   private:
274     template <class _Action, class... _Ts>
__implexec::__at_coro_exit::__at_coro_exit_t275     static auto __impl(_Action __action, _Ts... __ts) -> __task<_Ts...>
276     {
277         co_await static_cast<_Action&&>(__action)(static_cast<_Ts&&>(__ts)...);
278     }
279 
280   public:
281     template <class _Action, class... _Ts>
282         requires __callable<__decay_t<_Action>, __decay_t<_Ts>...>
operator ()exec::__at_coro_exit::__at_coro_exit_t283     auto operator()(_Action&& __action, _Ts&&... __ts) const -> __task<_Ts...>
284     {
285         return __impl(static_cast<_Action&&>(__action),
286                       static_cast<_Ts&&>(__ts)...);
287     }
288 };
289 } // namespace __at_coro_exit
290 
291 inline constexpr __at_coro_exit::__at_coro_exit_t at_coroutine_exit{};
292 } // namespace exec
293