File indexing completed on 2025-01-18 10:00:11
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "gloo/barrier.h"
0012
0013 namespace gloo {
0014
0015 class BarrierAllToOne : public Barrier {
0016 public:
0017 explicit BarrierAllToOne(
0018 const std::shared_ptr<Context>& context,
0019 int rootRank = 0)
0020 : Barrier(context), rootRank_(rootRank) {
0021 auto slot = this->context_->nextSlot();
0022 if (this->contextRank_ == rootRank_) {
0023
0024 for (int i = 0; i < this->contextSize_; i++) {
0025
0026 if (i == this->contextRank_) {
0027 continue;
0028 }
0029
0030 auto& pair = this->getPair(i);
0031 auto sdata = std::unique_ptr<int>(new int);
0032 auto sbuf = pair->createSendBuffer(slot, sdata.get(), sizeof(int));
0033 sendBuffersData_.push_back(std::move(sdata));
0034 sendBuffers_.push_back(std::move(sbuf));
0035 auto rdata = std::unique_ptr<int>(new int);
0036 auto rbuf = pair->createRecvBuffer(slot, rdata.get(), sizeof(int));
0037 recvBuffersData_.push_back(std::move(rdata));
0038 recvBuffers_.push_back(std::move(rbuf));
0039 }
0040 } else {
0041
0042 auto& pair = this->getPair(rootRank_);
0043 auto sdata = std::unique_ptr<int>(new int);
0044 auto sbuf = pair->createSendBuffer(slot, sdata.get(), sizeof(int));
0045 sendBuffersData_.push_back(std::move(sdata));
0046 sendBuffers_.push_back(std::move(sbuf));
0047 auto rdata = std::unique_ptr<int>(new int);
0048 auto rbuf = pair->createRecvBuffer(slot, rdata.get(), sizeof(int));
0049 recvBuffersData_.push_back(std::move(rdata));
0050 recvBuffers_.push_back(std::move(rbuf));
0051 }
0052 }
0053
0054 void run() {
0055 if (this->contextRank_ == rootRank_) {
0056
0057 for (auto& b : recvBuffers_) {
0058 b->waitRecv();
0059 }
0060
0061 for (auto& b : sendBuffers_) {
0062 b->send();
0063 }
0064 } else {
0065
0066 sendBuffers_[0]->send();
0067
0068 recvBuffers_[0]->waitRecv();
0069 }
0070 }
0071
0072 protected:
0073 const int rootRank_;
0074
0075 std::vector<std::unique_ptr<int>> sendBuffersData_;
0076 std::vector<std::unique_ptr<transport::Buffer>> sendBuffers_;
0077 std::vector<std::unique_ptr<int>> recvBuffersData_;
0078 std::vector<std::unique_ptr<transport::Buffer>> recvBuffers_;
0079 };
0080
0081 }