Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:11

0001 /**
0002  * Copyright (c) 2017-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <math.h>
0012 #include <stddef.h>
0013 #include <string.h>
0014 
0015 #include "gloo/algorithm.h"
0016 #include "gloo/common/error.h"
0017 #include "gloo/context.h"
0018 
0019 namespace gloo {
0020 
0021 namespace {
0022 // returns the last n bits of ctr reversed
0023 uint32_t reverseLastNBits(uint32_t ctr, uint32_t n) {
0024   uint32_t bitMask = 1;
0025   uint32_t reversed = 0;
0026   while (bitMask < (static_cast<uint32_t>(1) << n)) {
0027     reversed <<= 1;
0028     if (ctr & bitMask) {
0029       reversed |= 1;
0030     }
0031     bitMask <<= 1;
0032   }
0033   return reversed;
0034 }
0035 }
0036 
0037 template <typename T>
0038 class AllreduceHalvingDoubling : public Algorithm {
0039   void initBinaryBlocks() {
0040     uint32_t offset = this->contextSize_;
0041     uint32_t blockSize = 1;
0042     uint32_t currentBlockSize = 0;
0043     uint32_t prevBlockSize = 0;
0044     do {
0045       if (this->contextSize_ & blockSize) {
0046         prevBlockSize = currentBlockSize;
0047         currentBlockSize = blockSize;
0048         offset -= blockSize;
0049         if (myBinaryBlockSize_ != 0) {
0050           nextLargerBlockSize_ = currentBlockSize;
0051           break;
0052         }
0053         if (offset <= this->context_->rank) {
0054           offsetToMyBinaryBlock_ = offset;
0055           myBinaryBlockSize_ = currentBlockSize;
0056           nextSmallerBlockSize_ = prevBlockSize;
0057         }
0058       }
0059       blockSize <<= 1;
0060     } while (offset != 0);
0061 
0062     stepsWithinBlock_ = log2(myBinaryBlockSize_);
0063     rankInBinaryBlock_ = this->context_->rank % myBinaryBlockSize_;
0064   }
0065 
0066  public:
0067   AllreduceHalvingDoubling(
0068       const std::shared_ptr<Context>& context,
0069       const std::vector<T*> ptrs,
0070       const int count,
0071       const ReductionFunction<T>* fn = ReductionFunction<T>::sum)
0072       : Algorithm(context),
0073         ptrs_(ptrs),
0074         count_(count),
0075         bytes_(count_ * sizeof(T)),
0076         steps_(log2(this->contextSize_)),
0077         chunks_(1 << steps_),
0078         chunkSize_((count_ + chunks_ - 1) / chunks_),
0079         chunkBytes_(chunkSize_ * sizeof(T)),
0080         fn_(fn),
0081         recvBuf_(chunkSize_ << steps_),
0082         sendOffsets_(steps_),
0083         recvOffsets_(steps_),
0084         sendCounts_(steps_, 0),
0085         recvCounts_(steps_, 0),
0086         sendCountToLargerBlock_(0),
0087         offsetToMyBinaryBlock_(0),
0088         myBinaryBlockSize_(0),
0089         stepsWithinBlock_(0),
0090         rankInBinaryBlock_(0),
0091         nextSmallerBlockSize_(0),
0092         nextLargerBlockSize_(0) {
0093     if (count_ == 0 || this->contextSize_ == 1) {
0094         return;
0095     }
0096 
0097     initBinaryBlocks();
0098     sendDataBufs_.reserve(stepsWithinBlock_);
0099     recvDataBufs_.reserve(stepsWithinBlock_);
0100     // Reserve max needed number of context slots. Up to 2 slots per process
0101     // pair are needed (one for regular sends and one for notifications). For
0102     // simplicity, the same mapping is used on all processes so that the slots
0103     // trivially match across processes
0104     slotOffset_ = this->context_->nextSlot(
0105         2 * this->contextSize_ * (this->contextSize_ - 1));
0106 
0107     size_t bitmask = 1;
0108     size_t stepChunkSize = chunkSize_ << (steps_ - 1);
0109     size_t stepChunkBytes = stepChunkSize * sizeof(T);
0110     size_t sendOffset = 0;
0111     size_t recvOffset = 0;
0112     size_t bufferOffset = 0; // offset into recvBuf_
0113     for (int i = 0; i < stepsWithinBlock_; i++) {
0114       const int destRank = (this->context_->rank) ^ bitmask;
0115       auto& pair = this->context_->getPair(destRank);
0116       sendOffsets_[i] = sendOffset + ((destRank & bitmask) ? stepChunkSize : 0);
0117       recvOffsets_[i] =
0118           recvOffset + ((this->context_->rank & bitmask) ? stepChunkSize : 0);
0119       if (sendOffsets_[i] < count_) {
0120         // specifies number of elements to send in each step
0121         if (sendOffsets_[i] + stepChunkSize > count_) {
0122           sendCounts_[i] = count_ - sendOffsets_[i];
0123         } else {
0124           sendCounts_[i] = stepChunkSize;
0125         }
0126       }
0127       int myRank = this->context_->rank;
0128       auto slot = slotOffset_ +
0129           2 * (std::min(myRank, destRank) * this->contextSize_ +
0130                std::max(myRank, destRank));
0131       sendDataBufs_.push_back(pair->createSendBuffer(slot, ptrs_[0], bytes_));
0132       if (recvOffsets_[i] < count_) {
0133         // specifies number of elements received in each step
0134         if (recvOffsets_[i] + stepChunkSize > count_) {
0135           recvCounts_[i] = count_ - recvOffsets_[i];
0136         } else {
0137           recvCounts_[i] = stepChunkSize;
0138         }
0139       }
0140       recvDataBufs_.push_back(
0141           pair->createRecvBuffer(
0142               slot, &recvBuf_[bufferOffset], stepChunkBytes));
0143       bufferOffset += stepChunkSize;
0144       if (this->context_->rank & bitmask) {
0145         sendOffset += stepChunkSize;
0146         recvOffset += stepChunkSize;
0147       }
0148       bitmask <<= 1;
0149       stepChunkSize >>= 1;
0150       stepChunkBytes >>= 1;
0151 
0152       ++slot;
0153       sendNotificationBufs_.push_back(
0154           pair->createSendBuffer(slot, &dummy_, sizeof(dummy_)));
0155       recvNotificationBufs_.push_back(
0156           pair->createRecvBuffer(slot, &dummy_, sizeof(dummy_)));
0157     }
0158 
0159     if (nextSmallerBlockSize_ != 0) {
0160       const auto offsetToSmallerBlock =
0161           offsetToMyBinaryBlock_ + myBinaryBlockSize_;
0162       const int destRank =
0163           offsetToSmallerBlock + rankInBinaryBlock_ % nextSmallerBlockSize_;
0164       auto& destPair = this->context_->getPair(destRank);
0165       const auto myRank = this->context_->rank;
0166       const auto slot = slotOffset_ +
0167           2 * (std::min(myRank, destRank) * this->contextSize_ +
0168                std::max(myRank, destRank));
0169       smallerBlockSendDataBuf_ = destPair->createSendBuffer(
0170           slot, ptrs_[0], bytes_);
0171       const auto itemCount = recvCounts_[stepsWithinBlock_ - 1];
0172       if (itemCount > 0) {
0173         smallerBlockRecvDataBuf_ = destPair->createRecvBuffer(
0174             slot, &recvBuf_[bufferOffset], itemCount * sizeof(T));
0175       }
0176     }
0177     if (nextLargerBlockSize_ != 0) {
0178       // Due to the design decision of sending large messages to nearby ranks,
0179       // after the reduce-scatter the reduced chunks end up in an order
0180       // according to the reversed bit pattern of each proc's rank within the
0181       // block. So, instead of ranks 0, 1, 2, ... 7 having blocks A, B, C, D, E,
0182       // F, G, H etc. what you get is A, E, C, G, B, F, D, H. Taking this
0183       // example further, if there is also a smaller binary block of size 2
0184       // (with the reduced blocks A - D, E - H), rank 0 within the smaller block
0185       // will need to send chunks of its buffer to ranks 0, 4, 2, 6 within the
0186       // larger block (in that order) and rank 1 will send to 1, 5, 3, 7. Within
0187       // the reversed bit patterns, this communication is actually 0 to [0, 1,
0188       // 2, 3] and 1 to [4, 5, 6, 7].
0189       const auto offsetToLargerBlock =
0190           offsetToMyBinaryBlock_ - nextLargerBlockSize_;
0191       const auto numSendsAndReceivesToLargerBlock =
0192           nextLargerBlockSize_ / myBinaryBlockSize_;
0193       const auto totalItemsToSend =
0194           stepsWithinBlock_ > 0 ? recvCounts_[stepsWithinBlock_ - 1] : count_;
0195       sendCountToLargerBlock_ = stepChunkSize >>
0196           (static_cast<size_t>(log2(numSendsAndReceivesToLargerBlock)) - 1);
0197       auto srcOrdinal =
0198           reverseLastNBits(rankInBinaryBlock_, log2(myBinaryBlockSize_));
0199       auto destOrdinal = srcOrdinal * numSendsAndReceivesToLargerBlock;
0200       for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
0201         const int destRank = offsetToLargerBlock +
0202             reverseLastNBits(destOrdinal, log2(nextLargerBlockSize_));
0203         auto& destPair = this->context_->getPair(destRank);
0204         const auto myRank = this->context_->rank;
0205         const auto slot = slotOffset_ +
0206             2 * (std::min(myRank, destRank) * this->contextSize_ +
0207                  std::max(myRank, destRank));
0208         largerBlockSendDataBufs_.push_back(
0209             destPair->createSendBuffer(slot, ptrs[0], bytes_));
0210         if (sendCountToLargerBlock_ * i < totalItemsToSend) {
0211           const auto toSend = std::min(
0212               sendCountToLargerBlock_,
0213               totalItemsToSend - sendCountToLargerBlock_ * i);
0214           largerBlockRecvDataBufs_.push_back(
0215               destPair->createRecvBuffer(
0216                   slot, &recvBuf_[bufferOffset], toSend * sizeof(T)));
0217           bufferOffset += toSend;
0218         }
0219         destOrdinal++;
0220       }
0221     }
0222   }
0223 
0224   void run() {
0225     if (count_ == 0) {
0226       return;
0227     }
0228     size_t bufferOffset = 0;
0229     size_t numItems =
0230         stepsWithinBlock_ > 0 ? chunkSize_ << (steps_ - 1) : count_;
0231 
0232     for (int i = 1; i < ptrs_.size(); i++) {
0233       fn_->call(ptrs_[0], ptrs_[i], count_);
0234     }
0235     if (this->contextSize_ == 1) {
0236       // Broadcast ptrs_[0]
0237       for (int i = 1; i < ptrs_.size(); i++) {
0238         memcpy(ptrs_[i], ptrs_[0], bytes_);
0239       }
0240       return;
0241     }
0242 
0243     // Reduce-scatter
0244     for (int i = 0; i < stepsWithinBlock_; i++) {
0245       if (sendOffsets_[i] < count_) {
0246         sendDataBufs_[i]->send(
0247             sendOffsets_[i] * sizeof(T), sendCounts_[i] * sizeof(T));
0248       }
0249       if (recvOffsets_[i] < count_) {
0250         recvDataBufs_[i]->waitRecv();
0251         fn_->call(
0252             &ptrs_[0][recvOffsets_[i]],
0253             &recvBuf_[bufferOffset],
0254             recvCounts_[i]);
0255       }
0256       bufferOffset += numItems;
0257       sendNotificationBufs_[i]->send();
0258       numItems >>= 1;
0259     }
0260 
0261     // Communication across binary blocks for non-power-of-two number of
0262     // processes
0263 
0264     // receive from smaller block
0265     // data sizes same as in the last step of intrablock reduce-scatter above
0266     if (nextSmallerBlockSize_ != 0 && smallerBlockRecvDataBuf_ != nullptr) {
0267       smallerBlockRecvDataBuf_->waitRecv();
0268       fn_->call(
0269           &ptrs_[0][recvOffsets_[stepsWithinBlock_ - 1]],
0270           &recvBuf_[bufferOffset],
0271           recvCounts_[stepsWithinBlock_ - 1]);
0272     }
0273 
0274     const auto totalItemsToSend =
0275         stepsWithinBlock_ > 0 ? recvCounts_[stepsWithinBlock_ - 1] : count_;
0276     if (nextLargerBlockSize_ != 0 && totalItemsToSend != 0) {
0277       // scatter to larger block
0278       const auto offset =
0279           stepsWithinBlock_ > 0 ? recvOffsets_[stepsWithinBlock_ - 1] : 0;
0280       const auto numSendsAndReceivesToLargerBlock =
0281           nextLargerBlockSize_ / myBinaryBlockSize_;
0282       for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
0283         if (sendCountToLargerBlock_ * i < totalItemsToSend) {
0284           largerBlockSendDataBufs_[i]->send(
0285               (offset + i * sendCountToLargerBlock_) * sizeof(T),
0286               std::min(
0287                   sendCountToLargerBlock_,
0288                   totalItemsToSend - sendCountToLargerBlock_ * i) *
0289                   sizeof(T));
0290         }
0291       }
0292       // no notification is needed because the forward and backward messages
0293       // across blocks are serialized in relation to each other
0294 
0295       // receive from larger blocks
0296       for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
0297         if (sendCountToLargerBlock_ * i < totalItemsToSend) {
0298           largerBlockRecvDataBufs_[i]->waitRecv();
0299         }
0300       }
0301       memcpy(
0302           &ptrs_[0][offset],
0303           &recvBuf_[bufferOffset],
0304           totalItemsToSend * sizeof(T));
0305     }
0306 
0307     // Send to smaller block (technically the beginning of allgather)
0308     bool sentToSmallerBlock = false;
0309     if (nextSmallerBlockSize_ != 0) {
0310       if (recvOffsets_[stepsWithinBlock_ - 1] < count_) {
0311         sentToSmallerBlock = true;
0312         smallerBlockSendDataBuf_->send(
0313             recvOffsets_[stepsWithinBlock_ - 1] * sizeof(T),
0314             recvCounts_[stepsWithinBlock_ - 1] * sizeof(T));
0315       }
0316     }
0317 
0318     // Allgather
0319     numItems = chunkSize_ << (steps_ - stepsWithinBlock_);
0320     for (int i = stepsWithinBlock_ - 1; i >= 0; i--) {
0321       // verify that destination rank has received and processed this rank's
0322       // message during the reduce-scatter phase
0323       recvNotificationBufs_[i]->waitRecv();
0324       if (recvOffsets_[i] < count_) {
0325         sendDataBufs_[i]->send(
0326             recvOffsets_[i] * sizeof(T), recvCounts_[i] * sizeof(T));
0327       }
0328       bufferOffset -= numItems;
0329       if (sendOffsets_[i] < count_) {
0330         recvDataBufs_[i]->waitRecv();
0331         memcpy(
0332             &ptrs_[0][sendOffsets_[i]],
0333             &recvBuf_[bufferOffset],
0334             sendCounts_[i] * sizeof(T));
0335       }
0336       numItems <<= 1;
0337 
0338       // Send notification to the pair we just received from that
0339       // we're done dealing with the receive buffer.
0340       sendNotificationBufs_[i]->send();
0341     }
0342 
0343     // Broadcast ptrs_[0]
0344     for (int i = 1; i < ptrs_.size(); i++) {
0345       memcpy(ptrs_[i], ptrs_[0], bytes_);
0346     }
0347 
0348     // Wait for notifications from our peers within the block to make
0349     // sure we can send data immediately without risking overwriting
0350     // data in its receive buffer before it consumed that data.
0351     for (int i = stepsWithinBlock_ - 1; i >= 0; i--) {
0352       recvNotificationBufs_[i]->waitRecv();
0353     }
0354 
0355     // We have to be sure the send to the smaller block (if any) has
0356     // completed before returning. If we don't, the buffer contents may
0357     // be modified by our caller.
0358     if (sentToSmallerBlock) {
0359       smallerBlockSendDataBuf_->waitSend();
0360     }
0361   }
0362 
0363  protected:
0364   std::vector<T*> ptrs_;
0365   const int count_;
0366   const int bytes_;
0367   const size_t steps_;
0368   const size_t chunks_;
0369   const size_t chunkSize_;
0370   const size_t chunkBytes_;
0371   const ReductionFunction<T>* fn_;
0372 
0373   // buffer where data is received prior to being reduced
0374   std::vector<T> recvBuf_;
0375 
0376   // offsets into the data buffer from which to send during the reduce-scatter
0377   // these become the offsets at which the process receives during the allgather
0378   // indexed by step
0379   std::vector<size_t> sendOffsets_;
0380 
0381   // offsets at which data is reduced during the reduce-scatter and sent from in
0382   // the allgather
0383   std::vector<size_t> recvOffsets_;
0384 
0385   std::vector<std::unique_ptr<transport::Buffer>> sendDataBufs_;
0386   std::vector<std::unique_ptr<transport::Buffer>> recvDataBufs_;
0387 
0388   std::unique_ptr<transport::Buffer> smallerBlockSendDataBuf_;
0389   std::unique_ptr<transport::Buffer> smallerBlockRecvDataBuf_;
0390 
0391   std::vector<std::unique_ptr<transport::Buffer>> largerBlockSendDataBufs_;
0392   std::vector<std::unique_ptr<transport::Buffer>> largerBlockRecvDataBufs_;
0393 
0394   std::vector<size_t> sendCounts_;
0395   std::vector<size_t> recvCounts_;
0396   size_t sendCountToLargerBlock_;
0397 
0398   int dummy_;
0399   std::vector<std::unique_ptr<transport::Buffer>> sendNotificationBufs_;
0400   std::vector<std::unique_ptr<transport::Buffer>> recvNotificationBufs_;
0401 
0402   // for non-power-of-two number of processes, partition the processes into
0403   // binary blocks and keep track of which block each process is in, as well as
0404   // the adjoining larger and smaller blocks (with which communication will be
0405   // required)
0406   uint32_t offsetToMyBinaryBlock_;
0407   uint32_t myBinaryBlockSize_;
0408   uint32_t stepsWithinBlock_;
0409   uint32_t rankInBinaryBlock_;
0410   uint32_t nextSmallerBlockSize_;
0411   uint32_t nextLargerBlockSize_;
0412 
0413   int slotOffset_;
0414 };
0415 
0416 } // namespace gloo