1 /*
2  * Copyright (c) 2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "__meta.hpp"
19 #include "__type_traits.hpp"
20 #include "__utility.hpp"
21 
22 #include <cstddef>
23 #include <memory>
24 #include <new>
25 #include <type_traits>
26 
27 /********************************************************************************/
28 /* NB: The variant type implemented here default-constructs into the valueless
29  */
30 /* state. This is different from std::variant which default-constructs into the
31  */
32 /* first alternative. This is done to simplify the implementation and to avoid
33  */
34 /* the need for a default constructor for each alternative type. */
35 /********************************************************************************/
36 
37 STDEXEC_PRAGMA_PUSH()
38 STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces")
39 
40 namespace stdexec
41 {
42 inline constexpr std::size_t __variant_npos = ~0UL;
43 
44 struct __monostate
45 {};
46 
47 namespace __var
48 {
49 template <auto _Idx, class... _Ts>
50 class __variant;
51 
52 template <>
53 class __variant<__indices<>{}>
54 {
55   public:
56     template <class _Fn, class... _Us>
57     STDEXEC_ATTRIBUTE((host, device))
visit(_Fn &&,_Us &&...) const58     void visit(_Fn&&, _Us&&...) const noexcept
59     {
60         STDEXEC_ASSERT(false);
61     }
62 
63     STDEXEC_ATTRIBUTE((host, device))
index()64     static constexpr std::size_t index() noexcept
65     {
66         return __variant_npos;
67     }
68 
69     STDEXEC_ATTRIBUTE((host, device))
is_valueless()70     static constexpr bool is_valueless() noexcept
71     {
72         return true;
73     }
74 };
75 
76 template <std::size_t... _Is, __indices<_Is...> _Idx, class... _Ts>
77 class __variant<_Idx, _Ts...>
78 {
79     static constexpr std::size_t __max_size = stdexec::__umax({sizeof(_Ts)...});
80     static_assert(__max_size != 0);
81     std::size_t __index_{__variant_npos};
82     alignas(_Ts...) unsigned char __storage_[__max_size];
83 
84     STDEXEC_ATTRIBUTE((host, device))
__destroy()85     void __destroy() noexcept
86     {
87         auto __index = std::exchange(__index_, __variant_npos);
88         if (__variant_npos != __index)
89         {
90             ((_Is == __index ? std::destroy_at(static_cast<_Ts*>(__get_ptr()))
91                              : void(0)),
92              ...);
93         }
94     }
95 
96     template <std::size_t _Ny>
97     using __at = __m_at_c<_Ny, _Ts...>;
98 
99   public:
100     // immovable:
101     __variant(__variant&&) = delete;
102 
103     STDEXEC_ATTRIBUTE((host, device))
__variant()104     __variant() noexcept {}
105 
106     STDEXEC_ATTRIBUTE((host, device))
~__variant()107     ~__variant()
108     {
109         __destroy();
110     }
111 
112     STDEXEC_ATTRIBUTE((host, device, always_inline))
__get_ptr()113     void* __get_ptr() noexcept
114     {
115         return __storage_;
116     }
117 
118     STDEXEC_ATTRIBUTE((host, device, always_inline))
index() const119     std::size_t index() const noexcept
120     {
121         return __index_;
122     }
123 
124     STDEXEC_ATTRIBUTE((host, device, always_inline))
is_valueless() const125     bool is_valueless() const noexcept
126     {
127         return __index_ == __variant_npos;
128     }
129 
130     template <class _Ty, class... _As>
131     STDEXEC_ATTRIBUTE((host, device))
emplace(_As &&...__as)132     _Ty& emplace(_As&&... __as) //
133         noexcept(__nothrow_constructible_from<_Ty, _As...>)
134     {
135         constexpr std::size_t __new_index = stdexec::__index_of<_Ty, _Ts...>();
136         static_assert(__new_index != __variant_npos, "Type not in variant");
137 
138         __destroy();
139         ::new (__storage_) _Ty{static_cast<_As&&>(__as)...};
140         __index_ = __new_index;
141         return *std::launder(reinterpret_cast<_Ty*>(__storage_));
142     }
143 
144     template <std::size_t _Ny, class... _As>
145     STDEXEC_ATTRIBUTE((host, device))
emplace(_As &&...__as)146     __at<_Ny>& emplace(_As&&... __as) //
147         noexcept(__nothrow_constructible_from<__at<_Ny>, _As...>)
148     {
149         static_assert(_Ny < sizeof...(_Ts), "variant index is too large");
150 
151         __destroy();
152         ::new (__storage_) __at<_Ny>{static_cast<_As&&>(__as)...};
153         __index_ = _Ny;
154         return *std::launder(reinterpret_cast<__at<_Ny>*>(__storage_));
155     }
156 
157     template <std::size_t _Ny, class _Fn, class... _As>
158     STDEXEC_ATTRIBUTE((host, device))
emplace_from_at(_Fn && __fn,_As &&...__as)159     __at<_Ny>& emplace_from_at(_Fn&& __fn, _As&&... __as) //
160         noexcept(__nothrow_callable<_Fn, _As...>)
161     {
162         static_assert(__same_as<__call_result_t<_Fn, _As...>, __at<_Ny>>,
163                       "callable does not return the correct type");
164 
165         __destroy();
166         ::new (__storage_)
167             __at<_Ny>(static_cast<_Fn&&>(__fn)(static_cast<_As&&>(__as)...));
168         __index_ = _Ny;
169         return *std::launder(reinterpret_cast<__at<_Ny>*>(__storage_));
170     }
171 
172     template <class _Fn, class... _As>
173     STDEXEC_ATTRIBUTE((host, device, always_inline))
emplace_from(_Fn && __fn,_As &&...__as)174     auto emplace_from(_Fn&& __fn, _As&&... __as) //
175         noexcept(__nothrow_callable<_Fn, _As...>)
176             -> __call_result_t<_Fn, _As...>&
177     {
178         using __result_t = __call_result_t<_Fn, _As...>;
179         constexpr std::size_t __new_index =
180             stdexec::__index_of<__result_t, _Ts...>();
181         static_assert(__new_index != __variant_npos, "Type not in variant");
182         return emplace_from_at<__new_index>(static_cast<_Fn&&>(__fn),
183                                             static_cast<_As&&>(__as)...);
184     }
185 
186     template <class _Fn, class _Self, class... _As>
187     STDEXEC_ATTRIBUTE((host, device))
visit(_Fn && __fn,_Self && __self,_As &&...__as)188     static void visit(_Fn&& __fn, _Self&& __self, _As&&... __as) //
189         noexcept((__nothrow_callable<_Fn, _As..., __copy_cvref_t<_Self, _Ts>> &&
190                   ...))
191     {
192         STDEXEC_ASSERT(__self.__index_ != __variant_npos);
193         auto __index = __self.__index_; // make it local so we don't access it
194                                         // after it's deleted.
195         ((_Is == __index ? static_cast<_Fn&&>(__fn)(
196                                static_cast<_As&&>(__as)...,
197                                static_cast<_Self&&>(__self).template get<_Is>())
198                          : void()),
199          ...);
200     }
201 
202     template <std::size_t _Ny>
203     STDEXEC_ATTRIBUTE((host, device, always_inline))
get()204     decltype(auto) get() && noexcept
205     {
206         STDEXEC_ASSERT(_Ny == __index_);
207         return static_cast<__at<_Ny>&&>(
208             *reinterpret_cast<__at<_Ny>*>(__storage_));
209     }
210 
211     template <std::size_t _Ny>
212     STDEXEC_ATTRIBUTE((host, device, always_inline))
get()213     decltype(auto) get() & noexcept
214     {
215         STDEXEC_ASSERT(_Ny == __index_);
216         return *reinterpret_cast<__at<_Ny>*>(__storage_);
217     }
218 
219     template <std::size_t _Ny>
220     STDEXEC_ATTRIBUTE((host, device, always_inline))
get() const221     decltype(auto) get() const& noexcept
222     {
223         STDEXEC_ASSERT(_Ny == __index_);
224         return *reinterpret_cast<const __at<_Ny>*>(__storage_);
225     }
226 };
227 } // namespace __var
228 
229 using __var::__variant;
230 
231 template <class... _Ts>
232 using __variant_for = __variant<__indices_for<_Ts...>{}, _Ts...>;
233 
234 template <class... Ts>
235 using __uniqued_variant_for =
236     __mcall<__munique<__qq<__variant_for>>, __decay_t<Ts>...>;
237 
238 // So we can use __variant as a typelist
239 template <auto _Idx, class... _Ts>
240 struct __muncurry_<__variant<_Idx, _Ts...>>
241 {
242     template <class _Fn>
243     using __f = __minvoke<_Fn, _Ts...>;
244 };
245 } // namespace stdexec
246 
247 STDEXEC_PRAGMA_POP()
248