1 /*
2 * Copyright (c) 2021-2024 NVIDIA Corporation
3 *
4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5 * (the "License"); you may not use this file except in compliance with
6 * the License. You may obtain a copy of the License at
7 *
8 * https://llvm.org/LICENSE.txt
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #pragma once
17
18 #include "../stdexec/__detail/__meta.hpp"
19 #include "../stdexec/__detail/__optional.hpp"
20 #include "../stdexec/__detail/__variant.hpp"
21 #include "../stdexec/coroutine.hpp"
22 #include "../stdexec/execution.hpp"
23 #include "any_sender_of.hpp"
24 #include "at_coroutine_exit.hpp"
25 #include "inline_scheduler.hpp"
26 #include "scope.hpp"
27
28 #include <any>
29 #include <cassert>
30 #include <exception>
31 #include <utility>
32
33 STDEXEC_PRAGMA_PUSH()
34 STDEXEC_PRAGMA_IGNORE_GNU("-Wundefined-inline")
35
36 namespace exec
37 {
38 namespace __task
39 {
40 using namespace stdexec;
41
42 using __any_scheduler = //
43 any_receiver_ref< //
44 completion_signatures<set_error_t(std::exception_ptr),
45 set_stopped_t()> //
46 >::any_sender<>::any_scheduler<>;
47 static_assert(scheduler<__any_scheduler>);
48
49 template <class _Ty>
50 concept __stop_token_provider = //
51 requires(const _Ty& t) { //
52 get_stop_token(t);
53 };
54
55 template <class _Ty>
56 concept __indirect_stop_token_provider = //
57 requires(const _Ty& t) {
58 { get_env(t) } -> __stop_token_provider;
59 };
60
61 template <class _Ty>
62 concept __indirect_scheduler_provider = //
63 requires(const _Ty& t) {
64 { get_env(t) } -> __scheduler_provider;
65 };
66
67 template <class _ParentPromise>
__check_parent_promise_has_scheduler()68 constexpr bool __check_parent_promise_has_scheduler() noexcept
69 {
70 static_assert(__indirect_scheduler_provider<_ParentPromise>,
71 "exec::task<T> cannot be co_await-ed in a coroutine that "
72 "does not have an associated scheduler.");
73 return __indirect_scheduler_provider<_ParentPromise>;
74 }
75
76 struct __forward_stop_request
77 {
78 inplace_stop_source& __stop_source_;
79
operator ()exec::__task::__forward_stop_request80 void operator()() noexcept
81 {
82 __stop_source_.request_stop();
83 }
84 };
85
86 template <class _ParentPromise>
87 struct __default_awaiter_context;
88
89 ////////////////////////////////////////////////////////////////////////////////
90 // This is the context that is associated with basic_task's promise type
91 // by default. It handles forwarding of stop requests from parent to child.
92 enum class __scheduler_affinity
93 {
94 __none,
95 __sticky
96 };
97
98 struct __parent_promise_t
99 {};
100
101 template <__scheduler_affinity _SchedulerAffinity =
102 __scheduler_affinity::__sticky>
103 class __default_task_context_impl
104 {
105 template <class _ParentPromise>
106 friend struct __default_awaiter_context;
107
108 static constexpr bool __with_scheduler =
109 _SchedulerAffinity == __scheduler_affinity::__sticky;
110
111 STDEXEC_ATTRIBUTE((no_unique_address))
112 __if_c<__with_scheduler, __any_scheduler, __ignore> //
113 __scheduler_{exec::inline_scheduler{}};
114 inplace_stop_token __stop_token_;
115
116 public:
117 template <class _ParentPromise>
__default_task_context_impl(__parent_promise_t,_ParentPromise & __parent)118 explicit __default_task_context_impl(__parent_promise_t,
119 _ParentPromise& __parent) noexcept
120 {
121 if constexpr (_SchedulerAffinity == __scheduler_affinity::__sticky)
122 {
123 if constexpr (__check_parent_promise_has_scheduler<
124 _ParentPromise>())
125 {
126 __scheduler_ = get_scheduler(get_env(__parent));
127 }
128 }
129 }
130
131 template <scheduler _Scheduler>
__default_task_context_impl(_Scheduler && __scheduler)132 explicit __default_task_context_impl(_Scheduler&& __scheduler) :
133 __scheduler_{static_cast<_Scheduler&&>(__scheduler)}
134 {}
135
query(get_scheduler_t) const136 auto query(get_scheduler_t) const noexcept
137 -> const __any_scheduler& requires(__with_scheduler) {
138 return __scheduler_;
139 }
140
query(get_stop_token_t) const141 auto query(get_stop_token_t) const noexcept -> inplace_stop_token
142 {
143 return __stop_token_;
144 }
145
stop_requested() const146 [[nodiscard]] auto stop_requested() const noexcept -> bool
147 {
148 return __stop_token_.stop_requested();
149 }
150
151 template <scheduler _Scheduler>
set_scheduler(_Scheduler && __sched)152 void set_scheduler(_Scheduler&& __sched)
153 requires(__with_scheduler)
154 {
155 __scheduler_ = static_cast<_Scheduler&&>(__sched);
156 }
157
158 template <class _ThisPromise>
159 using promise_context_t = __default_task_context_impl;
160
161 template <class _ThisPromise, class _ParentPromise = void>
162 requires(!__with_scheduler) ||
163 __indirect_scheduler_provider<_ParentPromise>
164 using awaiter_context_t = __default_awaiter_context<_ParentPromise>;
165 };
166
167 template <class _Ty>
168 using default_task_context =
169 __default_task_context_impl<__scheduler_affinity::__sticky>;
170
171 template <class _Ty>
172 using __raw_task_context =
173 __default_task_context_impl<__scheduler_affinity::__none>;
174
175 // This is the context associated with basic_task's awaiter. By default
176 // it does nothing.
177 template <class _ParentPromise>
178 struct __default_awaiter_context
179 {
180 template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context181 explicit __default_awaiter_context(
182 __default_task_context_impl<_Affinity>& __self,
183 _ParentPromise& __parent) noexcept
184 {}
185 };
186
187 ////////////////////////////////////////////////////////////////////////////////
188 // This is the context to be associated with basic_task's awaiter when
189 // the parent coroutine's promise type is known, is a __stop_token_provider,
190 // and its stop token type is neither inplace_stop_token nor unstoppable.
191 template <__indirect_stop_token_provider _ParentPromise>
192 struct __default_awaiter_context<_ParentPromise>
193 {
194 using __stop_token_t = stop_token_of_t<env_of_t<_ParentPromise>>;
195 using __stop_callback_t =
196 typename __stop_token_t::template callback_type<__forward_stop_request>;
197
198 template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context199 explicit __default_awaiter_context(
200 __default_task_context_impl<_Affinity>& __self,
201 _ParentPromise& __parent) noexcept
202 // Register a callback that will request stop on this basic_task's
203 // stop_source when stop is requested on the parent coroutine's stop
204 // token.
205 :
206 __stop_callback_{get_stop_token(get_env(__parent)),
207 __forward_stop_request{__stop_source_}}
208 {
209 static_assert(
210 std::is_nothrow_constructible_v<__stop_callback_t, __stop_token_t,
211 __forward_stop_request>);
212 __self.__stop_token_ = __stop_source_.get_token();
213 }
214
215 inplace_stop_source __stop_source_{};
216 __stop_callback_t __stop_callback_;
217 };
218
219 // If the parent coroutine's type has a stop token of type inplace_stop_token,
220 // we don't need to register a stop callback.
221 template <__indirect_stop_token_provider _ParentPromise>
222 requires std::same_as<inplace_stop_token,
223 stop_token_of_t<env_of_t<_ParentPromise>>>
224 struct __default_awaiter_context<_ParentPromise>
225 {
226 template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context227 explicit __default_awaiter_context(
228 __default_task_context_impl<_Affinity>& __self,
229 _ParentPromise& __parent) noexcept
230 {
231 __self.__stop_token_ = get_stop_token(get_env(__parent));
232 }
233 };
234
235 // If the parent coroutine's stop token is unstoppable, there's no point
236 // forwarding stop tokens or stop requests at all.
237 template <__indirect_stop_token_provider _ParentPromise>
238 requires unstoppable_token<stop_token_of_t<env_of_t<_ParentPromise>>>
239 struct __default_awaiter_context<_ParentPromise>
240 {
241 template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context242 explicit __default_awaiter_context(__default_task_context_impl<_Affinity>&,
243 _ParentPromise&) noexcept
244 {}
245 };
246
247 // Finally, if we don't know the parent coroutine's promise type, assume the
248 // worst and save a type-erased stop callback.
249 template <>
250 struct __default_awaiter_context<void>
251 {
252 template <__scheduler_affinity _Affinity, class _ParentPromise>
__default_awaiter_contextexec::__task::__default_awaiter_context253 explicit __default_awaiter_context(
254 __default_task_context_impl<_Affinity>& __self,
255 _ParentPromise& __parent) noexcept
256 {}
257
258 template <__scheduler_affinity _Affinity,
259 __indirect_stop_token_provider _ParentPromise>
__default_awaiter_contextexec::__task::__default_awaiter_context260 explicit __default_awaiter_context(
261 __default_task_context_impl<_Affinity>& __self,
262 _ParentPromise& __parent)
263 {
264 // Register a callback that will request stop on this basic_task's
265 // stop_source when stop is requested on the parent coroutine's stop
266 // token.
267 using __stop_token_t = stop_token_of_t<env_of_t<_ParentPromise>>;
268 using __stop_callback_t =
269 stop_callback_for_t<__stop_token_t, __forward_stop_request>;
270
271 if constexpr (std::same_as<__stop_token_t, inplace_stop_token>)
272 {
273 __self.__stop_token_ = get_stop_token(get_env(__parent));
274 }
275 else if (auto __token = get_stop_token(get_env(__parent));
276 __token.stop_possible())
277 {
278 __stop_callback_.emplace<__stop_callback_t>(
279 std::move(__token), __forward_stop_request{__stop_source_});
280 __self.__stop_token_ = __stop_source_.get_token();
281 }
282 }
283
284 inplace_stop_source __stop_source_{};
285 std::any __stop_callback_{};
286 };
287
288 template <class _Promise, class _ParentPromise = void>
289 using awaiter_context_t = //
290 typename __decay_t<env_of_t<_Promise>> //
291 ::template awaiter_context_t<_Promise, _ParentPromise>;
292
293 ////////////////////////////////////////////////////////////////////////////////
294 // In a base class so it can be specialized when _Ty is void:
295 template <class _Ty>
296 struct __promise_base
297 {
return_valueexec::__task::__promise_base298 void return_value(_Ty value)
299 {
300 __data_.template emplace<0>(std::move(value));
301 }
302
303 __variant_for<_Ty, std::exception_ptr> __data_{};
304 };
305
306 template <>
307 struct __promise_base<void>
308 {
309 struct __void
310 {};
311
return_voidexec::__task::__promise_base312 void return_void()
313 {
314 __data_.template emplace<0>(__void{});
315 }
316
317 __variant_for<__void, std::exception_ptr> __data_{};
318 };
319
320 enum class disposition : unsigned
321 {
322 stopped,
323 succeeded,
324 failed,
325 };
326
327 struct __reschedule_coroutine_on
328 {
329 template <class _Scheduler>
330 struct __wrap
331 {
332 _Scheduler __sched_;
333 };
334
335 template <scheduler _Scheduler>
operator ()exec::__task::__reschedule_coroutine_on336 auto operator()(_Scheduler __sched) const noexcept -> __wrap<_Scheduler>
337 {
338 return {static_cast<_Scheduler&&>(__sched)};
339 }
340 };
341
342 ////////////////////////////////////////////////////////////////////////////////
343 // basic_task
344 template <class _Ty, class _Context = default_task_context<_Ty>>
345 class [[nodiscard]] basic_task
346 {
347 struct __promise;
348
349 public:
350 using __t = basic_task;
351 using __id = basic_task;
352 using promise_type = __promise;
353
basic_task(basic_task && __that)354 basic_task(basic_task&& __that) noexcept :
355 __coro_(std::exchange(__that.__coro_, {}))
356 {}
357
~basic_task()358 ~basic_task()
359 {
360 if (__coro_)
361 __coro_.destroy();
362 }
363
364 private:
365 struct __final_awaitable
366 {
await_readyexec::__task::basic_task::__final_awaitable367 static constexpr auto await_ready() noexcept -> bool
368 {
369 return false;
370 }
371
await_suspendexec::__task::basic_task::__final_awaitable372 static auto await_suspend(
373 __coro::coroutine_handle<__promise> __h) noexcept
374 -> __coro::coroutine_handle<>
375 {
376 return __h.promise().continuation().handle();
377 }
378
await_resumeexec::__task::basic_task::__final_awaitable379 static void await_resume() noexcept {}
380 };
381
382 using __promise_context_t =
383 typename _Context::template promise_context_t<__promise>;
384
385 struct __promise : __promise_base<_Ty>, with_awaitable_senders<__promise>
386 {
387 using __t = __promise;
388 using __id = __promise;
389
get_return_objectexec::__task::basic_task::__promise390 auto get_return_object() noexcept -> basic_task
391 {
392 return basic_task(
393 __coro::coroutine_handle<__promise>::from_promise(*this));
394 }
395
initial_suspendexec::__task::basic_task::__promise396 auto initial_suspend() noexcept -> __coro::suspend_always
397 {
398 return {};
399 }
400
final_suspendexec::__task::basic_task::__promise401 auto final_suspend() noexcept -> __final_awaitable
402 {
403 return {};
404 }
405
dispositionexec::__task::basic_task::__promise406 [[nodiscard]] auto disposition() const noexcept -> __task::disposition
407 {
408 switch (this->__data_.index())
409 {
410 case 0:
411 return __task::disposition::succeeded;
412 case 1:
413 return __task::disposition::failed;
414 default:
415 return __task::disposition::stopped;
416 }
417 }
418
unhandled_exceptionexec::__task::basic_task::__promise419 void unhandled_exception() noexcept
420 {
421 this->__data_.template emplace<1>(std::current_exception());
422 }
423
424 template <sender _Awaitable>
425 requires __scheduler_provider<_Context>
await_transformexec::__task::basic_task::__promise426 auto await_transform(_Awaitable&& __awaitable) noexcept
427 -> decltype(auto)
428 {
429 // TODO: If we have a complete-where-it-starts query then we can
430 // optimize this to avoid the reschedule
431 return as_awaitable(
432 continues_on(static_cast<_Awaitable&&>(__awaitable),
433 get_scheduler(*__context_)),
434 *this);
435 }
436
437 template <class _Scheduler>
438 requires __scheduler_provider<_Context>
await_transformexec::__task::basic_task::__promise439 auto await_transform(
440 __reschedule_coroutine_on::__wrap<_Scheduler> __box) noexcept
441 -> decltype(auto)
442 {
443 if (!std::exchange(__rescheduled_, true))
444 {
445 // Create a cleanup action that transitions back onto the
446 // current scheduler:
447 auto __sched = get_scheduler(*__context_);
448 auto __cleanup_task =
449 at_coroutine_exit(schedule, std::move(__sched));
450 // Insert the cleanup action into the head of the continuation
451 // chain by making direct calls to the cleanup task's awaiter
452 // member functions. See type _cleanup_task in
453 // at_coroutine_exit.hpp:
454 __cleanup_task.await_suspend(
455 __coro::coroutine_handle<__promise>::from_promise(*this));
456 (void)__cleanup_task.await_resume();
457 }
458 __context_->set_scheduler(__box.__sched_);
459 return as_awaitable(schedule(__box.__sched_), *this);
460 }
461
462 template <class _Awaitable>
await_transformexec::__task::basic_task::__promise463 auto await_transform(_Awaitable&& __awaitable) noexcept
464 -> decltype(auto)
465 {
466 return with_awaitable_senders<__promise>::await_transform(
467 static_cast<_Awaitable&&>(__awaitable));
468 }
469
get_envexec::__task::basic_task::__promise470 auto get_env() const noexcept -> const __promise_context_t&
471 {
472 return *__context_;
473 }
474
475 __optional<__promise_context_t> __context_{};
476 bool __rescheduled_{false};
477 };
478
479 template <class _ParentPromise = void>
480 struct __task_awaitable
481 {
482 __coro::coroutine_handle<__promise> __coro_;
483 __optional<awaiter_context_t<__promise, _ParentPromise>> __context_{};
484
~__task_awaitableexec::__task::basic_task::__task_awaitable485 ~__task_awaitable()
486 {
487 if (__coro_)
488 __coro_.destroy();
489 }
490
await_readyexec::__task::basic_task::__task_awaitable491 static constexpr auto await_ready() noexcept -> bool
492 {
493 return false;
494 }
495
496 template <class _ParentPromise2>
await_suspendexec::__task::basic_task::__task_awaitable497 auto await_suspend(
498 __coro::coroutine_handle<_ParentPromise2> __parent) noexcept
499 -> __coro::coroutine_handle<>
500 {
501 static_assert(__one_of<_ParentPromise, _ParentPromise2, void>);
502 __coro_.promise().__context_.emplace(__parent_promise_t(),
503 __parent.promise());
504 __context_.emplace(*__coro_.promise().__context_,
505 __parent.promise());
506 __coro_.promise().set_continuation(__parent);
507 if constexpr (requires {
508 __coro_.promise().stop_requested() ? 0 : 1;
509 })
510 {
511 if (__coro_.promise().stop_requested())
512 return __parent.promise().unhandled_stopped();
513 }
514 return __coro_;
515 }
516
await_resumeexec::__task::basic_task::__task_awaitable517 auto await_resume() -> _Ty
518 {
519 __context_.reset();
520 scope_guard __on_exit{[this]() noexcept {
521 std::exchange(__coro_, {}).destroy();
522 }};
523 if (__coro_.promise().__data_.index() == 1)
524 std::rethrow_exception(
525 std::move(__coro_.promise().__data_.template get<1>()));
526 if constexpr (!std::is_void_v<_Ty>)
527 return std::move(__coro_.promise().__data_.template get<0>());
528 }
529 };
530
531 public:
532 // Make this task awaitable within a particular context:
533 template <class _ParentPromise>
534 requires constructible_from<
535 awaiter_context_t<__promise, _ParentPromise>, __promise_context_t&,
536 _ParentPromise&>
STDEXEC_MEMFN_DECL(auto as_awaitable)537 STDEXEC_MEMFN_DECL(auto as_awaitable)(this basic_task&& __self,
538 _ParentPromise&) noexcept
539 -> __task_awaitable<_ParentPromise>
540 {
541 return __task_awaitable<_ParentPromise>{
542 std::exchange(__self.__coro_, {})};
543 }
544
545 // Make this task generally awaitable:
operator co_await()546 auto operator co_await() && noexcept -> __task_awaitable<>
547 requires __mvalid<awaiter_context_t, __promise>
548 {
549 return __task_awaitable<>{std::exchange(__coro_, {})};
550 }
551
552 // From the list of types [_Ty], remove any types that are void, and send
553 // the resulting list to __qf<set_value_t>, which uses the list of types
554 // as arguments of a function type. In other words, set_value_t() if _Ty
555 // is void, and set_value_t(_Ty) otherwise.
556 using __set_value_sig_t =
557 __minvoke<__mremove<void, __qf<set_value_t>>, _Ty>;
558
559 // Specify basic_task's completion signatures
560 // This is only necessary when basic_task is not generally awaitable
561 // owing to constraints imposed by its _Context parameter.
562 using __task_traits_t = //
563 completion_signatures<__set_value_sig_t,
564 set_error_t(std::exception_ptr), set_stopped_t()>;
565
get_completion_signatures(__ignore={}) const566 auto get_completion_signatures(__ignore = {}) const -> __task_traits_t
567 {
568 return {};
569 }
570
basic_task(__coro::coroutine_handle<promise_type> __coro)571 explicit basic_task(__coro::coroutine_handle<promise_type> __coro) noexcept
572 : __coro_(__coro)
573 {}
574
575 __coro::coroutine_handle<promise_type> __coro_;
576 };
577 } // namespace __task
578
579 using task_disposition = __task::disposition;
580
581 template <class _Ty>
582 using default_task_context = __task::default_task_context<_Ty>;
583
584 template <class _Promise, class _ParentPromise = void>
585 using awaiter_context_t = __task::awaiter_context_t<_Promise, _ParentPromise>;
586
587 template <class _Ty, class _Context = default_task_context<_Ty>>
588 using basic_task = __task::basic_task<_Ty, _Context>;
589
590 template <class _Ty>
591 using task = basic_task<_Ty, default_task_context<_Ty>>;
592
593 inline constexpr __task::__reschedule_coroutine_on reschedule_coroutine_on{};
594 } // namespace exec
595
596 namespace stdexec
597 {
598 template <class _Ty, class _Context>
599 inline constexpr bool enable_sender<exec::basic_task<_Ty, _Context>> = true;
600 } // namespace stdexec
601
602 STDEXEC_PRAGMA_POP()
603