File indexing completed on 2025-01-18 10:00:12
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <math.h>
0012 #include <stddef.h>
0013 #include <string.h>
0014
0015 #include "gloo/algorithm.h"
0016 #include "gloo/common/error.h"
0017 #include "gloo/context.h"
0018
0019 namespace gloo {
0020
0021 template <typename T>
0022 class ReduceScatterHalvingDoubling : public Algorithm {
0023 void initBinaryBlocks() {
0024 uint32_t offset = this->contextSize_;
0025 uint32_t blockSize = 1;
0026 uint32_t currentBlockSize = 0;
0027 uint32_t prevBlockSize = 0;
0028 do {
0029 if (this->contextSize_ & blockSize) {
0030 prevBlockSize = currentBlockSize;
0031 currentBlockSize = blockSize;
0032 offset -= blockSize;
0033 if (myBinaryBlockSize_ != 0) {
0034 nextLargerBlockSize_ = currentBlockSize;
0035 break;
0036 }
0037 if (offset <= this->context_->rank) {
0038 offsetToMyBinaryBlock_ = offset;
0039 myBinaryBlockSize_ = currentBlockSize;
0040 nextSmallerBlockSize_ = prevBlockSize;
0041 }
0042 }
0043 blockSize <<= 1;
0044 } while (offset != 0);
0045
0046 stepsWithinBlock_ = log2(myBinaryBlockSize_);
0047 rankInBinaryBlock_ = this->context_->rank % myBinaryBlockSize_;
0048 }
0049
0050
0051 uint32_t reverseLastNBits(uint32_t ctr, uint32_t n) {
0052 uint32_t bitMask = 1;
0053 uint32_t reversed = 0;
0054 while (bitMask < (static_cast<uint32_t>(1) << n)) {
0055 reversed <<= 1;
0056 if (ctr & bitMask) {
0057 reversed |= 1;
0058 }
0059 bitMask <<= 1;
0060 }
0061 return reversed;
0062 }
0063
0064 struct DistributionMap {
0065 int rank;
0066 size_t offset;
0067 size_t itemCount;
0068 DistributionMap(int dRank, size_t dOffset, size_t dItemCount)
0069 : rank(dRank), offset(dOffset), itemCount(dItemCount) {}
0070
0071 };
0072
0073 void getDistributionMap(
0074 size_t srcOffset, size_t srcCount, const std::vector<int>& recvCounts,
0075 bool reorder, std::vector<DistributionMap>& distributionMap) {
0076 if (srcCount == 0) {
0077 return;
0078 }
0079
0080 size_t destOffset = 0;
0081 auto size =
0082 reorder ? 1 << (int)log2(this->contextSize_) : this->contextSize_;
0083 int start = 0;
0084 for (; start < size; ++start) {
0085 if (destOffset + recvCounts[start] > srcOffset) break;
0086 destOffset += recvCounts[start];
0087 }
0088 destOffset = srcOffset - destOffset;
0089
0090 auto totalCount = srcCount;
0091 for (int i = start; i < size; ++i) {
0092 auto recvCount = recvCounts[i];
0093 if (destOffset != 0) {
0094 recvCount -= destOffset;
0095 destOffset = 0;
0096 }
0097 auto rank =
0098 reorder ? reverseLastNBits(i, log2(this->contextSize_)) : i;
0099 recvCount = recvCount < totalCount ? recvCount : totalCount;
0100 distributionMap.emplace_back(rank, srcOffset, recvCount);
0101 srcOffset += recvCount;
0102 totalCount -= recvCount;
0103 if (totalCount <= 0) {
0104 break;
0105 }
0106 }
0107 }
0108
0109 public:
0110 ReduceScatterHalvingDoubling(
0111 const std::shared_ptr<Context>& context,
0112 const std::vector<T*> ptrs,
0113 const int count,
0114 const std::vector<int> recvElems,
0115 const ReductionFunction<T>* fn = ReductionFunction<T>::sum)
0116 : Algorithm(context),
0117 ptrs_(ptrs),
0118 count_(count),
0119 recvElems_(recvElems),
0120 bytes_(count_ * sizeof(T)),
0121 steps_(log2(this->contextSize_)),
0122 chunks_(1 << steps_),
0123 chunkSize_((count_ + chunks_ - 1) / chunks_),
0124 chunkBytes_(chunkSize_ * sizeof(T)),
0125 fn_(fn),
0126 recvBuf_(chunkSize_ << steps_),
0127 recvBufDist_(count_),
0128 sendOffsets_(steps_),
0129 recvOffsets_(steps_),
0130 sendCounts_(steps_, 0),
0131 recvCounts_(steps_, 0),
0132 sendCountToLargerBlock_(0),
0133 offsetToMyBinaryBlock_(0),
0134 myBinaryBlockSize_(0),
0135 stepsWithinBlock_(0),
0136 rankInBinaryBlock_(0),
0137 nextSmallerBlockSize_(0),
0138 nextLargerBlockSize_(0) {
0139 if (this->contextSize_ == 1) {
0140 return;
0141 }
0142
0143 initBinaryBlocks();
0144 sendDataBufs_.reserve(stepsWithinBlock_);
0145 recvDataBufs_.reserve(stepsWithinBlock_);
0146
0147
0148
0149
0150 slotOffset_ = this->context_->nextSlot(
0151 4 * this->contextSize_ * (this->contextSize_ - 1));
0152
0153 size_t bitmask = 1;
0154 size_t stepChunkSize = chunkSize_ << (steps_ - 1);
0155 size_t stepChunkBytes = stepChunkSize * sizeof(T);
0156 size_t sendOffset = 0;
0157 size_t recvOffset = 0;
0158 size_t bufferOffset = 0;
0159 for (int i = 0; i < stepsWithinBlock_; i++) {
0160 const int destRank = (this->context_->rank) ^ bitmask;
0161 auto& pair = this->context_->getPair(destRank);
0162 sendOffsets_[i] = sendOffset + ((destRank & bitmask) ? stepChunkSize : 0);
0163 recvOffsets_[i] =
0164 recvOffset + ((this->context_->rank & bitmask) ? stepChunkSize : 0);
0165 if (sendOffsets_[i] < count_) {
0166
0167 if (sendOffsets_[i] + stepChunkSize > count_) {
0168 sendCounts_[i] = count_ - sendOffsets_[i];
0169 } else {
0170 sendCounts_[i] = stepChunkSize;
0171 }
0172 }
0173 int myRank = this->context_->rank;
0174 auto slot = slotOffset_ +
0175 2 * (std::min(myRank, destRank) * this->contextSize_ +
0176 std::max(myRank, destRank));
0177 sendDataBufs_.push_back(pair->createSendBuffer(slot, ptrs_[0], bytes_));
0178 if (recvOffsets_[i] < count_) {
0179
0180 if (recvOffsets_[i] + stepChunkSize > count_) {
0181 recvCounts_[i] = count_ - recvOffsets_[i];
0182 } else {
0183 recvCounts_[i] = stepChunkSize;
0184 }
0185 }
0186 recvDataBufs_.push_back(
0187 pair->createRecvBuffer(
0188 slot, &recvBuf_[bufferOffset], stepChunkBytes));
0189 bufferOffset += stepChunkSize;
0190 if (this->context_->rank & bitmask) {
0191 sendOffset += stepChunkSize;
0192 recvOffset += stepChunkSize;
0193 }
0194 bitmask <<= 1;
0195 stepChunkSize >>= 1;
0196 stepChunkBytes >>= 1;
0197
0198 ++slot;
0199 sendNotificationBufs_.push_back(
0200 pair->createSendBuffer(slot, &dummy_, sizeof(dummy_)));
0201 recvNotificationBufs_.push_back(
0202 pair->createRecvBuffer(slot, &dummy_, sizeof(dummy_)));
0203 }
0204
0205 const auto myRank = this->context_->rank;
0206 if (nextSmallerBlockSize_ != 0) {
0207 const auto offsetToSmallerBlock =
0208 offsetToMyBinaryBlock_ + myBinaryBlockSize_;
0209 const int destRank =
0210 offsetToSmallerBlock + rankInBinaryBlock_ % nextSmallerBlockSize_;
0211 auto& destPair = this->context_->getPair(destRank);
0212 auto slot = slotOffset_ +
0213 2 * (std::min(myRank, destRank) * this->contextSize_ +
0214 std::max(myRank, destRank));
0215 const auto itemCount = recvCounts_[stepsWithinBlock_ - 1];
0216 if (itemCount > 0) {
0217 smallerBlockRecvDataBuf_ = destPair->createRecvBuffer(
0218 slot, &recvBuf_[bufferOffset], itemCount * sizeof(T));
0219 }
0220 }
0221 if (nextLargerBlockSize_ != 0) {
0222
0223
0224
0225
0226
0227
0228
0229
0230
0231
0232
0233 const auto offsetToLargerBlock =
0234 offsetToMyBinaryBlock_ - nextLargerBlockSize_;
0235 const auto numSendsAndReceivesToLargerBlock =
0236 nextLargerBlockSize_ / myBinaryBlockSize_;
0237 sendCountToLargerBlock_ = stepChunkSize >>
0238 (static_cast<size_t>(log2(numSendsAndReceivesToLargerBlock)) - 1);
0239 auto srcOrdinal =
0240 reverseLastNBits(rankInBinaryBlock_, log2(myBinaryBlockSize_));
0241 auto destOrdinal = srcOrdinal * numSendsAndReceivesToLargerBlock;
0242 for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
0243 const int destRank = offsetToLargerBlock +
0244 reverseLastNBits(destOrdinal, log2(nextLargerBlockSize_));
0245 auto& destPair = this->context_->getPair(destRank);
0246 auto slot = slotOffset_ +
0247 2 * (std::min(myRank, destRank) * this->contextSize_ +
0248 std::max(myRank, destRank));
0249 largerBlockSendDataBufs_.push_back(
0250 destPair->createSendBuffer(slot, ptrs[0], bytes_));
0251 destOrdinal++;
0252 }
0253 }
0254
0255
0256
0257
0258
0259
0260
0261
0262
0263 if (nextLargerBlockSize_ == 0 && stepsWithinBlock_ > 0) {
0264 getDistributionMap(
0265 recvOffsets_[stepsWithinBlock_ - 1],
0266 recvCounts_[stepsWithinBlock_ - 1],
0267 recvElems_, false, distMapForSend_);
0268 for (const auto& distMap : distMapForSend_) {
0269 const int destRank = distMap.rank;
0270 if (myRank != destRank) {
0271 auto& destPair = this->context_->getPair(destRank);
0272 auto slot = slotOffset_ + 2 +
0273 2 * (std::min(myRank, destRank) * this->contextSize_ +
0274 std::max(myRank, destRank));
0275 distSendDataBufs_.push_back(
0276 destPair->createSendBuffer(slot, ptrs_[0], bytes_));
0277 ++slot;
0278 recvNotificationBufs_.push_back(
0279 destPair->createRecvBuffer(slot, &dummy_, sizeof(dummy_)));
0280 }
0281 }
0282 }
0283
0284
0285
0286
0287
0288
0289 if (recvElems_[myRank] > 0) {
0290 std::vector<int> srcCounts;
0291 size_t rem = count_;
0292 for (int i = 0; i < this->contextSize_; ++i) {
0293 srcCounts.push_back(std::min(chunkSize_, rem));
0294 rem = rem > chunkSize_ ? rem - chunkSize_ : 0;
0295 }
0296 size_t offset = 0;
0297 for (int i = 0; i < myRank; ++i) {
0298 offset += recvElems_[i];
0299 }
0300 getDistributionMap(
0301 offset, recvElems_[myRank], srcCounts, true, distMapForRecv_);
0302 for (const auto& distMap : distMapForRecv_) {
0303 const int srcRank = distMap.rank;
0304 if (myRank != srcRank) {
0305 auto& destPair = this->context_->getPair(srcRank);
0306 auto slot = slotOffset_ + 2 +
0307 2 * (std::min(myRank, srcRank) * this->contextSize_ +
0308 std::max(myRank, srcRank));
0309 distRecvDataBufs_.push_back(
0310 destPair->createRecvBuffer(
0311 slot, &recvBufDist_[distMap.offset],
0312 distMap.itemCount * sizeof(T)));
0313 ++slot;
0314 sendNotificationBufs_.push_back(
0315 destPair->createSendBuffer(slot, &dummy_, sizeof(dummy_)));
0316 }
0317 }
0318 }
0319
0320 }
0321
0322 void run() {
0323 size_t bufferOffset = 0;
0324 size_t numItems =
0325 stepsWithinBlock_ > 0 ? chunkSize_ << (steps_ - 1) : count_;
0326
0327 for (int i = 1; i < ptrs_.size(); i++) {
0328 fn_->call(ptrs_[0], ptrs_[i], count_);
0329 }
0330 if (this->contextSize_ == 1) {
0331
0332 for (int i = 1; i < ptrs_.size(); i++) {
0333 memcpy(ptrs_[i], ptrs_[0], bytes_);
0334 }
0335 return;
0336 }
0337
0338
0339 for (int i = 0; i < stepsWithinBlock_; i++) {
0340 if (sendOffsets_[i] < count_) {
0341 sendDataBufs_[i]->send(
0342 sendOffsets_[i] * sizeof(T), sendCounts_[i] * sizeof(T));
0343 }
0344 if (recvOffsets_[i] < count_) {
0345 recvDataBufs_[i]->waitRecv();
0346 fn_->call(
0347 &ptrs_[0][recvOffsets_[i]],
0348 &recvBuf_[bufferOffset],
0349 recvCounts_[i]);
0350 }
0351 bufferOffset += numItems;
0352 sendNotificationBufs_[i]->send();
0353 numItems >>= 1;
0354 }
0355
0356
0357
0358
0359
0360
0361 int sendNotifyOffset = stepsWithinBlock_;
0362 if (nextSmallerBlockSize_ != 0 && smallerBlockRecvDataBuf_ != nullptr) {
0363 smallerBlockRecvDataBuf_->waitRecv();
0364 fn_->call(
0365 &ptrs_[0][recvOffsets_[stepsWithinBlock_ - 1]],
0366 &recvBuf_[bufferOffset],
0367 recvCounts_[stepsWithinBlock_ - 1]);
0368 }
0369
0370 const auto totalItemsToSend =
0371 stepsWithinBlock_ > 0 ? recvCounts_[stepsWithinBlock_ - 1] : count_;
0372 if (nextLargerBlockSize_ != 0 && totalItemsToSend != 0) {
0373
0374 const auto offset =
0375 stepsWithinBlock_ > 0 ? recvOffsets_[stepsWithinBlock_ - 1] : 0;
0376 const auto numSendsAndReceivesToLargerBlock =
0377 nextLargerBlockSize_ / myBinaryBlockSize_;
0378 for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
0379 if (sendCountToLargerBlock_ * i < totalItemsToSend) {
0380 largerBlockSendDataBufs_[i]->send(
0381 (offset + i * sendCountToLargerBlock_) * sizeof(T),
0382 std::min(
0383 sendCountToLargerBlock_,
0384 totalItemsToSend - sendCountToLargerBlock_ * i) *
0385 sizeof(T));
0386 }
0387 }
0388 }
0389
0390
0391
0392 int index = 0;
0393 for (const auto& distMap : distMapForSend_) {
0394 const auto myRank = this->context_->rank;
0395 const int destRank = distMap.rank;
0396 if (myRank != destRank) {
0397 distSendDataBufs_[index++]->send(
0398 distMap.offset * sizeof(T), distMap.itemCount * sizeof(T));
0399 }
0400 }
0401 index = 0;
0402 bufferOffset = 0;
0403 for (const auto& distMap : distMapForRecv_) {
0404 const auto myRank = this->context_->rank;
0405 const int srcRank = distMap.rank;
0406 if (myRank != srcRank) {
0407 distRecvDataBufs_[index++]->waitRecv();
0408 memcpy(
0409 &ptrs_[0][bufferOffset],
0410 &recvBufDist_[distMap.offset],
0411 distMap.itemCount * sizeof(T));
0412 sendNotificationBufs_[sendNotifyOffset++]->send();
0413 } else {
0414 if (myRank != 0) {
0415 memcpy(
0416 &ptrs_[0][bufferOffset],
0417 &ptrs_[0][distMap.offset],
0418 distMap.itemCount * sizeof(T));
0419 }
0420 }
0421 bufferOffset += distMap.itemCount;
0422 }
0423
0424
0425 for (int i = 1; i < ptrs_.size(); i++) {
0426 memcpy(ptrs_[i], ptrs_[0], bytes_);
0427 }
0428
0429
0430
0431
0432 for (auto& recvNotificationBuf : recvNotificationBufs_) {
0433 recvNotificationBuf->waitRecv();
0434 }
0435 }
0436
0437 protected:
0438 std::vector<T*> ptrs_;
0439 const int count_;
0440 const std::vector<int> recvElems_;
0441 const int bytes_;
0442 const size_t steps_;
0443 const size_t chunks_;
0444 const size_t chunkSize_;
0445 const size_t chunkBytes_;
0446 const ReductionFunction<T>* fn_;
0447
0448
0449 std::vector<T> recvBuf_;
0450
0451
0452 std::vector<T> recvBufDist_;
0453
0454
0455
0456
0457 std::vector<size_t> sendOffsets_;
0458
0459
0460
0461 std::vector<size_t> recvOffsets_;
0462
0463 std::vector<std::unique_ptr<transport::Buffer>> sendDataBufs_;
0464 std::vector<std::unique_ptr<transport::Buffer>> recvDataBufs_;
0465
0466 std::unique_ptr<transport::Buffer> smallerBlockRecvDataBuf_;
0467 std::vector<std::unique_ptr<transport::Buffer>> largerBlockSendDataBufs_;
0468
0469 std::unique_ptr<transport::Buffer> xchgBlockSendDataBuf_;
0470 std::unique_ptr<transport::Buffer> xchgBlockRecvDataBuf_;
0471
0472 std::vector<std::unique_ptr<transport::Buffer>> distSendDataBufs_;
0473 std::vector<std::unique_ptr<transport::Buffer>> distRecvDataBufs_;
0474
0475 std::vector<DistributionMap> distMapForSend_;
0476 std::vector<DistributionMap> distMapForRecv_;
0477
0478 std::vector<size_t> sendCounts_;
0479 std::vector<size_t> recvCounts_;
0480 size_t sendCountToLargerBlock_;
0481
0482 int dummy_;
0483 std::vector<std::unique_ptr<transport::Buffer>> sendNotificationBufs_;
0484 std::vector<std::unique_ptr<transport::Buffer>> recvNotificationBufs_;
0485
0486
0487
0488
0489
0490 uint32_t offsetToMyBinaryBlock_;
0491 uint32_t myBinaryBlockSize_;
0492 uint32_t stepsWithinBlock_;
0493 uint32_t rankInBinaryBlock_;
0494 uint32_t nextSmallerBlockSize_;
0495 uint32_t nextLargerBlockSize_;
0496
0497 int slotOffset_;
0498 };
0499
0500 }