Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:10

0001 /**
0002  * Copyright (c) 2017-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <stddef.h>
0012 #include <string.h>
0013 
0014 #include "gloo/algorithm.h"
0015 #include "gloo/context.h"
0016 
0017 namespace gloo {
0018 
0019 // AllgatherRing is similar to MPI_Allgather where all processes receive the
0020 // buffers (inPtrs) from all other processes.
0021 // The caller needs to pass a preallocated receive buffer (outPtr) of size equal
0022 // to the context size x the total size of the send buffers (inPtrs) where the
0023 // send buffers of the process with rank = k will be written to
0024 // outPtr[k * number of input buffers * count] consecutively.
0025 template <typename T>
0026 class AllgatherRing : public Algorithm {
0027  public:
0028   AllgatherRing(
0029       const std::shared_ptr<Context>& context,
0030       const std::vector<const T*>& inPtrs,
0031       T* outPtr,
0032       int count)
0033       : Algorithm(context),
0034         inPtrs_(inPtrs),
0035         outPtr_(outPtr),
0036         count_(count),
0037         bytes_(count * sizeof(T)),
0038         inputStride_(count_ * inPtrs_.size()),
0039         leftPair_(this->getLeftPair()),
0040         rightPair_(this->getRightPair()) {
0041     auto slot = this->context_->nextSlot();
0042 
0043     sendDataBuf_ = rightPair_->createSendBuffer(
0044         slot, outPtr_, inPtrs_.size() * context_->size * bytes_);
0045     recvDataBuf_ = leftPair_->createRecvBuffer(
0046         slot, outPtr_, inPtrs_.size() * context_->size * bytes_);
0047 
0048     auto notificationSlot = this->context_->nextSlot();
0049     sendNotificationBuf_ =
0050         leftPair_->createSendBuffer(notificationSlot, &dummy_, sizeof(dummy_));
0051     recvNotificationBuf_ =
0052         rightPair_->createRecvBuffer(notificationSlot, &dummy_, sizeof(dummy_));
0053   }
0054 
0055   virtual ~AllgatherRing() {}
0056 
0057   void run() {
0058     // Short circuit if there is only a single process or the output is empty.
0059     if (this->contextSize_ == 1 || count_ == 0) {
0060       return;
0061     }
0062     const int rank = this->contextRank_;
0063     const int numRounds = this->contextSize_ - 1;
0064 
0065     // Copy local buffers.
0066     for (int i = 0; i < inPtrs_.size(); i++) {
0067       memcpy(outPtr_ + rank * inputStride_ + i * count_, inPtrs_[i], bytes_);
0068     }
0069 
0070     // We send input buffers in order.
0071     for (int i = 0; i < inPtrs_.size(); i++) {
0072       // We start every iteration by sending local buffer.
0073       int inRank = rank;
0074       for (int round = 0; round < numRounds; round++) {
0075         const int sendOffset = inRank * inputStride_ + i * count_;
0076         sendDataBuf_->send(
0077             sendOffset * sizeof(T), bytes_, sendOffset * sizeof(T));
0078         recvDataBuf_->waitRecv();
0079 
0080         // Nodes receive data from the left node in every round and forward it
0081         // to the right node.
0082         inRank = (numRounds - round + rank) % this->contextSize_;
0083 
0084         // Send notification to node on the left that this node is ready for an
0085         // inbox write.
0086         sendNotificationBuf_->send();
0087 
0088         // Wait for notification from node on the right.
0089         recvNotificationBuf_->waitRecv();
0090       }
0091     }
0092   }
0093 
0094  private:
0095   const std::vector<const T*> inPtrs_;
0096   T* outPtr_;
0097   const int count_;
0098   const int bytes_;
0099   const int inputStride_;
0100 
0101   std::unique_ptr<transport::Pair>& leftPair_;
0102   std::unique_ptr<transport::Pair>& rightPair_;
0103 
0104   std::unique_ptr<transport::Buffer> sendDataBuf_;
0105   std::unique_ptr<transport::Buffer> recvDataBuf_;
0106 
0107   int dummy_;
0108 
0109   std::unique_ptr<transport::Buffer> sendNotificationBuf_;
0110   std::unique_ptr<transport::Buffer> recvNotificationBuf_;
0111 };
0112 
0113 }  // namespace gloo