1 /* 2 * Copyright (c) 2021-2024 NVIDIA Corporation 3 * 4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 5 * (the "License"); you may not use this file except in compliance with 6 * the License. You may obtain a copy of the License at 7 * 8 * https://llvm.org/LICENSE.txt 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include "__config.hpp" 19 #include "__concepts.hpp" 20 #include "__utility.hpp" 21 22 namespace stdexec { 23 #if !STDEXEC_STD_NO_COROUTINES() 24 // Define some concepts and utilities for working with awaitables 25 template <class _Tp> 26 concept __await_suspend_result = __one_of<_Tp, void, bool> 27 || __is_instance_of<_Tp, __coro::coroutine_handle>; 28 29 template <class _Awaiter, class... _Promise> 30 concept __awaiter = requires(_Awaiter& __awaiter, __coro::coroutine_handle<_Promise...> __h) { 31 __awaiter.await_ready() ? 1 : 0; 32 { __awaiter.await_suspend(__h) } -> __await_suspend_result; 33 __awaiter.await_resume(); 34 }; 35 36 # if STDEXEC_MSVC() 37 // MSVCBUG https://developercommunity.visualstudio.com/t/operator-co_await-not-found-in-requires/10452721 38 39 template <class _Awaitable> __co_await_constraint(_Awaitable && __awaitable)40 void __co_await_constraint(_Awaitable&& __awaitable) 41 requires requires { operator co_await(static_cast<_Awaitable &&>(__awaitable)); }; 42 # endif 43 44 template <class _Awaitable> __get_awaiter(_Awaitable && __awaitable,__ignore={})45 auto __get_awaiter(_Awaitable&& __awaitable, __ignore = {}) -> decltype(auto) { 46 if constexpr (requires { static_cast<_Awaitable &&>(__awaitable).operator co_await(); }) { 47 return static_cast<_Awaitable&&>(__awaitable).operator co_await(); 48 } else if constexpr (requires { 49 # if STDEXEC_MSVC() 50 __co_await_constraint(static_cast<_Awaitable &&>(__awaitable)); 51 # else 52 operator co_await(static_cast<_Awaitable&&>(__awaitable)); 53 # endif 54 }) { 55 return operator co_await(static_cast<_Awaitable&&>(__awaitable)); 56 } else { 57 return static_cast<_Awaitable&&>(__awaitable); 58 } 59 } 60 61 template <class _Awaitable, class _Promise> __get_awaiter(_Awaitable && __awaitable,_Promise * __promise)62 auto __get_awaiter(_Awaitable&& __awaitable, _Promise* __promise) -> decltype(auto) 63 requires requires { __promise->await_transform(static_cast<_Awaitable &&>(__awaitable)); } 64 { 65 if constexpr (requires { 66 __promise->await_transform(static_cast<_Awaitable &&>(__awaitable)) 67 .operator co_await(); 68 }) { 69 return __promise->await_transform(static_cast<_Awaitable&&>(__awaitable)).operator co_await(); 70 } else if constexpr (requires { 71 # if STDEXEC_MSVC() 72 __co_await_constraint( 73 __promise->await_transform(static_cast<_Awaitable &&>(__awaitable))); 74 # else 75 operator co_await(__promise->await_transform(static_cast<_Awaitable&&>(__awaitable))); 76 # endif 77 }) { 78 return operator co_await(__promise->await_transform(static_cast<_Awaitable&&>(__awaitable))); 79 } else { 80 return __promise->await_transform(static_cast<_Awaitable&&>(__awaitable)); 81 } 82 } 83 84 template <class _Awaitable, class... _Promise> 85 concept __awaitable = requires(_Awaitable&& __awaitable, _Promise*... __promise) { 86 { 87 stdexec::__get_awaiter(static_cast<_Awaitable &&>(__awaitable), __promise...) 88 } -> __awaiter<_Promise...>; 89 }; 90 91 template <class _Tp> 92 auto __as_lvalue(_Tp&&) -> _Tp&; 93 94 template <class _Awaitable, class... _Promise> 95 requires __awaitable<_Awaitable, _Promise...> 96 using __await_result_t = decltype(stdexec::__as_lvalue( 97 stdexec::__get_awaiter( 98 std::declval<_Awaitable>(), 99 static_cast<_Promise*>(nullptr)...)) 100 .await_resume()); 101 102 #else 103 104 template <class _Awaitable, class... _Promise> 105 concept __awaitable = false; 106 107 template <class _Awaitable, class... _Promise> 108 requires __awaitable<_Awaitable, _Promise...> 109 using __await_result_t = void; 110 111 #endif 112 } // namespace stdexec 113