1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * Copyright (c) 2021-2022 NVIDIA Corporation
4  *
5  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
6  * (the "License"); you may not use this file except in compliance with
7  * the License. You may obtain a copy of the License at
8  *
9  *   https://llvm.org/LICENSE.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 #pragma once
18 
19 // The original idea is taken from libunifex and adapted to stdexec.
20 
21 #include "../stdexec/execution.hpp"
22 #include "any_sender_of.hpp"
23 #include "inline_scheduler.hpp"
24 
25 #include <exception>
26 #include <type_traits>
27 
28 namespace exec
29 {
30 namespace __at_coro_exit
31 {
32 using namespace stdexec;
33 
34 using __any_scheduler =                         //
35     any_receiver_ref<                           //
36         completion_signatures<set_error_t(std::exception_ptr),
37                               set_stopped_t()>> //
38     ::any_sender<>::any_scheduler<>;
39 
40 struct __die_on_stop_t
41 {
42     template <class _Receiver>
43     struct __receiver_id
44     {
45         struct __t
46         {
47             using receiver_concept = stdexec::receiver_t;
48             using __id = __receiver_id;
49             _Receiver __receiver_;
50 
51             template <__one_of<set_value_t, set_error_t> _Tag,
52                       __decays_to<__t> _Self, class... _Args>
53                 requires __callable<_Tag, _Receiver, _Args...>
tag_invokeexec::__at_coro_exit::__die_on_stop_t::__receiver_id54             friend void tag_invoke(_Tag, _Self&& __self,
55                                    _Args&&... __args) noexcept
56             {
57                 _Tag{}(static_cast<_Receiver&&>(__self.__receiver_),
58                        static_cast<_Args&&>(__args)...);
59             }
60 
61             template <same_as<set_stopped_t> _Tag>
tag_invokeexec::__at_coro_exit::__die_on_stop_t::__receiver_id62             [[noreturn]] friend void tag_invoke(_Tag, __t&&) noexcept
63             {
64                 std::terminate();
65             }
66 
tag_invokeexec::__at_coro_exit::__die_on_stop_t::__receiver_id67             friend auto tag_invoke(get_env_t, const __t& __self) noexcept
68                 -> env_of_t<_Receiver>
69             {
70                 return get_env(__self.__receiver_);
71             }
72         };
73     };
74     template <class _Rec>
75     using __receiver = __t<__receiver_id<_Rec>>;
76 
77     template <class _Sender>
78     struct __sender_id
79     {
80         template <class _Env>
81         using __completion_signatures = //
82             __mapply<__remove<set_stopped_t(), __q<completion_signatures>>,
83                      completion_signatures_of_t<_Sender, _Env>>;
84 
85         struct __t
86         {
87             using __id = __sender_id;
88             using sender_concept = stdexec::sender_t;
89 
90             _Sender __sender_;
91 
92             template <receiver _Receiver>
93                 requires sender_to<_Sender, __receiver<_Receiver>>
tag_invokeexec::__at_coro_exit::__die_on_stop_t::__sender_id94             friend auto tag_invoke(connect_t, __t&& __self,
95                                    _Receiver&& __rcvr) noexcept
96                 -> connect_result_t<_Sender, __receiver<_Receiver>>
97             {
98                 return stdexec::connect(
99                     static_cast<_Sender&&>(__self.__sender_),
100                     __receiver<_Receiver>{static_cast<_Receiver&&>(__rcvr)});
101             }
102 
103             template <__decays_to<__t> _Self, class _Env>
tag_invokeexec::__at_coro_exit::__die_on_stop_t::__sender_id104             friend auto tag_invoke(get_completion_signatures_t, _Self&&, _Env&&)
105                 -> __completion_signatures<_Env>
106             {
107                 return {};
108             }
109 
tag_invokeexec::__at_coro_exit::__die_on_stop_t::__sender_id110             friend auto tag_invoke(get_env_t, const __t& __self) noexcept
111                 -> env_of_t<_Sender>
112             {
113                 return get_env(__self.__sender_);
114             }
115         };
116     };
117     template <class _Sender>
118     using __sender = __t<__sender_id<__decay_t<_Sender>>>;
119 
120     template <sender _Sender>
operator ()exec::__at_coro_exit::__die_on_stop_t121     auto operator()(_Sender&& __sndr) const
122         noexcept(__nothrow_decay_copyable<_Sender>) -> __sender<_Sender>
123     {
124         return __sender<_Sender>{static_cast<_Sender&&>(__sndr)};
125     }
126 
127     template <class _Value>
operator ()exec::__at_coro_exit::__die_on_stop_t128     auto operator()(_Value&& __value) const noexcept -> _Value&&
129     {
130         return static_cast<_Value&&>(__value);
131     }
132 };
133 
134 inline constexpr __die_on_stop_t __die_on_stop;
135 
136 template <class _Promise>
137 concept __has_continuation = //
138     requires(_Promise& __promise, __continuation_handle<> __c) {
139         {
140             __promise.continuation()
141         } -> convertible_to<__continuation_handle<>>;
142         {
143             __promise.set_continuation(__c)
144         };
145     };
146 
147 template <class... _Ts>
148 class [[nodiscard]] __task
149 {
150     struct __promise;
151 
152   public:
153     using promise_type = __promise;
154 
__task(__coro::coroutine_handle<__promise> __coro)155     explicit __task(__coro::coroutine_handle<__promise> __coro) noexcept :
156         __coro_(__coro)
157     {}
158 
__task(__task && __that)159     __task(__task&& __that) noexcept :
160         __coro_(std::exchange(__that.__coro_, {}))
161     {}
162 
await_ready() const163     [[nodiscard]] auto await_ready() const noexcept -> bool
164     {
165         return false;
166     }
167 
168     template <__has_continuation _Promise>
await_suspend(__coro::coroutine_handle<_Promise> __parent)169     auto await_suspend(__coro::coroutine_handle<_Promise> __parent) noexcept
170         -> bool
171     {
172         __coro_.promise().__scheduler_ =
173             get_scheduler(get_env(__parent.promise()));
174         __coro_.promise().set_continuation(__parent.promise().continuation());
175         __parent.promise().set_continuation(__coro_);
176         return false;
177     }
178 
await_resume()179     auto await_resume() noexcept -> std::tuple<_Ts&...>
180     {
181         return std::exchange(__coro_, {}).promise().__args_;
182     }
183 
184   private:
185     struct __final_awaitable
186     {
await_readyexec::__at_coro_exit::__task::__final_awaitable187         static constexpr auto await_ready() noexcept -> bool
188         {
189             return false;
190         }
191 
192         static auto
await_suspendexec::__at_coro_exit::__task::__final_awaitable193             await_suspend(__coro::coroutine_handle<__promise> __h) noexcept
194             -> __coro::coroutine_handle<>
195         {
196             __promise& __p = __h.promise();
197             auto __coro = __p.__is_unhandled_stopped_
198                               ? __p.continuation().unhandled_stopped()
199                               : __p.continuation().handle();
200             return STDEXEC_DESTROY_AND_CONTINUE(__h, __coro);
201         }
202 
await_resumeexec::__at_coro_exit::__task::__final_awaitable203         void await_resume() const noexcept {}
204     };
205 
206     struct __env
207     {
208         const __promise& __promise_;
209 
tag_invoke(get_scheduler_t,__env __self)210         friend auto tag_invoke(get_scheduler_t, __env __self) noexcept
211             -> __any_scheduler
212         {
213             return __self.__promise_.__scheduler_;
214         }
215     };
216 
217     struct __promise : with_awaitable_senders<__promise>
218     {
219         template <class _Action>
__promiseexec::__at_coro_exit::__task::__promise220         explicit __promise(_Action&&, _Ts&... __ts) noexcept : __args_{__ts...}
221         {}
222 
initial_suspendexec::__at_coro_exit::__task::__promise223         auto initial_suspend() noexcept -> __coro::suspend_always
224         {
225             return {};
226         }
227 
final_suspendexec::__at_coro_exit::__task::__promise228         auto final_suspend() noexcept -> __final_awaitable
229         {
230             return {};
231         }
232 
return_voidexec::__at_coro_exit::__task::__promise233         void return_void() noexcept {}
234 
unhandled_exceptionexec::__at_coro_exit::__task::__promise235         [[noreturn]] void unhandled_exception() noexcept
236         {
237             std::terminate();
238         }
239 
unhandled_stoppedexec::__at_coro_exit::__task::__promise240         auto unhandled_stopped() noexcept -> __coro::coroutine_handle<__promise>
241         {
242             __is_unhandled_stopped_ = true;
243             return __coro::coroutine_handle<__promise>::from_promise(*this);
244         }
245 
get_return_objectexec::__at_coro_exit::__task::__promise246         auto get_return_object() noexcept -> __task
247         {
248             return __task(
249                 __coro::coroutine_handle<__promise>::from_promise(*this));
250         }
251 
252         template <class _Awaitable>
await_transformexec::__at_coro_exit::__task::__promise253         auto await_transform(_Awaitable&& __awaitable) noexcept
254             -> decltype(auto)
255         {
256             return as_awaitable(
257                 __die_on_stop(static_cast<_Awaitable&&>(__awaitable)), *this);
258         }
259 
tag_invoke(get_env_t,const __promise & __self)260         friend auto tag_invoke(get_env_t, const __promise& __self) noexcept
261             -> __env
262         {
263             return {__self};
264         }
265 
266         bool __is_unhandled_stopped_{false};
267         std::tuple<_Ts&...> __args_{};
268         __any_scheduler __scheduler_{inline_scheduler{}};
269     };
270 
271     __coro::coroutine_handle<__promise> __coro_;
272 };
273 
274 struct __at_coro_exit_t
275 {
276   private:
277     template <class _Action, class... _Ts>
__implexec::__at_coro_exit::__at_coro_exit_t278     static auto __impl(_Action __action, _Ts... __ts) -> __task<_Ts...>
279     {
280         co_await static_cast<_Action&&>(__action)(static_cast<_Ts&&>(__ts)...);
281     }
282 
283   public:
284     template <class _Action, class... _Ts>
285         requires __callable<__decay_t<_Action>, __decay_t<_Ts>...>
operator ()exec::__at_coro_exit::__at_coro_exit_t286     auto operator()(_Action&& __action, _Ts&&... __ts) const -> __task<_Ts...>
287     {
288         return __impl(static_cast<_Action&&>(__action),
289                       static_cast<_Ts&&>(__ts)...);
290     }
291 };
292 } // namespace __at_coro_exit
293 
294 inline constexpr __at_coro_exit::__at_coro_exit_t at_coroutine_exit{};
295 } // namespace exec
296