1 /*
2  * Copyright (c) 2022 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__concepts.hpp"
19 #include "__meta.hpp"
20 
21 #include <atomic>
22 #include <cstddef>
23 #include <memory>
24 #include <new>
25 #include <type_traits>
26 
27 #if STDEXEC_TSAN()
28 #include <sanitizer/tsan_interface.h>
29 #endif
30 
31 namespace stdexec
32 {
33 namespace __ptr
34 {
35 template <std::size_t _ReservedBits>
36 struct __count_and_bits
37 {
38     static constexpr std::size_t __ref_count_increment = 1ul << _ReservedBits;
39 
40     enum struct __bits : std::size_t
41     {
42     };
43 
__count(__bits __b)44     friend constexpr std::size_t __count(__bits __b) noexcept
45     {
46         return static_cast<std::size_t>(__b) / __ref_count_increment;
47     }
48 
49     template <std::size_t _Bit>
__bit(__bits __b)50     friend constexpr bool __bit(__bits __b) noexcept
51     {
52         static_assert(_Bit < _ReservedBits, "Bit index out of range");
53         return (static_cast<std::size_t>(__b) & (1ul << _Bit)) != 0;
54     }
55 };
56 
57 template <std::size_t _ReservedBits>
58 using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
59 
60 template <class _Ty, std::size_t _ReservedBits>
61 struct __make_intrusive_t;
62 
63 template <class _Ty, std::size_t _ReservedBits = 0ul>
64 class __intrusive_ptr;
65 
66 template <class _Ty, std::size_t _ReservedBits = 0ul>
67 struct __enable_intrusive_from_this
68 {
69     auto
70         __intrusive_from_this() noexcept -> __intrusive_ptr<_Ty, _ReservedBits>;
71     auto __intrusive_from_this() const noexcept
72         -> __intrusive_ptr<const _Ty, _ReservedBits>;
73 
74   private:
75     using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
76     friend _Ty;
77     __bits_t __inc_ref() noexcept;
78     __bits_t __dec_ref() noexcept;
79 
80     template <std::size_t _Bit>
81     bool __is_set() const noexcept;
82     template <std::size_t _Bit>
83     __bits_t __set_bit() noexcept;
84     template <std::size_t _Bit>
85     __bits_t __clear_bit() noexcept;
86 };
87 
88 STDEXEC_PRAGMA_PUSH()
89 STDEXEC_PRAGMA_IGNORE_GNU("-Wtsan")
90 
91 template <class _Ty, std::size_t _ReservedBits>
92 struct __control_block
93 {
94     using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
95     static constexpr std::size_t __ref_count_increment = 1ul << _ReservedBits;
96 
97     alignas(_Ty) unsigned char __value_[sizeof(_Ty)];
98     std::atomic<std::size_t> __ref_count_;
99 
100     template <class... _Us>
__control_blockstdexec::__ptr::__control_block101     explicit __control_block(_Us&&... __us) noexcept(noexcept(_Ty{
102         __declval<_Us>()...})) : __ref_count_(__ref_count_increment)
103     {
104         // Construct the value *after* the initialization of the atomic in case
105         // the constructor of _Ty calls __intrusive_from_this() (which
106         // increments the ref count):
107         ::new (static_cast<void*>(__value_)) _Ty{static_cast<_Us&&>(__us)...};
108     }
109 
~__control_blockstdexec::__ptr::__control_block110     ~__control_block()
111     {
112         __value().~_Ty();
113     }
114 
__valuestdexec::__ptr::__control_block115     auto __value() noexcept -> _Ty&
116     {
117         return *reinterpret_cast<_Ty*>(__value_);
118     }
119 
__inc_ref_stdexec::__ptr::__control_block120     __bits_t __inc_ref_() noexcept
121     {
122         auto __old = __ref_count_.fetch_add(__ref_count_increment,
123                                             std::memory_order_relaxed);
124         return static_cast<__bits_t>(__old);
125     }
126 
__dec_ref_stdexec::__ptr::__control_block127     __bits_t __dec_ref_() noexcept
128     {
129         auto __old = __ref_count_.fetch_sub(__ref_count_increment,
130                                             std::memory_order_acq_rel);
131         if (__count(static_cast<__bits_t>(__old)) == 1)
132         {
133             delete this;
134         }
135         return static_cast<__bits_t>(__old);
136     }
137 
138     // Returns true if the bit was set, false if it was already set.
139     template <std::size_t _Bit>
__is_set_stdexec::__ptr::__control_block140     [[nodiscard]] bool __is_set_() const noexcept
141     {
142         auto __old = __ref_count_.load(std::memory_order_relaxed);
143         return __bit<_Bit>(static_cast<__bits_t>(__old));
144     }
145 
146     template <std::size_t _Bit>
__set_bit_stdexec::__ptr::__control_block147     __bits_t __set_bit_() noexcept
148     {
149         static_assert(_Bit < _ReservedBits, "Bit index out of range");
150         constexpr std::size_t __mask = 1ul << _Bit;
151         auto __old = __ref_count_.fetch_or(__mask, std::memory_order_acq_rel);
152         return static_cast<__bits_t>(__old);
153     }
154 
155     // Returns true if the bit was cleared, false if it was already cleared.
156     template <std::size_t _Bit>
__clear_bit_stdexec::__ptr::__control_block157     __bits_t __clear_bit_() noexcept
158     {
159         static_assert(_Bit < _ReservedBits, "Bit index out of range");
160         constexpr std::size_t __mask = 1ul << _Bit;
161         auto __old = __ref_count_.fetch_and(~__mask, std::memory_order_acq_rel);
162         return static_cast<__bits_t>(__old);
163     }
164 };
165 
166 STDEXEC_PRAGMA_POP()
167 
168 template <class _Ty, std::size_t _ReservedBits /* = 0ul */>
169 class __intrusive_ptr
170 {
171     using _UncvTy = std::remove_cv_t<_Ty>;
172     using __enable_intrusive_t =
173         __enable_intrusive_from_this<_UncvTy, _ReservedBits>;
174     friend _Ty;
175     friend struct __make_intrusive_t<_Ty, _ReservedBits>;
176     friend struct __enable_intrusive_from_this<_UncvTy, _ReservedBits>;
177 
178     __control_block<_UncvTy, _ReservedBits>* __data_{nullptr};
179 
__intrusive_ptr(__control_block<_UncvTy,_ReservedBits> * __data)180     explicit __intrusive_ptr(
181         __control_block<_UncvTy, _ReservedBits>* __data) noexcept :
182         __data_(__data)
183     {}
184 
__inc_ref_()185     void __inc_ref_() noexcept
186     {
187         if (__data_)
188         {
189             __data_->__inc_ref_();
190         }
191     }
192 
__dec_ref_()193     void __dec_ref_() noexcept
194     {
195         if (__data_)
196         {
197             __data_->__dec_ref_();
198         }
199     }
200 
201     // For use when types want to take over manual control of the reference
202     // count. Very unsafe, but useful for implementing custom reference
203     // counting.
__release_()204     [[nodiscard]] __enable_intrusive_t* __release_() noexcept
205     {
206         auto* __data = std::exchange(__data_, nullptr);
207         return __data ? &__c_upcast<__enable_intrusive_t>(__data->__value())
208                       : nullptr;
209     }
210 
211   public:
212     using element_type = _Ty;
213 
214     __intrusive_ptr() = default;
215 
__intrusive_ptr(__intrusive_ptr && __that)216     __intrusive_ptr(__intrusive_ptr&& __that) noexcept :
217         __data_(std::exchange(__that.__data_, nullptr))
218     {}
219 
__intrusive_ptr(const __intrusive_ptr & __that)220     __intrusive_ptr(const __intrusive_ptr& __that) noexcept :
221         __data_(__that.__data_)
222     {
223         __inc_ref_();
224     }
225 
__intrusive_ptr(__enable_intrusive_from_this<_Ty,_ReservedBits> * __that)226     __intrusive_ptr(
227         __enable_intrusive_from_this<_Ty, _ReservedBits>* __that) noexcept :
228         __intrusive_ptr(
229             __that ? __that->__intrusive_from_this() : __intrusive_ptr())
230     {}
231 
operator =(__intrusive_ptr && __that)232     auto operator=(__intrusive_ptr&& __that) noexcept -> __intrusive_ptr&
233     {
234         [[maybe_unused]] __intrusive_ptr __old{
235             std::exchange(__data_, std::exchange(__that.__data_, nullptr))};
236         return *this;
237     }
238 
operator =(const __intrusive_ptr & __that)239     auto operator=(const __intrusive_ptr& __that) noexcept -> __intrusive_ptr&
240     {
241         return operator=(__intrusive_ptr(__that));
242     }
243 
operator =(__enable_intrusive_from_this<_Ty,_ReservedBits> * __that)244     auto operator=(__enable_intrusive_from_this<_Ty, _ReservedBits>*
245                        __that) noexcept -> __intrusive_ptr&
246     {
247         return operator=(
248             __that ? __that->__intrusive_from_this() : __intrusive_ptr());
249     }
250 
~__intrusive_ptr()251     ~__intrusive_ptr()
252     {
253         __dec_ref_();
254     }
255 
reset()256     void reset() noexcept
257     {
258         operator=({});
259     }
260 
swap(__intrusive_ptr & __that)261     void swap(__intrusive_ptr& __that) noexcept
262     {
263         std::swap(__data_, __that.__data_);
264     }
265 
get() const266     auto get() const noexcept -> _Ty*
267     {
268         return &__data_->__value();
269     }
270 
operator ->() const271     auto operator->() const noexcept -> _Ty*
272     {
273         return &__data_->__value();
274     }
275 
operator *() const276     auto operator*() const noexcept -> _Ty&
277     {
278         return __data_->__value();
279     }
280 
operator bool() const281     explicit operator bool() const noexcept
282     {
283         return __data_ != nullptr;
284     }
285 
operator !() const286     auto operator!() const noexcept -> bool
287     {
288         return __data_ == nullptr;
289     }
290 
291     auto operator==(const __intrusive_ptr&) const -> bool = default;
292 
operator ==(std::nullptr_t) const293     auto operator==(std::nullptr_t) const noexcept -> bool
294     {
295         return __data_ == nullptr;
296     }
297 };
298 
299 template <class _Ty, std::size_t _ReservedBits>
300 auto __enable_intrusive_from_this<
__intrusive_from_this()301     _Ty, _ReservedBits>::__intrusive_from_this() noexcept
302     -> __intrusive_ptr<_Ty, _ReservedBits>
303 {
304     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
305         &__c_downcast<_Ty>(*this));
306     __data->__inc_ref_();
307     return __intrusive_ptr<_Ty, _ReservedBits>{__data};
308 }
309 
310 template <class _Ty, std::size_t _ReservedBits>
__intrusive_from_this() const311 auto __enable_intrusive_from_this<_Ty, _ReservedBits>::__intrusive_from_this()
312     const noexcept -> __intrusive_ptr<const _Ty, _ReservedBits>
313 {
314     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
315         &__c_downcast<_Ty>(*this));
316     __data->__inc_ref_();
317     return __intrusive_ptr<const _Ty, _ReservedBits>{__data};
318 }
319 
320 template <class _Ty, std::size_t _ReservedBits>
321 __bits_t<_ReservedBits>
__inc_ref()322     __enable_intrusive_from_this<_Ty, _ReservedBits>::__inc_ref() noexcept
323 {
324     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
325         &__c_downcast<_Ty>(*this));
326     return __data->__inc_ref_();
327 }
328 
329 template <class _Ty, std::size_t _ReservedBits>
330 __bits_t<_ReservedBits>
__dec_ref()331     __enable_intrusive_from_this<_Ty, _ReservedBits>::__dec_ref() noexcept
332 {
333     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
334         &__c_downcast<_Ty>(*this));
335     return __data->__dec_ref_();
336 }
337 
338 template <class _Ty, std::size_t _ReservedBits>
339 template <std::size_t _Bit>
__is_set() const340 bool __enable_intrusive_from_this<_Ty, _ReservedBits>::__is_set() const noexcept
341 {
342     auto* __data = reinterpret_cast<const __control_block<_Ty, _ReservedBits>*>(
343         &__c_downcast<_Ty>(*this));
344     return __data->template __is_set_<_Bit>();
345 }
346 
347 template <class _Ty, std::size_t _ReservedBits>
348 template <std::size_t _Bit>
349 __bits_t<_ReservedBits>
__set_bit()350     __enable_intrusive_from_this<_Ty, _ReservedBits>::__set_bit() noexcept
351 {
352     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
353         &__c_downcast<_Ty>(*this));
354     return __data->template __set_bit_<_Bit>();
355 }
356 
357 template <class _Ty, std::size_t _ReservedBits>
358 template <std::size_t _Bit>
359 __bits_t<_ReservedBits>
__clear_bit()360     __enable_intrusive_from_this<_Ty, _ReservedBits>::__clear_bit() noexcept
361 {
362     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
363         &__c_downcast<_Ty>(*this));
364     return __data->template __clear_bit_<_Bit>();
365 }
366 
367 template <class _Ty, std::size_t _ReservedBits>
368 struct __make_intrusive_t
369 {
370     template <class... _Us>
371         requires constructible_from<_Ty, _Us...>
operator ()stdexec::__ptr::__make_intrusive_t372     auto operator()(_Us&&... __us) const -> __intrusive_ptr<_Ty, _ReservedBits>
373     {
374         using _UncvTy = std::remove_cv_t<_Ty>;
375         return __intrusive_ptr<_Ty, _ReservedBits>{
376             ::new __control_block<_UncvTy, _ReservedBits>{
377                 static_cast<_Us&&>(__us)...}};
378     }
379 };
380 } // namespace __ptr
381 
382 using __ptr::__enable_intrusive_from_this;
383 using __ptr::__intrusive_ptr;
384 template <class _Ty, std::size_t _ReservedBits = 0ul>
385 inline constexpr __ptr::__make_intrusive_t<_Ty, _ReservedBits>
386     __make_intrusive{};
387 
388 } // namespace stdexec
389