File indexing completed on 2025-01-18 10:00:11
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <math.h>
0012
0013 #include "gloo/algorithm.h"
0014 #include "gloo/common/logging.h"
0015 #include "gloo/context.h"
0016
0017 namespace gloo {
0018
0019 class PairwiseExchange: public Algorithm {
0020 public:
0021 explicit PairwiseExchange(
0022 const std::shared_ptr<Context>& context,
0023 const int numBytes, const int numDestinations)
0024 : Algorithm(context),
0025 numDestinations_(numDestinations),
0026 bytesPerMsg_(numBytes / numDestinations_),
0027 sendBufferData_(new char[numBytes]),
0028 recvBufferData_(new char[numBytes]) {
0029 GLOO_ENFORCE_EQ(this->contextSize_ % 2, 0);
0030 GLOO_ENFORCE_GT(bytesPerMsg_, 0);
0031 GLOO_ENFORCE_GT(numDestinations_, 0);
0032 GLOO_ENFORCE_LE(numDestinations_, log2(this->contextSize_));
0033
0034
0035 size_t bitmask = 1;
0036 for (int i = 0; i < numDestinations_; i++) {
0037 auto slot = this->context_->nextSlot();
0038 int destination =
0039 this->context_->rank ^ bitmask;
0040 const auto& pair = this->getPair(destination);
0041 sendBuffers_.push_back(pair->createSendBuffer(
0042 slot, &sendBufferData_.get()[i * bytesPerMsg_], bytesPerMsg_));
0043 recvBuffers_.push_back(pair->createRecvBuffer(
0044 slot, &recvBufferData_.get()[i * bytesPerMsg_], bytesPerMsg_));
0045 slot = this->context_->nextSlot();
0046 sendNotificationBufs_.push_back(
0047 pair->createSendBuffer(slot, &dummy_, sizeof(dummy_)));
0048 recvNotificationBufs_.push_back(
0049 pair->createRecvBuffer(slot, &dummy_, sizeof(dummy_)));
0050 bitmask <<= 1;
0051 }
0052 }
0053
0054 void run() {
0055 for (int i = 0; i < numDestinations_; i++) {
0056 sendBuffers_[i]->send();
0057 recvBuffers_[i]->waitRecv();
0058 sendNotificationBufs_[i]->send();
0059 recvNotificationBufs_[i]->waitRecv();
0060 }
0061 }
0062
0063 protected:
0064 const int numDestinations_;
0065 const int bytesPerMsg_;
0066 std::unique_ptr<char> sendBufferData_;
0067 std::unique_ptr<char> recvBufferData_;
0068 std::vector<std::unique_ptr<transport::Buffer>> sendBuffers_;
0069 std::vector<std::unique_ptr<transport::Buffer>> recvBuffers_;
0070 int dummy_;
0071 std::vector<std::unique_ptr<transport::Buffer>> sendNotificationBufs_;
0072 std::vector<std::unique_ptr<transport::Buffer>> recvNotificationBufs_;
0073 };
0074
0075 }