Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:54:52

0001 ///////////////////////////////////////////////////////////////////////////////
0002 // Copyright (c) Lewis Baker
0003 // Licenced under MIT license. See LICENSE.txt for details.
0004 ///////////////////////////////////////////////////////////////////////////////
0005 #ifndef CPPCORO_SHARED_LAZY_TASK_HPP_INCLUDED
0006 #define CPPCORO_SHARED_LAZY_TASK_HPP_INCLUDED
0007 
0008 #include <cppcoro/config.hpp>
0009 #include <cppcoro/awaitable_traits.hpp>
0010 #include <cppcoro/broken_promise.hpp>
0011 #include <cppcoro/task.hpp>
0012 
0013 #include <cppcoro/detail/remove_rvalue_reference.hpp>
0014 
0015 #include <atomic>
0016 #include <exception>
0017 #include <utility>
0018 #include <type_traits>
0019 
0020 #include <cppcoro/coroutine.hpp>
0021 
0022 namespace cppcoro
0023 {
0024     template<typename T>
0025     class shared_task;
0026 
0027     namespace detail
0028     {
0029         struct shared_task_waiter
0030         {
0031             cppcoro::coroutine_handle<> m_continuation;
0032             shared_task_waiter* m_next;
0033         };
0034 
0035         class shared_task_promise_base
0036         {
0037             friend struct final_awaiter;
0038 
0039             struct final_awaiter
0040             {
0041                 bool await_ready() const noexcept { return false; }
0042 
0043                 template<typename PROMISE>
0044                 void await_suspend(cppcoro::coroutine_handle<PROMISE> h) noexcept
0045                 {
0046                     shared_task_promise_base& promise = h.promise();
0047 
0048                     // Exchange operation needs to be 'release' so that subsequent awaiters have
0049                     // visibility of the result. Also needs to be 'acquire' so we have visibility
0050                     // of writes to the waiters list.
0051                     void* const valueReadyValue = &promise;
0052                     void* waiters = promise.m_waiters.exchange(valueReadyValue, std::memory_order_acq_rel);
0053                     if (waiters != nullptr)
0054                     {
0055                         shared_task_waiter* waiter = static_cast<shared_task_waiter*>(waiters);
0056                         while (waiter->m_next != nullptr)
0057                         {
0058                             // Read the m_next pointer before resuming the coroutine
0059                             // since resuming the coroutine may destroy the shared_task_waiter value.
0060                             auto* next = waiter->m_next;
0061                             waiter->m_continuation.resume();
0062                             waiter = next;
0063                         }
0064 
0065                         // Resume last waiter in tail position to allow it to potentially
0066                         // be compiled as a tail-call.
0067                         waiter->m_continuation.resume();
0068                     }
0069                 }
0070 
0071                 void await_resume() noexcept {}
0072             };
0073 
0074         public:
0075 
0076             shared_task_promise_base() noexcept
0077                 : m_refCount(1)
0078                 , m_waiters(&this->m_waiters)               
0079                 , m_exception(nullptr)
0080             {}
0081 
0082             cppcoro::suspend_always initial_suspend() noexcept { return {}; }
0083             final_awaiter final_suspend() noexcept { return {}; }
0084 
0085             void unhandled_exception() noexcept
0086             {
0087                 m_exception = std::current_exception();
0088             }
0089 
0090             bool is_ready() const noexcept
0091             {
0092                 const void* const valueReadyValue = this;
0093                 return m_waiters.load(std::memory_order_acquire) == valueReadyValue;
0094             }
0095 
0096             void add_ref() noexcept
0097             {
0098                 m_refCount.fetch_add(1, std::memory_order_relaxed);
0099             }
0100 
0101             /// Decrement the reference count.
0102             ///
0103             /// \return
0104             /// true if successfully detached, false if this was the last
0105             /// reference to the coroutine, in which case the caller must
0106             /// call destroy() on the coroutine handle.
0107             bool try_detach() noexcept
0108             {
0109                 return m_refCount.fetch_sub(1, std::memory_order_acq_rel) != 1;
0110             }
0111 
0112             /// Try to enqueue a waiter to the list of waiters.
0113             ///
0114             /// \param waiter
0115             /// Pointer to the state from the waiter object.
0116             /// Must have waiter->m_coroutine member populated with the coroutine
0117             /// handle of the awaiting coroutine.
0118             ///
0119             /// \param coroutine
0120             /// Coroutine handle for this promise object.
0121             ///
0122             /// \return
0123             /// true if the waiter was successfully queued, in which case
0124             /// waiter->m_coroutine will be resumed when the task completes.
0125             /// false if the coroutine was already completed and the awaiting
0126             /// coroutine can continue without suspending.
0127             bool try_await(shared_task_waiter* waiter, cppcoro::coroutine_handle<> coroutine)
0128             {
0129                 void* const valueReadyValue = this;
0130                 void* const notStartedValue = &this->m_waiters;
0131                 constexpr void* startedNoWaitersValue = static_cast<shared_task_waiter*>(nullptr);
0132 
0133                 // NOTE: If the coroutine is not yet started then the first waiter
0134                 // will start the coroutine before enqueuing itself up to the list
0135                 // of suspended waiters waiting for completion. We split this into
0136                 // two steps to allow the first awaiter to return without suspending.
0137                 // This avoids recursively resuming the first waiter inside the call to
0138                 // coroutine.resume() in the case that the coroutine completes
0139                 // synchronously, which could otherwise lead to stack-overflow if
0140                 // the awaiting coroutine awaited many synchronously-completing
0141                 // tasks in a row.
0142 
0143                 // Start the coroutine if not already started.
0144                 void* oldWaiters = m_waiters.load(std::memory_order_acquire);
0145                 if (oldWaiters == notStartedValue &&
0146                     m_waiters.compare_exchange_strong(
0147                       oldWaiters,
0148                       startedNoWaitersValue,
0149                       std::memory_order_relaxed))
0150                 {
0151                     // Start the task executing.
0152                     coroutine.resume();
0153                     oldWaiters = m_waiters.load(std::memory_order_acquire);
0154                 }
0155 
0156                 // Enqueue the waiter into the list of waiting coroutines.
0157                 do
0158                 {
0159                     if (oldWaiters == valueReadyValue)
0160                     {
0161                         // Coroutine already completed, don't suspend.
0162                         return false;
0163                     }
0164 
0165                     waiter->m_next = static_cast<shared_task_waiter*>(oldWaiters);
0166                 } while (!m_waiters.compare_exchange_weak(
0167                     oldWaiters,
0168                     static_cast<void*>(waiter),
0169                     std::memory_order_release,
0170                     std::memory_order_acquire));
0171 
0172                 return true;
0173             }
0174 
0175         protected:
0176 
0177             bool completed_with_unhandled_exception()
0178             {
0179                 return m_exception != nullptr;
0180             }
0181 
0182             void rethrow_if_unhandled_exception()
0183             {
0184                 if (m_exception != nullptr)
0185                 {
0186                     std::rethrow_exception(m_exception);
0187                 }
0188             }
0189 
0190         private:
0191 
0192             std::atomic<std::uint32_t> m_refCount;
0193 
0194             // Value is either
0195             // - nullptr          - indicates started, no waiters
0196             // - this             - indicates value is ready
0197             // - &this->m_waiters - indicates coroutine not started
0198             // - other            - pointer to head item in linked-list of waiters.
0199             //                      values are of type 'cppcoro::shared_task_waiter'.
0200             //                      indicates that the coroutine has been started.
0201             std::atomic<void*> m_waiters;
0202 
0203             std::exception_ptr m_exception;
0204 
0205         };
0206 
0207         template<typename T>
0208         class shared_task_promise : public shared_task_promise_base
0209         {
0210         public:
0211 
0212             shared_task_promise() noexcept = default;
0213 
0214             ~shared_task_promise()
0215             {
0216                 if (this->is_ready() && !this->completed_with_unhandled_exception())
0217                 {
0218                     reinterpret_cast<T*>(&m_valueStorage)->~T();
0219                 }
0220             }
0221 
0222             shared_task<T> get_return_object() noexcept;
0223 
0224             template<
0225                 typename VALUE,
0226                 typename = std::enable_if_t<std::is_convertible_v<VALUE&&, T>>>
0227             void return_value(VALUE&& value)
0228                 noexcept(std::is_nothrow_constructible_v<T, VALUE&&>)
0229             {
0230                 new (&m_valueStorage) T(std::forward<VALUE>(value));
0231             }
0232 
0233             T& result()
0234             {
0235                 this->rethrow_if_unhandled_exception();
0236                 return *reinterpret_cast<T*>(&m_valueStorage);
0237             }
0238 
0239         private:
0240 
0241             // Not using std::aligned_storage here due to bug in MSVC 2015 Update 2
0242             // that means it doesn't work for types with alignof(T) > 8.
0243             // See MS-Connect bug #2658635.
0244             alignas(T) char m_valueStorage[sizeof(T)];
0245 
0246         };
0247 
0248         template<>
0249         class shared_task_promise<void> : public shared_task_promise_base
0250         {
0251         public:
0252 
0253             shared_task_promise() noexcept = default;
0254 
0255             shared_task<void> get_return_object() noexcept;
0256 
0257             void return_void() noexcept
0258             {}
0259 
0260             void result()
0261             {
0262                 this->rethrow_if_unhandled_exception();
0263             }
0264 
0265         };
0266 
0267         template<typename T>
0268         class shared_task_promise<T&> : public shared_task_promise_base
0269         {
0270         public:
0271 
0272             shared_task_promise() noexcept = default;
0273 
0274             shared_task<T&> get_return_object() noexcept;
0275 
0276             void return_value(T& value) noexcept
0277             {
0278                 m_value = std::addressof(value);
0279             }
0280 
0281             T& result()
0282             {
0283                 this->rethrow_if_unhandled_exception();
0284                 return *m_value;
0285             }
0286 
0287         private:
0288 
0289             T* m_value;
0290 
0291         };
0292     }
0293 
0294     template<typename T = void>
0295     class [[nodiscard]] shared_task
0296     {
0297     public:
0298 
0299         using promise_type = detail::shared_task_promise<T>;
0300 
0301         using value_type = T;
0302 
0303     private:
0304 
0305         struct awaitable_base
0306         {
0307             cppcoro::coroutine_handle<promise_type> m_coroutine;
0308             detail::shared_task_waiter m_waiter;
0309 
0310             awaitable_base(cppcoro::coroutine_handle<promise_type> coroutine) noexcept
0311                 : m_coroutine(coroutine)
0312             {}
0313 
0314             bool await_ready() const noexcept
0315             {
0316                 return !m_coroutine || m_coroutine.promise().is_ready();
0317             }
0318 
0319             bool await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept
0320             {
0321                 m_waiter.m_continuation = awaiter;
0322                 return m_coroutine.promise().try_await(&m_waiter, m_coroutine);
0323             }
0324         };
0325 
0326     public:
0327 
0328         shared_task() noexcept
0329             : m_coroutine(nullptr)
0330         {}
0331 
0332         explicit shared_task(cppcoro::coroutine_handle<promise_type> coroutine)
0333             : m_coroutine(coroutine)
0334         {
0335             // Don't increment the ref-count here since it has already been
0336             // initialised to 2 (one for shared_task and one for coroutine)
0337             // in the shared_task_promise constructor.
0338         }
0339 
0340         shared_task(shared_task&& other) noexcept
0341             : m_coroutine(other.m_coroutine)
0342         {
0343             other.m_coroutine = nullptr;
0344         }
0345 
0346         shared_task(const shared_task& other) noexcept
0347             : m_coroutine(other.m_coroutine)
0348         {
0349             if (m_coroutine)
0350             {
0351                 m_coroutine.promise().add_ref();
0352             }
0353         }
0354 
0355         ~shared_task()
0356         {
0357             destroy();
0358         }
0359 
0360         shared_task& operator=(shared_task&& other) noexcept
0361         {
0362             if (&other != this)
0363             {
0364                 destroy();
0365 
0366                 m_coroutine = other.m_coroutine;
0367                 other.m_coroutine = nullptr;
0368             }
0369 
0370             return *this;
0371         }
0372 
0373         shared_task& operator=(const shared_task& other) noexcept
0374         {
0375             if (m_coroutine != other.m_coroutine)
0376             {
0377                 destroy();
0378 
0379                 m_coroutine = other.m_coroutine;
0380 
0381                 if (m_coroutine)
0382                 {
0383                     m_coroutine.promise().add_ref();
0384                 }
0385             }
0386 
0387             return *this;
0388         }
0389 
0390         void swap(shared_task& other) noexcept
0391         {
0392             std::swap(m_coroutine, other.m_coroutine);
0393         }
0394 
0395         /// \brief
0396         /// Query if the task result is complete.
0397         ///
0398         /// Awaiting a task that is ready will not block.
0399         bool is_ready() const noexcept
0400         {
0401             return !m_coroutine || m_coroutine.promise().is_ready();
0402         }
0403 
0404         auto operator co_await() const noexcept
0405         {
0406             struct awaitable : awaitable_base
0407             {
0408                 using awaitable_base::awaitable_base;
0409 
0410                 decltype(auto) await_resume()
0411                 {
0412                     if (!this->m_coroutine)
0413                     {
0414                         throw broken_promise{};
0415                     }
0416 
0417                     return this->m_coroutine.promise().result();
0418                 }
0419             };
0420 
0421             return awaitable{ m_coroutine };
0422         }
0423 
0424         /// \brief
0425         /// Returns an awaitable that will await completion of the task without
0426         /// attempting to retrieve the result.
0427         auto when_ready() const noexcept
0428         {
0429             struct awaitable : awaitable_base
0430             {
0431                 using awaitable_base::awaitable_base;
0432 
0433                 void await_resume() const noexcept {}
0434             };
0435 
0436             return awaitable{ m_coroutine };
0437         }
0438 
0439     private:
0440 
0441         template<typename U>
0442         friend bool operator==(const shared_task<U>&, const shared_task<U>&) noexcept;
0443 
0444         void destroy() noexcept
0445         {
0446             if (m_coroutine)
0447             {
0448                 if (!m_coroutine.promise().try_detach())
0449                 {
0450                     m_coroutine.destroy();
0451                 }
0452             }
0453         }
0454 
0455         cppcoro::coroutine_handle<promise_type> m_coroutine;
0456 
0457     };
0458 
0459     template<typename T>
0460     bool operator==(const shared_task<T>& lhs, const shared_task<T>& rhs) noexcept
0461     {
0462         return lhs.m_coroutine == rhs.m_coroutine;
0463     }
0464 
0465     template<typename T>
0466     bool operator!=(const shared_task<T>& lhs, const shared_task<T>& rhs) noexcept
0467     {
0468         return !(lhs == rhs);
0469     }
0470 
0471     template<typename T>
0472     void swap(shared_task<T>& a, shared_task<T>& b) noexcept
0473     {
0474         a.swap(b);
0475     }
0476 
0477     namespace detail
0478     {
0479         template<typename T>
0480         shared_task<T> shared_task_promise<T>::get_return_object() noexcept
0481         {
0482             return shared_task<T>{
0483                 cppcoro::coroutine_handle<shared_task_promise>::from_promise(*this)
0484             };
0485         }
0486 
0487         template<typename T>
0488         shared_task<T&> shared_task_promise<T&>::get_return_object() noexcept
0489         {
0490             return shared_task<T&>{
0491                 cppcoro::coroutine_handle<shared_task_promise>::from_promise(*this)
0492             };
0493         }
0494 
0495         inline shared_task<void> shared_task_promise<void>::get_return_object() noexcept
0496         {
0497             return shared_task<void>{
0498                 cppcoro::coroutine_handle<shared_task_promise>::from_promise(*this)
0499             };
0500         }
0501     }
0502 
0503     template<typename AWAITABLE>
0504     auto make_shared_task(AWAITABLE awaitable)
0505         -> shared_task<detail::remove_rvalue_reference_t<typename awaitable_traits<AWAITABLE>::await_result_t>>
0506     {
0507         co_return co_await static_cast<AWAITABLE&&>(awaitable);
0508     }
0509 }
0510 
0511 #endif