1 /*
2  * Copyright (c) 2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__meta.hpp"
19 #include "__type_traits.hpp"
20 #include "__utility.hpp"
21 
22 #include <cstddef>
23 #include <memory>
24 #include <new>
25 #include <type_traits>
26 
27 /********************************************************************************/
28 /* NB: The variant type implemented here default-constructs into the valueless
29  */
30 /* state. This is different from std::variant which default-constructs into the
31  */
32 /* first alternative. This is done to simplify the implementation and to avoid
33  */
34 /* the need for a default constructor for each alternative type. */
35 /********************************************************************************/
36 
37 STDEXEC_PRAGMA_PUSH()
38 STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces")
39 
40 namespace stdexec
41 {
42 inline constexpr std::size_t __variant_npos = ~0UL;
43 
44 struct __monostate
45 {};
46 
47 namespace __var
48 {
49 template <auto _Idx, class... _Ts>
50 class __variant;
51 
52 template <>
53 class __variant<__indices<>{}>
54 {
55   public:
56     template <class _Fn, class... _Us>
57     STDEXEC_ATTRIBUTE((host, device))
visit(_Fn &&,_Us &&...) const58     void visit(_Fn&&, _Us&&...) const noexcept
59     {
60         STDEXEC_ASSERT(false);
61     }
62 
63     STDEXEC_ATTRIBUTE((host, device))
index()64     static constexpr std::size_t index() noexcept
65     {
66         return __variant_npos;
67     }
68 
69     STDEXEC_ATTRIBUTE((host, device))
is_valueless()70     static constexpr bool is_valueless() noexcept
71     {
72         return true;
73     }
74 };
75 
76 template <std::size_t... _Is, __indices<_Is...> _Idx, class... _Ts>
77 class __variant<_Idx, _Ts...>
78 {
79     static constexpr std::size_t __max_size = stdexec::__umax({sizeof(_Ts)...});
80     static_assert(__max_size != 0);
81     std::size_t __index_{__variant_npos};
82     alignas(_Ts...) unsigned char __storage_[__max_size];
83 
84     STDEXEC_ATTRIBUTE((host, device))
__destroy()85     void __destroy() noexcept
86     {
87         auto __index = std::exchange(__index_, __variant_npos);
88         if (__variant_npos != __index)
89         {
90             ((_Is == __index ? std::destroy_at(static_cast<_Ts*>(__get_ptr()))
91                              : void(0)),
92              ...);
93         }
94     }
95 
96     template <std::size_t _Ny>
97     using __at = __m_at_c<_Ny, _Ts...>;
98 
99   public:
100     // immovable:
101     __variant(__variant&&) = delete;
102 
103     STDEXEC_ATTRIBUTE((host, device))
__variant()104     __variant() noexcept {}
105 
106     STDEXEC_ATTRIBUTE((host, device))
~__variant()107     ~__variant()
108     {
109         __destroy();
110     }
111 
112     STDEXEC_ATTRIBUTE((host, device, always_inline))
__get_ptr()113     void* __get_ptr() noexcept
114     {
115         return __storage_;
116     }
117 
118     STDEXEC_ATTRIBUTE((host, device, always_inline))
index() const119     std::size_t index() const noexcept
120     {
121         return __index_;
122     }
123 
124     STDEXEC_ATTRIBUTE((host, device, always_inline))
is_valueless() const125     bool is_valueless() const noexcept
126     {
127         return __index_ == __variant_npos;
128     }
129 
130     template <class _Ty, class... _As>
131     STDEXEC_ATTRIBUTE((host, device))
emplace(_As &&...__as)132     _Ty& emplace(_As&&... __as) //
133         noexcept(__nothrow_constructible_from<_Ty, _As...>)
134     {
135         constexpr std::size_t __new_index = stdexec::__index_of<_Ty, _Ts...>();
136         static_assert(__new_index != __variant_npos, "Type not in variant");
137 
138         __destroy();
139         ::new (__storage_) _Ty{static_cast<_As&&>(__as)...};
140         __index_ = __new_index;
141         return *reinterpret_cast<_Ty*>(__storage_);
142     }
143 
144     template <std::size_t _Ny, class... _As>
145     STDEXEC_ATTRIBUTE((host, device))
emplace(_As &&...__as)146     __at<_Ny>& emplace(_As&&... __as) //
147         noexcept(__nothrow_constructible_from<__at<_Ny>, _As...>)
148     {
149         static_assert(_Ny < sizeof...(_Ts), "variant index is too large");
150 
151         __destroy();
152         ::new (__storage_) __at<_Ny>{static_cast<_As&&>(__as)...};
153         __index_ = _Ny;
154         return *reinterpret_cast<__at<_Ny>*>(__storage_);
155     }
156 
157     template <class _Fn, class... _As>
158     STDEXEC_ATTRIBUTE((host, device))
emplace_from(_Fn && __fn,_As &&...__as)159     auto emplace_from(_Fn&& __fn, _As&&... __as) //
160         noexcept(__nothrow_callable<_Fn, _As...>)
161             -> __call_result_t<_Fn, _As...>&
162     {
163         using __result_t = __call_result_t<_Fn, _As...>;
164         constexpr std::size_t __new_index =
165             stdexec::__index_of<__result_t, _Ts...>();
166         static_assert(__new_index != __variant_npos, "Type not in variant");
167 
168         __destroy();
169         ::new (__storage_)
170             __result_t(static_cast<_Fn&&>(__fn)(static_cast<_As&&>(__as)...));
171         __index_ = __new_index;
172         return *reinterpret_cast<__result_t*>(__storage_);
173     }
174 
175     template <class _Fn, class _Self, class... _As>
176     STDEXEC_ATTRIBUTE((host, device))
visit(_Fn && __fn,_Self && __self,_As &&...__as)177     static void visit(_Fn&& __fn, _Self&& __self, _As&&... __as) //
178         noexcept((__nothrow_callable<_Fn, _As..., __copy_cvref_t<_Self, _Ts>> &&
179                   ...))
180     {
181         STDEXEC_ASSERT(__self.__index_ != __variant_npos);
182         auto __index = __self.__index_; // make it local so we don't access it
183                                         // after it's deleted.
184         ((_Is == __index ? static_cast<_Fn&&>(__fn)(
185                                static_cast<_As&&>(__as)...,
186                                static_cast<_Self&&>(__self).template get<_Is>())
187                          : void()),
188          ...);
189     }
190 
191     template <std::size_t _Ny>
192     STDEXEC_ATTRIBUTE((host, device, always_inline))
get()193     decltype(auto) get() && noexcept
194     {
195         STDEXEC_ASSERT(_Ny == __index_);
196         return static_cast<__at<_Ny>&&>(
197             *reinterpret_cast<__at<_Ny>*>(__storage_));
198     }
199 
200     template <std::size_t _Ny>
201     STDEXEC_ATTRIBUTE((host, device, always_inline))
get()202     decltype(auto) get() & noexcept
203     {
204         STDEXEC_ASSERT(_Ny == __index_);
205         return *reinterpret_cast<__at<_Ny>*>(__storage_);
206     }
207 
208     template <std::size_t _Ny>
209     STDEXEC_ATTRIBUTE((host, device, always_inline))
get() const210     decltype(auto) get() const& noexcept
211     {
212         STDEXEC_ASSERT(_Ny == __index_);
213         return *reinterpret_cast<const __at<_Ny>*>(__storage_);
214     }
215 };
216 } // namespace __var
217 
218 using __var::__variant;
219 
220 template <class... _Ts>
221 using __variant_for = __variant<__indices_for<_Ts...>{}, _Ts...>;
222 
223 template <class... Ts>
224 using __uniqued_variant_for =
225     __mcall<__munique<__qq<__variant_for>>, __decay_t<Ts>...>;
226 
227 // So we can use __variant as a typelist
228 template <auto _Idx, class... _Ts>
229 struct __muncurry_<__variant<_Idx, _Ts...>>
230 {
231     template <class _Fn>
232     using __f = __minvoke<_Fn, _Ts...>;
233 };
234 } // namespace stdexec
235 
236 STDEXEC_PRAGMA_POP()
237