File indexing completed on 2025-12-16 10:14:26
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
0011 #define EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
0012
0013 namespace Eigen {
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049 class EventCount {
0050 public:
0051 class Waiter;
0052
0053 EventCount(MaxSizeVector<Waiter>& waiters)
0054 : state_(kStackMask), waiters_(waiters) {
0055 eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
0056 }
0057
0058 ~EventCount() {
0059
0060 eigen_plain_assert(state_.load() == kStackMask);
0061 }
0062
0063
0064
0065
0066 void Prewait() {
0067 uint64_t state = state_.load(std::memory_order_relaxed);
0068 for (;;) {
0069 CheckState(state);
0070 uint64_t newstate = state + kWaiterInc;
0071 CheckState(newstate);
0072 if (state_.compare_exchange_weak(state, newstate,
0073 std::memory_order_seq_cst))
0074 return;
0075 }
0076 }
0077
0078
0079 void CommitWait(Waiter* w) {
0080 eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
0081 w->state = Waiter::kNotSignaled;
0082 const uint64_t me = (w - &waiters_[0]) | w->epoch;
0083 uint64_t state = state_.load(std::memory_order_seq_cst);
0084 for (;;) {
0085 CheckState(state, true);
0086 uint64_t newstate;
0087 if ((state & kSignalMask) != 0) {
0088
0089 newstate = state - kWaiterInc - kSignalInc;
0090 } else {
0091
0092 newstate = ((state & kWaiterMask) - kWaiterInc) | me;
0093 w->next.store(state & (kStackMask | kEpochMask),
0094 std::memory_order_relaxed);
0095 }
0096 CheckState(newstate);
0097 if (state_.compare_exchange_weak(state, newstate,
0098 std::memory_order_acq_rel)) {
0099 if ((state & kSignalMask) == 0) {
0100 w->epoch += kEpochInc;
0101 Park(w);
0102 }
0103 return;
0104 }
0105 }
0106 }
0107
0108
0109 void CancelWait() {
0110 uint64_t state = state_.load(std::memory_order_relaxed);
0111 for (;;) {
0112 CheckState(state, true);
0113 uint64_t newstate = state - kWaiterInc;
0114
0115
0116
0117
0118 if (((state & kWaiterMask) >> kWaiterShift) ==
0119 ((state & kSignalMask) >> kSignalShift))
0120 newstate -= kSignalInc;
0121 CheckState(newstate);
0122 if (state_.compare_exchange_weak(state, newstate,
0123 std::memory_order_acq_rel))
0124 return;
0125 }
0126 }
0127
0128
0129
0130 void Notify(bool notifyAll) {
0131 std::atomic_thread_fence(std::memory_order_seq_cst);
0132 uint64_t state = state_.load(std::memory_order_acquire);
0133 for (;;) {
0134 CheckState(state);
0135 const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
0136 const uint64_t signals = (state & kSignalMask) >> kSignalShift;
0137
0138 if ((state & kStackMask) == kStackMask && waiters == signals) return;
0139 uint64_t newstate;
0140 if (notifyAll) {
0141
0142 newstate =
0143 (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
0144 } else if (signals < waiters) {
0145
0146 newstate = state + kSignalInc;
0147 } else {
0148
0149 Waiter* w = &waiters_[state & kStackMask];
0150 uint64_t next = w->next.load(std::memory_order_relaxed);
0151 newstate = (state & (kWaiterMask | kSignalMask)) | next;
0152 }
0153 CheckState(newstate);
0154 if (state_.compare_exchange_weak(state, newstate,
0155 std::memory_order_acq_rel)) {
0156 if (!notifyAll && (signals < waiters))
0157 return;
0158 if ((state & kStackMask) == kStackMask) return;
0159 Waiter* w = &waiters_[state & kStackMask];
0160 if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
0161 Unpark(w);
0162 return;
0163 }
0164 }
0165 }
0166
0167 class Waiter {
0168 friend class EventCount;
0169
0170
0171 EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
0172 std::mutex mu;
0173 std::condition_variable cv;
0174 uint64_t epoch = 0;
0175 unsigned state = kNotSignaled;
0176 enum {
0177 kNotSignaled,
0178 kWaiting,
0179 kSignaled,
0180 };
0181 };
0182
0183 private:
0184
0185
0186
0187
0188
0189
0190
0191
0192 static const uint64_t kWaiterBits = 14;
0193 static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
0194 static const uint64_t kWaiterShift = kWaiterBits;
0195 static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
0196 << kWaiterShift;
0197 static const uint64_t kWaiterInc = 1ull << kWaiterShift;
0198 static const uint64_t kSignalShift = 2 * kWaiterBits;
0199 static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
0200 << kSignalShift;
0201 static const uint64_t kSignalInc = 1ull << kSignalShift;
0202 static const uint64_t kEpochShift = 3 * kWaiterBits;
0203 static const uint64_t kEpochBits = 64 - kEpochShift;
0204 static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
0205 static const uint64_t kEpochInc = 1ull << kEpochShift;
0206 std::atomic<uint64_t> state_;
0207 MaxSizeVector<Waiter>& waiters_;
0208
0209 static void CheckState(uint64_t state, bool waiter = false) {
0210 static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
0211 const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
0212 const uint64_t signals = (state & kSignalMask) >> kSignalShift;
0213 eigen_plain_assert(waiters >= signals);
0214 eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
0215 eigen_plain_assert(!waiter || waiters > 0);
0216 (void)waiters;
0217 (void)signals;
0218 }
0219
0220 void Park(Waiter* w) {
0221 std::unique_lock<std::mutex> lock(w->mu);
0222 while (w->state != Waiter::kSignaled) {
0223 w->state = Waiter::kWaiting;
0224 w->cv.wait(lock);
0225 }
0226 }
0227
0228 void Unpark(Waiter* w) {
0229 for (Waiter* next; w; w = next) {
0230 uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
0231 next = wnext == kStackMask ? nullptr : &waiters_[wnext];
0232 unsigned state;
0233 {
0234 std::unique_lock<std::mutex> lock(w->mu);
0235 state = w->state;
0236 w->state = Waiter::kSignaled;
0237 }
0238
0239 if (state == Waiter::kWaiting) w->cv.notify_one();
0240 }
0241 }
0242
0243 EventCount(const EventCount&) = delete;
0244 void operator=(const EventCount&) = delete;
0245 };
0246
0247 }
0248
0249 #endif