Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:11

0001 /**
0002  * Copyright (c) 2017-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <stddef.h>
0012 #include <string.h>
0013 
0014 #include "gloo/algorithm.h"
0015 #include "gloo/context.h"
0016 
0017 namespace gloo {
0018 
0019 template <typename T>
0020 class AllreduceRing : public Algorithm {
0021  public:
0022   AllreduceRing(
0023       const std::shared_ptr<Context>& context,
0024       const std::vector<T*>& ptrs,
0025       const int count,
0026       const ReductionFunction<T>* fn = ReductionFunction<T>::sum)
0027       : Algorithm(context),
0028         ptrs_(ptrs),
0029         count_(count),
0030         bytes_(count_ * sizeof(T)),
0031         fn_(fn) {
0032     inbox_ = static_cast<T*>(malloc(bytes_));
0033     outbox_ = static_cast<T*>(malloc(bytes_));
0034 
0035     if (this->contextSize_ == 1) {
0036       return;
0037     }
0038 
0039     auto& leftPair = this->getLeftPair();
0040     auto& rightPair = this->getRightPair();
0041     auto slot = this->context_->nextSlot();
0042 
0043     // Buffer to send to (rank+1).
0044     sendDataBuf_ = rightPair->createSendBuffer(slot, outbox_, bytes_);
0045 
0046     // Buffer that (rank-1) writes to.
0047     recvDataBuf_ = leftPair->createRecvBuffer(slot, inbox_, bytes_);
0048 
0049     // Dummy buffers for localized barrier.
0050     // Before sending to the right, we only need to know that the node
0051     // on the right is done using the inbox that's about to be written
0052     // into. No need for a global barrier.
0053     auto notificationSlot = this->context_->nextSlot();
0054     sendNotificationBuf_ =
0055       leftPair->createSendBuffer(notificationSlot, &dummy_, sizeof(dummy_));
0056     recvNotificationBuf_ =
0057       rightPair->createRecvBuffer(notificationSlot, &dummy_, sizeof(dummy_));
0058   }
0059 
0060   virtual ~AllreduceRing() {
0061     if (inbox_ != nullptr) {
0062       free(inbox_);
0063     }
0064     if (outbox_ != nullptr) {
0065       free(outbox_);
0066     }
0067   }
0068 
0069   void run() {
0070     if (count_ == 0) {
0071       return;
0072     }
0073 
0074     // Reduce specified pointers into ptrs_[0]
0075     for (int i = 1; i < ptrs_.size(); i++) {
0076       fn_->call(ptrs_[0], ptrs_[i], count_);
0077     }
0078 
0079     // Intialize outbox with locally reduced values
0080     memcpy(outbox_, ptrs_[0], bytes_);
0081 
0082     int numRounds = this->contextSize_ - 1;
0083     for (int round = 0; round < numRounds; round++) {
0084       // Initiate write to inbox of node on the right
0085       sendDataBuf_->send();
0086 
0087       // Wait for inbox write from node on the left
0088       recvDataBuf_->waitRecv();
0089 
0090       // Reduce
0091       fn_->call(ptrs_[0], inbox_, count_);
0092 
0093       // Wait for outbox write to complete
0094       sendDataBuf_->waitSend();
0095 
0096       // Prepare for next round if necessary
0097       if (round < (numRounds - 1)) {
0098         memcpy(outbox_, inbox_, bytes_);
0099       }
0100 
0101       // Send notification to node on the left that
0102       // this node is ready for an inbox write.
0103       sendNotificationBuf_->send();
0104 
0105       // Wait for notification from node on the right
0106       recvNotificationBuf_->waitRecv();
0107     }
0108 
0109     // Broadcast ptrs_[0]
0110     for (int i = 1; i < ptrs_.size(); i++) {
0111       memcpy(ptrs_[i], ptrs_[0], bytes_);
0112     }
0113   }
0114 
0115  protected:
0116   std::vector<T*> ptrs_;
0117   const int count_;
0118   const int bytes_;
0119   const ReductionFunction<T>* fn_;
0120 
0121   T* inbox_;
0122   T* outbox_;
0123   std::unique_ptr<transport::Buffer> sendDataBuf_;
0124   std::unique_ptr<transport::Buffer> recvDataBuf_;
0125 
0126   int dummy_;
0127   std::unique_ptr<transport::Buffer> sendNotificationBuf_;
0128   std::unique_ptr<transport::Buffer> recvNotificationBuf_;
0129 };
0130 
0131 } // namespace gloo