xref: /openbmc/sdbusplus/include/sdbusplus/async/stdexec/__detail/__as_awaitable.hpp (revision 10d0b4b7d1498cfd5c3d37edea271a54d1984e41)
1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__execution_fwd.hpp"
19 
20 #include "__awaitable.hpp"
21 #include "__concepts.hpp"
22 #include "__config.hpp"
23 #include "__meta.hpp"
24 #include "__receivers.hpp"
25 #include "__senders.hpp"
26 #include "__tag_invoke.hpp"
27 #include "__transform_completion_signatures.hpp"
28 #include "__type_traits.hpp"
29 
30 #include <exception>
31 #include <system_error>
32 #include <variant>
33 
34 namespace stdexec {
35 #if !STDEXEC_STD_NO_COROUTINES()
36   /////////////////////////////////////////////////////////////////////////////
37   // stdexec::as_awaitable [execution.coro_utils.as_awaitable]
38   namespace __as_awaitable {
39     struct __void { };
40 
41     template <class _Value>
42     using __value_or_void_t = __if_c<__same_as<_Value, void>, __void, _Value>;
43 
44     template <class _Value>
45     using __expected_t =
46       std::variant<std::monostate, __value_or_void_t<_Value>, std::exception_ptr>;
47 
48     template <class _Value>
49     struct __receiver_base {
50       using receiver_concept = receiver_t;
51 
52       template <class... _Us>
53         requires constructible_from<__value_or_void_t<_Value>, _Us...>
set_valuestdexec::__as_awaitable::__receiver_base54       void set_value(_Us&&... __us) noexcept {
55         STDEXEC_TRY {
56           __result_->template emplace<1>(static_cast<_Us&&>(__us)...);
57           __continuation_.resume();
58         }
59         STDEXEC_CATCH_ALL {
60           stdexec::set_error(static_cast<__receiver_base&&>(*this), std::current_exception());
61         }
62       }
63 
64       template <class _Error>
set_errorstdexec::__as_awaitable::__receiver_base65       void set_error(_Error&& __err) noexcept {
66         if constexpr (__decays_to<_Error, std::exception_ptr>)
67           __result_->template emplace<2>(static_cast<_Error&&>(__err));
68         else if constexpr (__decays_to<_Error, std::error_code>)
69           __result_->template emplace<2>(std::make_exception_ptr(std::system_error(__err)));
70         else
71           __result_->template emplace<2>(std::make_exception_ptr(static_cast<_Error&&>(__err)));
72         __continuation_.resume();
73       }
74 
75       __expected_t<_Value>* __result_;
76       __coro::coroutine_handle<> __continuation_;
77     };
78 
79     template <class _PromiseId, class _Value>
80     struct __receiver {
81       using _Promise = stdexec::__t<_PromiseId>;
82 
83       struct __t : __receiver_base<_Value> {
84         using __id = __receiver;
85 
set_stoppedstdexec::__as_awaitable::__receiver::__t86         void set_stopped() noexcept {
87           auto __continuation = __coro::coroutine_handle<_Promise>::from_address(
88             this->__continuation_.address());
89           __coro::coroutine_handle<> __stopped_continuation = __continuation.promise()
90                                                                 .unhandled_stopped();
91           __stopped_continuation.resume();
92         }
93 
94         // Forward get_env query to the coroutine promise
get_envstdexec::__as_awaitable::__receiver::__t95         auto get_env() const noexcept -> env_of_t<_Promise&> {
96           auto __continuation = __coro::coroutine_handle<_Promise>::from_address(
97             this->__continuation_.address());
98           return stdexec::get_env(__continuation.promise());
99         }
100       };
101     };
102 
103     // BUGBUG NOT TO SPEC: make senders of more-than-one-value awaitable
104     // by packaging the values into a tuple.
105     // See: https://github.com/cplusplus/sender-receiver/issues/182
106     template <std::size_t _Count>
107     extern const __q<__decayed_std_tuple> __as_single;
108 
109     template <>
110     inline const __q<__midentity> __as_single<1>;
111 
112     template <>
113     inline const __mconst<void> __as_single<0>;
114 
115     template <class... _Values>
116     using __single_value = __minvoke<decltype(__as_single<sizeof...(_Values)>), _Values...>;
117 
118     template <class _Sender, class _Promise>
119     using __value_t = __decay_t<
120       __value_types_of_t<_Sender, env_of_t<_Promise&>, __q<__single_value>, __msingle_or<void>>
121     >;
122 
123     template <class _Sender, class _Promise>
124     using __receiver_t = __t<__receiver<__id<_Promise>, __value_t<_Sender, _Promise>>>;
125 
126     template <class _Value>
127     struct __sender_awaitable_base {
128       [[nodiscard]]
await_readystdexec::__as_awaitable::__sender_awaitable_base129       auto await_ready() const noexcept -> bool {
130         return false;
131       }
132 
await_resumestdexec::__as_awaitable::__sender_awaitable_base133       auto await_resume() -> _Value {
134         switch (__result_.index()) {
135         case 0: // receiver contract not satisfied
136           STDEXEC_ASSERT(false && +"_Should never get here" == nullptr);
137           break;
138         case 1: // set_value
139           if constexpr (!__same_as<_Value, void>)
140             return static_cast<_Value&&>(std::get<1>(__result_));
141           else
142             return;
143         case 2: // set_error
144           std::rethrow_exception(std::get<2>(__result_));
145         }
146         std::terminate();
147       }
148 
149      protected:
150       __expected_t<_Value> __result_;
151     };
152 
153     template <class _PromiseId, class _SenderId>
154     struct __sender_awaitable {
155       using _Promise = stdexec::__t<_PromiseId>;
156       using _Sender = stdexec::__t<_SenderId>;
157       using __value = __value_t<_Sender, _Promise>;
158 
159       struct __t : __sender_awaitable_base<__value> {
__tstdexec::__as_awaitable::__sender_awaitable::__t160         __t(_Sender&& sndr, __coro::coroutine_handle<_Promise> __hcoro)
161           noexcept(__nothrow_connectable<_Sender, __receiver>)
162           : __op_state_(connect(
163               static_cast<_Sender&&>(sndr),
164               __receiver{
165                 {&this->__result_, __hcoro}
166         })) {
167         }
168 
await_suspendstdexec::__as_awaitable::__sender_awaitable::__t169         void await_suspend(__coro::coroutine_handle<_Promise>) noexcept {
170           stdexec::start(__op_state_);
171         }
172 
173        private:
174         using __receiver = __receiver_t<_Sender, _Promise>;
175         connect_result_t<_Sender, __receiver> __op_state_;
176       };
177     };
178 
179     template <class _Promise, class _Sender>
180     using __sender_awaitable_t = __t<__sender_awaitable<__id<_Promise>, __id<_Sender>>>;
181 
182     template <class _Sender, class _Promise>
183     concept __awaitable_sender = sender_in<_Sender, env_of_t<_Promise&>>
184                               && __mvalid<__value_t, _Sender, _Promise>
185                               && sender_to<_Sender, __receiver_t<_Sender, _Promise>>
186                               && requires(_Promise& __promise) {
187                                    {
188                                      __promise.unhandled_stopped()
189                                    } -> convertible_to<__coro::coroutine_handle<>>;
190                                  };
191 
192     struct __unspecified {
193       auto get_return_object() noexcept -> __unspecified;
194       auto initial_suspend() noexcept -> __unspecified;
195       auto final_suspend() noexcept -> __unspecified;
196       void unhandled_exception() noexcept;
197       void return_void() noexcept;
198       auto unhandled_stopped() noexcept -> __coro::coroutine_handle<>;
199     };
200 
201     template <class _Tp, class _Promise>
202     concept __has_as_awaitable_member = requires(_Tp&& __t, _Promise& __promise) {
203       static_cast<_Tp &&>(__t).as_awaitable(__promise);
204     };
205 
206     struct as_awaitable_t {
207       template <class _Tp, class _Promise>
__select_impl_stdexec::__as_awaitable::as_awaitable_t208       static constexpr auto __select_impl_() noexcept {
209         if constexpr (__has_as_awaitable_member<_Tp, _Promise>) {
210           using _Result = decltype(__declval<_Tp>().as_awaitable(__declval<_Promise&>()));
211           constexpr bool _Nothrow = noexcept(__declval<_Tp>().as_awaitable(__declval<_Promise&>()));
212           return static_cast<_Result (*)() noexcept(_Nothrow)>(nullptr);
213         } else if constexpr (tag_invocable<as_awaitable_t, _Tp, _Promise&>) {
214           using _Result = tag_invoke_result_t<as_awaitable_t, _Tp, _Promise&>;
215           constexpr bool _Nothrow = nothrow_tag_invocable<as_awaitable_t, _Tp, _Promise&>;
216           return static_cast<_Result (*)() noexcept(_Nothrow)>(nullptr);
217           // NOLINTNEXTLINE(bugprone-branch-clone)
218         } else if constexpr (__awaitable<_Tp, __unspecified>) { // NOT __awaitable<_Tp, _Promise> !!
219           using _Result = _Tp&&;
220           return static_cast<_Result (*)() noexcept>(nullptr);
221         } else if constexpr (__awaitable_sender<_Tp, _Promise>) {
222           using _Result = __sender_awaitable_t<_Promise, _Tp>;
223           constexpr bool _Nothrow =
224             __nothrow_constructible_from<_Result, _Tp, __coro::coroutine_handle<_Promise>>;
225           return static_cast<_Result (*)() noexcept(_Nothrow)>(nullptr);
226         } else {
227           using _Result = _Tp&&;
228           return static_cast<_Result (*)() noexcept>(nullptr);
229         }
230       }
231 
232       template <class _Tp, class _Promise>
233       using __select_impl_t = decltype(__select_impl_<_Tp, _Promise>());
234 
235       template <class _Tp, class _Promise>
operator ()stdexec::__as_awaitable::as_awaitable_t236       auto operator()(_Tp&& __t, _Promise& __promise) const
237         noexcept(__nothrow_callable<__select_impl_t<_Tp, _Promise>>)
238           -> __call_result_t<__select_impl_t<_Tp, _Promise>> {
239         if constexpr (__has_as_awaitable_member<_Tp, _Promise>) {
240           using _Result = decltype(static_cast<_Tp&&>(__t).as_awaitable(__promise));
241           static_assert(__awaitable<_Result, _Promise>);
242           return static_cast<_Tp&&>(__t).as_awaitable(__promise);
243         } else if constexpr (tag_invocable<as_awaitable_t, _Tp, _Promise&>) {
244           using _Result = tag_invoke_result_t<as_awaitable_t, _Tp, _Promise&>;
245           static_assert(__awaitable<_Result, _Promise>);
246           return tag_invoke(*this, static_cast<_Tp&&>(__t), __promise);
247           // NOLINTNEXTLINE(bugprone-branch-clone)
248         } else if constexpr (__awaitable<_Tp, __unspecified>) { // NOT __awaitable<_Tp, _Promise> !!
249           return static_cast<_Tp&&>(__t);
250         } else if constexpr (__awaitable_sender<_Tp, _Promise>) {
251           auto __hcoro = __coro::coroutine_handle<_Promise>::from_promise(__promise);
252           return __sender_awaitable_t<_Promise, _Tp>{static_cast<_Tp&&>(__t), __hcoro};
253         } else {
254           return static_cast<_Tp&&>(__t);
255         }
256       }
257     };
258   } // namespace __as_awaitable
259 
260   using __as_awaitable::as_awaitable_t;
261   inline constexpr as_awaitable_t as_awaitable{};
262 #endif
263 } // namespace stdexec
264