1 /* 2 * Copyright (c) 2021-2022 Facebook, Inc. and its affiliates 3 * Copyright (c) 2021-2024 NVIDIA Corporation 4 * 5 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://llvm.org/LICENSE.txt 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 #pragma once 18 19 #include "__detail/__stop_token.hpp" // IWYU pragma: export 20 21 #include <version> 22 #include <cstdint> 23 #include <utility> 24 #include <atomic> 25 #include <thread> 26 27 #if __has_include(<stop_token>) && __cpp_lib_jthread >= 2019'11L 28 # include <stop_token> // IWYU pragma: export 29 #endif 30 31 namespace stdexec { 32 namespace __stok { 33 struct __inplace_stop_callback_base { __executestdexec::__stok::__inplace_stop_callback_base34 void __execute() noexcept { 35 this->__execute_(this); 36 } 37 38 protected: 39 using __execute_fn_t = void(__inplace_stop_callback_base*) noexcept; 40 __inplace_stop_callback_basestdexec::__stok::__inplace_stop_callback_base41 explicit __inplace_stop_callback_base( 42 const inplace_stop_source* __source, 43 __execute_fn_t* __execute) noexcept 44 : __source_(__source) 45 , __execute_(__execute) { 46 } 47 48 void __register_callback_() noexcept; 49 50 friend inplace_stop_source; 51 52 const inplace_stop_source* __source_; 53 __execute_fn_t* __execute_; 54 __inplace_stop_callback_base* __next_ = nullptr; 55 __inplace_stop_callback_base** __prev_ptr_ = nullptr; 56 bool* __removed_during_callback_ = nullptr; 57 std::atomic<bool> __callback_completed_{false}; 58 }; 59 60 struct __spin_wait { 61 __spin_wait() noexcept = default; 62 __waitstdexec::__stok::__spin_wait63 void __wait() noexcept { 64 if (__count_++ < __yield_threshold_) { 65 // TODO: _mm_pause(); 66 } else { 67 if (__count_ == 0) 68 __count_ = __yield_threshold_; 69 std::this_thread::yield(); 70 } 71 } 72 73 private: 74 static constexpr uint32_t __yield_threshold_ = 20; 75 uint32_t __count_ = 0; 76 }; 77 } // namespace __stok 78 79 template <class _Callback> 80 class inplace_stop_callback; 81 82 // [stopsource.inplace], class inplace_stop_source 83 class inplace_stop_source { 84 public: 85 inplace_stop_source() noexcept = default; 86 ~inplace_stop_source(); 87 inplace_stop_source(inplace_stop_source&&) = delete; 88 89 auto get_token() const noexcept -> inplace_stop_token; 90 91 auto request_stop() noexcept -> bool; 92 stop_requested() const93 auto stop_requested() const noexcept -> bool { 94 return (__state_.load(std::memory_order_acquire) & __stop_requested_flag_) != 0; 95 } 96 97 private: 98 friend inplace_stop_token; 99 friend __stok::__inplace_stop_callback_base; 100 template <class> 101 friend class inplace_stop_callback; 102 103 auto __lock_() const noexcept -> uint8_t; 104 void __unlock_(uint8_t) const noexcept; 105 106 auto __try_lock_unless_stop_requested_(bool) const noexcept -> bool; 107 108 auto __try_add_callback_(__stok::__inplace_stop_callback_base*) const noexcept -> bool; 109 110 void __remove_callback_(__stok::__inplace_stop_callback_base*) const noexcept; 111 112 static constexpr uint8_t __stop_requested_flag_ = 1; 113 static constexpr uint8_t __locked_flag_ = 2; 114 115 mutable std::atomic<uint8_t> __state_{0}; 116 mutable __stok::__inplace_stop_callback_base* __callbacks_ = nullptr; 117 std::thread::id __notifying_thread_; 118 }; 119 120 // [stoptoken.inplace], class inplace_stop_token 121 class inplace_stop_token { 122 public: 123 template <class _Fun> 124 using callback_type = inplace_stop_callback<_Fun>; 125 inplace_stop_token()126 inplace_stop_token() noexcept 127 : __source_(nullptr) { 128 } 129 130 inplace_stop_token(const inplace_stop_token& __other) noexcept = default; 131 inplace_stop_token(inplace_stop_token && __other)132 inplace_stop_token(inplace_stop_token&& __other) noexcept 133 : __source_(std::exchange(__other.__source_, {})) { 134 } 135 136 auto operator=(const inplace_stop_token& __other) noexcept -> inplace_stop_token& = default; 137 operator =(inplace_stop_token && __other)138 auto operator=(inplace_stop_token&& __other) noexcept -> inplace_stop_token& { 139 __source_ = std::exchange(__other.__source_, nullptr); 140 return *this; 141 } 142 143 [[nodiscard]] stop_requested() const144 auto stop_requested() const noexcept -> bool { 145 return __source_ != nullptr && __source_->stop_requested(); 146 } 147 148 [[nodiscard]] stop_possible() const149 auto stop_possible() const noexcept -> bool { 150 return __source_ != nullptr; 151 } 152 swap(inplace_stop_token & __other)153 void swap(inplace_stop_token& __other) noexcept { 154 std::swap(__source_, __other.__source_); 155 } 156 157 auto operator==(const inplace_stop_token&) const noexcept -> bool = default; 158 159 private: 160 friend inplace_stop_source; 161 template <class> 162 friend class inplace_stop_callback; 163 inplace_stop_token(const inplace_stop_source * __source)164 explicit inplace_stop_token(const inplace_stop_source* __source) noexcept 165 : __source_(__source) { 166 } 167 168 const inplace_stop_source* __source_; 169 }; 170 get_token() const171 inline auto inplace_stop_source::get_token() const noexcept -> inplace_stop_token { 172 return inplace_stop_token{this}; 173 } 174 175 // [stopcallback.inplace], class template inplace_stop_callback 176 template <class _Fun> 177 class inplace_stop_callback : __stok::__inplace_stop_callback_base { 178 public: 179 template <class _Fun2> 180 requires constructible_from<_Fun, _Fun2> inplace_stop_callback(inplace_stop_token __token,_Fun2 && __fun)181 explicit inplace_stop_callback(inplace_stop_token __token, _Fun2&& __fun) 182 noexcept(__nothrow_constructible_from<_Fun, _Fun2>) 183 : __stok::__inplace_stop_callback_base( 184 __token.__source_, 185 &inplace_stop_callback::__execute_impl_) 186 , __fun_(static_cast<_Fun2&&>(__fun)) { 187 __register_callback_(); 188 } 189 ~inplace_stop_callback()190 ~inplace_stop_callback() { 191 if (__source_ != nullptr) 192 __source_->__remove_callback_(this); 193 } 194 195 private: __execute_impl_(__stok::__inplace_stop_callback_base * cb)196 static void __execute_impl_(__stok::__inplace_stop_callback_base* cb) noexcept { 197 std::move(static_cast<inplace_stop_callback*>(cb)->__fun_)(); 198 } 199 200 STDEXEC_ATTRIBUTE(no_unique_address) _Fun __fun_; 201 }; 202 203 namespace __stok { __register_callback_()204 inline void __inplace_stop_callback_base::__register_callback_() noexcept { 205 if (__source_ != nullptr) { 206 if (!__source_->__try_add_callback_(this)) { 207 __source_ = nullptr; 208 // Callback not registered because stop_requested() was true. 209 // Execute inline here. 210 __execute(); 211 } 212 } 213 } 214 } // namespace __stok 215 ~inplace_stop_source()216 inline inplace_stop_source::~inplace_stop_source() { 217 STDEXEC_ASSERT((__state_.load(std::memory_order_relaxed) & __locked_flag_) == 0); 218 STDEXEC_ASSERT(__callbacks_ == nullptr); 219 } 220 request_stop()221 inline auto inplace_stop_source::request_stop() noexcept -> bool { 222 if (!__try_lock_unless_stop_requested_(true)) 223 return true; 224 225 __notifying_thread_ = std::this_thread::get_id(); 226 227 // We are responsible for executing callbacks. 228 while (__callbacks_ != nullptr) { 229 auto* __callbk = __callbacks_; 230 __callbk->__prev_ptr_ = nullptr; 231 __callbacks_ = __callbk->__next_; 232 if (__callbacks_ != nullptr) 233 __callbacks_->__prev_ptr_ = &__callbacks_; 234 235 __state_.store(__stop_requested_flag_, std::memory_order_release); 236 237 bool __removed_during_callback = false; 238 __callbk->__removed_during_callback_ = &__removed_during_callback; 239 240 __callbk->__execute(); 241 242 if (!__removed_during_callback) { 243 __callbk->__removed_during_callback_ = nullptr; 244 __callbk->__callback_completed_.store(true, std::memory_order_release); 245 } 246 247 __lock_(); 248 } 249 250 __state_.store(__stop_requested_flag_, std::memory_order_release); 251 return false; 252 } 253 __lock_() const254 inline auto inplace_stop_source::__lock_() const noexcept -> uint8_t { 255 __stok::__spin_wait __spin; 256 auto __old_state = __state_.load(std::memory_order_relaxed); 257 do { 258 while ((__old_state & __locked_flag_) != 0) { 259 __spin.__wait(); 260 __old_state = __state_.load(std::memory_order_relaxed); 261 } 262 } while (!__state_.compare_exchange_weak( 263 __old_state, 264 __old_state | __locked_flag_, 265 std::memory_order_acquire, 266 std::memory_order_relaxed)); 267 268 return __old_state; 269 } 270 __unlock_(uint8_t __old_state) const271 inline void inplace_stop_source::__unlock_(uint8_t __old_state) const noexcept { 272 (void) __state_.store(__old_state, std::memory_order_release); 273 } 274 275 inline auto __try_lock_unless_stop_requested_(bool __set_stop_requested) const276 inplace_stop_source::__try_lock_unless_stop_requested_(bool __set_stop_requested) const noexcept 277 -> bool { 278 __stok::__spin_wait __spin; 279 auto __old_state = __state_.load(std::memory_order_relaxed); 280 do { 281 while (true) { 282 if ((__old_state & __stop_requested_flag_) != 0) { 283 // Stop already requested. 284 return false; 285 } else if (__old_state == 0) { 286 break; 287 } else { 288 __spin.__wait(); 289 __old_state = __state_.load(std::memory_order_relaxed); 290 } 291 } 292 } while (!__state_.compare_exchange_weak( 293 __old_state, 294 __set_stop_requested ? (__locked_flag_ | __stop_requested_flag_) : __locked_flag_, 295 std::memory_order_acq_rel, 296 std::memory_order_relaxed)); 297 298 // Lock acquired successfully 299 return true; 300 } 301 __try_add_callback_(__stok::__inplace_stop_callback_base * __callbk) const302 inline auto inplace_stop_source::__try_add_callback_( 303 __stok::__inplace_stop_callback_base* __callbk) const noexcept -> bool { 304 if (!__try_lock_unless_stop_requested_(false)) { 305 return false; 306 } 307 308 __callbk->__next_ = __callbacks_; 309 __callbk->__prev_ptr_ = &__callbacks_; 310 if (__callbacks_ != nullptr) { 311 __callbacks_->__prev_ptr_ = &__callbk->__next_; 312 } 313 __callbacks_ = __callbk; 314 315 __unlock_(0); 316 317 return true; 318 } 319 __remove_callback_(__stok::__inplace_stop_callback_base * __callbk) const320 inline void inplace_stop_source::__remove_callback_( 321 __stok::__inplace_stop_callback_base* __callbk) const noexcept { 322 auto __old_state = __lock_(); 323 324 if (__callbk->__prev_ptr_ != nullptr) { 325 // Callback has not been executed yet. 326 // Remove from the list. 327 *__callbk->__prev_ptr_ = __callbk->__next_; 328 if (__callbk->__next_ != nullptr) { 329 __callbk->__next_->__prev_ptr_ = __callbk->__prev_ptr_; 330 } 331 __unlock_(__old_state); 332 } else { 333 auto __notifying_thread = __notifying_thread_; 334 __unlock_(__old_state); 335 336 // Callback has either already been executed or is 337 // currently executing on another thread. 338 if (std::this_thread::get_id() == __notifying_thread) { 339 if (__callbk->__removed_during_callback_ != nullptr) { 340 *__callbk->__removed_during_callback_ = true; 341 } 342 } else { 343 // Concurrently executing on another thread. 344 // Wait until the other thread finishes executing the callback. 345 __stok::__spin_wait __spin; 346 while (!__callbk->__callback_completed_.load(std::memory_order_acquire)) { 347 __spin.__wait(); 348 } 349 } 350 } 351 } 352 353 using in_place_stop_token 354 [[deprecated("in_place_stop_token has been renamed inplace_stop_token")]] = inplace_stop_token; 355 356 using in_place_stop_source [[deprecated( 357 "in_place_stop_token has been renamed inplace_stop_source")]] = inplace_stop_source; 358 359 template <class _Fun> 360 using in_place_stop_callback 361 [[deprecated("in_place_stop_callback has been renamed inplace_stop_callback")]] = 362 inplace_stop_callback<_Fun>; 363 } // namespace stdexec 364