xref: /openbmc/sdbusplus/include/sdbusplus/async/stdexec/__detail/__bulk.hpp (revision 10d0b4b7d1498cfd5c3d37edea271a54d1984e41)
1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__execution_legacy.hpp"
19 #include "__execution_fwd.hpp"
20 
21 // include these after __execution_fwd.hpp
22 #include "__basic_sender.hpp"
23 #include "__diagnostics.hpp"
24 #include "__domain.hpp"
25 #include "__meta.hpp"
26 #include "__senders_core.hpp"
27 #include "__sender_adaptor_closure.hpp"
28 #include "__transform_completion_signatures.hpp"
29 #include "__transform_sender.hpp"
30 #include "__senders.hpp" // IWYU pragma: keep for __well_formed_sender
31 
32 STDEXEC_PRAGMA_PUSH()
33 STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces")
34 
35 namespace stdexec {
36   /////////////////////////////////////////////////////////////////////////////
37   // [execution.senders.adaptors.bulk]
38   namespace __bulk {
39     struct bulk_t;
40     struct bulk_chunked_t;
41     struct bulk_unchunked_t;
42 
43     //! Wrapper for a policy object.
44     //!
45     //! If we wrap a standard execution policy, we don't store anything, as we know the type.
46     //! Stores the execution policy object if it's a non-standard one.
47     //! Provides a way to query the execution policy object.
48     template <class _Pol>
49     struct __policy_wrapper {
50       _Pol __pol_;
51 
__policy_wrapperstdexec::__bulk::__policy_wrapper52       /*implicit*/ __policy_wrapper(_Pol __pol)
53         : __pol_{__pol} {
54       }
55 
__getstdexec::__bulk::__policy_wrapper56       const _Pol& __get() const noexcept {
57         return __pol_;
58       }
59     };
60 
61     template <>
62     struct __policy_wrapper<sequenced_policy> {
__policy_wrapperstdexec::__bulk::__policy_wrapper63       /*implicit*/ __policy_wrapper(const sequenced_policy&) {
64       }
65 
__getstdexec::__bulk::__policy_wrapper66       const sequenced_policy& __get() const noexcept {
67         return seq;
68       }
69     };
70 
71     template <>
72     struct __policy_wrapper<parallel_policy> {
__policy_wrapperstdexec::__bulk::__policy_wrapper73       /*implicit*/ __policy_wrapper(const parallel_policy&) {
74       }
75 
__getstdexec::__bulk::__policy_wrapper76       const parallel_policy& __get() const noexcept {
77         return par;
78       }
79     };
80 
81     template <>
82     struct __policy_wrapper<parallel_unsequenced_policy> {
__policy_wrapperstdexec::__bulk::__policy_wrapper83       /*implicit*/ __policy_wrapper(const parallel_unsequenced_policy&) {
84       }
85 
__getstdexec::__bulk::__policy_wrapper86       const parallel_unsequenced_policy& __get() const noexcept {
87         return par_unseq;
88       }
89     };
90 
91     template <>
92     struct __policy_wrapper<unsequenced_policy> {
__policy_wrapperstdexec::__bulk::__policy_wrapper93       /*implicit*/ __policy_wrapper(const unsequenced_policy&) {
94       }
95 
__getstdexec::__bulk::__policy_wrapper96       const unsequenced_policy& __get() const noexcept {
97         return unseq;
98       }
99     };
100 
101     template <class _Pol, class _Shape, class _Fun>
102     struct __data {
103       STDEXEC_ATTRIBUTE(no_unique_address) __policy_wrapper<_Pol> __pol_;
104       _Shape __shape_;
105       STDEXEC_ATTRIBUTE(no_unique_address) _Fun __fun_;
106       static constexpr auto __mbrs_ =
107         __mliterals<&__data::__pol_, &__data::__shape_, &__data::__fun_>();
108     };
109     template <class _Pol, class _Shape, class _Fun>
110     __data(const _Pol&, _Shape, _Fun) -> __data<_Pol, _Shape, _Fun>;
111 
112     template <class _AlgoTag>
113     struct __bulk_traits;
114 
115     template <>
116     struct __bulk_traits<bulk_t> {
117       using __on_not_callable =
118         __callable_error<"In stdexec::bulk(Sender, Policy, Shape, Function)..."_mstr>;
119 
120       // Curried function, after passing the required indices.
121       template <class _Fun, class _Shape>
122       using __fun_curried =
123         __mbind_front<__mtry_catch_q<__nothrow_invocable_t, __on_not_callable>, _Fun, _Shape>;
124     };
125 
126     template <>
127     struct __bulk_traits<bulk_chunked_t> {
128       using __on_not_callable =
129         __callable_error<"In stdexec::bulk_chunked(Sender, Policy, Shape, Function)..."_mstr>;
130 
131       // Curried function, after passing the required indices.
132       template <class _Fun, class _Shape>
133       using __fun_curried = __mbind_front<
134         __mtry_catch_q<__nothrow_invocable_t, __on_not_callable>,
135         _Fun,
136         _Shape,
137         _Shape
138       >;
139     };
140 
141     template <>
142     struct __bulk_traits<bulk_unchunked_t> {
143       using __on_not_callable =
144         __callable_error<"In stdexec::bulk_unchunked(Sender, Policy, Shape, Function)..."_mstr>;
145 
146       // Curried function, after passing the required indices.
147       template <class _Fun, class _Shape>
148       using __fun_curried =
149         __mbind_front<__mtry_catch_q<__nothrow_invocable_t, __on_not_callable>, _Fun, _Shape>;
150     };
151 
152     template <class _Ty>
153     using __decay_ref = __decay_t<_Ty>&;
154 
155     template <class _AlgoTag, class _Fun, class _Shape, class _CvrefSender, class... _Env>
156     using __with_error_invoke_t = __if<
157       __value_types_t<
158         __completion_signatures_of_t<_CvrefSender, _Env...>,
159         __mtransform<
160           __q<__decay_ref>,
161           typename __bulk_traits<_AlgoTag>::template __fun_curried<_Fun, _Shape>
162         >,
163         __q<__mand>
164       >,
165       completion_signatures<>,
166       __eptr_completion
167     >;
168 
169 
170     template <class _AlgoTag, class _Fun, class _Shape, class _CvrefSender, class... _Env>
171     using __completion_signatures = transform_completion_signatures<
172       __completion_signatures_of_t<_CvrefSender, _Env...>,
173       __with_error_invoke_t<_AlgoTag, _Fun, _Shape, _CvrefSender, _Env...>
174     >;
175 
176     template <class _AlgoTag>
177     struct __generic_bulk_t { // NOLINT(bugprone-crtp-constructor-accessibility)
178       template <sender _Sender, typename _Policy, integral _Shape, copy_constructible _Fun>
179         requires is_execution_policy_v<std::remove_cvref_t<_Policy>>
STDEXEC_ATTRIBUTEstdexec::__bulk::__generic_bulk_t180       STDEXEC_ATTRIBUTE(host, device)
181       auto operator()(_Sender&& __sndr, _Policy&& __pol, _Shape __shape, _Fun __fun) const
182         -> __well_formed_sender auto {
183         auto __domain = __get_early_domain(__sndr);
184         return stdexec::transform_sender(
185           __domain,
186           __make_sexpr<_AlgoTag>(
187             __data{__pol, __shape, static_cast<_Fun&&>(__fun)}, static_cast<_Sender&&>(__sndr)));
188       }
189 
190       template <typename _Policy, integral _Shape, copy_constructible _Fun>
191         requires is_execution_policy_v<std::remove_cvref_t<_Policy>>
STDEXEC_ATTRIBUTEstdexec::__bulk::__generic_bulk_t192       STDEXEC_ATTRIBUTE(always_inline)
193       auto operator()(_Policy&& __pol, _Shape __shape, _Fun __fun) const
194         -> __binder_back<_AlgoTag, _Policy, _Shape, _Fun> {
195         return {
196           {static_cast<_Policy&&>(__pol),
197            static_cast<_Shape&&>(__shape),
198            static_cast<_Fun&&>(__fun)},
199           {},
200           {}
201         };
202       }
203 
204       template <sender _Sender, integral _Shape, copy_constructible _Fun>
205       [[deprecated(
206         "The bulk algorithm now requires an execution policy such as stdexec::par as an "
207         "argument.")]]
STDEXEC_ATTRIBUTEstdexec::__bulk::__generic_bulk_t208       STDEXEC_ATTRIBUTE(host, device) auto
209         operator()(_Sender&& __sndr, _Shape __shape, _Fun __fun) const {
210         return (*this)(
211           static_cast<_Sender&&>(__sndr),
212           par,
213           static_cast<_Shape&&>(__shape),
214           static_cast<_Fun&&>(__fun));
215       }
216 
217       template <integral _Shape, copy_constructible _Fun>
218       [[deprecated(
219         "The bulk algorithm now requires an execution policy such as stdexec::par as an "
220         "argument.")]]
STDEXEC_ATTRIBUTEstdexec::__bulk::__generic_bulk_t221       STDEXEC_ATTRIBUTE(always_inline) auto operator()(_Shape __shape, _Fun __fun) const {
222         return (*this)(par, static_cast<_Shape&&>(__shape), static_cast<_Fun&&>(__fun));
223       }
224     };
225 
226     struct bulk_t : __generic_bulk_t<bulk_t> {
227       template <class _Env>
__transform_sender_fnstdexec::__bulk::bulk_t228       static auto __transform_sender_fn(const _Env&) {
229         return [&]<class _Data, class _Child>(__ignore, _Data&& __data, _Child&& __child) {
230           using __shape_t = std::remove_cvref_t<decltype(__data.__shape_)>;
231           auto __new_f =
232             [__func = std::move(
233                __data.__fun_)](__shape_t __begin, __shape_t __end, auto&&... __vs) mutable
234 #if !STDEXEC_MSVC()
235             // MSVCBUG https://developercommunity.visualstudio.com/t/noexcept-expression-in-lambda-template-n/10718680
236             noexcept(noexcept(__data.__fun_(__begin++, __vs...)))
237 #endif
238           {
239             while (__begin != __end)
240               __func(__begin++, __vs...);
241           };
242 
243           // Lower `bulk` to `bulk_chunked`. If `bulk_chunked` is customized, we will see the customization.
244           return bulk_chunked(
245             static_cast<_Child&&>(__child),
246             __data.__pol_.__get(),
247             __data.__shape_,
248             std::move(__new_f));
249         };
250       }
251 
252       template <class _Sender, class _Env>
transform_senderstdexec::__bulk::bulk_t253       static auto transform_sender(_Sender&& __sndr, const _Env& __env) {
254         return __sexpr_apply(static_cast<_Sender&&>(__sndr), __transform_sender_fn(__env));
255       }
256     };
257 
258     struct bulk_chunked_t : __generic_bulk_t<bulk_chunked_t> { };
259 
260     struct bulk_unchunked_t : __generic_bulk_t<bulk_unchunked_t> { };
261 
262     template <class _AlgoTag>
263     struct __bulk_impl_base : __sexpr_defaults {
264       template <class _Sender>
265       using __fun_t = decltype(__decay_t<__data_of<_Sender>>::__fun_);
266 
267       template <class _Sender>
268       using __shape_t = decltype(__decay_t<__data_of<_Sender>>::__shape_);
269 
270       static constexpr auto get_completion_signatures =
271         []<class _Sender, class... _Env>(_Sender&&, _Env&&...) noexcept -> __completion_signatures<
272                                                                           _AlgoTag,
273                                                                           __fun_t<_Sender>,
274                                                                           __shape_t<_Sender>,
275                                                                           __child_of<_Sender>,
276                                                                           _Env...
277                                                                         > {
278         static_assert(sender_expr_for<_Sender, bulk_t>);
279         return {};
280       };
281     };
282 
283     struct __bulk_chunked_impl : __bulk_impl_base<bulk_chunked_t> {
284       //! This implements the core default behavior for `bulk_chunked`:
285       //! When setting value, it calls the function with the entire range.
286       //! Note: This is not done in parallel. That is customized by the scheduler.
287       //! See, e.g., static_thread_pool::bulk_receiver::__t.
288       static constexpr auto complete =
289         []<class _Tag, class _State, class _Receiver, class... _Args>(
290           __ignore,
291           _State& __state,
292           _Receiver& __rcvr,
293           _Tag,
294           _Args&&... __args) noexcept -> void {
295         if constexpr (same_as<_Tag, set_value_t>) {
296           // Intercept set_value and dispatch to the bulk operation.
297           using __shape_t = decltype(__state.__shape_);
298           if constexpr (noexcept(__state.__fun_(__shape_t{}, __shape_t{}, __args...))) {
299             // The noexcept version that doesn't need try/catch:
300             __state.__fun_(static_cast<__shape_t>(0), __state.__shape_, __args...);
301             _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...);
302           } else {
303             STDEXEC_TRY {
304               __state.__fun_(static_cast<__shape_t>(0), __state.__shape_, __args...);
305               _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...);
306             }
307             STDEXEC_CATCH_ALL {
308               stdexec::set_error(static_cast<_Receiver&&>(__rcvr), std::current_exception());
309             }
310           }
311         } else {
312           _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...);
313         }
314       };
315     };
316 
317     struct __bulk_unchunked_impl : __bulk_impl_base<bulk_unchunked_t> {
318       //! This implements the core default behavior for `bulk_unchunked`:
319       //! When setting value, it loops over the shape and invokes the function.
320       //! Note: This is not done in concurrently. That is customized by the scheduler.
321       static constexpr auto complete =
322         []<class _Tag, class _State, class _Receiver, class... _Args>(
323           __ignore,
324           _State& __state,
325           _Receiver& __rcvr,
326           _Tag,
327           _Args&&... __args) noexcept -> void {
328         if constexpr (std::same_as<_Tag, set_value_t>) {
329           using __shape_t = decltype(__state.__shape_);
330           if constexpr (noexcept(__state.__fun_(__shape_t{}, __args...))) {
331             // The noexcept version that doesn't need try/catch:
332             for (__shape_t __i{}; __i != __state.__shape_; ++__i) {
333               __state.__fun_(__i, __args...);
334             }
335             _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...);
336           } else {
337             STDEXEC_TRY {
338               for (__shape_t __i{}; __i != __state.__shape_; ++__i) {
339                 __state.__fun_(__i, __args...);
340               }
341               _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...);
342             }
343             STDEXEC_CATCH_ALL {
344               stdexec::set_error(static_cast<_Receiver&&>(__rcvr), std::current_exception());
345             }
346           }
347         } else {
348           _Tag()(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...);
349         }
350       };
351     };
352 
353     struct __bulk_impl : __bulk_impl_base<bulk_t> {
354       // Implementation is handled by lowering to `bulk_chunked` in `transform_sender`.
355     };
356   } // namespace __bulk
357 
358   using __bulk::bulk_t;
359   using __bulk::bulk_chunked_t;
360   using __bulk::bulk_unchunked_t;
361   inline constexpr bulk_t bulk{};
362   inline constexpr bulk_chunked_t bulk_chunked{};
363   inline constexpr bulk_unchunked_t bulk_unchunked{};
364 
365   template <>
366   struct __sexpr_impl<bulk_t> : __bulk::__bulk_impl { };
367 
368   template <>
369   struct __sexpr_impl<bulk_chunked_t> : __bulk::__bulk_chunked_impl { };
370 
371   template <>
372   struct __sexpr_impl<bulk_unchunked_t> : __bulk::__bulk_unchunked_impl { };
373 } // namespace stdexec
374 
375 STDEXEC_PRAGMA_POP()
376