1 /* 2 * Copyright (c) 2021-2024 NVIDIA Corporation 3 * 4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 5 * (the "License"); you may not use this file except in compliance with 6 * the License. You may obtain a copy of the License at 7 * 8 * https://llvm.org/LICENSE.txt 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include "__execution_legacy.hpp" 19 #include "__execution_fwd.hpp" 20 21 // include these after __execution_fwd.hpp 22 #include "__basic_sender.hpp" 23 #include "__diagnostics.hpp" 24 #include "__domain.hpp" 25 #include "__meta.hpp" 26 #include "__senders_core.hpp" 27 #include "__sender_adaptor_closure.hpp" 28 #include "__transform_completion_signatures.hpp" 29 #include "__transform_sender.hpp" 30 #include "__senders.hpp" // IWYU pragma: keep for __well_formed_sender 31 32 STDEXEC_PRAGMA_PUSH() 33 STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces") 34 35 namespace stdexec { 36 ///////////////////////////////////////////////////////////////////////////// 37 // [execution.senders.adaptors.bulk] 38 namespace __bulk { 39 struct bulk_t; 40 struct bulk_chunked_t; 41 struct bulk_unchunked_t; 42 43 //! Wrapper for a policy object. 44 //! 45 //! If we wrap a standard execution policy, we don't store anything, as we know the type. 46 //! Stores the execution policy object if it's a non-standard one. 47 //! Provides a way to query the execution policy object. 48 template <class _Pol> 49 struct __policy_wrapper { 50 _Pol __pol_; 51 __policy_wrapperstdexec::__bulk::__policy_wrapper52 /*implicit*/ __policy_wrapper(_Pol __pol) 53 : __pol_{__pol} { 54 } 55 __getstdexec::__bulk::__policy_wrapper56 const _Pol& __get() const noexcept { 57 return __pol_; 58 } 59 }; 60 61 template <> 62 struct __policy_wrapper<sequenced_policy> { __policy_wrapperstdexec::__bulk::__policy_wrapper63 /*implicit*/ __policy_wrapper(const sequenced_policy&) { 64 } 65 __getstdexec::__bulk::__policy_wrapper66 const sequenced_policy& __get() const noexcept { 67 return seq; 68 } 69 }; 70 71 template <> 72 struct __policy_wrapper<parallel_policy> { __policy_wrapperstdexec::__bulk::__policy_wrapper73 /*implicit*/ __policy_wrapper(const parallel_policy&) { 74 } 75 __getstdexec::__bulk::__policy_wrapper76 const parallel_policy& __get() const noexcept { 77 return par; 78 } 79 }; 80 81 template <> 82 struct __policy_wrapper<parallel_unsequenced_policy> { __policy_wrapperstdexec::__bulk::__policy_wrapper83 /*implicit*/ __policy_wrapper(const parallel_unsequenced_policy&) { 84 } 85 __getstdexec::__bulk::__policy_wrapper86 const parallel_unsequenced_policy& __get() const noexcept { 87 return par_unseq; 88 } 89 }; 90 91 template <> 92 struct __policy_wrapper<unsequenced_policy> { __policy_wrapperstdexec::__bulk::__policy_wrapper93 /*implicit*/ __policy_wrapper(const unsequenced_policy&) { 94 } 95 __getstdexec::__bulk::__policy_wrapper96 const unsequenced_policy& __get() const noexcept { 97 return unseq; 98 } 99 }; 100 101 template <class _Pol, class _Shape, class _Fun> 102 struct __data { 103 STDEXEC_ATTRIBUTE(no_unique_address) __policy_wrapper<_Pol> __pol_; 104 _Shape __shape_; 105 STDEXEC_ATTRIBUTE(no_unique_address) _Fun __fun_; 106 static constexpr auto __mbrs_ = 107 __mliterals<&__data::__pol_, &__data::__shape_, &__data::__fun_>(); 108 }; 109 template <class _Pol, class _Shape, class _Fun> 110 __data(const _Pol&, _Shape, _Fun) -> __data<_Pol, _Shape, _Fun>; 111 112 template <class _AlgoTag> 113 struct __bulk_traits; 114 115 template <> 116 struct __bulk_traits<bulk_t> { 117 using __on_not_callable = 118 __callable_error<"In stdexec::bulk(Sender, Policy, Shape, Function)..."_mstr>; 119 120 // Curried function, after passing the required indices. 121 template <class _Fun, class _Shape> 122 using __fun_curried = 123 __mbind_front<__mtry_catch_q<__nothrow_invocable_t, __on_not_callable>, _Fun, _Shape>; 124 }; 125 126 template <> 127 struct __bulk_traits<bulk_chunked_t> { 128 using __on_not_callable = 129 __callable_error<"In stdexec::bulk_chunked(Sender, Policy, Shape, Function)..."_mstr>; 130 131 // Curried function, after passing the required indices. 132 template <class _Fun, class _Shape> 133 using __fun_curried = __mbind_front< 134 __mtry_catch_q<__nothrow_invocable_t, __on_not_callable>, 135 _Fun, 136 _Shape, 137 _Shape 138 >; 139 }; 140 141 template <> 142 struct __bulk_traits<bulk_unchunked_t> { 143 using __on_not_callable = 144 __callable_error<"In stdexec::bulk_unchunked(Sender, Policy, Shape, Function)..."_mstr>; 145 146 // Curried function, after passing the required indices. 147 template <class _Fun, class _Shape> 148 using __fun_curried = 149 __mbind_front<__mtry_catch_q<__nothrow_invocable_t, __on_not_callable>, _Fun, _Shape>; 150 }; 151 152 template <class _Ty> 153 using __decay_ref = __decay_t<_Ty>&; 154 155 template <class _AlgoTag, class _Fun, class _Shape, class _CvrefSender, class... _Env> 156 using __with_error_invoke_t = __if< 157 __value_types_t< 158 __completion_signatures_of_t<_CvrefSender, _Env...>, 159 __mtransform< 160 __q<__decay_ref>, 161 typename __bulk_traits<_AlgoTag>::template __fun_curried<_Fun, _Shape> 162 >, 163 __q<__mand> 164 >, 165 completion_signatures<>, 166 __eptr_completion 167 >; 168 169 170 template <class _AlgoTag, class _Fun, class _Shape, class _CvrefSender, class... _Env> 171 using __completion_signatures = transform_completion_signatures< 172 __completion_signatures_of_t<_CvrefSender, _Env...>, 173 __with_error_invoke_t<_AlgoTag, _Fun, _Shape, _CvrefSender, _Env...> 174 >; 175 176 template <class _AlgoTag> 177 struct __generic_bulk_t { // NOLINT(bugprone-crtp-constructor-accessibility) 178 template <sender _Sender, typename _Policy, integral _Shape, copy_constructible _Fun> 179 requires is_execution_policy_v<std::remove_cvref_t<_Policy>> STDEXEC_ATTRIBUTEstdexec::__bulk::__generic_bulk_t180 STDEXEC_ATTRIBUTE(host, device) 181 auto operator()(_Sender&& __sndr, _Policy&& __pol, _Shape __shape, _Fun __fun) const 182 -> __well_formed_sender auto { 183 auto __domain = __get_early_domain(__sndr); 184 return stdexec::transform_sender( 185 __domain, 186 __make_sexpr<_AlgoTag>( 187 __data{__pol, __shape, static_cast<_Fun&&>(__fun)}, static_cast<_Sender&&>(__sndr))); 188 } 189 190 template <typename _Policy, integral _Shape, copy_constructible _Fun> 191 requires is_execution_policy_v<std::remove_cvref_t<_Policy>> STDEXEC_ATTRIBUTEstdexec::__bulk::__generic_bulk_t192 STDEXEC_ATTRIBUTE(always_inline) 193 auto operator()(_Policy&& __pol, _Shape __shape, _Fun __fun) const 194 -> __binder_back<_AlgoTag, _Policy, _Shape, _Fun> { 195 return { 196 {static_cast<_Policy&&>(__pol), 197 static_cast<_Shape&&>(__shape), 198 static_cast<_Fun&&>(__fun)}, 199 {}, 200 {} 201 }; 202 } 203 204 template <sender _Sender, integral _Shape, copy_constructible _Fun> 205 [[deprecated( 206 "The bulk algorithm now requires an execution policy such as stdexec::par as an " 207 "argument.")]] STDEXEC_ATTRIBUTEstdexec::__bulk::__generic_bulk_t208 STDEXEC_ATTRIBUTE(host, device) auto 209 operator()(_Sender&& __sndr, _Shape __shape, _Fun __fun) const { 210 return (*this)( 211 static_cast<_Sender&&>(__sndr), 212 par, 213 static_cast<_Shape&&>(__shape), 214 static_cast<_Fun&&>(__fun)); 215 } 216 217 template <integral _Shape, copy_constructible _Fun> 218 [[deprecated( 219 "The bulk algorithm now requires an execution policy such as stdexec::par as an " 220 "argument.")]] STDEXEC_ATTRIBUTEstdexec::__bulk::__generic_bulk_t221 STDEXEC_ATTRIBUTE(always_inline) auto operator()(_Shape __shape, _Fun __fun) const { 222 return (*this)(par, static_cast<_Shape&&>(__shape), static_cast<_Fun&&>(__fun)); 223 } 224 }; 225 226 struct bulk_t : __generic_bulk_t<bulk_t> { 227 template <class _Env> __transform_sender_fnstdexec::__bulk::bulk_t228 static auto __transform_sender_fn(const _Env&) { 229 return [&]<class _Data, class _Child>(__ignore, _Data&& __data, _Child&& __child) { 230 using __shape_t = std::remove_cvref_t<decltype(__data.__shape_)>; 231 auto __new_f = 232 [__func = std::move( 233 __data.__fun_)](__shape_t __begin, __shape_t __end, auto&&... __vs) mutable 234 #if !STDEXEC_MSVC() 235 // MSVCBUG https://developercommunity.visualstudio.com/t/noexcept-expression-in-lambda-template-n/10718680 236 noexcept(noexcept(__data.__fun_(__begin++, __vs...))) 237 #endif 238 { 239 while (__begin != __end) 240 __func(__begin++, __vs...); 241 }; 242 243 // Lower `bulk` to `bulk_chunked`. If `bulk_chunked` is customized, we will see the customization. 244 return bulk_chunked( 245 static_cast<_Child&&>(__child), 246 __data.__pol_.__get(), 247 __data.__shape_, 248 std::move(__new_f)); 249 }; 250 } 251 252 template <class _Sender, class _Env> transform_senderstdexec::__bulk::bulk_t253 static auto transform_sender(_Sender&& __sndr, const _Env& __env) { 254 return __sexpr_apply(static_cast<_Sender&&>(__sndr), __transform_sender_fn(__env)); 255 } 256 }; 257 258 struct bulk_chunked_t : __generic_bulk_t<bulk_chunked_t> { }; 259 260 struct bulk_unchunked_t : __generic_bulk_t<bulk_unchunked_t> { }; 261 262 template <class _AlgoTag> 263 struct __bulk_impl_base : __sexpr_defaults { 264 template <class _Sender> 265 using __fun_t = decltype(__decay_t<__data_of<_Sender>>::__fun_); 266 267 template <class _Sender> 268 using __shape_t = decltype(__decay_t<__data_of<_Sender>>::__shape_); 269 270 static constexpr auto get_completion_signatures = 271 []<class _Sender, class... _Env>(_Sender&&, _Env&&...) noexcept -> __completion_signatures< 272 _AlgoTag, 273 __fun_t<_Sender>, 274 __shape_t<_Sender>, 275 __child_of<_Sender>, 276 _Env... 277 > { 278 static_assert(sender_expr_for<_Sender, bulk_t>); 279 return {}; 280 }; 281 }; 282 283 struct __bulk_chunked_impl : __bulk_impl_base<bulk_chunked_t> { 284 //! This implements the core default behavior for `bulk_chunked`: 285 //! When setting value, it calls the function with the entire range. 286 //! Note: This is not done in parallel. That is customized by the scheduler. 287 //! See, e.g., static_thread_pool::bulk_receiver::__t. 288 static constexpr auto complete = 289 []<class _Tag, class _State, class _Receiver, class... _Args>( 290 __ignore, 291 _State& __state, 292 _Receiver& __rcvr, 293 _Tag, 294 _Args&&... __args) noexcept -> void { 295 if constexpr (same_as<_Tag, set_value_t>) { 296 // Intercept set_value and dispatch to the bulk operation. 297 using __shape_t = decltype(__state.__shape_); 298 if constexpr (noexcept(__state.__fun_(__shape_t{}, __shape_t{}, __args...))) { 299 // The noexcept version that doesn't need try/catch: 300 __state.__fun_(static_cast<__shape_t>(0), __state.__shape_, __args...); 301 _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...); 302 } else { 303 STDEXEC_TRY { 304 __state.__fun_(static_cast<__shape_t>(0), __state.__shape_, __args...); 305 _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...); 306 } 307 STDEXEC_CATCH_ALL { 308 stdexec::set_error(static_cast<_Receiver&&>(__rcvr), std::current_exception()); 309 } 310 } 311 } else { 312 _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...); 313 } 314 }; 315 }; 316 317 struct __bulk_unchunked_impl : __bulk_impl_base<bulk_unchunked_t> { 318 //! This implements the core default behavior for `bulk_unchunked`: 319 //! When setting value, it loops over the shape and invokes the function. 320 //! Note: This is not done in concurrently. That is customized by the scheduler. 321 static constexpr auto complete = 322 []<class _Tag, class _State, class _Receiver, class... _Args>( 323 __ignore, 324 _State& __state, 325 _Receiver& __rcvr, 326 _Tag, 327 _Args&&... __args) noexcept -> void { 328 if constexpr (std::same_as<_Tag, set_value_t>) { 329 using __shape_t = decltype(__state.__shape_); 330 if constexpr (noexcept(__state.__fun_(__shape_t{}, __args...))) { 331 // The noexcept version that doesn't need try/catch: 332 for (__shape_t __i{}; __i != __state.__shape_; ++__i) { 333 __state.__fun_(__i, __args...); 334 } 335 _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...); 336 } else { 337 STDEXEC_TRY { 338 for (__shape_t __i{}; __i != __state.__shape_; ++__i) { 339 __state.__fun_(__i, __args...); 340 } 341 _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...); 342 } 343 STDEXEC_CATCH_ALL { 344 stdexec::set_error(static_cast<_Receiver&&>(__rcvr), std::current_exception()); 345 } 346 } 347 } else { 348 _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...); 349 } 350 }; 351 }; 352 353 struct __bulk_impl : __bulk_impl_base<bulk_t> { 354 // Implementation is handled by lowering to `bulk_chunked` in `transform_sender`. 355 }; 356 } // namespace __bulk 357 358 using __bulk::bulk_t; 359 using __bulk::bulk_chunked_t; 360 using __bulk::bulk_unchunked_t; 361 inline constexpr bulk_t bulk{}; 362 inline constexpr bulk_chunked_t bulk_chunked{}; 363 inline constexpr bulk_unchunked_t bulk_unchunked{}; 364 365 template <> 366 struct __sexpr_impl<bulk_t> : __bulk::__bulk_impl { }; 367 368 template <> 369 struct __sexpr_impl<bulk_chunked_t> : __bulk::__bulk_chunked_impl { }; 370 371 template <> 372 struct __sexpr_impl<bulk_unchunked_t> : __bulk::__bulk_unchunked_impl { }; 373 } // namespace stdexec 374 375 STDEXEC_PRAGMA_POP() 376