1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__awaitable.hpp"
19 #include "__completion_signatures.hpp"
20 #include "__concepts.hpp"
21 #include "__config.hpp"
22 #include "__execution_fwd.hpp"
23 #include "__meta.hpp"
24 #include "__receivers.hpp"
25 #include "__tag_invoke.hpp"
26 
27 #include <exception>
28 #include <utility>
29 
30 namespace stdexec
31 {
32 #if !STDEXEC_STD_NO_COROUTINES()
33 /////////////////////////////////////////////////////////////////////////////
34 // __connect_awaitable_
35 namespace __connect_awaitable_
36 {
37 struct __promise_base
38 {
initial_suspendstdexec::__connect_awaitable_::__promise_base39     auto initial_suspend() noexcept -> __coro::suspend_always
40     {
41         return {};
42     }
43 
final_suspendstdexec::__connect_awaitable_::__promise_base44     [[noreturn]] auto final_suspend() noexcept -> __coro::suspend_always
45     {
46         std::terminate();
47     }
48 
unhandled_exceptionstdexec::__connect_awaitable_::__promise_base49     [[noreturn]] void unhandled_exception() noexcept
50     {
51         std::terminate();
52     }
53 
return_voidstdexec::__connect_awaitable_::__promise_base54     [[noreturn]] void return_void() noexcept
55     {
56         std::terminate();
57     }
58 };
59 
60 struct __operation_base
61 {
62     __coro::coroutine_handle<> __coro_;
63 
__operation_basestdexec::__connect_awaitable_::__operation_base64     explicit __operation_base(__coro::coroutine_handle<> __hcoro) noexcept :
65         __coro_(__hcoro)
66     {}
67 
__operation_basestdexec::__connect_awaitable_::__operation_base68     __operation_base(__operation_base&& __other) noexcept :
69         __coro_(std::exchange(__other.__coro_, {}))
70     {}
71 
~__operation_basestdexec::__connect_awaitable_::__operation_base72     ~__operation_base()
73     {
74         if (__coro_)
75         {
76 #if STDEXEC_MSVC()
77             // MSVCBUG
78             // https://developercommunity.visualstudio.com/t/Double-destroy-of-a-local-in-coroutine-d/10456428
79 
80             // Reassign __coro_ before calling destroy to make the mutation
81             // observable and to hopefully ensure that the compiler does not
82             // eliminate it.
83             auto __coro = __coro_;
84             __coro_ = {};
85             __coro.destroy();
86 #else
87             __coro_.destroy();
88 #endif
89         }
90     }
91 
startstdexec::__connect_awaitable_::__operation_base92     void start() & noexcept
93     {
94         __coro_.resume();
95     }
96 };
97 
98 template <class _ReceiverId>
99 struct __promise;
100 
101 template <class _ReceiverId>
102 struct __operation
103 {
104     struct __t : __operation_base
105     {
106         using promise_type = stdexec::__t<__promise<_ReceiverId>>;
107         using __operation_base::__operation_base;
108     };
109 };
110 
111 template <class _ReceiverId>
112 struct __promise
113 {
114     using _Receiver = stdexec::__t<_ReceiverId>;
115 
116     struct __t : __promise_base
117     {
118         using __id = __promise;
119 
__tstdexec::__connect_awaitable_::__promise::__t120         explicit __t(auto&, _Receiver& __rcvr) noexcept : __rcvr_(__rcvr) {}
121 
unhandled_stoppedstdexec::__connect_awaitable_::__promise::__t122         auto unhandled_stopped() noexcept -> __coro::coroutine_handle<>
123         {
124             stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr_));
125             // Returning noop_coroutine here causes the __connect_awaitable
126             // coroutine to never resume past the point where it co_await's
127             // the awaitable.
128             return __coro::noop_coroutine();
129         }
130 
get_return_objectstdexec::__connect_awaitable_::__promise::__t131         auto get_return_object() noexcept
132             -> stdexec::__t<__operation<_ReceiverId>>
133         {
134             return stdexec::__t<__operation<_ReceiverId>>{
135                 __coro::coroutine_handle<__t>::from_promise(*this)};
136         }
137 
138         template <class _Awaitable>
await_transformstdexec::__connect_awaitable_::__promise::__t139         auto await_transform(_Awaitable&& __awaitable) noexcept -> _Awaitable&&
140         {
141             return static_cast<_Awaitable&&>(__awaitable);
142         }
143 
144         template <class _Awaitable>
145             requires tag_invocable<as_awaitable_t, _Awaitable, __t&>
await_transformstdexec::__connect_awaitable_::__promise::__t146         auto await_transform(_Awaitable&& __awaitable) //
147             noexcept(nothrow_tag_invocable<as_awaitable_t, _Awaitable, __t&>)
148                 -> tag_invoke_result_t<as_awaitable_t, _Awaitable, __t&>
149         {
150             return tag_invoke(as_awaitable,
151                               static_cast<_Awaitable&&>(__awaitable), *this);
152         }
153 
154         // Pass through the get_env receiver query
get_envstdexec::__connect_awaitable_::__promise::__t155         auto get_env() const noexcept -> env_of_t<_Receiver>
156         {
157             return stdexec::get_env(__rcvr_);
158         }
159 
160         _Receiver& __rcvr_;
161     };
162 };
163 
164 template <receiver _Receiver>
165 using __promise_t = __t<__promise<__id<_Receiver>>>;
166 
167 template <receiver _Receiver>
168 using __operation_t = __t<__operation<__id<_Receiver>>>;
169 
170 struct __connect_awaitable_t
171 {
172   private:
173     template <class _Fun, class... _Ts>
__co_callstdexec::__connect_awaitable_::__connect_awaitable_t174     static auto __co_call(_Fun __fun, _Ts&&... __as) noexcept
175     {
176         auto __fn = [&, __fun]() noexcept {
177             __fun(static_cast<_Ts&&>(__as)...);
178         };
179 
180         struct __awaiter
181         {
182             decltype(__fn) __fn_;
183 
184             static constexpr auto await_ready() noexcept -> bool
185             {
186                 return false;
187             }
188 
189             void await_suspend(__coro::coroutine_handle<>) noexcept
190             {
191                 __fn_();
192             }
193 
194             [[noreturn]] void await_resume() noexcept
195             {
196                 std::terminate();
197             }
198         };
199 
200         return __awaiter{__fn};
201     }
202 
203     template <class _Awaitable, class _Receiver>
204 #if STDEXEC_GCC() && (__GNUC__ > 11)
205     __attribute__((__used__))
206 #endif
207     static auto
__co_implstdexec::__connect_awaitable_::__connect_awaitable_t208         __co_impl(_Awaitable __awaitable,
209                   _Receiver __rcvr) -> __operation_t<_Receiver>
210     {
211         using __result_t = __await_result_t<_Awaitable, __promise_t<_Receiver>>;
212         std::exception_ptr __eptr;
213         try
214         {
215             if constexpr (same_as<__result_t, void>)
216                 co_await (co_await static_cast<_Awaitable&&>(__awaitable),
217                           __co_call(set_value,
218                                     static_cast<_Receiver&&>(__rcvr)));
219             else
220                 co_await __co_call(
221                     set_value, static_cast<_Receiver&&>(__rcvr),
222                     co_await static_cast<_Awaitable&&>(__awaitable));
223         }
224         catch (...)
225         {
226             __eptr = std::current_exception();
227         }
228         co_await __co_call(set_error, static_cast<_Receiver&&>(__rcvr),
229                            static_cast<std::exception_ptr&&>(__eptr));
230     }
231 
232     template <receiver _Receiver, class _Awaitable>
233     using __completions_t = //
234         completion_signatures<
235             __minvoke<      // set_value_t() or set_value_t(T)
236                 __mremove<void, __qf<set_value_t>>,
237                 __await_result_t<_Awaitable, __promise_t<_Receiver>>>,
238             set_error_t(std::exception_ptr), set_stopped_t()>;
239 
240   public:
241     template <class _Receiver, __awaitable<__promise_t<_Receiver>> _Awaitable>
242         requires receiver_of<_Receiver, __completions_t<_Receiver, _Awaitable>>
operator ()stdexec::__connect_awaitable_::__connect_awaitable_t243     auto operator()(_Awaitable&& __awaitable,
244                     _Receiver __rcvr) const -> __operation_t<_Receiver>
245     {
246         return __co_impl(static_cast<_Awaitable&&>(__awaitable),
247                          static_cast<_Receiver&&>(__rcvr));
248     }
249 };
250 } // namespace __connect_awaitable_
251 
252 using __connect_awaitable_::__connect_awaitable_t;
253 #else
254 struct __connect_awaitable_t
255 {};
256 #endif
257 inline constexpr __connect_awaitable_t __connect_awaitable{};
258 } // namespace stdexec
259