Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:11

0001 /**
0002  * Copyright (c) 2017-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <math.h>
0012 
0013 #include "gloo/algorithm.h"
0014 #include "gloo/common/logging.h"
0015 #include "gloo/context.h"
0016 
0017 namespace gloo {
0018 
0019 class PairwiseExchange: public Algorithm {
0020  public:
0021   explicit PairwiseExchange(
0022       const std::shared_ptr<Context>& context,
0023       const int numBytes, const int numDestinations)
0024       : Algorithm(context),
0025         numDestinations_(numDestinations),
0026         bytesPerMsg_(numBytes / numDestinations_),
0027         sendBufferData_(new char[numBytes]),
0028         recvBufferData_(new char[numBytes]) {
0029     GLOO_ENFORCE_EQ(this->contextSize_ % 2, 0);
0030     GLOO_ENFORCE_GT(bytesPerMsg_, 0);
0031     GLOO_ENFORCE_GT(numDestinations_, 0);
0032     GLOO_ENFORCE_LE(numDestinations_, log2(this->contextSize_));
0033 
0034     // Processes communicate bidirectionally in pairs
0035     size_t bitmask = 1;
0036     for (int i = 0; i < numDestinations_; i++) {
0037       auto slot = this->context_->nextSlot();
0038       int destination =
0039         this->context_->rank ^ bitmask;
0040       const auto& pair = this->getPair(destination);
0041       sendBuffers_.push_back(pair->createSendBuffer(
0042           slot, &sendBufferData_.get()[i * bytesPerMsg_], bytesPerMsg_));
0043       recvBuffers_.push_back(pair->createRecvBuffer(
0044           slot, &recvBufferData_.get()[i * bytesPerMsg_], bytesPerMsg_));
0045       slot = this->context_->nextSlot();
0046       sendNotificationBufs_.push_back(
0047           pair->createSendBuffer(slot, &dummy_, sizeof(dummy_)));
0048       recvNotificationBufs_.push_back(
0049           pair->createRecvBuffer(slot, &dummy_, sizeof(dummy_)));
0050       bitmask <<= 1;
0051     }
0052   }
0053 
0054   void run() {
0055     for (int i = 0; i < numDestinations_; i++) {
0056       sendBuffers_[i]->send();
0057       recvBuffers_[i]->waitRecv();
0058       sendNotificationBufs_[i]->send();
0059       recvNotificationBufs_[i]->waitRecv();
0060     }
0061   }
0062 
0063  protected:
0064   const int numDestinations_;
0065   const int bytesPerMsg_;
0066   std::unique_ptr<char> sendBufferData_;
0067   std::unique_ptr<char> recvBufferData_;
0068   std::vector<std::unique_ptr<transport::Buffer>> sendBuffers_;
0069   std::vector<std::unique_ptr<transport::Buffer>> recvBuffers_;
0070   int dummy_;
0071   std::vector<std::unique_ptr<transport::Buffer>> sendNotificationBufs_;
0072   std::vector<std::unique_ptr<transport::Buffer>> recvNotificationBufs_;
0073 };
0074 
0075 } // namespace gloo