File indexing completed on 2025-01-18 10:00:10
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012
0013 #include "gloo/context.h"
0014 #include "gloo/math.h"
0015
0016 namespace gloo {
0017
0018 extern const size_t kOnDeviceThreshold;
0019
0020 class Algorithm {
0021 public:
0022 explicit Algorithm(const std::shared_ptr<Context>&);
0023 virtual ~Algorithm() noexcept(false) = 0;
0024
0025 virtual void run() = 0;
0026
0027 protected:
0028 std::shared_ptr<Context> context_;
0029
0030 const int contextRank_;
0031 const int contextSize_;
0032
0033 std::unique_ptr<transport::Pair>& getPair(int i);
0034
0035
0036 std::unique_ptr<transport::Pair>& getLeftPair();
0037 std::unique_ptr<transport::Pair>& getRightPair();
0038 };
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049 enum ReductionType {
0050 SUM = 1,
0051 PRODUCT = 2,
0052 MAX = 3,
0053 MIN = 4,
0054
0055
0056 CUSTOM = 1000,
0057 };
0058
0059 template <typename T>
0060 class ReductionFunction {
0061 public:
0062 using Function = void(T*, const T*, size_t n);
0063
0064 static const ReductionFunction<T>* sum;
0065 static const ReductionFunction<T>* product;
0066 static const ReductionFunction<T>* min;
0067 static const ReductionFunction<T>* max;
0068
0069 ReductionFunction(ReductionType type, Function* fn)
0070 : type_(type), fn_(fn) {}
0071
0072 ReductionType type() const {
0073 return type_;
0074 }
0075
0076 void call(T* x, const T* y, size_t n) const {
0077 fn_(x, y, n);
0078 }
0079
0080 protected:
0081 ReductionType type_;
0082 Function* fn_;
0083 };
0084
0085 template <typename T>
0086 const ReductionFunction<T>* ReductionFunction<T>::sum =
0087 new ReductionFunction<T>(SUM, &::gloo::sum<T>);
0088 template <typename T>
0089 const ReductionFunction<T>* ReductionFunction<T>::product =
0090 new ReductionFunction<T>(PRODUCT, &::gloo::product<T>);
0091 template <typename T>
0092 const ReductionFunction<T>* ReductionFunction<T>::min =
0093 new ReductionFunction<T>(MIN, &::gloo::min<T>);
0094 template <typename T>
0095 const ReductionFunction<T>* ReductionFunction<T>::max =
0096 new ReductionFunction<T>(MAX, &::gloo::max<T>);
0097
0098
0099
0100
0101 template <typename T>
0102 class LocalOp {
0103 public:
0104 virtual ~LocalOp() noexcept(false) {}
0105 virtual void runAsync() = 0;
0106 virtual void wait() = 0;
0107
0108
0109 inline void run() {
0110 runAsync();
0111 wait();
0112 }
0113 };
0114
0115 }