1 /* 2 * Copyright (c) 2024 NVIDIA Corporation 3 * 4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 5 * (the "License"); you may not use this file except in compliance with 6 * the License. You may obtain a copy of the License at 7 * 8 * https://llvm.org/LICENSE.txt 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include "__meta.hpp" 19 #include "__type_traits.hpp" 20 #include "__utility.hpp" 21 22 #include <cstddef> 23 #include <memory> 24 #include <new> 25 #include <type_traits> 26 27 /********************************************************************************/ 28 /* NB: The variant type implemented here default-constructs into the valueless 29 */ 30 /* state. This is different from std::variant which default-constructs into the 31 */ 32 /* first alternative. This is done to simplify the implementation and to avoid 33 */ 34 /* the need for a default constructor for each alternative type. */ 35 /********************************************************************************/ 36 37 STDEXEC_PRAGMA_PUSH() 38 STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces") 39 40 namespace stdexec 41 { 42 inline constexpr std::size_t __variant_npos = ~0UL; 43 44 struct __monostate 45 {}; 46 47 namespace __var 48 { 49 template <auto _Idx, class... _Ts> 50 class __variant; 51 52 template <> 53 class __variant<__indices<>{}> 54 { 55 public: 56 template <class _Fn, class... _Us> 57 STDEXEC_ATTRIBUTE((host, device)) visit(_Fn &&,_Us &&...) const58 void visit(_Fn&&, _Us&&...) const noexcept 59 { 60 STDEXEC_ASSERT(false); 61 } 62 63 STDEXEC_ATTRIBUTE((host, device)) index()64 static constexpr std::size_t index() noexcept 65 { 66 return __variant_npos; 67 } 68 69 STDEXEC_ATTRIBUTE((host, device)) is_valueless()70 static constexpr bool is_valueless() noexcept 71 { 72 return true; 73 } 74 }; 75 76 template <std::size_t... _Is, __indices<_Is...> _Idx, class... _Ts> 77 class __variant<_Idx, _Ts...> 78 { 79 static constexpr std::size_t __max_size = stdexec::__umax({sizeof(_Ts)...}); 80 static_assert(__max_size != 0); 81 std::size_t __index_{__variant_npos}; 82 alignas(_Ts...) unsigned char __storage_[__max_size]; 83 84 STDEXEC_ATTRIBUTE((host, device)) __destroy()85 void __destroy() noexcept 86 { 87 auto __index = std::exchange(__index_, __variant_npos); 88 if (__variant_npos != __index) 89 { 90 ((_Is == __index ? std::destroy_at(static_cast<_Ts*>(__get_ptr())) 91 : void(0)), 92 ...); 93 } 94 } 95 96 template <std::size_t _Ny> 97 using __at = __m_at_c<_Ny, _Ts...>; 98 99 public: 100 // immovable: 101 __variant(__variant&&) = delete; 102 103 STDEXEC_ATTRIBUTE((host, device)) __variant()104 __variant() noexcept {} 105 106 STDEXEC_ATTRIBUTE((host, device)) ~__variant()107 ~__variant() 108 { 109 __destroy(); 110 } 111 112 STDEXEC_ATTRIBUTE((host, device, always_inline)) __get_ptr()113 void* __get_ptr() noexcept 114 { 115 return __storage_; 116 } 117 118 STDEXEC_ATTRIBUTE((host, device, always_inline)) index() const119 std::size_t index() const noexcept 120 { 121 return __index_; 122 } 123 124 STDEXEC_ATTRIBUTE((host, device, always_inline)) is_valueless() const125 bool is_valueless() const noexcept 126 { 127 return __index_ == __variant_npos; 128 } 129 130 template <class _Ty, class... _As> 131 STDEXEC_ATTRIBUTE((host, device)) emplace(_As &&...__as)132 _Ty& emplace(_As&&... __as) // 133 noexcept(__nothrow_constructible_from<_Ty, _As...>) 134 { 135 constexpr std::size_t __new_index = stdexec::__index_of<_Ty, _Ts...>(); 136 static_assert(__new_index != __variant_npos, "Type not in variant"); 137 138 __destroy(); 139 ::new (__storage_) _Ty{static_cast<_As&&>(__as)...}; 140 __index_ = __new_index; 141 return *std::launder(reinterpret_cast<_Ty*>(__storage_)); 142 } 143 144 template <std::size_t _Ny, class... _As> 145 STDEXEC_ATTRIBUTE((host, device)) emplace(_As &&...__as)146 __at<_Ny>& emplace(_As&&... __as) // 147 noexcept(__nothrow_constructible_from<__at<_Ny>, _As...>) 148 { 149 static_assert(_Ny < sizeof...(_Ts), "variant index is too large"); 150 151 __destroy(); 152 ::new (__storage_) __at<_Ny>{static_cast<_As&&>(__as)...}; 153 __index_ = _Ny; 154 return *std::launder(reinterpret_cast<__at<_Ny>*>(__storage_)); 155 } 156 157 template <std::size_t _Ny, class _Fn, class... _As> 158 STDEXEC_ATTRIBUTE((host, device)) emplace_from_at(_Fn && __fn,_As &&...__as)159 __at<_Ny>& emplace_from_at(_Fn&& __fn, _As&&... __as) // 160 noexcept(__nothrow_callable<_Fn, _As...>) 161 { 162 static_assert(__same_as<__call_result_t<_Fn, _As...>, __at<_Ny>>, 163 "callable does not return the correct type"); 164 165 __destroy(); 166 ::new (__storage_) 167 __at<_Ny>(static_cast<_Fn&&>(__fn)(static_cast<_As&&>(__as)...)); 168 __index_ = _Ny; 169 return *std::launder(reinterpret_cast<__at<_Ny>*>(__storage_)); 170 } 171 172 template <class _Fn, class... _As> 173 STDEXEC_ATTRIBUTE((host, device, always_inline)) emplace_from(_Fn && __fn,_As &&...__as)174 auto emplace_from(_Fn&& __fn, _As&&... __as) // 175 noexcept(__nothrow_callable<_Fn, _As...>) 176 -> __call_result_t<_Fn, _As...>& 177 { 178 using __result_t = __call_result_t<_Fn, _As...>; 179 constexpr std::size_t __new_index = 180 stdexec::__index_of<__result_t, _Ts...>(); 181 static_assert(__new_index != __variant_npos, "Type not in variant"); 182 return emplace_from_at<__new_index>(static_cast<_Fn&&>(__fn), 183 static_cast<_As&&>(__as)...); 184 } 185 186 template <class _Fn, class _Self, class... _As> 187 STDEXEC_ATTRIBUTE((host, device)) visit(_Fn && __fn,_Self && __self,_As &&...__as)188 static void visit(_Fn&& __fn, _Self&& __self, _As&&... __as) // 189 noexcept((__nothrow_callable<_Fn, _As..., __copy_cvref_t<_Self, _Ts>> && 190 ...)) 191 { 192 STDEXEC_ASSERT(__self.__index_ != __variant_npos); 193 auto __index = __self.__index_; // make it local so we don't access it 194 // after it's deleted. 195 ((_Is == __index ? static_cast<_Fn&&>(__fn)( 196 static_cast<_As&&>(__as)..., 197 static_cast<_Self&&>(__self).template get<_Is>()) 198 : void()), 199 ...); 200 } 201 202 template <std::size_t _Ny> 203 STDEXEC_ATTRIBUTE((host, device, always_inline)) get()204 decltype(auto) get() && noexcept 205 { 206 STDEXEC_ASSERT(_Ny == __index_); 207 return static_cast<__at<_Ny>&&>( 208 *reinterpret_cast<__at<_Ny>*>(__storage_)); 209 } 210 211 template <std::size_t _Ny> 212 STDEXEC_ATTRIBUTE((host, device, always_inline)) get()213 decltype(auto) get() & noexcept 214 { 215 STDEXEC_ASSERT(_Ny == __index_); 216 return *reinterpret_cast<__at<_Ny>*>(__storage_); 217 } 218 219 template <std::size_t _Ny> 220 STDEXEC_ATTRIBUTE((host, device, always_inline)) get() const221 decltype(auto) get() const& noexcept 222 { 223 STDEXEC_ASSERT(_Ny == __index_); 224 return *reinterpret_cast<const __at<_Ny>*>(__storage_); 225 } 226 }; 227 } // namespace __var 228 229 using __var::__variant; 230 231 template <class... _Ts> 232 using __variant_for = __variant<__indices_for<_Ts...>{}, _Ts...>; 233 234 template <class... Ts> 235 using __uniqued_variant_for = 236 __mcall<__munique<__qq<__variant_for>>, __decay_t<Ts>...>; 237 238 // So we can use __variant as a typelist 239 template <auto _Idx, class... _Ts> 240 struct __muncurry_<__variant<_Idx, _Ts...>> 241 { 242 template <class _Fn> 243 using __f = __minvoke<_Fn, _Ts...>; 244 }; 245 } // namespace stdexec 246 247 STDEXEC_PRAGMA_POP() 248