1 /*
2 * Copyright (c) 2021-2024 NVIDIA Corporation
3 *
4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5 * (the "License"); you may not use this file except in compliance with
6 * the License. You may obtain a copy of the License at
7 *
8 * https://llvm.org/LICENSE.txt
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #pragma once
17
18 #include "__concepts.hpp"
19 #include "__config.hpp"
20 #include "__utility.hpp"
21
22 namespace stdexec
23 {
24 #if !STDEXEC_STD_NO_COROUTINES()
25 // Define some concepts and utilities for working with awaitables
26 template <class _Tp>
27 concept __await_suspend_result =
28 __one_of<_Tp, void, bool> ||
29 __is_instance_of<_Tp, __coro::coroutine_handle>;
30
31 template <class _Awaiter, class _Promise>
32 concept __with_await_suspend =
33 requires(_Awaiter& __awaiter, __coro::coroutine_handle<_Promise> __h) {
34 { __awaiter.await_suspend(__h) } -> __await_suspend_result;
35 };
36
37 template <class _Awaiter, class... _Promise>
38 concept __awaiter = //
39 requires(_Awaiter& __awaiter) {
40 __awaiter.await_ready() ? 1 : 0;
41 __awaiter.await_resume();
42 } && //
43 (__with_await_suspend<_Awaiter, _Promise> && ...);
44
45 #if STDEXEC_MSVC()
46 // MSVCBUG
47 // https://developercommunity.visualstudio.com/t/operator-co_await-not-found-in-requires/10452721
48
49 template <class _Awaitable>
__co_await_constraint(_Awaitable && __awaitable)50 void __co_await_constraint(_Awaitable&& __awaitable)
51 requires requires {
52 operator co_await(static_cast<_Awaitable&&>(__awaitable));
53 };
54 #endif
55
56 template <class _Awaitable>
__get_awaiter(_Awaitable && __awaitable,__ignore={})57 auto __get_awaiter(_Awaitable&& __awaitable, __ignore = {}) -> decltype(auto)
58 {
59 if constexpr (requires {
60 static_cast<_Awaitable&&>(__awaitable)
61 .operator co_await();
62 })
63 {
64 return static_cast<_Awaitable&&>(__awaitable).operator co_await();
65 }
66 else if constexpr (requires {
67 #if STDEXEC_MSVC()
68 __co_await_constraint(
69 static_cast<_Awaitable&&>(__awaitable));
70 #else
71 operator co_await(static_cast<_Awaitable&&>(__awaitable));
72 #endif
73 })
74 {
75 return operator co_await(static_cast<_Awaitable&&>(__awaitable));
76 }
77 else
78 {
79 return static_cast<_Awaitable&&>(__awaitable);
80 }
81 }
82
83 template <class _Awaitable, class _Promise>
__get_awaiter(_Awaitable && __awaitable,_Promise * __promise)84 auto __get_awaiter(_Awaitable&& __awaitable,
85 _Promise* __promise) -> decltype(auto)
86 requires requires {
87 __promise->await_transform(
88 static_cast<_Awaitable&&>(__awaitable));
89 }
90 {
91 if constexpr (requires {
92 __promise
93 ->await_transform(
94 static_cast<_Awaitable&&>(__awaitable))
95 .operator co_await();
96 })
97 {
98 return __promise
99 ->await_transform(static_cast<_Awaitable&&>(__awaitable))
100 .operator co_await();
101 }
102 else if constexpr (requires {
103 #if STDEXEC_MSVC()
104 __co_await_constraint(__promise->await_transform(
105 static_cast<_Awaitable&&>(__awaitable)));
106 #else
107 operator co_await(__promise->await_transform(static_cast<_Awaitable&&>(__awaitable)));
108 #endif
109 })
110 {
111 return operator co_await(
112 __promise->await_transform(static_cast<_Awaitable&&>(__awaitable)));
113 }
114 else
115 {
116 return __promise->await_transform(
117 static_cast<_Awaitable&&>(__awaitable));
118 }
119 }
120
121 template <class _Awaitable, class... _Promise>
122 concept __awaitable = //
123 requires(_Awaitable&& __awaitable, _Promise*... __promise) {
124 {
125 stdexec::__get_awaiter(static_cast<_Awaitable&&>(__awaitable),
126 __promise...)
127 } -> __awaiter<_Promise...>;
128 };
129
130 template <class _Tp>
131 auto __as_lvalue(_Tp&&) -> _Tp&;
132
133 template <class _Awaitable, class... _Promise>
134 requires __awaitable<_Awaitable, _Promise...>
135 using __await_result_t =
136 decltype(stdexec::__as_lvalue(
137 stdexec::__get_awaiter(std::declval<_Awaitable>(),
138 static_cast<_Promise*>(nullptr)...))
139 .await_resume());
140
141 #else
142
143 template <class _Awaitable, class... _Promise>
144 concept __awaitable = false;
145
146 template <class _Awaitable, class... _Promise>
147 requires __awaitable<_Awaitable, _Promise...>
148 using __await_result_t = void;
149
150 #endif
151 } // namespace stdexec
152