xref: /openbmc/sdbusplus/include/sdbusplus/async/stdexec/stop_token.hpp (revision 10d0b4b7d1498cfd5c3d37edea271a54d1984e41)
1 /*
2  * Copyright (c) 2021-2022 Facebook, Inc. and its affiliates
3  * Copyright (c) 2021-2024 NVIDIA Corporation
4  *
5  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
6  * (the "License"); you may not use this file except in compliance with
7  * the License. You may obtain a copy of the License at
8  *
9  *   https://llvm.org/LICENSE.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 #pragma once
18 
19 #include "__detail/__stop_token.hpp" // IWYU pragma: export
20 
21 #include <version>
22 #include <cstdint>
23 #include <utility>
24 #include <atomic>
25 #include <thread>
26 
27 #if __has_include(<stop_token>) && __cpp_lib_jthread >= 2019'11L
28 #  include <stop_token> // IWYU pragma: export
29 #endif
30 
31 namespace stdexec {
32   namespace __stok {
33     struct __inplace_stop_callback_base {
__executestdexec::__stok::__inplace_stop_callback_base34       void __execute() noexcept {
35         this->__execute_(this);
36       }
37 
38      protected:
39       using __execute_fn_t = void(__inplace_stop_callback_base*) noexcept;
40 
__inplace_stop_callback_basestdexec::__stok::__inplace_stop_callback_base41       explicit __inplace_stop_callback_base(
42         const inplace_stop_source* __source,
43         __execute_fn_t* __execute) noexcept
44         : __source_(__source)
45         , __execute_(__execute) {
46       }
47 
48       void __register_callback_() noexcept;
49 
50       friend inplace_stop_source;
51 
52       const inplace_stop_source* __source_;
53       __execute_fn_t* __execute_;
54       __inplace_stop_callback_base* __next_ = nullptr;
55       __inplace_stop_callback_base** __prev_ptr_ = nullptr;
56       bool* __removed_during_callback_ = nullptr;
57       std::atomic<bool> __callback_completed_{false};
58     };
59 
60     struct __spin_wait {
61       __spin_wait() noexcept = default;
62 
__waitstdexec::__stok::__spin_wait63       void __wait() noexcept {
64         if (__count_++ < __yield_threshold_) {
65           // TODO: _mm_pause();
66         } else {
67           if (__count_ == 0)
68             __count_ = __yield_threshold_;
69           std::this_thread::yield();
70         }
71       }
72 
73      private:
74       static constexpr uint32_t __yield_threshold_ = 20;
75       uint32_t __count_ = 0;
76     };
77   } // namespace __stok
78 
79   template <class _Callback>
80   class inplace_stop_callback;
81 
82   // [stopsource.inplace], class inplace_stop_source
83   class inplace_stop_source {
84    public:
85     inplace_stop_source() noexcept = default;
86     ~inplace_stop_source();
87     inplace_stop_source(inplace_stop_source&&) = delete;
88 
89     auto get_token() const noexcept -> inplace_stop_token;
90 
91     auto request_stop() noexcept -> bool;
92 
stop_requested() const93     auto stop_requested() const noexcept -> bool {
94       return (__state_.load(std::memory_order_acquire) & __stop_requested_flag_) != 0;
95     }
96 
97    private:
98     friend inplace_stop_token;
99     friend __stok::__inplace_stop_callback_base;
100     template <class>
101     friend class inplace_stop_callback;
102 
103     auto __lock_() const noexcept -> uint8_t;
104     void __unlock_(uint8_t) const noexcept;
105 
106     auto __try_lock_unless_stop_requested_(bool) const noexcept -> bool;
107 
108     auto __try_add_callback_(__stok::__inplace_stop_callback_base*) const noexcept -> bool;
109 
110     void __remove_callback_(__stok::__inplace_stop_callback_base*) const noexcept;
111 
112     static constexpr uint8_t __stop_requested_flag_ = 1;
113     static constexpr uint8_t __locked_flag_ = 2;
114 
115     mutable std::atomic<uint8_t> __state_{0};
116     mutable __stok::__inplace_stop_callback_base* __callbacks_ = nullptr;
117     std::thread::id __notifying_thread_;
118   };
119 
120   // [stoptoken.inplace], class inplace_stop_token
121   class inplace_stop_token {
122    public:
123     template <class _Fun>
124     using callback_type = inplace_stop_callback<_Fun>;
125 
inplace_stop_token()126     inplace_stop_token() noexcept
127       : __source_(nullptr) {
128     }
129 
130     inplace_stop_token(const inplace_stop_token& __other) noexcept = default;
131 
inplace_stop_token(inplace_stop_token && __other)132     inplace_stop_token(inplace_stop_token&& __other) noexcept
133       : __source_(std::exchange(__other.__source_, {})) {
134     }
135 
136     auto operator=(const inplace_stop_token& __other) noexcept -> inplace_stop_token& = default;
137 
operator =(inplace_stop_token && __other)138     auto operator=(inplace_stop_token&& __other) noexcept -> inplace_stop_token& {
139       __source_ = std::exchange(__other.__source_, nullptr);
140       return *this;
141     }
142 
143     [[nodiscard]]
stop_requested() const144     auto stop_requested() const noexcept -> bool {
145       return __source_ != nullptr && __source_->stop_requested();
146     }
147 
148     [[nodiscard]]
stop_possible() const149     auto stop_possible() const noexcept -> bool {
150       return __source_ != nullptr;
151     }
152 
swap(inplace_stop_token & __other)153     void swap(inplace_stop_token& __other) noexcept {
154       std::swap(__source_, __other.__source_);
155     }
156 
157     auto operator==(const inplace_stop_token&) const noexcept -> bool = default;
158 
159    private:
160     friend inplace_stop_source;
161     template <class>
162     friend class inplace_stop_callback;
163 
inplace_stop_token(const inplace_stop_source * __source)164     explicit inplace_stop_token(const inplace_stop_source* __source) noexcept
165       : __source_(__source) {
166     }
167 
168     const inplace_stop_source* __source_;
169   };
170 
get_token() const171   inline auto inplace_stop_source::get_token() const noexcept -> inplace_stop_token {
172     return inplace_stop_token{this};
173   }
174 
175   // [stopcallback.inplace], class template inplace_stop_callback
176   template <class _Fun>
177   class inplace_stop_callback : __stok::__inplace_stop_callback_base {
178    public:
179     template <class _Fun2>
180       requires constructible_from<_Fun, _Fun2>
inplace_stop_callback(inplace_stop_token __token,_Fun2 && __fun)181     explicit inplace_stop_callback(inplace_stop_token __token, _Fun2&& __fun)
182       noexcept(__nothrow_constructible_from<_Fun, _Fun2>)
183       : __stok::__inplace_stop_callback_base(
184           __token.__source_,
185           &inplace_stop_callback::__execute_impl_)
186       , __fun_(static_cast<_Fun2&&>(__fun)) {
187       __register_callback_();
188     }
189 
~inplace_stop_callback()190     ~inplace_stop_callback() {
191       if (__source_ != nullptr)
192         __source_->__remove_callback_(this);
193     }
194 
195    private:
__execute_impl_(__stok::__inplace_stop_callback_base * cb)196     static void __execute_impl_(__stok::__inplace_stop_callback_base* cb) noexcept {
197       std::move(static_cast<inplace_stop_callback*>(cb)->__fun_)();
198     }
199 
200     STDEXEC_ATTRIBUTE(no_unique_address) _Fun __fun_;
201   };
202 
203   namespace __stok {
__register_callback_()204     inline void __inplace_stop_callback_base::__register_callback_() noexcept {
205       if (__source_ != nullptr) {
206         if (!__source_->__try_add_callback_(this)) {
207           __source_ = nullptr;
208           // Callback not registered because stop_requested() was true.
209           // Execute inline here.
210           __execute();
211         }
212       }
213     }
214   } // namespace __stok
215 
~inplace_stop_source()216   inline inplace_stop_source::~inplace_stop_source() {
217     STDEXEC_ASSERT((__state_.load(std::memory_order_relaxed) & __locked_flag_) == 0);
218     STDEXEC_ASSERT(__callbacks_ == nullptr);
219   }
220 
request_stop()221   inline auto inplace_stop_source::request_stop() noexcept -> bool {
222     if (!__try_lock_unless_stop_requested_(true))
223       return true;
224 
225     __notifying_thread_ = std::this_thread::get_id();
226 
227     // We are responsible for executing callbacks.
228     while (__callbacks_ != nullptr) {
229       auto* __callbk = __callbacks_;
230       __callbk->__prev_ptr_ = nullptr;
231       __callbacks_ = __callbk->__next_;
232       if (__callbacks_ != nullptr)
233         __callbacks_->__prev_ptr_ = &__callbacks_;
234 
235       __state_.store(__stop_requested_flag_, std::memory_order_release);
236 
237       bool __removed_during_callback = false;
238       __callbk->__removed_during_callback_ = &__removed_during_callback;
239 
240       __callbk->__execute();
241 
242       if (!__removed_during_callback) {
243         __callbk->__removed_during_callback_ = nullptr;
244         __callbk->__callback_completed_.store(true, std::memory_order_release);
245       }
246 
247       __lock_();
248     }
249 
250     __state_.store(__stop_requested_flag_, std::memory_order_release);
251     return false;
252   }
253 
__lock_() const254   inline auto inplace_stop_source::__lock_() const noexcept -> uint8_t {
255     __stok::__spin_wait __spin;
256     auto __old_state = __state_.load(std::memory_order_relaxed);
257     do {
258       while ((__old_state & __locked_flag_) != 0) {
259         __spin.__wait();
260         __old_state = __state_.load(std::memory_order_relaxed);
261       }
262     } while (!__state_.compare_exchange_weak(
263       __old_state,
264       __old_state | __locked_flag_,
265       std::memory_order_acquire,
266       std::memory_order_relaxed));
267 
268     return __old_state;
269   }
270 
__unlock_(uint8_t __old_state) const271   inline void inplace_stop_source::__unlock_(uint8_t __old_state) const noexcept {
272     (void) __state_.store(__old_state, std::memory_order_release);
273   }
274 
275   inline auto
__try_lock_unless_stop_requested_(bool __set_stop_requested) const276     inplace_stop_source::__try_lock_unless_stop_requested_(bool __set_stop_requested) const noexcept
277     -> bool {
278     __stok::__spin_wait __spin;
279     auto __old_state = __state_.load(std::memory_order_relaxed);
280     do {
281       while (true) {
282         if ((__old_state & __stop_requested_flag_) != 0) {
283           // Stop already requested.
284           return false;
285         } else if (__old_state == 0) {
286           break;
287         } else {
288           __spin.__wait();
289           __old_state = __state_.load(std::memory_order_relaxed);
290         }
291       }
292     } while (!__state_.compare_exchange_weak(
293       __old_state,
294       __set_stop_requested ? (__locked_flag_ | __stop_requested_flag_) : __locked_flag_,
295       std::memory_order_acq_rel,
296       std::memory_order_relaxed));
297 
298     // Lock acquired successfully
299     return true;
300   }
301 
__try_add_callback_(__stok::__inplace_stop_callback_base * __callbk) const302   inline auto inplace_stop_source::__try_add_callback_(
303     __stok::__inplace_stop_callback_base* __callbk) const noexcept -> bool {
304     if (!__try_lock_unless_stop_requested_(false)) {
305       return false;
306     }
307 
308     __callbk->__next_ = __callbacks_;
309     __callbk->__prev_ptr_ = &__callbacks_;
310     if (__callbacks_ != nullptr) {
311       __callbacks_->__prev_ptr_ = &__callbk->__next_;
312     }
313     __callbacks_ = __callbk;
314 
315     __unlock_(0);
316 
317     return true;
318   }
319 
__remove_callback_(__stok::__inplace_stop_callback_base * __callbk) const320   inline void inplace_stop_source::__remove_callback_(
321     __stok::__inplace_stop_callback_base* __callbk) const noexcept {
322     auto __old_state = __lock_();
323 
324     if (__callbk->__prev_ptr_ != nullptr) {
325       // Callback has not been executed yet.
326       // Remove from the list.
327       *__callbk->__prev_ptr_ = __callbk->__next_;
328       if (__callbk->__next_ != nullptr) {
329         __callbk->__next_->__prev_ptr_ = __callbk->__prev_ptr_;
330       }
331       __unlock_(__old_state);
332     } else {
333       auto __notifying_thread = __notifying_thread_;
334       __unlock_(__old_state);
335 
336       // Callback has either already been executed or is
337       // currently executing on another thread.
338       if (std::this_thread::get_id() == __notifying_thread) {
339         if (__callbk->__removed_during_callback_ != nullptr) {
340           *__callbk->__removed_during_callback_ = true;
341         }
342       } else {
343         // Concurrently executing on another thread.
344         // Wait until the other thread finishes executing the callback.
345         __stok::__spin_wait __spin;
346         while (!__callbk->__callback_completed_.load(std::memory_order_acquire)) {
347           __spin.__wait();
348         }
349       }
350     }
351   }
352 
353   using in_place_stop_token
354     [[deprecated("in_place_stop_token has been renamed inplace_stop_token")]] = inplace_stop_token;
355 
356   using in_place_stop_source [[deprecated(
357     "in_place_stop_token has been renamed inplace_stop_source")]] = inplace_stop_source;
358 
359   template <class _Fun>
360   using in_place_stop_callback
361     [[deprecated("in_place_stop_callback has been renamed inplace_stop_callback")]] =
362       inplace_stop_callback<_Fun>;
363 } // namespace stdexec
364