1 /*
2  * Copyright (c) 2022 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__concepts.hpp"
19 #include "__meta.hpp"
20 
21 #include <atomic>
22 #include <cstddef>
23 #include <memory>
24 #include <new>
25 #include <type_traits>
26 
27 #if STDEXEC_TSAN()
28 #include <sanitizer/tsan_interface.h>
29 #endif
30 
31 namespace stdexec
32 {
33 namespace __ptr
34 {
35 template <std::size_t _ReservedBits>
36 struct __count_and_bits
37 {
38     static constexpr std::size_t __ref_count_increment = 1ul << _ReservedBits;
39 
40     enum struct __bits : std::size_t
41     {
42     };
43 
44     friend constexpr std::size_t __count(__bits __b) noexcept
45     {
46         return static_cast<std::size_t>(__b) / __ref_count_increment;
47     }
48 
49     template <std::size_t _Bit>
50     friend constexpr bool __bit(__bits __b) noexcept
51     {
52         static_assert(_Bit < _ReservedBits, "Bit index out of range");
53         return (static_cast<std::size_t>(__b) & (1ul << _Bit)) != 0;
54     }
55 };
56 
57 template <std::size_t _ReservedBits>
58 using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
59 
60 template <class _Ty, std::size_t _ReservedBits>
61 struct __make_intrusive_t;
62 
63 template <class _Ty, std::size_t _ReservedBits = 0ul>
64 class __intrusive_ptr;
65 
66 template <class _Ty, std::size_t _ReservedBits = 0ul>
67 struct __enable_intrusive_from_this
68 {
69     auto __intrusive_from_this() noexcept
70         -> __intrusive_ptr<_Ty, _ReservedBits>;
71     auto __intrusive_from_this() const noexcept
72         -> __intrusive_ptr<const _Ty, _ReservedBits>;
73 
74   private:
75     using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
76     friend _Ty;
77     __bits_t __inc_ref() noexcept;
78     __bits_t __dec_ref() noexcept;
79 
80     template <std::size_t _Bit>
81     bool __is_set() const noexcept;
82     template <std::size_t _Bit>
83     __bits_t __set_bit() noexcept;
84     template <std::size_t _Bit>
85     __bits_t __clear_bit() noexcept;
86 };
87 
88 STDEXEC_PRAGMA_PUSH()
89 STDEXEC_PRAGMA_IGNORE_GNU("-Wtsan")
90 
91 template <class _Ty, std::size_t _ReservedBits>
92 struct __control_block
93 {
94     using __bits_t = typename __count_and_bits<_ReservedBits>::__bits;
95     static constexpr std::size_t __ref_count_increment = 1ul << _ReservedBits;
96 
97     alignas(_Ty) unsigned char __value_[sizeof(_Ty)];
98     std::atomic<std::size_t> __ref_count_;
99 
100     template <class... _Us>
101     explicit __control_block(_Us&&... __us) noexcept(noexcept(_Ty{
102         __declval<_Us>()...})) :
103         __ref_count_(__ref_count_increment)
104     {
105         // Construct the value *after* the initialization of the atomic in case
106         // the constructor of _Ty calls __intrusive_from_this() (which
107         // increments the ref count):
108         ::new (static_cast<void*>(__value_)) _Ty{static_cast<_Us&&>(__us)...};
109     }
110 
111     ~__control_block()
112     {
113         __value().~_Ty();
114     }
115 
116     auto __value() noexcept -> _Ty&
117     {
118         return *reinterpret_cast<_Ty*>(__value_);
119     }
120 
121     __bits_t __inc_ref_() noexcept
122     {
123         auto __old = __ref_count_.fetch_add(__ref_count_increment,
124                                             std::memory_order_relaxed);
125         return static_cast<__bits_t>(__old);
126     }
127 
128     __bits_t __dec_ref_() noexcept
129     {
130         auto __old = __ref_count_.fetch_sub(__ref_count_increment,
131                                             std::memory_order_acq_rel);
132         if (__count(static_cast<__bits_t>(__old)) == 1)
133         {
134             delete this;
135         }
136         return static_cast<__bits_t>(__old);
137     }
138 
139     // Returns true if the bit was set, false if it was already set.
140     template <std::size_t _Bit>
141     [[nodiscard]] bool __is_set_() const noexcept
142     {
143         auto __old = __ref_count_.load(std::memory_order_relaxed);
144         return __bit<_Bit>(static_cast<__bits_t>(__old));
145     }
146 
147     template <std::size_t _Bit>
148     __bits_t __set_bit_() noexcept
149     {
150         static_assert(_Bit < _ReservedBits, "Bit index out of range");
151         constexpr std::size_t __mask = 1ul << _Bit;
152         auto __old = __ref_count_.fetch_or(__mask, std::memory_order_acq_rel);
153         return static_cast<__bits_t>(__old);
154     }
155 
156     // Returns true if the bit was cleared, false if it was already cleared.
157     template <std::size_t _Bit>
158     __bits_t __clear_bit_() noexcept
159     {
160         static_assert(_Bit < _ReservedBits, "Bit index out of range");
161         constexpr std::size_t __mask = 1ul << _Bit;
162         auto __old = __ref_count_.fetch_and(~__mask, std::memory_order_acq_rel);
163         return static_cast<__bits_t>(__old);
164     }
165 };
166 
167 STDEXEC_PRAGMA_POP()
168 
169 template <class _Ty, std::size_t _ReservedBits /* = 0ul */>
170 class __intrusive_ptr
171 {
172     using _UncvTy = std::remove_cv_t<_Ty>;
173     using __enable_intrusive_t =
174         __enable_intrusive_from_this<_UncvTy, _ReservedBits>;
175     friend _Ty;
176     friend struct __make_intrusive_t<_Ty, _ReservedBits>;
177     friend struct __enable_intrusive_from_this<_UncvTy, _ReservedBits>;
178 
179     __control_block<_UncvTy, _ReservedBits>* __data_{nullptr};
180 
181     explicit __intrusive_ptr(
182         __control_block<_UncvTy, _ReservedBits>* __data) noexcept :
183         __data_(__data)
184     {}
185 
186     void __inc_ref_() noexcept
187     {
188         if (__data_)
189         {
190             __data_->__inc_ref_();
191         }
192     }
193 
194     void __dec_ref_() noexcept
195     {
196         if (__data_)
197         {
198             __data_->__dec_ref_();
199         }
200     }
201 
202     // For use when types want to take over manual control of the reference
203     // count. Very unsafe, but useful for implementing custom reference
204     // counting.
205     [[nodiscard]] __enable_intrusive_t* __release_() noexcept
206     {
207         auto* __data = std::exchange(__data_, nullptr);
208         return __data ? &__c_upcast<__enable_intrusive_t>(__data->__value())
209                       : nullptr;
210     }
211 
212   public:
213     using element_type = _Ty;
214 
215     __intrusive_ptr() = default;
216 
217     __intrusive_ptr(__intrusive_ptr&& __that) noexcept :
218         __data_(std::exchange(__that.__data_, nullptr))
219     {}
220 
221     __intrusive_ptr(const __intrusive_ptr& __that) noexcept :
222         __data_(__that.__data_)
223     {
224         __inc_ref_();
225     }
226 
227     __intrusive_ptr(
228         __enable_intrusive_from_this<_Ty, _ReservedBits>* __that) noexcept :
229         __intrusive_ptr(__that ? __that->__intrusive_from_this()
230                                : __intrusive_ptr())
231     {}
232 
233     auto operator=(__intrusive_ptr&& __that) noexcept -> __intrusive_ptr&
234     {
235         [[maybe_unused]] __intrusive_ptr __old{
236             std::exchange(__data_, std::exchange(__that.__data_, nullptr))};
237         return *this;
238     }
239 
240     auto operator=(const __intrusive_ptr& __that) noexcept -> __intrusive_ptr&
241     {
242         return operator=(__intrusive_ptr(__that));
243     }
244 
245     auto operator=(
246         __enable_intrusive_from_this<_Ty, _ReservedBits>* __that) noexcept
247         -> __intrusive_ptr&
248     {
249         return operator=(__that ? __that->__intrusive_from_this()
250                                 : __intrusive_ptr());
251     }
252 
253     ~__intrusive_ptr()
254     {
255         __dec_ref_();
256     }
257 
258     void reset() noexcept
259     {
260         operator=({});
261     }
262 
263     void swap(__intrusive_ptr& __that) noexcept
264     {
265         std::swap(__data_, __that.__data_);
266     }
267 
268     auto get() const noexcept -> _Ty*
269     {
270         return &__data_->__value();
271     }
272 
273     auto operator->() const noexcept -> _Ty*
274     {
275         return &__data_->__value();
276     }
277 
278     auto operator*() const noexcept -> _Ty&
279     {
280         return __data_->__value();
281     }
282 
283     explicit operator bool() const noexcept
284     {
285         return __data_ != nullptr;
286     }
287 
288     auto operator!() const noexcept -> bool
289     {
290         return __data_ == nullptr;
291     }
292 
293     auto operator==(const __intrusive_ptr&) const -> bool = default;
294 
295     auto operator==(std::nullptr_t) const noexcept -> bool
296     {
297         return __data_ == nullptr;
298     }
299 };
300 
301 template <class _Ty, std::size_t _ReservedBits>
302 auto __enable_intrusive_from_this<
303     _Ty, _ReservedBits>::__intrusive_from_this() noexcept
304     -> __intrusive_ptr<_Ty, _ReservedBits>
305 {
306     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
307         &__c_downcast<_Ty>(*this));
308     __data->__inc_ref_();
309     return __intrusive_ptr<_Ty, _ReservedBits>{__data};
310 }
311 
312 template <class _Ty, std::size_t _ReservedBits>
313 auto __enable_intrusive_from_this<_Ty, _ReservedBits>::__intrusive_from_this()
314     const noexcept -> __intrusive_ptr<const _Ty, _ReservedBits>
315 {
316     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
317         &__c_downcast<_Ty>(*this));
318     __data->__inc_ref_();
319     return __intrusive_ptr<const _Ty, _ReservedBits>{__data};
320 }
321 
322 template <class _Ty, std::size_t _ReservedBits>
323 __bits_t<_ReservedBits>
324     __enable_intrusive_from_this<_Ty, _ReservedBits>::__inc_ref() noexcept
325 {
326     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
327         &__c_downcast<_Ty>(*this));
328     return __data->__inc_ref_();
329 }
330 
331 template <class _Ty, std::size_t _ReservedBits>
332 __bits_t<_ReservedBits>
333     __enable_intrusive_from_this<_Ty, _ReservedBits>::__dec_ref() noexcept
334 {
335     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
336         &__c_downcast<_Ty>(*this));
337     return __data->__dec_ref_();
338 }
339 
340 template <class _Ty, std::size_t _ReservedBits>
341 template <std::size_t _Bit>
342 bool __enable_intrusive_from_this<_Ty, _ReservedBits>::__is_set() const noexcept
343 {
344     auto* __data = reinterpret_cast<const __control_block<_Ty, _ReservedBits>*>(
345         &__c_downcast<_Ty>(*this));
346     return __data->template __is_set_<_Bit>();
347 }
348 
349 template <class _Ty, std::size_t _ReservedBits>
350 template <std::size_t _Bit>
351 __bits_t<_ReservedBits>
352     __enable_intrusive_from_this<_Ty, _ReservedBits>::__set_bit() noexcept
353 {
354     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
355         &__c_downcast<_Ty>(*this));
356     return __data->template __set_bit_<_Bit>();
357 }
358 
359 template <class _Ty, std::size_t _ReservedBits>
360 template <std::size_t _Bit>
361 __bits_t<_ReservedBits>
362     __enable_intrusive_from_this<_Ty, _ReservedBits>::__clear_bit() noexcept
363 {
364     auto* __data = reinterpret_cast<__control_block<_Ty, _ReservedBits>*>(
365         &__c_downcast<_Ty>(*this));
366     return __data->template __clear_bit_<_Bit>();
367 }
368 
369 template <class _Ty, std::size_t _ReservedBits>
370 struct __make_intrusive_t
371 {
372     template <class... _Us>
373         requires constructible_from<_Ty, _Us...>
374     auto operator()(_Us&&... __us) const -> __intrusive_ptr<_Ty, _ReservedBits>
375     {
376         using _UncvTy = std::remove_cv_t<_Ty>;
377         return __intrusive_ptr<_Ty, _ReservedBits>{
378             ::new __control_block<_UncvTy, _ReservedBits>{
379                 static_cast<_Us&&>(__us)...}};
380     }
381 };
382 } // namespace __ptr
383 
384 using __ptr::__enable_intrusive_from_this;
385 using __ptr::__intrusive_ptr;
386 template <class _Ty, std::size_t _ReservedBits = 0ul>
387 inline constexpr __ptr::__make_intrusive_t<_Ty, _ReservedBits>
388     __make_intrusive{};
389 
390 } // namespace stdexec
391