1 /*
2 * Copyright (c) 2021-2024 NVIDIA Corporation
3 *
4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5 * (the "License"); you may not use this file except in compliance with
6 * the License. You may obtain a copy of the License at
7 *
8 * https://llvm.org/LICENSE.txt
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #pragma once
17
18 #include "__detail/__config.hpp"
19 #include "__detail/__meta.hpp"
20 #include "__detail/__tag_invoke.hpp"
21 #include "concepts.hpp"
22
23 #include <cstddef>
24 #include <functional>
25 #include <tuple>
26 #include <type_traits>
27
28 namespace stdexec
29 {
30 template <class _Fun0, class _Fun1>
31 struct __composed
32 {
33 STDEXEC_ATTRIBUTE((no_unique_address))
34 _Fun0 __t0_;
35 STDEXEC_ATTRIBUTE((no_unique_address))
36 _Fun1 __t1_;
37
38 template <class... _Ts>
39 requires __callable<_Fun1, _Ts...> &&
40 __callable<_Fun0, __call_result_t<_Fun1, _Ts...>>
41 STDEXEC_ATTRIBUTE((always_inline))
42 __call_result_t<_Fun0, __call_result_t<_Fun1, _Ts...>>
operator ()stdexec::__composed43 operator()(_Ts&&... __ts) &&
44 {
45 return static_cast<_Fun0&&>(__t0_)(
46 static_cast<_Fun1&&>(__t1_)(static_cast<_Ts&&>(__ts)...));
47 }
48
49 template <class... _Ts>
50 requires __callable<const _Fun1&, _Ts...> &&
51 __callable<const _Fun0&, __call_result_t<const _Fun1&, _Ts...>>
52 STDEXEC_ATTRIBUTE((always_inline))
53 __call_result_t<_Fun0, __call_result_t<_Fun1, _Ts...>>
operator ()stdexec::__composed54 operator()(_Ts&&... __ts) const&
55 {
56 return __t0_(__t1_(static_cast<_Ts&&>(__ts)...));
57 }
58 };
59
60 inline constexpr struct __compose_t
61 {
62 template <class _Fun0, class _Fun1>
63 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__compose_t64 __composed<_Fun0, _Fun1> operator()(_Fun0 __fun0, _Fun1 __fun1) const
65 {
66 return {static_cast<_Fun0&&>(__fun0), static_cast<_Fun1&&>(__fun1)};
67 }
68 } __compose{};
69
70 namespace __invoke_
71 {
72 template <class>
73 inline constexpr bool __is_refwrap = false;
74 template <class _Up>
75 inline constexpr bool __is_refwrap<std::reference_wrapper<_Up>> = true;
76
77 struct __funobj
78 {
79 template <class _Fun, class... _Args>
80 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__funobj81 constexpr auto operator()(_Fun&& __fun, _Args&&... __args) const noexcept(
82 noexcept((static_cast<_Fun&&>(__fun))(static_cast<_Args&&>(__args)...)))
83 -> decltype((static_cast<_Fun&&>(__fun))(
84 static_cast<_Args&&>(__args)...))
85 {
86 return static_cast<_Fun&&>(__fun)(static_cast<_Args&&>(__args)...);
87 }
88 };
89
90 struct __memfn
91 {
92 template <class _Memptr, class _Ty, class... _Args>
93 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memfn94 constexpr auto operator()(_Memptr __mem_ptr, _Ty&& __ty,
95 _Args&&... __args) const
96 noexcept(noexcept(((static_cast<_Ty&&>(__ty)).*
97 __mem_ptr)(static_cast<_Args&&>(__args)...)))
98 -> decltype(((static_cast<_Ty&&>(__ty)).*
99 __mem_ptr)(static_cast<_Args&&>(__args)...))
100 {
101 return ((static_cast<_Ty&&>(__ty)).*__mem_ptr)(
102 static_cast<_Args&&>(__args)...);
103 }
104 };
105
106 struct __memfn_refwrap
107 {
108 template <class _Memptr, class _Ty, class... _Args>
109 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memfn_refwrap110 constexpr auto operator()(_Memptr __mem_ptr, _Ty __ty,
111 _Args&&... __args) const
112 noexcept(noexcept((__ty.get().*
113 __mem_ptr)(static_cast<_Args&&>(__args)...)))
114 -> decltype((__ty.get().*
115 __mem_ptr)(static_cast<_Args&&>(__args)...))
116 {
117 return (__ty.get().*__mem_ptr)(static_cast<_Args&&>(__args)...);
118 }
119 };
120
121 struct __memfn_smartptr
122 {
123 template <class _Memptr, class _Ty, class... _Args>
124 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memfn_smartptr125 constexpr auto operator()(_Memptr __mem_ptr, _Ty&& __ty,
126 _Args&&... __args) const
127 noexcept(noexcept(((*static_cast<_Ty&&>(__ty)).*
128 __mem_ptr)(static_cast<_Args&&>(__args)...)))
129 -> decltype(((*static_cast<_Ty&&>(__ty)).*
130 __mem_ptr)(static_cast<_Args&&>(__args)...))
131 {
132 return ((*static_cast<_Ty&&>(__ty)).*__mem_ptr)(
133 static_cast<_Args&&>(__args)...);
134 }
135 };
136
137 struct __memobj
138 {
139 template <class _Mbr, class _Class, class _Ty>
140 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memobj141 constexpr auto operator()(_Mbr _Class::* __mem_ptr,
142 _Ty&& __ty) const noexcept
143 -> decltype(((static_cast<_Ty&&>(__ty)).*__mem_ptr))
144 {
145 return ((static_cast<_Ty&&>(__ty)).*__mem_ptr);
146 }
147 };
148
149 struct __memobj_refwrap
150 {
151 template <class _Mbr, class _Class, class _Ty>
152 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memobj_refwrap153 constexpr auto operator()(_Mbr _Class::* __mem_ptr, _Ty __ty) const noexcept
154 -> decltype((__ty.get().*__mem_ptr))
155 {
156 return (__ty.get().*__mem_ptr);
157 }
158 };
159
160 struct __memobj_smartptr
161 {
162 template <class _Mbr, class _Class, class _Ty>
163 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memobj_smartptr164 constexpr auto operator()(_Mbr _Class::* __mem_ptr,
165 _Ty&& __ty) const noexcept
166 -> decltype(((*static_cast<_Ty&&>(__ty)).*__mem_ptr))
167 {
168 return ((*static_cast<_Ty&&>(__ty)).*__mem_ptr);
169 }
170 };
171
172 auto __invoke_selector(__ignore, __ignore) noexcept -> __funobj;
173
174 template <class _Mbr, class _Class, class _Ty>
__invoke_selector(_Mbr _Class::*,const _Ty &)175 auto __invoke_selector(_Mbr _Class::*, const _Ty&) noexcept
176 {
177 if constexpr (STDEXEC_IS_CONST(_Mbr) || STDEXEC_IS_CONST(const _Mbr))
178 {
179 // member function ptr case
180 if constexpr (STDEXEC_IS_BASE_OF(_Class, _Ty))
181 {
182 return __memobj{};
183 }
184 else if constexpr (__is_refwrap<_Ty>)
185 {
186 return __memobj_refwrap{};
187 }
188 else
189 {
190 return __memobj_smartptr{};
191 }
192 }
193 else
194 {
195 // member object ptr case
196 if constexpr (STDEXEC_IS_BASE_OF(_Class, _Ty))
197 {
198 return __memfn{};
199 }
200 else if constexpr (__is_refwrap<_Ty>)
201 {
202 return __memfn_refwrap{};
203 }
204 else
205 {
206 return __memfn_smartptr{};
207 }
208 }
209 }
210
211 struct __invoke_t
212 {
213 template <class _Fun>
214 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__invoke_t215 constexpr auto operator()(_Fun&& __fun) const
216 noexcept(noexcept((static_cast<_Fun&&>(__fun))()))
217 -> decltype((static_cast<_Fun&&>(__fun))())
218 {
219 return static_cast<_Fun&&>(__fun)();
220 }
221
222 template <class _Fun, class _Ty, class... _Args>
223 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__invoke_t224 constexpr auto operator()(_Fun&& __fun, _Ty&& __ty, _Args&&... __args) const
225 noexcept(noexcept(__invoke_selector(__fun, __ty)(
226 static_cast<_Fun&&>(__fun), static_cast<_Ty&&>(__ty),
227 static_cast<_Args&&>(__args)...)))
228 -> decltype(__invoke_selector(__fun, __ty)(
229 static_cast<_Fun&&>(__fun), static_cast<_Ty&&>(__ty),
230 static_cast<_Args&&>(__args)...))
231 {
232 return decltype(__invoke_selector(__fun, __ty))()(
233 static_cast<_Fun&&>(__fun), static_cast<_Ty&&>(__ty),
234 static_cast<_Args&&>(__args)...);
235 }
236 };
237 } // namespace __invoke_
238
239 inline constexpr __invoke_::__invoke_t __invoke{};
240
241 template <class _Fun, class... _As>
242 concept __invocable = //
243 requires(_Fun&& __f, _As&&... __as) {
244 __invoke(static_cast<_Fun&&>(__f), static_cast<_As&&>(__as)...);
245 };
246
247 template <class _Fun, class... _As>
248 concept __nothrow_invocable = //
249 __invocable<_Fun, _As...> && //
250 requires(_Fun&& __f, _As&&... __as) {
251 {
252 __invoke(static_cast<_Fun&&>(__f), static_cast<_As&&>(__as)...)
253 } noexcept;
254 };
255
256 template <class _Fun, class... _As>
257 using __invoke_result_t = //
258 decltype(__invoke(__declval<_Fun>(), __declval<_As>()...));
259
260 namespace __apply_
261 {
262 using std::get;
263
264 template <std::size_t... _Is, class _Fn, class _Tup>
265 STDEXEC_ATTRIBUTE((always_inline))
__impl(__indices<_Is...>,_Fn && __fn,_Tup && __tup)266 constexpr auto __impl(__indices<_Is...>, _Fn&& __fn, _Tup&& __tup) //
267 noexcept(noexcept(__invoke(static_cast<_Fn&&>(__fn),
268 get<_Is>(static_cast<_Tup&&>(__tup))...)))
269 -> decltype(__invoke(static_cast<_Fn&&>(__fn),
270 get<_Is>(static_cast<_Tup&&>(__tup))...))
271 {
272 return __invoke(static_cast<_Fn&&>(__fn),
273 get<_Is>(static_cast<_Tup&&>(__tup))...);
274 }
275
276 template <class _Tup>
277 using __tuple_indices =
278 __make_indices<std::tuple_size<std::remove_cvref_t<_Tup>>::value>;
279
280 template <class _Fn, class _Tup>
281 using __result_t = decltype(__apply_::__impl(
282 __tuple_indices<_Tup>(), __declval<_Fn>(), __declval<_Tup>()));
283 } // namespace __apply_
284
285 template <class _Fn, class _Tup>
286 concept __applicable = __mvalid<__apply_::__result_t, _Fn, _Tup>;
287
288 template <class _Fn, class _Tup>
289 concept __nothrow_applicable =
290 __applicable<_Fn, _Tup> //
291 && //
292 noexcept(__apply_::__impl(__apply_::__tuple_indices<_Tup>(),
293 __declval<_Fn>(), __declval<_Tup>()));
294
295 template <class _Fn, class _Tup>
296 requires __applicable<_Fn, _Tup>
297 using __apply_result_t = __apply_::__result_t<_Fn, _Tup>;
298
299 struct __apply_t
300 {
301 template <class _Fn, class _Tup>
302 requires __applicable<_Fn, _Tup>
303 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__apply_t304 constexpr auto operator()(_Fn&& __fn, _Tup&& __tup) const
305 noexcept(__nothrow_applicable<_Fn, _Tup>) -> __apply_result_t<_Fn, _Tup>
306 {
307 return __apply_::__impl(__apply_::__tuple_indices<_Tup>(),
308 static_cast<_Fn&&>(__fn),
309 static_cast<_Tup&&>(__tup));
310 }
311 };
312
313 inline constexpr __apply_t __apply{};
314
315 template <class _Tag, class _Ty>
316 struct __field
317 {
318 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__field319 _Ty operator()(_Tag) const noexcept(__nothrow_decay_copyable<const _Ty&>)
320 {
321 return __t_;
322 }
323
324 _Ty __t_;
325 };
326
327 template <class _Tag>
328 struct __mkfield_
329 {
330 template <class _Ty>
331 STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__mkfield_332 __field<_Tag, __decay_t<_Ty>> operator()(_Ty&& __ty) const
333 noexcept(__nothrow_decay_copyable<_Ty>)
334 {
335 return {static_cast<_Ty&&>(__ty)};
336 }
337 };
338
339 template <class _Tag>
340 inline constexpr __mkfield_<_Tag> __mkfield{};
341 } // namespace stdexec
342