Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:54:53

0001 ///////////////////////////////////////////////////////////////////////////////
0002 // Copyright (c) Lewis Baker
0003 // Licenced under MIT license. See LICENSE.txt for details.
0004 ///////////////////////////////////////////////////////////////////////////////
0005 #ifndef CPPCORO_TASK_HPP_INCLUDED
0006 #define CPPCORO_TASK_HPP_INCLUDED
0007 
0008 #include <cppcoro/config.hpp>
0009 #include <cppcoro/awaitable_traits.hpp>
0010 #include <cppcoro/broken_promise.hpp>
0011 
0012 #include <cppcoro/detail/remove_rvalue_reference.hpp>
0013 
0014 #include <atomic>
0015 #include <exception>
0016 #include <utility>
0017 #include <type_traits>
0018 #include <cstdint>
0019 #include <cassert>
0020 #include <new>
0021 
0022 #include <cppcoro/coroutine.hpp>
0023 
0024 namespace cppcoro
0025 {
0026     template<typename T> class task;
0027 
0028     namespace detail
0029     {
0030         class task_promise_base
0031         {
0032             friend struct final_awaitable;
0033 
0034             struct final_awaitable
0035             {
0036                 bool await_ready() const noexcept { return false; }
0037 
0038 #if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER
0039                 template<typename PROMISE>
0040                 cppcoro::coroutine_handle<> await_suspend(
0041                     cppcoro::coroutine_handle<PROMISE> coro) noexcept
0042                 {
0043                     return coro.promise().m_continuation;
0044                 }
0045 #else
0046                 // HACK: Need to add CPPCORO_NOINLINE to await_suspend() method
0047                 // to avoid MSVC 2017.8 from spilling some local variables in
0048                 // await_suspend() onto the coroutine frame in some cases.
0049                 // Without this, some tests in async_auto_reset_event_tests.cpp
0050                 // were crashing under x86 optimised builds.
0051                 template<typename PROMISE>
0052                 CPPCORO_NOINLINE
0053                 void await_suspend(cppcoro::coroutine_handle<PROMISE> coroutine) noexcept
0054                 {
0055                     task_promise_base& promise = coroutine.promise();
0056 
0057                     // Use 'release' memory semantics in case we finish before the
0058                     // awaiter can suspend so that the awaiting thread sees our
0059                     // writes to the resulting value.
0060                     // Use 'acquire' memory semantics in case the caller registered
0061                     // the continuation before we finished. Ensure we see their write
0062                     // to m_continuation.
0063                     if (promise.m_state.exchange(true, std::memory_order_acq_rel))
0064                     {
0065                         promise.m_continuation.resume();
0066                     }
0067                 }
0068 #endif
0069 
0070                 void await_resume() noexcept {}
0071             };
0072 
0073         public:
0074 
0075             task_promise_base() noexcept
0076 #if !CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER
0077                 : m_state(false)
0078 #endif
0079             {}
0080 
0081             auto initial_suspend() noexcept
0082             {
0083                 return cppcoro::suspend_always{};
0084             }
0085 
0086             auto final_suspend() noexcept
0087             {
0088                 return final_awaitable{};
0089             }
0090 
0091 #if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER
0092             void set_continuation(cppcoro::coroutine_handle<> continuation) noexcept
0093             {
0094                 m_continuation = continuation;
0095             }
0096 #else
0097             bool try_set_continuation(cppcoro::coroutine_handle<> continuation)
0098             {
0099                 m_continuation = continuation;
0100                 return !m_state.exchange(true, std::memory_order_acq_rel);
0101             }
0102 #endif
0103 
0104         private:
0105 
0106             cppcoro::coroutine_handle<> m_continuation;
0107 
0108 #if !CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER
0109             // Initially false. Set to true when either a continuation is registered
0110             // or when the coroutine has run to completion. Whichever operation
0111             // successfully transitions from false->true got there first.
0112             std::atomic<bool> m_state;
0113 #endif
0114 
0115         };
0116 
0117         template<typename T>
0118         class task_promise final : public task_promise_base
0119         {
0120         public:
0121 
0122             task_promise() noexcept {}
0123 
0124             ~task_promise()
0125             {
0126                 switch (m_resultType)
0127                 {
0128                 case result_type::value:
0129                     m_value.~T();
0130                     break;
0131                 case result_type::exception:
0132                     m_exception.~exception_ptr();
0133                     break;
0134                 default:
0135                     break;
0136                 }
0137             }
0138 
0139             task<T> get_return_object() noexcept;
0140 
0141             void unhandled_exception() noexcept
0142             {
0143                 ::new (static_cast<void*>(std::addressof(m_exception))) std::exception_ptr(
0144                     std::current_exception());
0145                 m_resultType = result_type::exception;
0146             }
0147 
0148             template<
0149                 typename VALUE,
0150                 typename = std::enable_if_t<std::is_convertible_v<VALUE&&, T>>>
0151             void return_value(VALUE&& value)
0152                 noexcept(std::is_nothrow_constructible_v<T, VALUE&&>)
0153             {
0154                 ::new (static_cast<void*>(std::addressof(m_value))) T(std::forward<VALUE>(value));
0155                 m_resultType = result_type::value;
0156             }
0157 
0158             T& result() &
0159             {
0160                 if (m_resultType == result_type::exception)
0161                 {
0162                     std::rethrow_exception(m_exception);
0163                 }
0164 
0165                 assert(m_resultType == result_type::value);
0166 
0167                 return m_value;
0168             }
0169 
0170             // HACK: Need to have co_await of task<int> return prvalue rather than
0171             // rvalue-reference to work around an issue with MSVC where returning
0172             // rvalue reference of a fundamental type from await_resume() will
0173             // cause the value to be copied to a temporary. This breaks the
0174             // sync_wait() implementation.
0175             // See https://github.com/lewissbaker/cppcoro/issues/40#issuecomment-326864107
0176             using rvalue_type = std::conditional_t<
0177                 std::is_arithmetic_v<T> || std::is_pointer_v<T>,
0178                 T,
0179                 T&&>;
0180 
0181             rvalue_type result() &&
0182             {
0183                 if (m_resultType == result_type::exception)
0184                 {
0185                     std::rethrow_exception(m_exception);
0186                 }
0187 
0188                 assert(m_resultType == result_type::value);
0189 
0190                 return std::move(m_value);
0191             }
0192 
0193         private:
0194 
0195             enum class result_type { empty, value, exception };
0196 
0197             result_type m_resultType = result_type::empty;
0198 
0199             union
0200             {
0201                 T m_value;
0202                 std::exception_ptr m_exception;
0203             };
0204 
0205         };
0206 
0207         template<>
0208         class task_promise<void> : public task_promise_base
0209         {
0210         public:
0211 
0212             task_promise() noexcept = default;
0213 
0214             task<void> get_return_object() noexcept;
0215 
0216             void return_void() noexcept
0217             {}
0218 
0219             void unhandled_exception() noexcept
0220             {
0221                 m_exception = std::current_exception();
0222             }
0223 
0224             void result()
0225             {
0226                 if (m_exception)
0227                 {
0228                     std::rethrow_exception(m_exception);
0229                 }
0230             }
0231 
0232         private:
0233 
0234             std::exception_ptr m_exception;
0235 
0236         };
0237 
0238         template<typename T>
0239         class task_promise<T&> : public task_promise_base
0240         {
0241         public:
0242 
0243             task_promise() noexcept = default;
0244 
0245             task<T&> get_return_object() noexcept;
0246 
0247             void unhandled_exception() noexcept
0248             {
0249                 m_exception = std::current_exception();
0250             }
0251 
0252             void return_value(T& value) noexcept
0253             {
0254                 m_value = std::addressof(value);
0255             }
0256 
0257             T& result()
0258             {
0259                 if (m_exception)
0260                 {
0261                     std::rethrow_exception(m_exception);
0262                 }
0263 
0264                 return *m_value;
0265             }
0266 
0267         private:
0268 
0269             T* m_value = nullptr;
0270             std::exception_ptr m_exception;
0271 
0272         };
0273     }
0274 
0275     /// \brief
0276     /// A task represents an operation that produces a result both lazily
0277     /// and asynchronously.
0278     ///
0279     /// When you call a coroutine that returns a task, the coroutine
0280     /// simply captures any passed parameters and returns exeuction to the
0281     /// caller. Execution of the coroutine body does not start until the
0282     /// coroutine is first co_await'ed.
0283     template<typename T = void>
0284     class [[nodiscard]] task
0285     {
0286     public:
0287 
0288         using promise_type = detail::task_promise<T>;
0289 
0290         using value_type = T;
0291 
0292     private:
0293 
0294         struct awaitable_base
0295         {
0296             cppcoro::coroutine_handle<promise_type> m_coroutine;
0297 
0298             awaitable_base(cppcoro::coroutine_handle<promise_type> coroutine) noexcept
0299                 : m_coroutine(coroutine)
0300             {}
0301 
0302             bool await_ready() const noexcept
0303             {
0304                 return !m_coroutine || m_coroutine.done();
0305             }
0306 
0307 #if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER
0308             cppcoro::coroutine_handle<> await_suspend(
0309                 cppcoro::coroutine_handle<> awaitingCoroutine) noexcept
0310             {
0311                 m_coroutine.promise().set_continuation(awaitingCoroutine);
0312                 return m_coroutine;
0313             }
0314 #else
0315             bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept
0316             {
0317                 // NOTE: We are using the bool-returning version of await_suspend() here
0318                 // to work around a potential stack-overflow issue if a coroutine
0319                 // awaits many synchronously-completing tasks in a loop.
0320                 //
0321                 // We first start the task by calling resume() and then conditionally
0322                 // attach the continuation if it has not already completed. This allows us
0323                 // to immediately resume the awaiting coroutine without increasing
0324                 // the stack depth, avoiding the stack-overflow problem. However, it has
0325                 // the down-side of requiring a std::atomic to arbitrate the race between
0326                 // the coroutine potentially completing on another thread concurrently
0327                 // with registering the continuation on this thread.
0328                 //
0329                 // We can eliminate the use of the std::atomic once we have access to
0330                 // coroutine_handle-returning await_suspend() on both MSVC and Clang
0331                 // as this will provide ability to suspend the awaiting coroutine and
0332                 // resume another coroutine with a guaranteed tail-call to resume().
0333                 m_coroutine.resume();
0334                 return m_coroutine.promise().try_set_continuation(awaitingCoroutine);
0335             }
0336 #endif
0337         };
0338 
0339     public:
0340 
0341         task() noexcept
0342             : m_coroutine(nullptr)
0343         {}
0344 
0345         explicit task(cppcoro::coroutine_handle<promise_type> coroutine)
0346             : m_coroutine(coroutine)
0347         {}
0348 
0349         task(task&& t) noexcept
0350             : m_coroutine(t.m_coroutine)
0351         {
0352             t.m_coroutine = nullptr;
0353         }
0354 
0355         /// Disable copy construction/assignment.
0356         task(const task&) = delete;
0357         task& operator=(const task&) = delete;
0358 
0359         /// Frees resources used by this task.
0360         ~task()
0361         {
0362             if (m_coroutine)
0363             {
0364                 m_coroutine.destroy();
0365             }
0366         }
0367 
0368         task& operator=(task&& other) noexcept
0369         {
0370             if (std::addressof(other) != this)
0371             {
0372                 if (m_coroutine)
0373                 {
0374                     m_coroutine.destroy();
0375                 }
0376 
0377                 m_coroutine = other.m_coroutine;
0378                 other.m_coroutine = nullptr;
0379             }
0380 
0381             return *this;
0382         }
0383 
0384         /// \brief
0385         /// Query if the task result is complete.
0386         ///
0387         /// Awaiting a task that is ready is guaranteed not to block/suspend.
0388         bool is_ready() const noexcept
0389         {
0390             return !m_coroutine || m_coroutine.done();
0391         }
0392 
0393         auto operator co_await() const & noexcept
0394         {
0395             struct awaitable : awaitable_base
0396             {
0397                 using awaitable_base::awaitable_base;
0398 
0399                 decltype(auto) await_resume()
0400                 {
0401                     if (!this->m_coroutine)
0402                     {
0403                         throw broken_promise{};
0404                     }
0405 
0406                     return this->m_coroutine.promise().result();
0407                 }
0408             };
0409 
0410             return awaitable{ m_coroutine };
0411         }
0412 
0413         auto operator co_await() const && noexcept
0414         {
0415             struct awaitable : awaitable_base
0416             {
0417                 using awaitable_base::awaitable_base;
0418 
0419                 decltype(auto) await_resume()
0420                 {
0421                     if (!this->m_coroutine)
0422                     {
0423                         throw broken_promise{};
0424                     }
0425 
0426                     return std::move(this->m_coroutine.promise()).result();
0427                 }
0428             };
0429 
0430             return awaitable{ m_coroutine };
0431         }
0432 
0433         /// \brief
0434         /// Returns an awaitable that will await completion of the task without
0435         /// attempting to retrieve the result.
0436         auto when_ready() const noexcept
0437         {
0438             struct awaitable : awaitable_base
0439             {
0440                 using awaitable_base::awaitable_base;
0441 
0442                 void await_resume() const noexcept {}
0443             };
0444 
0445             return awaitable{ m_coroutine };
0446         }
0447 
0448     private:
0449 
0450         cppcoro::coroutine_handle<promise_type> m_coroutine;
0451 
0452     };
0453 
0454     namespace detail
0455     {
0456         template<typename T>
0457         task<T> task_promise<T>::get_return_object() noexcept
0458         {
0459             return task<T>{ cppcoro::coroutine_handle<task_promise>::from_promise(*this) };
0460         }
0461 
0462         inline task<void> task_promise<void>::get_return_object() noexcept
0463         {
0464             return task<void>{ cppcoro::coroutine_handle<task_promise>::from_promise(*this) };
0465         }
0466 
0467         template<typename T>
0468         task<T&> task_promise<T&>::get_return_object() noexcept
0469         {
0470             return task<T&>{ cppcoro::coroutine_handle<task_promise>::from_promise(*this) };
0471         }
0472     }
0473 
0474     template<typename AWAITABLE>
0475     auto make_task(AWAITABLE awaitable)
0476         -> task<detail::remove_rvalue_reference_t<typename awaitable_traits<AWAITABLE>::await_result_t>>
0477     {
0478         co_return co_await static_cast<AWAITABLE&&>(awaitable);
0479     }
0480 }
0481 
0482 #endif