xref: /openbmc/sdbusplus/include/sdbusplus/async/stdexec/task.hpp (revision a23d26bd8cd0cc40ff6311e86e0de2c7f43f55ea)
1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "../stdexec/__detail/__meta.hpp"
19 #include "../stdexec/__detail/__optional.hpp"
20 #include "../stdexec/__detail/__variant.hpp"
21 #include "../stdexec/coroutine.hpp"
22 #include "../stdexec/execution.hpp"
23 #include "any_sender_of.hpp"
24 #include "at_coroutine_exit.hpp"
25 #include "inline_scheduler.hpp"
26 #include "scope.hpp"
27 
28 #include <any>
29 #include <cassert>
30 #include <exception>
31 #include <utility>
32 
33 STDEXEC_PRAGMA_PUSH()
34 STDEXEC_PRAGMA_IGNORE_GNU("-Wundefined-inline")
35 
36 namespace exec
37 {
38 namespace __task
39 {
40 using namespace stdexec;
41 
42 using __any_scheduler =                        //
43     any_receiver_ref<                          //
44         completion_signatures<set_error_t(std::exception_ptr),
45                               set_stopped_t()> //
46         >::any_sender<>::any_scheduler<>;
47 static_assert(scheduler<__any_scheduler>);
48 
49 template <class _Ty>
50 concept __stop_token_provider = //
51     requires(const _Ty& t) {    //
52         get_stop_token(t);
53     };
54 
55 template <class _Ty>
56 concept __indirect_stop_token_provider = //
57     requires(const _Ty& t) {
58         { get_env(t) } -> __stop_token_provider;
59     };
60 
61 template <class _Ty>
62 concept __indirect_scheduler_provider = //
63     requires(const _Ty& t) {
64         { get_env(t) } -> __scheduler_provider;
65     };
66 
67 template <class _ParentPromise>
__check_parent_promise_has_scheduler()68 constexpr bool __check_parent_promise_has_scheduler() noexcept
69 {
70     static_assert(__indirect_scheduler_provider<_ParentPromise>,
71                   "exec::task<T> cannot be co_await-ed in a coroutine that "
72                   "does not have an associated scheduler.");
73     return __indirect_scheduler_provider<_ParentPromise>;
74 }
75 
76 struct __forward_stop_request
77 {
78     inplace_stop_source& __stop_source_;
79 
operator ()exec::__task::__forward_stop_request80     void operator()() noexcept
81     {
82         __stop_source_.request_stop();
83     }
84 };
85 
86 template <class _ParentPromise>
87 struct __default_awaiter_context;
88 
89 ////////////////////////////////////////////////////////////////////////////////
90 // This is the context that is associated with basic_task's promise type
91 // by default. It handles forwarding of stop requests from parent to child.
92 enum class __scheduler_affinity
93 {
94     __none,
95     __sticky
96 };
97 
98 struct __parent_promise_t
99 {};
100 
101 template <__scheduler_affinity _SchedulerAffinity =
102               __scheduler_affinity::__sticky>
103 class __default_task_context_impl
104 {
105     template <class _ParentPromise>
106     friend struct __default_awaiter_context;
107 
108     static constexpr bool __with_scheduler =
109         _SchedulerAffinity == __scheduler_affinity::__sticky;
110 
111     STDEXEC_ATTRIBUTE((no_unique_address))
112     __if_c<__with_scheduler, __any_scheduler, __ignore> //
113         __scheduler_{exec::inline_scheduler{}};
114     inplace_stop_token __stop_token_;
115 
116   public:
117     template <class _ParentPromise>
__default_task_context_impl(__parent_promise_t,_ParentPromise & __parent)118     explicit __default_task_context_impl(__parent_promise_t,
119                                          _ParentPromise& __parent) noexcept
120     {
121         if constexpr (_SchedulerAffinity == __scheduler_affinity::__sticky)
122         {
123             if constexpr (__check_parent_promise_has_scheduler<
124                               _ParentPromise>())
125             {
126                 __scheduler_ = get_scheduler(get_env(__parent));
127             }
128         }
129     }
130 
131     template <scheduler _Scheduler>
__default_task_context_impl(_Scheduler && __scheduler)132     explicit __default_task_context_impl(_Scheduler&& __scheduler) :
133         __scheduler_{static_cast<_Scheduler&&>(__scheduler)}
134     {}
135 
query(get_scheduler_t) const136     auto query(get_scheduler_t) const noexcept -> const __any_scheduler&
137         requires(__with_scheduler)
138     {
139         return __scheduler_;
140     }
141 
query(get_stop_token_t) const142     auto query(get_stop_token_t) const noexcept -> inplace_stop_token
143     {
144         return __stop_token_;
145     }
146 
stop_requested() const147     [[nodiscard]] auto stop_requested() const noexcept -> bool
148     {
149         return __stop_token_.stop_requested();
150     }
151 
152     template <scheduler _Scheduler>
set_scheduler(_Scheduler && __sched)153     void set_scheduler(_Scheduler&& __sched)
154         requires(__with_scheduler)
155     {
156         __scheduler_ = static_cast<_Scheduler&&>(__sched);
157     }
158 
159     template <class _ThisPromise>
160     using promise_context_t = __default_task_context_impl;
161 
162     template <class _ThisPromise, class _ParentPromise = void>
163         requires(!__with_scheduler) ||
164                     __indirect_scheduler_provider<_ParentPromise>
165     using awaiter_context_t = __default_awaiter_context<_ParentPromise>;
166 };
167 
168 template <class _Ty>
169 using default_task_context =
170     __default_task_context_impl<__scheduler_affinity::__sticky>;
171 
172 template <class _Ty>
173 using __raw_task_context =
174     __default_task_context_impl<__scheduler_affinity::__none>;
175 
176 // This is the context associated with basic_task's awaiter. By default
177 // it does nothing.
178 template <class _ParentPromise>
179 struct __default_awaiter_context
180 {
181     template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context182     explicit __default_awaiter_context(
183         __default_task_context_impl<_Affinity>& __self,
184         _ParentPromise& __parent) noexcept
185     {}
186 };
187 
188 ////////////////////////////////////////////////////////////////////////////////
189 // This is the context to be associated with basic_task's awaiter when
190 // the parent coroutine's promise type is known, is a __stop_token_provider,
191 // and its stop token type is neither inplace_stop_token nor unstoppable.
192 template <__indirect_stop_token_provider _ParentPromise>
193 struct __default_awaiter_context<_ParentPromise>
194 {
195     using __stop_token_t = stop_token_of_t<env_of_t<_ParentPromise>>;
196     using __stop_callback_t =
197         typename __stop_token_t::template callback_type<__forward_stop_request>;
198 
199     template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context200     explicit __default_awaiter_context(
201         __default_task_context_impl<_Affinity>& __self,
202         _ParentPromise& __parent) noexcept
203         // Register a callback that will request stop on this basic_task's
204         // stop_source when stop is requested on the parent coroutine's stop
205         // token.
206         :
207         __stop_callback_{get_stop_token(get_env(__parent)),
208                          __forward_stop_request{__stop_source_}}
209     {
210         static_assert(
211             std::is_nothrow_constructible_v<__stop_callback_t, __stop_token_t,
212                                             __forward_stop_request>);
213         __self.__stop_token_ = __stop_source_.get_token();
214     }
215 
216     inplace_stop_source __stop_source_{};
217     __stop_callback_t __stop_callback_;
218 };
219 
220 // If the parent coroutine's type has a stop token of type inplace_stop_token,
221 // we don't need to register a stop callback.
222 template <__indirect_stop_token_provider _ParentPromise>
223     requires std::same_as<inplace_stop_token,
224                           stop_token_of_t<env_of_t<_ParentPromise>>>
225 struct __default_awaiter_context<_ParentPromise>
226 {
227     template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context228     explicit __default_awaiter_context(
229         __default_task_context_impl<_Affinity>& __self,
230         _ParentPromise& __parent) noexcept
231     {
232         __self.__stop_token_ = get_stop_token(get_env(__parent));
233     }
234 };
235 
236 // If the parent coroutine's stop token is unstoppable, there's no point
237 // forwarding stop tokens or stop requests at all.
238 template <__indirect_stop_token_provider _ParentPromise>
239     requires unstoppable_token<stop_token_of_t<env_of_t<_ParentPromise>>>
240 struct __default_awaiter_context<_ParentPromise>
241 {
242     template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context243     explicit __default_awaiter_context(__default_task_context_impl<_Affinity>&,
244                                        _ParentPromise&) noexcept
245     {}
246 };
247 
248 // Finally, if we don't know the parent coroutine's promise type, assume the
249 // worst and save a type-erased stop callback.
250 template <>
251 struct __default_awaiter_context<void>
252 {
253     template <__scheduler_affinity _Affinity, class _ParentPromise>
__default_awaiter_contextexec::__task::__default_awaiter_context254     explicit __default_awaiter_context(
255         __default_task_context_impl<_Affinity>& __self,
256         _ParentPromise& __parent) noexcept
257     {}
258 
259     template <__scheduler_affinity _Affinity,
260               __indirect_stop_token_provider _ParentPromise>
__default_awaiter_contextexec::__task::__default_awaiter_context261     explicit __default_awaiter_context(
262         __default_task_context_impl<_Affinity>& __self,
263         _ParentPromise& __parent)
264     {
265         // Register a callback that will request stop on this basic_task's
266         // stop_source when stop is requested on the parent coroutine's stop
267         // token.
268         using __stop_token_t = stop_token_of_t<env_of_t<_ParentPromise>>;
269         using __stop_callback_t =
270             stop_callback_for_t<__stop_token_t, __forward_stop_request>;
271 
272         if constexpr (std::same_as<__stop_token_t, inplace_stop_token>)
273         {
274             __self.__stop_token_ = get_stop_token(get_env(__parent));
275         }
276         else if (auto __token = get_stop_token(get_env(__parent));
277                  __token.stop_possible())
278         {
279             __stop_callback_.emplace<__stop_callback_t>(
280                 std::move(__token), __forward_stop_request{__stop_source_});
281             __self.__stop_token_ = __stop_source_.get_token();
282         }
283     }
284 
285     inplace_stop_source __stop_source_{};
286     std::any __stop_callback_{};
287 };
288 
289 template <class _Promise, class _ParentPromise = void>
290 using awaiter_context_t =                  //
291     typename __decay_t<env_of_t<_Promise>> //
292     ::template awaiter_context_t<_Promise, _ParentPromise>;
293 
294 ////////////////////////////////////////////////////////////////////////////////
295 // In a base class so it can be specialized when _Ty is void:
296 template <class _Ty>
297 struct __promise_base
298 {
return_valueexec::__task::__promise_base299     void return_value(_Ty value)
300     {
301         __data_.template emplace<0>(std::move(value));
302     }
303 
304     __variant_for<_Ty, std::exception_ptr> __data_{};
305 };
306 
307 template <>
308 struct __promise_base<void>
309 {
310     struct __void
311     {};
312 
return_voidexec::__task::__promise_base313     void return_void()
314     {
315         __data_.template emplace<0>(__void{});
316     }
317 
318     __variant_for<__void, std::exception_ptr> __data_{};
319 };
320 
321 enum class disposition : unsigned
322 {
323     stopped,
324     succeeded,
325     failed,
326 };
327 
328 struct __reschedule_coroutine_on
329 {
330     template <class _Scheduler>
331     struct __wrap
332     {
333         _Scheduler __sched_;
334     };
335 
336     template <scheduler _Scheduler>
operator ()exec::__task::__reschedule_coroutine_on337     auto operator()(_Scheduler __sched) const noexcept -> __wrap<_Scheduler>
338     {
339         return {static_cast<_Scheduler&&>(__sched)};
340     }
341 };
342 
343 ////////////////////////////////////////////////////////////////////////////////
344 // basic_task
345 template <class _Ty, class _Context = default_task_context<_Ty>>
346 class [[nodiscard]] basic_task
347 {
348     struct __promise;
349 
350   public:
351     using __t = basic_task;
352     using __id = basic_task;
353     using promise_type = __promise;
354 
basic_task(basic_task && __that)355     basic_task(basic_task&& __that) noexcept :
356         __coro_(std::exchange(__that.__coro_, {}))
357     {}
358 
~basic_task()359     ~basic_task()
360     {
361         if (__coro_)
362             __coro_.destroy();
363     }
364 
365   private:
366     struct __final_awaitable
367     {
await_readyexec::__task::basic_task::__final_awaitable368         static constexpr auto await_ready() noexcept -> bool
369         {
370             return false;
371         }
372 
await_suspendexec::__task::basic_task::__final_awaitable373         static auto await_suspend(
374             __coro::coroutine_handle<__promise> __h) noexcept
375             -> __coro::coroutine_handle<>
376         {
377             return __h.promise().continuation().handle();
378         }
379 
await_resumeexec::__task::basic_task::__final_awaitable380         static void await_resume() noexcept {}
381     };
382 
383     using __promise_context_t =
384         typename _Context::template promise_context_t<__promise>;
385 
386     struct __promise : __promise_base<_Ty>, with_awaitable_senders<__promise>
387     {
388         using __t = __promise;
389         using __id = __promise;
390 
get_return_objectexec::__task::basic_task::__promise391         auto get_return_object() noexcept -> basic_task
392         {
393             return basic_task(
394                 __coro::coroutine_handle<__promise>::from_promise(*this));
395         }
396 
initial_suspendexec::__task::basic_task::__promise397         auto initial_suspend() noexcept -> __coro::suspend_always
398         {
399             return {};
400         }
401 
final_suspendexec::__task::basic_task::__promise402         auto final_suspend() noexcept -> __final_awaitable
403         {
404             return {};
405         }
406 
dispositionexec::__task::basic_task::__promise407         [[nodiscard]] auto disposition() const noexcept -> __task::disposition
408         {
409             switch (this->__data_.index())
410             {
411                 case 0:
412                     return __task::disposition::succeeded;
413                 case 1:
414                     return __task::disposition::failed;
415                 default:
416                     return __task::disposition::stopped;
417             }
418         }
419 
unhandled_exceptionexec::__task::basic_task::__promise420         void unhandled_exception() noexcept
421         {
422             this->__data_.template emplace<1>(std::current_exception());
423         }
424 
425         template <sender _Awaitable>
426             requires __scheduler_provider<_Context>
await_transformexec::__task::basic_task::__promise427         auto await_transform(_Awaitable&& __awaitable) noexcept
428             -> decltype(auto)
429         {
430             // TODO: If we have a complete-where-it-starts query then we can
431             // optimize this to avoid the reschedule
432             return as_awaitable(
433                 continues_on(static_cast<_Awaitable&&>(__awaitable),
434                              get_scheduler(*__context_)),
435                 *this);
436         }
437 
438         template <class _Scheduler>
439             requires __scheduler_provider<_Context>
await_transformexec::__task::basic_task::__promise440         auto await_transform(
441             __reschedule_coroutine_on::__wrap<_Scheduler> __box) noexcept
442             -> decltype(auto)
443         {
444             if (!std::exchange(__rescheduled_, true))
445             {
446                 // Create a cleanup action that transitions back onto the
447                 // current scheduler:
448                 auto __sched = get_scheduler(*__context_);
449                 auto __cleanup_task =
450                     at_coroutine_exit(schedule, std::move(__sched));
451                 // Insert the cleanup action into the head of the continuation
452                 // chain by making direct calls to the cleanup task's awaiter
453                 // member functions. See type _cleanup_task in
454                 // at_coroutine_exit.hpp:
455                 __cleanup_task.await_suspend(
456                     __coro::coroutine_handle<__promise>::from_promise(*this));
457                 (void)__cleanup_task.await_resume();
458             }
459             __context_->set_scheduler(__box.__sched_);
460             return as_awaitable(schedule(__box.__sched_), *this);
461         }
462 
463         template <class _Awaitable>
await_transformexec::__task::basic_task::__promise464         auto await_transform(_Awaitable&& __awaitable) noexcept
465             -> decltype(auto)
466         {
467             return with_awaitable_senders<__promise>::await_transform(
468                 static_cast<_Awaitable&&>(__awaitable));
469         }
470 
get_envexec::__task::basic_task::__promise471         auto get_env() const noexcept -> const __promise_context_t&
472         {
473             return *__context_;
474         }
475 
476         __optional<__promise_context_t> __context_{};
477         bool __rescheduled_{false};
478     };
479 
480     template <class _ParentPromise = void>
481     struct __task_awaitable
482     {
483         __coro::coroutine_handle<__promise> __coro_;
484         __optional<awaiter_context_t<__promise, _ParentPromise>> __context_{};
485 
~__task_awaitableexec::__task::basic_task::__task_awaitable486         ~__task_awaitable()
487         {
488             if (__coro_)
489                 __coro_.destroy();
490         }
491 
await_readyexec::__task::basic_task::__task_awaitable492         static constexpr auto await_ready() noexcept -> bool
493         {
494             return false;
495         }
496 
497         template <class _ParentPromise2>
await_suspendexec::__task::basic_task::__task_awaitable498         auto await_suspend(
499             __coro::coroutine_handle<_ParentPromise2> __parent) noexcept
500             -> __coro::coroutine_handle<>
501         {
502             static_assert(__one_of<_ParentPromise, _ParentPromise2, void>);
503             __coro_.promise().__context_.emplace(__parent_promise_t(),
504                                                  __parent.promise());
505             __context_.emplace(*__coro_.promise().__context_,
506                                __parent.promise());
507             __coro_.promise().set_continuation(__parent);
508             if constexpr (requires {
509                               __coro_.promise().stop_requested() ? 0 : 1;
510                           })
511             {
512                 if (__coro_.promise().stop_requested())
513                     return __parent.promise().unhandled_stopped();
514             }
515             return __coro_;
516         }
517 
await_resumeexec::__task::basic_task::__task_awaitable518         auto await_resume() -> _Ty
519         {
520             __context_.reset();
521             scope_guard __on_exit{[this]() noexcept {
522                 std::exchange(__coro_, {}).destroy();
523             }};
524             if (__coro_.promise().__data_.index() == 1)
525                 std::rethrow_exception(
526                     std::move(__coro_.promise().__data_.template get<1>()));
527             if constexpr (!std::is_void_v<_Ty>)
528                 return std::move(__coro_.promise().__data_.template get<0>());
529         }
530     };
531 
532   public:
533     // Make this task awaitable within a particular context:
534     template <class _ParentPromise>
535         requires constructible_from<
536             awaiter_context_t<__promise, _ParentPromise>, __promise_context_t&,
537             _ParentPromise&>
STDEXEC_MEMFN_DECL(auto as_awaitable)538     STDEXEC_MEMFN_DECL(auto as_awaitable)(this basic_task&& __self,
539                                           _ParentPromise&) noexcept
540         -> __task_awaitable<_ParentPromise>
541     {
542         return __task_awaitable<_ParentPromise>{
543             std::exchange(__self.__coro_, {})};
544     }
545 
546     // Make this task generally awaitable:
operator co_await()547     auto operator co_await() && noexcept -> __task_awaitable<>
548         requires __mvalid<awaiter_context_t, __promise>
549     {
550         return __task_awaitable<>{std::exchange(__coro_, {})};
551     }
552 
553     // From the list of types [_Ty], remove any types that are void, and send
554     //   the resulting list to __qf<set_value_t>, which uses the list of types
555     //   as arguments of a function type. In other words, set_value_t() if _Ty
556     //   is void, and set_value_t(_Ty) otherwise.
557     using __set_value_sig_t =
558         __minvoke<__mremove<void, __qf<set_value_t>>, _Ty>;
559 
560     // Specify basic_task's completion signatures
561     //   This is only necessary when basic_task is not generally awaitable
562     //   owing to constraints imposed by its _Context parameter.
563     using __task_traits_t = //
564         completion_signatures<__set_value_sig_t,
565                               set_error_t(std::exception_ptr), set_stopped_t()>;
566 
get_completion_signatures(__ignore={}) const567     auto get_completion_signatures(__ignore = {}) const -> __task_traits_t
568     {
569         return {};
570     }
571 
basic_task(__coro::coroutine_handle<promise_type> __coro)572     explicit basic_task(__coro::coroutine_handle<promise_type> __coro) noexcept
573         : __coro_(__coro)
574     {}
575 
576     __coro::coroutine_handle<promise_type> __coro_;
577 };
578 } // namespace __task
579 
580 using task_disposition = __task::disposition;
581 
582 template <class _Ty>
583 using default_task_context = __task::default_task_context<_Ty>;
584 
585 template <class _Promise, class _ParentPromise = void>
586 using awaiter_context_t = __task::awaiter_context_t<_Promise, _ParentPromise>;
587 
588 template <class _Ty, class _Context = default_task_context<_Ty>>
589 using basic_task = __task::basic_task<_Ty, _Context>;
590 
591 template <class _Ty>
592 using task = basic_task<_Ty, default_task_context<_Ty>>;
593 
594 inline constexpr __task::__reschedule_coroutine_on reschedule_coroutine_on{};
595 } // namespace exec
596 
597 namespace stdexec
598 {
599 template <class _Ty, class _Context>
600 inline constexpr bool enable_sender<exec::basic_task<_Ty, _Context>> = true;
601 } // namespace stdexec
602 
603 STDEXEC_PRAGMA_POP()
604