1 /*
2 * Copyright (c) 2021-2022 Facebook, Inc. and its affiliates
3 * Copyright (c) 2021-2024 NVIDIA Corporation
4 *
5 * Licensed under the Apache License Version 2.0 with LLVM Exceptions
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://llvm.org/LICENSE.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 #pragma once
18
19 #include "__detail/__stop_token.hpp"
20
21 #include <atomic>
22 #include <concepts>
23 #include <cstdint>
24 #include <thread>
25 #include <type_traits>
26 #include <utility>
27 #include <version>
28
29 #if __has_include(<stop_token>) && __cpp_lib_jthread >= 201911
30 #include <stop_token>
31 #endif
32
33 namespace stdexec
34 {
35 namespace __stok
36 {
37 struct __inplace_stop_callback_base
38 {
__executestdexec::__stok::__inplace_stop_callback_base39 void __execute() noexcept
40 {
41 this->__execute_(this);
42 }
43
44 protected:
45 using __execute_fn_t = void(__inplace_stop_callback_base*) noexcept;
46
__inplace_stop_callback_basestdexec::__stok::__inplace_stop_callback_base47 explicit __inplace_stop_callback_base( //
48 const inplace_stop_source* __source, //
49 __execute_fn_t* __execute) noexcept :
50 __source_(__source), __execute_(__execute)
51 {}
52
53 void __register_callback_() noexcept;
54
55 friend inplace_stop_source;
56
57 const inplace_stop_source* __source_;
58 __execute_fn_t* __execute_;
59 __inplace_stop_callback_base* __next_ = nullptr;
60 __inplace_stop_callback_base** __prev_ptr_ = nullptr;
61 bool* __removed_during_callback_ = nullptr;
62 std::atomic<bool> __callback_completed_{false};
63 };
64
65 struct __spin_wait
66 {
67 __spin_wait() noexcept = default;
68
__waitstdexec::__stok::__spin_wait69 void __wait() noexcept
70 {
71 if (__count_++ < __yield_threshold_)
72 {
73 // TODO: _mm_pause();
74 }
75 else
76 {
77 if (__count_ == 0)
78 __count_ = __yield_threshold_;
79 std::this_thread::yield();
80 }
81 }
82
83 private:
84 static constexpr uint32_t __yield_threshold_ = 20;
85 uint32_t __count_ = 0;
86 };
87 } // namespace __stok
88
89 // [stoptoken.never], class never_stop_token
90 struct never_stop_token
91 {
92 private:
93 struct __callback_type
94 {
__callback_typestdexec::never_stop_token::__callback_type95 explicit __callback_type(never_stop_token, auto&&) noexcept {}
96 };
97
98 public:
99 template <class>
100 using callback_type = __callback_type;
101
stop_requestedstdexec::never_stop_token102 static constexpr auto stop_requested() noexcept -> bool
103 {
104 return false;
105 }
106
stop_possiblestdexec::never_stop_token107 static constexpr auto stop_possible() noexcept -> bool
108 {
109 return false;
110 }
111
112 auto operator==(const never_stop_token&) const noexcept -> bool = default;
113 };
114
115 template <class _Callback>
116 class inplace_stop_callback;
117
118 // [stopsource.inplace], class inplace_stop_source
119 class inplace_stop_source
120 {
121 public:
122 inplace_stop_source() noexcept = default;
123 ~inplace_stop_source();
124 inplace_stop_source(inplace_stop_source&&) = delete;
125
126 auto get_token() const noexcept -> inplace_stop_token;
127
128 auto request_stop() noexcept -> bool;
129
stop_requested() const130 auto stop_requested() const noexcept -> bool
131 {
132 return (__state_.load(std::memory_order_acquire) &
133 __stop_requested_flag_) != 0;
134 }
135
136 private:
137 friend inplace_stop_token;
138 friend __stok::__inplace_stop_callback_base;
139 template <class>
140 friend class inplace_stop_callback;
141
142 auto __lock_() const noexcept -> uint8_t;
143 void __unlock_(uint8_t) const noexcept;
144
145 auto __try_lock_unless_stop_requested_(bool) const noexcept -> bool;
146
147 auto __try_add_callback_(
148 __stok::__inplace_stop_callback_base*) const noexcept -> bool;
149
150 void __remove_callback_(
151 __stok::__inplace_stop_callback_base*) const noexcept;
152
153 static constexpr uint8_t __stop_requested_flag_ = 1;
154 static constexpr uint8_t __locked_flag_ = 2;
155
156 mutable std::atomic<uint8_t> __state_{0};
157 mutable __stok::__inplace_stop_callback_base* __callbacks_ = nullptr;
158 std::thread::id __notifying_thread_;
159 };
160
161 // [stoptoken.inplace], class inplace_stop_token
162 class inplace_stop_token
163 {
164 public:
165 template <class _Fun>
166 using callback_type = inplace_stop_callback<_Fun>;
167
inplace_stop_token()168 inplace_stop_token() noexcept : __source_(nullptr) {}
169
170 inplace_stop_token(const inplace_stop_token& __other) noexcept = default;
171
inplace_stop_token(inplace_stop_token && __other)172 inplace_stop_token(inplace_stop_token&& __other) noexcept :
173 __source_(std::exchange(__other.__source_, {}))
174 {}
175
176 auto operator=(const inplace_stop_token& __other) noexcept
177 -> inplace_stop_token& = default;
178
operator =(inplace_stop_token && __other)179 auto operator=(inplace_stop_token&& __other) noexcept -> inplace_stop_token&
180 {
181 __source_ = std::exchange(__other.__source_, nullptr);
182 return *this;
183 }
184
stop_requested() const185 [[nodiscard]] auto stop_requested() const noexcept -> bool
186 {
187 return __source_ != nullptr && __source_->stop_requested();
188 }
189
stop_possible() const190 [[nodiscard]] auto stop_possible() const noexcept -> bool
191 {
192 return __source_ != nullptr;
193 }
194
swap(inplace_stop_token & __other)195 void swap(inplace_stop_token& __other) noexcept
196 {
197 std::swap(__source_, __other.__source_);
198 }
199
200 auto operator==(const inplace_stop_token&) const noexcept -> bool = default;
201
202 private:
203 friend inplace_stop_source;
204 template <class>
205 friend class inplace_stop_callback;
206
inplace_stop_token(const inplace_stop_source * __source)207 explicit inplace_stop_token(const inplace_stop_source* __source) noexcept :
208 __source_(__source)
209 {}
210
211 const inplace_stop_source* __source_;
212 };
213
214 inline auto
get_token() const215 inplace_stop_source::get_token() const noexcept -> inplace_stop_token
216 {
217 return inplace_stop_token{this};
218 }
219
220 // [stopcallback.inplace], class template inplace_stop_callback
221 template <class _Fun>
222 class inplace_stop_callback : __stok::__inplace_stop_callback_base
223 {
224 public:
225 template <class _Fun2>
226 requires constructible_from<_Fun, _Fun2>
inplace_stop_callback(inplace_stop_token __token,_Fun2 && __fun)227 explicit inplace_stop_callback(inplace_stop_token __token,
228 _Fun2&& __fun) //
229 noexcept(__nothrow_constructible_from<_Fun, _Fun2>) :
230 __stok::__inplace_stop_callback_base(
231 __token.__source_, &inplace_stop_callback::__execute_impl_),
232 __fun_(static_cast<_Fun2&&>(__fun))
233 {
234 __register_callback_();
235 }
236
~inplace_stop_callback()237 ~inplace_stop_callback()
238 {
239 if (__source_ != nullptr)
240 __source_->__remove_callback_(this);
241 }
242
243 private:
244 static void
__execute_impl_(__stok::__inplace_stop_callback_base * cb)245 __execute_impl_(__stok::__inplace_stop_callback_base* cb) noexcept
246 {
247 std::move(static_cast<inplace_stop_callback*>(cb)->__fun_)();
248 }
249
250 STDEXEC_ATTRIBUTE((no_unique_address))
251 _Fun __fun_;
252 };
253
254 namespace __stok
255 {
__register_callback_()256 inline void __inplace_stop_callback_base::__register_callback_() noexcept
257 {
258 if (__source_ != nullptr)
259 {
260 if (!__source_->__try_add_callback_(this))
261 {
262 __source_ = nullptr;
263 // Callback not registered because stop_requested() was true.
264 // Execute inline here.
265 __execute();
266 }
267 }
268 }
269 } // namespace __stok
270
~inplace_stop_source()271 inline inplace_stop_source::~inplace_stop_source()
272 {
273 STDEXEC_ASSERT(
274 (__state_.load(std::memory_order_relaxed) & __locked_flag_) == 0);
275 STDEXEC_ASSERT(__callbacks_ == nullptr);
276 }
277
request_stop()278 inline auto inplace_stop_source::request_stop() noexcept -> bool
279 {
280 if (!__try_lock_unless_stop_requested_(true))
281 return true;
282
283 __notifying_thread_ = std::this_thread::get_id();
284
285 // We are responsible for executing callbacks.
286 while (__callbacks_ != nullptr)
287 {
288 auto* __callbk = __callbacks_;
289 __callbk->__prev_ptr_ = nullptr;
290 __callbacks_ = __callbk->__next_;
291 if (__callbacks_ != nullptr)
292 __callbacks_->__prev_ptr_ = &__callbacks_;
293
294 __state_.store(__stop_requested_flag_, std::memory_order_release);
295
296 bool __removed_during_callback = false;
297 __callbk->__removed_during_callback_ = &__removed_during_callback;
298
299 __callbk->__execute();
300
301 if (!__removed_during_callback)
302 {
303 __callbk->__removed_during_callback_ = nullptr;
304 __callbk->__callback_completed_.store(true,
305 std::memory_order_release);
306 }
307
308 __lock_();
309 }
310
311 __state_.store(__stop_requested_flag_, std::memory_order_release);
312 return false;
313 }
314
__lock_() const315 inline auto inplace_stop_source::__lock_() const noexcept -> uint8_t
316 {
317 __stok::__spin_wait __spin;
318 auto __old_state = __state_.load(std::memory_order_relaxed);
319 do
320 {
321 while ((__old_state & __locked_flag_) != 0)
322 {
323 __spin.__wait();
324 __old_state = __state_.load(std::memory_order_relaxed);
325 }
326 } while (!__state_.compare_exchange_weak(
327 __old_state, __old_state | __locked_flag_, std::memory_order_acquire,
328 std::memory_order_relaxed));
329
330 return __old_state;
331 }
332
__unlock_(uint8_t __old_state) const333 inline void inplace_stop_source::__unlock_(uint8_t __old_state) const noexcept
334 {
335 (void)__state_.store(__old_state, std::memory_order_release);
336 }
337
__try_lock_unless_stop_requested_(bool __set_stop_requested) const338 inline auto inplace_stop_source::__try_lock_unless_stop_requested_(
339 bool __set_stop_requested) const noexcept -> bool
340 {
341 __stok::__spin_wait __spin;
342 auto __old_state = __state_.load(std::memory_order_relaxed);
343 do
344 {
345 while (true)
346 {
347 if ((__old_state & __stop_requested_flag_) != 0)
348 {
349 // Stop already requested.
350 return false;
351 }
352 else if (__old_state == 0)
353 {
354 break;
355 }
356 else
357 {
358 __spin.__wait();
359 __old_state = __state_.load(std::memory_order_relaxed);
360 }
361 }
362 } while (!__state_.compare_exchange_weak(
363 __old_state,
364 __set_stop_requested ? (__locked_flag_ | __stop_requested_flag_)
365 : __locked_flag_,
366 std::memory_order_acq_rel, std::memory_order_relaxed));
367
368 // Lock acquired successfully
369 return true;
370 }
371
__try_add_callback_(__stok::__inplace_stop_callback_base * __callbk) const372 inline auto inplace_stop_source::__try_add_callback_(
373 __stok::__inplace_stop_callback_base* __callbk) const noexcept -> bool
374 {
375 if (!__try_lock_unless_stop_requested_(false))
376 {
377 return false;
378 }
379
380 __callbk->__next_ = __callbacks_;
381 __callbk->__prev_ptr_ = &__callbacks_;
382 if (__callbacks_ != nullptr)
383 {
384 __callbacks_->__prev_ptr_ = &__callbk->__next_;
385 }
386 __callbacks_ = __callbk;
387
388 __unlock_(0);
389
390 return true;
391 }
392
__remove_callback_(__stok::__inplace_stop_callback_base * __callbk) const393 inline void inplace_stop_source::__remove_callback_(
394 __stok::__inplace_stop_callback_base* __callbk) const noexcept
395 {
396 auto __old_state = __lock_();
397
398 if (__callbk->__prev_ptr_ != nullptr)
399 {
400 // Callback has not been executed yet.
401 // Remove from the list.
402 *__callbk->__prev_ptr_ = __callbk->__next_;
403 if (__callbk->__next_ != nullptr)
404 {
405 __callbk->__next_->__prev_ptr_ = __callbk->__prev_ptr_;
406 }
407 __unlock_(__old_state);
408 }
409 else
410 {
411 auto __notifying_thread = __notifying_thread_;
412 __unlock_(__old_state);
413
414 // Callback has either already been executed or is
415 // currently executing on another thread.
416 if (std::this_thread::get_id() == __notifying_thread)
417 {
418 if (__callbk->__removed_during_callback_ != nullptr)
419 {
420 *__callbk->__removed_during_callback_ = true;
421 }
422 }
423 else
424 {
425 // Concurrently executing on another thread.
426 // Wait until the other thread finishes executing the callback.
427 __stok::__spin_wait __spin;
428 while (!__callbk->__callback_completed_.load(
429 std::memory_order_acquire))
430 {
431 __spin.__wait();
432 }
433 }
434 }
435 }
436
437 using in_place_stop_token
438 [[deprecated("in_place_stop_token has been renamed inplace_stop_token")]] =
439 inplace_stop_token;
440
441 using in_place_stop_source
442 [[deprecated("in_place_stop_token has been renamed inplace_stop_source")]] =
443 inplace_stop_source;
444
445 template <class _Fun>
446 using in_place_stop_callback [[deprecated(
447 "in_place_stop_callback has been renamed inplace_stop_callback")]] =
448 inplace_stop_callback<_Fun>;
449 } // namespace stdexec
450