1 /* 2 * Copyright (c) 2021-2024 NVIDIA Corporation 3 * 4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 5 * (the "License"); you may not use this file except in compliance with 6 * the License. You may obtain a copy of the License at 7 * 8 * https://llvm.org/LICENSE.txt 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include "__awaitable.hpp" 19 #include "__completion_signatures.hpp" 20 #include "__concepts.hpp" 21 #include "__config.hpp" 22 #include "__execution_fwd.hpp" 23 #include "__meta.hpp" 24 #include "__receivers.hpp" 25 #include "__tag_invoke.hpp" 26 27 #include <exception> 28 #include <utility> 29 30 namespace stdexec 31 { 32 #if !STDEXEC_STD_NO_COROUTINES() 33 ///////////////////////////////////////////////////////////////////////////// 34 // __connect_awaitable_ 35 namespace __connect_awaitable_ 36 { 37 struct __promise_base 38 { initial_suspendstdexec::__connect_awaitable_::__promise_base39 auto initial_suspend() noexcept -> __coro::suspend_always 40 { 41 return {}; 42 } 43 final_suspendstdexec::__connect_awaitable_::__promise_base44 [[noreturn]] auto final_suspend() noexcept -> __coro::suspend_always 45 { 46 std::terminate(); 47 } 48 unhandled_exceptionstdexec::__connect_awaitable_::__promise_base49 [[noreturn]] void unhandled_exception() noexcept 50 { 51 std::terminate(); 52 } 53 return_voidstdexec::__connect_awaitable_::__promise_base54 [[noreturn]] void return_void() noexcept 55 { 56 std::terminate(); 57 } 58 }; 59 60 struct __operation_base 61 { 62 __coro::coroutine_handle<> __coro_; 63 __operation_basestdexec::__connect_awaitable_::__operation_base64 explicit __operation_base(__coro::coroutine_handle<> __hcoro) noexcept : 65 __coro_(__hcoro) 66 {} 67 __operation_basestdexec::__connect_awaitable_::__operation_base68 __operation_base(__operation_base&& __other) noexcept : 69 __coro_(std::exchange(__other.__coro_, {})) 70 {} 71 ~__operation_basestdexec::__connect_awaitable_::__operation_base72 ~__operation_base() 73 { 74 if (__coro_) 75 { 76 #if STDEXEC_MSVC() 77 // MSVCBUG 78 // https://developercommunity.visualstudio.com/t/Double-destroy-of-a-local-in-coroutine-d/10456428 79 80 // Reassign __coro_ before calling destroy to make the mutation 81 // observable and to hopefully ensure that the compiler does not 82 // eliminate it. 83 auto __coro = __coro_; 84 __coro_ = {}; 85 __coro.destroy(); 86 #else 87 __coro_.destroy(); 88 #endif 89 } 90 } 91 startstdexec::__connect_awaitable_::__operation_base92 void start() & noexcept 93 { 94 __coro_.resume(); 95 } 96 }; 97 98 template <class _ReceiverId> 99 struct __promise; 100 101 template <class _ReceiverId> 102 struct __operation 103 { 104 struct __t : __operation_base 105 { 106 using promise_type = stdexec::__t<__promise<_ReceiverId>>; 107 using __operation_base::__operation_base; 108 }; 109 }; 110 111 template <class _ReceiverId> 112 struct __promise 113 { 114 using _Receiver = stdexec::__t<_ReceiverId>; 115 116 struct __t : __promise_base 117 { 118 using __id = __promise; 119 120 #if STDEXEC_EDG() __tstdexec::__connect_awaitable_::__promise::__t121 __t(auto&&, _Receiver&& __rcvr) noexcept : __rcvr_(__rcvr) {} 122 #else __tstdexec::__connect_awaitable_::__promise::__t123 explicit __t(auto&, _Receiver& __rcvr) noexcept : __rcvr_(__rcvr) {} 124 #endif 125 unhandled_stoppedstdexec::__connect_awaitable_::__promise::__t126 auto unhandled_stopped() noexcept -> __coro::coroutine_handle<> 127 { 128 stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr_)); 129 // Returning noop_coroutine here causes the __connect_awaitable 130 // coroutine to never resume past the point where it co_await's 131 // the awaitable. 132 return __coro::noop_coroutine(); 133 } 134 get_return_objectstdexec::__connect_awaitable_::__promise::__t135 auto get_return_object() noexcept 136 -> stdexec::__t<__operation<_ReceiverId>> 137 { 138 return stdexec::__t<__operation<_ReceiverId>>{ 139 __coro::coroutine_handle<__t>::from_promise(*this)}; 140 } 141 142 template <class _Awaitable> await_transformstdexec::__connect_awaitable_::__promise::__t143 auto await_transform(_Awaitable&& __awaitable) noexcept -> _Awaitable&& 144 { 145 return static_cast<_Awaitable&&>(__awaitable); 146 } 147 148 template <class _Awaitable> 149 requires tag_invocable<as_awaitable_t, _Awaitable, __t&> await_transformstdexec::__connect_awaitable_::__promise::__t150 auto await_transform(_Awaitable&& __awaitable) // 151 noexcept(nothrow_tag_invocable<as_awaitable_t, _Awaitable, __t&>) 152 -> tag_invoke_result_t<as_awaitable_t, _Awaitable, __t&> 153 { 154 return tag_invoke(as_awaitable, 155 static_cast<_Awaitable&&>(__awaitable), *this); 156 } 157 158 // Pass through the get_env receiver query get_envstdexec::__connect_awaitable_::__promise::__t159 auto get_env() const noexcept -> env_of_t<_Receiver> 160 { 161 return stdexec::get_env(__rcvr_); 162 } 163 164 _Receiver& __rcvr_; 165 }; 166 }; 167 168 template <receiver _Receiver> 169 using __promise_t = __t<__promise<__id<_Receiver>>>; 170 171 template <receiver _Receiver> 172 using __operation_t = __t<__operation<__id<_Receiver>>>; 173 174 struct __connect_awaitable_t 175 { 176 private: 177 template <class _Fun, class... _Ts> __co_callstdexec::__connect_awaitable_::__connect_awaitable_t178 static auto __co_call(_Fun __fun, _Ts&&... __as) noexcept 179 { 180 auto __fn = [&, __fun]() noexcept { 181 __fun(static_cast<_Ts&&>(__as)...); 182 }; 183 184 struct __awaiter 185 { 186 decltype(__fn) __fn_; 187 188 static constexpr auto await_ready() noexcept -> bool 189 { 190 return false; 191 } 192 193 void await_suspend(__coro::coroutine_handle<>) noexcept 194 { 195 __fn_(); 196 } 197 198 [[noreturn]] void await_resume() noexcept 199 { 200 std::terminate(); 201 } 202 }; 203 204 return __awaiter{__fn}; 205 } 206 207 template <class _Awaitable, class _Receiver> 208 #if STDEXEC_GCC() && (__GNUC__ > 11) 209 __attribute__((__used__)) 210 #endif 211 static auto __co_implstdexec::__connect_awaitable_::__connect_awaitable_t212 __co_impl(_Awaitable __awaitable, _Receiver __rcvr) 213 -> __operation_t<_Receiver> 214 { 215 using __result_t = __await_result_t<_Awaitable, __promise_t<_Receiver>>; 216 std::exception_ptr __eptr; 217 try 218 { 219 if constexpr (same_as<__result_t, void>) 220 co_await (co_await static_cast<_Awaitable&&>(__awaitable), 221 __co_call(set_value, 222 static_cast<_Receiver&&>(__rcvr))); 223 else 224 co_await __co_call( 225 set_value, static_cast<_Receiver&&>(__rcvr), 226 co_await static_cast<_Awaitable&&>(__awaitable)); 227 } 228 catch (...) 229 { 230 __eptr = std::current_exception(); 231 } 232 co_await __co_call(set_error, static_cast<_Receiver&&>(__rcvr), 233 static_cast<std::exception_ptr&&>(__eptr)); 234 } 235 236 template <receiver _Receiver, class _Awaitable> 237 using __completions_t = // 238 completion_signatures< 239 __minvoke< // set_value_t() or set_value_t(T) 240 __mremove<void, __qf<set_value_t>>, 241 __await_result_t<_Awaitable, __promise_t<_Receiver>>>, 242 set_error_t(std::exception_ptr), set_stopped_t()>; 243 244 public: 245 template <class _Receiver, __awaitable<__promise_t<_Receiver>> _Awaitable> 246 requires receiver_of<_Receiver, __completions_t<_Receiver, _Awaitable>> operator ()stdexec::__connect_awaitable_::__connect_awaitable_t247 auto operator()(_Awaitable&& __awaitable, _Receiver __rcvr) const 248 -> __operation_t<_Receiver> 249 { 250 return __co_impl(static_cast<_Awaitable&&>(__awaitable), 251 static_cast<_Receiver&&>(__rcvr)); 252 } 253 }; 254 } // namespace __connect_awaitable_ 255 256 using __connect_awaitable_::__connect_awaitable_t; 257 #else 258 struct __connect_awaitable_t 259 {}; 260 #endif 261 inline constexpr __connect_awaitable_t __connect_awaitable{}; 262 } // namespace stdexec 263