File indexing completed on 2025-01-18 10:00:11
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <stddef.h>
0012 #include <string.h>
0013
0014 #include "gloo/algorithm.h"
0015 #include "gloo/context.h"
0016
0017 namespace gloo {
0018
0019 template <typename T>
0020 class AllreduceRingChunked : public Algorithm {
0021 public:
0022 AllreduceRingChunked(
0023 const std::shared_ptr<Context>& context,
0024 const std::vector<T*>& ptrs,
0025 const int count,
0026 const ReductionFunction<T>* fn = ReductionFunction<T>::sum)
0027 : Algorithm(context),
0028 ptrs_(ptrs),
0029 count_(count),
0030 bytes_(count_ * sizeof(T)),
0031 fn_(fn) {
0032
0033 constexpr unsigned long minSize = 256;
0034 chunks_ = this->contextSize_ * 2;
0035 #ifdef _WIN32
0036 chunkSize_ = std::max((size_t)minSize, (size_t)((count_ + chunks_ - 1) / chunks_));
0037 #else
0038 chunkSize_ = std::max(minSize, (count_ + chunks_ - 1) / chunks_);
0039 #endif
0040 chunkBytes_ = chunkSize_ * sizeof(T);
0041
0042
0043 for (int i = 0; i < 2; i++) {
0044 inbox_[i] = static_cast<T*>(malloc(bytes_));
0045 }
0046
0047 if (count_ == 0 || this->contextSize_ == 1) {
0048 return;
0049 }
0050
0051 auto& leftPair = this->getLeftPair();
0052 auto& rightPair = this->getRightPair();
0053 for (int i = 0; i < 2; i++) {
0054 auto slot = this->context_->nextSlot();
0055
0056
0057 sendDataBuf_[i] =
0058 rightPair->createSendBuffer(slot, ptrs_[0], bytes_);
0059
0060 recvDataBuf_[i] =
0061 leftPair->createRecvBuffer(slot, inbox_[i], chunkBytes_);
0062 }
0063
0064
0065
0066
0067
0068 auto notificationSlot = this->context_->nextSlot();
0069 sendNotificationBuf_ =
0070 leftPair->createSendBuffer(notificationSlot, &dummy_, sizeof(dummy_));
0071 recvNotificationBuf_ =
0072 rightPair->createRecvBuffer(notificationSlot, &dummy_, sizeof(dummy_));
0073 }
0074
0075 virtual ~AllreduceRingChunked() {
0076 for (int i = 0; i < 2; i++) {
0077 if (inbox_[i] != nullptr) {
0078 free(inbox_[i]);
0079 }
0080 }
0081 }
0082
0083 void run() {
0084 if (count_ == 0) {
0085 return;
0086 }
0087
0088
0089 for (int i = 1; i < ptrs_.size(); i++) {
0090 fn_->call(ptrs_[0], ptrs_[i], count_);
0091 }
0092
0093 if (this->contextSize_ == 1) {
0094
0095 for (int i = 1; i < ptrs_.size(); i++) {
0096 memcpy(ptrs_[i], ptrs_[0], bytes_);
0097 }
0098 return;
0099 }
0100
0101
0102 copyChunkAtOffset(2 * this->contextRank_);
0103 copyChunkAtOffset(2 * this->contextRank_ + 1);
0104
0105
0106 for (int round = 2; round < chunks_; round++) {
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125 auto chunkOffset = ((2 * this->contextRank_) - (round & ~0x1) +
0126 (round & 0x1) + chunks_) %
0127 chunks_;
0128 auto offset = chunkOffset * chunkSize_;
0129 auto length = chunkSize_;
0130 if (offset + length <= count_) {
0131
0132 } else if (offset < count_) {
0133
0134 length = count_ - offset;
0135 } else {
0136
0137 length = 0;
0138 }
0139
0140
0141 recvDataBuf_[chunkOffset & 1]->waitRecv();
0142
0143
0144 if (length > 0) {
0145 fn_->call(&ptrs_[0][offset], inbox_[chunkOffset & 1], length);
0146 }
0147
0148
0149
0150 sendNotificationBuf_->send();
0151
0152
0153
0154 recvNotificationBuf_->waitRecv();
0155
0156
0157 copyChunkAtOffset(chunkOffset);
0158 }
0159
0160
0161
0162
0163 for (int round = 0; round < (chunks_ - 2); round++) {
0164 auto chunkOffset = ((2 * this->contextRank_) - (round & ~0x1) +
0165 (round & 0x1) + chunks_) %
0166 chunks_;
0167 auto offset = chunkOffset * chunkSize_;
0168 auto length = chunkSize_;
0169 if (offset + length <= count_) {
0170
0171 } else if (offset < count_) {
0172
0173 length = count_ - offset;
0174 } else {
0175
0176 length = 0;
0177 }
0178
0179
0180 recvDataBuf_[chunkOffset & 1]->waitRecv();
0181
0182
0183 if (length > 0) {
0184 memcpy(&ptrs_[0][offset], inbox_[chunkOffset & 1], length * sizeof(T));
0185 }
0186
0187
0188 if (round < (chunks_ - 4)) {
0189
0190
0191 sendNotificationBuf_->send();
0192
0193
0194
0195 recvNotificationBuf_->waitRecv();
0196
0197
0198 copyChunkAtOffset(chunkOffset);
0199 }
0200 }
0201
0202
0203
0204
0205 sendNotificationBuf_->send();
0206 recvNotificationBuf_->waitRecv();
0207
0208
0209 for (int i = 1; i < ptrs_.size(); i++) {
0210 memcpy(ptrs_[i], ptrs_[0], bytes_);
0211 }
0212 }
0213
0214 protected:
0215 void copyChunkAtOffset(int chunkOffset) {
0216
0217 auto offset = (chunkOffset % chunks_) * chunkSize_;
0218 auto length = chunkSize_;
0219 if (offset + length <= count_) {
0220
0221 } else if (offset < count_) {
0222
0223 length = count_ - offset;
0224 } else {
0225
0226
0227
0228
0229 offset = 0;
0230 length = 1;
0231 }
0232
0233
0234 sendDataBuf_[chunkOffset & 0x1]->send(
0235 offset * sizeof(T), length * sizeof(T));
0236 }
0237
0238 std::vector<T*> ptrs_;
0239 const int count_;
0240 const int bytes_;
0241 const ReductionFunction<T>* fn_;
0242
0243 size_t chunks_;
0244 size_t chunkSize_;
0245 size_t chunkBytes_;
0246
0247 T* inbox_[2];
0248 std::unique_ptr<transport::Buffer> sendDataBuf_[2];
0249 std::unique_ptr<transport::Buffer> recvDataBuf_[2];
0250
0251 int dummy_;
0252 std::unique_ptr<transport::Buffer> sendNotificationBuf_;
0253 std::unique_ptr<transport::Buffer> recvNotificationBuf_;
0254 };
0255
0256 }