xref: /openbmc/sdbusplus/include/sdbusplus/async/stdexec/__detail/__connect_awaitable.hpp (revision 36137e09614746b13603b5fbae79e6f70819c46b)
1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__awaitable.hpp"
19 #include "__completion_signatures.hpp"
20 #include "__concepts.hpp"
21 #include "__config.hpp"
22 #include "__execution_fwd.hpp"
23 #include "__meta.hpp"
24 #include "__receivers.hpp"
25 #include "__tag_invoke.hpp"
26 
27 #include <exception>
28 #include <utility>
29 
30 namespace stdexec
31 {
32 #if !STDEXEC_STD_NO_COROUTINES()
33 /////////////////////////////////////////////////////////////////////////////
34 // __connect_awaitable_
35 namespace __connect_awaitable_
36 {
37 struct __promise_base
38 {
initial_suspendstdexec::__connect_awaitable_::__promise_base39     auto initial_suspend() noexcept -> __coro::suspend_always
40     {
41         return {};
42     }
43 
final_suspendstdexec::__connect_awaitable_::__promise_base44     [[noreturn]] auto final_suspend() noexcept -> __coro::suspend_always
45     {
46         std::terminate();
47     }
48 
unhandled_exceptionstdexec::__connect_awaitable_::__promise_base49     [[noreturn]] void unhandled_exception() noexcept
50     {
51         std::terminate();
52     }
53 
return_voidstdexec::__connect_awaitable_::__promise_base54     [[noreturn]] void return_void() noexcept
55     {
56         std::terminate();
57     }
58 };
59 
60 struct __operation_base
61 {
62     __coro::coroutine_handle<> __coro_;
63 
__operation_basestdexec::__connect_awaitable_::__operation_base64     explicit __operation_base(__coro::coroutine_handle<> __hcoro) noexcept :
65         __coro_(__hcoro)
66     {}
67 
__operation_basestdexec::__connect_awaitable_::__operation_base68     __operation_base(__operation_base&& __other) noexcept :
69         __coro_(std::exchange(__other.__coro_, {}))
70     {}
71 
~__operation_basestdexec::__connect_awaitable_::__operation_base72     ~__operation_base()
73     {
74         if (__coro_)
75         {
76 #if STDEXEC_MSVC()
77             // MSVCBUG
78             // https://developercommunity.visualstudio.com/t/Double-destroy-of-a-local-in-coroutine-d/10456428
79 
80             // Reassign __coro_ before calling destroy to make the mutation
81             // observable and to hopefully ensure that the compiler does not
82             // eliminate it.
83             auto __coro = __coro_;
84             __coro_ = {};
85             __coro.destroy();
86 #else
87             __coro_.destroy();
88 #endif
89         }
90     }
91 
startstdexec::__connect_awaitable_::__operation_base92     void start() & noexcept
93     {
94         __coro_.resume();
95     }
96 };
97 
98 template <class _ReceiverId>
99 struct __promise;
100 
101 template <class _ReceiverId>
102 struct __operation
103 {
104     struct __t : __operation_base
105     {
106         using promise_type = stdexec::__t<__promise<_ReceiverId>>;
107         using __operation_base::__operation_base;
108     };
109 };
110 
111 template <class _ReceiverId>
112 struct __promise
113 {
114     using _Receiver = stdexec::__t<_ReceiverId>;
115 
116     struct __t : __promise_base
117     {
118         using __id = __promise;
119 
120 #if STDEXEC_EDG()
__tstdexec::__connect_awaitable_::__promise::__t121         __t(auto&&, _Receiver&& __rcvr) noexcept : __rcvr_(__rcvr) {}
122 #else
__tstdexec::__connect_awaitable_::__promise::__t123         explicit __t(auto&, _Receiver& __rcvr) noexcept : __rcvr_(__rcvr) {}
124 #endif
125 
unhandled_stoppedstdexec::__connect_awaitable_::__promise::__t126         auto unhandled_stopped() noexcept -> __coro::coroutine_handle<>
127         {
128             stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr_));
129             // Returning noop_coroutine here causes the __connect_awaitable
130             // coroutine to never resume past the point where it co_await's
131             // the awaitable.
132             return __coro::noop_coroutine();
133         }
134 
get_return_objectstdexec::__connect_awaitable_::__promise::__t135         auto get_return_object() noexcept
136             -> stdexec::__t<__operation<_ReceiverId>>
137         {
138             return stdexec::__t<__operation<_ReceiverId>>{
139                 __coro::coroutine_handle<__t>::from_promise(*this)};
140         }
141 
142         template <class _Awaitable>
await_transformstdexec::__connect_awaitable_::__promise::__t143         auto await_transform(_Awaitable&& __awaitable) noexcept -> _Awaitable&&
144         {
145             return static_cast<_Awaitable&&>(__awaitable);
146         }
147 
148         template <class _Awaitable>
149             requires tag_invocable<as_awaitable_t, _Awaitable, __t&>
await_transformstdexec::__connect_awaitable_::__promise::__t150         auto await_transform(_Awaitable&& __awaitable) //
151             noexcept(nothrow_tag_invocable<as_awaitable_t, _Awaitable, __t&>)
152                 -> tag_invoke_result_t<as_awaitable_t, _Awaitable, __t&>
153         {
154             return tag_invoke(as_awaitable,
155                               static_cast<_Awaitable&&>(__awaitable), *this);
156         }
157 
158         // Pass through the get_env receiver query
get_envstdexec::__connect_awaitable_::__promise::__t159         auto get_env() const noexcept -> env_of_t<_Receiver>
160         {
161             return stdexec::get_env(__rcvr_);
162         }
163 
164         _Receiver& __rcvr_;
165     };
166 };
167 
168 template <receiver _Receiver>
169 using __promise_t = __t<__promise<__id<_Receiver>>>;
170 
171 template <receiver _Receiver>
172 using __operation_t = __t<__operation<__id<_Receiver>>>;
173 
174 struct __connect_awaitable_t
175 {
176   private:
177     template <class _Fun, class... _Ts>
__co_callstdexec::__connect_awaitable_::__connect_awaitable_t178     static auto __co_call(_Fun __fun, _Ts&&... __as) noexcept
179     {
180         auto __fn = [&, __fun]() noexcept {
181             __fun(static_cast<_Ts&&>(__as)...);
182         };
183 
184         struct __awaiter
185         {
186             decltype(__fn) __fn_;
187 
188             static constexpr auto await_ready() noexcept -> bool
189             {
190                 return false;
191             }
192 
193             void await_suspend(__coro::coroutine_handle<>) noexcept
194             {
195                 __fn_();
196             }
197 
198             [[noreturn]] void await_resume() noexcept
199             {
200                 std::terminate();
201             }
202         };
203 
204         return __awaiter{__fn};
205     }
206 
207     template <class _Awaitable, class _Receiver>
208 #if STDEXEC_GCC() && (__GNUC__ > 11)
209     __attribute__((__used__))
210 #endif
211     static auto
__co_implstdexec::__connect_awaitable_::__connect_awaitable_t212         __co_impl(_Awaitable __awaitable, _Receiver __rcvr)
213             -> __operation_t<_Receiver>
214     {
215         using __result_t = __await_result_t<_Awaitable, __promise_t<_Receiver>>;
216         std::exception_ptr __eptr;
217         try
218         {
219             if constexpr (same_as<__result_t, void>)
220                 co_await (co_await static_cast<_Awaitable&&>(__awaitable),
221                           __co_call(set_value,
222                                     static_cast<_Receiver&&>(__rcvr)));
223             else
224                 co_await __co_call(
225                     set_value, static_cast<_Receiver&&>(__rcvr),
226                     co_await static_cast<_Awaitable&&>(__awaitable));
227         }
228         catch (...)
229         {
230             __eptr = std::current_exception();
231         }
232         co_await __co_call(set_error, static_cast<_Receiver&&>(__rcvr),
233                            static_cast<std::exception_ptr&&>(__eptr));
234     }
235 
236     template <receiver _Receiver, class _Awaitable>
237     using __completions_t = //
238         completion_signatures<
239             __minvoke<      // set_value_t() or set_value_t(T)
240                 __mremove<void, __qf<set_value_t>>,
241                 __await_result_t<_Awaitable, __promise_t<_Receiver>>>,
242             set_error_t(std::exception_ptr), set_stopped_t()>;
243 
244   public:
245     template <class _Receiver, __awaitable<__promise_t<_Receiver>> _Awaitable>
246         requires receiver_of<_Receiver, __completions_t<_Receiver, _Awaitable>>
operator ()stdexec::__connect_awaitable_::__connect_awaitable_t247     auto operator()(_Awaitable&& __awaitable, _Receiver __rcvr) const
248         -> __operation_t<_Receiver>
249     {
250         return __co_impl(static_cast<_Awaitable&&>(__awaitable),
251                          static_cast<_Receiver&&>(__rcvr));
252     }
253 };
254 } // namespace __connect_awaitable_
255 
256 using __connect_awaitable_::__connect_awaitable_t;
257 #else
258 struct __connect_awaitable_t
259 {};
260 #endif
261 inline constexpr __connect_awaitable_t __connect_awaitable{};
262 } // namespace stdexec
263