Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:10

0001 /**
0002  * Copyright (c) 2017-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 
0013 #include "gloo/context.h"
0014 #include "gloo/math.h"
0015 
0016 namespace gloo {
0017 
0018 extern const size_t kOnDeviceThreshold;
0019 
0020 class Algorithm {
0021  public:
0022   explicit Algorithm(const std::shared_ptr<Context>&);
0023   virtual ~Algorithm() noexcept(false) = 0;
0024 
0025   virtual void run() = 0;
0026 
0027  protected:
0028   std::shared_ptr<Context> context_;
0029 
0030   const int contextRank_;
0031   const int contextSize_;
0032 
0033   std::unique_ptr<transport::Pair>& getPair(int i);
0034 
0035   // Helpers for ring algorithms
0036   std::unique_ptr<transport::Pair>& getLeftPair();
0037   std::unique_ptr<transport::Pair>& getRightPair();
0038 };
0039 
0040 // Type of reduction function.
0041 //
0042 // If the reduction type is one of the built-ins, algorithm
0043 // implementations may use accelerated versions if available.
0044 //
0045 // For example, if a ReductionFunction with ReductionType equal
0046 // SUM is passed to CUDA aware Allreduce, it knows it can
0047 // use a NCCL implementation instead of the specified function.
0048 //
0049 enum ReductionType {
0050   SUM = 1,
0051   PRODUCT = 2,
0052   MAX = 3,
0053   MIN = 4,
0054 
0055   // Use larger number so we have plenty of room to add built-ins
0056   CUSTOM = 1000,
0057 };
0058 
0059 template <typename T>
0060 class ReductionFunction {
0061  public:
0062   using Function = void(T*, const T*, size_t n);
0063 
0064   static const ReductionFunction<T>* sum;
0065   static const ReductionFunction<T>* product;
0066   static const ReductionFunction<T>* min;
0067   static const ReductionFunction<T>* max;
0068 
0069   ReductionFunction(ReductionType type, Function* fn)
0070       : type_(type), fn_(fn) {}
0071 
0072   ReductionType type() const {
0073     return type_;
0074   }
0075 
0076   void call(T* x, const T* y, size_t n) const {
0077     fn_(x, y, n);
0078   }
0079 
0080  protected:
0081   ReductionType type_;
0082   Function* fn_;
0083 };
0084 
0085 template <typename T>
0086 const ReductionFunction<T>* ReductionFunction<T>::sum =
0087   new ReductionFunction<T>(SUM, &::gloo::sum<T>);
0088 template <typename T>
0089 const ReductionFunction<T>* ReductionFunction<T>::product =
0090   new ReductionFunction<T>(PRODUCT, &::gloo::product<T>);
0091 template <typename T>
0092 const ReductionFunction<T>* ReductionFunction<T>::min =
0093   new ReductionFunction<T>(MIN, &::gloo::min<T>);
0094 template <typename T>
0095 const ReductionFunction<T>* ReductionFunction<T>::max =
0096   new ReductionFunction<T>(MAX, &::gloo::max<T>);
0097 
0098 // Local operation.
0099 // If an algorithm uses multiple local pointers, local operations
0100 // can be used for local reduction, broadcast, gathering, etc.
0101 template <typename T>
0102 class LocalOp {
0103  public:
0104   virtual ~LocalOp() noexcept(false) {}
0105   virtual void runAsync() = 0;
0106   virtual void wait() = 0;
0107 
0108   // Synchronous run is equal to asynchronous run and wait.
0109   inline void run() {
0110     runAsync();
0111     wait();
0112   }
0113 };
0114 
0115 } // namespace gloo