1 /* 2 * Copyright (c) 2024 NVIDIA Corporation 3 * 4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions 5 * (the "License"); you may not use this file except in compliance with 6 * the License. You may obtain a copy of the License at 7 * 8 * https://llvm.org/LICENSE.txt 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include "__meta.hpp" 19 #include "__type_traits.hpp" 20 #include "__utility.hpp" 21 #include "__scope.hpp" 22 23 #include <cstddef> 24 #include <new> 25 #include <utility> 26 27 /********************************************************************************/ 28 /* NB: The variant type implemented here default-constructs into the valueless */ 29 /* state. This is different from std::variant which default-constructs into the */ 30 /* first alternative. This is done to simplify the implementation and to avoid */ 31 /* the need for a default constructor for each alternative type. */ 32 /********************************************************************************/ 33 34 STDEXEC_PRAGMA_PUSH() 35 STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces") 36 37 namespace stdexec { 38 #if STDEXEC_NVHPC() 39 enum __variant_npos_t : std::size_t { 40 __variant_npos = ~0UL 41 }; 42 #else 43 STDEXEC_GLOBAL_CONSTANT std::size_t __variant_npos = ~0UL; 44 #endif 45 46 struct __monostate { }; 47 48 namespace __var { 49 template <class _Ty> STDEXEC_ATTRIBUTE(host,device,always_inline)50 STDEXEC_ATTRIBUTE(host, device, always_inline) 51 void __destroy_at(_Ty *ptr) noexcept { 52 ptr->~_Ty(); 53 } 54 STDEXEC_ATTRIBUTE(host,device)55 STDEXEC_ATTRIBUTE(host, device) 56 inline auto __mk_index_guard(std::size_t &__index, std::size_t __new) noexcept { 57 __index = __new; 58 return __scope_guard{[&__index]() noexcept { __index = __variant_npos; }}; 59 } 60 61 template <auto _Idx, class... _Ts> 62 class __variant; 63 64 template <> 65 class __variant<__indices<>{}> { 66 public: 67 template <class _Fn, class... _Us> STDEXEC_ATTRIBUTE(host,device)68 STDEXEC_ATTRIBUTE(host, device) 69 void visit(_Fn &&, _Us &&...) const noexcept { 70 STDEXEC_ASSERT(false); 71 } 72 STDEXEC_ATTRIBUTE(host,device)73 STDEXEC_ATTRIBUTE(host, device) static constexpr auto index() noexcept -> std::size_t { 74 return __variant_npos; 75 } 76 STDEXEC_ATTRIBUTE(host,device)77 STDEXEC_ATTRIBUTE(host, device) static constexpr auto is_valueless() noexcept -> bool { 78 return true; 79 } 80 }; 81 82 template <std::size_t... _Is, __indices<_Is...> _Idx, class... _Ts> 83 class __variant<_Idx, _Ts...> { 84 static constexpr std::size_t __max_size = stdexec::__umax({sizeof(_Ts)...}); 85 static_assert(__max_size != 0); 86 std::size_t __index_{__variant_npos}; 87 alignas(_Ts...) unsigned char __storage_[__max_size]; 88 STDEXEC_ATTRIBUTE(host,device)89 STDEXEC_ATTRIBUTE(host, device) void __destroy() noexcept { 90 auto __index = std::exchange(__index_, __variant_npos); 91 if (__variant_npos != __index) { 92 ((_Is == __index ? __var::__destroy_at(static_cast<_Ts *>(__get_ptr())) : void(0)), ...); 93 } 94 } 95 96 public: 97 98 template <std::size_t _Ny> 99 using __at = __m_at_c<_Ny, _Ts...>; 100 101 // immovable: 102 __variant(__variant &&) = delete; 103 104 STDEXEC_ATTRIBUTE(host, device) __variant() noexcept = default; 105 STDEXEC_ATTRIBUTE(host,device)106 STDEXEC_ATTRIBUTE(host, device) ~__variant() { 107 __destroy(); 108 } 109 110 [[nodiscard]] STDEXEC_ATTRIBUTE(host,device,always_inline)111 STDEXEC_ATTRIBUTE(host, device, always_inline) auto __get_ptr() noexcept -> void * { 112 return __storage_; 113 } 114 115 [[nodiscard]] STDEXEC_ATTRIBUTE(host,device,always_inline)116 STDEXEC_ATTRIBUTE(host, device, always_inline) auto index() const noexcept -> std::size_t { 117 return __index_; 118 } 119 120 [[nodiscard]] STDEXEC_ATTRIBUTE(host,device,always_inline)121 STDEXEC_ATTRIBUTE(host, device, always_inline) auto is_valueless() const noexcept -> bool { 122 return __index_ == __variant_npos; 123 } 124 125 // The following emplace functions must take great care to avoid use-after-free bugs. 126 // If the object being constructed calls `start` on a newly created operation state 127 // (as does the object returned from `submit`), and if `start` completes inline, it 128 // could cause the destruction of the outer operation state that owns *this. The 129 // function below uses the following pattern to avoid this: 130 // 1. Store the new index in __index_. 131 // 2. Create a scope guard that will reset __index_ to __variant_npos if the 132 // constructor throws. 133 // 3. Construct the new object in the storage, which may cause the invalidation of 134 // *this. The emplace function must not access any members of *this after this point. 135 // 4. Dismiss the scope guard, which will leave __index_ set to the new index. 136 // 5. Return a reference to the new object -- which may be invalid! Calling code 137 // must be aware of the danger. 138 template <class _Ty, class... _As> STDEXEC_ATTRIBUTE(host,device)139 STDEXEC_ATTRIBUTE(host, device) 140 auto emplace(_As &&...__as) noexcept(__nothrow_constructible_from<_Ty, _As...>) -> _Ty & { 141 constexpr std::size_t __new_index = stdexec::__index_of<_Ty, _Ts...>(); 142 static_assert(__new_index != __variant_npos, "Type not in variant"); 143 144 __destroy(); 145 auto __sg = __mk_index_guard(__index_, __new_index); 146 auto *__p = ::new (__storage_) _Ty{static_cast<_As &&>(__as)...}; 147 __sg.__dismiss(); 148 return *std::launder(__p); 149 } 150 151 template <std::size_t _Ny, class... _As> STDEXEC_ATTRIBUTE(host,device)152 STDEXEC_ATTRIBUTE(host, device) 153 auto emplace(_As &&...__as) noexcept(__nothrow_constructible_from<__at<_Ny>, _As...>) 154 -> __at<_Ny> & { 155 static_assert(_Ny < sizeof...(_Ts), "variant index is too large"); 156 157 __destroy(); 158 auto __sg = __mk_index_guard(__index_, _Ny); 159 auto *__p = ::new (__storage_) __at<_Ny>{static_cast<_As &&>(__as)...}; 160 __sg.__dismiss(); 161 return *std::launder(__p); 162 } 163 164 template <std::size_t _Ny, class _Fn, class... _As> STDEXEC_ATTRIBUTE(host,device)165 STDEXEC_ATTRIBUTE(host, device) 166 auto emplace_from_at(_Fn &&__fn, _As &&...__as) noexcept(__nothrow_callable<_Fn, _As...>) 167 -> __at<_Ny> & { 168 static_assert( 169 __same_as<__call_result_t<_Fn, _As...>, __at<_Ny>>, 170 "callable does not return the correct type"); 171 172 __destroy(); 173 auto __sg = __mk_index_guard(__index_, _Ny); 174 auto *__p = ::new (__storage_) 175 __at<_Ny>(static_cast<_Fn &&>(__fn)(static_cast<_As &&>(__as)...)); 176 __sg.__dismiss(); 177 return *std::launder(__p); 178 } 179 180 template <class _Fn, class... _As> STDEXEC_ATTRIBUTE(host,device,always_inline)181 STDEXEC_ATTRIBUTE(host, device, always_inline) 182 auto emplace_from(_Fn &&__fn, _As &&...__as) noexcept(__nothrow_callable<_Fn, _As...>) 183 -> __call_result_t<_Fn, _As...> & { 184 using __result_t = __call_result_t<_Fn, _As...>; 185 constexpr std::size_t __new_index = stdexec::__index_of<__result_t, _Ts...>(); 186 static_assert(__new_index != __variant_npos, "Type not in variant"); 187 return emplace_from_at<__new_index>( 188 static_cast<_Fn &&>(__fn), static_cast<_As &&>(__as)...); 189 } 190 191 template <class _Fn, class _Self, class... _As> STDEXEC_ATTRIBUTE(host,device)192 STDEXEC_ATTRIBUTE(host, device) 193 static void visit(_Fn &&__fn, _Self &&__self, _As &&...__as) 194 noexcept((__nothrow_callable<_Fn, _As..., __copy_cvref_t<_Self, _Ts>> && ...)) { 195 STDEXEC_ASSERT(__self.__index_ != __variant_npos); 196 auto __index = __self.__index_; // make it local so we don't access it after it's deleted. 197 ((_Is == __index 198 ? static_cast<_Fn &&>(__fn)( 199 static_cast<_As &&>(__as)..., static_cast<_Self &&>(__self).template get<_Is>()) 200 : void()), 201 ...); 202 } 203 204 template <std::size_t _Ny> STDEXEC_ATTRIBUTE(host,device,always_inline)205 STDEXEC_ATTRIBUTE(host, device, always_inline) 206 auto get() && noexcept -> decltype(auto) { 207 STDEXEC_ASSERT(_Ny == __index_); 208 return static_cast<__at<_Ny> &&>(*reinterpret_cast<__at<_Ny> *>(__storage_)); 209 } 210 211 template <std::size_t _Ny> STDEXEC_ATTRIBUTE(host,device,always_inline)212 STDEXEC_ATTRIBUTE(host, device, always_inline) 213 auto get() & noexcept -> decltype(auto) { 214 STDEXEC_ASSERT(_Ny == __index_); 215 return *reinterpret_cast<__at<_Ny> *>(__storage_); 216 } 217 218 template <std::size_t _Ny> STDEXEC_ATTRIBUTE(host,device,always_inline)219 STDEXEC_ATTRIBUTE(host, device, always_inline) 220 auto get() const & noexcept -> decltype(auto) { 221 STDEXEC_ASSERT(_Ny == __index_); 222 return *reinterpret_cast<const __at<_Ny> *>(__storage_); 223 } 224 }; 225 } // namespace __var 226 227 using __var::__variant; 228 229 template <class... _Ts> 230 using __variant_for = __variant<__indices_for<_Ts...>{}, _Ts...>; 231 232 template <class... Ts> 233 using __uniqued_variant_for = __mcall<__munique<__qq<__variant_for>>, __decay_t<Ts>...>; 234 235 // So we can use __variant as a typelist 236 template <auto _Idx, class... _Ts> 237 struct __muncurry_<__variant<_Idx, _Ts...>> { 238 template <class _Fn> 239 using __f = __minvoke<_Fn, _Ts...>; 240 }; 241 } // namespace stdexec 242 243 STDEXEC_PRAGMA_POP() 244