Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:03:46

0001 ///////////////////////////////////////////////////////////////////////////////
0002 // Copyright (c) Lewis Baker
0003 // Licenced under MIT license. See LICENSE.txt for details.
0004 ///////////////////////////////////////////////////////////////////////////////
0005 #ifndef CPPCORO_DETAIL_WIN32_OVERLAPPED_OPERATION_HPP_INCLUDED
0006 #define CPPCORO_DETAIL_WIN32_OVERLAPPED_OPERATION_HPP_INCLUDED
0007 
0008 #include <cppcoro/cancellation_registration.hpp>
0009 #include <cppcoro/cancellation_token.hpp>
0010 #include <cppcoro/operation_cancelled.hpp>
0011 
0012 #include <cppcoro/detail/win32.hpp>
0013 
0014 #include <optional>
0015 #include <system_error>
0016 #include <cppcoro/coroutine.hpp>
0017 #include <cassert>
0018 
0019 namespace cppcoro
0020 {
0021     namespace detail
0022     {
0023         class win32_overlapped_operation_base
0024             : protected detail::win32::io_state
0025         {
0026         public:
0027 
0028             win32_overlapped_operation_base(
0029                 detail::win32::io_state::callback_type* callback) noexcept
0030                 : detail::win32::io_state(callback)
0031                 , m_errorCode(0)
0032                 , m_numberOfBytesTransferred(0)
0033             {}
0034 
0035             win32_overlapped_operation_base(
0036                 void* pointer,
0037                 detail::win32::io_state::callback_type* callback) noexcept
0038                 : detail::win32::io_state(pointer, callback)
0039                 , m_errorCode(0)
0040                 , m_numberOfBytesTransferred(0)
0041             {}
0042 
0043             win32_overlapped_operation_base(
0044                 std::uint64_t offset,
0045                 detail::win32::io_state::callback_type* callback) noexcept
0046                 : detail::win32::io_state(offset, callback)
0047                 , m_errorCode(0)
0048                 , m_numberOfBytesTransferred(0)
0049             {}
0050 
0051             _OVERLAPPED* get_overlapped() noexcept
0052             {
0053                 return reinterpret_cast<_OVERLAPPED*>(
0054                     static_cast<detail::win32::overlapped*>(this));
0055             }
0056 
0057             std::size_t get_result()
0058             {
0059                 if (m_errorCode != 0)
0060                 {
0061                     throw std::system_error{
0062                         static_cast<int>(m_errorCode),
0063                         std::system_category()
0064                     };
0065                 }
0066 
0067                 return m_numberOfBytesTransferred;
0068             }
0069 
0070             detail::win32::dword_t m_errorCode;
0071             detail::win32::dword_t m_numberOfBytesTransferred;
0072 
0073         };
0074 
0075         template<typename OPERATION>
0076         class win32_overlapped_operation
0077             : protected win32_overlapped_operation_base
0078         {
0079         protected:
0080 
0081             win32_overlapped_operation() noexcept
0082                 : win32_overlapped_operation_base(
0083                     &win32_overlapped_operation::on_operation_completed)
0084             {}
0085 
0086             win32_overlapped_operation(void* pointer) noexcept
0087                 : win32_overlapped_operation_base(
0088                     pointer,
0089                     &win32_overlapped_operation::on_operation_completed)
0090             {}
0091 
0092             win32_overlapped_operation(std::uint64_t offset) noexcept
0093                 : win32_overlapped_operation_base(
0094                     offset,
0095                     &win32_overlapped_operation::on_operation_completed)
0096             {}
0097 
0098         public:
0099 
0100             bool await_ready() const noexcept { return false; }
0101 
0102             CPPCORO_NOINLINE
0103             bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine)
0104             {
0105                 static_assert(std::is_base_of_v<win32_overlapped_operation, OPERATION>);
0106 
0107                 m_awaitingCoroutine = awaitingCoroutine;
0108                 return static_cast<OPERATION*>(this)->try_start();
0109             }
0110 
0111             decltype(auto) await_resume()
0112             {
0113                 return static_cast<OPERATION*>(this)->get_result();
0114             }
0115 
0116         private:
0117 
0118             static void on_operation_completed(
0119                 detail::win32::io_state* ioState,
0120                 detail::win32::dword_t errorCode,
0121                 detail::win32::dword_t numberOfBytesTransferred,
0122                 [[maybe_unused]] detail::win32::ulongptr_t completionKey) noexcept
0123             {
0124                 auto* operation = static_cast<win32_overlapped_operation*>(ioState);
0125                 operation->m_errorCode = errorCode;
0126                 operation->m_numberOfBytesTransferred = numberOfBytesTransferred;
0127                 operation->m_awaitingCoroutine.resume();
0128             }
0129 
0130             cppcoro::coroutine_handle<> m_awaitingCoroutine;
0131 
0132         };
0133 
0134         template<typename OPERATION>
0135         class win32_overlapped_operation_cancellable
0136             : protected win32_overlapped_operation_base
0137         {
0138             // ERROR_OPERATION_ABORTED value from <Windows.h>
0139             static constexpr detail::win32::dword_t error_operation_aborted = 995L;
0140 
0141         protected:
0142 
0143             win32_overlapped_operation_cancellable(cancellation_token&& ct) noexcept
0144                 : win32_overlapped_operation_base(&win32_overlapped_operation_cancellable::on_operation_completed)
0145                 , m_state(ct.is_cancellation_requested() ? state::completed : state::not_started)
0146                 , m_cancellationToken(std::move(ct))
0147             {
0148                 m_errorCode = error_operation_aborted;
0149             }
0150 
0151             win32_overlapped_operation_cancellable(
0152                 void* pointer,
0153                 cancellation_token&& ct) noexcept
0154                 : win32_overlapped_operation_base(pointer, &win32_overlapped_operation_cancellable::on_operation_completed)
0155                 , m_state(ct.is_cancellation_requested() ? state::completed : state::not_started)
0156                 , m_cancellationToken(std::move(ct))
0157             {
0158                 m_errorCode = error_operation_aborted;
0159             }
0160 
0161             win32_overlapped_operation_cancellable(
0162                 std::uint64_t offset,
0163                 cancellation_token&& ct) noexcept
0164                 : win32_overlapped_operation_base(offset, &win32_overlapped_operation_cancellable::on_operation_completed)
0165                 , m_state(ct.is_cancellation_requested() ? state::completed : state::not_started)
0166                 , m_cancellationToken(std::move(ct))
0167             {
0168                 m_errorCode = error_operation_aborted;
0169             }
0170 
0171             win32_overlapped_operation_cancellable(
0172                 win32_overlapped_operation_cancellable&& other) noexcept
0173                 : win32_overlapped_operation_base(std::move(other))
0174                 , m_state(other.m_state.load(std::memory_order_relaxed))
0175                 , m_cancellationToken(std::move(other.m_cancellationToken))
0176             {
0177                 assert(m_errorCode == other.m_errorCode);
0178                 assert(m_numberOfBytesTransferred == other.m_numberOfBytesTransferred);
0179             }
0180 
0181         public:
0182 
0183             bool await_ready() const noexcept
0184             {
0185                 return m_state.load(std::memory_order_relaxed) == state::completed;
0186             }
0187 
0188             CPPCORO_NOINLINE
0189             bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine)
0190             {
0191                 static_assert(std::is_base_of_v<win32_overlapped_operation_cancellable, OPERATION>);
0192 
0193                 m_awaitingCoroutine = awaitingCoroutine;
0194 
0195                 // TRICKY: Register cancellation callback before starting the operation
0196                 // in case the callback registration throws due to insufficient
0197                 // memory. We need to make sure that the logic that occurs after
0198                 // starting the operation is noexcept, otherwise we run into the
0199                 // problem of not being able to cancel the started operation and
0200                 // the dilemma of what to do with the exception.
0201                 //
0202                 // However, doing this means that the cancellation callback may run
0203                 // prior to returning below so in the case that cancellation may
0204                 // occur we defer setting the state to 'started' until after
0205                 // the operation has finished starting. The cancellation callback
0206                 // will only attempt to request cancellation of the operation with
0207                 // CancelIoEx() once the state has been set to 'started'.
0208                 const bool canBeCancelled = m_cancellationToken.can_be_cancelled();
0209                 if (canBeCancelled)
0210                 {
0211                     m_cancellationCallback.emplace(
0212                         std::move(m_cancellationToken),
0213                         [this] { this->on_cancellation_requested(); });
0214                 }
0215                 else
0216                 {
0217                     m_state.store(state::started, std::memory_order_relaxed);
0218                 }
0219 
0220                 // Now start the operation.
0221                 const bool willCompleteAsynchronously = static_cast<OPERATION*>(this)->try_start();
0222                 if (!willCompleteAsynchronously)
0223                 {
0224                     // Operation completed synchronously, resume awaiting coroutine immediately.
0225                     return false;
0226                 }
0227 
0228                 if (canBeCancelled)
0229                 {
0230                     // Need to flag that the operation has finished starting now.
0231 
0232                     // However, the operation may have completed concurrently on
0233                     // another thread, transitioning directly from not_started -> complete.
0234                     // Or it may have had the cancellation callback execute and transition
0235                     // from not_started -> cancellation_requested. We use a compare-exchange
0236                     // to determine a winner between these potential racing cases.
0237                     state oldState = state::not_started;
0238                     if (!m_state.compare_exchange_strong(
0239                         oldState,
0240                         state::started,
0241                         std::memory_order_release,
0242                         std::memory_order_acquire))
0243                     {
0244                         if (oldState == state::cancellation_requested)
0245                         {
0246                             // Request the operation be cancelled.
0247                             // Note that it may have already completed on a background
0248                             // thread by now so this request for cancellation may end up
0249                             // being ignored.
0250                             static_cast<OPERATION*>(this)->cancel();
0251 
0252                             if (!m_state.compare_exchange_strong(
0253                                 oldState,
0254                                 state::started,
0255                                 std::memory_order_release,
0256                                 std::memory_order_acquire))
0257                             {
0258                                 assert(oldState == state::completed);
0259                                 return false;
0260                             }
0261                         }
0262                         else
0263                         {
0264                             assert(oldState == state::completed);
0265                             return false;
0266                         }
0267                     }
0268                 }
0269 
0270                 return true;
0271             }
0272 
0273             decltype(auto) await_resume()
0274             {
0275                 // Free memory used by the cancellation callback now that the operation
0276                 // has completed rather than waiting until the operation object destructs.
0277                 // eg. If the operation is passed to when_all() then the operation object
0278                 // may not be destructed until all of the operations complete.
0279                 m_cancellationCallback.reset();
0280 
0281                 if (m_errorCode == error_operation_aborted)
0282                 {
0283                     throw operation_cancelled{};
0284                 }
0285 
0286                 return static_cast<OPERATION*>(this)->get_result();
0287             }
0288 
0289         private:
0290 
0291             enum class state
0292             {
0293                 not_started,
0294                 started,
0295                 cancellation_requested,
0296                 completed
0297             };
0298 
0299             void on_cancellation_requested() noexcept
0300             {
0301                 auto oldState = m_state.load(std::memory_order_acquire);
0302                 if (oldState == state::not_started)
0303                 {
0304                     // This callback is running concurrently with await_suspend().
0305                     // The call to start the operation may not have returned yet so
0306                     // we can't safely request cancellation of it. Instead we try to
0307                     // notify the await_suspend() thread by transitioning the state
0308                     // to state::cancellation_requested so that the await_suspend()
0309                     // thread can request cancellation after it has finished starting
0310                     // the operation.
0311                     const bool transferredCancelResponsibility =
0312                         m_state.compare_exchange_strong(
0313                             oldState,
0314                             state::cancellation_requested,
0315                             std::memory_order_release,
0316                             std::memory_order_acquire);
0317                     if (transferredCancelResponsibility)
0318                     {
0319                         return;
0320                     }
0321                 }
0322 
0323                 // No point requesting cancellation if the operation has already completed.
0324                 if (oldState != state::completed)
0325                 {
0326                     static_cast<OPERATION*>(this)->cancel();
0327                 }
0328             }
0329 
0330             static void on_operation_completed(
0331                 detail::win32::io_state* ioState,
0332                 detail::win32::dword_t errorCode,
0333                 detail::win32::dword_t numberOfBytesTransferred,
0334                 [[maybe_unused]] detail::win32::ulongptr_t completionKey) noexcept
0335             {
0336                 auto* operation = static_cast<win32_overlapped_operation_cancellable*>(ioState);
0337 
0338                 operation->m_errorCode = errorCode;
0339                 operation->m_numberOfBytesTransferred = numberOfBytesTransferred;
0340 
0341                 auto state = operation->m_state.load(std::memory_order_acquire);
0342                 if (state == state::started)
0343                 {
0344                     operation->m_state.store(state::completed, std::memory_order_relaxed);
0345                     operation->m_awaitingCoroutine.resume();
0346                 }
0347                 else
0348                 {
0349                     // We are racing with await_suspend() call suspending.
0350                     // Try to mark it as completed using an atomic exchange and look
0351                     // at the previous value to determine whether the coroutine suspended
0352                     // first (in which case we resume it now) or we marked it as completed
0353                     // first (in which case await_suspend() will return false and immediately
0354                     // resume the coroutine).
0355                     state = operation->m_state.exchange(
0356                         state::completed,
0357                         std::memory_order_acq_rel);
0358                     if (state == state::started)
0359                     {
0360                         // The await_suspend() method returned (or will return) 'true' and so
0361                         // we need to resume the coroutine.
0362                         operation->m_awaitingCoroutine.resume();
0363                     }
0364                 }
0365             }
0366 
0367             std::atomic<state> m_state;
0368             cppcoro::cancellation_token m_cancellationToken;
0369             std::optional<cppcoro::cancellation_registration> m_cancellationCallback;
0370             cppcoro::coroutine_handle<> m_awaitingCoroutine;
0371 
0372         };
0373     }
0374 }
0375 
0376 #endif