File indexing completed on 2025-01-18 09:54:52
0001
0002
0003
0004
0005 #ifndef CPPCORO_SHARED_LAZY_TASK_HPP_INCLUDED
0006 #define CPPCORO_SHARED_LAZY_TASK_HPP_INCLUDED
0007
0008 #include <cppcoro/config.hpp>
0009 #include <cppcoro/awaitable_traits.hpp>
0010 #include <cppcoro/broken_promise.hpp>
0011 #include <cppcoro/task.hpp>
0012
0013 #include <cppcoro/detail/remove_rvalue_reference.hpp>
0014
0015 #include <atomic>
0016 #include <exception>
0017 #include <utility>
0018 #include <type_traits>
0019
0020 #include <cppcoro/coroutine.hpp>
0021
0022 namespace cppcoro
0023 {
0024 template<typename T>
0025 class shared_task;
0026
0027 namespace detail
0028 {
0029 struct shared_task_waiter
0030 {
0031 cppcoro::coroutine_handle<> m_continuation;
0032 shared_task_waiter* m_next;
0033 };
0034
0035 class shared_task_promise_base
0036 {
0037 friend struct final_awaiter;
0038
0039 struct final_awaiter
0040 {
0041 bool await_ready() const noexcept { return false; }
0042
0043 template<typename PROMISE>
0044 void await_suspend(cppcoro::coroutine_handle<PROMISE> h) noexcept
0045 {
0046 shared_task_promise_base& promise = h.promise();
0047
0048
0049
0050
0051 void* const valueReadyValue = &promise;
0052 void* waiters = promise.m_waiters.exchange(valueReadyValue, std::memory_order_acq_rel);
0053 if (waiters != nullptr)
0054 {
0055 shared_task_waiter* waiter = static_cast<shared_task_waiter*>(waiters);
0056 while (waiter->m_next != nullptr)
0057 {
0058
0059
0060 auto* next = waiter->m_next;
0061 waiter->m_continuation.resume();
0062 waiter = next;
0063 }
0064
0065
0066
0067 waiter->m_continuation.resume();
0068 }
0069 }
0070
0071 void await_resume() noexcept {}
0072 };
0073
0074 public:
0075
0076 shared_task_promise_base() noexcept
0077 : m_refCount(1)
0078 , m_waiters(&this->m_waiters)
0079 , m_exception(nullptr)
0080 {}
0081
0082 cppcoro::suspend_always initial_suspend() noexcept { return {}; }
0083 final_awaiter final_suspend() noexcept { return {}; }
0084
0085 void unhandled_exception() noexcept
0086 {
0087 m_exception = std::current_exception();
0088 }
0089
0090 bool is_ready() const noexcept
0091 {
0092 const void* const valueReadyValue = this;
0093 return m_waiters.load(std::memory_order_acquire) == valueReadyValue;
0094 }
0095
0096 void add_ref() noexcept
0097 {
0098 m_refCount.fetch_add(1, std::memory_order_relaxed);
0099 }
0100
0101
0102
0103
0104
0105
0106
0107 bool try_detach() noexcept
0108 {
0109 return m_refCount.fetch_sub(1, std::memory_order_acq_rel) != 1;
0110 }
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127 bool try_await(shared_task_waiter* waiter, cppcoro::coroutine_handle<> coroutine)
0128 {
0129 void* const valueReadyValue = this;
0130 void* const notStartedValue = &this->m_waiters;
0131 constexpr void* startedNoWaitersValue = static_cast<shared_task_waiter*>(nullptr);
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144 void* oldWaiters = m_waiters.load(std::memory_order_acquire);
0145 if (oldWaiters == notStartedValue &&
0146 m_waiters.compare_exchange_strong(
0147 oldWaiters,
0148 startedNoWaitersValue,
0149 std::memory_order_relaxed))
0150 {
0151
0152 coroutine.resume();
0153 oldWaiters = m_waiters.load(std::memory_order_acquire);
0154 }
0155
0156
0157 do
0158 {
0159 if (oldWaiters == valueReadyValue)
0160 {
0161
0162 return false;
0163 }
0164
0165 waiter->m_next = static_cast<shared_task_waiter*>(oldWaiters);
0166 } while (!m_waiters.compare_exchange_weak(
0167 oldWaiters,
0168 static_cast<void*>(waiter),
0169 std::memory_order_release,
0170 std::memory_order_acquire));
0171
0172 return true;
0173 }
0174
0175 protected:
0176
0177 bool completed_with_unhandled_exception()
0178 {
0179 return m_exception != nullptr;
0180 }
0181
0182 void rethrow_if_unhandled_exception()
0183 {
0184 if (m_exception != nullptr)
0185 {
0186 std::rethrow_exception(m_exception);
0187 }
0188 }
0189
0190 private:
0191
0192 std::atomic<std::uint32_t> m_refCount;
0193
0194
0195
0196
0197
0198
0199
0200
0201 std::atomic<void*> m_waiters;
0202
0203 std::exception_ptr m_exception;
0204
0205 };
0206
0207 template<typename T>
0208 class shared_task_promise : public shared_task_promise_base
0209 {
0210 public:
0211
0212 shared_task_promise() noexcept = default;
0213
0214 ~shared_task_promise()
0215 {
0216 if (this->is_ready() && !this->completed_with_unhandled_exception())
0217 {
0218 reinterpret_cast<T*>(&m_valueStorage)->~T();
0219 }
0220 }
0221
0222 shared_task<T> get_return_object() noexcept;
0223
0224 template<
0225 typename VALUE,
0226 typename = std::enable_if_t<std::is_convertible_v<VALUE&&, T>>>
0227 void return_value(VALUE&& value)
0228 noexcept(std::is_nothrow_constructible_v<T, VALUE&&>)
0229 {
0230 new (&m_valueStorage) T(std::forward<VALUE>(value));
0231 }
0232
0233 T& result()
0234 {
0235 this->rethrow_if_unhandled_exception();
0236 return *reinterpret_cast<T*>(&m_valueStorage);
0237 }
0238
0239 private:
0240
0241
0242
0243
0244 alignas(T) char m_valueStorage[sizeof(T)];
0245
0246 };
0247
0248 template<>
0249 class shared_task_promise<void> : public shared_task_promise_base
0250 {
0251 public:
0252
0253 shared_task_promise() noexcept = default;
0254
0255 shared_task<void> get_return_object() noexcept;
0256
0257 void return_void() noexcept
0258 {}
0259
0260 void result()
0261 {
0262 this->rethrow_if_unhandled_exception();
0263 }
0264
0265 };
0266
0267 template<typename T>
0268 class shared_task_promise<T&> : public shared_task_promise_base
0269 {
0270 public:
0271
0272 shared_task_promise() noexcept = default;
0273
0274 shared_task<T&> get_return_object() noexcept;
0275
0276 void return_value(T& value) noexcept
0277 {
0278 m_value = std::addressof(value);
0279 }
0280
0281 T& result()
0282 {
0283 this->rethrow_if_unhandled_exception();
0284 return *m_value;
0285 }
0286
0287 private:
0288
0289 T* m_value;
0290
0291 };
0292 }
0293
0294 template<typename T = void>
0295 class [[nodiscard]] shared_task
0296 {
0297 public:
0298
0299 using promise_type = detail::shared_task_promise<T>;
0300
0301 using value_type = T;
0302
0303 private:
0304
0305 struct awaitable_base
0306 {
0307 cppcoro::coroutine_handle<promise_type> m_coroutine;
0308 detail::shared_task_waiter m_waiter;
0309
0310 awaitable_base(cppcoro::coroutine_handle<promise_type> coroutine) noexcept
0311 : m_coroutine(coroutine)
0312 {}
0313
0314 bool await_ready() const noexcept
0315 {
0316 return !m_coroutine || m_coroutine.promise().is_ready();
0317 }
0318
0319 bool await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept
0320 {
0321 m_waiter.m_continuation = awaiter;
0322 return m_coroutine.promise().try_await(&m_waiter, m_coroutine);
0323 }
0324 };
0325
0326 public:
0327
0328 shared_task() noexcept
0329 : m_coroutine(nullptr)
0330 {}
0331
0332 explicit shared_task(cppcoro::coroutine_handle<promise_type> coroutine)
0333 : m_coroutine(coroutine)
0334 {
0335
0336
0337
0338 }
0339
0340 shared_task(shared_task&& other) noexcept
0341 : m_coroutine(other.m_coroutine)
0342 {
0343 other.m_coroutine = nullptr;
0344 }
0345
0346 shared_task(const shared_task& other) noexcept
0347 : m_coroutine(other.m_coroutine)
0348 {
0349 if (m_coroutine)
0350 {
0351 m_coroutine.promise().add_ref();
0352 }
0353 }
0354
0355 ~shared_task()
0356 {
0357 destroy();
0358 }
0359
0360 shared_task& operator=(shared_task&& other) noexcept
0361 {
0362 if (&other != this)
0363 {
0364 destroy();
0365
0366 m_coroutine = other.m_coroutine;
0367 other.m_coroutine = nullptr;
0368 }
0369
0370 return *this;
0371 }
0372
0373 shared_task& operator=(const shared_task& other) noexcept
0374 {
0375 if (m_coroutine != other.m_coroutine)
0376 {
0377 destroy();
0378
0379 m_coroutine = other.m_coroutine;
0380
0381 if (m_coroutine)
0382 {
0383 m_coroutine.promise().add_ref();
0384 }
0385 }
0386
0387 return *this;
0388 }
0389
0390 void swap(shared_task& other) noexcept
0391 {
0392 std::swap(m_coroutine, other.m_coroutine);
0393 }
0394
0395
0396
0397
0398
0399 bool is_ready() const noexcept
0400 {
0401 return !m_coroutine || m_coroutine.promise().is_ready();
0402 }
0403
0404 auto operator co_await() const noexcept
0405 {
0406 struct awaitable : awaitable_base
0407 {
0408 using awaitable_base::awaitable_base;
0409
0410 decltype(auto) await_resume()
0411 {
0412 if (!this->m_coroutine)
0413 {
0414 throw broken_promise{};
0415 }
0416
0417 return this->m_coroutine.promise().result();
0418 }
0419 };
0420
0421 return awaitable{ m_coroutine };
0422 }
0423
0424
0425
0426
0427 auto when_ready() const noexcept
0428 {
0429 struct awaitable : awaitable_base
0430 {
0431 using awaitable_base::awaitable_base;
0432
0433 void await_resume() const noexcept {}
0434 };
0435
0436 return awaitable{ m_coroutine };
0437 }
0438
0439 private:
0440
0441 template<typename U>
0442 friend bool operator==(const shared_task<U>&, const shared_task<U>&) noexcept;
0443
0444 void destroy() noexcept
0445 {
0446 if (m_coroutine)
0447 {
0448 if (!m_coroutine.promise().try_detach())
0449 {
0450 m_coroutine.destroy();
0451 }
0452 }
0453 }
0454
0455 cppcoro::coroutine_handle<promise_type> m_coroutine;
0456
0457 };
0458
0459 template<typename T>
0460 bool operator==(const shared_task<T>& lhs, const shared_task<T>& rhs) noexcept
0461 {
0462 return lhs.m_coroutine == rhs.m_coroutine;
0463 }
0464
0465 template<typename T>
0466 bool operator!=(const shared_task<T>& lhs, const shared_task<T>& rhs) noexcept
0467 {
0468 return !(lhs == rhs);
0469 }
0470
0471 template<typename T>
0472 void swap(shared_task<T>& a, shared_task<T>& b) noexcept
0473 {
0474 a.swap(b);
0475 }
0476
0477 namespace detail
0478 {
0479 template<typename T>
0480 shared_task<T> shared_task_promise<T>::get_return_object() noexcept
0481 {
0482 return shared_task<T>{
0483 cppcoro::coroutine_handle<shared_task_promise>::from_promise(*this)
0484 };
0485 }
0486
0487 template<typename T>
0488 shared_task<T&> shared_task_promise<T&>::get_return_object() noexcept
0489 {
0490 return shared_task<T&>{
0491 cppcoro::coroutine_handle<shared_task_promise>::from_promise(*this)
0492 };
0493 }
0494
0495 inline shared_task<void> shared_task_promise<void>::get_return_object() noexcept
0496 {
0497 return shared_task<void>{
0498 cppcoro::coroutine_handle<shared_task_promise>::from_promise(*this)
0499 };
0500 }
0501 }
0502
0503 template<typename AWAITABLE>
0504 auto make_shared_task(AWAITABLE awaitable)
0505 -> shared_task<detail::remove_rvalue_reference_t<typename awaitable_traits<AWAITABLE>::await_result_t>>
0506 {
0507 co_return co_await static_cast<AWAITABLE&&>(awaitable);
0508 }
0509 }
0510
0511 #endif