1 /* 2 * Copyright (c) 2021-2024 NVIDIA Corporation 3 * 4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 5 * (the "License"); you may not use this file except in compliance with 6 * the License. You may obtain a copy of the License at 7 * 8 * https://llvm.org/LICENSE.txt 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include "__execution_fwd.hpp" 19 20 #include "__awaitable.hpp" 21 #include "__completion_signatures.hpp" 22 #include "__concepts.hpp" 23 #include "__config.hpp" 24 #include "__env.hpp" 25 #include "__meta.hpp" 26 #include "__receivers.hpp" 27 28 #include <exception> 29 #include <utility> 30 31 namespace stdexec { 32 #if !STDEXEC_STD_NO_COROUTINES() 33 ///////////////////////////////////////////////////////////////////////////// 34 // __connect_awaitable_ 35 namespace __connect_awaitable_ { 36 struct __promise_base { initial_suspendstdexec::__connect_awaitable_::__promise_base37 auto initial_suspend() noexcept -> __coro::suspend_always { 38 return {}; 39 } 40 41 [[noreturn]] final_suspendstdexec::__connect_awaitable_::__promise_base42 auto final_suspend() noexcept -> __coro::suspend_always { 43 std::terminate(); 44 } 45 46 [[noreturn]] unhandled_exceptionstdexec::__connect_awaitable_::__promise_base47 void unhandled_exception() noexcept { 48 std::terminate(); 49 } 50 51 [[noreturn]] return_voidstdexec::__connect_awaitable_::__promise_base52 void return_void() noexcept { 53 std::terminate(); 54 } 55 }; 56 57 struct __operation_base { 58 __coro::coroutine_handle<> __coro_; 59 __operation_basestdexec::__connect_awaitable_::__operation_base60 explicit __operation_base(__coro::coroutine_handle<> __hcoro) noexcept 61 : __coro_(__hcoro) { 62 } 63 __operation_basestdexec::__connect_awaitable_::__operation_base64 __operation_base(__operation_base&& __other) noexcept 65 : __coro_(std::exchange(__other.__coro_, {})) { 66 } 67 ~__operation_basestdexec::__connect_awaitable_::__operation_base68 ~__operation_base() { 69 if (__coro_) { 70 # if STDEXEC_MSVC() 71 // MSVCBUG https://developercommunity.visualstudio.com/t/Double-destroy-of-a-local-in-coroutine-d/10456428 72 73 // Reassign __coro_ before calling destroy to make the mutation 74 // observable and to hopefully ensure that the compiler does not eliminate it. 75 auto __coro = __coro_; 76 __coro_ = {}; 77 __coro.destroy(); 78 # else 79 __coro_.destroy(); 80 # endif 81 } 82 } 83 startstdexec::__connect_awaitable_::__operation_base84 void start() & noexcept { 85 __coro_.resume(); 86 } 87 }; 88 89 template <class _ReceiverId> 90 struct __promise; 91 92 template <class _ReceiverId> 93 struct __operation { 94 struct __t : __operation_base { 95 using promise_type = stdexec::__t<__promise<_ReceiverId>>; 96 using __operation_base::__operation_base; 97 }; 98 }; 99 100 template <class _ReceiverId> 101 struct __promise { 102 using _Receiver = stdexec::__t<_ReceiverId>; 103 104 struct __t 105 : __promise_base 106 , __env::__with_await_transform<__t> { 107 using __id = __promise; 108 109 # if STDEXEC_EDG() __tstdexec::__connect_awaitable_::__promise::__t110 __t(auto&&, _Receiver&& __rcvr) noexcept 111 : __rcvr_(__rcvr) { 112 } 113 # else __tstdexec::__connect_awaitable_::__promise::__t114 explicit __t(auto&, _Receiver& __rcvr) noexcept 115 : __rcvr_(__rcvr) { 116 } 117 # endif 118 unhandled_stoppedstdexec::__connect_awaitable_::__promise::__t119 auto unhandled_stopped() noexcept -> __coro::coroutine_handle<> { 120 stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr_)); 121 // Returning noop_coroutine here causes the __connect_awaitable 122 // coroutine to never resume past the point where it co_await's 123 // the awaitable. 124 return __coro::noop_coroutine(); 125 } 126 get_return_objectstdexec::__connect_awaitable_::__promise::__t127 auto get_return_object() noexcept -> stdexec::__t<__operation<_ReceiverId>> { 128 return stdexec::__t<__operation<_ReceiverId>>{ 129 __coro::coroutine_handle<__t>::from_promise(*this)}; 130 } 131 132 // Pass through the get_env receiver query get_envstdexec::__connect_awaitable_::__promise::__t133 auto get_env() const noexcept -> env_of_t<_Receiver> { 134 return stdexec::get_env(__rcvr_); 135 } 136 137 _Receiver& __rcvr_; 138 }; 139 }; 140 141 template <receiver _Receiver> 142 using __promise_t = __t<__promise<__id<_Receiver>>>; 143 144 template <receiver _Receiver> 145 using __operation_t = __t<__operation<__id<_Receiver>>>; 146 147 struct __connect_awaitable_t { 148 private: 149 template <class _Fun, class... _Ts> __co_callstdexec::__connect_awaitable_::__connect_awaitable_t150 static auto __co_call(_Fun __fun, _Ts&&... __as) noexcept { 151 auto __fn = [&, __fun]() noexcept { 152 __fun(static_cast<_Ts&&>(__as)...); 153 }; 154 155 struct __awaiter { 156 decltype(__fn) __fn_; 157 158 static constexpr auto await_ready() noexcept -> bool { 159 return false; 160 } 161 162 void await_suspend(__coro::coroutine_handle<>) noexcept { 163 __fn_(); 164 } 165 166 [[noreturn]] 167 void await_resume() noexcept { 168 std::terminate(); 169 } 170 }; 171 172 return __awaiter{__fn}; 173 } 174 175 template <class _Awaitable, class _Receiver> 176 # if STDEXEC_GCC() && (STDEXEC_GCC_VERSION >= 12'00) 177 __attribute__((__used__)) 178 # endif __co_implstdexec::__connect_awaitable_::__connect_awaitable_t179 static auto __co_impl(_Awaitable __awaitable, _Receiver __rcvr) -> __operation_t<_Receiver> { 180 using __result_t = __await_result_t<_Awaitable, __promise_t<_Receiver>>; 181 std::exception_ptr __eptr; 182 STDEXEC_TRY { 183 if constexpr (same_as<__result_t, void>) 184 co_await ( 185 co_await static_cast<_Awaitable&&>(__awaitable), 186 __co_call(set_value, static_cast<_Receiver&&>(__rcvr))); 187 else 188 co_await __co_call( 189 set_value, 190 static_cast<_Receiver&&>(__rcvr), 191 co_await static_cast<_Awaitable&&>(__awaitable)); 192 } 193 STDEXEC_CATCH_ALL { 194 __eptr = std::current_exception(); 195 } 196 co_await __co_call( 197 set_error, static_cast<_Receiver&&>(__rcvr), static_cast<std::exception_ptr&&>(__eptr)); 198 } 199 200 template <receiver _Receiver, class _Awaitable> 201 using __completions_t = 202 completion_signatures< 203 __minvoke< // set_value_t() or set_value_t(T) 204 __mremove<void, __qf<set_value_t>>, 205 __await_result_t<_Awaitable, __promise_t<_Receiver>>>, 206 set_error_t(std::exception_ptr), 207 set_stopped_t()>; 208 209 public: 210 template <class _Receiver, __awaitable<__promise_t<_Receiver>> _Awaitable> 211 requires receiver_of<_Receiver, __completions_t<_Receiver, _Awaitable>> 212 auto operator ()stdexec::__connect_awaitable_::__connect_awaitable_t213 operator()(_Awaitable&& __awaitable, _Receiver __rcvr) const -> __operation_t<_Receiver> { 214 return __co_impl(static_cast<_Awaitable&&>(__awaitable), static_cast<_Receiver&&>(__rcvr)); 215 } 216 }; 217 } // namespace __connect_awaitable_ 218 219 using __connect_awaitable_::__connect_awaitable_t; 220 #else 221 struct __connect_awaitable_t { }; 222 #endif 223 inline constexpr __connect_awaitable_t __connect_awaitable{}; 224 } // namespace stdexec 225