Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:04

0001 #ifndef TMVA_SOFIE_RFUNCTION
0002 #define TMVA_SOFIE_RFUNCTION
0003 
0004 #include "TMVA/RModel_Base.hxx"
0005 #include "TMVA/SOFIE_common.hxx"
0006 
0007 #include <memory>
0008 #include <string>
0009 
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013 
0014 class RModel;
0015 
0016 
0017 class RFunction {
0018 protected:
0019     std::string fFuncName;
0020     FunctionType fType;
0021 public:
0022     RFunction() {}
0023     virtual ~RFunction() {}
0024     FunctionType GetFunctionType() {
0025         return fType;
0026     }
0027 
0028     RFunction(std::string funcName, FunctionType type):
0029         fFuncName(UTILITY::Clean_name(funcName)),fType(type) {}
0030 
0031 };
0032 
0033 class RFunction_Update: public RFunction {
0034 protected:
0035     std::shared_ptr<RModel> function_block;
0036     FunctionTarget fTarget;
0037     GraphType fGraphType;
0038     std::vector<std::string> fInputTensors;
0039     std::vector<ROperator*> fAddlOp;  // temporary vector to store pointer that will be moved in a unique_ptr
0040 
0041 public:
0042     virtual ~RFunction_Update() {}
0043     RFunction_Update() {}
0044     RFunction_Update(FunctionTarget target, GraphType gType);
0045 
0046     virtual void AddInitializedTensors(const std::vector<std::vector<std::string>>&) {};
0047     virtual void Initialize() {};
0048     virtual void AddLayerNormalization(int, float, size_t, const std::string&,
0049                                        const std::string&, const std::string&, const std::string&) {};
0050     void AddInputTensors(const std::vector<std::vector<std::size_t>>& inputShapes);
0051     void AddInputTensors(const std::vector<std::vector<Dim>>& inputShapes);
0052     std::shared_ptr<RModel> GetFunctionBlock() {
0053         return function_block;
0054     }
0055     std::string GenerateModel(const std::string& filename, long read_pos = 0, long block_size = -1);
0056     std::string Generate(const std::vector<std::string>& inputPtrs);
0057     FunctionTarget GetFunctionTarget() {
0058         return fTarget;
0059     }
0060 };
0061 
0062 class RFunction_Aggregate: public RFunction {
0063 protected:
0064     FunctionReducer fReducer;
0065 public:
0066     virtual ~RFunction_Aggregate() {}
0067     RFunction_Aggregate() {}
0068     RFunction_Aggregate(FunctionReducer reducer): fReducer(reducer) {
0069         fType = FunctionType::AGGREGATE;
0070     }
0071     virtual std::string GenerateModel() = 0;
0072     std::string GetFunctionName() {
0073         return fFuncName;
0074     }
0075     FunctionReducer GetFunctionReducer() {
0076         return fReducer;
0077     }
0078     std::string Generate(std::size_t num_features, const std::vector<std::string>& inputTensors);
0079     std::string Generate(std::size_t num_features, const std::string & inputTensors);
0080 
0081 };
0082 
0083 
0084 }//SOFIE
0085 }//Experimental
0086 }//TMVA
0087 
0088 
0089 #endif //TMVA_SOFIE_RFUNCTION