1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__detail/__config.hpp"
19 #include "__detail/__meta.hpp"
20 #include "__detail/__tag_invoke.hpp"
21 #include "concepts.hpp"
22 
23 #include <cstddef>
24 #include <functional>
25 #include <tuple>
26 #include <type_traits>
27 
28 namespace stdexec
29 {
30 template <class _Fun0, class _Fun1>
31 struct __composed
32 {
33     STDEXEC_ATTRIBUTE((no_unique_address))
34     _Fun0 __t0_;
35     STDEXEC_ATTRIBUTE((no_unique_address))
36     _Fun1 __t1_;
37 
38     template <class... _Ts>
39         requires __callable<_Fun1, _Ts...> &&
40                  __callable<_Fun0, __call_result_t<_Fun1, _Ts...>>
41     STDEXEC_ATTRIBUTE((always_inline))
42     __call_result_t<_Fun0, __call_result_t<_Fun1, _Ts...>>
operator ()stdexec::__composed43         operator()(_Ts&&... __ts) &&
44     {
45         return static_cast<_Fun0&&>(__t0_)(
46             static_cast<_Fun1&&>(__t1_)(static_cast<_Ts&&>(__ts)...));
47     }
48 
49     template <class... _Ts>
50         requires __callable<const _Fun1&, _Ts...> &&
51                  __callable<const _Fun0&, __call_result_t<const _Fun1&, _Ts...>>
52     STDEXEC_ATTRIBUTE((always_inline))
53     __call_result_t<_Fun0, __call_result_t<_Fun1, _Ts...>>
operator ()stdexec::__composed54         operator()(_Ts&&... __ts) const&
55     {
56         return __t0_(__t1_(static_cast<_Ts&&>(__ts)...));
57     }
58 };
59 
60 inline constexpr struct __compose_t
61 {
62     template <class _Fun0, class _Fun1>
63     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__compose_t64     __composed<_Fun0, _Fun1> operator()(_Fun0 __fun0, _Fun1 __fun1) const
65     {
66         return {static_cast<_Fun0&&>(__fun0), static_cast<_Fun1&&>(__fun1)};
67     }
68 } __compose{};
69 
70 namespace __invoke_
71 {
72 template <class>
73 inline constexpr bool __is_refwrap = false;
74 template <class _Up>
75 inline constexpr bool __is_refwrap<std::reference_wrapper<_Up>> = true;
76 
77 struct __funobj
78 {
79     template <class _Fun, class... _Args>
80     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__funobj81     constexpr auto operator()(_Fun&& __fun, _Args&&... __args) const noexcept(
82         noexcept((static_cast<_Fun&&>(__fun))(static_cast<_Args&&>(__args)...)))
83         -> decltype((static_cast<_Fun&&>(__fun))(
84             static_cast<_Args&&>(__args)...))
85     {
86         return static_cast<_Fun&&>(__fun)(static_cast<_Args&&>(__args)...);
87     }
88 };
89 
90 struct __memfn
91 {
92     template <class _Memptr, class _Ty, class... _Args>
93     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memfn94     constexpr auto operator()(_Memptr __mem_ptr, _Ty&& __ty,
95                               _Args&&... __args) const
96         noexcept(noexcept(((static_cast<_Ty&&>(__ty)).*
97                            __mem_ptr)(static_cast<_Args&&>(__args)...)))
98             -> decltype(((static_cast<_Ty&&>(__ty)).*
99                          __mem_ptr)(static_cast<_Args&&>(__args)...))
100     {
101         return ((static_cast<_Ty&&>(__ty)).*__mem_ptr)(
102             static_cast<_Args&&>(__args)...);
103     }
104 };
105 
106 struct __memfn_refwrap
107 {
108     template <class _Memptr, class _Ty, class... _Args>
109     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memfn_refwrap110     constexpr auto operator()(_Memptr __mem_ptr, _Ty __ty,
111                               _Args&&... __args) const
112         noexcept(noexcept((__ty.get().*
113                            __mem_ptr)(static_cast<_Args&&>(__args)...)))
114             -> decltype((__ty.get().*
115                          __mem_ptr)(static_cast<_Args&&>(__args)...))
116     {
117         return (__ty.get().*__mem_ptr)(static_cast<_Args&&>(__args)...);
118     }
119 };
120 
121 struct __memfn_smartptr
122 {
123     template <class _Memptr, class _Ty, class... _Args>
124     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memfn_smartptr125     constexpr auto operator()(_Memptr __mem_ptr, _Ty&& __ty,
126                               _Args&&... __args) const
127         noexcept(noexcept(((*static_cast<_Ty&&>(__ty)).*
128                            __mem_ptr)(static_cast<_Args&&>(__args)...)))
129             -> decltype(((*static_cast<_Ty&&>(__ty)).*
130                          __mem_ptr)(static_cast<_Args&&>(__args)...))
131     {
132         return ((*static_cast<_Ty&&>(__ty)).*__mem_ptr)(
133             static_cast<_Args&&>(__args)...);
134     }
135 };
136 
137 struct __memobj
138 {
139     template <class _Mbr, class _Class, class _Ty>
140     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memobj141     constexpr auto operator()(_Mbr _Class::*__mem_ptr, _Ty&& __ty)
142         const noexcept -> decltype(((static_cast<_Ty&&>(__ty)).*__mem_ptr))
143     {
144         return ((static_cast<_Ty&&>(__ty)).*__mem_ptr);
145     }
146 };
147 
148 struct __memobj_refwrap
149 {
150     template <class _Mbr, class _Class, class _Ty>
151     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memobj_refwrap152     constexpr auto operator()(_Mbr _Class::*__mem_ptr, _Ty __ty) const noexcept
153         -> decltype((__ty.get().*__mem_ptr))
154     {
155         return (__ty.get().*__mem_ptr);
156     }
157 };
158 
159 struct __memobj_smartptr
160 {
161     template <class _Mbr, class _Class, class _Ty>
162     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__memobj_smartptr163     constexpr auto operator()(_Mbr _Class::*__mem_ptr, _Ty&& __ty)
164         const noexcept -> decltype(((*static_cast<_Ty&&>(__ty)).*__mem_ptr))
165     {
166         return ((*static_cast<_Ty&&>(__ty)).*__mem_ptr);
167     }
168 };
169 
170 auto __invoke_selector(__ignore, __ignore) noexcept -> __funobj;
171 
172 template <class _Mbr, class _Class, class _Ty>
__invoke_selector(_Mbr _Class::*,const _Ty &)173 auto __invoke_selector(_Mbr _Class::*, const _Ty&) noexcept
174 {
175     if constexpr (STDEXEC_IS_CONST(_Mbr) || STDEXEC_IS_CONST(const _Mbr))
176     {
177         // member function ptr case
178         if constexpr (STDEXEC_IS_BASE_OF(_Class, _Ty))
179         {
180             return __memobj{};
181         }
182         else if constexpr (__is_refwrap<_Ty>)
183         {
184             return __memobj_refwrap{};
185         }
186         else
187         {
188             return __memobj_smartptr{};
189         }
190     }
191     else
192     {
193         // member object ptr case
194         if constexpr (STDEXEC_IS_BASE_OF(_Class, _Ty))
195         {
196             return __memfn{};
197         }
198         else if constexpr (__is_refwrap<_Ty>)
199         {
200             return __memfn_refwrap{};
201         }
202         else
203         {
204             return __memfn_smartptr{};
205         }
206     }
207 }
208 
209 struct __invoke_t
210 {
211     template <class _Fun>
212     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__invoke_t213     constexpr auto operator()(_Fun&& __fun) const
214         noexcept(noexcept((static_cast<_Fun&&>(__fun))()))
215             -> decltype((static_cast<_Fun&&>(__fun))())
216     {
217         return static_cast<_Fun&&>(__fun)();
218     }
219 
220     template <class _Fun, class _Ty, class... _Args>
221     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__invoke_::__invoke_t222     constexpr auto operator()(_Fun&& __fun, _Ty&& __ty, _Args&&... __args) const
223         noexcept(noexcept(__invoke_selector(__fun, __ty)(
224             static_cast<_Fun&&>(__fun), static_cast<_Ty&&>(__ty),
225             static_cast<_Args&&>(__args)...)))
226             -> decltype(__invoke_selector(__fun, __ty)(
227                 static_cast<_Fun&&>(__fun), static_cast<_Ty&&>(__ty),
228                 static_cast<_Args&&>(__args)...))
229     {
230         return decltype(__invoke_selector(__fun, __ty))()(
231             static_cast<_Fun&&>(__fun), static_cast<_Ty&&>(__ty),
232             static_cast<_Args&&>(__args)...);
233     }
234 };
235 } // namespace __invoke_
236 
237 inline constexpr __invoke_::__invoke_t __invoke{};
238 
239 template <class _Fun, class... _As>
240 concept __invocable = //
241     requires(_Fun&& __f, _As&&... __as) {
242         __invoke(static_cast<_Fun&&>(__f), static_cast<_As&&>(__as)...);
243     };
244 
245 template <class _Fun, class... _As>
246 concept __nothrow_invocable =    //
247     __invocable<_Fun, _As...> && //
248     requires(_Fun&& __f, _As&&... __as) {
249         {
250             __invoke(static_cast<_Fun&&>(__f), static_cast<_As&&>(__as)...)
251         } noexcept;
252     };
253 
254 template <class _Fun, class... _As>
255 using __invoke_result_t = //
256     decltype(__invoke(__declval<_Fun>(), __declval<_As>()...));
257 
258 namespace __apply_
259 {
260 using std::get;
261 
262 template <std::size_t... _Is, class _Fn, class _Tup>
263 STDEXEC_ATTRIBUTE((always_inline))
__impl(__indices<_Is...>,_Fn && __fn,_Tup && __tup)264 constexpr auto __impl(__indices<_Is...>, _Fn&& __fn, _Tup&& __tup) //
265     noexcept(noexcept(__invoke(static_cast<_Fn&&>(__fn),
266                                get<_Is>(static_cast<_Tup&&>(__tup))...)))
267         -> decltype(__invoke(static_cast<_Fn&&>(__fn),
268                              get<_Is>(static_cast<_Tup&&>(__tup))...))
269 {
270     return __invoke(static_cast<_Fn&&>(__fn),
271                     get<_Is>(static_cast<_Tup&&>(__tup))...);
272 }
273 
274 template <class _Tup>
275 using __tuple_indices =
276     __make_indices<std::tuple_size<std::remove_cvref_t<_Tup>>::value>;
277 
278 template <class _Fn, class _Tup>
279 using __result_t = decltype(__apply_::__impl(
280     __tuple_indices<_Tup>(), __declval<_Fn>(), __declval<_Tup>()));
281 } // namespace __apply_
282 
283 template <class _Fn, class _Tup>
284 concept __applicable = __mvalid<__apply_::__result_t, _Fn, _Tup>;
285 
286 template <class _Fn, class _Tup>
287 concept __nothrow_applicable =
288     __applicable<_Fn, _Tup> //
289     &&                      //
290     noexcept(__apply_::__impl(__apply_::__tuple_indices<_Tup>(),
291                               __declval<_Fn>(), __declval<_Tup>()));
292 
293 template <class _Fn, class _Tup>
294     requires __applicable<_Fn, _Tup>
295 using __apply_result_t = __apply_::__result_t<_Fn, _Tup>;
296 
297 struct __apply_t
298 {
299     template <class _Fn, class _Tup>
300         requires __applicable<_Fn, _Tup>
301     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__apply_t302     constexpr auto operator()(_Fn&& __fn, _Tup&& __tup) const
303         noexcept(__nothrow_applicable<_Fn, _Tup>) -> __apply_result_t<_Fn, _Tup>
304     {
305         return __apply_::__impl(__apply_::__tuple_indices<_Tup>(),
306                                 static_cast<_Fn&&>(__fn),
307                                 static_cast<_Tup&&>(__tup));
308     }
309 };
310 
311 inline constexpr __apply_t __apply{};
312 
313 template <class _Tag, class _Ty>
314 struct __field
315 {
316     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__field317     _Ty operator()(_Tag) const noexcept(__nothrow_decay_copyable<const _Ty&>)
318     {
319         return __t_;
320     }
321 
322     _Ty __t_;
323 };
324 
325 template <class _Tag>
326 struct __mkfield_
327 {
328     template <class _Ty>
329     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__mkfield_330     __field<_Tag, __decay_t<_Ty>> operator()(_Ty&& __ty) const
331         noexcept(__nothrow_decay_copyable<_Ty>)
332     {
333         return {static_cast<_Ty&&>(__ty)};
334     }
335 };
336 
337 template <class _Tag>
338 inline constexpr __mkfield_<_Tag> __mkfield{};
339 } // namespace stdexec
340