File indexing completed on 2025-01-18 10:00:11
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <cstring>
0012 #include <vector>
0013
0014 #include "gloo/algorithm.h"
0015 #include "gloo/common/common.h"
0016 #include "gloo/common/logging.h"
0017
0018 namespace gloo {
0019
0020 template <typename T>
0021 class BroadcastOneToAll : public Algorithm {
0022 public:
0023 BroadcastOneToAll(
0024 const std::shared_ptr<Context>& context,
0025 const std::vector<T*>& ptrs,
0026 size_t count,
0027 int rootRank = 0,
0028 int rootPointerRank = 0)
0029 : Algorithm(context),
0030 ptrs_(ptrs),
0031 count_(count),
0032 bytes_(count * sizeof(T)),
0033 rootRank_(rootRank),
0034 rootPointerRank_(rootPointerRank) {
0035 GLOO_ENFORCE_GE(rootRank_, 0);
0036 GLOO_ENFORCE_LT(rootRank_, contextSize_);
0037 GLOO_ENFORCE_GE(rootPointerRank_, 0);
0038 GLOO_ENFORCE_LT(rootPointerRank_, ptrs_.size());
0039
0040
0041 if (contextSize_ > 1) {
0042 auto ptr = ptrs_[rootPointerRank_];
0043 auto slot = context_->nextSlot();
0044 if (contextRank_ == rootRank_) {
0045 sender_.resize(contextSize_);
0046 for (auto i = 0; i < contextSize_; i++) {
0047 if (i == contextRank_) {
0048 continue;
0049 }
0050
0051 sender_[i] = make_unique<forSender>();
0052 auto& pair = context_->getPair(i);
0053 sender_[i]->clearToSendBuffer = pair->createRecvBuffer(
0054 slot, &sender_[i]->dummy, sizeof(sender_[i]->dummy));
0055 sender_[i]->sendBuffer = pair->createSendBuffer(slot, ptr, bytes_);
0056 }
0057 } else {
0058 receiver_ = make_unique<forReceiver>();
0059 auto& rootPair = context_->getPair(rootRank_);
0060 receiver_->clearToSendBuffer = rootPair->createSendBuffer(
0061 slot, &receiver_->dummy, sizeof(receiver_->dummy));
0062 receiver_->recvBuffer = rootPair->createRecvBuffer(slot, ptr, bytes_);
0063 }
0064 }
0065 }
0066
0067 void run() {
0068 if (contextSize_ == 1) {
0069 broadcastLocally();
0070 return;
0071 }
0072
0073 if (contextRank_ == rootRank_) {
0074
0075 for (auto i = 0; i < contextSize_; i++) {
0076 if (i == contextRank_) {
0077 continue;
0078 }
0079 sender_[i]->clearToSendBuffer->waitRecv();
0080 sender_[i]->sendBuffer->send();
0081 }
0082
0083
0084 broadcastLocally();
0085
0086
0087 for (auto i = 0; i < contextSize_; i++) {
0088 if (i == contextRank_) {
0089 continue;
0090 }
0091 sender_[i]->sendBuffer->waitSend();
0092 }
0093 } else {
0094 receiver_->clearToSendBuffer->send();
0095 receiver_->recvBuffer->waitRecv();
0096
0097
0098 broadcastLocally();
0099 }
0100 }
0101
0102 protected:
0103
0104 void broadcastLocally() {
0105 for (auto i = 0; i < ptrs_.size(); i++) {
0106 if (i == rootPointerRank_) {
0107 continue;
0108 }
0109
0110 memcpy(ptrs_[i], ptrs_[rootPointerRank_], bytes_);
0111 }
0112 }
0113
0114 std::vector<T*> ptrs_;
0115 const size_t count_;
0116 const size_t bytes_;
0117 const int rootRank_;
0118 const int rootPointerRank_;
0119
0120
0121 struct forSender {
0122 int dummy;
0123 std::unique_ptr<transport::Buffer> clearToSendBuffer;
0124 std::unique_ptr<transport::Buffer> sendBuffer;
0125 };
0126
0127 std::vector<std::unique_ptr<forSender>> sender_;
0128
0129
0130 struct forReceiver {
0131 int dummy;
0132 std::unique_ptr<transport::Buffer> clearToSendBuffer;
0133 std::unique_ptr<transport::Buffer> recvBuffer;
0134 };
0135
0136 std::unique_ptr<forReceiver> receiver_;
0137 };
0138
0139 }