Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 10:14:26

0001 // This file is part of Eigen, a lightweight C++ template library
0002 // for linear algebra.
0003 //
0004 // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
0005 //
0006 // This Source Code Form is subject to the terms of the Mozilla
0007 // Public License v. 2.0. If a copy of the MPL was not distributed
0008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
0009 
0010 #ifndef EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
0011 #define EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
0012 
0013 namespace Eigen {
0014 
0015 // EventCount allows to wait for arbitrary predicates in non-blocking
0016 // algorithms. Think of condition variable, but wait predicate does not need to
0017 // be protected by a mutex. Usage:
0018 // Waiting thread does:
0019 //
0020 //   if (predicate)
0021 //     return act();
0022 //   EventCount::Waiter& w = waiters[my_index];
0023 //   ec.Prewait(&w);
0024 //   if (predicate) {
0025 //     ec.CancelWait(&w);
0026 //     return act();
0027 //   }
0028 //   ec.CommitWait(&w);
0029 //
0030 // Notifying thread does:
0031 //
0032 //   predicate = true;
0033 //   ec.Notify(true);
0034 //
0035 // Notify is cheap if there are no waiting threads. Prewait/CommitWait are not
0036 // cheap, but they are executed only if the preceding predicate check has
0037 // failed.
0038 //
0039 // Algorithm outline:
0040 // There are two main variables: predicate (managed by user) and state_.
0041 // Operation closely resembles Dekker mutual algorithm:
0042 // https://en.wikipedia.org/wiki/Dekker%27s_algorithm
0043 // Waiting thread sets state_ then checks predicate, Notifying thread sets
0044 // predicate then checks state_. Due to seq_cst fences in between these
0045 // operations it is guaranteed than either waiter will see predicate change
0046 // and won't block, or notifying thread will see state_ change and will unblock
0047 // the waiter, or both. But it can't happen that both threads don't see each
0048 // other changes, which would lead to deadlock.
0049 class EventCount {
0050  public:
0051   class Waiter;
0052 
0053   EventCount(MaxSizeVector<Waiter>& waiters)
0054       : state_(kStackMask), waiters_(waiters) {
0055     eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
0056   }
0057 
0058   ~EventCount() {
0059     // Ensure there are no waiters.
0060     eigen_plain_assert(state_.load() == kStackMask);
0061   }
0062 
0063   // Prewait prepares for waiting.
0064   // After calling Prewait, the thread must re-check the wait predicate
0065   // and then call either CancelWait or CommitWait.
0066   void Prewait() {
0067     uint64_t state = state_.load(std::memory_order_relaxed);
0068     for (;;) {
0069       CheckState(state);
0070       uint64_t newstate = state + kWaiterInc;
0071       CheckState(newstate);
0072       if (state_.compare_exchange_weak(state, newstate,
0073                                        std::memory_order_seq_cst))
0074         return;
0075     }
0076   }
0077 
0078   // CommitWait commits waiting after Prewait.
0079   void CommitWait(Waiter* w) {
0080     eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
0081     w->state = Waiter::kNotSignaled;
0082     const uint64_t me = (w - &waiters_[0]) | w->epoch;
0083     uint64_t state = state_.load(std::memory_order_seq_cst);
0084     for (;;) {
0085       CheckState(state, true);
0086       uint64_t newstate;
0087       if ((state & kSignalMask) != 0) {
0088         // Consume the signal and return immidiately.
0089         newstate = state - kWaiterInc - kSignalInc;
0090       } else {
0091         // Remove this thread from pre-wait counter and add to the waiter stack.
0092         newstate = ((state & kWaiterMask) - kWaiterInc) | me;
0093         w->next.store(state & (kStackMask | kEpochMask),
0094                       std::memory_order_relaxed);
0095       }
0096       CheckState(newstate);
0097       if (state_.compare_exchange_weak(state, newstate,
0098                                        std::memory_order_acq_rel)) {
0099         if ((state & kSignalMask) == 0) {
0100           w->epoch += kEpochInc;
0101           Park(w);
0102         }
0103         return;
0104       }
0105     }
0106   }
0107 
0108   // CancelWait cancels effects of the previous Prewait call.
0109   void CancelWait() {
0110     uint64_t state = state_.load(std::memory_order_relaxed);
0111     for (;;) {
0112       CheckState(state, true);
0113       uint64_t newstate = state - kWaiterInc;
0114       // We don't know if the thread was also notified or not,
0115       // so we should not consume a signal unconditionaly.
0116       // Only if number of waiters is equal to number of signals,
0117       // we know that the thread was notified and we must take away the signal.
0118       if (((state & kWaiterMask) >> kWaiterShift) ==
0119           ((state & kSignalMask) >> kSignalShift))
0120         newstate -= kSignalInc;
0121       CheckState(newstate);
0122       if (state_.compare_exchange_weak(state, newstate,
0123                                        std::memory_order_acq_rel))
0124         return;
0125     }
0126   }
0127 
0128   // Notify wakes one or all waiting threads.
0129   // Must be called after changing the associated wait predicate.
0130   void Notify(bool notifyAll) {
0131     std::atomic_thread_fence(std::memory_order_seq_cst);
0132     uint64_t state = state_.load(std::memory_order_acquire);
0133     for (;;) {
0134       CheckState(state);
0135       const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
0136       const uint64_t signals = (state & kSignalMask) >> kSignalShift;
0137       // Easy case: no waiters.
0138       if ((state & kStackMask) == kStackMask && waiters == signals) return;
0139       uint64_t newstate;
0140       if (notifyAll) {
0141         // Empty wait stack and set signal to number of pre-wait threads.
0142         newstate =
0143             (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
0144       } else if (signals < waiters) {
0145         // There is a thread in pre-wait state, unblock it.
0146         newstate = state + kSignalInc;
0147       } else {
0148         // Pop a waiter from list and unpark it.
0149         Waiter* w = &waiters_[state & kStackMask];
0150         uint64_t next = w->next.load(std::memory_order_relaxed);
0151         newstate = (state & (kWaiterMask | kSignalMask)) | next;
0152       }
0153       CheckState(newstate);
0154       if (state_.compare_exchange_weak(state, newstate,
0155                                        std::memory_order_acq_rel)) {
0156         if (!notifyAll && (signals < waiters))
0157           return;  // unblocked pre-wait thread
0158         if ((state & kStackMask) == kStackMask) return;
0159         Waiter* w = &waiters_[state & kStackMask];
0160         if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
0161         Unpark(w);
0162         return;
0163       }
0164     }
0165   }
0166 
0167   class Waiter {
0168     friend class EventCount;
0169     // Align to 128 byte boundary to prevent false sharing with other Waiter
0170     // objects in the same vector.
0171     EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
0172     std::mutex mu;
0173     std::condition_variable cv;
0174     uint64_t epoch = 0;
0175     unsigned state = kNotSignaled;
0176     enum {
0177       kNotSignaled,
0178       kWaiting,
0179       kSignaled,
0180     };
0181   };
0182 
0183  private:
0184   // State_ layout:
0185   // - low kWaiterBits is a stack of waiters committed wait
0186   //   (indexes in waiters_ array are used as stack elements,
0187   //   kStackMask means empty stack).
0188   // - next kWaiterBits is count of waiters in prewait state.
0189   // - next kWaiterBits is count of pending signals.
0190   // - remaining bits are ABA counter for the stack.
0191   //   (stored in Waiter node and incremented on push).
0192   static const uint64_t kWaiterBits = 14;
0193   static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
0194   static const uint64_t kWaiterShift = kWaiterBits;
0195   static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
0196                                       << kWaiterShift;
0197   static const uint64_t kWaiterInc = 1ull << kWaiterShift;
0198   static const uint64_t kSignalShift = 2 * kWaiterBits;
0199   static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
0200                                       << kSignalShift;
0201   static const uint64_t kSignalInc = 1ull << kSignalShift;
0202   static const uint64_t kEpochShift = 3 * kWaiterBits;
0203   static const uint64_t kEpochBits = 64 - kEpochShift;
0204   static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
0205   static const uint64_t kEpochInc = 1ull << kEpochShift;
0206   std::atomic<uint64_t> state_;
0207   MaxSizeVector<Waiter>& waiters_;
0208 
0209   static void CheckState(uint64_t state, bool waiter = false) {
0210     static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
0211     const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
0212     const uint64_t signals = (state & kSignalMask) >> kSignalShift;
0213     eigen_plain_assert(waiters >= signals);
0214     eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
0215     eigen_plain_assert(!waiter || waiters > 0);
0216     (void)waiters;
0217     (void)signals;
0218   }
0219 
0220   void Park(Waiter* w) {
0221     std::unique_lock<std::mutex> lock(w->mu);
0222     while (w->state != Waiter::kSignaled) {
0223       w->state = Waiter::kWaiting;
0224       w->cv.wait(lock);
0225     }
0226   }
0227 
0228   void Unpark(Waiter* w) {
0229     for (Waiter* next; w; w = next) {
0230       uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
0231       next = wnext == kStackMask ? nullptr : &waiters_[wnext];
0232       unsigned state;
0233       {
0234         std::unique_lock<std::mutex> lock(w->mu);
0235         state = w->state;
0236         w->state = Waiter::kSignaled;
0237       }
0238       // Avoid notifying if it wasn't waiting.
0239       if (state == Waiter::kWaiting) w->cv.notify_one();
0240     }
0241   }
0242 
0243   EventCount(const EventCount&) = delete;
0244   void operator=(const EventCount&) = delete;
0245 };
0246 
0247 }  // namespace Eigen
0248 
0249 #endif  // EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_