File indexing completed on 2025-01-18 09:54:53
0001
0002
0003
0004
0005 #ifndef CPPCORO_TASK_HPP_INCLUDED
0006 #define CPPCORO_TASK_HPP_INCLUDED
0007
0008 #include <cppcoro/config.hpp>
0009 #include <cppcoro/awaitable_traits.hpp>
0010 #include <cppcoro/broken_promise.hpp>
0011
0012 #include <cppcoro/detail/remove_rvalue_reference.hpp>
0013
0014 #include <atomic>
0015 #include <exception>
0016 #include <utility>
0017 #include <type_traits>
0018 #include <cstdint>
0019 #include <cassert>
0020 #include <new>
0021
0022 #include <cppcoro/coroutine.hpp>
0023
0024 namespace cppcoro
0025 {
0026 template<typename T> class task;
0027
0028 namespace detail
0029 {
0030 class task_promise_base
0031 {
0032 friend struct final_awaitable;
0033
0034 struct final_awaitable
0035 {
0036 bool await_ready() const noexcept { return false; }
0037
0038 #if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER
0039 template<typename PROMISE>
0040 cppcoro::coroutine_handle<> await_suspend(
0041 cppcoro::coroutine_handle<PROMISE> coro) noexcept
0042 {
0043 return coro.promise().m_continuation;
0044 }
0045 #else
0046
0047
0048
0049
0050
0051 template<typename PROMISE>
0052 CPPCORO_NOINLINE
0053 void await_suspend(cppcoro::coroutine_handle<PROMISE> coroutine) noexcept
0054 {
0055 task_promise_base& promise = coroutine.promise();
0056
0057
0058
0059
0060
0061
0062
0063 if (promise.m_state.exchange(true, std::memory_order_acq_rel))
0064 {
0065 promise.m_continuation.resume();
0066 }
0067 }
0068 #endif
0069
0070 void await_resume() noexcept {}
0071 };
0072
0073 public:
0074
0075 task_promise_base() noexcept
0076 #if !CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER
0077 : m_state(false)
0078 #endif
0079 {}
0080
0081 auto initial_suspend() noexcept
0082 {
0083 return cppcoro::suspend_always{};
0084 }
0085
0086 auto final_suspend() noexcept
0087 {
0088 return final_awaitable{};
0089 }
0090
0091 #if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER
0092 void set_continuation(cppcoro::coroutine_handle<> continuation) noexcept
0093 {
0094 m_continuation = continuation;
0095 }
0096 #else
0097 bool try_set_continuation(cppcoro::coroutine_handle<> continuation)
0098 {
0099 m_continuation = continuation;
0100 return !m_state.exchange(true, std::memory_order_acq_rel);
0101 }
0102 #endif
0103
0104 private:
0105
0106 cppcoro::coroutine_handle<> m_continuation;
0107
0108 #if !CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER
0109
0110
0111
0112 std::atomic<bool> m_state;
0113 #endif
0114
0115 };
0116
0117 template<typename T>
0118 class task_promise final : public task_promise_base
0119 {
0120 public:
0121
0122 task_promise() noexcept {}
0123
0124 ~task_promise()
0125 {
0126 switch (m_resultType)
0127 {
0128 case result_type::value:
0129 m_value.~T();
0130 break;
0131 case result_type::exception:
0132 m_exception.~exception_ptr();
0133 break;
0134 default:
0135 break;
0136 }
0137 }
0138
0139 task<T> get_return_object() noexcept;
0140
0141 void unhandled_exception() noexcept
0142 {
0143 ::new (static_cast<void*>(std::addressof(m_exception))) std::exception_ptr(
0144 std::current_exception());
0145 m_resultType = result_type::exception;
0146 }
0147
0148 template<
0149 typename VALUE,
0150 typename = std::enable_if_t<std::is_convertible_v<VALUE&&, T>>>
0151 void return_value(VALUE&& value)
0152 noexcept(std::is_nothrow_constructible_v<T, VALUE&&>)
0153 {
0154 ::new (static_cast<void*>(std::addressof(m_value))) T(std::forward<VALUE>(value));
0155 m_resultType = result_type::value;
0156 }
0157
0158 T& result() &
0159 {
0160 if (m_resultType == result_type::exception)
0161 {
0162 std::rethrow_exception(m_exception);
0163 }
0164
0165 assert(m_resultType == result_type::value);
0166
0167 return m_value;
0168 }
0169
0170
0171
0172
0173
0174
0175
0176 using rvalue_type = std::conditional_t<
0177 std::is_arithmetic_v<T> || std::is_pointer_v<T>,
0178 T,
0179 T&&>;
0180
0181 rvalue_type result() &&
0182 {
0183 if (m_resultType == result_type::exception)
0184 {
0185 std::rethrow_exception(m_exception);
0186 }
0187
0188 assert(m_resultType == result_type::value);
0189
0190 return std::move(m_value);
0191 }
0192
0193 private:
0194
0195 enum class result_type { empty, value, exception };
0196
0197 result_type m_resultType = result_type::empty;
0198
0199 union
0200 {
0201 T m_value;
0202 std::exception_ptr m_exception;
0203 };
0204
0205 };
0206
0207 template<>
0208 class task_promise<void> : public task_promise_base
0209 {
0210 public:
0211
0212 task_promise() noexcept = default;
0213
0214 task<void> get_return_object() noexcept;
0215
0216 void return_void() noexcept
0217 {}
0218
0219 void unhandled_exception() noexcept
0220 {
0221 m_exception = std::current_exception();
0222 }
0223
0224 void result()
0225 {
0226 if (m_exception)
0227 {
0228 std::rethrow_exception(m_exception);
0229 }
0230 }
0231
0232 private:
0233
0234 std::exception_ptr m_exception;
0235
0236 };
0237
0238 template<typename T>
0239 class task_promise<T&> : public task_promise_base
0240 {
0241 public:
0242
0243 task_promise() noexcept = default;
0244
0245 task<T&> get_return_object() noexcept;
0246
0247 void unhandled_exception() noexcept
0248 {
0249 m_exception = std::current_exception();
0250 }
0251
0252 void return_value(T& value) noexcept
0253 {
0254 m_value = std::addressof(value);
0255 }
0256
0257 T& result()
0258 {
0259 if (m_exception)
0260 {
0261 std::rethrow_exception(m_exception);
0262 }
0263
0264 return *m_value;
0265 }
0266
0267 private:
0268
0269 T* m_value = nullptr;
0270 std::exception_ptr m_exception;
0271
0272 };
0273 }
0274
0275
0276
0277
0278
0279
0280
0281
0282
0283 template<typename T = void>
0284 class [[nodiscard]] task
0285 {
0286 public:
0287
0288 using promise_type = detail::task_promise<T>;
0289
0290 using value_type = T;
0291
0292 private:
0293
0294 struct awaitable_base
0295 {
0296 cppcoro::coroutine_handle<promise_type> m_coroutine;
0297
0298 awaitable_base(cppcoro::coroutine_handle<promise_type> coroutine) noexcept
0299 : m_coroutine(coroutine)
0300 {}
0301
0302 bool await_ready() const noexcept
0303 {
0304 return !m_coroutine || m_coroutine.done();
0305 }
0306
0307 #if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER
0308 cppcoro::coroutine_handle<> await_suspend(
0309 cppcoro::coroutine_handle<> awaitingCoroutine) noexcept
0310 {
0311 m_coroutine.promise().set_continuation(awaitingCoroutine);
0312 return m_coroutine;
0313 }
0314 #else
0315 bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept
0316 {
0317
0318
0319
0320
0321
0322
0323
0324
0325
0326
0327
0328
0329
0330
0331
0332
0333 m_coroutine.resume();
0334 return m_coroutine.promise().try_set_continuation(awaitingCoroutine);
0335 }
0336 #endif
0337 };
0338
0339 public:
0340
0341 task() noexcept
0342 : m_coroutine(nullptr)
0343 {}
0344
0345 explicit task(cppcoro::coroutine_handle<promise_type> coroutine)
0346 : m_coroutine(coroutine)
0347 {}
0348
0349 task(task&& t) noexcept
0350 : m_coroutine(t.m_coroutine)
0351 {
0352 t.m_coroutine = nullptr;
0353 }
0354
0355
0356 task(const task&) = delete;
0357 task& operator=(const task&) = delete;
0358
0359
0360 ~task()
0361 {
0362 if (m_coroutine)
0363 {
0364 m_coroutine.destroy();
0365 }
0366 }
0367
0368 task& operator=(task&& other) noexcept
0369 {
0370 if (std::addressof(other) != this)
0371 {
0372 if (m_coroutine)
0373 {
0374 m_coroutine.destroy();
0375 }
0376
0377 m_coroutine = other.m_coroutine;
0378 other.m_coroutine = nullptr;
0379 }
0380
0381 return *this;
0382 }
0383
0384
0385
0386
0387
0388 bool is_ready() const noexcept
0389 {
0390 return !m_coroutine || m_coroutine.done();
0391 }
0392
0393 auto operator co_await() const & noexcept
0394 {
0395 struct awaitable : awaitable_base
0396 {
0397 using awaitable_base::awaitable_base;
0398
0399 decltype(auto) await_resume()
0400 {
0401 if (!this->m_coroutine)
0402 {
0403 throw broken_promise{};
0404 }
0405
0406 return this->m_coroutine.promise().result();
0407 }
0408 };
0409
0410 return awaitable{ m_coroutine };
0411 }
0412
0413 auto operator co_await() const && noexcept
0414 {
0415 struct awaitable : awaitable_base
0416 {
0417 using awaitable_base::awaitable_base;
0418
0419 decltype(auto) await_resume()
0420 {
0421 if (!this->m_coroutine)
0422 {
0423 throw broken_promise{};
0424 }
0425
0426 return std::move(this->m_coroutine.promise()).result();
0427 }
0428 };
0429
0430 return awaitable{ m_coroutine };
0431 }
0432
0433
0434
0435
0436 auto when_ready() const noexcept
0437 {
0438 struct awaitable : awaitable_base
0439 {
0440 using awaitable_base::awaitable_base;
0441
0442 void await_resume() const noexcept {}
0443 };
0444
0445 return awaitable{ m_coroutine };
0446 }
0447
0448 private:
0449
0450 cppcoro::coroutine_handle<promise_type> m_coroutine;
0451
0452 };
0453
0454 namespace detail
0455 {
0456 template<typename T>
0457 task<T> task_promise<T>::get_return_object() noexcept
0458 {
0459 return task<T>{ cppcoro::coroutine_handle<task_promise>::from_promise(*this) };
0460 }
0461
0462 inline task<void> task_promise<void>::get_return_object() noexcept
0463 {
0464 return task<void>{ cppcoro::coroutine_handle<task_promise>::from_promise(*this) };
0465 }
0466
0467 template<typename T>
0468 task<T&> task_promise<T&>::get_return_object() noexcept
0469 {
0470 return task<T&>{ cppcoro::coroutine_handle<task_promise>::from_promise(*this) };
0471 }
0472 }
0473
0474 template<typename AWAITABLE>
0475 auto make_task(AWAITABLE awaitable)
0476 -> task<detail::remove_rvalue_reference_t<typename awaitable_traits<AWAITABLE>::await_result_t>>
0477 {
0478 co_return co_await static_cast<AWAITABLE&&>(awaitable);
0479 }
0480 }
0481
0482 #endif