Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:11

0001 /**
0002  * Copyright (c) 2017-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <cstring>
0012 #include <vector>
0013 
0014 #include "gloo/algorithm.h"
0015 #include "gloo/common/common.h"
0016 #include "gloo/common/logging.h"
0017 
0018 namespace gloo {
0019 
0020 template <typename T>
0021 class BroadcastOneToAll : public Algorithm {
0022  public:
0023   BroadcastOneToAll(
0024       const std::shared_ptr<Context>& context,
0025       const std::vector<T*>& ptrs,
0026       size_t count,
0027       int rootRank = 0,
0028       int rootPointerRank = 0)
0029       : Algorithm(context),
0030         ptrs_(ptrs),
0031         count_(count),
0032         bytes_(count * sizeof(T)),
0033         rootRank_(rootRank),
0034         rootPointerRank_(rootPointerRank) {
0035     GLOO_ENFORCE_GE(rootRank_, 0);
0036     GLOO_ENFORCE_LT(rootRank_, contextSize_);
0037     GLOO_ENFORCE_GE(rootPointerRank_, 0);
0038     GLOO_ENFORCE_LT(rootPointerRank_, ptrs_.size());
0039 
0040     // Setup pairs/buffers for sender/receivers
0041     if (contextSize_ > 1) {
0042       auto ptr = ptrs_[rootPointerRank_];
0043       auto slot = context_->nextSlot();
0044       if (contextRank_ == rootRank_) {
0045         sender_.resize(contextSize_);
0046         for (auto i = 0; i < contextSize_; i++) {
0047           if (i == contextRank_) {
0048             continue;
0049           }
0050 
0051           sender_[i] = make_unique<forSender>();
0052           auto& pair = context_->getPair(i);
0053           sender_[i]->clearToSendBuffer = pair->createRecvBuffer(
0054               slot, &sender_[i]->dummy, sizeof(sender_[i]->dummy));
0055           sender_[i]->sendBuffer = pair->createSendBuffer(slot, ptr, bytes_);
0056         }
0057       } else {
0058         receiver_ = make_unique<forReceiver>();
0059         auto& rootPair = context_->getPair(rootRank_);
0060         receiver_->clearToSendBuffer = rootPair->createSendBuffer(
0061             slot, &receiver_->dummy, sizeof(receiver_->dummy));
0062         receiver_->recvBuffer = rootPair->createRecvBuffer(slot, ptr, bytes_);
0063       }
0064     }
0065   }
0066 
0067   void run() {
0068     if (contextSize_ == 1) {
0069       broadcastLocally();
0070       return;
0071     }
0072 
0073     if (contextRank_ == rootRank_) {
0074       // Fire off send operations after receiving clear to send
0075       for (auto i = 0; i < contextSize_; i++) {
0076         if (i == contextRank_) {
0077           continue;
0078         }
0079         sender_[i]->clearToSendBuffer->waitRecv();
0080         sender_[i]->sendBuffer->send();
0081       }
0082 
0083       // Broadcast locally while sends are happening
0084       broadcastLocally();
0085 
0086       // Wait for all send operations to complete
0087       for (auto i = 0; i < contextSize_; i++) {
0088         if (i == contextRank_) {
0089           continue;
0090         }
0091         sender_[i]->sendBuffer->waitSend();
0092       }
0093     } else {
0094       receiver_->clearToSendBuffer->send();
0095       receiver_->recvBuffer->waitRecv();
0096 
0097       // Broadcast locally after receiving from root
0098       broadcastLocally();
0099     }
0100   }
0101 
0102  protected:
0103   // Broadcast from root pointer to other pointers
0104   void broadcastLocally() {
0105     for (auto i = 0; i < ptrs_.size(); i++) {
0106       if (i == rootPointerRank_) {
0107         continue;
0108       }
0109 
0110       memcpy(ptrs_[i], ptrs_[rootPointerRank_], bytes_);
0111     }
0112   }
0113 
0114   std::vector<T*> ptrs_;
0115   const size_t count_;
0116   const size_t bytes_;
0117   const int rootRank_;
0118   const int rootPointerRank_;
0119 
0120   // For the sender (root)
0121   struct forSender {
0122     int dummy;
0123     std::unique_ptr<transport::Buffer> clearToSendBuffer;
0124     std::unique_ptr<transport::Buffer> sendBuffer;
0125   };
0126 
0127   std::vector<std::unique_ptr<forSender>> sender_;
0128 
0129   // For all receivers
0130   struct forReceiver {
0131     int dummy;
0132     std::unique_ptr<transport::Buffer> clearToSendBuffer;
0133     std::unique_ptr<transport::Buffer> recvBuffer;
0134   };
0135 
0136   std::unique_ptr<forReceiver> receiver_;
0137 };
0138 
0139 } // namespace gloo