Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:11

0001 /**
0002  * Copyright (c) 2017-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <math.h>
0012 #include <stddef.h>
0013 #include <string.h>
0014 #include <cstring>
0015 #include <iomanip>
0016 #include <iostream>
0017 #include <unordered_map>
0018 #include <algorithm>
0019 
0020 #include "gloo/algorithm.h"
0021 #include "gloo/common/error.h"
0022 #include "gloo/context.h"
0023 
0024 /**
0025  * This file contains following classes:
0026  *
0027  * **Node**
0028  *
0029  * This is a helper class. We create one object for each node
0030  * participating in allreduce operation with respective rank. It enacapsulates
0031  * information related to processing of elements. That is, how many elements
0032  * need to be sent from what offset or received by a particular node and be
0033  * reduced at what offset etc.
0034  *
0035  * **Group**
0036  *
0037  * This is another helper class. As part of each step of processing
0038  * we divide nodes into multiple groups. This class helps track properties of
0039  * that group. Such as, which nodes are part of the group, how many elements
0040  * collectively all nodes need to process and at what offset etc.
0041  *
0042  * **AllreduceBcube**
0043  *
0044  * This is the main allreduce implementation. Bcube is a scheme where nodes are
0045  * divided in groups. In reduce-scatter stage, in each group, a node peers with
0046  * `base - 1` other nodes. In the first step data is reduced between nodes
0047  * within the group. In the next step each node of a group peers with `base - 1`
0048  * nodes from other exclusively different groups. Since each node would start
0049  * with reduced data communicating with it would be like communicating with
0050  * `base` number of nodes/groups from the previous step. This process continues
0051  * until all the groups are covered and to be able to do that the algorithm
0052  * would have log_base(n) number of steps. Each step the node reduces
0053  * totalNumElems_ / (base^step) amount of elements. At the end of reduce-scatter
0054  * stage each node would have reduced a chunk of elements. Now, in all-gather
0055  * we follow a reverse process of reduce-scatter to communicate the reduced data
0056  * with other nodes.
0057  */
0058 namespace gloo {
0059 
0060 namespace bcube {
0061 
0062 /**
0063  * Helps capture all information related to a node
0064  */
0065 class Node {
0066  public:
0067   explicit Node(int rank, int steps) : rank_(rank) {
0068     for (int i = 0; i < steps; ++i) {
0069       peersPerStep_.emplace_back();
0070     }
0071     numElemsPerStep_.resize(steps);
0072     ptrOffsetPerStep_.resize(steps);
0073   }
0074   /**
0075    * Get the rank of this node
0076    */
0077   int getRank() const {
0078     return rank_;
0079   }
0080   /**
0081    * Used to record all the peer nodes, the number of elements to process and
0082    * the offset from which data in the original ptr buffer will be processed by
0083    * this node in a particular step. This is to be done as part of setup()
0084    * function only.
0085    * @param step The step for which we are recording attributes
0086    * @param peerRanks All peer ranks. This would contain self too so need to
0087    * @param numElems The number of elements this node will be processing in the
0088    * @param offset The offset in the ptrs array
0089    *  filter that out.
0090    */
0091   void setPerStepAttributes(
0092       int step,
0093       const std::vector<int>& peerRanks,
0094       int numElems,
0095       int offset) {
0096     for (int peerRank : peerRanks) {
0097       if (peerRank != rank_) {
0098         peersPerStep_[step].emplace_back(peerRank);
0099       }
0100     }
0101     numElemsPerStep_[step] = numElems;
0102     ptrOffsetPerStep_[step] = offset;
0103   }
0104   /**
0105    * Get all the nodes this node peers with in a particular step
0106    * @param step The step for which we need to get peers
0107    * @return List of ranks of all peer nodes
0108    */
0109   const std::vector<int>& getPeersPerStep(int step) const {
0110     return peersPerStep_[step];
0111   }
0112   /**
0113    * Get count of elements this node needs to process in a specified the step
0114    * @param step The step for which we are querying count
0115    */
0116   int getNumElemsPerStep(int step) const {
0117     return numElemsPerStep_[step];
0118   }
0119   /**
0120    * Get offset to ptrs array this node needs to start processing from in the
0121    * specified step
0122    * @param step The step for which we are querying offset
0123    */
0124   int getPtrOffsetPerStep(int step) const {
0125     return ptrOffsetPerStep_[step];
0126   }
0127 
0128  private:
0129   /**
0130    * Rank of this node
0131    */
0132   const int rank_;
0133   /**
0134    * A vector of a list of ranks (value) of nodes this node would
0135    * peer with in a step (index)
0136    */
0137   std::vector<std::vector<int>> peersPerStep_;
0138   /**
0139    * A vector of number of elements (value) this node needs to process in a step
0140    * (index). This could be the number of elements to be received and reduced by
0141    * a node and correspondingly sent by its peers during a step of
0142    * reduce-scatter stage, or, similarly, the number of elements received and
0143    * copied in the ptrs_ array by a node and correspondingly sent by it's peer
0144    * during a step of all-gather stage.
0145    */
0146   std::vector<int> numElemsPerStep_;
0147   /**
0148    * A vector of offset (value) within the ptrs_ array from which data needs to
0149    * be processed by this node in a step (index). This would be used by peers to
0150    * send data from ptrs_ array to this node and used with reduce function
0151    * during reduce-scatter phase or during all-gather to send elements to peers
0152    * from ptrs_ array.
0153    */
0154   std::vector<int> ptrOffsetPerStep_;
0155 };
0156 
0157 /**
0158  * Helps capture all information related to a peer group
0159  */
0160 class Group {
0161  public:
0162   Group(
0163       int step,
0164       const Node& firstNode,
0165       int peerDistance,
0166       int base,
0167       int nodes,
0168       int totalNumElems)
0169       : nodeRanks_(
0170             getNodeRanks(firstNode.getRank(), peerDistance, base, nodes)),
0171         ptrOffset_((0 == step) ? 0 : firstNode.getPtrOffsetPerStep(step - 1)),
0172         numElems_(computeNumElems(
0173             step,
0174             firstNode,
0175             nodeRanks_.size(),
0176             totalNumElems)) {}
0177   /**
0178    * Simple getter for all the nodes in the group
0179    * @return List of ranks of nodes in the group
0180    */
0181   const std::vector<int>& getNodeRanks() const {
0182     return nodeRanks_;
0183   }
0184   /**
0185    * Get the offset from which the group should process data
0186    * @return Offset in the ptrs array
0187    */
0188   int getPtrOffset() const {
0189     return ptrOffset_;
0190   }
0191   /**
0192    * Get the number of elements this group is supposed to process
0193    * @return Count of elements (in ptr or receive buffers)
0194    */
0195   int getNumElems() const {
0196     return numElems_;
0197   }
0198 
0199  private:
0200   const std::vector<int> nodeRanks_;
0201   const int ptrOffset_;
0202   const int numElems_;
0203   /**
0204    * Computes the number of elements this group needs to process. If this is the
0205    * first step we start with all elements. For subsequent steps it's number
0206    * of elements processed by single node in previous step. If this value is
0207    * smaller than number of peers in the group simply use number of peers as the
0208    * count so that at least one element is exchanged. Also, note that in this
0209    * case some nodes may end up duplicating the work as the ptrOffset wraps
0210    * around the totalNumElems_ in updateGroupNodes() function.
0211    * @param step The current step
0212    * @param firstNode The first node in the group
0213    * @param peers The total number of peers in the group
0214    * @count The total number of elements to be processed by this node
0215    * @return The number of elements to be processed by this group
0216    */
0217   static int
0218   computeNumElems(int step, const Node& firstNode, int peers, int count) {
0219     int groupCount =
0220         (0 == step) ? count : firstNode.getNumElemsPerStep(step - 1);
0221     return std::max(groupCount, peers);
0222   }
0223   /**
0224    * Determines all the nodes in a group in a particular step
0225    * @param peerDistance This is the distance between rank of each peer in the
0226    *   group
0227    * @return List of ranks of nodes in the group
0228    */
0229   std::vector<int>
0230   getNodeRanks(int firstNodeRank, int peerDistance, int base, int nodes) const {
0231     std::vector<int> groupPeers;
0232     for (int i = 0; i < base; ++i) {
0233       int peerRank = firstNodeRank + i * peerDistance;
0234       if (peerRank < nodes) {
0235         groupPeers.emplace_back(peerRank);
0236       }
0237     }
0238     return groupPeers;
0239   }
0240 };
0241 
0242 } // namespace bcube
0243 
0244 /**
0245  * This is another implemenation of allreduce algorithm where-in we divide
0246  * nodes into group of base_ nodes instead of a factor of two used by
0247  * allreduce_halving_doubling. It basically shards the data based on the base
0248  * and does a reduce-scatter followed by all-gather very much like the
0249  * allreduce_halving_doubling algorithm.
0250  *
0251  * This algorithm can handle cases where we don't really have a complete
0252  * hypercube, i.e. number of nodes != c * base ^ x where c and x are some
0253  * contants; however,  the number of nodes must be divisible by base.
0254  */
0255 template <typename T>
0256 class AllreduceBcube : public Algorithm {
0257  public:
0258   AllreduceBcube(
0259       const std::shared_ptr<Context>& context,
0260       const std::vector<T*> ptrs,
0261       const int count,
0262       const ReductionFunction<T>* fn = ReductionFunction<T>::sum)
0263       : Algorithm(context),
0264         myRank_(this->context_->rank),
0265         base_(this->context_->base),
0266         nodes_(this->contextSize_),
0267         ptrs_(ptrs),
0268         totalNumElems_(count),
0269         bytes_(totalNumElems_ * sizeof(T)),
0270         steps_(computeSteps(nodes_, base_)),
0271         fn_(fn),
0272         recvBufs_(steps_ * base_) {
0273     if (totalNumElems_ == 0 || nodes_ == 1) {
0274       return;
0275     }
0276     setupNodes();
0277     /*
0278      * Reserve max needed number of context slots. Up to 2 slots per process
0279      * pair are needed (one for regular sends and one for notifications). For
0280      * simplicity, the same mapping is used on all processes so that the slots
0281      * trivially match across processes
0282      */
0283     int slotOffset_ = this->context_->nextSlot(
0284         2 * this->contextSize_ * (this->contextSize_ - 1));
0285 
0286     int bufIdx = 0;
0287     for (int step = 0; step < steps_; ++step) {
0288       for (int destRank : getPeersPerStep(myRank_, step)) {
0289         int recvSize = std::max(
0290             getNumElemsPerStep(myRank_, step),
0291             getNumElemsPerStep(destRank, step));
0292         auto& pair = this->context_->getPair(destRank);
0293         auto slot = slotOffset_ +
0294             2 * (std::min(myRank_, destRank) * nodes_ +
0295                  std::max(myRank_, destRank));
0296         sendDataBufs_[destRank] =
0297             pair->createSendBuffer(slot, ptrs_[0], bytes_);
0298         recvBufs_[bufIdx].resize(recvSize);
0299         recvDataBufs_[destRank] = pair->createRecvBuffer(
0300             slot, &recvBufs_[bufIdx][0], recvSize * sizeof(T));
0301         recvBufIdx_[destRank] = bufIdx;
0302         ++bufIdx;
0303         ++slot;
0304         sendNotificationBufs_[destRank] =
0305             pair->createSendBuffer(slot, &dummy_, sizeof(dummy_));
0306         recvNotificationBufs_[destRank] =
0307             pair->createRecvBuffer(slot, &dummy_, sizeof(dummy_));
0308       } // nodes
0309     } // steps
0310   }
0311 
0312 #ifdef DEBUG
0313 #define DEBUG_PRINT_STAGE(stage) \
0314   do {                           \
0315     printStageBuffer(stage);     \
0316   } while (false)
0317 #define DEBUG_PRINT_SEND(stage)                                              \
0318   do {                                                                       \
0319     printStepBuffer(                                                         \
0320         stage, step, myRank_, destRank, &ptrs_[0][0], sendCount, ptrOffset); \
0321   } while (false)
0322 #define DEBUG_PRINT_RECV(stage)              \
0323   do {                                       \
0324     printStepBuffer(                         \
0325         stage,                               \
0326         step,                                \
0327         srcRank,                             \
0328         myRank_,                             \
0329         &recvBufs_[recvBufIdx_[srcRank]][0], \
0330         recvCount);                          \
0331   } while (false)
0332 #else
0333 #define DEBUG_PRINT_STAGE(stage)
0334 #define DEBUG_PRINT_SEND(stage)
0335 #define DEBUG_PRINT_RECV(stage)
0336 #endif
0337 
0338   void run() {
0339     if (totalNumElems_ == 0) {
0340       return;
0341     }
0342     // Local reduce operation
0343     for (int i = 1; i < ptrs_.size(); i++) {
0344       fn_->call(ptrs_[0], ptrs_[i], totalNumElems_);
0345     }
0346 
0347     if (nodes_ == 1) {
0348       // Broadcast ptrs_[0]
0349       for (int i = 1; i < ptrs_.size(); i++) {
0350         memcpy(ptrs_[i], ptrs_[0], bytes_);
0351       }
0352       return;
0353     }
0354 
0355     // Reduce-scatter
0356     DEBUG_PRINT_STAGE("start");
0357     for (int step = 0; step < steps_; ++step) {
0358       for (int destRank : getPeersPerStep(myRank_, step)) {
0359         int sendCount = getNumElemsPerStep(destRank, step);
0360         int ptrOffset = getPtrOffsetPerStep(destRank, step);
0361         DEBUG_PRINT_SEND("reduce-scatter");
0362         sendDataBufs_[destRank]->send(
0363             ptrOffset * sizeof(T), sendCount * sizeof(T));
0364       } // sends within group
0365 
0366       for (int srcRank : getPeersPerStep(myRank_, step)) {
0367         int recvCount = getNumElemsPerStep(myRank_, step);
0368         int ptrOffset = getPtrOffsetPerStep(myRank_, step);
0369         recvDataBufs_[srcRank]->waitRecv();
0370         DEBUG_PRINT_RECV("reduce-scatter");
0371         fn_->call(
0372             &ptrs_[0][ptrOffset],
0373             &recvBufs_[recvBufIdx_[srcRank]][0],
0374             recvCount);
0375         /*
0376          * Send notification to the pair we just received from that
0377          * we're done dealing with the receive buffer.
0378          */
0379         sendNotificationBufs_[srcRank]->send();
0380       } // recvs within group and reduces
0381     } // reduce-scatter steps
0382 
0383     DEBUG_PRINT_STAGE("reduce-scattered");
0384 
0385     // All-gather
0386     for (int step = steps_ - 1; step >= 0; --step) {
0387       for (int destRank : getPeersPerStep(myRank_, step)) {
0388         int sendCount = getNumElemsPerStep(myRank_, step);
0389         int ptrOffset = getPtrOffsetPerStep(myRank_, step);
0390         /*
0391          * Wait for notification from the peer to make sure we can send data
0392          * without risking any overwrites in its receive buffer.
0393          */
0394         recvNotificationBufs_[destRank]->waitRecv();
0395         DEBUG_PRINT_SEND("all-gather");
0396         sendDataBufs_[destRank]->send(
0397             ptrOffset * sizeof(T), sendCount * sizeof(T));
0398       }
0399 
0400       for (int srcRank : getPeersPerStep(myRank_, step)) {
0401         int recvCount = getNumElemsPerStep(srcRank, step);
0402         int ptrOffset = getPtrOffsetPerStep(srcRank, step);
0403         recvDataBufs_[srcRank]->waitRecv();
0404         DEBUG_PRINT_RECV("all-gather");
0405         std::memcpy(
0406             &ptrs_[0][ptrOffset],
0407             &recvBufs_[recvBufIdx_[srcRank]][0],
0408             recvCount * sizeof(T));
0409         if (step == 0) {
0410           /*
0411            * Send notification to the pair we just received from that
0412            * we're done dealing with the receive buffer.
0413            */
0414           sendNotificationBufs_[srcRank]->send();
0415         }
0416       } // recvs within group and reduces
0417     } // all-gather steps
0418 
0419     DEBUG_PRINT_STAGE("all-reduced");
0420 
0421     // Broadcast ptrs_[0]
0422     for (int i = 1; i < ptrs_.size(); i++) {
0423       memcpy(ptrs_[i], ptrs_[0], bytes_);
0424     }
0425 
0426     /*
0427      * Wait for notifications from our peers within the block to make
0428      * sure we can send data immediately without risking overwriting
0429      * data in its receive buffer before it consumed that data.
0430      */
0431     for (int peerRank : getPeersPerStep(myRank_, 0)) {
0432       recvNotificationBufs_[peerRank]->waitRecv();
0433     }
0434   }
0435 
0436  private:
0437   /**
0438    * Number of words to be printed per section by printElems
0439    */
0440   static constexpr int wordsPerSection = 4;
0441   /**
0442    * Number of words to be printed per line by printElems
0443    */
0444   static constexpr int wordsPerLine = 4 * wordsPerSection;
0445   /**
0446    * Just a reference to current nodes rank
0447    */
0448   const int myRank_{0};
0449   /**
0450    * Number of nodes in a typical group
0451    */
0452   const int base_{2};
0453   /**
0454    * Total number of nodes
0455    */
0456   const int nodes_{0};
0457   /**
0458    * Pointer to the elements
0459    */
0460   const std::vector<T*> ptrs_{nullptr};
0461   /**
0462    * Total number of elements to process
0463    */
0464   const int totalNumElems_{0};
0465   /**
0466    * Total number of bytes to process
0467    */
0468   const int bytes_{0};
0469   /**
0470    * Total number of steps
0471    */
0472   const size_t steps_{0};
0473   /**
0474    * The reduce operation function
0475    */
0476   const ReductionFunction<T>* fn_{nullptr};
0477   /**
0478    * List of actual buffers for incoming data
0479    */
0480   std::vector<std::vector<T>> recvBufs_;
0481   /**
0482    * Map of rank to incoming buffer index in recvBufs
0483    */
0484   std::unordered_map<int, int> recvBufIdx_;
0485   /**
0486    * Map of rank to Buffer which will be used for outgoing data
0487    */
0488   std::unordered_map<int, std::unique_ptr<transport::Buffer>> sendDataBufs_;
0489   /**
0490    * Map of rank to Buffer which will be used for incoming data
0491    */
0492   std::unordered_map<int, std::unique_ptr<transport::Buffer>> recvDataBufs_;
0493   /**
0494    * Dummy data used to signal end of one setup
0495    */
0496   int dummy_;
0497   /**
0498    * Map of rank to Buffer which will be used for outgoing synchronization data
0499    * at end of reduce-scatter and all-gather
0500    */
0501   std::unordered_map<int, std::unique_ptr<transport::Buffer>>
0502       sendNotificationBufs_;
0503   /**
0504    * Map of rank to Buffer which will be used for incoming synchronization data
0505    * at end of reduce-scatter and all-gather
0506    */
0507   std::unordered_map<int, std::unique_ptr<transport::Buffer>>
0508       recvNotificationBufs_;
0509   /**
0510    * List of all the nodes
0511    */
0512   std::vector<bcube::Node> allNodes_;
0513   /**
0514    * Compute number of steps required in reduce-scatter and all-gather (each)
0515    * @param nodes The total number of nodes
0516    * @para peers The maximum number of peers in a group
0517    */
0518   static int computeSteps(int nodes, int peers) {
0519     float lg2n = log2(nodes);
0520     float lg2p = log2(peers);
0521     return ceil(lg2n / lg2p);
0522   }
0523   /**
0524    * Basically a gate to make sure only the right node(s) print logs
0525    * @param rank Rank of the current node
0526    */
0527   static bool printCheck(int /*rank*/) {
0528     return false;
0529   }
0530   /**
0531    * Prints a break given the offset of an element about to be printed
0532    * @param p Pointer to the elements
0533    * @param x The current offset to the pointer to words
0534    */
0535   static void printBreak(T* p, int x) {
0536     if (0 == x % wordsPerLine) {
0537       std::cout << std::endl
0538                 << &p[x] << " " << std::setfill('0') << std::setw(5) << x
0539                 << ": ";
0540     } else if (0 == x % wordsPerSection) {
0541       std::cout << "- ";
0542     }
0543   }
0544   /**
0545    * Pretty prints a list of elements
0546    * @param p Pointer to the elements
0547    * @param count The number of elements to be printed
0548    * @param start The offset from which to print
0549    */
0550   static void printElems(T* p, int count, int start = 0) {
0551     auto alignedStart = (start / wordsPerLine) * wordsPerLine;
0552     for (int x = alignedStart; x < start + count; ++x) {
0553       printBreak(p, x);
0554       if (x < start) {
0555         std::cout << "..... ";
0556       } else {
0557         std::cout << std::setfill('0') << std::setw(5) << p[x] << " ";
0558       }
0559     }
0560   }
0561   /**
0562    * Prints contents in the ptrs array at a particular stage
0563    * @param msg Custom message to be printed
0564    */
0565   void printStageBuffer(const std::string& msg) {
0566     if (printCheck(myRank_)) {
0567       std::cout << "rank (" << myRank_ << ") " << msg << ": ";
0568       printElems(&ptrs_[0][0], totalNumElems_);
0569       std::cout << std::endl;
0570     }
0571   }
0572 
0573   /**
0574    * Prints specified buffer during a step
0575    * @param step The step when the buffer is being printed
0576    * @param srcRank The sender of the data
0577    * @param destRank The receiver of data
0578    * @param p Poniter to the buffer to be printed
0579    * @param count Number of elements to be printed
0580    * @param start The offset from which to print
0581    */
0582   void printStepBuffer(
0583       const std::string& stage,
0584       int step,
0585       int srcRank,
0586       int destRank,
0587       T* p,
0588       int count,
0589       int start = 0) {
0590     if (printCheck(myRank_)) {
0591       std::cout << stage << ": step (" << step << ") "
0592                 << "srcRank (" << srcRank << ") -> "
0593                 << "destRank (" << destRank << "): ";
0594       printElems(p, count, start);
0595       std::cout << std::endl;
0596     }
0597   }
0598   /**
0599    * Get all the peers of node with specified rank
0600    * @param rank Rank of the node for which peers are needed
0601    * @param step The step for which we need to get peers
0602    * @return List of ranks of all peer nodes
0603    */
0604   const std::vector<int>& getPeersPerStep(int rank, int step) {
0605     return allNodes_[rank].getPeersPerStep(step);
0606   }
0607   /**
0608    * Get count of elements specified node needs to process in specified the step
0609    * @param rank Rank of the node for which count is requested
0610    * @param step The step for which we are querying count
0611    */
0612   int getNumElemsPerStep(int rank, int step) {
0613     return allNodes_[rank].getNumElemsPerStep(step);
0614   }
0615   /**
0616    * Get offset to ptrs array specified node needs to start processing from in
0617    * the specified step
0618    * @param rank Rank of the node for which offset is requested
0619    * @param step The step for which we are querying offset
0620    */
0621   int getPtrOffsetPerStep(int rank, int step) {
0622     return allNodes_[rank].getPtrOffsetPerStep(step);
0623   }
0624   /**
0625    * Creates all the nodes with sequential ranks
0626    */
0627   void createNodes() {
0628     for (int rank = 0; rank < nodes_; ++rank) {
0629       allNodes_.emplace_back(rank, steps_);
0630     }
0631   }
0632   /**
0633    * Updates the peer, count and offset values for all the nodes in a group
0634    * @param step The step for which we are updating the values
0635    * @param groups The group object with all peer, count and offset data
0636    */
0637   void updateGroupNodes(int step, const bcube::Group& group) {
0638     const std::vector<int>& peers = group.getNodeRanks();
0639     const int peersSz = peers.size();
0640     int ptrOffset = group.getPtrOffset();
0641     int count = group.getNumElems() / peersSz;
0642     const int countRem = group.getNumElems() % peersSz;
0643     if (0 == count) {
0644       count = 1;
0645     }
0646     for (int i = 0; i < peersSz; ++i) {
0647       bcube::Node& node = allNodes_[peers[i]];
0648       if (peersSz - 1 != i) { // if not the last node in group
0649         node.setPerStepAttributes(step, peers, count, ptrOffset);
0650         ptrOffset += count;
0651       } else {
0652         /*
0653          * The last node get the remainder elements if the number of
0654          * elements is not exactly divisible by number of peers
0655          */
0656         node.setPerStepAttributes(step, peers, count + countRem, ptrOffset);
0657         ptrOffset += count + countRem;
0658       }
0659       ptrOffset %= totalNumElems_;
0660     }
0661   }
0662   /**
0663    * Setup all the nodes
0664    * Here are the things we do in this function
0665    *  - Create nodes
0666    *  - Compute and store elements per group in each step
0667    *  - Step up all the nodes
0668    */
0669   void setupNodes() {
0670     // Create all the nodes upfront
0671     createNodes();
0672 
0673     // Now we actually try to set up the nodes
0674     int peerDistance = 1;
0675     for (int step = 0; step < steps_; ++step) {
0676       std::vector<bcube::Group> groups;
0677       // Iterate over all the nodes to identify the first node of each group
0678       for (int rank = 0; rank < nodes_; ++rank) {
0679         const bcube::Node& firstNode = allNodes_[rank];
0680         // Only the ones with no peers would be first node
0681         if (0 == firstNode.getPeersPerStep(step).size()) {
0682           // Create a new group
0683           groups.emplace_back(
0684               step, firstNode, peerDistance, base_, nodes_, totalNumElems_);
0685           // Iterrate over all the peer nodes and set them up for the step
0686           updateGroupNodes(step, groups.back());
0687         } // if (0 == firstNode ...
0688       } // for (int rank = 0..
0689       // Done iterating over all the nodes. Update peerDistance for next step.
0690       peerDistance *= base_;
0691     } // for (int step ...
0692   } // setupNodes
0693 };
0694 
0695 } // namespace gloo