1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__awaitable.hpp"
19 #include "__concepts.hpp"
20 #include "__config.hpp"
21 #include "__cpo.hpp"
22 #include "__execution_fwd.hpp"
23 #include "__meta.hpp"
24 #include "__receivers.hpp"
25 #include "__senders.hpp"
26 #include "__tag_invoke.hpp"
27 #include "__transform_completion_signatures.hpp"
28 #include "__type_traits.hpp"
29 
30 #include <exception>
31 #include <system_error>
32 #include <variant>
33 
34 namespace stdexec
35 {
36 #if !STDEXEC_STD_NO_COROUTINES()
37 /////////////////////////////////////////////////////////////////////////////
38 // stdexec::as_awaitable [execution.coro_utils.as_awaitable]
39 namespace __as_awaitable
40 {
41 struct __void
42 {};
43 
44 template <class _Value>
45 using __value_or_void_t = __if_c<__same_as<_Value, void>, __void, _Value>;
46 
47 template <class _Value>
48 using __expected_t =
49     std::variant<std::monostate, __value_or_void_t<_Value>, std::exception_ptr>;
50 
51 template <class _Value>
52 struct __receiver_base
53 {
54     using receiver_concept = receiver_t;
55 
56     template <class... _Us>
57         requires constructible_from<__value_or_void_t<_Value>, _Us...>
set_valuestdexec::__as_awaitable::__receiver_base58     void set_value(_Us&&... __us) noexcept
59     {
60         try
61         {
62             __result_->template emplace<1>(static_cast<_Us&&>(__us)...);
63             __continuation_.resume();
64         }
65         catch (...)
66         {
67             stdexec::set_error(static_cast<__receiver_base&&>(*this),
68                                std::current_exception());
69         }
70     }
71 
72     template <class _Error>
set_errorstdexec::__as_awaitable::__receiver_base73     void set_error(_Error&& __err) noexcept
74     {
75         if constexpr (__decays_to<_Error, std::exception_ptr>)
76             __result_->template emplace<2>(static_cast<_Error&&>(__err));
77         else if constexpr (__decays_to<_Error, std::error_code>)
78             __result_->template emplace<2>(
79                 std::make_exception_ptr(std::system_error(__err)));
80         else
81             __result_->template emplace<2>(
82                 std::make_exception_ptr(static_cast<_Error&&>(__err)));
83         __continuation_.resume();
84     }
85 
86     __expected_t<_Value>* __result_;
87     __coro::coroutine_handle<> __continuation_;
88 };
89 
90 template <class _PromiseId, class _Value>
91 struct __receiver
92 {
93     using _Promise = stdexec::__t<_PromiseId>;
94 
95     struct __t : __receiver_base<_Value>
96     {
97         using __id = __receiver;
98 
set_stoppedstdexec::__as_awaitable::__receiver::__t99         void set_stopped() noexcept
100         {
101             auto __continuation =
102                 __coro::coroutine_handle<_Promise>::from_address(
103                     this->__continuation_.address());
104             __coro::coroutine_handle<> __stopped_continuation =
105                 __continuation.promise().unhandled_stopped();
106             __stopped_continuation.resume();
107         }
108 
109         // Forward get_env query to the coroutine promise
get_envstdexec::__as_awaitable::__receiver::__t110         auto get_env() const noexcept -> env_of_t<_Promise&>
111         {
112             auto __continuation =
113                 __coro::coroutine_handle<_Promise>::from_address(
114                     this->__continuation_.address());
115             return stdexec::get_env(__continuation.promise());
116         }
117     };
118 };
119 
120 // BUGBUG NOT TO SPEC: make senders of more-than-one-value awaitable
121 // by packaging the values into a tuple.
122 // See: https://github.com/cplusplus/sender-receiver/issues/182
123 template <std::size_t _Count>
124 extern const __q<__decayed_std_tuple> __as_single;
125 
126 template <>
127 inline const __q<__midentity> __as_single<1>;
128 
129 template <>
130 inline const __mconst<void> __as_single<0>;
131 
132 template <class... _Values>
133 using __single_value =
134     __minvoke<decltype(__as_single<sizeof...(_Values)>), _Values...>;
135 
136 template <class _Sender, class _Promise>
137 using __value_t =
138     __decay_t<__value_types_of_t<_Sender, env_of_t<_Promise&>,
139                                  __q<__single_value>, __msingle_or<void>>>;
140 
141 template <class _Sender, class _Promise>
142 using __receiver_t =
143     __t<__receiver<__id<_Promise>, __value_t<_Sender, _Promise>>>;
144 
145 template <class _Value>
146 struct __sender_awaitable_base
147 {
await_readystdexec::__as_awaitable::__sender_awaitable_base148     [[nodiscard]] auto await_ready() const noexcept -> bool
149     {
150         return false;
151     }
152 
await_resumestdexec::__as_awaitable::__sender_awaitable_base153     auto await_resume() -> _Value
154     {
155         switch (__result_.index())
156         {
157             case 0: // receiver contract not satisfied
158                 STDEXEC_ASSERT(false && +"_Should never get here" == nullptr);
159                 break;
160             case 1: // set_value
161                 if constexpr (!__same_as<_Value, void>)
162                     return static_cast<_Value&&>(std::get<1>(__result_));
163                 else
164                     return;
165             case 2: // set_error
166                 std::rethrow_exception(std::get<2>(__result_));
167         }
168         std::terminate();
169     }
170 
171   protected:
172     __expected_t<_Value> __result_;
173 };
174 
175 template <class _PromiseId, class _SenderId>
176 struct __sender_awaitable
177 {
178     using _Promise = stdexec::__t<_PromiseId>;
179     using _Sender = stdexec::__t<_SenderId>;
180     using __value = __value_t<_Sender, _Promise>;
181 
182     struct __t : __sender_awaitable_base<__value>
183     {
__tstdexec::__as_awaitable::__sender_awaitable::__t184         __t(_Sender&& sndr, __coro::coroutine_handle<_Promise> __hcoro) //
185             noexcept(__nothrow_connectable<_Sender, __receiver>) :
186             __op_state_(connect(static_cast<_Sender&&>(sndr),
187                                 __receiver{{&this->__result_, __hcoro}}))
188         {}
189 
await_suspendstdexec::__as_awaitable::__sender_awaitable::__t190         void await_suspend(__coro::coroutine_handle<_Promise>) noexcept
191         {
192             stdexec::start(__op_state_);
193         }
194 
195       private:
196         using __receiver = __receiver_t<_Sender, _Promise>;
197         connect_result_t<_Sender, __receiver> __op_state_;
198     };
199 };
200 
201 template <class _Promise, class _Sender>
202 using __sender_awaitable_t =
203     __t<__sender_awaitable<__id<_Promise>, __id<_Sender>>>;
204 
205 template <class _Sender, class _Promise>
206 concept __awaitable_sender =
207     sender_in<_Sender, env_of_t<_Promise&>> &&             //
208     __mvalid<__value_t, _Sender, _Promise> &&              //
209     sender_to<_Sender, __receiver_t<_Sender, _Promise>> && //
210     requires(_Promise& __promise) {
211         {
212             __promise.unhandled_stopped()
213         } -> convertible_to<__coro::coroutine_handle<>>;
214     };
215 
216 struct __unspecified
217 {
218     auto get_return_object() noexcept -> __unspecified;
219     auto initial_suspend() noexcept -> __unspecified;
220     auto final_suspend() noexcept -> __unspecified;
221     void unhandled_exception() noexcept;
222     void return_void() noexcept;
223     auto unhandled_stopped() noexcept -> __coro::coroutine_handle<>;
224 };
225 
226 struct as_awaitable_t
227 {
228     template <class _Tp, class _Promise>
__select_impl_stdexec::__as_awaitable::as_awaitable_t229     static constexpr auto __select_impl_() noexcept
230     {
231         if constexpr (tag_invocable<as_awaitable_t, _Tp, _Promise&>)
232         {
233             using _Result = tag_invoke_result_t<as_awaitable_t, _Tp, _Promise&>;
234             constexpr bool _Nothrow =
235                 nothrow_tag_invocable<as_awaitable_t, _Tp, _Promise&>;
236             return static_cast<_Result (*)() noexcept(_Nothrow)>(nullptr);
237             // NOLINTNEXTLINE(bugprone-branch-clone)
238         }
239         else if constexpr (__awaitable<_Tp, __unspecified>)
240         { // NOT __awaitable<_Tp, _Promise> !!
241             using _Result = _Tp&&;
242             return static_cast<_Result (*)() noexcept>(nullptr);
243         }
244         else if constexpr (__awaitable_sender<_Tp, _Promise>)
245         {
246             using _Result = __sender_awaitable_t<_Promise, _Tp>;
247             constexpr bool _Nothrow = __nothrow_constructible_from<
248                 _Result, _Tp, __coro::coroutine_handle<_Promise>>;
249             return static_cast<_Result (*)() noexcept(_Nothrow)>(nullptr);
250         }
251         else
252         {
253             using _Result = _Tp&&;
254             return static_cast<_Result (*)() noexcept>(nullptr);
255         }
256     }
257 
258     template <class _Tp, class _Promise>
259     using __select_impl_t = decltype(__select_impl_<_Tp, _Promise>());
260 
261     template <class _Tp, class _Promise>
operator ()stdexec::__as_awaitable::as_awaitable_t262     auto operator()(_Tp&& __t, _Promise& __promise) const
263         noexcept(__nothrow_callable<__select_impl_t<_Tp, _Promise>>)
264             -> __call_result_t<__select_impl_t<_Tp, _Promise>>
265     {
266         if constexpr (tag_invocable<as_awaitable_t, _Tp, _Promise&>)
267         {
268             using _Result = tag_invoke_result_t<as_awaitable_t, _Tp, _Promise&>;
269             static_assert(__awaitable<_Result, _Promise>);
270             return tag_invoke(*this, static_cast<_Tp&&>(__t), __promise);
271             // NOLINTNEXTLINE(bugprone-branch-clone)
272         }
273         else if constexpr (__awaitable<_Tp, __unspecified>)
274         { // NOT __awaitable<_Tp, _Promise> !!
275             return static_cast<_Tp&&>(__t);
276         }
277         else if constexpr (__awaitable_sender<_Tp, _Promise>)
278         {
279             auto __hcoro =
280                 __coro::coroutine_handle<_Promise>::from_promise(__promise);
281             return __sender_awaitable_t<_Promise, _Tp>{static_cast<_Tp&&>(__t),
282                                                        __hcoro};
283         }
284         else
285         {
286             return static_cast<_Tp&&>(__t);
287         }
288     }
289 };
290 } // namespace __as_awaitable
291 
292 using __as_awaitable::as_awaitable_t;
293 inline constexpr as_awaitable_t as_awaitable{};
294 #endif
295 } // namespace stdexec
296