Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:11

0001 /**
0002  * Copyright (c) 2017-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include "gloo/barrier.h"
0012 
0013 namespace gloo {
0014 
0015 class BarrierAllToOne : public Barrier {
0016  public:
0017   explicit BarrierAllToOne(
0018       const std::shared_ptr<Context>& context,
0019       int rootRank = 0)
0020       : Barrier(context), rootRank_(rootRank) {
0021     auto slot = this->context_->nextSlot();
0022     if (this->contextRank_ == rootRank_) {
0023       // Create send/recv buffers for every peer
0024       for (int i = 0; i < this->contextSize_; i++) {
0025         // Skip self
0026         if (i == this->contextRank_) {
0027           continue;
0028         }
0029 
0030         auto& pair = this->getPair(i);
0031         auto sdata = std::unique_ptr<int>(new int);
0032         auto sbuf = pair->createSendBuffer(slot, sdata.get(), sizeof(int));
0033         sendBuffersData_.push_back(std::move(sdata));
0034         sendBuffers_.push_back(std::move(sbuf));
0035         auto rdata = std::unique_ptr<int>(new int);
0036         auto rbuf = pair->createRecvBuffer(slot, rdata.get(), sizeof(int));
0037         recvBuffersData_.push_back(std::move(rdata));
0038         recvBuffers_.push_back(std::move(rbuf));
0039       }
0040     } else {
0041       // Create send/recv buffers to/from the root
0042       auto& pair = this->getPair(rootRank_);
0043       auto sdata = std::unique_ptr<int>(new int);
0044       auto sbuf = pair->createSendBuffer(slot, sdata.get(), sizeof(int));
0045       sendBuffersData_.push_back(std::move(sdata));
0046       sendBuffers_.push_back(std::move(sbuf));
0047       auto rdata = std::unique_ptr<int>(new int);
0048       auto rbuf = pair->createRecvBuffer(slot, rdata.get(), sizeof(int));
0049       recvBuffersData_.push_back(std::move(rdata));
0050       recvBuffers_.push_back(std::move(rbuf));
0051     }
0052   }
0053 
0054   void run() {
0055     if (this->contextRank_ == rootRank_) {
0056       // Wait for message from all peers
0057       for (auto& b : recvBuffers_) {
0058         b->waitRecv();
0059       }
0060       // Notify all peers
0061       for (auto& b : sendBuffers_) {
0062         b->send();
0063       }
0064     } else {
0065       // Send message to root
0066       sendBuffers_[0]->send();
0067       // Wait for acknowledgement from root
0068       recvBuffers_[0]->waitRecv();
0069     }
0070   }
0071 
0072  protected:
0073   const int rootRank_;
0074 
0075   std::vector<std::unique_ptr<int>> sendBuffersData_;
0076   std::vector<std::unique_ptr<transport::Buffer>> sendBuffers_;
0077   std::vector<std::unique_ptr<int>> recvBuffersData_;
0078   std::vector<std::unique_ptr<transport::Buffer>> recvBuffers_;
0079 };
0080 
0081 } // namespace gloo