Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:03:46

0001 ///////////////////////////////////////////////////////////////////////////////
0002 // Copyright (c) Lewis Baker
0003 // Licenced under MIT license. See LICENSE.txt for details.
0004 ///////////////////////////////////////////////////////////////////////////////
0005 #ifndef CPPCORO_DETAIL_WHEN_ALL_TASK_HPP_INCLUDED
0006 #define CPPCORO_DETAIL_WHEN_ALL_TASK_HPP_INCLUDED
0007 
0008 #include <cppcoro/awaitable_traits.hpp>
0009 
0010 #include <cppcoro/detail/when_all_counter.hpp>
0011 #include <cppcoro/detail/void_value.hpp>
0012 
0013 #include <cppcoro/coroutine.hpp>
0014 #include <cassert>
0015 #include <exception>
0016 
0017 namespace cppcoro
0018 {
0019     namespace detail
0020     {
0021         template<typename TASK_CONTAINER>
0022         class when_all_ready_awaitable;
0023 
0024         template<typename RESULT>
0025         class when_all_task;
0026 
0027         template<typename RESULT>
0028         class when_all_task_promise final
0029         {
0030         public:
0031 
0032             using coroutine_handle_t = cppcoro::coroutine_handle<when_all_task_promise<RESULT>>;
0033 
0034             when_all_task_promise() noexcept
0035             {}
0036 
0037             auto get_return_object() noexcept
0038             {
0039                 return coroutine_handle_t::from_promise(*this);
0040             }
0041 
0042             cppcoro::suspend_always initial_suspend() noexcept
0043             {
0044                 return{};
0045             }
0046 
0047             auto final_suspend() noexcept
0048             {
0049                 class completion_notifier
0050                 {
0051                 public:
0052 
0053                     bool await_ready() const noexcept { return false; }
0054 
0055                     void await_suspend(coroutine_handle_t coro) const noexcept
0056                     {
0057                         coro.promise().m_counter->notify_awaitable_completed();
0058                     }
0059 
0060                     void await_resume() const noexcept {}
0061 
0062                 };
0063 
0064                 return completion_notifier{};
0065             }
0066 
0067             void unhandled_exception() noexcept
0068             {
0069                 m_exception = std::current_exception();
0070             }
0071 
0072             void return_void() noexcept
0073             {
0074                 // We should have either suspended at co_yield point or
0075                 // an exception was thrown before running off the end of
0076                 // the coroutine.
0077                 assert(false);
0078             }
0079 
0080 #if CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC < 19'20'00000
0081             // HACK: This is needed to work around a bug in MSVC 2017.7/2017.8.
0082             // See comment in make_when_all_task below.
0083             template<typename Awaitable>
0084             Awaitable&& await_transform(Awaitable&& awaitable)
0085             {
0086                 return static_cast<Awaitable&&>(awaitable);
0087             }
0088 
0089             struct get_promise_t {};
0090             static constexpr get_promise_t get_promise = {};
0091 
0092             auto await_transform(get_promise_t)
0093             {
0094                 class awaiter
0095                 {
0096                 public:
0097                     awaiter(when_all_task_promise* promise) noexcept : m_promise(promise) {}
0098                     bool await_ready() noexcept {
0099                         return true;
0100                     }
0101                     void await_suspend(cppcoro::coroutine_handle<>) noexcept {}
0102                     when_all_task_promise& await_resume() noexcept
0103                     {
0104                         return *m_promise;
0105                     }
0106                 private:
0107                     when_all_task_promise* m_promise;
0108                 };
0109                 return awaiter{ this };
0110             }
0111 #endif
0112 
0113 
0114             auto yield_value(RESULT&& result) noexcept
0115             {
0116                 m_result = std::addressof(result);
0117                 return final_suspend();
0118             }
0119 
0120             void start(when_all_counter& counter) noexcept
0121             {
0122                 m_counter = &counter;
0123                 coroutine_handle_t::from_promise(*this).resume();
0124             }
0125 
0126             RESULT& result() &
0127             {
0128                 rethrow_if_exception();
0129                 return *m_result;
0130             }
0131 
0132             RESULT&& result() &&
0133             {
0134                 rethrow_if_exception();
0135                 return std::forward<RESULT>(*m_result);
0136             }
0137 
0138         private:
0139 
0140             void rethrow_if_exception()
0141             {
0142                 if (m_exception)
0143                 {
0144                     std::rethrow_exception(m_exception);
0145                 }
0146             }
0147 
0148             when_all_counter* m_counter;
0149             std::exception_ptr m_exception;
0150             std::add_pointer_t<RESULT> m_result;
0151 
0152         };
0153 
0154         template<>
0155         class when_all_task_promise<void> final
0156         {
0157         public:
0158 
0159             using coroutine_handle_t = cppcoro::coroutine_handle<when_all_task_promise<void>>;
0160 
0161             when_all_task_promise() noexcept
0162             {}
0163 
0164             auto get_return_object() noexcept
0165             {
0166                 return coroutine_handle_t::from_promise(*this);
0167             }
0168 
0169             cppcoro::suspend_always initial_suspend() noexcept
0170             {
0171                 return{};
0172             }
0173 
0174             auto final_suspend() noexcept
0175             {
0176                 class completion_notifier
0177                 {
0178                 public:
0179 
0180                     bool await_ready() const noexcept { return false; }
0181 
0182                     void await_suspend(coroutine_handle_t coro) const noexcept
0183                     {
0184                         coro.promise().m_counter->notify_awaitable_completed();
0185                     }
0186 
0187                     void await_resume() const noexcept {}
0188 
0189                 };
0190 
0191                 return completion_notifier{};
0192             }
0193 
0194             void unhandled_exception() noexcept
0195             {
0196                 m_exception = std::current_exception();
0197             }
0198 
0199             void return_void() noexcept
0200             {
0201             }
0202 
0203             void start(when_all_counter& counter) noexcept
0204             {
0205                 m_counter = &counter;
0206                 coroutine_handle_t::from_promise(*this).resume();
0207             }
0208 
0209             void result()
0210             {
0211                 if (m_exception)
0212                 {
0213                     std::rethrow_exception(m_exception);
0214                 }
0215             }
0216 
0217         private:
0218 
0219             when_all_counter* m_counter;
0220             std::exception_ptr m_exception;
0221 
0222         };
0223 
0224         template<typename RESULT>
0225         class when_all_task final
0226         {
0227         public:
0228 
0229             using promise_type = when_all_task_promise<RESULT>;
0230 
0231             using coroutine_handle_t = typename promise_type::coroutine_handle_t;
0232 
0233             when_all_task(coroutine_handle_t coroutine) noexcept
0234                 : m_coroutine(coroutine)
0235             {}
0236 
0237             when_all_task(when_all_task&& other) noexcept
0238                 : m_coroutine(std::exchange(other.m_coroutine, coroutine_handle_t{}))
0239             {}
0240 
0241             ~when_all_task()
0242             {
0243                 if (m_coroutine) m_coroutine.destroy();
0244             }
0245 
0246             when_all_task(const when_all_task&) = delete;
0247             when_all_task& operator=(const when_all_task&) = delete;
0248 
0249             decltype(auto) result() &
0250             {
0251                 return m_coroutine.promise().result();
0252             }
0253 
0254             decltype(auto) result() &&
0255             {
0256                 return std::move(m_coroutine.promise()).result();
0257             }
0258 
0259             decltype(auto) non_void_result() &
0260             {
0261                 if constexpr (std::is_void_v<decltype(this->result())>)
0262                 {
0263                     this->result();
0264                     return void_value{};
0265                 }
0266                 else
0267                 {
0268                     return this->result();
0269                 }
0270             }
0271 
0272             decltype(auto) non_void_result() &&
0273             {
0274                 if constexpr (std::is_void_v<decltype(this->result())>)
0275                 {
0276                     std::move(*this).result();
0277                     return void_value{};
0278                 }
0279                 else
0280                 {
0281                     return std::move(*this).result();
0282                 }
0283             }
0284 
0285         private:
0286 
0287             template<typename TASK_CONTAINER>
0288             friend class when_all_ready_awaitable;
0289 
0290             void start(when_all_counter& counter) noexcept
0291             {
0292                 m_coroutine.promise().start(counter);
0293             }
0294 
0295             coroutine_handle_t m_coroutine;
0296 
0297         };
0298 
0299         template<
0300             typename AWAITABLE,
0301             typename RESULT = typename cppcoro::awaitable_traits<AWAITABLE&&>::await_result_t,
0302             std::enable_if_t<!std::is_void_v<RESULT>, int> = 0>
0303         when_all_task<RESULT> make_when_all_task(AWAITABLE awaitable)
0304         {
0305 #if CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC < 19'20'00000
0306             // HACK: Workaround another bug in MSVC where the expression 'co_yield co_await x' seems
0307             // to completely ignore the co_yield an never calls promise.yield_value().
0308             // The coroutine seems to be resuming the 'co_await' after the 'co_yield'
0309             // rather than before the 'co_yield'.
0310             // This bug is present in VS 2017.7 and VS 2017.8.
0311             auto& promise = co_await when_all_task_promise<RESULT>::get_promise;
0312             co_await promise.yield_value(co_await std::forward<AWAITABLE>(awaitable));
0313 #else
0314             co_yield co_await static_cast<AWAITABLE&&>(awaitable);
0315 #endif
0316         }
0317 
0318         template<
0319             typename AWAITABLE,
0320             typename RESULT = typename cppcoro::awaitable_traits<AWAITABLE&&>::await_result_t,
0321             std::enable_if_t<std::is_void_v<RESULT>, int> = 0>
0322         when_all_task<void> make_when_all_task(AWAITABLE awaitable)
0323         {
0324             co_await static_cast<AWAITABLE&&>(awaitable);
0325         }
0326 
0327         template<
0328             typename AWAITABLE,
0329             typename RESULT = typename cppcoro::awaitable_traits<AWAITABLE&>::await_result_t,
0330             std::enable_if_t<!std::is_void_v<RESULT>, int> = 0>
0331         when_all_task<RESULT> make_when_all_task(std::reference_wrapper<AWAITABLE> awaitable)
0332         {
0333 #if CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC < 19'20'00000
0334             // HACK: Workaround another bug in MSVC where the expression 'co_yield co_await x' seems
0335             // to completely ignore the co_yield and never calls promise.yield_value().
0336             // The coroutine seems to be resuming the 'co_await' after the 'co_yield'
0337             // rather than before the 'co_yield'.
0338             // This bug is present in VS 2017.7 and VS 2017.8.
0339             auto& promise = co_await when_all_task_promise<RESULT>::get_promise;
0340             co_await promise.yield_value(co_await awaitable.get());
0341 #else
0342             co_yield co_await awaitable.get();
0343 #endif
0344         }
0345 
0346         template<
0347             typename AWAITABLE,
0348             typename RESULT = typename cppcoro::awaitable_traits<AWAITABLE&>::await_result_t,
0349             std::enable_if_t<std::is_void_v<RESULT>, int> = 0>
0350         when_all_task<void> make_when_all_task(std::reference_wrapper<AWAITABLE> awaitable)
0351         {
0352             co_await awaitable.get();
0353         }
0354     }
0355 }
0356 
0357 #endif