Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-03-13 09:13:00

0001 /**
0002  * Copyright (c) 2017-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 
0013 #include "gloo/common/logging.h"
0014 #include "gloo/transport/address.h"
0015 #include "gloo/transport/buffer.h"
0016 #include "gloo/transport/unbound_buffer.h"
0017 
0018 namespace gloo {
0019 namespace transport {
0020 
0021 class Pair {
0022  public:
0023   virtual ~Pair() = 0;
0024 
0025   virtual const Address& address() const = 0;
0026 
0027   virtual void connect(const std::vector<char>& bytes) = 0;
0028 
0029   virtual void close() = 0;
0030 
0031   virtual void setSync(bool enable, bool busyPoll) = 0;
0032 
0033   virtual std::unique_ptr<Buffer>
0034   createSendBuffer(int slot, void* ptr, size_t size) = 0;
0035 
0036   virtual std::unique_ptr<Buffer>
0037   createRecvBuffer(int slot, void* ptr, size_t size) = 0;
0038 
0039   // Send from the specified buffer to remote side of pair.
0040   virtual void send(
0041       UnboundBuffer* buf,
0042       uint64_t tag,
0043       size_t offset = 0,
0044       size_t nbytes = 0) = 0;
0045 
0046   // Receive into the specified buffer from the remote side of pair.
0047   virtual void recv(
0048       UnboundBuffer* buf,
0049       uint64_t tag,
0050       size_t offset = 0,
0051       size_t nbytes = 0) = 0;
0052 
0053   // Sets the local rank of the process to be localRank
0054   // (See below for description of local rank)
0055   void setLocalRank(int localRank) {
0056     // Local rank should be a non-negative number
0057     GLOO_ENFORCE(localRank >= 0, "LocalRank must be non-negative");
0058 
0059     localRank_ = localRank;
0060   }
0061 
0062   // Returns the local rank of the process
0063   // (See below for description of local rank)
0064   int getLocalRank() const {
0065     return localRank_;
0066   }
0067 
0068  protected:
0069   // Rank of the process on the local machine
0070   // e.g. Suppose we have 2 machines with 8 GPUs per machine.
0071   //      This means we have a total of 16 processes with
0072   //      global ranks 0 to 15. The local ranks would then
0073   //      be 0 to 7 on each machine.
0074   int localRank_;
0075 };
0076 
0077 } // namespace transport
0078 } // namespace gloo