1 /*
2 * Copyright (c) 2022 NVIDIA Corporation
3 *
4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5 * (the "License"); you may not use this file except in compliance with
6 * the License. You may obtain a copy of the License at
7 *
8 * https://llvm.org/LICENSE.txt
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #pragma once
17
18 #include "__concepts.hpp"
19 #include "__meta.hpp"
20
21 #include <atomic>
22 #include <cstddef>
23 #include <memory>
24 #include <new>
25 #include <type_traits>
26
27 #if STDEXEC_TSAN()
28 #include <sanitizer/tsan_interface.h>
29 #endif
30
31 namespace stdexec
32 {
33 namespace __ptr
34 {
35 template <std::size_t _ReservedBits>
36 struct __count_and_bits
37 {
38 static constexpr std::size_t __ref_count_increment = 1ul << _ReservedBits;
39
40 enum struct __bits : std::size_t
41 {
42 };
43
__count(__bits __b)44 friend constexpr std::size_t __count(__bits __b) noexcept
45 {
46 return static_cast<std::size_t>(__b) / __ref_count_increment;
47 }
48
49 template <std::size_t _Bit>
__bit(__bits __b)50 friend constexpr bool __bit(__bits __b) noexcept
51 {
52 static_assert(_Bit < _ReservedBits, "Bit index out of range");
53 return (static_cast<std::size_t>(__b) & (1ul << _Bit)) != 0;
54 }
55 };
56
57 template <std::size_t _ReservedBits>
58 using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
59
60 template <class _Ty, std::size_t _ReservedBits>
61 struct __make_intrusive_t;
62
63 template <class _Ty, std::size_t _ReservedBits = 0ul>
64 class __intrusive_ptr;
65
66 template <class _Ty, std::size_t _ReservedBits = 0ul>
67 struct __enable_intrusive_from_this
68 {
69 auto __intrusive_from_this() noexcept
70 -> __intrusive_ptr<_Ty, _ReservedBits>;
71 auto __intrusive_from_this() const noexcept
72 -> __intrusive_ptr<const _Ty, _ReservedBits>;
73
74 private:
75 using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
76 friend _Ty;
77 __bits_t __inc_ref() noexcept;
78 __bits_t __dec_ref() noexcept;
79
80 template <std::size_t _Bit>
81 bool __is_set() const noexcept;
82 template <std::size_t _Bit>
83 __bits_t __set_bit() noexcept;
84 template <std::size_t _Bit>
85 __bits_t __clear_bit() noexcept;
86 };
87
88 STDEXEC_PRAGMA_PUSH()
89 STDEXEC_PRAGMA_IGNORE_GNU("-Wtsan")
90
91 template <class _Ty, std::size_t _ReservedBits>
92 struct __control_block
93 {
94 using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
95 static constexpr std::size_t __ref_count_increment = 1ul << _ReservedBits;
96
97 alignas(_Ty) unsigned char __value_[sizeof(_Ty)];
98 std::atomic<std::size_t> __ref_count_;
99
100 template <class... _Us>
__control_blockstdexec::__ptr::__control_block101 explicit __control_block(_Us&&... __us) noexcept(noexcept(_Ty{
102 __declval<_Us>()...})) : __ref_count_(__ref_count_increment)
103 {
104 // Construct the value *after* the initialization of the atomic in case
105 // the constructor of _Ty calls __intrusive_from_this() (which
106 // increments the ref count):
107 ::new (static_cast<void*>(__value_)) _Ty{static_cast<_Us&&>(__us)...};
108 }
109
~__control_blockstdexec::__ptr::__control_block110 ~__control_block()
111 {
112 __value().~_Ty();
113 }
114
__valuestdexec::__ptr::__control_block115 auto __value() noexcept -> _Ty&
116 {
117 return *reinterpret_cast<_Ty*>(__value_);
118 }
119
__inc_ref_stdexec::__ptr::__control_block120 __bits_t __inc_ref_() noexcept
121 {
122 auto __old = __ref_count_.fetch_add(__ref_count_increment,
123 std::memory_order_relaxed);
124 return static_cast<__bits_t>(__old);
125 }
126
__dec_ref_stdexec::__ptr::__control_block127 __bits_t __dec_ref_() noexcept
128 {
129 auto __old = __ref_count_.fetch_sub(__ref_count_increment,
130 std::memory_order_acq_rel);
131 if (__count(static_cast<__bits_t>(__old)) == 1)
132 {
133 delete this;
134 }
135 return static_cast<__bits_t>(__old);
136 }
137
138 // Returns true if the bit was set, false if it was already set.
139 template <std::size_t _Bit>
__is_set_stdexec::__ptr::__control_block140 [[nodiscard]] bool __is_set_() const noexcept
141 {
142 auto __old = __ref_count_.load(std::memory_order_relaxed);
143 return __bit<_Bit>(static_cast<__bits_t>(__old));
144 }
145
146 template <std::size_t _Bit>
__set_bit_stdexec::__ptr::__control_block147 __bits_t __set_bit_() noexcept
148 {
149 static_assert(_Bit < _ReservedBits, "Bit index out of range");
150 constexpr std::size_t __mask = 1ul << _Bit;
151 auto __old = __ref_count_.fetch_or(__mask, std::memory_order_acq_rel);
152 return static_cast<__bits_t>(__old);
153 }
154
155 // Returns true if the bit was cleared, false if it was already cleared.
156 template <std::size_t _Bit>
__clear_bit_stdexec::__ptr::__control_block157 __bits_t __clear_bit_() noexcept
158 {
159 static_assert(_Bit < _ReservedBits, "Bit index out of range");
160 constexpr std::size_t __mask = 1ul << _Bit;
161 auto __old = __ref_count_.fetch_and(~__mask, std::memory_order_acq_rel);
162 return static_cast<__bits_t>(__old);
163 }
164 };
165
166 STDEXEC_PRAGMA_POP()
167
168 template <class _Ty, std::size_t _ReservedBits /* = 0ul */>
169 class __intrusive_ptr
170 {
171 using _UncvTy = std::remove_cv_t<_Ty>;
172 using __enable_intrusive_t =
173 __enable_intrusive_from_this<_UncvTy, _ReservedBits>;
174 friend _Ty;
175 friend struct __make_intrusive_t<_Ty, _ReservedBits>;
176 friend struct __enable_intrusive_from_this<_UncvTy, _ReservedBits>;
177
178 __control_block<_UncvTy, _ReservedBits>* __data_{nullptr};
179
__intrusive_ptr(__control_block<_UncvTy,_ReservedBits> * __data)180 explicit __intrusive_ptr(
181 __control_block<_UncvTy, _ReservedBits>* __data) noexcept :
182 __data_(__data)
183 {}
184
__inc_ref_()185 void __inc_ref_() noexcept
186 {
187 if (__data_)
188 {
189 __data_->__inc_ref_();
190 }
191 }
192
__dec_ref_()193 void __dec_ref_() noexcept
194 {
195 if (__data_)
196 {
197 __data_->__dec_ref_();
198 }
199 }
200
201 // For use when types want to take over manual control of the reference
202 // count. Very unsafe, but useful for implementing custom reference
203 // counting.
__release_()204 [[nodiscard]] __enable_intrusive_t* __release_() noexcept
205 {
206 auto* __data = std::exchange(__data_, nullptr);
207 return __data ? &__c_upcast<__enable_intrusive_t>(__data->__value())
208 : nullptr;
209 }
210
211 public:
212 using element_type = _Ty;
213
214 __intrusive_ptr() = default;
215
__intrusive_ptr(__intrusive_ptr && __that)216 __intrusive_ptr(__intrusive_ptr&& __that) noexcept :
217 __data_(std::exchange(__that.__data_, nullptr))
218 {}
219
__intrusive_ptr(const __intrusive_ptr & __that)220 __intrusive_ptr(const __intrusive_ptr& __that) noexcept :
221 __data_(__that.__data_)
222 {
223 __inc_ref_();
224 }
225
__intrusive_ptr(__enable_intrusive_from_this<_Ty,_ReservedBits> * __that)226 __intrusive_ptr(
227 __enable_intrusive_from_this<_Ty, _ReservedBits>* __that) noexcept :
228 __intrusive_ptr(
229 __that ? __that->__intrusive_from_this() : __intrusive_ptr())
230 {}
231
operator =(__intrusive_ptr && __that)232 auto operator=(__intrusive_ptr&& __that) noexcept -> __intrusive_ptr&
233 {
234 [[maybe_unused]] __intrusive_ptr __old{
235 std::exchange(__data_, std::exchange(__that.__data_, nullptr))};
236 return *this;
237 }
238
operator =(const __intrusive_ptr & __that)239 auto operator=(const __intrusive_ptr& __that) noexcept -> __intrusive_ptr&
240 {
241 return operator=(__intrusive_ptr(__that));
242 }
243
operator =(__enable_intrusive_from_this<_Ty,_ReservedBits> * __that)244 auto operator=(
245 __enable_intrusive_from_this<_Ty, _ReservedBits>* __that) noexcept
246 -> __intrusive_ptr&
247 {
248 return operator=(
249 __that ? __that->__intrusive_from_this() : __intrusive_ptr());
250 }
251
~__intrusive_ptr()252 ~__intrusive_ptr()
253 {
254 __dec_ref_();
255 }
256
reset()257 void reset() noexcept
258 {
259 operator=({});
260 }
261
swap(__intrusive_ptr & __that)262 void swap(__intrusive_ptr& __that) noexcept
263 {
264 std::swap(__data_, __that.__data_);
265 }
266
get() const267 auto get() const noexcept -> _Ty*
268 {
269 return &__data_->__value();
270 }
271
operator ->() const272 auto operator->() const noexcept -> _Ty*
273 {
274 return &__data_->__value();
275 }
276
operator *() const277 auto operator*() const noexcept -> _Ty&
278 {
279 return __data_->__value();
280 }
281
operator bool() const282 explicit operator bool() const noexcept
283 {
284 return __data_ != nullptr;
285 }
286
operator !() const287 auto operator!() const noexcept -> bool
288 {
289 return __data_ == nullptr;
290 }
291
292 auto operator==(const __intrusive_ptr&) const -> bool = default;
293
operator ==(std::nullptr_t) const294 auto operator==(std::nullptr_t) const noexcept -> bool
295 {
296 return __data_ == nullptr;
297 }
298 };
299
300 template <class _Ty, std::size_t _ReservedBits>
301 auto __enable_intrusive_from_this<
__intrusive_from_this()302 _Ty, _ReservedBits>::__intrusive_from_this() noexcept
303 -> __intrusive_ptr<_Ty, _ReservedBits>
304 {
305 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
306 &__c_downcast<_Ty>(*this));
307 __data->__inc_ref_();
308 return __intrusive_ptr<_Ty, _ReservedBits>{__data};
309 }
310
311 template <class _Ty, std::size_t _ReservedBits>
__intrusive_from_this() const312 auto __enable_intrusive_from_this<_Ty, _ReservedBits>::__intrusive_from_this()
313 const noexcept -> __intrusive_ptr<const _Ty, _ReservedBits>
314 {
315 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
316 &__c_downcast<_Ty>(*this));
317 __data->__inc_ref_();
318 return __intrusive_ptr<const _Ty, _ReservedBits>{__data};
319 }
320
321 template <class _Ty, std::size_t _ReservedBits>
322 __bits_t<_ReservedBits>
__inc_ref()323 __enable_intrusive_from_this<_Ty, _ReservedBits>::__inc_ref() noexcept
324 {
325 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
326 &__c_downcast<_Ty>(*this));
327 return __data->__inc_ref_();
328 }
329
330 template <class _Ty, std::size_t _ReservedBits>
331 __bits_t<_ReservedBits>
__dec_ref()332 __enable_intrusive_from_this<_Ty, _ReservedBits>::__dec_ref() noexcept
333 {
334 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
335 &__c_downcast<_Ty>(*this));
336 return __data->__dec_ref_();
337 }
338
339 template <class _Ty, std::size_t _ReservedBits>
340 template <std::size_t _Bit>
__is_set() const341 bool __enable_intrusive_from_this<_Ty, _ReservedBits>::__is_set() const noexcept
342 {
343 auto* __data = reinterpret_cast<const __control_block<_Ty, _ReservedBits>*>(
344 &__c_downcast<_Ty>(*this));
345 return __data->template __is_set_<_Bit>();
346 }
347
348 template <class _Ty, std::size_t _ReservedBits>
349 template <std::size_t _Bit>
350 __bits_t<_ReservedBits>
__set_bit()351 __enable_intrusive_from_this<_Ty, _ReservedBits>::__set_bit() noexcept
352 {
353 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
354 &__c_downcast<_Ty>(*this));
355 return __data->template __set_bit_<_Bit>();
356 }
357
358 template <class _Ty, std::size_t _ReservedBits>
359 template <std::size_t _Bit>
360 __bits_t<_ReservedBits>
__clear_bit()361 __enable_intrusive_from_this<_Ty, _ReservedBits>::__clear_bit() noexcept
362 {
363 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
364 &__c_downcast<_Ty>(*this));
365 return __data->template __clear_bit_<_Bit>();
366 }
367
368 template <class _Ty, std::size_t _ReservedBits>
369 struct __make_intrusive_t
370 {
371 template <class... _Us>
372 requires constructible_from<_Ty, _Us...>
operator ()stdexec::__ptr::__make_intrusive_t373 auto operator()(_Us&&... __us) const -> __intrusive_ptr<_Ty, _ReservedBits>
374 {
375 using _UncvTy = std::remove_cv_t<_Ty>;
376 return __intrusive_ptr<_Ty, _ReservedBits>{
377 ::new __control_block<_UncvTy, _ReservedBits>{
378 static_cast<_Us&&>(__us)...}};
379 }
380 };
381 } // namespace __ptr
382
383 using __ptr::__enable_intrusive_from_this;
384 using __ptr::__intrusive_ptr;
385 template <class _Ty, std::size_t _ReservedBits = 0ul>
386 inline constexpr __ptr::__make_intrusive_t<_Ty, _ReservedBits>
387 __make_intrusive{};
388
389 } // namespace stdexec
390