File indexing completed on 2025-01-18 10:00:10
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <stddef.h>
0012 #include <string.h>
0013
0014 #include "gloo/algorithm.h"
0015 #include "gloo/context.h"
0016
0017 namespace gloo {
0018
0019
0020
0021
0022
0023
0024
0025 template <typename T>
0026 class AllgatherRing : public Algorithm {
0027 public:
0028 AllgatherRing(
0029 const std::shared_ptr<Context>& context,
0030 const std::vector<const T*>& inPtrs,
0031 T* outPtr,
0032 int count)
0033 : Algorithm(context),
0034 inPtrs_(inPtrs),
0035 outPtr_(outPtr),
0036 count_(count),
0037 bytes_(count * sizeof(T)),
0038 inputStride_(count_ * inPtrs_.size()),
0039 leftPair_(this->getLeftPair()),
0040 rightPair_(this->getRightPair()) {
0041 auto slot = this->context_->nextSlot();
0042
0043 sendDataBuf_ = rightPair_->createSendBuffer(
0044 slot, outPtr_, inPtrs_.size() * context_->size * bytes_);
0045 recvDataBuf_ = leftPair_->createRecvBuffer(
0046 slot, outPtr_, inPtrs_.size() * context_->size * bytes_);
0047
0048 auto notificationSlot = this->context_->nextSlot();
0049 sendNotificationBuf_ =
0050 leftPair_->createSendBuffer(notificationSlot, &dummy_, sizeof(dummy_));
0051 recvNotificationBuf_ =
0052 rightPair_->createRecvBuffer(notificationSlot, &dummy_, sizeof(dummy_));
0053 }
0054
0055 virtual ~AllgatherRing() {}
0056
0057 void run() {
0058
0059 if (this->contextSize_ == 1 || count_ == 0) {
0060 return;
0061 }
0062 const int rank = this->contextRank_;
0063 const int numRounds = this->contextSize_ - 1;
0064
0065
0066 for (int i = 0; i < inPtrs_.size(); i++) {
0067 memcpy(outPtr_ + rank * inputStride_ + i * count_, inPtrs_[i], bytes_);
0068 }
0069
0070
0071 for (int i = 0; i < inPtrs_.size(); i++) {
0072
0073 int inRank = rank;
0074 for (int round = 0; round < numRounds; round++) {
0075 const int sendOffset = inRank * inputStride_ + i * count_;
0076 sendDataBuf_->send(
0077 sendOffset * sizeof(T), bytes_, sendOffset * sizeof(T));
0078 recvDataBuf_->waitRecv();
0079
0080
0081
0082 inRank = (numRounds - round + rank) % this->contextSize_;
0083
0084
0085
0086 sendNotificationBuf_->send();
0087
0088
0089 recvNotificationBuf_->waitRecv();
0090 }
0091 }
0092 }
0093
0094 private:
0095 const std::vector<const T*> inPtrs_;
0096 T* outPtr_;
0097 const int count_;
0098 const int bytes_;
0099 const int inputStride_;
0100
0101 std::unique_ptr<transport::Pair>& leftPair_;
0102 std::unique_ptr<transport::Pair>& rightPair_;
0103
0104 std::unique_ptr<transport::Buffer> sendDataBuf_;
0105 std::unique_ptr<transport::Buffer> recvDataBuf_;
0106
0107 int dummy_;
0108
0109 std::unique_ptr<transport::Buffer> sendNotificationBuf_;
0110 std::unique_ptr<transport::Buffer> recvNotificationBuf_;
0111 };
0112
0113 }