xref: /openbmc/sdbusplus/include/sdbusplus/async/stdexec/task.hpp (revision 6269157344064457b7e1241d4efe59c6c51c7a59)
1 /*
2  * Copyright (c) 2021-2024 NVIDIA Corporation
3  *
4  * Licensed under the Apache License Version 2.0 with LLVM Exceptions
5  * (the "License"); you may not use this file except in compliance with
6  * the License. You may obtain a copy of the License at
7  *
8  *   https://llvm.org/LICENSE.txt
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "../stdexec/__detail/__meta.hpp"
19 #include "../stdexec/__detail/__optional.hpp"
20 #include "../stdexec/__detail/__variant.hpp"
21 #include "../stdexec/coroutine.hpp"
22 #include "../stdexec/execution.hpp"
23 #include "any_sender_of.hpp"
24 #include "at_coroutine_exit.hpp"
25 #include "inline_scheduler.hpp"
26 #include "scope.hpp"
27 
28 #include <any>
29 #include <cassert>
30 #include <exception>
31 #include <utility>
32 
33 STDEXEC_PRAGMA_PUSH()
34 STDEXEC_PRAGMA_IGNORE_GNU("-Wundefined-inline")
35 
36 namespace exec
37 {
38 namespace __task
39 {
40 using namespace stdexec;
41 
42 using __any_scheduler =                        //
43     any_receiver_ref<                          //
44         completion_signatures<set_error_t(std::exception_ptr),
45                               set_stopped_t()> //
46         >::any_sender<>::any_scheduler<>;
47 static_assert(scheduler<__any_scheduler>);
48 
49 template <class _Ty>
50 concept __stop_token_provider = //
51     requires(const _Ty& t) {    //
52         get_stop_token(t);
53     };
54 
55 template <class _Ty>
56 concept __indirect_stop_token_provider = //
57     requires(const _Ty& t) {
58         { get_env(t) } -> __stop_token_provider;
59     };
60 
61 template <class _Ty>
62 concept __indirect_scheduler_provider = //
63     requires(const _Ty& t) {
64         { get_env(t) } -> __scheduler_provider;
65     };
66 
67 template <class _ParentPromise>
__check_parent_promise_has_scheduler()68 constexpr bool __check_parent_promise_has_scheduler() noexcept
69 {
70     static_assert(__indirect_scheduler_provider<_ParentPromise>,
71                   "exec::task<T> cannot be co_await-ed in a coroutine that "
72                   "does not have an associated scheduler.");
73     return __indirect_scheduler_provider<_ParentPromise>;
74 }
75 
76 struct __forward_stop_request
77 {
78     inplace_stop_source& __stop_source_;
79 
operator ()exec::__task::__forward_stop_request80     void operator()() noexcept
81     {
82         __stop_source_.request_stop();
83     }
84 };
85 
86 template <class _ParentPromise>
87 struct __default_awaiter_context;
88 
89 ////////////////////////////////////////////////////////////////////////////////
90 // This is the context that is associated with basic_task's promise type
91 // by default. It handles forwarding of stop requests from parent to child.
92 enum class __scheduler_affinity
93 {
94     __none,
95     __sticky
96 };
97 
98 struct __parent_promise_t
99 {};
100 
101 template <__scheduler_affinity _SchedulerAffinity =
102               __scheduler_affinity::__sticky>
103 class __default_task_context_impl
104 {
105     template <class _ParentPromise>
106     friend struct __default_awaiter_context;
107 
108     static constexpr bool __with_scheduler =
109         _SchedulerAffinity == __scheduler_affinity::__sticky;
110 
111     STDEXEC_ATTRIBUTE((no_unique_address))
112     __if_c<__with_scheduler, __any_scheduler, __ignore> //
113         __scheduler_{exec::inline_scheduler{}};
114     inplace_stop_token __stop_token_;
115 
116   public:
117     template <class _ParentPromise>
__default_task_context_impl(__parent_promise_t,_ParentPromise & __parent)118     explicit __default_task_context_impl(__parent_promise_t,
119                                          _ParentPromise& __parent) noexcept
120     {
121         if constexpr (_SchedulerAffinity == __scheduler_affinity::__sticky)
122         {
123             if constexpr (__check_parent_promise_has_scheduler<
124                               _ParentPromise>())
125             {
126                 __scheduler_ = get_scheduler(get_env(__parent));
127             }
128         }
129     }
130 
131     template <scheduler _Scheduler>
__default_task_context_impl(_Scheduler && __scheduler)132     explicit __default_task_context_impl(_Scheduler&& __scheduler) :
133         __scheduler_{static_cast<_Scheduler&&>(__scheduler)}
134     {}
135 
query(get_scheduler_t) const136     auto query(get_scheduler_t) const noexcept
137         -> const __any_scheduler& requires(__with_scheduler) {
138                                       return __scheduler_;
139                                   }
140 
query(get_stop_token_t) const141     auto query(get_stop_token_t) const noexcept -> inplace_stop_token
142     {
143         return __stop_token_;
144     }
145 
stop_requested() const146     [[nodiscard]] auto stop_requested() const noexcept -> bool
147     {
148         return __stop_token_.stop_requested();
149     }
150 
151     template <scheduler _Scheduler>
set_scheduler(_Scheduler && __sched)152     void set_scheduler(_Scheduler&& __sched)
153         requires(__with_scheduler)
154     {
155         __scheduler_ = static_cast<_Scheduler&&>(__sched);
156     }
157 
158     template <class _ThisPromise>
159     using promise_context_t = __default_task_context_impl;
160 
161     template <class _ThisPromise, class _ParentPromise = void>
162         requires(!__with_scheduler) ||
163                     __indirect_scheduler_provider<_ParentPromise>
164     using awaiter_context_t = __default_awaiter_context<_ParentPromise>;
165 };
166 
167 template <class _Ty>
168 using default_task_context =
169     __default_task_context_impl<__scheduler_affinity::__sticky>;
170 
171 template <class _Ty>
172 using __raw_task_context =
173     __default_task_context_impl<__scheduler_affinity::__none>;
174 
175 // This is the context associated with basic_task's awaiter. By default
176 // it does nothing.
177 template <class _ParentPromise>
178 struct __default_awaiter_context
179 {
180     template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context181     explicit __default_awaiter_context(
182         __default_task_context_impl<_Affinity>& __self,
183         _ParentPromise& __parent) noexcept
184     {}
185 };
186 
187 ////////////////////////////////////////////////////////////////////////////////
188 // This is the context to be associated with basic_task's awaiter when
189 // the parent coroutine's promise type is known, is a __stop_token_provider,
190 // and its stop token type is neither inplace_stop_token nor unstoppable.
191 template <__indirect_stop_token_provider _ParentPromise>
192 struct __default_awaiter_context<_ParentPromise>
193 {
194     using __stop_token_t = stop_token_of_t<env_of_t<_ParentPromise>>;
195     using __stop_callback_t =
196         typename __stop_token_t::template callback_type<__forward_stop_request>;
197 
198     template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context199     explicit __default_awaiter_context(
200         __default_task_context_impl<_Affinity>& __self,
201         _ParentPromise& __parent) noexcept
202         // Register a callback that will request stop on this basic_task's
203         // stop_source when stop is requested on the parent coroutine's stop
204         // token.
205         :
206         __stop_callback_{get_stop_token(get_env(__parent)),
207                          __forward_stop_request{__stop_source_}}
208     {
209         static_assert(
210             std::is_nothrow_constructible_v<__stop_callback_t, __stop_token_t,
211                                             __forward_stop_request>);
212         __self.__stop_token_ = __stop_source_.get_token();
213     }
214 
215     inplace_stop_source __stop_source_{};
216     __stop_callback_t __stop_callback_;
217 };
218 
219 // If the parent coroutine's type has a stop token of type inplace_stop_token,
220 // we don't need to register a stop callback.
221 template <__indirect_stop_token_provider _ParentPromise>
222     requires std::same_as<inplace_stop_token,
223                           stop_token_of_t<env_of_t<_ParentPromise>>>
224 struct __default_awaiter_context<_ParentPromise>
225 {
226     template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context227     explicit __default_awaiter_context(
228         __default_task_context_impl<_Affinity>& __self,
229         _ParentPromise& __parent) noexcept
230     {
231         __self.__stop_token_ = get_stop_token(get_env(__parent));
232     }
233 };
234 
235 // If the parent coroutine's stop token is unstoppable, there's no point
236 // forwarding stop tokens or stop requests at all.
237 template <__indirect_stop_token_provider _ParentPromise>
238     requires unstoppable_token<stop_token_of_t<env_of_t<_ParentPromise>>>
239 struct __default_awaiter_context<_ParentPromise>
240 {
241     template <__scheduler_affinity _Affinity>
__default_awaiter_contextexec::__task::__default_awaiter_context242     explicit __default_awaiter_context(__default_task_context_impl<_Affinity>&,
243                                        _ParentPromise&) noexcept
244     {}
245 };
246 
247 // Finally, if we don't know the parent coroutine's promise type, assume the
248 // worst and save a type-erased stop callback.
249 template <>
250 struct __default_awaiter_context<void>
251 {
252     template <__scheduler_affinity _Affinity, class _ParentPromise>
__default_awaiter_contextexec::__task::__default_awaiter_context253     explicit __default_awaiter_context(
254         __default_task_context_impl<_Affinity>& __self,
255         _ParentPromise& __parent) noexcept
256     {}
257 
258     template <__scheduler_affinity _Affinity,
259               __indirect_stop_token_provider _ParentPromise>
__default_awaiter_contextexec::__task::__default_awaiter_context260     explicit __default_awaiter_context(
261         __default_task_context_impl<_Affinity>& __self,
262         _ParentPromise& __parent)
263     {
264         // Register a callback that will request stop on this basic_task's
265         // stop_source when stop is requested on the parent coroutine's stop
266         // token.
267         using __stop_token_t = stop_token_of_t<env_of_t<_ParentPromise>>;
268         using __stop_callback_t =
269             stop_callback_for_t<__stop_token_t, __forward_stop_request>;
270 
271         if constexpr (std::same_as<__stop_token_t, inplace_stop_token>)
272         {
273             __self.__stop_token_ = get_stop_token(get_env(__parent));
274         }
275         else if (auto __token = get_stop_token(get_env(__parent));
276                  __token.stop_possible())
277         {
278             __stop_callback_.emplace<__stop_callback_t>(
279                 std::move(__token), __forward_stop_request{__stop_source_});
280             __self.__stop_token_ = __stop_source_.get_token();
281         }
282     }
283 
284     inplace_stop_source __stop_source_{};
285     std::any __stop_callback_{};
286 };
287 
288 template <class _Promise, class _ParentPromise = void>
289 using awaiter_context_t =                  //
290     typename __decay_t<env_of_t<_Promise>> //
291     ::template awaiter_context_t<_Promise, _ParentPromise>;
292 
293 ////////////////////////////////////////////////////////////////////////////////
294 // In a base class so it can be specialized when _Ty is void:
295 template <class _Ty>
296 struct __promise_base
297 {
return_valueexec::__task::__promise_base298     void return_value(_Ty value)
299     {
300         __data_.template emplace<0>(std::move(value));
301     }
302 
303     __variant_for<_Ty, std::exception_ptr> __data_{};
304 };
305 
306 template <>
307 struct __promise_base<void>
308 {
309     struct __void
310     {};
311 
return_voidexec::__task::__promise_base312     void return_void()
313     {
314         __data_.template emplace<0>(__void{});
315     }
316 
317     __variant_for<__void, std::exception_ptr> __data_{};
318 };
319 
320 enum class disposition : unsigned
321 {
322     stopped,
323     succeeded,
324     failed,
325 };
326 
327 struct __reschedule_coroutine_on
328 {
329     template <class _Scheduler>
330     struct __wrap
331     {
332         _Scheduler __sched_;
333     };
334 
335     template <scheduler _Scheduler>
operator ()exec::__task::__reschedule_coroutine_on336     auto operator()(_Scheduler __sched) const noexcept -> __wrap<_Scheduler>
337     {
338         return {static_cast<_Scheduler&&>(__sched)};
339     }
340 };
341 
342 ////////////////////////////////////////////////////////////////////////////////
343 // basic_task
344 template <class _Ty, class _Context = default_task_context<_Ty>>
345 class [[nodiscard]] basic_task
346 {
347     struct __promise;
348 
349   public:
350     using __t = basic_task;
351     using __id = basic_task;
352     using promise_type = __promise;
353 
basic_task(basic_task && __that)354     basic_task(basic_task&& __that) noexcept :
355         __coro_(std::exchange(__that.__coro_, {}))
356     {}
357 
~basic_task()358     ~basic_task()
359     {
360         if (__coro_)
361             __coro_.destroy();
362     }
363 
364   private:
365     struct __final_awaitable
366     {
await_readyexec::__task::basic_task::__final_awaitable367         static constexpr auto await_ready() noexcept -> bool
368         {
369             return false;
370         }
371 
await_suspendexec::__task::basic_task::__final_awaitable372         static auto await_suspend(
373             __coro::coroutine_handle<__promise> __h) noexcept
374             -> __coro::coroutine_handle<>
375         {
376             return __h.promise().continuation().handle();
377         }
378 
await_resumeexec::__task::basic_task::__final_awaitable379         static void await_resume() noexcept {}
380     };
381 
382     using __promise_context_t =
383         typename _Context::template promise_context_t<__promise>;
384 
385     struct __promise : __promise_base<_Ty>, with_awaitable_senders<__promise>
386     {
387         using __t = __promise;
388         using __id = __promise;
389 
get_return_objectexec::__task::basic_task::__promise390         auto get_return_object() noexcept -> basic_task
391         {
392             return basic_task(
393                 __coro::coroutine_handle<__promise>::from_promise(*this));
394         }
395 
initial_suspendexec::__task::basic_task::__promise396         auto initial_suspend() noexcept -> __coro::suspend_always
397         {
398             return {};
399         }
400 
final_suspendexec::__task::basic_task::__promise401         auto final_suspend() noexcept -> __final_awaitable
402         {
403             return {};
404         }
405 
dispositionexec::__task::basic_task::__promise406         [[nodiscard]] auto disposition() const noexcept -> __task::disposition
407         {
408             switch (this->__data_.index())
409             {
410                 case 0:
411                     return __task::disposition::succeeded;
412                 case 1:
413                     return __task::disposition::failed;
414                 default:
415                     return __task::disposition::stopped;
416             }
417         }
418 
unhandled_exceptionexec::__task::basic_task::__promise419         void unhandled_exception() noexcept
420         {
421             this->__data_.template emplace<1>(std::current_exception());
422         }
423 
424         template <sender _Awaitable>
425             requires __scheduler_provider<_Context>
await_transformexec::__task::basic_task::__promise426         auto await_transform(_Awaitable&& __awaitable) noexcept
427             -> decltype(auto)
428         {
429             // TODO: If we have a complete-where-it-starts query then we can
430             // optimize this to avoid the reschedule
431             return as_awaitable(
432                 continues_on(static_cast<_Awaitable&&>(__awaitable),
433                              get_scheduler(*__context_)),
434                 *this);
435         }
436 
437         template <class _Scheduler>
438             requires __scheduler_provider<_Context>
await_transformexec::__task::basic_task::__promise439         auto await_transform(
440             __reschedule_coroutine_on::__wrap<_Scheduler> __box) noexcept
441             -> decltype(auto)
442         {
443             if (!std::exchange(__rescheduled_, true))
444             {
445                 // Create a cleanup action that transitions back onto the
446                 // current scheduler:
447                 auto __sched = get_scheduler(*__context_);
448                 auto __cleanup_task =
449                     at_coroutine_exit(schedule, std::move(__sched));
450                 // Insert the cleanup action into the head of the continuation
451                 // chain by making direct calls to the cleanup task's awaiter
452                 // member functions. See type _cleanup_task in
453                 // at_coroutine_exit.hpp:
454                 __cleanup_task.await_suspend(
455                     __coro::coroutine_handle<__promise>::from_promise(*this));
456                 (void)__cleanup_task.await_resume();
457             }
458             __context_->set_scheduler(__box.__sched_);
459             return as_awaitable(schedule(__box.__sched_), *this);
460         }
461 
462         template <class _Awaitable>
await_transformexec::__task::basic_task::__promise463         auto await_transform(_Awaitable&& __awaitable) noexcept
464             -> decltype(auto)
465         {
466             return with_awaitable_senders<__promise>::await_transform(
467                 static_cast<_Awaitable&&>(__awaitable));
468         }
469 
get_envexec::__task::basic_task::__promise470         auto get_env() const noexcept -> const __promise_context_t&
471         {
472             return *__context_;
473         }
474 
475         __optional<__promise_context_t> __context_{};
476         bool __rescheduled_{false};
477     };
478 
479     template <class _ParentPromise = void>
480     struct __task_awaitable
481     {
482         __coro::coroutine_handle<__promise> __coro_;
483         __optional<awaiter_context_t<__promise, _ParentPromise>> __context_{};
484 
~__task_awaitableexec::__task::basic_task::__task_awaitable485         ~__task_awaitable()
486         {
487             if (__coro_)
488                 __coro_.destroy();
489         }
490 
await_readyexec::__task::basic_task::__task_awaitable491         static constexpr auto await_ready() noexcept -> bool
492         {
493             return false;
494         }
495 
496         template <class _ParentPromise2>
await_suspendexec::__task::basic_task::__task_awaitable497         auto await_suspend(
498             __coro::coroutine_handle<_ParentPromise2> __parent) noexcept
499             -> __coro::coroutine_handle<>
500         {
501             static_assert(__one_of<_ParentPromise, _ParentPromise2, void>);
502             __coro_.promise().__context_.emplace(__parent_promise_t(),
503                                                  __parent.promise());
504             __context_.emplace(*__coro_.promise().__context_,
505                                __parent.promise());
506             __coro_.promise().set_continuation(__parent);
507             if constexpr (requires {
508                               __coro_.promise().stop_requested() ? 0 : 1;
509                           })
510             {
511                 if (__coro_.promise().stop_requested())
512                     return __parent.promise().unhandled_stopped();
513             }
514             return __coro_;
515         }
516 
await_resumeexec::__task::basic_task::__task_awaitable517         auto await_resume() -> _Ty
518         {
519             __context_.reset();
520             scope_guard __on_exit{[this]() noexcept {
521                 std::exchange(__coro_, {}).destroy();
522             }};
523             if (__coro_.promise().__data_.index() == 1)
524                 std::rethrow_exception(
525                     std::move(__coro_.promise().__data_.template get<1>()));
526             if constexpr (!std::is_void_v<_Ty>)
527                 return std::move(__coro_.promise().__data_.template get<0>());
528         }
529     };
530 
531   public:
532     // Make this task awaitable within a particular context:
533     template <class _ParentPromise>
534         requires constructible_from<
535             awaiter_context_t<__promise, _ParentPromise>, __promise_context_t&,
536             _ParentPromise&>
STDEXEC_MEMFN_DECL(auto as_awaitable)537     STDEXEC_MEMFN_DECL(auto as_awaitable)(this basic_task&& __self,
538                                           _ParentPromise&) noexcept
539         -> __task_awaitable<_ParentPromise>
540     {
541         return __task_awaitable<_ParentPromise>{
542             std::exchange(__self.__coro_, {})};
543     }
544 
545     // Make this task generally awaitable:
operator co_await()546     auto operator co_await() && noexcept -> __task_awaitable<>
547         requires __mvalid<awaiter_context_t, __promise>
548     {
549         return __task_awaitable<>{std::exchange(__coro_, {})};
550     }
551 
552     // From the list of types [_Ty], remove any types that are void, and send
553     //   the resulting list to __qf<set_value_t>, which uses the list of types
554     //   as arguments of a function type. In other words, set_value_t() if _Ty
555     //   is void, and set_value_t(_Ty) otherwise.
556     using __set_value_sig_t =
557         __minvoke<__mremove<void, __qf<set_value_t>>, _Ty>;
558 
559     // Specify basic_task's completion signatures
560     //   This is only necessary when basic_task is not generally awaitable
561     //   owing to constraints imposed by its _Context parameter.
562     using __task_traits_t = //
563         completion_signatures<__set_value_sig_t,
564                               set_error_t(std::exception_ptr), set_stopped_t()>;
565 
get_completion_signatures(__ignore={}) const566     auto get_completion_signatures(__ignore = {}) const -> __task_traits_t
567     {
568         return {};
569     }
570 
basic_task(__coro::coroutine_handle<promise_type> __coro)571     explicit basic_task(__coro::coroutine_handle<promise_type> __coro) noexcept
572         : __coro_(__coro)
573     {}
574 
575     __coro::coroutine_handle<promise_type> __coro_;
576 };
577 } // namespace __task
578 
579 using task_disposition = __task::disposition;
580 
581 template <class _Ty>
582 using default_task_context = __task::default_task_context<_Ty>;
583 
584 template <class _Promise, class _ParentPromise = void>
585 using awaiter_context_t = __task::awaiter_context_t<_Promise, _ParentPromise>;
586 
587 template <class _Ty, class _Context = default_task_context<_Ty>>
588 using basic_task = __task::basic_task<_Ty, _Context>;
589 
590 template <class _Ty>
591 using task = basic_task<_Ty, default_task_context<_Ty>>;
592 
593 inline constexpr __task::__reschedule_coroutine_on reschedule_coroutine_on{};
594 } // namespace exec
595 
596 namespace stdexec
597 {
598 template <class _Ty, class _Context>
599 inline constexpr bool enable_sender<exec::basic_task<_Ty, _Context>> = true;
600 } // namespace stdexec
601 
602 STDEXEC_PRAGMA_POP()
603