1 /*
2 * Copyright (c) 2021-2024 NVIDIA Corporation
3 *
4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5 * (the "License"); you may not use this file except in compliance with
6 * the License. You may obtain a copy of the License at
7 *
8 * https://llvm.org/LICENSE.txt
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #pragma once
17
18 #include "__detail/__config.hpp"
19 #include "__detail/__meta.hpp"
20 #include "__detail/__tag_invoke.hpp"
21 #include "concepts.hpp"
22
23 #include <cstddef>
24 #include <functional>
25 #include <tuple>
26 #include <type_traits>
27
28 namespace stdexec
29 {
30 template <class _Fun0, class _Fun1>
31 struct __composed
32 {
33 STDEXEC_ATTRIBUTE((no_unique_address))
34 _Fun0 __t0_;
35 STDEXEC_ATTRIBUTE((no_unique_address))
36 _Fun1 __t1_;
37
38 template <class... _Ts>
39 requires __callable<_Fun1, _Ts...> &&
40 __callable<_Fun0, __call_result_t<_Fun1, _Ts...>>
41 STDEXEC_ATTRIBUTE((always_inline))
42 __call_result_t<_Fun0, __call_result_t<_Fun1, _Ts...>>
operator ()stdexec::__composed43 operator()(_Ts&&... __ts) &&
44 {
45 return static_cast<_Fun0&&>(__t0_)(
46 static_cast<_Fun1&&>(__t1_)(static_cast<_Ts&&>(__ts)...));
47 }
48
49 template <class... _Ts>
50 requires __callable<const _Fun1&, _Ts...> &&
51 __callable<const _Fun0&, __call_result_t<const _Fun1&, _Ts...>>
52 STDEXEC_ATTRIBUTE((always_inline))
53 __call_result_t<_Fun0, __call_result_t<_Fun1, _Ts...>>
operator ()stdexec::__composed54 operator()(_Ts&&... __ts) const&
55 {
56 return __t0_(__t1_(static_cast<_Ts&&>(__ts)...));
57 }
58 };
59
60 inline constexpr struct __compose_t
61 {
62 template <class _Fun0, class _Fun1>
63 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__compose_t64 __composed<_Fun0, _Fun1> operator()(_Fun0 __fun0, _Fun1 __fun1) const
65 {
66 return {static_cast<_Fun0&&>(__fun0), static_cast<_Fun1&&>(__fun1)};
67 }
68 } __compose{};
69
70 namespace __invoke_
71 {
72 template <class>
73 inline constexpr bool __is_refwrap = false;
74 template <class _Up>
75 inline constexpr bool __is_refwrap<std::reference_wrapper<_Up>> = true;
76
77 struct __funobj
78 {
79 template <class _Fun, class... _Args>
80 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__funobj81 constexpr auto operator()(_Fun&& __fun, _Args&&... __args) const noexcept(
82 noexcept((static_cast<_Fun&&>(__fun))(static_cast<_Args&&>(__args)...)))
83 -> decltype((static_cast<_Fun&&>(__fun))(
84 static_cast<_Args&&>(__args)...))
85 {
86 return static_cast<_Fun&&>(__fun)(static_cast<_Args&&>(__args)...);
87 }
88 };
89
90 struct __memfn
91 {
92 template <class _Memptr, class _Ty, class... _Args>
93 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memfn94 constexpr auto operator()(_Memptr __mem_ptr, _Ty&& __ty,
95 _Args&&... __args) const
96 noexcept(noexcept(((static_cast<_Ty&&>(__ty)).*
97 __mem_ptr)(static_cast<_Args&&>(__args)...)))
98 -> decltype(((static_cast<_Ty&&>(__ty)).*
99 __mem_ptr)(static_cast<_Args&&>(__args)...))
100 {
101 return ((static_cast<_Ty&&>(__ty)).*__mem_ptr)(
102 static_cast<_Args&&>(__args)...);
103 }
104 };
105
106 struct __memfn_refwrap
107 {
108 template <class _Memptr, class _Ty, class... _Args>
109 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memfn_refwrap110 constexpr auto operator()(_Memptr __mem_ptr, _Ty __ty,
111 _Args&&... __args) const
112 noexcept(noexcept((__ty.get().*
113 __mem_ptr)(static_cast<_Args&&>(__args)...)))
114 -> decltype((__ty.get().*
115 __mem_ptr)(static_cast<_Args&&>(__args)...))
116 {
117 return (__ty.get().*__mem_ptr)(static_cast<_Args&&>(__args)...);
118 }
119 };
120
121 struct __memfn_smartptr
122 {
123 template <class _Memptr, class _Ty, class... _Args>
124 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memfn_smartptr125 constexpr auto operator()(_Memptr __mem_ptr, _Ty&& __ty,
126 _Args&&... __args) const
127 noexcept(noexcept(((*static_cast<_Ty&&>(__ty)).*
128 __mem_ptr)(static_cast<_Args&&>(__args)...)))
129 -> decltype(((*static_cast<_Ty&&>(__ty)).*
130 __mem_ptr)(static_cast<_Args&&>(__args)...))
131 {
132 return ((*static_cast<_Ty&&>(__ty)).*__mem_ptr)(
133 static_cast<_Args&&>(__args)...);
134 }
135 };
136
137 struct __memobj
138 {
139 template <class _Mbr, class _Class, class _Ty>
140 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memobj141 constexpr auto operator()(_Mbr _Class::*__mem_ptr, _Ty&& __ty)
142 const noexcept -> decltype(((static_cast<_Ty&&>(__ty)).*__mem_ptr))
143 {
144 return ((static_cast<_Ty&&>(__ty)).*__mem_ptr);
145 }
146 };
147
148 struct __memobj_refwrap
149 {
150 template <class _Mbr, class _Class, class _Ty>
151 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memobj_refwrap152 constexpr auto operator()(_Mbr _Class::*__mem_ptr, _Ty __ty) const noexcept
153 -> decltype((__ty.get().*__mem_ptr))
154 {
155 return (__ty.get().*__mem_ptr);
156 }
157 };
158
159 struct __memobj_smartptr
160 {
161 template <class _Mbr, class _Class, class _Ty>
162 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memobj_smartptr163 constexpr auto operator()(_Mbr _Class::*__mem_ptr, _Ty&& __ty)
164 const noexcept -> decltype(((*static_cast<_Ty&&>(__ty)).*__mem_ptr))
165 {
166 return ((*static_cast<_Ty&&>(__ty)).*__mem_ptr);
167 }
168 };
169
170 auto __invoke_selector(__ignore, __ignore) noexcept -> __funobj;
171
172 template <class _Mbr, class _Class, class _Ty>
__invoke_selector(_Mbr _Class::*,const _Ty &)173 auto __invoke_selector(_Mbr _Class::*, const _Ty&) noexcept
174 {
175 if constexpr (STDEXEC_IS_CONST(_Mbr) || STDEXEC_IS_CONST(const _Mbr))
176 {
177 // member function ptr case
178 if constexpr (STDEXEC_IS_BASE_OF(_Class, _Ty))
179 {
180 return __memobj{};
181 }
182 else if constexpr (__is_refwrap<_Ty>)
183 {
184 return __memobj_refwrap{};
185 }
186 else
187 {
188 return __memobj_smartptr{};
189 }
190 }
191 else
192 {
193 // member object ptr case
194 if constexpr (STDEXEC_IS_BASE_OF(_Class, _Ty))
195 {
196 return __memfn{};
197 }
198 else if constexpr (__is_refwrap<_Ty>)
199 {
200 return __memfn_refwrap{};
201 }
202 else
203 {
204 return __memfn_smartptr{};
205 }
206 }
207 }
208
209 struct __invoke_t
210 {
211 template <class _Fun>
212 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__invoke_t213 constexpr auto operator()(_Fun&& __fun) const
214 noexcept(noexcept((static_cast<_Fun&&>(__fun))()))
215 -> decltype((static_cast<_Fun&&>(__fun))())
216 {
217 return static_cast<_Fun&&>(__fun)();
218 }
219
220 template <class _Fun, class _Ty, class... _Args>
221 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__invoke_t222 constexpr auto operator()(_Fun&& __fun, _Ty&& __ty, _Args&&... __args) const
223 noexcept(noexcept(__invoke_selector(__fun, __ty)(
224 static_cast<_Fun&&>(__fun), static_cast<_Ty&&>(__ty),
225 static_cast<_Args&&>(__args)...)))
226 -> decltype(__invoke_selector(__fun, __ty)(
227 static_cast<_Fun&&>(__fun), static_cast<_Ty&&>(__ty),
228 static_cast<_Args&&>(__args)...))
229 {
230 return decltype(__invoke_selector(__fun, __ty))()(
231 static_cast<_Fun&&>(__fun), static_cast<_Ty&&>(__ty),
232 static_cast<_Args&&>(__args)...);
233 }
234 };
235 } // namespace __invoke_
236
237 inline constexpr __invoke_::__invoke_t __invoke{};
238
239 template <class _Fun, class... _As>
240 concept __invocable = //
241 requires(_Fun&& __f, _As&&... __as) {
242 __invoke(static_cast<_Fun&&>(__f), static_cast<_As&&>(__as)...);
243 };
244
245 template <class _Fun, class... _As>
246 concept __nothrow_invocable = //
247 __invocable<_Fun, _As...> && //
248 requires(_Fun&& __f, _As&&... __as) {
249 {
250 __invoke(static_cast<_Fun&&>(__f), static_cast<_As&&>(__as)...)
251 } noexcept;
252 };
253
254 template <class _Fun, class... _As>
255 using __invoke_result_t = //
256 decltype(__invoke(__declval<_Fun>(), __declval<_As>()...));
257
258 namespace __apply_
259 {
260 using std::get;
261
262 template <std::size_t... _Is, class _Fn, class _Tup>
263 STDEXEC_ATTRIBUTE((always_inline))
__impl(__indices<_Is...>,_Fn && __fn,_Tup && __tup)264 constexpr auto __impl(__indices<_Is...>, _Fn&& __fn, _Tup&& __tup) //
265 noexcept(noexcept(__invoke(static_cast<_Fn&&>(__fn),
266 get<_Is>(static_cast<_Tup&&>(__tup))...)))
267 -> decltype(__invoke(static_cast<_Fn&&>(__fn),
268 get<_Is>(static_cast<_Tup&&>(__tup))...))
269 {
270 return __invoke(static_cast<_Fn&&>(__fn),
271 get<_Is>(static_cast<_Tup&&>(__tup))...);
272 }
273
274 template <class _Tup>
275 using __tuple_indices =
276 __make_indices<std::tuple_size<std::remove_cvref_t<_Tup>>::value>;
277
278 template <class _Fn, class _Tup>
279 using __result_t = decltype(__apply_::__impl(
280 __tuple_indices<_Tup>(), __declval<_Fn>(), __declval<_Tup>()));
281 } // namespace __apply_
282
283 template <class _Fn, class _Tup>
284 concept __applicable = __mvalid<__apply_::__result_t, _Fn, _Tup>;
285
286 template <class _Fn, class _Tup>
287 concept __nothrow_applicable =
288 __applicable<_Fn, _Tup> //
289 && //
290 noexcept(__apply_::__impl(__apply_::__tuple_indices<_Tup>(),
291 __declval<_Fn>(), __declval<_Tup>()));
292
293 template <class _Fn, class _Tup>
294 requires __applicable<_Fn, _Tup>
295 using __apply_result_t = __apply_::__result_t<_Fn, _Tup>;
296
297 struct __apply_t
298 {
299 template <class _Fn, class _Tup>
300 requires __applicable<_Fn, _Tup>
301 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__apply_t302 constexpr auto operator()(_Fn&& __fn, _Tup&& __tup) const
303 noexcept(__nothrow_applicable<_Fn, _Tup>) -> __apply_result_t<_Fn, _Tup>
304 {
305 return __apply_::__impl(__apply_::__tuple_indices<_Tup>(),
306 static_cast<_Fn&&>(__fn),
307 static_cast<_Tup&&>(__tup));
308 }
309 };
310
311 inline constexpr __apply_t __apply{};
312
313 template <class _Tag, class _Ty>
314 struct __field
315 {
316 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__field317 _Ty operator()(_Tag) const noexcept(__nothrow_decay_copyable<const _Ty&>)
318 {
319 return __t_;
320 }
321
322 _Ty __t_;
323 };
324
325 template <class _Tag>
326 struct __mkfield_
327 {
328 template <class _Ty>
329 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__mkfield_330 __field<_Tag, __decay_t<_Ty>> operator()(_Ty&& __ty) const
331 noexcept(__nothrow_decay_copyable<_Ty>)
332 {
333 return {static_cast<_Ty&&>(__ty)};
334 }
335 };
336
337 template <class _Tag>
338 inline constexpr __mkfield_<_Tag> __mkfield{};
339 } // namespace stdexec
340