1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__execution_fwd.hpp"
19 
20 // include these after __execution_fwd.hpp
21 #include "../functional.hpp"
22 #include "../stop_token.hpp"
23 #include "__basic_sender.hpp"
24 #include "__cpo.hpp"
25 #include "__env.hpp"
26 #include "__intrusive_ptr.hpp"
27 #include "__intrusive_slist.hpp"
28 #include "__meta.hpp"
29 #include "__optional.hpp"
30 #include "__transform_completion_signatures.hpp"
31 #include "__tuple.hpp"
32 #include "__variant.hpp"
33 
34 #include <exception>
35 #include <mutex>
36 
37 namespace stdexec
38 {
39 ////////////////////////////////////////////////////////////////////////////
40 // shared components of split and ensure_started
41 //
42 // The split and ensure_started algorithms are very similar in implementation.
43 // The salient differences are:
44 //
45 // split: the input async operation is always connected. It is only
46 //   started when one of the split senders is connected and started.
47 //   split senders are copyable, so there are multiple operation states
48 //   to be notified on completion. These are stored in an instrusive
49 //   linked list.
50 //
51 // ensure_started: the input async operation is always started, so
52 //   the internal receiver will always be completed. The ensure_started
53 //   sender is move-only and single-shot, so there will only ever be one
54 //   operation state to be notified on completion.
55 //
56 // The shared state should add-ref itself when the input async
57 // operation is started and release itself when its completion
58 // is notified.
59 namespace __shared
60 {
61 template <class _BaseEnv>
62 using __env_t =                //
63     __env::__join_t<prop<get_stop_token_t, inplace_stop_token>,
64                     _BaseEnv>; // BUGBUG NOT TO SPEC
65 
66 template <class _Receiver>
__make_notify_visitor(_Receiver & __rcvr)67 auto __make_notify_visitor(_Receiver& __rcvr) noexcept
68 {
69     return [&]<class _Tuple>(_Tuple&& __tupl) noexcept -> void {
70         __tupl.apply(
71             [&](auto __tag, auto&&... __args) noexcept -> void {
72                 __tag(static_cast<_Receiver&&>(__rcvr),
73                       __forward_like<_Tuple>(__args)...);
74             },
75             __tupl);
76     };
77 }
78 
79 struct __local_state_base : __immovable
80 {
81     using __notify_fn = void(__local_state_base*) noexcept;
82 
83     __notify_fn* __notify_{};
84     __local_state_base* __next_{};
85 };
86 
87 template <class _CvrefSender, class _Env>
88 struct __shared_state;
89 
90 // The operation state of ensure_started, and each operation state of split, has
91 // one of these, created when the sender is connected. There are 0 or more of
92 // them for each underlying async operation. It is what ensure_started- and
93 // split-sender's `get_state` fn returns. It holds a ref count to the shared
94 // state.
95 template <class _CvrefSender, class _Receiver>
96 struct __local_state :
97     __local_state_base,
98     __enable_receiver_from_this<_CvrefSender, _Receiver>
99 {
100     using __tag_t = tag_of_t<_CvrefSender>;
101     using __stok_t = stop_token_of_t<env_of_t<_Receiver>>;
102     static_assert(__one_of<__tag_t, __split::__split_t,
103                            __ensure_started::__ensure_started_t>);
104 
__local_statestdexec::__shared::__local_state105     explicit __local_state(_CvrefSender&& __sndr) noexcept :
106         __local_state::__local_state_base{{},
107                                           &__notify<tag_of_t<_CvrefSender>>},
108         __sh_state_(__get_sh_state(__sndr))
109     {}
110 
~__local_statestdexec::__shared::__local_state111     ~__local_state()
112     {
113         __sh_state_t::__detach(__sh_state_);
114     }
115 
116     // Stop request callback:
operator ()stdexec::__shared::__local_state117     void operator()() noexcept
118     {
119         // We reach here when a split/ensure_started sender has received a stop
120         // request from the receiver to which it is connected.
121         if (std::unique_lock __lock{__sh_state_->__mutex_})
122         {
123             // Remove this operation from the waiters list. Removal can fail if:
124             //   1. It was already removed by another thread, or
125             //   2. It hasn't been added yet (see `start` below), or
126             //   3. The underlying operation has already completed.
127             //
128             // In each case, the right thing to do is nothing. If (1) then we
129             // raced with another thread and lost. In that case, the other
130             // thread will take care of it. If (2) then `start` will take care
131             // of it. If (3) then this stop request is safe to ignore.
132             if (!__sh_state_->__waiters_.remove(this))
133                 return;
134         }
135 
136         // The following code and the __notify function cannot both execute.
137         // This is because the
138         // __notify function is called from the shared state's __notify_waiters
139         // function, which first sets __waiters_ to the completed state. As a
140         // result, the attempt to remove `this` from the waiters list above will
141         // fail and this stop request is ignored.
142         __sh_state_t::__detach(__sh_state_);
143         stdexec::set_stopped(static_cast<_Receiver&&>(this->__receiver()));
144     }
145 
146     // This is called from __shared_state::__notify_waiters when the input async
147     // operation completes; or, if it has already completed when start is
148     // called, it is called from start:
149     // __notify cannot race with __on_stop_request. See comment in
150     // __on_stop_request.
151     template <class _Tag>
__notifystdexec::__shared::__local_state152     static void __notify(__local_state_base* __base) noexcept
153     {
154         auto* const __self = static_cast<__local_state*>(__base);
155 
156         // The split algorithm sends by T const&. ensure_started sends by T&&.
157         constexpr bool __is_split = same_as<__split::__split_t, _Tag>;
158         using __variant_t = decltype(__self->__sh_state_->__results_);
159         using __cv_variant_t =
160             __if_c<__is_split, const __variant_t&, __variant_t>;
161 
162         __self->__on_stop_.reset();
163 
164         auto __visitor = __make_notify_visitor(__self->__receiver());
165         __variant_t::visit(__visitor, static_cast<__cv_variant_t&&>(
166                                           __self->__sh_state_->__results_));
167     }
168 
__get_sh_statestdexec::__shared::__local_state169     static auto __get_sh_state(_CvrefSender& __sndr) noexcept
170     {
171         return __sndr
172             .apply(static_cast<_CvrefSender&&>(__sndr), __detail::__get_data())
173             .__sh_state_;
174     }
175 
176     using __sh_state_ptr_t = __result_of<__get_sh_state, _CvrefSender&>;
177     using __sh_state_t = typename __sh_state_ptr_t::element_type;
178 
179     __optional<stop_callback_for_t<__stok_t, __local_state&>> __on_stop_{};
180     __sh_state_ptr_t __sh_state_;
181 };
182 
183 template <class _CvrefSenderId, class _EnvId>
184 struct __receiver
185 {
186     using _CvrefSender = stdexec::__cvref_t<_CvrefSenderId>;
187     using _Env = stdexec::__t<_EnvId>;
188 
189     struct __t
190     {
191         using receiver_concept = receiver_t;
192         using __id = __receiver;
193 
194         template <class... _As>
195         STDEXEC_ATTRIBUTE((always_inline))
set_valuestdexec::__shared::__receiver::__t196         void set_value(_As&&... __as) noexcept
197         {
198             __sh_state_->__complete(set_value_t(), static_cast<_As&&>(__as)...);
199         }
200 
201         template <class _Error>
202         STDEXEC_ATTRIBUTE((always_inline))
set_errorstdexec::__shared::__receiver::__t203         void set_error(_Error&& __err) noexcept
204         {
205             __sh_state_->__complete(set_error_t(),
206                                     static_cast<_Error&&>(__err));
207         }
208 
209         STDEXEC_ATTRIBUTE((always_inline))
set_stoppedstdexec::__shared::__receiver::__t210         void set_stopped() noexcept
211         {
212             __sh_state_->__complete(set_stopped_t());
213         }
214 
get_envstdexec::__shared::__receiver::__t215         auto get_env() const noexcept -> const __env_t<_Env>&
216         {
217             return __sh_state_->__env_;
218         }
219 
220         // The receiver does not hold a reference to the shared state.
221         __shared_state<_CvrefSender, _Env>* __sh_state_;
222     };
223 };
224 
__get_tombstone()225 inline __local_state_base* __get_tombstone() noexcept
226 {
227     static __local_state_base __tombstone_{{}, nullptr, nullptr};
228     return &__tombstone_;
229 }
230 
231 template <class _CvrefSender, class _Env>
232 struct __shared_state :
233     private __enable_intrusive_from_this<__shared_state<_CvrefSender, _Env>, 2>
234 {
235     using __receiver_t = __t<__receiver<__cvref_id<_CvrefSender>, __id<_Env>>>;
236     using __waiters_list_t = __intrusive_slist<&__local_state_base::__next_>;
237 
238     using __variant_t = //
239         __transform_completion_signatures<
240             __completion_signatures_of_t<_CvrefSender, _Env>,
241             __mbind_front_q<__decayed_tuple, set_value_t>::__f,
242             __mbind_front_q<__decayed_tuple, set_error_t>::__f,
243             __tuple_for<set_error_t, std::exception_ptr>,
244             __munique<__mbind_front_q<__variant_for,
245                                       __tuple_for<set_stopped_t>>>::__f,
246             __tuple_for<set_error_t, std::exception_ptr>>;
247 
248     static constexpr std::size_t __started_bit = 0;
249     static constexpr std::size_t __completed_bit = 1;
250 
251     inplace_stop_source __stop_source_{};
252     __env_t<_Env> __env_;
253     __variant_t __results_{}; // Defaults to the "set_stopped" state
254     std::mutex __mutex_;      // This mutex guards access to __waiters_.
255     __waiters_list_t __waiters_{};
256     connect_result_t<_CvrefSender, __receiver_t> __shared_op_;
257 
__shared_statestdexec::__shared::__shared_state258     explicit __shared_state(_CvrefSender&& __sndr, _Env __env) :
259         __env_(__env::__join(prop{get_stop_token, __stop_source_.get_token()},
260                              static_cast<_Env&&>(__env))),
261         __shared_op_(
262             connect(static_cast<_CvrefSender&&>(__sndr), __receiver_t{this}))
263     {
264         // add one ref count to account for the case where there are no watchers
265         // left but the shared op is still running.
266         this->__inc_ref();
267     }
268 
269     // The caller of this wants to release their reference to the shared state.
270     // The ref count must be at least 2 at this point: one owned by the caller,
271     // and one added in the
272     // __shared_state ctor.
__detachstdexec::__shared::__shared_state273     static void __detach(__intrusive_ptr<__shared_state, 2>& __ptr) noexcept
274     {
275         // Ask the intrusive ptr to stop managing the reference count so we can
276         // manage it manually.
277         if (auto* __self = __ptr.__release_())
278         {
279             auto __old = __self->__dec_ref();
280             STDEXEC_ASSERT(__count(__old) >= 2);
281 
282             if (__count(__old) == 2)
283             {
284                 // The last watcher has released its reference. Asked the shared
285                 // op to stop.
286                 static_cast<__shared_state*>(__self)
287                     ->__stop_source_.request_stop();
288 
289                 // Additionally, if the shared op was never started, or if it
290                 // has already completed, then the shared state is no longer
291                 // needed. Decrement the ref count to 0 here, which will delete
292                 // __self.
293                 if (!__bit<__started_bit>(__old) ||
294                     __bit<__completed_bit>(__old))
295                 {
296                     __self->__dec_ref();
297                 }
298             }
299         }
300     }
301 
302     /// @post The started bit is set in the shared state's ref count, OR the
303     /// __waiters_ list is set to the known "tombstone" value indicating
304     /// completion.
__try_startstdexec::__shared::__shared_state305     void __try_start() noexcept
306     {
307         // With the split algorithm, multiple split senders can be started
308         // simultaneously, but only one should start the shared async operation.
309         // If the "started" bit is set, then someone else has already started
310         // the shared operation. Do nothing.
311         if (this->template __is_set<__started_bit>())
312         {
313             return;
314         }
315         else if (__bit<__started_bit>(
316                      this->template __set_bit<__started_bit>()))
317         {
318             return;
319         }
320         else if (__stop_source_.stop_requested())
321         {
322             // Stop has already been requested. Rather than starting the
323             // operation, complete with set_stopped immediately.
324             // 1. Sets __waiters_ to a known "tombstone" value
325             // 2. Notifies all the waiters that the operation has stopped
326             // 3. Sets the "completed" bit in the ref count.
327             __notify_waiters();
328             return;
329         }
330         else
331         {
332             stdexec::start(__shared_op_);
333         }
334     }
335 
336     template <class _StopToken>
__try_add_waiterstdexec::__shared::__shared_state337     bool __try_add_waiter(__local_state_base* __waiter,
338                           _StopToken __stok) noexcept
339     {
340         std::unique_lock __lock{__mutex_};
341         if (__waiters_.front() == __get_tombstone())
342         {
343             // The work has already completed. Notify the waiter immediately.
344             __lock.unlock();
345             __waiter->__notify_(__waiter);
346             return true;
347         }
348         else if (__stok.stop_requested())
349         {
350             // Stop has been requested. Do not add the waiter.
351             return false;
352         }
353         else
354         {
355             // Add the waiter to the list.
356             __waiters_.push_front(__waiter);
357             return true;
358         }
359     }
360 
361     /// @brief This is called when the shared async operation completes.
362     /// @post __waiters_ is set to a known "tombstone" value.
363     template <class _Tag, class... _As>
__completestdexec::__shared::__shared_state364     void __complete(_Tag, _As&&... __as) noexcept
365     {
366         try
367         {
368             using __tuple_t = __decayed_tuple<_Tag, _As...>;
369             __results_.template emplace<__tuple_t>(_Tag(),
370                                                    static_cast<_As&&>(__as)...);
371         }
372         catch (...)
373         {
374             using __tuple_t = __decayed_tuple<set_error_t, std::exception_ptr>;
375             __results_.template emplace<__tuple_t>(set_error,
376                                                    std::current_exception());
377         }
378 
379         __notify_waiters();
380     }
381 
382     /// @brief This is called when the shared async operation completes.
383     /// @post __waiters_ is set to a known "tombstone" value.
__notify_waitersstdexec::__shared::__shared_state384     void __notify_waiters() noexcept
385     {
386         __waiters_list_t __waiters_copy{__get_tombstone()};
387 
388         // Set the waiters list to a known "tombstone" value that we can check
389         // later.
390         {
391             std::lock_guard __lock{__mutex_};
392             __waiters_.swap(__waiters_copy);
393         }
394 
395         STDEXEC_ASSERT(__waiters_copy.front() != __get_tombstone());
396         for (auto __itr = __waiters_copy.begin();
397              __itr != __waiters_copy.end();)
398         {
399             __local_state_base* __item = *__itr;
400 
401             // We must increment the iterator before calling notify, since
402             // notify may end up triggering *__item to be destructed on another
403             // thread, and the intrusive slist's iterator increment relies on
404             // __item.
405             ++__itr;
406 
407             __item->__notify_(__item);
408         }
409 
410         // Set the "completed" bit in the ref count. If the ref count is 1, then
411         // there are no more waiters. Release the final reference.
412         if (__count(this->template __set_bit<__completed_bit>()) == 1)
413         {
414             this->__dec_ref(); // release the extra ref count, deletes this
415         }
416     }
417 };
418 
419 template <class _Cvref, class _CvrefSender, class _Env>
420 using __make_completions = //
421     __try_make_completion_signatures<
422         // NOT TO SPEC:
423         // See https://github.com/cplusplus/sender-receiver/issues/23
424         _CvrefSender, __env_t<_Env>,
425         completion_signatures<set_error_t(
426                                   __minvoke<_Cvref, std::exception_ptr>),
427                               set_stopped_t()>, // NOT TO SPEC
428         __mtransform<_Cvref,
429                      __mcompose<__q<completion_signatures>, __qf<set_value_t>>>,
430         __mtransform<
431             _Cvref, __mcompose<__q<completion_signatures>, __qf<set_error_t>>>>;
432 
433 // split completes with const T&. ensure_started completes with T&&.
434 template <class _Tag>
435 using __cvref_results_t = //
436     __mcompose<__if_c<same_as<_Tag, __split::__split_t>, __cpclr, __cp>,
437                __q<__decay_t>>;
438 
439 // NOTE: the use of __mapply in the return type below takes advantage of the
440 // fact that _ShState denotes an instance of the __shared_state template, which
441 // is parameterized on the cvref-qualified sender and the environment.
442 template <class _Tag, class _ShState>
443 using __completions = //
444     __mapply<__mbind_front_q<__make_completions, __cvref_results_t<_Tag>>,
445              _ShState>;
446 
447 template <class _CvrefSender, class _Env, bool _Copyable = true>
448 struct __box
449 {
450     using __tag_t = __if_c<_Copyable, __split::__split_t,
451                            __ensure_started::__ensure_started_t>;
452     using __sh_state_t = __shared_state<_CvrefSender, _Env>;
453 
__boxstdexec::__shared::__box454     __box(__tag_t, __intrusive_ptr<__sh_state_t, 2> __sh_state) noexcept :
455         __sh_state_(std::move(__sh_state))
456     {}
457 
458     __box(__box&&) noexcept = default;
459     __box(const __box&) noexcept
460         requires _Copyable
461     = default;
462 
~__boxstdexec::__shared::__box463     ~__box()
464     {
465         __sh_state_t::__detach(__sh_state_);
466     }
467 
468     __intrusive_ptr<__sh_state_t, 2> __sh_state_;
469 };
470 
471 template <class _CvrefSender, class _Env>
472 __box(__split::__split_t,
473       __intrusive_ptr<__shared_state<_CvrefSender, _Env>, 2>) //
474     ->__box<_CvrefSender, _Env, true>;
475 
476 template <class _CvrefSender, class _Env>
477 __box(__ensure_started::__ensure_started_t,
478       __intrusive_ptr<__shared_state<_CvrefSender, _Env>, 2>)
479     -> __box<_CvrefSender, _Env, false>;
480 
481 template <class _Tag>
482 struct __shared_impl : __sexpr_defaults
483 {
484     static constexpr auto get_state = //
485         []<class _CvrefSender, class _Receiver>(
486             _CvrefSender&& __sndr,
487             _Receiver&) noexcept -> __local_state<_CvrefSender, _Receiver> {
488         static_assert(sender_expr_for<_CvrefSender, _Tag>);
489         return __local_state<_CvrefSender, _Receiver>{
490             static_cast<_CvrefSender&&>(__sndr)};
491     };
492 
493     static constexpr auto get_completion_signatures = //
494         []<class _Self>(const _Self&, auto&&...) noexcept
495         -> __completions<_Tag, typename __data_of<_Self>::__sh_state_t> {
496         static_assert(sender_expr_for<_Self, _Tag>);
497         return {};
498     };
499 
500     static constexpr auto start = //
501         []<class _Sender, class _Receiver>(
502             __local_state<_Sender, _Receiver>& __self,
503             _Receiver& __rcvr) noexcept -> void {
504         using __sh_state_t =
505             typename __local_state<_Sender, _Receiver>::__sh_state_t;
506         // Scenario: there are no more split senders, this is the only operation
507         // state, the underlying operation has not yet been started, and the
508         // receiver's stop token is already in the "stop requested" state. Then
509         // registering the stop callback will call
510         // __on_stop_request on __self synchronously. It may also be called
511         // asynchronously at any point after the callback is registered. Beware.
512         // We are guaranteed, however, that
513         // __on_stop_request will not complete the operation or decrement the
514         // shared state's ref count until after __self has been added to the
515         // waiters list.
516         const auto __stok = stdexec::get_stop_token(stdexec::get_env(__rcvr));
517         __self.__on_stop_.emplace(__stok, __self);
518 
519         // We haven't put __self in the waiters list yet and we are holding a
520         // ref count to
521         // __sh_state_, so nothing can happen to the __sh_state_ here.
522 
523         // Start the shared op. As an optimization, skip it if the receiver's
524         // stop token has already been signaled.
525         if (!__stok.stop_requested())
526         {
527             __self.__sh_state_->__try_start();
528             if (__self.__sh_state_->__try_add_waiter(&__self, __stok))
529             {
530                 // successfully added the waiter
531                 return;
532             }
533         }
534 
535         // Otherwise, failed to add the waiter because of a stop-request.
536         // Complete synchronously with set_stopped().
537         __self.__on_stop_.reset();
538         __sh_state_t::__detach(__self.__sh_state_);
539         stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr));
540     };
541 };
542 } // namespace __shared
543 } // namespace stdexec
544