Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:10:13

0001 /**
0002  * Copyright (c) 2018-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <algorithm>
0012 #include <chrono>
0013 #include <cstdint>
0014 #include <functional>
0015 #include <memory>
0016 #include <mutex>
0017 #include <unordered_map>
0018 #include <vector>
0019 
0020 #include "gloo/common/store.h"
0021 #include "gloo/transport/pair.h"
0022 #include "gloo/transport/unbound_buffer.h"
0023 
0024 namespace gloo {
0025 namespace transport {
0026 
0027 // The context represents a set of pairs that belong to the same
0028 // group. It is roughly equivalent to the top level context class
0029 // with the exception that it captures transport specifics.
0030 //
0031 // While implementing the recv-from-any functionality we realized we
0032 // realized we needed some transport-specific state shared between all
0033 // pairs in a group, to arbitrate between multiple pairs attempting to
0034 // send to the same buffer.
0035 //
0036 class Context {
0037  public:
0038   using slot_t = uint64_t;
0039   using rank_t = int;
0040 
0041   Context(int rank, int size);
0042 
0043   virtual ~Context();
0044 
0045   const int rank;
0046   const int size;
0047 
0048   virtual std::unique_ptr<Pair>& getPair(int rank);
0049 
0050   virtual std::unique_ptr<Pair>& createPair(int rank) = 0;
0051 
0052   virtual void createAndConnectAllPairs(IStore& store);
0053 
0054   // Creates unbound buffer to be used with the ranks in this context.
0055   // It is not bound to a specific rank, but still bound to this
0056   // context. This is needed to support recv-from-any semantics, where
0057   // the context is used as shared arbiter between pairs that are
0058   // ready to send and buffers that are ready to receive.
0059   virtual std::unique_ptr<transport::UnboundBuffer> createUnboundBuffer(
0060       void* ptr,
0061       size_t size) = 0;
0062 
0063   void setTimeout(std::chrono::milliseconds timeout) {
0064     timeout_ = timeout;
0065   }
0066 
0067   std::chrono::milliseconds getTimeout() const {
0068     return timeout_;
0069   }
0070 
0071  protected:
0072   // Protects access to the pending operations and expected
0073   // notifications vectors. These vectors can only be mutated by an
0074   // instance of the Context::Mutator class, which acquires this lock
0075   // upon construction.
0076   //
0077   // The vector of pairs is logically const and may be accessed
0078   // without holding this lock.
0079   //
0080   // If this lock is acquired from a function on a Pair class, ensure
0081   // that Pair's instance lock is acquired before acquiring this lock.
0082   //
0083   std::mutex mutex_;
0084 
0085   // Lifecycle of the pairs is managed by a std::unique_ptr of the
0086   // base class. This is done because the public context API dictates
0087   // that getPair() returns a reference to this type. Functions
0088   // internal to this class can cast these points to the native
0089   // transport specific type.
0090   std::vector<std::unique_ptr<Pair>> pairs_;
0091 
0092   // Default timeout for new pairs (e.g. during initialization) and
0093   // any kind of send/recv operation.
0094   std::chrono::milliseconds timeout_;
0095 
0096   std::vector<char> extractAddress(const std::vector<char>& allAddrs, int i) const;
0097 
0098  protected:
0099   // Keep track of pending send and recv notifications or operations
0100   // for a single slot.
0101   //
0102   // The order with which ranks register a pending operation is
0103   // preserved to avoid starvation for recv-from-any operations.
0104   //
0105   class Tally final {
0106    private:
0107     // This class stores the ranks with pending operations against
0108     // this slot. The parent class uses two instances: one for send
0109     // operations and one for receive operations.
0110     //
0111     // Adding ranks with a pending operation is done through the
0112     // `push` method, which adds the rank to the end of the list.
0113     // Removing ranks with a pending operation is done through the
0114     // `shift` method, which removes a specific rank from the
0115     // beginning of the list. The caller must know the rank to remove
0116     // prior to calling `shift`, because a receive operation may be
0117     // limited to a subset of ranks.
0118     //
0119     class List final {
0120      public:
0121       bool empty() const {
0122         return ranks_.empty();
0123       }
0124 
0125       const std::vector<rank_t>& list() const {
0126         return ranks_;
0127       }
0128 
0129       // Push rank to the end of the list.
0130       void push(rank_t rank) {
0131         ranks_.push_back(rank);
0132       }
0133 
0134       // Shift rank from the beginning of the list.
0135       // Returns if the rank could be found and was removed.
0136       bool shift(rank_t rank) {
0137         auto it = std::find(ranks_.begin(), ranks_.end(), rank);
0138         if (it != ranks_.end()) {
0139           ranks_.erase(it);
0140           return true;
0141         }
0142         return false;
0143       }
0144 
0145      private:
0146       std::vector<rank_t> ranks_;
0147     };
0148 
0149    public:
0150     explicit Tally(slot_t slot) : slot(slot) {}
0151 
0152     slot_t slot;
0153 
0154     bool empty() const {
0155       return send_.empty() && recv_.empty();
0156     }
0157 
0158     const std::vector<rank_t>& getSendList() const {
0159       return send_.list();
0160     }
0161 
0162     const std::vector<rank_t>& getRecvList() const {
0163       return recv_.list();
0164     }
0165 
0166     void pushSend(rank_t rank) {
0167       send_.push(rank);
0168     }
0169 
0170     void pushRecv(rank_t rank) {
0171       recv_.push(rank);
0172     }
0173 
0174     bool shiftSend(rank_t rank) {
0175       return send_.shift(rank);
0176     }
0177 
0178     bool shiftRecv(rank_t rank) {
0179       return recv_.shift(rank);
0180     }
0181 
0182    private:
0183     List send_;
0184     List recv_;
0185   };
0186 
0187   // This class is used to locate a tally for a specific slot in a
0188   // vector container. If the tally if needed and doesn't exist, it
0189   // may be lazily created. If, on destruction, the tally that this
0190   // class points to is empty, it is removed from the container.
0191   //
0192   // This functionality is needed both for the pending operation tally
0193   // and for the expected notification tally. Therefore, we chose to
0194   // use a somewhat generalized class, over duplicating the same
0195   // functionality for two identical containers.
0196   //
0197   class LazyTally final {
0198    public:
0199     LazyTally(std::vector<Tally>& vec, slot_t slot);
0200 
0201     ~LazyTally();
0202 
0203     // Returns if a Tally instance exists for the specified slot.
0204     bool exists();
0205 
0206     // Returns pointer to Tally instance for the specified slot.
0207     // Lazily constructs a new instance if needed.
0208     Tally& get();
0209 
0210    private:
0211     // Reference to underlying container.
0212     std::vector<Tally>& vec_;
0213 
0214     // Slot for the tally we're interested in.
0215     const slot_t slot_;
0216 
0217     // Iterator to tally for the specified slot.
0218     std::vector<Tally>::iterator it_;
0219 
0220     // If the iterator has been initialized.
0221     bool initialized_;
0222 
0223     // Initialize iterator to Tally instance for this slot.
0224     void initialize_iterator();
0225   };
0226 
0227   // This class is used to mutate the pending operation tally and
0228   // expected notification tally for a specific source rank against a
0229   // specific slot. An instance is expected to have a short lifetime
0230   // as it holds a lock on the parent context object.
0231   class Mutator final {
0232    public:
0233     Mutator(Context& context, slot_t slot, rank_t rank);
0234 
0235     void pushRemotePendingRecv();
0236 
0237     void pushRemotePendingSend();
0238 
0239     bool shiftRemotePendingRecv();
0240 
0241     bool shiftRemotePendingSend();
0242 
0243     // When posting a receive operation, we first check if a send
0244     // notification for the specified slot was already received. If
0245     // not, it may already been in flight, and we must take care to
0246     // ignore it when it arrives. This function ensures that the next
0247     // send notification for this slot is ignored.
0248     //
0249     // Also see `shiftExpectedSendNotification`.
0250     //
0251     void pushExpectedSendNotification();
0252 
0253     // This function returns whether or not we were expecting a send
0254     // notification for this slot. If we do, we can ignore it.
0255     //
0256     // Also see `pushExpectedSendNotification`.
0257     //
0258     bool shiftExpectedSendNotification();
0259 
0260    private:
0261     std::lock_guard<std::mutex> lock_;
0262     Context& context_;
0263     const slot_t slot_;
0264     const rank_t rank_;
0265 
0266     // Find and mutate pending operation tally.
0267     LazyTally pendingOperations_;
0268 
0269     // Find and mutate expected notification tally.
0270     LazyTally expectedNotifications_;
0271   };
0272 
0273   // The pending operation tally is stored as a vector under the
0274   // assumption that we're working with very few of them. It should be
0275   // cheaper to perform a linear search in contiguous memory than it
0276   // is to maintain a map of them and pay a higher mutation overhead.
0277   std::vector<Tally> pendingOperations_;
0278 
0279   // If a recv operation is posted before the corresponding send
0280   // notification is received, then we need to make sure the send
0281   // notification isn't added to the pending operations vector. To do
0282   // so, we maintain a structure of notifications we expect to
0283   // receive, so that they can be dropped when they are.
0284   std::vector<Tally> expectedNotifications_;
0285 
0286   // Permit the mutator class to touch the pending operation tally.
0287   friend class Mutator;
0288 
0289  protected:
0290   // Return iterator to pending operation tally for specific slot.
0291   std::vector<Tally>::iterator findPendingOperations(slot_t slot);
0292 };
0293 
0294 } // namespace transport
0295 } // namespace gloo