1 /*
2  * Copyright (c) 2021-2022 Facebook, Inc. and its affiliates
3  * Copyright (c) 2021-2024 NVIDIA Corporation
4  *
5  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
6  * (the "License"); you may not use this file except in compliance with
7  * the License. You may obtain a copy of the License at
8  *
9  *   https://llvm.org/LICENSE.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 #pragma once
18 
19 #include "__detail/__stop_token.hpp"
20 
21 #include <atomic>
22 #include <concepts>
23 #include <cstdint>
24 #include <thread>
25 #include <type_traits>
26 #include <utility>
27 #include <version>
28 
29 #if __has_include(<stop_token>) && __cpp_lib_jthread >= 201911
30 #include <stop_token>
31 #endif
32 
33 namespace stdexec
34 {
35 namespace __stok
36 {
37 struct __inplace_stop_callback_base
38 {
__executestdexec::__stok::__inplace_stop_callback_base39     void __execute() noexcept
40     {
41         this->__execute_(this);
42     }
43 
44   protected:
45     using __execute_fn_t = void(__inplace_stop_callback_base*) noexcept;
46 
__inplace_stop_callback_basestdexec::__stok::__inplace_stop_callback_base47     explicit __inplace_stop_callback_base(   //
48         const inplace_stop_source* __source, //
49         __execute_fn_t* __execute) noexcept :
50         __source_(__source), __execute_(__execute)
51     {}
52 
53     void __register_callback_() noexcept;
54 
55     friend inplace_stop_source;
56 
57     const inplace_stop_source* __source_;
58     __execute_fn_t* __execute_;
59     __inplace_stop_callback_base* __next_ = nullptr;
60     __inplace_stop_callback_base** __prev_ptr_ = nullptr;
61     bool* __removed_during_callback_ = nullptr;
62     std::atomic<bool> __callback_completed_{false};
63 };
64 
65 struct __spin_wait
66 {
67     __spin_wait() noexcept = default;
68 
__waitstdexec::__stok::__spin_wait69     void __wait() noexcept
70     {
71         if (__count_++ < __yield_threshold_)
72         {
73             // TODO: _mm_pause();
74         }
75         else
76         {
77             if (__count_ == 0)
78                 __count_ = __yield_threshold_;
79             std::this_thread::yield();
80         }
81     }
82 
83   private:
84     static constexpr uint32_t __yield_threshold_ = 20;
85     uint32_t __count_ = 0;
86 };
87 } // namespace __stok
88 
89 // [stoptoken.never], class never_stop_token
90 struct never_stop_token
91 {
92   private:
93     struct __callback_type
94     {
__callback_typestdexec::never_stop_token::__callback_type95         explicit __callback_type(never_stop_token, auto&&) noexcept {}
96     };
97 
98   public:
99     template <class>
100     using callback_type = __callback_type;
101 
stop_requestedstdexec::never_stop_token102     static constexpr auto stop_requested() noexcept -> bool
103     {
104         return false;
105     }
106 
stop_possiblestdexec::never_stop_token107     static constexpr auto stop_possible() noexcept -> bool
108     {
109         return false;
110     }
111 
112     auto operator==(const never_stop_token&) const noexcept -> bool = default;
113 };
114 
115 template <class _Callback>
116 class inplace_stop_callback;
117 
118 // [stopsource.inplace], class inplace_stop_source
119 class inplace_stop_source
120 {
121   public:
122     inplace_stop_source() noexcept = default;
123     ~inplace_stop_source();
124     inplace_stop_source(inplace_stop_source&&) = delete;
125 
126     auto get_token() const noexcept -> inplace_stop_token;
127 
128     auto request_stop() noexcept -> bool;
129 
stop_requested() const130     auto stop_requested() const noexcept -> bool
131     {
132         return (__state_.load(std::memory_order_acquire) &
133                 __stop_requested_flag_) != 0;
134     }
135 
136   private:
137     friend inplace_stop_token;
138     friend __stok::__inplace_stop_callback_base;
139     template <class>
140     friend class inplace_stop_callback;
141 
142     auto __lock_() const noexcept -> uint8_t;
143     void __unlock_(uint8_t) const noexcept;
144 
145     auto __try_lock_unless_stop_requested_(bool) const noexcept -> bool;
146 
147     auto __try_add_callback_(
148         __stok::__inplace_stop_callback_base*) const noexcept -> bool;
149 
150     void __remove_callback_(
151         __stok::__inplace_stop_callback_base*) const noexcept;
152 
153     static constexpr uint8_t __stop_requested_flag_ = 1;
154     static constexpr uint8_t __locked_flag_ = 2;
155 
156     mutable std::atomic<uint8_t> __state_{0};
157     mutable __stok::__inplace_stop_callback_base* __callbacks_ = nullptr;
158     std::thread::id __notifying_thread_;
159 };
160 
161 // [stoptoken.inplace], class inplace_stop_token
162 class inplace_stop_token
163 {
164   public:
165     template <class _Fun>
166     using callback_type = inplace_stop_callback<_Fun>;
167 
inplace_stop_token()168     inplace_stop_token() noexcept : __source_(nullptr) {}
169 
170     inplace_stop_token(const inplace_stop_token& __other) noexcept = default;
171 
inplace_stop_token(inplace_stop_token && __other)172     inplace_stop_token(inplace_stop_token&& __other) noexcept :
173         __source_(std::exchange(__other.__source_, {}))
174     {}
175 
176     auto operator=(const inplace_stop_token& __other) noexcept
177         -> inplace_stop_token& = default;
178 
operator =(inplace_stop_token && __other)179     auto operator=(inplace_stop_token&& __other) noexcept -> inplace_stop_token&
180     {
181         __source_ = std::exchange(__other.__source_, nullptr);
182         return *this;
183     }
184 
stop_requested() const185     [[nodiscard]] auto stop_requested() const noexcept -> bool
186     {
187         return __source_ != nullptr && __source_->stop_requested();
188     }
189 
stop_possible() const190     [[nodiscard]] auto stop_possible() const noexcept -> bool
191     {
192         return __source_ != nullptr;
193     }
194 
swap(inplace_stop_token & __other)195     void swap(inplace_stop_token& __other) noexcept
196     {
197         std::swap(__source_, __other.__source_);
198     }
199 
200     auto operator==(const inplace_stop_token&) const noexcept -> bool = default;
201 
202   private:
203     friend inplace_stop_source;
204     template <class>
205     friend class inplace_stop_callback;
206 
inplace_stop_token(const inplace_stop_source * __source)207     explicit inplace_stop_token(const inplace_stop_source* __source) noexcept :
208         __source_(__source)
209     {}
210 
211     const inplace_stop_source* __source_;
212 };
213 
214 inline auto
get_token() const215     inplace_stop_source::get_token() const noexcept -> inplace_stop_token
216 {
217     return inplace_stop_token{this};
218 }
219 
220 // [stopcallback.inplace], class template inplace_stop_callback
221 template <class _Fun>
222 class inplace_stop_callback : __stok::__inplace_stop_callback_base
223 {
224   public:
225     template <class _Fun2>
226         requires constructible_from<_Fun, _Fun2>
inplace_stop_callback(inplace_stop_token __token,_Fun2 && __fun)227     explicit inplace_stop_callback(inplace_stop_token __token,
228                                    _Fun2&& __fun) //
229         noexcept(__nothrow_constructible_from<_Fun, _Fun2>) :
230         __stok::__inplace_stop_callback_base(
231             __token.__source_, &inplace_stop_callback::__execute_impl_),
232         __fun_(static_cast<_Fun2&&>(__fun))
233     {
234         __register_callback_();
235     }
236 
~inplace_stop_callback()237     ~inplace_stop_callback()
238     {
239         if (__source_ != nullptr)
240             __source_->__remove_callback_(this);
241     }
242 
243   private:
244     static void
__execute_impl_(__stok::__inplace_stop_callback_base * cb)245         __execute_impl_(__stok::__inplace_stop_callback_base* cb) noexcept
246     {
247         std::move(static_cast<inplace_stop_callback*>(cb)->__fun_)();
248     }
249 
250     STDEXEC_ATTRIBUTE((no_unique_address))
251     _Fun __fun_;
252 };
253 
254 namespace __stok
255 {
__register_callback_()256 inline void __inplace_stop_callback_base::__register_callback_() noexcept
257 {
258     if (__source_ != nullptr)
259     {
260         if (!__source_->__try_add_callback_(this))
261         {
262             __source_ = nullptr;
263             // Callback not registered because stop_requested() was true.
264             // Execute inline here.
265             __execute();
266         }
267     }
268 }
269 } // namespace __stok
270 
~inplace_stop_source()271 inline inplace_stop_source::~inplace_stop_source()
272 {
273     STDEXEC_ASSERT(
274         (__state_.load(std::memory_order_relaxed) & __locked_flag_) == 0);
275     STDEXEC_ASSERT(__callbacks_ == nullptr);
276 }
277 
request_stop()278 inline auto inplace_stop_source::request_stop() noexcept -> bool
279 {
280     if (!__try_lock_unless_stop_requested_(true))
281         return true;
282 
283     __notifying_thread_ = std::this_thread::get_id();
284 
285     // We are responsible for executing callbacks.
286     while (__callbacks_ != nullptr)
287     {
288         auto* __callbk = __callbacks_;
289         __callbk->__prev_ptr_ = nullptr;
290         __callbacks_ = __callbk->__next_;
291         if (__callbacks_ != nullptr)
292             __callbacks_->__prev_ptr_ = &__callbacks_;
293 
294         __state_.store(__stop_requested_flag_, std::memory_order_release);
295 
296         bool __removed_during_callback = false;
297         __callbk->__removed_during_callback_ = &__removed_during_callback;
298 
299         __callbk->__execute();
300 
301         if (!__removed_during_callback)
302         {
303             __callbk->__removed_during_callback_ = nullptr;
304             __callbk->__callback_completed_.store(true,
305                                                   std::memory_order_release);
306         }
307 
308         __lock_();
309     }
310 
311     __state_.store(__stop_requested_flag_, std::memory_order_release);
312     return false;
313 }
314 
__lock_() const315 inline auto inplace_stop_source::__lock_() const noexcept -> uint8_t
316 {
317     __stok::__spin_wait __spin;
318     auto __old_state = __state_.load(std::memory_order_relaxed);
319     do
320     {
321         while ((__old_state & __locked_flag_) != 0)
322         {
323             __spin.__wait();
324             __old_state = __state_.load(std::memory_order_relaxed);
325         }
326     } while (!__state_.compare_exchange_weak(
327         __old_state, __old_state | __locked_flag_, std::memory_order_acquire,
328         std::memory_order_relaxed));
329 
330     return __old_state;
331 }
332 
__unlock_(uint8_t __old_state) const333 inline void inplace_stop_source::__unlock_(uint8_t __old_state) const noexcept
334 {
335     (void)__state_.store(__old_state, std::memory_order_release);
336 }
337 
__try_lock_unless_stop_requested_(bool __set_stop_requested) const338 inline auto inplace_stop_source::__try_lock_unless_stop_requested_(
339     bool __set_stop_requested) const noexcept -> bool
340 {
341     __stok::__spin_wait __spin;
342     auto __old_state = __state_.load(std::memory_order_relaxed);
343     do
344     {
345         while (true)
346         {
347             if ((__old_state & __stop_requested_flag_) != 0)
348             {
349                 // Stop already requested.
350                 return false;
351             }
352             else if (__old_state == 0)
353             {
354                 break;
355             }
356             else
357             {
358                 __spin.__wait();
359                 __old_state = __state_.load(std::memory_order_relaxed);
360             }
361         }
362     } while (!__state_.compare_exchange_weak(
363         __old_state,
364         __set_stop_requested ? (__locked_flag_ | __stop_requested_flag_)
365                              : __locked_flag_,
366         std::memory_order_acq_rel, std::memory_order_relaxed));
367 
368     // Lock acquired successfully
369     return true;
370 }
371 
__try_add_callback_(__stok::__inplace_stop_callback_base * __callbk) const372 inline auto inplace_stop_source::__try_add_callback_(
373     __stok::__inplace_stop_callback_base* __callbk) const noexcept -> bool
374 {
375     if (!__try_lock_unless_stop_requested_(false))
376     {
377         return false;
378     }
379 
380     __callbk->__next_ = __callbacks_;
381     __callbk->__prev_ptr_ = &__callbacks_;
382     if (__callbacks_ != nullptr)
383     {
384         __callbacks_->__prev_ptr_ = &__callbk->__next_;
385     }
386     __callbacks_ = __callbk;
387 
388     __unlock_(0);
389 
390     return true;
391 }
392 
__remove_callback_(__stok::__inplace_stop_callback_base * __callbk) const393 inline void inplace_stop_source::__remove_callback_(
394     __stok::__inplace_stop_callback_base* __callbk) const noexcept
395 {
396     auto __old_state = __lock_();
397 
398     if (__callbk->__prev_ptr_ != nullptr)
399     {
400         // Callback has not been executed yet.
401         // Remove from the list.
402         *__callbk->__prev_ptr_ = __callbk->__next_;
403         if (__callbk->__next_ != nullptr)
404         {
405             __callbk->__next_->__prev_ptr_ = __callbk->__prev_ptr_;
406         }
407         __unlock_(__old_state);
408     }
409     else
410     {
411         auto __notifying_thread = __notifying_thread_;
412         __unlock_(__old_state);
413 
414         // Callback has either already been executed or is
415         // currently executing on another thread.
416         if (std::this_thread::get_id() == __notifying_thread)
417         {
418             if (__callbk->__removed_during_callback_ != nullptr)
419             {
420                 *__callbk->__removed_during_callback_ = true;
421             }
422         }
423         else
424         {
425             // Concurrently executing on another thread.
426             // Wait until the other thread finishes executing the callback.
427             __stok::__spin_wait __spin;
428             while (!__callbk->__callback_completed_.load(
429                 std::memory_order_acquire))
430             {
431                 __spin.__wait();
432             }
433         }
434     }
435 }
436 
437 using in_place_stop_token
438     [[deprecated("in_place_stop_token has been renamed inplace_stop_token")]] =
439         inplace_stop_token;
440 
441 using in_place_stop_source
442     [[deprecated("in_place_stop_token has been renamed inplace_stop_source")]] =
443         inplace_stop_source;
444 
445 template <class _Fun>
446 using in_place_stop_callback [[deprecated(
447     "in_place_stop_callback has been renamed inplace_stop_callback")]] =
448     inplace_stop_callback<_Fun>;
449 } // namespace stdexec
450