1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__concepts.hpp"
19 #include "__config.hpp"
20 #include "__utility.hpp"
21 
22 namespace stdexec
23 {
24 #if !STDEXEC_STD_NO_COROUTINES()
25 // Define some concepts and utilities for working with awaitables
26 template <class _Tp>
27 concept __await_suspend_result =
28     __one_of<_Tp, void, bool> ||
29     __is_instance_of<_Tp, __coro::coroutine_handle>;
30 
31 template <class _Awaiter, class _Promise>
32 concept __with_await_suspend =
33     requires(_Awaiter& __awaiter, __coro::coroutine_handle<_Promise> __h) {
34         { __awaiter.await_suspend(__h) } -> __await_suspend_result;
35     };
36 
37 template <class _Awaiter, class... _Promise>
38 concept __awaiter = //
39     requires(_Awaiter& __awaiter) {
40         __awaiter.await_ready() ? 1 : 0;
41         __awaiter.await_resume();
42     } && //
43     (__with_await_suspend<_Awaiter, _Promise> && ...);
44 
45 #if STDEXEC_MSVC()
46 // MSVCBUG
47 // https://developercommunity.visualstudio.com/t/operator-co_await-not-found-in-requires/10452721
48 
49 template <class _Awaitable>
__co_await_constraint(_Awaitable && __awaitable)50 void __co_await_constraint(_Awaitable&& __awaitable)
51     requires requires {
52                  operator co_await(static_cast<_Awaitable&&>(__awaitable));
53              };
54 #endif
55 
56 template <class _Awaitable>
__get_awaiter(_Awaitable && __awaitable,__ignore={})57 auto __get_awaiter(_Awaitable&& __awaitable, __ignore = {}) -> decltype(auto)
58 {
59     if constexpr (requires {
60                       static_cast<_Awaitable&&>(__awaitable)
61                           .operator co_await();
62                   })
63     {
64         return static_cast<_Awaitable&&>(__awaitable).operator co_await();
65     }
66     else if constexpr (requires {
67 #if STDEXEC_MSVC()
68                            __co_await_constraint(
69                                static_cast<_Awaitable&&>(__awaitable));
70 #else
71         operator co_await(static_cast<_Awaitable&&>(__awaitable));
72 #endif
73                        })
74     {
75         return operator co_await(static_cast<_Awaitable&&>(__awaitable));
76     }
77     else
78     {
79         return static_cast<_Awaitable&&>(__awaitable);
80     }
81 }
82 
83 template <class _Awaitable, class _Promise>
__get_awaiter(_Awaitable && __awaitable,_Promise * __promise)84 auto __get_awaiter(_Awaitable&& __awaitable,
85                    _Promise* __promise) -> decltype(auto)
86     requires requires {
87                  __promise->await_transform(
88                      static_cast<_Awaitable&&>(__awaitable));
89              }
90 {
91     if constexpr (requires {
92                       __promise
93                           ->await_transform(
94                               static_cast<_Awaitable&&>(__awaitable))
95                           .operator co_await();
96                   })
97     {
98         return __promise
99             ->await_transform(static_cast<_Awaitable&&>(__awaitable))
100             .operator co_await();
101     }
102     else if constexpr (requires {
103 #if STDEXEC_MSVC()
104                            __co_await_constraint(__promise->await_transform(
105                                static_cast<_Awaitable&&>(__awaitable)));
106 #else
107         operator co_await(__promise->await_transform(static_cast<_Awaitable&&>(__awaitable)));
108 #endif
109                        })
110     {
111         return operator co_await(
112             __promise->await_transform(static_cast<_Awaitable&&>(__awaitable)));
113     }
114     else
115     {
116         return __promise->await_transform(
117             static_cast<_Awaitable&&>(__awaitable));
118     }
119 }
120 
121 template <class _Awaitable, class... _Promise>
122 concept __awaitable = //
123     requires(_Awaitable&& __awaitable, _Promise*... __promise) {
124         {
125             stdexec::__get_awaiter(static_cast<_Awaitable&&>(__awaitable),
126                                    __promise...)
127         } -> __awaiter<_Promise...>;
128     };
129 
130 template <class _Tp>
131 auto __as_lvalue(_Tp&&) -> _Tp&;
132 
133 template <class _Awaitable, class... _Promise>
134     requires __awaitable<_Awaitable, _Promise...>
135 using __await_result_t =
136     decltype(stdexec::__as_lvalue(
137                  stdexec::__get_awaiter(std::declval<_Awaitable>(),
138                                         static_cast<_Promise*>(nullptr)...))
139                  .await_resume());
140 
141 #else
142 
143 template <class _Awaitable, class... _Promise>
144 concept __awaitable = false;
145 
146 template <class _Awaitable, class... _Promise>
147     requires __awaitable<_Awaitable, _Promise...>
148 using __await_result_t = void;
149 
150 #endif
151 } // namespace stdexec
152