1 /*
2 * Copyright (c) 2021-2024 NVIDIA Corporation
3 *
4 * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5 * (the "License"); you may not use this file except in compliance with
6 * the License. You may obtain a copy of the License at
7 *
8 * https://llvm.org/LICENSE.txt
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #pragma once
17
18 #include "../stdexec/__detail/__meta.hpp"
19 #include "../stdexec/__detail/__optional.hpp"
20 #include "../stdexec/__detail/__variant.hpp"
21 #include "../stdexec/coroutine.hpp"
22 #include "../stdexec/execution.hpp"
23 #include "any_sender_of.hpp"
24 #include "at_coroutine_exit.hpp"
25 #include "inline_scheduler.hpp"
26 #include "scope.hpp"
27
28 #include <any>
29 #include <cassert>
30 #include <exception>
31 #include <utility>
32
33 STDEXEC_PRAGMA_PUSH()
34 STDEXEC_PRAGMA_IGNORE_GNU("-Wundefined-inline")
35
36 namespace exec
37 {
38 namespace __task
39 {
40 using namespace stdexec;
41
42 using __any_scheduler = //
43 any_receiver_ref< //
44 completion_signatures<set_error_t(std::exception_ptr),
45 set_stopped_t()> //
46 >::any_sender<>::any_scheduler<>;
47 static_assert(scheduler<__any_scheduler>);
48
49 template <class _Ty>
50 concept __stop_token_provider = //
51 requires(const _Ty& t) { //
52 get_stop_token(t);
53 };
54
55 template <class _Ty>
56 concept __indirect_stop_token_provider = //
57 requires(const _Ty& t) {
58 { get_env(t) } -> __stop_token_provider;
59 };
60
61 template <class _Ty>
62 concept __indirect_scheduler_provider = //
63 requires(const _Ty& t) {
64 { get_env(t) } -> __scheduler_provider;
65 };
66
67 template <class _ParentPromise>
__check_parent_promise_has_scheduler()68 constexpr bool __check_parent_promise_has_scheduler() noexcept
69 {
70 static_assert(__indirect_scheduler_provider<_ParentPromise>,
71 "exec::task<T> cannot be co_await-ed in a coroutine that "
72 "does not have an associated scheduler.");
73 return __indirect_scheduler_provider<_ParentPromise>;
74 }
75
76 struct __forward_stop_request
77 {
78 inplace_stop_source& __stop_source_;
79
operator ()exec::__task::__forward_stop_request80 void operator()() noexcept
81 {
82 __stop_source_.request_stop();
83 }
84 };
85
86 template <class _ParentPromise>
87 struct __default_awaiter_context;
88
89 ////////////////////////////////////////////////////////////////////////////////
90 // This is the context that is associated with basic_task's promise type
91 // by default. It handles forwarding of stop requests from parent to child.
92 enum class __scheduler_affinity
93 {
94 __none,
95 __sticky
96 };
97
98 struct __parent_promise_t
99 {};
100
101 template <__scheduler_affinity _SchedulerAffinity =
102 __scheduler_affinity::__sticky>
103 class __default_task_context_impl
104 {
105 template <class _ParentPromise>
106 friend struct __default_awaiter_context;
107
108 static constexpr bool __with_scheduler =
109 _SchedulerAffinity == __scheduler_affinity::__sticky;
110
111 STDEXEC_ATTRIBUTE((no_unique_address))
112 __if_c<__with_scheduler, __any_scheduler, __ignore> //
113 __scheduler_{exec::inline_scheduler{}};
114 inplace_stop_token __stop_token_;
115
116 public:
117 template <class _ParentPromise>
__default_task_context_impl(__parent_promise_t,_ParentPromise & __parent)118 explicit __default_task_context_impl(__parent_promise_t,
119 _ParentPromise& __parent) noexcept
120 {
121 if constexpr (_SchedulerAffinity == __scheduler_affinity::__sticky)
122 {
123 if constexpr (__check_parent_promise_has_scheduler<
124 _ParentPromise>())
125 {
126 __scheduler_ = get_scheduler(get_env(__parent));
127 }
128 }
129 }
130
131 template <scheduler _Scheduler>
__default_task_context_impl(_Scheduler && __scheduler)132 explicit __default_task_context_impl(_Scheduler&& __scheduler) :
133 __scheduler_{static_cast<_Scheduler&&>(__scheduler)}
134 {}
135
query(get_scheduler_t) const136 auto query(get_scheduler_t) const noexcept -> const __any_scheduler&
137 requires(__with_scheduler)
138 {
139 return __scheduler_;
140 }
141
query(get_stop_token_t) const142 auto query(get_stop_token_t) const noexcept -> inplace_stop_token
143 {
144 return __stop_token_;
145 }
146
stop_requested() const147 [[nodiscard]] auto stop_requested() const noexcept -> bool
148 {
149 return __stop_token_.stop_requested();
150 }
151
152 template <scheduler _Scheduler>
set_scheduler(_Scheduler && __sched)153 void set_scheduler(_Scheduler&& __sched)
154 requires(__with_scheduler)
155 {
156 __scheduler_ = static_cast<_Scheduler&&>(__sched);
157 }
158
159 template <class _ThisPromise>
160 using promise_context_t = __default_task_context_impl;
161
162 template <class _ThisPromise, class _ParentPromise = void>
163 requires(!__with_scheduler) ||
164 __indirect_scheduler_provider<_ParentPromise>
165 using awaiter_context_t = __default_awaiter_context<_ParentPromise>;
166 };
167
168 template <class _Ty>
169 using default_task_context =
170 __default_task_context_impl<__scheduler_affinity::__sticky>;
171
172 template <class _Ty>
173 using __raw_task_context =
174 __default_task_context_impl<__scheduler_affinity::__none>;
175
176 // This is the context associated with basic_task's awaiter. By default
177 // it does nothing.
178 template <class _ParentPromise>
179 struct __default_awaiter_context
180 {
181 template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context182 explicit __default_awaiter_context(
183 __default_task_context_impl<_Affinity>& __self,
184 _ParentPromise& __parent) noexcept
185 {}
186 };
187
188 ////////////////////////////////////////////////////////////////////////////////
189 // This is the context to be associated with basic_task's awaiter when
190 // the parent coroutine's promise type is known, is a __stop_token_provider,
191 // and its stop token type is neither inplace_stop_token nor unstoppable.
192 template <__indirect_stop_token_provider _ParentPromise>
193 struct __default_awaiter_context<_ParentPromise>
194 {
195 using __stop_token_t = stop_token_of_t<env_of_t<_ParentPromise>>;
196 using __stop_callback_t =
197 typename __stop_token_t::template callback_type<__forward_stop_request>;
198
199 template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context200 explicit __default_awaiter_context(
201 __default_task_context_impl<_Affinity>& __self,
202 _ParentPromise& __parent) noexcept
203 // Register a callback that will request stop on this basic_task's
204 // stop_source when stop is requested on the parent coroutine's stop
205 // token.
206 :
207 __stop_callback_{get_stop_token(get_env(__parent)),
208 __forward_stop_request{__stop_source_}}
209 {
210 static_assert(
211 std::is_nothrow_constructible_v<__stop_callback_t, __stop_token_t,
212 __forward_stop_request>);
213 __self.__stop_token_ = __stop_source_.get_token();
214 }
215
216 inplace_stop_source __stop_source_{};
217 __stop_callback_t __stop_callback_;
218 };
219
220 // If the parent coroutine's type has a stop token of type inplace_stop_token,
221 // we don't need to register a stop callback.
222 template <__indirect_stop_token_provider _ParentPromise>
223 requires std::same_as<inplace_stop_token,
224 stop_token_of_t<env_of_t<_ParentPromise>>>
225 struct __default_awaiter_context<_ParentPromise>
226 {
227 template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context228 explicit __default_awaiter_context(
229 __default_task_context_impl<_Affinity>& __self,
230 _ParentPromise& __parent) noexcept
231 {
232 __self.__stop_token_ = get_stop_token(get_env(__parent));
233 }
234 };
235
236 // If the parent coroutine's stop token is unstoppable, there's no point
237 // forwarding stop tokens or stop requests at all.
238 template <__indirect_stop_token_provider _ParentPromise>
239 requires unstoppable_token<stop_token_of_t<env_of_t<_ParentPromise>>>
240 struct __default_awaiter_context<_ParentPromise>
241 {
242 template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context243 explicit __default_awaiter_context(__default_task_context_impl<_Affinity>&,
244 _ParentPromise&) noexcept
245 {}
246 };
247
248 // Finally, if we don't know the parent coroutine's promise type, assume the
249 // worst and save a type-erased stop callback.
250 template <>
251 struct __default_awaiter_context<void>
252 {
253 template <__scheduler_affinity _Affinity, class _ParentPromise>
__default_awaiter_contextexec::__task::__default_awaiter_context254 explicit __default_awaiter_context(
255 __default_task_context_impl<_Affinity>& __self,
256 _ParentPromise& __parent) noexcept
257 {}
258
259 template <__scheduler_affinity _Affinity,
260 __indirect_stop_token_provider _ParentPromise>
__default_awaiter_contextexec::__task::__default_awaiter_context261 explicit __default_awaiter_context(
262 __default_task_context_impl<_Affinity>& __self,
263 _ParentPromise& __parent)
264 {
265 // Register a callback that will request stop on this basic_task's
266 // stop_source when stop is requested on the parent coroutine's stop
267 // token.
268 using __stop_token_t = stop_token_of_t<env_of_t<_ParentPromise>>;
269 using __stop_callback_t =
270 stop_callback_for_t<__stop_token_t, __forward_stop_request>;
271
272 if constexpr (std::same_as<__stop_token_t, inplace_stop_token>)
273 {
274 __self.__stop_token_ = get_stop_token(get_env(__parent));
275 }
276 else if (auto __token = get_stop_token(get_env(__parent));
277 __token.stop_possible())
278 {
279 __stop_callback_.emplace<__stop_callback_t>(
280 std::move(__token), __forward_stop_request{__stop_source_});
281 __self.__stop_token_ = __stop_source_.get_token();
282 }
283 }
284
285 inplace_stop_source __stop_source_{};
286 std::any __stop_callback_{};
287 };
288
289 template <class _Promise, class _ParentPromise = void>
290 using awaiter_context_t = //
291 typename __decay_t<env_of_t<_Promise>> //
292 ::template awaiter_context_t<_Promise, _ParentPromise>;
293
294 ////////////////////////////////////////////////////////////////////////////////
295 // In a base class so it can be specialized when _Ty is void:
296 template <class _Ty>
297 struct __promise_base
298 {
return_valueexec::__task::__promise_base299 void return_value(_Ty value)
300 {
301 __data_.template emplace<0>(std::move(value));
302 }
303
304 __variant_for<_Ty, std::exception_ptr> __data_{};
305 };
306
307 template <>
308 struct __promise_base<void>
309 {
310 struct __void
311 {};
312
return_voidexec::__task::__promise_base313 void return_void()
314 {
315 __data_.template emplace<0>(__void{});
316 }
317
318 __variant_for<__void, std::exception_ptr> __data_{};
319 };
320
321 enum class disposition : unsigned
322 {
323 stopped,
324 succeeded,
325 failed,
326 };
327
328 struct __reschedule_coroutine_on
329 {
330 template <class _Scheduler>
331 struct __wrap
332 {
333 _Scheduler __sched_;
334 };
335
336 template <scheduler _Scheduler>
operator ()exec::__task::__reschedule_coroutine_on337 auto operator()(_Scheduler __sched) const noexcept -> __wrap<_Scheduler>
338 {
339 return {static_cast<_Scheduler&&>(__sched)};
340 }
341 };
342
343 ////////////////////////////////////////////////////////////////////////////////
344 // basic_task
345 template <class _Ty, class _Context = default_task_context<_Ty>>
346 class [[nodiscard]] basic_task
347 {
348 struct __promise;
349
350 public:
351 using __t = basic_task;
352 using __id = basic_task;
353 using promise_type = __promise;
354
basic_task(basic_task && __that)355 basic_task(basic_task&& __that) noexcept :
356 __coro_(std::exchange(__that.__coro_, {}))
357 {}
358
~basic_task()359 ~basic_task()
360 {
361 if (__coro_)
362 __coro_.destroy();
363 }
364
365 private:
366 struct __final_awaitable
367 {
await_readyexec::__task::basic_task::__final_awaitable368 static constexpr auto await_ready() noexcept -> bool
369 {
370 return false;
371 }
372
await_suspendexec::__task::basic_task::__final_awaitable373 static auto await_suspend(
374 __coro::coroutine_handle<__promise> __h) noexcept
375 -> __coro::coroutine_handle<>
376 {
377 return __h.promise().continuation().handle();
378 }
379
await_resumeexec::__task::basic_task::__final_awaitable380 static void await_resume() noexcept {}
381 };
382
383 using __promise_context_t =
384 typename _Context::template promise_context_t<__promise>;
385
386 struct __promise : __promise_base<_Ty>, with_awaitable_senders<__promise>
387 {
388 using __t = __promise;
389 using __id = __promise;
390
get_return_objectexec::__task::basic_task::__promise391 auto get_return_object() noexcept -> basic_task
392 {
393 return basic_task(
394 __coro::coroutine_handle<__promise>::from_promise(*this));
395 }
396
initial_suspendexec::__task::basic_task::__promise397 auto initial_suspend() noexcept -> __coro::suspend_always
398 {
399 return {};
400 }
401
final_suspendexec::__task::basic_task::__promise402 auto final_suspend() noexcept -> __final_awaitable
403 {
404 return {};
405 }
406
dispositionexec::__task::basic_task::__promise407 [[nodiscard]] auto disposition() const noexcept -> __task::disposition
408 {
409 switch (this->__data_.index())
410 {
411 case 0:
412 return __task::disposition::succeeded;
413 case 1:
414 return __task::disposition::failed;
415 default:
416 return __task::disposition::stopped;
417 }
418 }
419
unhandled_exceptionexec::__task::basic_task::__promise420 void unhandled_exception() noexcept
421 {
422 this->__data_.template emplace<1>(std::current_exception());
423 }
424
425 template <sender _Awaitable>
426 requires __scheduler_provider<_Context>
await_transformexec::__task::basic_task::__promise427 auto await_transform(_Awaitable&& __awaitable) noexcept
428 -> decltype(auto)
429 {
430 // TODO: If we have a complete-where-it-starts query then we can
431 // optimize this to avoid the reschedule
432 return as_awaitable(
433 continues_on(static_cast<_Awaitable&&>(__awaitable),
434 get_scheduler(*__context_)),
435 *this);
436 }
437
438 template <class _Scheduler>
439 requires __scheduler_provider<_Context>
await_transformexec::__task::basic_task::__promise440 auto await_transform(
441 __reschedule_coroutine_on::__wrap<_Scheduler> __box) noexcept
442 -> decltype(auto)
443 {
444 if (!std::exchange(__rescheduled_, true))
445 {
446 // Create a cleanup action that transitions back onto the
447 // current scheduler:
448 auto __sched = get_scheduler(*__context_);
449 auto __cleanup_task =
450 at_coroutine_exit(schedule, std::move(__sched));
451 // Insert the cleanup action into the head of the continuation
452 // chain by making direct calls to the cleanup task's awaiter
453 // member functions. See type _cleanup_task in
454 // at_coroutine_exit.hpp:
455 __cleanup_task.await_suspend(
456 __coro::coroutine_handle<__promise>::from_promise(*this));
457 (void)__cleanup_task.await_resume();
458 }
459 __context_->set_scheduler(__box.__sched_);
460 return as_awaitable(schedule(__box.__sched_), *this);
461 }
462
463 template <class _Awaitable>
await_transformexec::__task::basic_task::__promise464 auto await_transform(_Awaitable&& __awaitable) noexcept
465 -> decltype(auto)
466 {
467 return with_awaitable_senders<__promise>::await_transform(
468 static_cast<_Awaitable&&>(__awaitable));
469 }
470
get_envexec::__task::basic_task::__promise471 auto get_env() const noexcept -> const __promise_context_t&
472 {
473 return *__context_;
474 }
475
476 __optional<__promise_context_t> __context_{};
477 bool __rescheduled_{false};
478 };
479
480 template <class _ParentPromise = void>
481 struct __task_awaitable
482 {
483 __coro::coroutine_handle<__promise> __coro_;
484 __optional<awaiter_context_t<__promise, _ParentPromise>> __context_{};
485
~__task_awaitableexec::__task::basic_task::__task_awaitable486 ~__task_awaitable()
487 {
488 if (__coro_)
489 __coro_.destroy();
490 }
491
await_readyexec::__task::basic_task::__task_awaitable492 static constexpr auto await_ready() noexcept -> bool
493 {
494 return false;
495 }
496
497 template <class _ParentPromise2>
await_suspendexec::__task::basic_task::__task_awaitable498 auto await_suspend(
499 __coro::coroutine_handle<_ParentPromise2> __parent) noexcept
500 -> __coro::coroutine_handle<>
501 {
502 static_assert(__one_of<_ParentPromise, _ParentPromise2, void>);
503 __coro_.promise().__context_.emplace(__parent_promise_t(),
504 __parent.promise());
505 __context_.emplace(*__coro_.promise().__context_,
506 __parent.promise());
507 __coro_.promise().set_continuation(__parent);
508 if constexpr (requires {
509 __coro_.promise().stop_requested() ? 0 : 1;
510 })
511 {
512 if (__coro_.promise().stop_requested())
513 return __parent.promise().unhandled_stopped();
514 }
515 return __coro_;
516 }
517
await_resumeexec::__task::basic_task::__task_awaitable518 auto await_resume() -> _Ty
519 {
520 __context_.reset();
521 scope_guard __on_exit{[this]() noexcept {
522 std::exchange(__coro_, {}).destroy();
523 }};
524 if (__coro_.promise().__data_.index() == 1)
525 std::rethrow_exception(
526 std::move(__coro_.promise().__data_.template get<1>()));
527 if constexpr (!std::is_void_v<_Ty>)
528 return std::move(__coro_.promise().__data_.template get<0>());
529 }
530 };
531
532 public:
533 // Make this task awaitable within a particular context:
534 template <class _ParentPromise>
535 requires constructible_from<
536 awaiter_context_t<__promise, _ParentPromise>, __promise_context_t&,
537 _ParentPromise&>
STDEXEC_MEMFN_DECL(auto as_awaitable)538 STDEXEC_MEMFN_DECL(auto as_awaitable)(this basic_task&& __self,
539 _ParentPromise&) noexcept
540 -> __task_awaitable<_ParentPromise>
541 {
542 return __task_awaitable<_ParentPromise>{
543 std::exchange(__self.__coro_, {})};
544 }
545
546 // Make this task generally awaitable:
operator co_await()547 auto operator co_await() && noexcept -> __task_awaitable<>
548 requires __mvalid<awaiter_context_t, __promise>
549 {
550 return __task_awaitable<>{std::exchange(__coro_, {})};
551 }
552
553 // From the list of types [_Ty], remove any types that are void, and send
554 // the resulting list to __qf<set_value_t>, which uses the list of types
555 // as arguments of a function type. In other words, set_value_t() if _Ty
556 // is void, and set_value_t(_Ty) otherwise.
557 using __set_value_sig_t =
558 __minvoke<__mremove<void, __qf<set_value_t>>, _Ty>;
559
560 // Specify basic_task's completion signatures
561 // This is only necessary when basic_task is not generally awaitable
562 // owing to constraints imposed by its _Context parameter.
563 using __task_traits_t = //
564 completion_signatures<__set_value_sig_t,
565 set_error_t(std::exception_ptr), set_stopped_t()>;
566
get_completion_signatures(__ignore={}) const567 auto get_completion_signatures(__ignore = {}) const -> __task_traits_t
568 {
569 return {};
570 }
571
basic_task(__coro::coroutine_handle<promise_type> __coro)572 explicit basic_task(__coro::coroutine_handle<promise_type> __coro) noexcept
573 : __coro_(__coro)
574 {}
575
576 __coro::coroutine_handle<promise_type> __coro_;
577 };
578 } // namespace __task
579
580 using task_disposition = __task::disposition;
581
582 template <class _Ty>
583 using default_task_context = __task::default_task_context<_Ty>;
584
585 template <class _Promise, class _ParentPromise = void>
586 using awaiter_context_t = __task::awaiter_context_t<_Promise, _ParentPromise>;
587
588 template <class _Ty, class _Context = default_task_context<_Ty>>
589 using basic_task = __task::basic_task<_Ty, _Context>;
590
591 template <class _Ty>
592 using task = basic_task<_Ty, default_task_context<_Ty>>;
593
594 inline constexpr __task::__reschedule_coroutine_on reschedule_coroutine_on{};
595 } // namespace exec
596
597 namespace stdexec
598 {
599 template <class _Ty, class _Context>
600 inline constexpr bool enable_sender<exec::basic_task<_Ty, _Context>> = true;
601 } // namespace stdexec
602
603 STDEXEC_PRAGMA_POP()
604