1 /*
2  * Copyright (c) 2021-2022 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__detail/__concepts.hpp"
19 #include "__detail/__config.hpp"
20 #include "concepts.hpp"
21 
22 #include <version>
23 #if __cpp_impl_coroutine >= 201902 && __cpp_lib_coroutine >= 201902
24 #include <coroutine>
25 namespace __coro = std;
26 #elif defined(__cpp_coroutines) && __has_include(<experimental/coroutine>)
27 #include <experimental/coroutine>
28 namespace __coro = std::experimental;
29 #else
30 #define STDEXEC_STD_NO_COROUTINES_ 1
31 #endif
32 
33 namespace stdexec
34 {
35 #if !STDEXEC_STD_NO_COROUTINES_
36 // Define some concepts and utilities for working with awaitables
37 template <class _Tp>
38 concept __await_suspend_result =
39     __one_of<_Tp, void, bool> ||
40     __is_instance_of<_Tp, __coro::coroutine_handle>;
41 
42 template <class _Awaiter, class _Promise>
43 concept __with_await_suspend =
44     same_as<_Promise, void> || //
45     requires(_Awaiter& __awaiter, __coro::coroutine_handle<_Promise> __h) {
46         {
47             __awaiter.await_suspend(__h)
48         } -> __await_suspend_result;
49     };
50 
51 template <class _Awaiter, class _Promise = void>
52 concept __awaiter = //
53     requires(_Awaiter& __awaiter) {
54         __awaiter.await_ready() ? 1 : 0;
55         __awaiter.await_resume();
56     } && //
57     __with_await_suspend<_Awaiter, _Promise>;
58 
59 #if STDEXEC_MSVC()
60 // MSVCBUG
61 // https://developercommunity.visualstudio.com/t/operator-co_await-not-found-in-requires/10452721
62 
63 template <class _Awaitable>
__co_await_constraint(_Awaitable && __awaitable)64 void __co_await_constraint(_Awaitable&& __awaitable)
65     requires requires {
66                  operator co_await(static_cast<_Awaitable&&>(__awaitable));
67              };
68 #endif
69 
70 template <class _Awaitable>
__get_awaiter(_Awaitable && __awaitable,void *)71 auto __get_awaiter(_Awaitable&& __awaitable, void*) -> decltype(auto)
72 {
73     if constexpr (requires {
74                       static_cast<_Awaitable&&>(__awaitable)
75                           .
76                           operator co_await();
77                   })
78     {
79         return static_cast<_Awaitable&&>(__awaitable).operator co_await();
80     }
81     else if constexpr (requires {
82 #if STDEXEC_MSVC()
83                            __co_await_constraint(
84                                static_cast<_Awaitable&&>(__awaitable));
85 #else
86                            operator co_await(
87                                static_cast<_Awaitable&&>(__awaitable));
88 #endif
89                        })
90     {
91         return operator co_await(static_cast<_Awaitable&&>(__awaitable));
92     }
93     else
94     {
95         return static_cast<_Awaitable&&>(__awaitable);
96     }
97 }
98 
99 template <class _Awaitable, class _Promise>
__get_awaiter(_Awaitable && __awaitable,_Promise * __promise)100 auto __get_awaiter(_Awaitable&& __awaitable, _Promise* __promise)
101     -> decltype(auto)
102     requires requires {
103                  __promise->await_transform(
104                      static_cast<_Awaitable&&>(__awaitable));
105              }
106 {
107     if constexpr (requires {
108                       __promise
109                           ->await_transform(
110                               static_cast<_Awaitable&&>(__awaitable))
111                           .
112                           operator co_await();
113                   })
114     {
115         return __promise
116             ->await_transform(static_cast<_Awaitable&&>(__awaitable))
117             .
118             operator co_await();
119     }
120     else if constexpr (requires {
121 #if STDEXEC_MSVC()
122                            __co_await_constraint(__promise->await_transform(
123                                static_cast<_Awaitable&&>(__awaitable)));
124 #else
125                            operator co_await(__promise->await_transform(
126                                static_cast<_Awaitable&&>(__awaitable)));
127 #endif
128                        })
129     {
130         return operator co_await(
131             __promise->await_transform(static_cast<_Awaitable&&>(__awaitable)));
132     }
133     else
134     {
135         return __promise->await_transform(
136             static_cast<_Awaitable&&>(__awaitable));
137     }
138 }
139 
140 template <class _Awaitable, class _Promise = void>
141 concept __awaitable = //
142     requires(_Awaitable&& __awaitable, _Promise* __promise) {
143         {
144             stdexec::__get_awaiter(static_cast<_Awaitable&&>(__awaitable),
145                                    __promise)
146         } -> __awaiter<_Promise>;
147     };
148 
149 template <class _Tp>
150 auto __as_lvalue(_Tp&&) -> _Tp&;
151 
152 template <class _Awaitable, class _Promise = void>
153     requires __awaitable<_Awaitable, _Promise>
154 using __await_result_t =
155     decltype(stdexec::__as_lvalue(
156                  stdexec::__get_awaiter(std::declval<_Awaitable>(),
157                                         static_cast<_Promise*>(nullptr)))
158                  .await_resume());
159 
160 #else
161 
162 template <class _Awaitable, class _Promise = void>
163 concept __awaitable = false;
164 
165 template <class _Awaitable, class _Promise = void>
166     requires __awaitable<_Awaitable, _Promise>
167 using __await_result_t = void;
168 
169 #endif
170 } // namespace stdexec
171