1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__execution_fwd.hpp"
19 
20 // include these after __execution_fwd.hpp
21 #include "__basic_sender.hpp"
22 #include "__diagnostics.hpp"
23 #include "__domain.hpp"
24 #include "__meta.hpp"
25 #include "__sender_adaptor_closure.hpp"
26 #include "__senders.hpp"
27 #include "__senders_core.hpp"
28 #include "__transform_completion_signatures.hpp"
29 #include "__transform_sender.hpp"
30 
31 STDEXEC_PRAGMA_PUSH()
32 STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces")
33 
34 namespace stdexec
35 {
36 /////////////////////////////////////////////////////////////////////////////
37 // [execution.senders.adaptors.bulk]
38 namespace __bulk
39 {
40 inline constexpr __mstring __bulk_context =
41     "In stdexec::bulk(Sender, Shape, Function)..."_mstr;
42 using __on_not_callable = __callable_error<__bulk_context>;
43 
44 template <class _Shape, class _Fun>
45 struct __data
46 {
47     _Shape __shape_;
48     STDEXEC_ATTRIBUTE((no_unique_address))
49     _Fun __fun_;
50     static constexpr auto __mbrs_ =
51         __mliterals<&__data::__shape_, &__data::__fun_>();
52 };
53 template <class _Shape, class _Fun>
54 __data(_Shape, _Fun) -> __data<_Shape, _Fun>;
55 
56 template <class _Ty>
57 using __decay_ref = __decay_t<_Ty>&;
58 
59 template <class _Catch, class _Fun, class _Shape, class _CvrefSender,
60           class... _Env>
61 using __with_error_invoke_t = //
62     __if<__value_types_t<
63              __completion_signatures_of_t<_CvrefSender, _Env...>,
64              __mtransform<
65                  __q<__decay_ref>,
66                  __mbind_front<__mtry_catch_q<__nothrow_invocable_t, _Catch>,
67                                _Fun, _Shape>>,
68              __q<__mand>>,
69          completion_signatures<>, __eptr_completion>;
70 
71 template <class _Fun, class _Shape, class _CvrefSender, class... _Env>
72 using __completion_signatures = //
73     transform_completion_signatures<
74         __completion_signatures_of_t<_CvrefSender, _Env...>,
75         __with_error_invoke_t<__on_not_callable, _Fun, _Shape, _CvrefSender,
76                               _Env...>>;
77 
78 struct bulk_t
79 {
80     template <sender _Sender, integral _Shape, __movable_value _Fun>
81     STDEXEC_ATTRIBUTE((host, device))
operator ()stdexec::__bulk::bulk_t82     auto operator()(_Sender&& __sndr, _Shape __shape,
83                     _Fun __fun) const -> __well_formed_sender auto
84     {
85         auto __domain = __get_early_domain(__sndr);
86         return stdexec::transform_sender(
87             __domain,
88             __make_sexpr<bulk_t>(__data{__shape, static_cast<_Fun&&>(__fun)},
89                                  static_cast<_Sender&&>(__sndr)));
90     }
91 
92     template <integral _Shape, class _Fun>
93     STDEXEC_ATTRIBUTE((always_inline))
operator ()stdexec::__bulk::bulk_t94     auto operator()(_Shape __shape,
95                     _Fun __fun) const -> __binder_back<bulk_t, _Shape, _Fun>
96     {
97         return {{static_cast<_Shape&&>(__shape), static_cast<_Fun&&>(__fun)},
98                 {},
99                 {}};
100     }
101 
102     // This describes how to use the pieces of a bulk sender to find
103     // legacy customizations of the bulk algorithm.
104     using _Sender = __1;
105     using _Shape = __nth_member<0>(__0);
106     using _Fun = __nth_member<1>(__0);
107     using __legacy_customizations_t =
108         __types<tag_invoke_t(bulk_t,
109                              get_completion_scheduler_t<set_value_t>(
110                                  get_env_t(_Sender&)),
111                              _Sender, _Shape, _Fun),
112                 tag_invoke_t(bulk_t, _Sender, _Shape, _Fun)>;
113 };
114 
115 struct __bulk_impl : __sexpr_defaults
116 {
117     template <class _Sender>
118     using __fun_t = decltype(__decay_t<__data_of<_Sender>>::__fun_);
119 
120     template <class _Sender>
121     using __shape_t = decltype(__decay_t<__data_of<_Sender>>::__shape_);
122 
123     static constexpr auto get_completion_signatures = //
124         []<class _Sender, class... _Env>(_Sender&&, _Env&&...) noexcept
125         -> __completion_signatures<__fun_t<_Sender>, __shape_t<_Sender>,
126                                    __child_of<_Sender>, _Env...> {
127         static_assert(sender_expr_for<_Sender, bulk_t>);
128         return {};
129     };
130 
131     static constexpr auto complete = //
132         []<class _Tag, class _State, class _Receiver, class... _Args>(
133             __ignore, _State& __state, _Receiver& __rcvr, _Tag,
134             _Args&&... __args) noexcept -> void {
135         if constexpr (std::same_as<_Tag, set_value_t>)
136         {
137             using __shape_t = decltype(__state.__shape_);
138             if constexpr (noexcept(__state.__fun_(__shape_t{}, __args...)))
139             {
140                 for (__shape_t __i{}; __i != __state.__shape_; ++__i)
141                 {
142                     __state.__fun_(__i, __args...);
143                 }
144                 _Tag()(static_cast<_Receiver&&>(__rcvr),
145                        static_cast<_Args&&>(__args)...);
146             }
147             else
148             {
149                 try
150                 {
151                     for (__shape_t __i{}; __i != __state.__shape_; ++__i)
152                     {
153                         __state.__fun_(__i, __args...);
154                     }
155                     _Tag()(static_cast<_Receiver&&>(__rcvr),
156                            static_cast<_Args&&>(__args)...);
157                 }
158                 catch (...)
159                 {
160                     stdexec::set_error(static_cast<_Receiver&&>(__rcvr),
161                                        std::current_exception());
162                 }
163             }
164         }
165         else
166         {
167             _Tag()(static_cast<_Receiver&&>(__rcvr),
168                    static_cast<_Args&&>(__args)...);
169         }
170     };
171 };
172 } // namespace __bulk
173 
174 using __bulk::bulk_t;
175 inline constexpr bulk_t bulk{};
176 
177 template <>
178 struct __sexpr_impl<bulk_t> : __bulk::__bulk_impl
179 {};
180 } // namespace stdexec
181 
182 STDEXEC_PRAGMA_POP()
183