1 /*
2  * Copyright (c) 2022 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "../concepts.hpp"
19 #include "__config.hpp"
20 #include "__meta.hpp"
21 
22 #include <atomic>
23 #include <cstddef>
24 #include <memory>
25 #include <new>
26 #include <type_traits>
27 
28 #if STDEXEC_TSAN()
29 #include <sanitizer/tsan_interface.h>
30 #endif
31 
32 namespace stdexec
33 {
34 namespace __ptr
35 {
36 template <class _Ty>
37 struct __make_intrusive_t;
38 
39 template <class _Ty>
40 class __intrusive_ptr;
41 
42 template <class _Ty>
43 struct __enable_intrusive_from_this
44 {
45     auto __intrusive_from_this() noexcept -> __intrusive_ptr<_Ty>;
46     auto __intrusive_from_this() const noexcept -> __intrusive_ptr<const _Ty>;
47 
48   private:
49     friend _Ty;
50     void __inc_ref() noexcept;
51     void __dec_ref() noexcept;
52 };
53 
54 STDEXEC_PRAGMA_PUSH()
55 STDEXEC_PRAGMA_IGNORE_GNU("-Wtsan")
56 
57 template <class _Ty>
58 struct __control_block
59 {
60     alignas(_Ty) unsigned char __value_[sizeof(_Ty)];
61     std::atomic<unsigned long> __refcount_;
62 
63     template <class... _Us>
__control_blockstdexec::__ptr::__control_block64     explicit __control_block(_Us&&... __us) noexcept(noexcept(_Ty{
65         __declval<_Us>()...})) :
66         __refcount_(1u)
67     {
68         // Construct the value *after* the initialization of the
69         // atomic in case the constructor of _Ty calls
70         // __intrusive_from_this() (which increments the atomic):
71         ::new (static_cast<void*>(__value_)) _Ty{static_cast<_Us&&>(__us)...};
72     }
73 
~__control_blockstdexec::__ptr::__control_block74     ~__control_block()
75     {
76         __value().~_Ty();
77     }
78 
__valuestdexec::__ptr::__control_block79     auto __value() noexcept -> _Ty&
80     {
81         return *reinterpret_cast<_Ty*>(__value_);
82     }
83 
__inc_ref_stdexec::__ptr::__control_block84     void __inc_ref_() noexcept
85     {
86         __refcount_.fetch_add(1, std::memory_order_relaxed);
87     }
88 
__dec_ref_stdexec::__ptr::__control_block89     void __dec_ref_() noexcept
90     {
91         if (1u == __refcount_.fetch_sub(1, std::memory_order_release))
92         {
93             std::atomic_thread_fence(std::memory_order_acquire);
94             // TSan does not support std::atomic_thread_fence, so we
95             // need to use the TSan-specific __tsan_acquire instead:
96             STDEXEC_TSAN(__tsan_acquire(&__refcount_));
97             delete this;
98         }
99     }
100 };
101 
102 STDEXEC_PRAGMA_POP()
103 
104 template <class _Ty>
105 class __intrusive_ptr
106 {
107     using _UncvTy = std::remove_cv_t<_Ty>;
108     friend struct __make_intrusive_t<_Ty>;
109     friend struct __enable_intrusive_from_this<_UncvTy>;
110 
111     __control_block<_UncvTy>* __data_{nullptr};
112 
__intrusive_ptr(__control_block<_UncvTy> * __data)113     explicit __intrusive_ptr(__control_block<_UncvTy>* __data) noexcept :
114         __data_(__data)
115     {}
116 
__inc_ref_()117     void __inc_ref_() noexcept
118     {
119         if (__data_)
120         {
121             __data_->__inc_ref_();
122         }
123     }
124 
__dec_ref_()125     void __dec_ref_() noexcept
126     {
127         if (__data_)
128         {
129             __data_->__dec_ref_();
130         }
131     }
132 
133   public:
134     using element_type = _Ty;
135 
136     __intrusive_ptr() = default;
137 
__intrusive_ptr(__intrusive_ptr && __that)138     __intrusive_ptr(__intrusive_ptr&& __that) noexcept :
139         __data_(std::exchange(__that.__data_, nullptr))
140     {}
141 
__intrusive_ptr(const __intrusive_ptr & __that)142     __intrusive_ptr(const __intrusive_ptr& __that) noexcept :
143         __data_(__that.__data_)
144     {
145         __inc_ref_();
146     }
147 
__intrusive_ptr(__enable_intrusive_from_this<_Ty> * __that)148     __intrusive_ptr(__enable_intrusive_from_this<_Ty>* __that) noexcept :
149         __intrusive_ptr(__that ? __that->__intrusive_from_this()
150                                : __intrusive_ptr())
151     {}
152 
operator =(__intrusive_ptr && __that)153     auto operator=(__intrusive_ptr&& __that) noexcept -> __intrusive_ptr&
154     {
155         [[maybe_unused]] __intrusive_ptr __old{
156             std::exchange(__data_, std::exchange(__that.__data_, nullptr))};
157         return *this;
158     }
159 
operator =(const __intrusive_ptr & __that)160     auto operator=(const __intrusive_ptr& __that) noexcept -> __intrusive_ptr&
161     {
162         return operator=(__intrusive_ptr(__that));
163     }
164 
operator =(__enable_intrusive_from_this<_Ty> * __that)165     auto operator=(__enable_intrusive_from_this<_Ty>* __that) noexcept
166         -> __intrusive_ptr&
167     {
168         return operator=(__that ? __that->__intrusive_from_this()
169                                 : __intrusive_ptr());
170     }
171 
~__intrusive_ptr()172     ~__intrusive_ptr()
173     {
174         __dec_ref_();
175     }
176 
reset()177     void reset() noexcept
178     {
179         operator=({});
180     }
181 
swap(__intrusive_ptr & __that)182     void swap(__intrusive_ptr& __that) noexcept
183     {
184         std::swap(__data_, __that.__data_);
185     }
186 
get() const187     auto get() const noexcept -> _Ty*
188     {
189         return &__data_->__value();
190     }
191 
operator ->() const192     auto operator->() const noexcept -> _Ty*
193     {
194         return &__data_->__value();
195     }
196 
operator *() const197     auto operator*() const noexcept -> _Ty&
198     {
199         return __data_->__value();
200     }
201 
operator bool() const202     explicit operator bool() const noexcept
203     {
204         return __data_ != nullptr;
205     }
206 
operator !() const207     auto operator!() const noexcept -> bool
208     {
209         return __data_ == nullptr;
210     }
211 
212     auto operator==(const __intrusive_ptr&) const -> bool = default;
213 
operator ==(std::nullptr_t) const214     auto operator==(std::nullptr_t) const noexcept -> bool
215     {
216         return __data_ == nullptr;
217     }
218 };
219 
220 template <class _Ty>
__intrusive_from_this()221 auto __enable_intrusive_from_this<_Ty>::__intrusive_from_this() noexcept
222     -> __intrusive_ptr<_Ty>
223 {
224     auto* __data =
225         reinterpret_cast<__control_block<_Ty>*>(static_cast<_Ty*>(this));
226     __data->__inc_ref_();
227     return __intrusive_ptr<_Ty>{__data};
228 }
229 
230 template <class _Ty>
__intrusive_from_this() const231 auto __enable_intrusive_from_this<_Ty>::__intrusive_from_this() const noexcept
232     -> __intrusive_ptr<const _Ty>
233 {
234     auto* __data =
235         reinterpret_cast<__control_block<_Ty>*>(static_cast<const _Ty*>(this));
236     __data->__inc_ref_();
237     return __intrusive_ptr<const _Ty>{__data};
238 }
239 
240 template <class _Ty>
__inc_ref()241 void __enable_intrusive_from_this<_Ty>::__inc_ref() noexcept
242 {
243     auto* __data =
244         reinterpret_cast<__control_block<_Ty>*>(static_cast<_Ty*>(this));
245     __data->__inc_ref_();
246 }
247 
248 template <class _Ty>
__dec_ref()249 void __enable_intrusive_from_this<_Ty>::__dec_ref() noexcept
250 {
251     auto* __data =
252         reinterpret_cast<__control_block<_Ty>*>(static_cast<_Ty*>(this));
253     __data->__dec_ref_();
254 }
255 
256 template <class _Ty>
257 struct __make_intrusive_t
258 {
259     template <class... _Us>
260         requires constructible_from<_Ty, _Us...>
operator ()stdexec::__ptr::__make_intrusive_t261     auto operator()(_Us&&... __us) const -> __intrusive_ptr<_Ty>
262     {
263         using _UncvTy = std::remove_cv_t<_Ty>;
264         return __intrusive_ptr<_Ty>{
265             ::new __control_block<_UncvTy>{static_cast<_Us&&>(__us)...}};
266     }
267 };
268 } // namespace __ptr
269 
270 using __ptr::__enable_intrusive_from_this;
271 using __ptr::__intrusive_ptr;
272 template <class _Ty>
273 inline constexpr __ptr::__make_intrusive_t<_Ty> __make_intrusive{};
274 
275 } // namespace stdexec
276