1 /* 2 * Copyright (c) 2024 NVIDIA Corporation 3 * 4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 5 * (the "License"); you may not use this file except in compliance with 6 * the License. You may obtain a copy of the License at 7 * 8 * https://llvm.org/LICENSE.txt 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include "__meta.hpp" 19 #include "__type_traits.hpp" 20 #include "__utility.hpp" 21 22 #include <cstddef> 23 #include <memory> 24 #include <new> 25 #include <type_traits> 26 27 /********************************************************************************/ 28 /* NB: The variant type implemented here default-constructs into the valueless 29 */ 30 /* state. This is different from std::variant which default-constructs into the 31 */ 32 /* first alternative. This is done to simplify the implementation and to avoid 33 */ 34 /* the need for a default constructor for each alternative type. */ 35 /********************************************************************************/ 36 37 STDEXEC_PRAGMA_PUSH() 38 STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces") 39 40 namespace stdexec 41 { 42 inline constexpr std::size_t __variant_npos = ~0UL; 43 44 struct __monostate 45 {}; 46 47 namespace __var 48 { 49 template <auto _Idx, class... _Ts> 50 class __variant; 51 52 template <> 53 class __variant<__indices<>{}> 54 { 55 public: 56 template <class _Fn, class... _Us> 57 STDEXEC_ATTRIBUTE((host, device)) visit(_Fn &&,_Us &&...) const58 void visit(_Fn&&, _Us&&...) const noexcept 59 { 60 STDEXEC_ASSERT(false); 61 } 62 63 STDEXEC_ATTRIBUTE((host, device)) index()64 static constexpr std::size_t index() noexcept 65 { 66 return __variant_npos; 67 } 68 69 STDEXEC_ATTRIBUTE((host, device)) is_valueless()70 static constexpr bool is_valueless() noexcept 71 { 72 return true; 73 } 74 }; 75 76 template <std::size_t... _Is, __indices<_Is...> _Idx, class... _Ts> 77 class __variant<_Idx, _Ts...> 78 { 79 static constexpr std::size_t __max_size = stdexec::__umax({sizeof(_Ts)...}); 80 static_assert(__max_size != 0); 81 std::size_t __index_{__variant_npos}; 82 alignas(_Ts...) unsigned char __storage_[__max_size]; 83 84 STDEXEC_ATTRIBUTE((host, device)) __destroy()85 void __destroy() noexcept 86 { 87 auto __index = std::exchange(__index_, __variant_npos); 88 if (__variant_npos != __index) 89 { 90 ((_Is == __index ? std::destroy_at(static_cast<_Ts*>(__get_ptr())) 91 : void(0)), 92 ...); 93 } 94 } 95 96 template <std::size_t _Ny> 97 using __at = __m_at_c<_Ny, _Ts...>; 98 99 public: 100 // immovable: 101 __variant(__variant&&) = delete; 102 103 STDEXEC_ATTRIBUTE((host, device)) __variant()104 __variant() noexcept {} 105 106 STDEXEC_ATTRIBUTE((host, device)) ~__variant()107 ~__variant() 108 { 109 __destroy(); 110 } 111 112 STDEXEC_ATTRIBUTE((host, device, always_inline)) __get_ptr()113 void* __get_ptr() noexcept 114 { 115 return __storage_; 116 } 117 118 STDEXEC_ATTRIBUTE((host, device, always_inline)) index() const119 std::size_t index() const noexcept 120 { 121 return __index_; 122 } 123 124 STDEXEC_ATTRIBUTE((host, device, always_inline)) is_valueless() const125 bool is_valueless() const noexcept 126 { 127 return __index_ == __variant_npos; 128 } 129 130 template <class _Ty, class... _As> 131 STDEXEC_ATTRIBUTE((host, device)) emplace(_As &&...__as)132 _Ty& emplace(_As&&... __as) // 133 noexcept(__nothrow_constructible_from<_Ty, _As...>) 134 { 135 constexpr std::size_t __new_index = stdexec::__index_of<_Ty, _Ts...>(); 136 static_assert(__new_index != __variant_npos, "Type not in variant"); 137 138 __destroy(); 139 ::new (__storage_) _Ty{static_cast<_As&&>(__as)...}; 140 __index_ = __new_index; 141 return *reinterpret_cast<_Ty*>(__storage_); 142 } 143 144 template <std::size_t _Ny, class... _As> 145 STDEXEC_ATTRIBUTE((host, device)) emplace(_As &&...__as)146 __at<_Ny>& emplace(_As&&... __as) // 147 noexcept(__nothrow_constructible_from<__at<_Ny>, _As...>) 148 { 149 static_assert(_Ny < sizeof...(_Ts), "variant index is too large"); 150 151 __destroy(); 152 ::new (__storage_) __at<_Ny>{static_cast<_As&&>(__as)...}; 153 __index_ = _Ny; 154 return *reinterpret_cast<__at<_Ny>*>(__storage_); 155 } 156 157 template <class _Fn, class... _As> 158 STDEXEC_ATTRIBUTE((host, device)) emplace_from(_Fn && __fn,_As &&...__as)159 auto emplace_from(_Fn&& __fn, _As&&... __as) // 160 noexcept(__nothrow_callable<_Fn, _As...>) 161 -> __call_result_t<_Fn, _As...>& 162 { 163 using __result_t = __call_result_t<_Fn, _As...>; 164 constexpr std::size_t __new_index = 165 stdexec::__index_of<__result_t, _Ts...>(); 166 static_assert(__new_index != __variant_npos, "Type not in variant"); 167 168 __destroy(); 169 ::new (__storage_) 170 __result_t(static_cast<_Fn&&>(__fn)(static_cast<_As&&>(__as)...)); 171 __index_ = __new_index; 172 return *reinterpret_cast<__result_t*>(__storage_); 173 } 174 175 template <class _Fn, class _Self, class... _As> 176 STDEXEC_ATTRIBUTE((host, device)) visit(_Fn && __fn,_Self && __self,_As &&...__as)177 static void visit(_Fn&& __fn, _Self&& __self, _As&&... __as) // 178 noexcept((__nothrow_callable<_Fn, _As..., __copy_cvref_t<_Self, _Ts>> && 179 ...)) 180 { 181 STDEXEC_ASSERT(__self.__index_ != __variant_npos); 182 auto __index = __self.__index_; // make it local so we don't access it 183 // after it's deleted. 184 ((_Is == __index ? static_cast<_Fn&&>(__fn)( 185 static_cast<_As&&>(__as)..., 186 static_cast<_Self&&>(__self).template get<_Is>()) 187 : void()), 188 ...); 189 } 190 191 template <std::size_t _Ny> 192 STDEXEC_ATTRIBUTE((host, device, always_inline)) get()193 decltype(auto) get() && noexcept 194 { 195 STDEXEC_ASSERT(_Ny == __index_); 196 return static_cast<__at<_Ny>&&>( 197 *reinterpret_cast<__at<_Ny>*>(__storage_)); 198 } 199 200 template <std::size_t _Ny> 201 STDEXEC_ATTRIBUTE((host, device, always_inline)) get()202 decltype(auto) get() & noexcept 203 { 204 STDEXEC_ASSERT(_Ny == __index_); 205 return *reinterpret_cast<__at<_Ny>*>(__storage_); 206 } 207 208 template <std::size_t _Ny> 209 STDEXEC_ATTRIBUTE((host, device, always_inline)) get() const210 decltype(auto) get() const& noexcept 211 { 212 STDEXEC_ASSERT(_Ny == __index_); 213 return *reinterpret_cast<const __at<_Ny>*>(__storage_); 214 } 215 }; 216 } // namespace __var 217 218 using __var::__variant; 219 220 template <class... _Ts> 221 using __variant_for = __variant<__indices_for<_Ts...>{}, _Ts...>; 222 223 template <class... Ts> 224 using __uniqued_variant_for = 225 __mcall<__munique<__qq<__variant_for>>, __decay_t<Ts>...>; 226 227 // So we can use __variant as a typelist 228 template <auto _Idx, class... _Ts> 229 struct __muncurry_<__variant<_Idx, _Ts...>> 230 { 231 template <class _Fn> 232 using __f = __minvoke<_Fn, _Ts...>; 233 }; 234 } // namespace stdexec 235 236 STDEXEC_PRAGMA_POP() 237