Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:12

0001 /**
0002  * Copyright (c) 2018-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <math.h>
0012 #include <stddef.h>
0013 #include <string.h>
0014 
0015 #include "gloo/algorithm.h"
0016 #include "gloo/common/error.h"
0017 #include "gloo/context.h"
0018 
0019 namespace gloo {
0020 
0021 template <typename T>
0022 class ReduceScatterHalvingDoubling : public Algorithm {
0023   void initBinaryBlocks() {
0024     uint32_t offset = this->contextSize_;
0025     uint32_t blockSize = 1;
0026     uint32_t currentBlockSize = 0;
0027     uint32_t prevBlockSize = 0;
0028     do {
0029       if (this->contextSize_ & blockSize) {
0030         prevBlockSize = currentBlockSize;
0031         currentBlockSize = blockSize;
0032         offset -= blockSize;
0033         if (myBinaryBlockSize_ != 0) {
0034           nextLargerBlockSize_ = currentBlockSize;
0035           break;
0036         }
0037         if (offset <= this->context_->rank) {
0038           offsetToMyBinaryBlock_ = offset;
0039           myBinaryBlockSize_ = currentBlockSize;
0040           nextSmallerBlockSize_ = prevBlockSize;
0041         }
0042       }
0043       blockSize <<= 1;
0044     } while (offset != 0);
0045 
0046     stepsWithinBlock_ = log2(myBinaryBlockSize_);
0047     rankInBinaryBlock_ = this->context_->rank % myBinaryBlockSize_;
0048   }
0049 
0050   // returns the last n bits of ctr reversed
0051   uint32_t reverseLastNBits(uint32_t ctr, uint32_t n) {
0052     uint32_t bitMask = 1;
0053     uint32_t reversed = 0;
0054     while (bitMask < (static_cast<uint32_t>(1) << n)) {
0055       reversed <<= 1;
0056       if (ctr & bitMask) {
0057         reversed |= 1;
0058       }
0059       bitMask <<= 1;
0060     }
0061     return reversed;
0062   }
0063 
0064   struct DistributionMap {
0065     int rank;
0066     size_t offset;
0067     size_t itemCount;
0068     DistributionMap(int dRank, size_t dOffset, size_t dItemCount)
0069         : rank(dRank), offset(dOffset), itemCount(dItemCount) {}
0070 
0071   };
0072 
0073   void getDistributionMap(
0074       size_t srcOffset, size_t srcCount, const std::vector<int>& recvCounts,
0075       bool reorder, std::vector<DistributionMap>& distributionMap) {
0076     if (srcCount == 0) {
0077       return;
0078     }
0079 
0080     size_t destOffset = 0;
0081     auto size =
0082         reorder ? 1 << (int)log2(this->contextSize_) : this->contextSize_;
0083     int start = 0;
0084     for (; start < size; ++start) {
0085       if (destOffset + recvCounts[start] > srcOffset) break;
0086       destOffset += recvCounts[start];
0087     }
0088     destOffset = srcOffset - destOffset;
0089 
0090     auto totalCount = srcCount;
0091     for (int i = start; i < size; ++i) {
0092       auto recvCount = recvCounts[i];
0093       if (destOffset != 0) {
0094         recvCount -= destOffset;
0095         destOffset = 0;
0096       }
0097       auto rank =
0098           reorder ? reverseLastNBits(i, log2(this->contextSize_)) : i;
0099       recvCount = recvCount < totalCount ? recvCount : totalCount;
0100       distributionMap.emplace_back(rank, srcOffset, recvCount);
0101       srcOffset += recvCount;
0102       totalCount -= recvCount;
0103       if (totalCount <= 0) {
0104         break;
0105       }
0106     }
0107   }
0108 
0109  public:
0110   ReduceScatterHalvingDoubling(
0111       const std::shared_ptr<Context>& context,
0112       const std::vector<T*> ptrs,
0113       const int count,
0114       const std::vector<int> recvElems,
0115       const ReductionFunction<T>* fn = ReductionFunction<T>::sum)
0116       : Algorithm(context),
0117         ptrs_(ptrs),
0118         count_(count),
0119         recvElems_(recvElems),
0120         bytes_(count_ * sizeof(T)),
0121         steps_(log2(this->contextSize_)),
0122         chunks_(1 << steps_),
0123         chunkSize_((count_ + chunks_ - 1) / chunks_),
0124         chunkBytes_(chunkSize_ * sizeof(T)),
0125         fn_(fn),
0126         recvBuf_(chunkSize_ << steps_),
0127         recvBufDist_(count_),
0128         sendOffsets_(steps_),
0129         recvOffsets_(steps_),
0130         sendCounts_(steps_, 0),
0131         recvCounts_(steps_, 0),
0132         sendCountToLargerBlock_(0),
0133         offsetToMyBinaryBlock_(0),
0134         myBinaryBlockSize_(0),
0135         stepsWithinBlock_(0),
0136         rankInBinaryBlock_(0),
0137         nextSmallerBlockSize_(0),
0138         nextLargerBlockSize_(0) {
0139     if (this->contextSize_ == 1) {
0140         return;
0141     }
0142 
0143     initBinaryBlocks();
0144     sendDataBufs_.reserve(stepsWithinBlock_);
0145     recvDataBufs_.reserve(stepsWithinBlock_);
0146     // Reserve max needed number of context slots. Up to 4 slots per process
0147     // pair are needed (two for regular sends and two for notifications). For
0148     // simplicity, the same mapping is used on all processes so that the slots
0149     // trivially match across processes
0150     slotOffset_ = this->context_->nextSlot(
0151         4 * this->contextSize_ * (this->contextSize_ - 1));
0152 
0153     size_t bitmask = 1;
0154     size_t stepChunkSize = chunkSize_ << (steps_ - 1);
0155     size_t stepChunkBytes = stepChunkSize * sizeof(T);
0156     size_t sendOffset = 0;
0157     size_t recvOffset = 0;
0158     size_t bufferOffset = 0; // offset into recvBuf_
0159     for (int i = 0; i < stepsWithinBlock_; i++) {
0160       const int destRank = (this->context_->rank) ^ bitmask;
0161       auto& pair = this->context_->getPair(destRank);
0162       sendOffsets_[i] = sendOffset + ((destRank & bitmask) ? stepChunkSize : 0);
0163       recvOffsets_[i] =
0164           recvOffset + ((this->context_->rank & bitmask) ? stepChunkSize : 0);
0165       if (sendOffsets_[i] < count_) {
0166         // specifies number of elements to send in each step
0167         if (sendOffsets_[i] + stepChunkSize > count_) {
0168           sendCounts_[i] = count_ - sendOffsets_[i];
0169         } else {
0170           sendCounts_[i] = stepChunkSize;
0171         }
0172       }
0173       int myRank = this->context_->rank;
0174       auto slot = slotOffset_ +
0175           2 * (std::min(myRank, destRank) * this->contextSize_ +
0176                std::max(myRank, destRank));
0177       sendDataBufs_.push_back(pair->createSendBuffer(slot, ptrs_[0], bytes_));
0178       if (recvOffsets_[i] < count_) {
0179         // specifies number of elements received in each step
0180         if (recvOffsets_[i] + stepChunkSize > count_) {
0181           recvCounts_[i] = count_ - recvOffsets_[i];
0182         } else {
0183           recvCounts_[i] = stepChunkSize;
0184         }
0185       }
0186       recvDataBufs_.push_back(
0187           pair->createRecvBuffer(
0188               slot, &recvBuf_[bufferOffset], stepChunkBytes));
0189       bufferOffset += stepChunkSize;
0190       if (this->context_->rank & bitmask) {
0191         sendOffset += stepChunkSize;
0192         recvOffset += stepChunkSize;
0193       }
0194       bitmask <<= 1;
0195       stepChunkSize >>= 1;
0196       stepChunkBytes >>= 1;
0197 
0198       ++slot;
0199       sendNotificationBufs_.push_back(
0200           pair->createSendBuffer(slot, &dummy_, sizeof(dummy_)));
0201       recvNotificationBufs_.push_back(
0202           pair->createRecvBuffer(slot, &dummy_, sizeof(dummy_)));
0203     }
0204 
0205     const auto myRank = this->context_->rank;
0206     if (nextSmallerBlockSize_ != 0) {
0207       const auto offsetToSmallerBlock =
0208           offsetToMyBinaryBlock_ + myBinaryBlockSize_;
0209       const int destRank =
0210           offsetToSmallerBlock + rankInBinaryBlock_ % nextSmallerBlockSize_;
0211       auto& destPair = this->context_->getPair(destRank);
0212       auto slot = slotOffset_ +
0213           2 * (std::min(myRank, destRank) * this->contextSize_ +
0214                std::max(myRank, destRank));
0215       const auto itemCount = recvCounts_[stepsWithinBlock_ - 1];
0216       if (itemCount > 0) {
0217         smallerBlockRecvDataBuf_ = destPair->createRecvBuffer(
0218             slot, &recvBuf_[bufferOffset], itemCount * sizeof(T));
0219       }
0220     }
0221     if (nextLargerBlockSize_ != 0) {
0222       // Due to the design decision of sending large messages to nearby ranks,
0223       // after the reduce-scatter the reduced chunks end up in an order
0224       // according to the reversed bit pattern of each proc's rank within the
0225       // block. So, instead of ranks 0, 1, 2, ... 7 having blocks A, B, C, D, E,
0226       // F, G, H etc. what you get is A, E, C, G, B, F, D, H. Taking this
0227       // example further, if there is also a smaller binary block of size 2
0228       // (with the reduced blocks A - D, E - H), rank 0 within the smaller block
0229       // will need to send chunks of its buffer to ranks 0, 4, 2, 6 within the
0230       // larger block (in that order) and rank 1 will send to 1, 5, 3, 7. Within
0231       // the reversed bit patterns, this communication is actually 0 to [0, 1,
0232       // 2, 3] and 1 to [4, 5, 6, 7].
0233       const auto offsetToLargerBlock =
0234           offsetToMyBinaryBlock_ - nextLargerBlockSize_;
0235       const auto numSendsAndReceivesToLargerBlock =
0236           nextLargerBlockSize_ / myBinaryBlockSize_;
0237       sendCountToLargerBlock_ = stepChunkSize >>
0238           (static_cast<size_t>(log2(numSendsAndReceivesToLargerBlock)) - 1);
0239       auto srcOrdinal =
0240           reverseLastNBits(rankInBinaryBlock_, log2(myBinaryBlockSize_));
0241       auto destOrdinal = srcOrdinal * numSendsAndReceivesToLargerBlock;
0242       for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
0243         const int destRank = offsetToLargerBlock +
0244             reverseLastNBits(destOrdinal, log2(nextLargerBlockSize_));
0245         auto& destPair = this->context_->getPair(destRank);
0246         auto slot = slotOffset_ +
0247             2 * (std::min(myRank, destRank) * this->contextSize_ +
0248                  std::max(myRank, destRank));
0249         largerBlockSendDataBufs_.push_back(
0250             destPair->createSendBuffer(slot, ptrs[0], bytes_));
0251         destOrdinal++;
0252       }
0253     }
0254 
0255     // Distribution phase: Scatter/distribute based on user-specified
0256     // distribution. Note that, due to nature of recursive halving algorithm
0257     // in the largest binary block, the blocks are not ordered in correct order.
0258     // Enforce correct order by exchanging data between processes p and p',
0259     // where p' is the bit-reverse of p.
0260 
0261     // Sends: The largest binary block ends up having the scattered data.
0262     // Therefore, only those ranks participate in sending messages.
0263     if (nextLargerBlockSize_ == 0 && stepsWithinBlock_ > 0) {
0264       getDistributionMap(
0265           recvOffsets_[stepsWithinBlock_ - 1],
0266           recvCounts_[stepsWithinBlock_ - 1],
0267           recvElems_, false, distMapForSend_);
0268       for (const auto& distMap : distMapForSend_) {
0269         const int destRank = distMap.rank;
0270         if (myRank != destRank) {
0271           auto& destPair = this->context_->getPair(destRank);
0272           auto slot = slotOffset_ + 2 +
0273               2 * (std::min(myRank, destRank) * this->contextSize_ +
0274                    std::max(myRank, destRank));
0275           distSendDataBufs_.push_back(
0276               destPair->createSendBuffer(slot, ptrs_[0], bytes_));
0277           ++slot;
0278           recvNotificationBufs_.push_back(
0279               destPair->createRecvBuffer(slot, &dummy_, sizeof(dummy_)));
0280         }
0281       }
0282     }
0283 
0284     // Recvs: Recv the data from the largest binary block. Based on the
0285     // user-specified distribution, the receivers identify which ranks in the
0286     // binary block they should receive from. Since the data in
0287     // largest binary block is reordered after recursive-halving, the receivers
0288     // reorder the sender info here.
0289     if (recvElems_[myRank] > 0) {
0290       std::vector<int> srcCounts;
0291       size_t rem = count_;
0292       for (int i = 0; i < this->contextSize_; ++i) {
0293         srcCounts.push_back(std::min(chunkSize_, rem));
0294         rem = rem > chunkSize_ ? rem - chunkSize_ : 0;
0295       }
0296       size_t offset = 0;
0297       for (int i = 0; i < myRank; ++i) {
0298         offset += recvElems_[i];
0299       }
0300       getDistributionMap(
0301         offset, recvElems_[myRank], srcCounts, true, distMapForRecv_);
0302       for (const auto& distMap : distMapForRecv_) {
0303         const int srcRank = distMap.rank;
0304         if (myRank != srcRank) {
0305           auto& destPair = this->context_->getPair(srcRank);
0306           auto slot = slotOffset_ + 2 +
0307               2 * (std::min(myRank, srcRank) * this->contextSize_ +
0308                    std::max(myRank, srcRank));
0309           distRecvDataBufs_.push_back(
0310               destPair->createRecvBuffer(
0311                   slot, &recvBufDist_[distMap.offset],
0312                   distMap.itemCount * sizeof(T)));
0313           ++slot;
0314           sendNotificationBufs_.push_back(
0315               destPair->createSendBuffer(slot, &dummy_, sizeof(dummy_)));
0316         }
0317       }
0318     }
0319 
0320   }
0321 
0322   void run() {
0323     size_t bufferOffset = 0;
0324     size_t numItems =
0325         stepsWithinBlock_ > 0 ? chunkSize_ << (steps_ - 1) : count_;
0326 
0327     for (int i = 1; i < ptrs_.size(); i++) {
0328       fn_->call(ptrs_[0], ptrs_[i], count_);
0329     }
0330     if (this->contextSize_ == 1) {
0331       // Broadcast ptrs_[0]
0332       for (int i = 1; i < ptrs_.size(); i++) {
0333         memcpy(ptrs_[i], ptrs_[0], bytes_);
0334       }
0335       return;
0336     }
0337 
0338     // Reduce-scatter (within binary block).
0339     for (int i = 0; i < stepsWithinBlock_; i++) {
0340       if (sendOffsets_[i] < count_) {
0341         sendDataBufs_[i]->send(
0342             sendOffsets_[i] * sizeof(T), sendCounts_[i] * sizeof(T));
0343       }
0344       if (recvOffsets_[i] < count_) {
0345         recvDataBufs_[i]->waitRecv();
0346         fn_->call(
0347             &ptrs_[0][recvOffsets_[i]],
0348             &recvBuf_[bufferOffset],
0349             recvCounts_[i]);
0350       }
0351       bufferOffset += numItems;
0352       sendNotificationBufs_[i]->send();
0353       numItems >>= 1;
0354     }
0355 
0356     // Communication across binary blocks for non-power-of-two number of
0357     // processes
0358 
0359     // receive from smaller block
0360     // data sizes same as in the last step of intrablock reduce-scatter above
0361     int sendNotifyOffset = stepsWithinBlock_;
0362     if (nextSmallerBlockSize_ != 0 && smallerBlockRecvDataBuf_ != nullptr) {
0363       smallerBlockRecvDataBuf_->waitRecv();
0364       fn_->call(
0365           &ptrs_[0][recvOffsets_[stepsWithinBlock_ - 1]],
0366           &recvBuf_[bufferOffset],
0367           recvCounts_[stepsWithinBlock_ - 1]);
0368     }
0369 
0370     const auto totalItemsToSend =
0371         stepsWithinBlock_ > 0 ? recvCounts_[stepsWithinBlock_ - 1] : count_;
0372     if (nextLargerBlockSize_ != 0 && totalItemsToSend != 0) {
0373       // scatter to larger block
0374       const auto offset =
0375           stepsWithinBlock_ > 0 ? recvOffsets_[stepsWithinBlock_ - 1] : 0;
0376       const auto numSendsAndReceivesToLargerBlock =
0377           nextLargerBlockSize_ / myBinaryBlockSize_;
0378       for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
0379         if (sendCountToLargerBlock_ * i < totalItemsToSend) {
0380           largerBlockSendDataBufs_[i]->send(
0381               (offset + i * sendCountToLargerBlock_) * sizeof(T),
0382               std::min(
0383                   sendCountToLargerBlock_,
0384                   totalItemsToSend - sendCountToLargerBlock_ * i) *
0385                   sizeof(T));
0386         }
0387       }
0388     }
0389 
0390     // Distribution phase: Scatter/distribute based on user specified
0391     // distribution.
0392     int index = 0;
0393     for (const auto& distMap : distMapForSend_) {
0394       const auto myRank = this->context_->rank;
0395       const int destRank = distMap.rank;
0396       if (myRank != destRank) {
0397         distSendDataBufs_[index++]->send(
0398           distMap.offset * sizeof(T), distMap.itemCount * sizeof(T));
0399       }
0400     }
0401     index = 0;
0402     bufferOffset = 0;
0403     for (const auto& distMap : distMapForRecv_) {
0404       const auto myRank = this->context_->rank;
0405       const int srcRank = distMap.rank;
0406       if (myRank != srcRank) {
0407         distRecvDataBufs_[index++]->waitRecv();
0408         memcpy(
0409             &ptrs_[0][bufferOffset],
0410             &recvBufDist_[distMap.offset],
0411             distMap.itemCount * sizeof(T));
0412         sendNotificationBufs_[sendNotifyOffset++]->send();
0413       } else {
0414         if (myRank != 0) { // Data already in-place for rank 0.
0415           memcpy(
0416               &ptrs_[0][bufferOffset],
0417               &ptrs_[0][distMap.offset],
0418               distMap.itemCount * sizeof(T));
0419         }
0420       }
0421       bufferOffset += distMap.itemCount;
0422     }
0423 
0424     // Broadcast ptrs_[0]
0425     for (int i = 1; i < ptrs_.size(); i++) {
0426       memcpy(ptrs_[i], ptrs_[0], bytes_);
0427     }
0428 
0429     // Wait for all notifications to make sure we can send data immediately
0430     // without risking overwriting data in its receive buffer before it
0431     // consumed that data.
0432     for (auto& recvNotificationBuf : recvNotificationBufs_) {
0433       recvNotificationBuf->waitRecv();
0434     }
0435   }
0436 
0437  protected:
0438   std::vector<T*> ptrs_;
0439   const int count_;
0440   const std::vector<int> recvElems_;
0441   const int bytes_;
0442   const size_t steps_;
0443   const size_t chunks_;
0444   const size_t chunkSize_;
0445   const size_t chunkBytes_;
0446   const ReductionFunction<T>* fn_;
0447 
0448   // buffer where data is received prior to being reduced
0449   std::vector<T> recvBuf_;
0450 
0451   // buffer where data is received during distribution phase
0452   std::vector<T> recvBufDist_;
0453 
0454   // offsets into the data buffer from which to send during the reduce-scatter
0455   // these become the offsets at which the process receives during the allgather
0456   // indexed by step
0457   std::vector<size_t> sendOffsets_;
0458 
0459   // offsets at which data is reduced during the reduce-scatter and sent from in
0460   // the allgather
0461   std::vector<size_t> recvOffsets_;
0462 
0463   std::vector<std::unique_ptr<transport::Buffer>> sendDataBufs_;
0464   std::vector<std::unique_ptr<transport::Buffer>> recvDataBufs_;
0465 
0466   std::unique_ptr<transport::Buffer> smallerBlockRecvDataBuf_;
0467   std::vector<std::unique_ptr<transport::Buffer>> largerBlockSendDataBufs_;
0468 
0469   std::unique_ptr<transport::Buffer> xchgBlockSendDataBuf_;
0470   std::unique_ptr<transport::Buffer> xchgBlockRecvDataBuf_;
0471 
0472   std::vector<std::unique_ptr<transport::Buffer>> distSendDataBufs_;
0473   std::vector<std::unique_ptr<transport::Buffer>> distRecvDataBufs_;
0474 
0475   std::vector<DistributionMap> distMapForSend_;
0476   std::vector<DistributionMap> distMapForRecv_;
0477 
0478   std::vector<size_t> sendCounts_;
0479   std::vector<size_t> recvCounts_;
0480   size_t sendCountToLargerBlock_;
0481 
0482   int dummy_;
0483   std::vector<std::unique_ptr<transport::Buffer>> sendNotificationBufs_;
0484   std::vector<std::unique_ptr<transport::Buffer>> recvNotificationBufs_;
0485 
0486   // for non-power-of-two number of processes, partition the processes into
0487   // binary blocks and keep track of which block each process is in, as well as
0488   // the adjoining larger and smaller blocks (with which communication will be
0489   // required)
0490   uint32_t offsetToMyBinaryBlock_;
0491   uint32_t myBinaryBlockSize_;
0492   uint32_t stepsWithinBlock_;
0493   uint32_t rankInBinaryBlock_;
0494   uint32_t nextSmallerBlockSize_;
0495   uint32_t nextLargerBlockSize_;
0496 
0497   int slotOffset_;
0498 };
0499 
0500 } // namespace gloo