File indexing completed on 2025-01-18 10:11:04
0001 #ifndef TMVA_SOFIE_RFUNCTION
0002 #define TMVA_SOFIE_RFUNCTION
0003
0004 #include "TMVA/RModel_Base.hxx"
0005 #include "TMVA/SOFIE_common.hxx"
0006
0007 #include <memory>
0008 #include <string>
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 class RModel;
0015
0016
0017 class RFunction {
0018 protected:
0019 std::string fFuncName;
0020 FunctionType fType;
0021 public:
0022 RFunction() {}
0023 virtual ~RFunction() {}
0024 FunctionType GetFunctionType() {
0025 return fType;
0026 }
0027
0028 RFunction(std::string funcName, FunctionType type):
0029 fFuncName(UTILITY::Clean_name(funcName)),fType(type) {}
0030
0031 };
0032
0033 class RFunction_Update: public RFunction {
0034 protected:
0035 std::shared_ptr<RModel> function_block;
0036 FunctionTarget fTarget;
0037 GraphType fGraphType;
0038 std::vector<std::string> fInputTensors;
0039 std::vector<ROperator*> fAddlOp;
0040
0041 public:
0042 virtual ~RFunction_Update() {}
0043 RFunction_Update() {}
0044 RFunction_Update(FunctionTarget target, GraphType gType);
0045
0046 virtual void AddInitializedTensors(const std::vector<std::vector<std::string>>&) {};
0047 virtual void Initialize() {};
0048 virtual void AddLayerNormalization(int, float, size_t, const std::string&,
0049 const std::string&, const std::string&, const std::string&) {};
0050 void AddInputTensors(const std::vector<std::vector<std::size_t>>& inputShapes);
0051 void AddInputTensors(const std::vector<std::vector<Dim>>& inputShapes);
0052 std::shared_ptr<RModel> GetFunctionBlock() {
0053 return function_block;
0054 }
0055 std::string GenerateModel(const std::string& filename, long read_pos = 0, long block_size = -1);
0056 std::string Generate(const std::vector<std::string>& inputPtrs);
0057 FunctionTarget GetFunctionTarget() {
0058 return fTarget;
0059 }
0060 };
0061
0062 class RFunction_Aggregate: public RFunction {
0063 protected:
0064 FunctionReducer fReducer;
0065 public:
0066 virtual ~RFunction_Aggregate() {}
0067 RFunction_Aggregate() {}
0068 RFunction_Aggregate(FunctionReducer reducer): fReducer(reducer) {
0069 fType = FunctionType::AGGREGATE;
0070 }
0071 virtual std::string GenerateModel() = 0;
0072 std::string GetFunctionName() {
0073 return fFuncName;
0074 }
0075 FunctionReducer GetFunctionReducer() {
0076 return fReducer;
0077 }
0078 std::string Generate(std::size_t num_features, const std::vector<std::string>& inputTensors);
0079 std::string Generate(std::size_t num_features, const std::string & inputTensors);
0080
0081 };
0082
0083
0084 }
0085 }
0086 }
0087
0088
0089 #endif