File indexing completed on 2025-01-30 10:03:46
0001
0002
0003
0004
0005 #ifndef CPPCORO_DETAIL_WHEN_ALL_TASK_HPP_INCLUDED
0006 #define CPPCORO_DETAIL_WHEN_ALL_TASK_HPP_INCLUDED
0007
0008 #include <cppcoro/awaitable_traits.hpp>
0009
0010 #include <cppcoro/detail/when_all_counter.hpp>
0011 #include <cppcoro/detail/void_value.hpp>
0012
0013 #include <cppcoro/coroutine.hpp>
0014 #include <cassert>
0015 #include <exception>
0016
0017 namespace cppcoro
0018 {
0019 namespace detail
0020 {
0021 template<typename TASK_CONTAINER>
0022 class when_all_ready_awaitable;
0023
0024 template<typename RESULT>
0025 class when_all_task;
0026
0027 template<typename RESULT>
0028 class when_all_task_promise final
0029 {
0030 public:
0031
0032 using coroutine_handle_t = cppcoro::coroutine_handle<when_all_task_promise<RESULT>>;
0033
0034 when_all_task_promise() noexcept
0035 {}
0036
0037 auto get_return_object() noexcept
0038 {
0039 return coroutine_handle_t::from_promise(*this);
0040 }
0041
0042 cppcoro::suspend_always initial_suspend() noexcept
0043 {
0044 return{};
0045 }
0046
0047 auto final_suspend() noexcept
0048 {
0049 class completion_notifier
0050 {
0051 public:
0052
0053 bool await_ready() const noexcept { return false; }
0054
0055 void await_suspend(coroutine_handle_t coro) const noexcept
0056 {
0057 coro.promise().m_counter->notify_awaitable_completed();
0058 }
0059
0060 void await_resume() const noexcept {}
0061
0062 };
0063
0064 return completion_notifier{};
0065 }
0066
0067 void unhandled_exception() noexcept
0068 {
0069 m_exception = std::current_exception();
0070 }
0071
0072 void return_void() noexcept
0073 {
0074
0075
0076
0077 assert(false);
0078 }
0079
0080 #if CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC < 19'20'00000
0081
0082
0083 template<typename Awaitable>
0084 Awaitable&& await_transform(Awaitable&& awaitable)
0085 {
0086 return static_cast<Awaitable&&>(awaitable);
0087 }
0088
0089 struct get_promise_t {};
0090 static constexpr get_promise_t get_promise = {};
0091
0092 auto await_transform(get_promise_t)
0093 {
0094 class awaiter
0095 {
0096 public:
0097 awaiter(when_all_task_promise* promise) noexcept : m_promise(promise) {}
0098 bool await_ready() noexcept {
0099 return true;
0100 }
0101 void await_suspend(cppcoro::coroutine_handle<>) noexcept {}
0102 when_all_task_promise& await_resume() noexcept
0103 {
0104 return *m_promise;
0105 }
0106 private:
0107 when_all_task_promise* m_promise;
0108 };
0109 return awaiter{ this };
0110 }
0111 #endif
0112
0113
0114 auto yield_value(RESULT&& result) noexcept
0115 {
0116 m_result = std::addressof(result);
0117 return final_suspend();
0118 }
0119
0120 void start(when_all_counter& counter) noexcept
0121 {
0122 m_counter = &counter;
0123 coroutine_handle_t::from_promise(*this).resume();
0124 }
0125
0126 RESULT& result() &
0127 {
0128 rethrow_if_exception();
0129 return *m_result;
0130 }
0131
0132 RESULT&& result() &&
0133 {
0134 rethrow_if_exception();
0135 return std::forward<RESULT>(*m_result);
0136 }
0137
0138 private:
0139
0140 void rethrow_if_exception()
0141 {
0142 if (m_exception)
0143 {
0144 std::rethrow_exception(m_exception);
0145 }
0146 }
0147
0148 when_all_counter* m_counter;
0149 std::exception_ptr m_exception;
0150 std::add_pointer_t<RESULT> m_result;
0151
0152 };
0153
0154 template<>
0155 class when_all_task_promise<void> final
0156 {
0157 public:
0158
0159 using coroutine_handle_t = cppcoro::coroutine_handle<when_all_task_promise<void>>;
0160
0161 when_all_task_promise() noexcept
0162 {}
0163
0164 auto get_return_object() noexcept
0165 {
0166 return coroutine_handle_t::from_promise(*this);
0167 }
0168
0169 cppcoro::suspend_always initial_suspend() noexcept
0170 {
0171 return{};
0172 }
0173
0174 auto final_suspend() noexcept
0175 {
0176 class completion_notifier
0177 {
0178 public:
0179
0180 bool await_ready() const noexcept { return false; }
0181
0182 void await_suspend(coroutine_handle_t coro) const noexcept
0183 {
0184 coro.promise().m_counter->notify_awaitable_completed();
0185 }
0186
0187 void await_resume() const noexcept {}
0188
0189 };
0190
0191 return completion_notifier{};
0192 }
0193
0194 void unhandled_exception() noexcept
0195 {
0196 m_exception = std::current_exception();
0197 }
0198
0199 void return_void() noexcept
0200 {
0201 }
0202
0203 void start(when_all_counter& counter) noexcept
0204 {
0205 m_counter = &counter;
0206 coroutine_handle_t::from_promise(*this).resume();
0207 }
0208
0209 void result()
0210 {
0211 if (m_exception)
0212 {
0213 std::rethrow_exception(m_exception);
0214 }
0215 }
0216
0217 private:
0218
0219 when_all_counter* m_counter;
0220 std::exception_ptr m_exception;
0221
0222 };
0223
0224 template<typename RESULT>
0225 class when_all_task final
0226 {
0227 public:
0228
0229 using promise_type = when_all_task_promise<RESULT>;
0230
0231 using coroutine_handle_t = typename promise_type::coroutine_handle_t;
0232
0233 when_all_task(coroutine_handle_t coroutine) noexcept
0234 : m_coroutine(coroutine)
0235 {}
0236
0237 when_all_task(when_all_task&& other) noexcept
0238 : m_coroutine(std::exchange(other.m_coroutine, coroutine_handle_t{}))
0239 {}
0240
0241 ~when_all_task()
0242 {
0243 if (m_coroutine) m_coroutine.destroy();
0244 }
0245
0246 when_all_task(const when_all_task&) = delete;
0247 when_all_task& operator=(const when_all_task&) = delete;
0248
0249 decltype(auto) result() &
0250 {
0251 return m_coroutine.promise().result();
0252 }
0253
0254 decltype(auto) result() &&
0255 {
0256 return std::move(m_coroutine.promise()).result();
0257 }
0258
0259 decltype(auto) non_void_result() &
0260 {
0261 if constexpr (std::is_void_v<decltype(this->result())>)
0262 {
0263 this->result();
0264 return void_value{};
0265 }
0266 else
0267 {
0268 return this->result();
0269 }
0270 }
0271
0272 decltype(auto) non_void_result() &&
0273 {
0274 if constexpr (std::is_void_v<decltype(this->result())>)
0275 {
0276 std::move(*this).result();
0277 return void_value{};
0278 }
0279 else
0280 {
0281 return std::move(*this).result();
0282 }
0283 }
0284
0285 private:
0286
0287 template<typename TASK_CONTAINER>
0288 friend class when_all_ready_awaitable;
0289
0290 void start(when_all_counter& counter) noexcept
0291 {
0292 m_coroutine.promise().start(counter);
0293 }
0294
0295 coroutine_handle_t m_coroutine;
0296
0297 };
0298
0299 template<
0300 typename AWAITABLE,
0301 typename RESULT = typename cppcoro::awaitable_traits<AWAITABLE&&>::await_result_t,
0302 std::enable_if_t<!std::is_void_v<RESULT>, int> = 0>
0303 when_all_task<RESULT> make_when_all_task(AWAITABLE awaitable)
0304 {
0305 #if CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC < 19'20'00000
0306
0307
0308
0309
0310
0311 auto& promise = co_await when_all_task_promise<RESULT>::get_promise;
0312 co_await promise.yield_value(co_await std::forward<AWAITABLE>(awaitable));
0313 #else
0314 co_yield co_await static_cast<AWAITABLE&&>(awaitable);
0315 #endif
0316 }
0317
0318 template<
0319 typename AWAITABLE,
0320 typename RESULT = typename cppcoro::awaitable_traits<AWAITABLE&&>::await_result_t,
0321 std::enable_if_t<std::is_void_v<RESULT>, int> = 0>
0322 when_all_task<void> make_when_all_task(AWAITABLE awaitable)
0323 {
0324 co_await static_cast<AWAITABLE&&>(awaitable);
0325 }
0326
0327 template<
0328 typename AWAITABLE,
0329 typename RESULT = typename cppcoro::awaitable_traits<AWAITABLE&>::await_result_t,
0330 std::enable_if_t<!std::is_void_v<RESULT>, int> = 0>
0331 when_all_task<RESULT> make_when_all_task(std::reference_wrapper<AWAITABLE> awaitable)
0332 {
0333 #if CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC < 19'20'00000
0334
0335
0336
0337
0338
0339 auto& promise = co_await when_all_task_promise<RESULT>::get_promise;
0340 co_await promise.yield_value(co_await awaitable.get());
0341 #else
0342 co_yield co_await awaitable.get();
0343 #endif
0344 }
0345
0346 template<
0347 typename AWAITABLE,
0348 typename RESULT = typename cppcoro::awaitable_traits<AWAITABLE&>::await_result_t,
0349 std::enable_if_t<std::is_void_v<RESULT>, int> = 0>
0350 when_all_task<void> make_when_all_task(std::reference_wrapper<AWAITABLE> awaitable)
0351 {
0352 co_await awaitable.get();
0353 }
0354 }
0355 }
0356
0357 #endif