File indexing completed on 2025-01-18 10:00:11
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <stddef.h>
0012 #include <string.h>
0013
0014 #include "gloo/algorithm.h"
0015 #include "gloo/context.h"
0016
0017 namespace gloo {
0018
0019 template <typename T>
0020 class AllreduceRing : public Algorithm {
0021 public:
0022 AllreduceRing(
0023 const std::shared_ptr<Context>& context,
0024 const std::vector<T*>& ptrs,
0025 const int count,
0026 const ReductionFunction<T>* fn = ReductionFunction<T>::sum)
0027 : Algorithm(context),
0028 ptrs_(ptrs),
0029 count_(count),
0030 bytes_(count_ * sizeof(T)),
0031 fn_(fn) {
0032 inbox_ = static_cast<T*>(malloc(bytes_));
0033 outbox_ = static_cast<T*>(malloc(bytes_));
0034
0035 if (this->contextSize_ == 1) {
0036 return;
0037 }
0038
0039 auto& leftPair = this->getLeftPair();
0040 auto& rightPair = this->getRightPair();
0041 auto slot = this->context_->nextSlot();
0042
0043
0044 sendDataBuf_ = rightPair->createSendBuffer(slot, outbox_, bytes_);
0045
0046
0047 recvDataBuf_ = leftPair->createRecvBuffer(slot, inbox_, bytes_);
0048
0049
0050
0051
0052
0053 auto notificationSlot = this->context_->nextSlot();
0054 sendNotificationBuf_ =
0055 leftPair->createSendBuffer(notificationSlot, &dummy_, sizeof(dummy_));
0056 recvNotificationBuf_ =
0057 rightPair->createRecvBuffer(notificationSlot, &dummy_, sizeof(dummy_));
0058 }
0059
0060 virtual ~AllreduceRing() {
0061 if (inbox_ != nullptr) {
0062 free(inbox_);
0063 }
0064 if (outbox_ != nullptr) {
0065 free(outbox_);
0066 }
0067 }
0068
0069 void run() {
0070 if (count_ == 0) {
0071 return;
0072 }
0073
0074
0075 for (int i = 1; i < ptrs_.size(); i++) {
0076 fn_->call(ptrs_[0], ptrs_[i], count_);
0077 }
0078
0079
0080 memcpy(outbox_, ptrs_[0], bytes_);
0081
0082 int numRounds = this->contextSize_ - 1;
0083 for (int round = 0; round < numRounds; round++) {
0084
0085 sendDataBuf_->send();
0086
0087
0088 recvDataBuf_->waitRecv();
0089
0090
0091 fn_->call(ptrs_[0], inbox_, count_);
0092
0093
0094 sendDataBuf_->waitSend();
0095
0096
0097 if (round < (numRounds - 1)) {
0098 memcpy(outbox_, inbox_, bytes_);
0099 }
0100
0101
0102
0103 sendNotificationBuf_->send();
0104
0105
0106 recvNotificationBuf_->waitRecv();
0107 }
0108
0109
0110 for (int i = 1; i < ptrs_.size(); i++) {
0111 memcpy(ptrs_[i], ptrs_[0], bytes_);
0112 }
0113 }
0114
0115 protected:
0116 std::vector<T*> ptrs_;
0117 const int count_;
0118 const int bytes_;
0119 const ReductionFunction<T>* fn_;
0120
0121 T* inbox_;
0122 T* outbox_;
0123 std::unique_ptr<transport::Buffer> sendDataBuf_;
0124 std::unique_ptr<transport::Buffer> recvDataBuf_;
0125
0126 int dummy_;
0127 std::unique_ptr<transport::Buffer> sendNotificationBuf_;
0128 std::unique_ptr<transport::Buffer> recvNotificationBuf_;
0129 };
0130
0131 }