1 /* 2 * Copyright (c) 2022 NVIDIA Corporation 3 * 4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 5 * (the "License"); you may not use this file except in compliance with 6 * the License. You may obtain a copy of the License at 7 * 8 * https://llvm.org/LICENSE.txt 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include "__concepts.hpp" 19 20 #include <atomic> 21 #include <new> // IWYU pragma: keep for ::new 22 #include <cstddef> 23 #include <type_traits> 24 #include <utility> 25 26 #if STDEXEC_TSAN() 27 # include <sanitizer/tsan_interface.h> 28 #endif 29 30 namespace stdexec { 31 namespace __ptr { 32 template <std::size_t _ReservedBits> 33 struct __count_and_bits { 34 static constexpr std::size_t __ref_count_increment = 1ul << _ReservedBits; 35 36 enum struct __bits : std::size_t { 37 }; 38 __count(__bits __b)39 friend constexpr auto __count(__bits __b) noexcept -> std::size_t { 40 return static_cast<std::size_t>(__b) / __ref_count_increment; 41 } 42 43 template <std::size_t _Bit> __bit(__bits __b)44 friend constexpr auto __bit(__bits __b) noexcept -> bool { 45 static_assert(_Bit < _ReservedBits, "Bit index out of range"); 46 return (static_cast<std::size_t>(__b) & (1ul << _Bit)) != 0; 47 } 48 }; 49 50 template <std::size_t _ReservedBits> 51 using __bits_t = __count_and_bits<_ReservedBits>::__bits; 52 53 template <class _Ty, std::size_t _ReservedBits> 54 struct __make_intrusive_t; 55 56 template <class _Ty, std::size_t _ReservedBits = 0ul> 57 class __intrusive_ptr; 58 59 template <class _Ty, std::size_t _ReservedBits = 0ul> 60 struct __enable_intrusive_from_this { 61 auto __intrusive_from_this() noexcept -> __intrusive_ptr<_Ty, _ReservedBits>; 62 auto __intrusive_from_this() const noexcept -> __intrusive_ptr<const _Ty, _ReservedBits>; 63 private: 64 using __bits_t = __count_and_bits<_ReservedBits>::__bits; 65 friend _Ty; 66 auto __inc_ref() noexcept -> __bits_t; 67 auto __dec_ref() noexcept -> __bits_t; 68 69 template <std::size_t _Bit> 70 [[nodiscard]] 71 auto __is_set() const noexcept -> bool; 72 template <std::size_t _Bit> 73 auto __set_bit() noexcept -> __bits_t; 74 template <std::size_t _Bit> 75 auto __clear_bit() noexcept -> __bits_t; 76 }; 77 78 STDEXEC_PRAGMA_PUSH() 79 STDEXEC_PRAGMA_IGNORE_GNU("-Wtsan") 80 81 template <class _Ty, std::size_t _ReservedBits> 82 struct __control_block { 83 using __bits_t = __count_and_bits<_ReservedBits>::__bits; 84 static constexpr std::size_t __ref_count_increment = 1ul << _ReservedBits; 85 86 alignas(_Ty) unsigned char __value_[sizeof(_Ty)]; 87 std::atomic<std::size_t> __ref_count_; 88 89 template <class... _Us> __control_blockstdexec::__ptr::__control_block90 explicit __control_block(_Us&&... __us) noexcept(noexcept(_Ty{__declval<_Us>()...})) 91 : __ref_count_(__ref_count_increment) { 92 // Construct the value *after* the initialization of the atomic in case the constructor of 93 // _Ty calls __intrusive_from_this() (which increments the ref count): 94 ::new (static_cast<void*>(__value_)) _Ty{static_cast<_Us&&>(__us)...}; 95 } 96 ~__control_blockstdexec::__ptr::__control_block97 ~__control_block() { 98 __value().~_Ty(); 99 } 100 __valuestdexec::__ptr::__control_block101 auto __value() noexcept -> _Ty& { 102 return *reinterpret_cast<_Ty*>(__value_); 103 } 104 __inc_ref_stdexec::__ptr::__control_block105 auto __inc_ref_() noexcept -> __bits_t { 106 auto __old = __ref_count_.fetch_add(__ref_count_increment, std::memory_order_relaxed); 107 return static_cast<__bits_t>(__old); 108 } 109 __dec_ref_stdexec::__ptr::__control_block110 auto __dec_ref_() noexcept -> __bits_t { 111 auto __old = __ref_count_.fetch_sub(__ref_count_increment, std::memory_order_acq_rel); 112 if (__count(static_cast<__bits_t>(__old)) == 1) { 113 delete this; 114 } 115 return static_cast<__bits_t>(__old); 116 } 117 118 // Returns true if the bit was set, false if it was already set. 119 template <std::size_t _Bit> 120 [[nodiscard]] __is_set_stdexec::__ptr::__control_block121 auto __is_set_() const noexcept -> bool { 122 auto __old = __ref_count_.load(std::memory_order_relaxed); 123 return __bit<_Bit>(static_cast<__bits_t>(__old)); 124 } 125 126 template <std::size_t _Bit> __set_bit_stdexec::__ptr::__control_block127 auto __set_bit_() noexcept -> __bits_t { 128 static_assert(_Bit < _ReservedBits, "Bit index out of range"); 129 constexpr std::size_t __mask = 1ul << _Bit; 130 auto __old = __ref_count_.fetch_or(__mask, std::memory_order_acq_rel); 131 return static_cast<__bits_t>(__old); 132 } 133 134 // Returns true if the bit was cleared, false if it was already cleared. 135 template <std::size_t _Bit> __clear_bit_stdexec::__ptr::__control_block136 auto __clear_bit_() noexcept -> __bits_t { 137 static_assert(_Bit < _ReservedBits, "Bit index out of range"); 138 constexpr std::size_t __mask = 1ul << _Bit; 139 auto __old = __ref_count_.fetch_and(~__mask, std::memory_order_acq_rel); 140 return static_cast<__bits_t>(__old); 141 } 142 }; 143 144 STDEXEC_PRAGMA_POP() 145 146 template <class _Ty, std::size_t _ReservedBits /* = 0ul */> 147 class __intrusive_ptr { 148 using _UncvTy = std::remove_cv_t<_Ty>; 149 using __enable_intrusive_t = __enable_intrusive_from_this<_UncvTy, _ReservedBits>; 150 friend _Ty; 151 friend struct __make_intrusive_t<_Ty, _ReservedBits>; 152 friend struct __enable_intrusive_from_this<_UncvTy, _ReservedBits>; 153 154 __control_block<_UncvTy, _ReservedBits>* __data_{nullptr}; 155 __intrusive_ptr(__control_block<_UncvTy,_ReservedBits> * __data)156 explicit __intrusive_ptr(__control_block<_UncvTy, _ReservedBits>* __data) noexcept 157 : __data_(__data) { 158 } 159 __inc_ref_()160 void __inc_ref_() noexcept { 161 if (__data_) { 162 __data_->__inc_ref_(); 163 } 164 } 165 __dec_ref_()166 void __dec_ref_() noexcept { 167 if (__data_) { 168 __data_->__dec_ref_(); 169 } 170 } 171 172 // For use when types want to take over manual control of the reference count. 173 // Very unsafe, but useful for implementing custom reference counting. 174 [[nodiscard]] __release_()175 auto __release_() noexcept -> __enable_intrusive_t* { 176 auto* __data = std::exchange(__data_, nullptr); 177 return __data ? &__c_upcast<__enable_intrusive_t>(__data->__value()) : nullptr; 178 } 179 180 public: 181 using element_type = _Ty; 182 183 __intrusive_ptr() = default; 184 __intrusive_ptr(__intrusive_ptr && __that)185 __intrusive_ptr(__intrusive_ptr&& __that) noexcept 186 : __data_(std::exchange(__that.__data_, nullptr)) { 187 } 188 __intrusive_ptr(const __intrusive_ptr & __that)189 __intrusive_ptr(const __intrusive_ptr& __that) noexcept 190 : __data_(__that.__data_) { 191 __inc_ref_(); 192 } 193 __intrusive_ptr(__enable_intrusive_from_this<_Ty,_ReservedBits> * __that)194 __intrusive_ptr(__enable_intrusive_from_this<_Ty, _ReservedBits>* __that) noexcept 195 : __intrusive_ptr(__that ? __that->__intrusive_from_this() : __intrusive_ptr()) { 196 } 197 operator =(__intrusive_ptr && __that)198 auto operator=(__intrusive_ptr&& __that) noexcept -> __intrusive_ptr& { 199 [[maybe_unused]] 200 __intrusive_ptr __old{std::exchange(__data_, std::exchange(__that.__data_, nullptr))}; 201 return *this; 202 } 203 operator =(const __intrusive_ptr & __that)204 auto operator=(const __intrusive_ptr& __that) noexcept -> __intrusive_ptr& { 205 return operator=(__intrusive_ptr(__that)); 206 } 207 operator =(__enable_intrusive_from_this<_Ty,_ReservedBits> * __that)208 auto operator=(__enable_intrusive_from_this<_Ty, _ReservedBits>* __that) noexcept 209 -> __intrusive_ptr& { 210 return operator=(__that ? __that->__intrusive_from_this() : __intrusive_ptr()); 211 } 212 ~__intrusive_ptr()213 ~__intrusive_ptr() { 214 __dec_ref_(); 215 } 216 reset()217 void reset() noexcept { 218 operator=({}); 219 } 220 swap(__intrusive_ptr & __that)221 void swap(__intrusive_ptr& __that) noexcept { 222 std::swap(__data_, __that.__data_); 223 } 224 get() const225 auto get() const noexcept -> _Ty* { 226 return &__data_->__value(); 227 } 228 operator ->() const229 auto operator->() const noexcept -> _Ty* { 230 return &__data_->__value(); 231 } 232 operator *() const233 auto operator*() const noexcept -> _Ty& { 234 return __data_->__value(); 235 } 236 operator bool() const237 explicit operator bool() const noexcept { 238 return __data_ != nullptr; 239 } 240 operator !() const241 auto operator!() const noexcept -> bool { 242 return __data_ == nullptr; 243 } 244 245 auto operator==(const __intrusive_ptr&) const -> bool = default; 246 operator ==(std::nullptr_t) const247 auto operator==(std::nullptr_t) const noexcept -> bool { 248 return __data_ == nullptr; 249 } 250 }; 251 252 template <class _Ty, std::size_t _ReservedBits> __intrusive_from_this()253 auto __enable_intrusive_from_this<_Ty, _ReservedBits>::__intrusive_from_this() noexcept 254 -> __intrusive_ptr<_Ty, _ReservedBits> { 255 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>( 256 &__c_downcast<_Ty>(*this)); 257 __data->__inc_ref_(); 258 return __intrusive_ptr<_Ty, _ReservedBits>{__data}; 259 } 260 261 template <class _Ty, std::size_t _ReservedBits> __intrusive_from_this() const262 auto __enable_intrusive_from_this<_Ty, _ReservedBits>::__intrusive_from_this() const noexcept 263 -> __intrusive_ptr<const _Ty, _ReservedBits> { 264 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>( 265 &__c_downcast<_Ty>(*this)); 266 __data->__inc_ref_(); 267 return __intrusive_ptr<const _Ty, _ReservedBits>{__data}; 268 } 269 270 template <class _Ty, std::size_t _ReservedBits> __inc_ref()271 auto __enable_intrusive_from_this<_Ty, _ReservedBits>::__inc_ref() noexcept 272 -> __ptr::__bits_t<_ReservedBits> { 273 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>( 274 &__c_downcast<_Ty>(*this)); 275 return __data->__inc_ref_(); 276 } 277 278 template <class _Ty, std::size_t _ReservedBits> __dec_ref()279 auto __enable_intrusive_from_this<_Ty, _ReservedBits>::__dec_ref() noexcept 280 -> __ptr::__bits_t<_ReservedBits> { 281 282 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>( 283 &__c_downcast<_Ty>(*this)); 284 return __data->__dec_ref_(); 285 } 286 287 template <class _Ty, std::size_t _ReservedBits> 288 template <std::size_t _Bit> __is_set() const289 auto __enable_intrusive_from_this<_Ty, _ReservedBits>::__is_set() const noexcept -> bool { 290 auto* __data = reinterpret_cast<const __control_block<_Ty, _ReservedBits>*>( 291 &__c_downcast<_Ty>(*this)); 292 return __data->template __is_set_<_Bit>(); 293 } 294 295 template <class _Ty, std::size_t _ReservedBits> 296 template <std::size_t _Bit> __set_bit()297 auto __enable_intrusive_from_this<_Ty, _ReservedBits>::__set_bit() noexcept 298 -> __ptr::__bits_t<_ReservedBits> { 299 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>( 300 &__c_downcast<_Ty>(*this)); 301 return __data->template __set_bit_<_Bit>(); 302 } 303 304 template <class _Ty, std::size_t _ReservedBits> 305 template <std::size_t _Bit> __clear_bit()306 auto __enable_intrusive_from_this<_Ty, _ReservedBits>::__clear_bit() noexcept 307 -> __ptr::__bits_t<_ReservedBits> { 308 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>( 309 &__c_downcast<_Ty>(*this)); 310 return __data->template __clear_bit_<_Bit>(); 311 } 312 313 template <class _Ty, std::size_t _ReservedBits> 314 struct __make_intrusive_t { 315 template <class... _Us> 316 requires constructible_from<_Ty, _Us...> operator ()stdexec::__ptr::__make_intrusive_t317 auto operator()(_Us&&... __us) const -> __intrusive_ptr<_Ty, _ReservedBits> { 318 using _UncvTy = std::remove_cv_t<_Ty>; 319 return __intrusive_ptr<_Ty, _ReservedBits>{ 320 ::new __control_block<_UncvTy, _ReservedBits>{static_cast<_Us&&>(__us)...}}; 321 } 322 }; 323 } // namespace __ptr 324 325 using __ptr::__intrusive_ptr; 326 using __ptr::__enable_intrusive_from_this; 327 template <class _Ty, std::size_t _ReservedBits = 0ul> 328 inline constexpr __ptr::__make_intrusive_t<_Ty, _ReservedBits> __make_intrusive{}; 329 330 } // namespace stdexec 331