1 /* 2 * Copyright (c) 2021-2024 NVIDIA Corporation 3 * 4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 5 * (the "License"); you may not use this file except in compliance with 6 * the License. You may obtain a copy of the License at 7 * 8 * https://llvm.org/LICENSE.txt 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include "__detail/__config.hpp" 19 #include "__detail/__meta.hpp" 20 21 #include "concepts.hpp" // IWYU pragma: keep 22 23 #include <functional> 24 #include <tuple> 25 #include <type_traits> 26 #include <cstddef> 27 28 namespace stdexec { 29 template <class _Fun0, class _Fun1> 30 struct __composed { 31 STDEXEC_ATTRIBUTE(no_unique_address) _Fun0 __t0_; 32 STDEXEC_ATTRIBUTE(no_unique_address) _Fun1 __t1_; 33 34 template <class... _Ts> 35 requires __callable<_Fun1, _Ts...> && __callable<_Fun0, __call_result_t<_Fun1, _Ts...>> STDEXEC_ATTRIBUTEstdexec::__composed36 STDEXEC_ATTRIBUTE(host, device, always_inline) 37 auto operator()(_Ts&&... __ts) && -> __call_result_t<_Fun0, __call_result_t<_Fun1, _Ts...>> { 38 return static_cast<_Fun0&&>(__t0_)(static_cast<_Fun1&&>(__t1_)(static_cast<_Ts&&>(__ts)...)); 39 } 40 41 template <class... _Ts> 42 requires __callable<const _Fun1&, _Ts...> 43 && __callable<const _Fun0&, __call_result_t<const _Fun1&, _Ts...>> STDEXEC_ATTRIBUTEstdexec::__composed44 STDEXEC_ATTRIBUTE(host, device, always_inline) 45 auto 46 operator()(_Ts&&... __ts) const & -> __call_result_t<_Fun0, __call_result_t<_Fun1, _Ts...>> { 47 return __t0_(__t1_(static_cast<_Ts&&>(__ts)...)); 48 } 49 }; 50 51 inline constexpr struct __compose_t { 52 template <class _Fun0, class _Fun1> STDEXEC_ATTRIBUTEstdexec::__compose_t53 STDEXEC_ATTRIBUTE(host, device, always_inline) 54 auto operator()(_Fun0 __fun0, _Fun1 __fun1) const -> __composed<_Fun0, _Fun1> { 55 return {static_cast<_Fun0&&>(__fun0), static_cast<_Fun1&&>(__fun1)}; 56 } 57 } __compose{}; 58 59 namespace __invoke_ { 60 template <class> 61 inline constexpr bool __is_refwrap = false; 62 template <class _Up> 63 inline constexpr bool __is_refwrap<std::reference_wrapper<_Up>> = true; 64 65 struct __funobj { 66 template <class _Fun, class... _Args> STDEXEC_ATTRIBUTEstdexec::__invoke_::__funobj67 STDEXEC_ATTRIBUTE(host, device, always_inline) 68 constexpr auto operator()(_Fun&& __fun, _Args&&... __args) const 69 noexcept(noexcept((static_cast<_Fun&&>(__fun))(static_cast<_Args&&>(__args)...))) 70 -> decltype((static_cast<_Fun&&>(__fun))(static_cast<_Args&&>(__args)...)) { 71 return static_cast<_Fun&&>(__fun)(static_cast<_Args&&>(__args)...); 72 } 73 }; 74 75 struct __memfn { 76 template <class _Memptr, class _Ty, class... _Args> STDEXEC_ATTRIBUTEstdexec::__invoke_::__memfn77 STDEXEC_ATTRIBUTE(host, device, always_inline) 78 constexpr auto operator()(_Memptr __mem_ptr, _Ty&& __ty, _Args&&... __args) const 79 noexcept(noexcept(((static_cast<_Ty&&>(__ty)).*__mem_ptr)(static_cast<_Args&&>(__args)...))) 80 -> decltype(((static_cast<_Ty&&>(__ty)).*__mem_ptr)(static_cast<_Args&&>(__args)...)) { 81 return ((static_cast<_Ty&&>(__ty)).*__mem_ptr)(static_cast<_Args&&>(__args)...); 82 } 83 }; 84 85 struct __memfn_refwrap { 86 template <class _Memptr, class _Ty, class... _Args> STDEXEC_ATTRIBUTEstdexec::__invoke_::__memfn_refwrap87 STDEXEC_ATTRIBUTE(host, device, always_inline) 88 constexpr auto operator()(_Memptr __mem_ptr, _Ty __ty, _Args&&... __args) const 89 noexcept(noexcept((__ty.get().*__mem_ptr)(static_cast<_Args&&>(__args)...))) 90 -> decltype((__ty.get().*__mem_ptr)(static_cast<_Args&&>(__args)...)) { 91 return (__ty.get().*__mem_ptr)(static_cast<_Args&&>(__args)...); 92 } 93 }; 94 95 struct __memfn_smartptr { 96 template <class _Memptr, class _Ty, class... _Args> STDEXEC_ATTRIBUTEstdexec::__invoke_::__memfn_smartptr97 STDEXEC_ATTRIBUTE(host, device, always_inline) 98 constexpr auto operator()(_Memptr __mem_ptr, _Ty&& __ty, _Args&&... __args) const noexcept( 99 noexcept(((*static_cast<_Ty&&>(__ty)).*__mem_ptr)(static_cast<_Args&&>(__args)...))) 100 -> decltype(((*static_cast<_Ty&&>(__ty)).*__mem_ptr)(static_cast<_Args&&>(__args)...)) { 101 return ((*static_cast<_Ty&&>(__ty)).*__mem_ptr)(static_cast<_Args&&>(__args)...); 102 } 103 }; 104 105 struct __memobj { 106 template <class _Mbr, class _Class, class _Ty> STDEXEC_ATTRIBUTEstdexec::__invoke_::__memobj107 STDEXEC_ATTRIBUTE(host, device, always_inline) 108 constexpr auto operator()(_Mbr _Class::* __mem_ptr, _Ty&& __ty) const noexcept 109 -> decltype(((static_cast<_Ty&&>(__ty)).*__mem_ptr)) { 110 return ((static_cast<_Ty&&>(__ty)).*__mem_ptr); 111 } 112 }; 113 114 struct __memobj_refwrap { 115 template <class _Mbr, class _Class, class _Ty> STDEXEC_ATTRIBUTEstdexec::__invoke_::__memobj_refwrap116 STDEXEC_ATTRIBUTE(host, device, always_inline) 117 constexpr auto operator()(_Mbr _Class::* __mem_ptr, _Ty __ty) const noexcept 118 -> decltype((__ty.get().*__mem_ptr)) { 119 return (__ty.get().*__mem_ptr); 120 } 121 }; 122 123 struct __memobj_smartptr { 124 template <class _Mbr, class _Class, class _Ty> STDEXEC_ATTRIBUTEstdexec::__invoke_::__memobj_smartptr125 STDEXEC_ATTRIBUTE(host, device, always_inline) 126 constexpr auto operator()(_Mbr _Class::* __mem_ptr, _Ty&& __ty) const noexcept 127 -> decltype(((*static_cast<_Ty&&>(__ty)).*__mem_ptr)) { 128 return ((*static_cast<_Ty&&>(__ty)).*__mem_ptr); 129 } 130 }; 131 132 STDEXEC_ATTRIBUTE(host, device) 133 auto __invoke_selector(__ignore, __ignore) noexcept -> __funobj; 134 135 template <class _Mbr, class _Class, class _Ty> STDEXEC_ATTRIBUTE(host,device)136 STDEXEC_ATTRIBUTE(host, device) 137 auto __invoke_selector(_Mbr _Class::*, const _Ty&) noexcept { 138 if constexpr (STDEXEC_IS_FUNCTION(_Mbr)) { 139 // member function ptr case 140 if constexpr (STDEXEC_IS_BASE_OF(_Class, _Ty)) { 141 return __memfn{}; 142 } else if constexpr (__is_refwrap<_Ty>) { 143 return __memfn_refwrap{}; 144 } else { 145 return __memfn_smartptr{}; 146 } 147 } else { 148 // member object ptr case 149 if constexpr (STDEXEC_IS_BASE_OF(_Class, _Ty)) { 150 return __memobj{}; 151 } else if constexpr (__is_refwrap<_Ty>) { 152 return __memobj_refwrap{}; 153 } else { 154 return __memobj_smartptr{}; 155 } 156 } 157 } 158 159 struct __invoke_t { 160 template <class _Fun> STDEXEC_ATTRIBUTEstdexec::__invoke_::__invoke_t161 STDEXEC_ATTRIBUTE(host, device, always_inline) 162 constexpr auto operator()(_Fun&& __fun) const noexcept(noexcept(static_cast<_Fun&&>(__fun)())) 163 -> decltype(static_cast<_Fun&&>(__fun)()) { 164 return static_cast<_Fun&&>(__fun)(); 165 } 166 167 template <class _Fun, class _Ty, class... _Args> STDEXEC_ATTRIBUTEstdexec::__invoke_::__invoke_t168 STDEXEC_ATTRIBUTE(host, device, always_inline) 169 constexpr auto operator()(_Fun&& __fun, _Ty&& __ty, _Args&&... __args) const 170 noexcept(noexcept(__invoke_::__invoke_selector(__fun, __ty)( 171 static_cast<_Fun&&>(__fun), 172 static_cast<_Ty&&>(__ty), 173 static_cast<_Args&&>(__args)...))) 174 -> decltype(__invoke_::__invoke_selector(__fun, __ty)( 175 static_cast<_Fun&&>(__fun), 176 static_cast<_Ty&&>(__ty), 177 static_cast<_Args&&>(__args)...)) { 178 return decltype(__invoke_::__invoke_selector(__fun, __ty))()( 179 static_cast<_Fun&&>(__fun), static_cast<_Ty&&>(__ty), static_cast<_Args&&>(__args)...); 180 } 181 }; 182 } // namespace __invoke_ 183 184 inline constexpr __invoke_::__invoke_t __invoke{}; 185 186 template <class _Fun, class... _As> 187 concept __invocable = requires(_Fun&& __f, _As&&... __as) { 188 __invoke(static_cast<_Fun &&>(__f), static_cast<_As &&>(__as)...); 189 }; 190 191 template <class _Fun, class... _As> 192 concept __nothrow_invocable = __invocable<_Fun, _As...> && requires(_Fun&& __f, _As&&... __as) { 193 { __invoke(static_cast<_Fun &&>(__f), static_cast<_As &&>(__as)...) } noexcept; 194 }; 195 196 template <class _Fun, class... _As> 197 using __invoke_result_t = decltype(__invoke(__declval<_Fun>(), __declval<_As>()...)); 198 199 namespace __apply_ { 200 using std::get; 201 202 template <std::size_t... _Is, class _Fn, class _Tup> STDEXEC_ATTRIBUTE(always_inline)203 STDEXEC_ATTRIBUTE(always_inline) 204 constexpr auto __impl(__indices<_Is...>, _Fn&& __fn, _Tup&& __tup) noexcept( 205 noexcept(__invoke(static_cast<_Fn&&>(__fn), get<_Is>(static_cast<_Tup&&>(__tup))...))) 206 -> decltype(__invoke(static_cast<_Fn&&>(__fn), get<_Is>(static_cast<_Tup&&>(__tup))...)) { 207 return __invoke(static_cast<_Fn&&>(__fn), get<_Is>(static_cast<_Tup&&>(__tup))...); 208 } 209 210 template <class _Tup> 211 using __tuple_indices = __make_indices<std::tuple_size<std::remove_cvref_t<_Tup>>::value>; 212 213 template <class _Fn, class _Tup> 214 using __result_t = 215 decltype(__apply_::__impl(__tuple_indices<_Tup>(), __declval<_Fn>(), __declval<_Tup>())); 216 } // namespace __apply_ 217 218 template <class _Fn, class _Tup> 219 concept __applicable = __mvalid<__apply_::__result_t, _Fn, _Tup>; 220 221 template <class _Fn, class _Tup> 222 concept __nothrow_applicable = 223 __applicable<_Fn, _Tup> 224 && noexcept( 225 __apply_::__impl(__apply_::__tuple_indices<_Tup>(), __declval<_Fn>(), __declval<_Tup>())); 226 227 template <class _Fn, class _Tup> 228 requires __applicable<_Fn, _Tup> 229 using __apply_result_t = __apply_::__result_t<_Fn, _Tup>; 230 231 struct __apply_t { 232 template <class _Fn, class _Tup> 233 requires __applicable<_Fn, _Tup> STDEXEC_ATTRIBUTEstdexec::__apply_t234 STDEXEC_ATTRIBUTE(always_inline) 235 constexpr auto operator()(_Fn&& __fn, _Tup&& __tup) const 236 noexcept(__nothrow_applicable<_Fn, _Tup>) -> __apply_result_t<_Fn, _Tup> { 237 return __apply_::__impl( 238 __apply_::__tuple_indices<_Tup>(), static_cast<_Fn&&>(__fn), static_cast<_Tup&&>(__tup)); 239 } 240 }; 241 242 inline constexpr __apply_t __apply{}; 243 } // namespace stdexec 244