1 /*
2 * Copyright (c) 2022 NVIDIA Corporation
3 *
4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5 * (the "License"); you may not use this file except in compliance with
6 * the License. You may obtain a copy of the License at
7 *
8 * https://llvm.org/LICENSE.txt
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #pragma once
17
18 #include "__concepts.hpp"
19 #include "__meta.hpp"
20
21 #include <atomic>
22 #include <cstddef>
23 #include <memory>
24 #include <new>
25 #include <type_traits>
26
27 #if STDEXEC_TSAN()
28 #include <sanitizer/tsan_interface.h>
29 #endif
30
31 namespace stdexec
32 {
33 namespace __ptr
34 {
35 template <std::size_t _ReservedBits>
36 struct __count_and_bits
37 {
38 static constexpr std::size_t __ref_count_increment = 1ul << _ReservedBits;
39
40 enum struct __bits : std::size_t
41 {
42 };
43
__count(__bits __b)44 friend constexpr std::size_t __count(__bits __b) noexcept
45 {
46 return static_cast<std::size_t>(__b) / __ref_count_increment;
47 }
48
49 template <std::size_t _Bit>
__bit(__bits __b)50 friend constexpr bool __bit(__bits __b) noexcept
51 {
52 static_assert(_Bit < _ReservedBits, "Bit index out of range");
53 return (static_cast<std::size_t>(__b) & (1ul << _Bit)) != 0;
54 }
55 };
56
57 template <std::size_t _ReservedBits>
58 using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
59
60 template <class _Ty, std::size_t _ReservedBits>
61 struct __make_intrusive_t;
62
63 template <class _Ty, std::size_t _ReservedBits = 0ul>
64 class __intrusive_ptr;
65
66 template <class _Ty, std::size_t _ReservedBits = 0ul>
67 struct __enable_intrusive_from_this
68 {
69 auto
70 __intrusive_from_this() noexcept -> __intrusive_ptr<_Ty, _ReservedBits>;
71 auto __intrusive_from_this() const noexcept
72 -> __intrusive_ptr<const _Ty, _ReservedBits>;
73
74 private:
75 using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
76 friend _Ty;
77 __bits_t __inc_ref() noexcept;
78 __bits_t __dec_ref() noexcept;
79
80 template <std::size_t _Bit>
81 bool __is_set() const noexcept;
82 template <std::size_t _Bit>
83 __bits_t __set_bit() noexcept;
84 template <std::size_t _Bit>
85 __bits_t __clear_bit() noexcept;
86 };
87
88 STDEXEC_PRAGMA_PUSH()
89 STDEXEC_PRAGMA_IGNORE_GNU("-Wtsan")
90
91 template <class _Ty, std::size_t _ReservedBits>
92 struct __control_block
93 {
94 using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
95 static constexpr std::size_t __ref_count_increment = 1ul << _ReservedBits;
96
97 alignas(_Ty) unsigned char __value_[sizeof(_Ty)];
98 std::atomic<std::size_t> __ref_count_;
99
100 template <class... _Us>
__control_blockstdexec::__ptr::__control_block101 explicit __control_block(_Us&&... __us) noexcept(noexcept(_Ty{
102 __declval<_Us>()...})) : __ref_count_(__ref_count_increment)
103 {
104 // Construct the value *after* the initialization of the atomic in case
105 // the constructor of _Ty calls __intrusive_from_this() (which
106 // increments the ref count):
107 ::new (static_cast<void*>(__value_)) _Ty{static_cast<_Us&&>(__us)...};
108 }
109
~__control_blockstdexec::__ptr::__control_block110 ~__control_block()
111 {
112 __value().~_Ty();
113 }
114
__valuestdexec::__ptr::__control_block115 auto __value() noexcept -> _Ty&
116 {
117 return *reinterpret_cast<_Ty*>(__value_);
118 }
119
__inc_ref_stdexec::__ptr::__control_block120 __bits_t __inc_ref_() noexcept
121 {
122 auto __old = __ref_count_.fetch_add(__ref_count_increment,
123 std::memory_order_relaxed);
124 return static_cast<__bits_t>(__old);
125 }
126
__dec_ref_stdexec::__ptr::__control_block127 __bits_t __dec_ref_() noexcept
128 {
129 auto __old = __ref_count_.fetch_sub(__ref_count_increment,
130 std::memory_order_acq_rel);
131 if (__count(static_cast<__bits_t>(__old)) == 1)
132 {
133 delete this;
134 }
135 return static_cast<__bits_t>(__old);
136 }
137
138 // Returns true if the bit was set, false if it was already set.
139 template <std::size_t _Bit>
__is_set_stdexec::__ptr::__control_block140 [[nodiscard]] bool __is_set_() const noexcept
141 {
142 auto __old = __ref_count_.load(std::memory_order_relaxed);
143 return __bit<_Bit>(static_cast<__bits_t>(__old));
144 }
145
146 template <std::size_t _Bit>
__set_bit_stdexec::__ptr::__control_block147 __bits_t __set_bit_() noexcept
148 {
149 static_assert(_Bit < _ReservedBits, "Bit index out of range");
150 constexpr std::size_t __mask = 1ul << _Bit;
151 auto __old = __ref_count_.fetch_or(__mask, std::memory_order_acq_rel);
152 return static_cast<__bits_t>(__old);
153 }
154
155 // Returns true if the bit was cleared, false if it was already cleared.
156 template <std::size_t _Bit>
__clear_bit_stdexec::__ptr::__control_block157 __bits_t __clear_bit_() noexcept
158 {
159 static_assert(_Bit < _ReservedBits, "Bit index out of range");
160 constexpr std::size_t __mask = 1ul << _Bit;
161 auto __old = __ref_count_.fetch_and(~__mask, std::memory_order_acq_rel);
162 return static_cast<__bits_t>(__old);
163 }
164 };
165
166 STDEXEC_PRAGMA_POP()
167
168 template <class _Ty, std::size_t _ReservedBits /* = 0ul */>
169 class __intrusive_ptr
170 {
171 using _UncvTy = std::remove_cv_t<_Ty>;
172 using __enable_intrusive_t =
173 __enable_intrusive_from_this<_UncvTy, _ReservedBits>;
174 friend _Ty;
175 friend struct __make_intrusive_t<_Ty, _ReservedBits>;
176 friend struct __enable_intrusive_from_this<_UncvTy, _ReservedBits>;
177
178 __control_block<_UncvTy, _ReservedBits>* __data_{nullptr};
179
__intrusive_ptr(__control_block<_UncvTy,_ReservedBits> * __data)180 explicit __intrusive_ptr(
181 __control_block<_UncvTy, _ReservedBits>* __data) noexcept :
182 __data_(__data)
183 {}
184
__inc_ref_()185 void __inc_ref_() noexcept
186 {
187 if (__data_)
188 {
189 __data_->__inc_ref_();
190 }
191 }
192
__dec_ref_()193 void __dec_ref_() noexcept
194 {
195 if (__data_)
196 {
197 __data_->__dec_ref_();
198 }
199 }
200
201 // For use when types want to take over manual control of the reference
202 // count. Very unsafe, but useful for implementing custom reference
203 // counting.
__release_()204 [[nodiscard]] __enable_intrusive_t* __release_() noexcept
205 {
206 auto* __data = std::exchange(__data_, nullptr);
207 return __data ? &__c_upcast<__enable_intrusive_t>(__data->__value())
208 : nullptr;
209 }
210
211 public:
212 using element_type = _Ty;
213
214 __intrusive_ptr() = default;
215
__intrusive_ptr(__intrusive_ptr && __that)216 __intrusive_ptr(__intrusive_ptr&& __that) noexcept :
217 __data_(std::exchange(__that.__data_, nullptr))
218 {}
219
__intrusive_ptr(const __intrusive_ptr & __that)220 __intrusive_ptr(const __intrusive_ptr& __that) noexcept :
221 __data_(__that.__data_)
222 {
223 __inc_ref_();
224 }
225
__intrusive_ptr(__enable_intrusive_from_this<_Ty,_ReservedBits> * __that)226 __intrusive_ptr(
227 __enable_intrusive_from_this<_Ty, _ReservedBits>* __that) noexcept :
228 __intrusive_ptr(
229 __that ? __that->__intrusive_from_this() : __intrusive_ptr())
230 {}
231
operator =(__intrusive_ptr && __that)232 auto operator=(__intrusive_ptr&& __that) noexcept -> __intrusive_ptr&
233 {
234 [[maybe_unused]] __intrusive_ptr __old{
235 std::exchange(__data_, std::exchange(__that.__data_, nullptr))};
236 return *this;
237 }
238
operator =(const __intrusive_ptr & __that)239 auto operator=(const __intrusive_ptr& __that) noexcept -> __intrusive_ptr&
240 {
241 return operator=(__intrusive_ptr(__that));
242 }
243
operator =(__enable_intrusive_from_this<_Ty,_ReservedBits> * __that)244 auto operator=(__enable_intrusive_from_this<_Ty, _ReservedBits>*
245 __that) noexcept -> __intrusive_ptr&
246 {
247 return operator=(
248 __that ? __that->__intrusive_from_this() : __intrusive_ptr());
249 }
250
~__intrusive_ptr()251 ~__intrusive_ptr()
252 {
253 __dec_ref_();
254 }
255
reset()256 void reset() noexcept
257 {
258 operator=({});
259 }
260
swap(__intrusive_ptr & __that)261 void swap(__intrusive_ptr& __that) noexcept
262 {
263 std::swap(__data_, __that.__data_);
264 }
265
get() const266 auto get() const noexcept -> _Ty*
267 {
268 return &__data_->__value();
269 }
270
operator ->() const271 auto operator->() const noexcept -> _Ty*
272 {
273 return &__data_->__value();
274 }
275
operator *() const276 auto operator*() const noexcept -> _Ty&
277 {
278 return __data_->__value();
279 }
280
operator bool() const281 explicit operator bool() const noexcept
282 {
283 return __data_ != nullptr;
284 }
285
operator !() const286 auto operator!() const noexcept -> bool
287 {
288 return __data_ == nullptr;
289 }
290
291 auto operator==(const __intrusive_ptr&) const -> bool = default;
292
operator ==(std::nullptr_t) const293 auto operator==(std::nullptr_t) const noexcept -> bool
294 {
295 return __data_ == nullptr;
296 }
297 };
298
299 template <class _Ty, std::size_t _ReservedBits>
300 auto __enable_intrusive_from_this<
__intrusive_from_this()301 _Ty, _ReservedBits>::__intrusive_from_this() noexcept
302 -> __intrusive_ptr<_Ty, _ReservedBits>
303 {
304 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
305 &__c_downcast<_Ty>(*this));
306 __data->__inc_ref_();
307 return __intrusive_ptr<_Ty, _ReservedBits>{__data};
308 }
309
310 template <class _Ty, std::size_t _ReservedBits>
__intrusive_from_this() const311 auto __enable_intrusive_from_this<_Ty, _ReservedBits>::__intrusive_from_this()
312 const noexcept -> __intrusive_ptr<const _Ty, _ReservedBits>
313 {
314 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
315 &__c_downcast<_Ty>(*this));
316 __data->__inc_ref_();
317 return __intrusive_ptr<const _Ty, _ReservedBits>{__data};
318 }
319
320 template <class _Ty, std::size_t _ReservedBits>
321 __bits_t<_ReservedBits>
__inc_ref()322 __enable_intrusive_from_this<_Ty, _ReservedBits>::__inc_ref() noexcept
323 {
324 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
325 &__c_downcast<_Ty>(*this));
326 return __data->__inc_ref_();
327 }
328
329 template <class _Ty, std::size_t _ReservedBits>
330 __bits_t<_ReservedBits>
__dec_ref()331 __enable_intrusive_from_this<_Ty, _ReservedBits>::__dec_ref() noexcept
332 {
333 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
334 &__c_downcast<_Ty>(*this));
335 return __data->__dec_ref_();
336 }
337
338 template <class _Ty, std::size_t _ReservedBits>
339 template <std::size_t _Bit>
__is_set() const340 bool __enable_intrusive_from_this<_Ty, _ReservedBits>::__is_set() const noexcept
341 {
342 auto* __data = reinterpret_cast<const __control_block<_Ty, _ReservedBits>*>(
343 &__c_downcast<_Ty>(*this));
344 return __data->template __is_set_<_Bit>();
345 }
346
347 template <class _Ty, std::size_t _ReservedBits>
348 template <std::size_t _Bit>
349 __bits_t<_ReservedBits>
__set_bit()350 __enable_intrusive_from_this<_Ty, _ReservedBits>::__set_bit() noexcept
351 {
352 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
353 &__c_downcast<_Ty>(*this));
354 return __data->template __set_bit_<_Bit>();
355 }
356
357 template <class _Ty, std::size_t _ReservedBits>
358 template <std::size_t _Bit>
359 __bits_t<_ReservedBits>
__clear_bit()360 __enable_intrusive_from_this<_Ty, _ReservedBits>::__clear_bit() noexcept
361 {
362 auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
363 &__c_downcast<_Ty>(*this));
364 return __data->template __clear_bit_<_Bit>();
365 }
366
367 template <class _Ty, std::size_t _ReservedBits>
368 struct __make_intrusive_t
369 {
370 template <class... _Us>
371 requires constructible_from<_Ty, _Us...>
operator ()stdexec::__ptr::__make_intrusive_t372 auto operator()(_Us&&... __us) const -> __intrusive_ptr<_Ty, _ReservedBits>
373 {
374 using _UncvTy = std::remove_cv_t<_Ty>;
375 return __intrusive_ptr<_Ty, _ReservedBits>{
376 ::new __control_block<_UncvTy, _ReservedBits>{
377 static_cast<_Us&&>(__us)...}};
378 }
379 };
380 } // namespace __ptr
381
382 using __ptr::__enable_intrusive_from_this;
383 using __ptr::__intrusive_ptr;
384 template <class _Ty, std::size_t _ReservedBits = 0ul>
385 inline constexpr __ptr::__make_intrusive_t<_Ty, _ReservedBits>
386 __make_intrusive{};
387
388 } // namespace stdexec
389