Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/RModel_Base.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_RMODEL_BASE
0002 #define TMVA_SOFIE_RMODEL_BASE
0003 
0004 #include <type_traits>
0005 #include <unordered_set>
0006 #include <vector>
0007 #include <unordered_map>
0008 #include <memory>
0009 #include <ctime>
0010 #include <set>
0011 #include <iomanip>
0012 #include <fstream>
0013 #include <sstream>
0014 #include "TMVA/SOFIE_common.hxx"
0015 #include "TMVA/ROperator.hxx"
0016 #include "TBuffer.h"
0017 
0018 namespace TMVA {
0019 namespace Experimental {
0020 namespace SOFIE {
0021 
0022 enum class Options {
0023    kDefault = 0x0,
0024    kNoSession = 0x1,
0025    kNoWeightFile = 0x2,
0026    kRootBinaryWeightFile = 0x4,
0027    kGNN = 0x8,
0028    kGNNComponent = 0x10,
0029 };
0030 
0031 // Optimization levels inspired by ONNXRuntime.
0032 // We only get Operator Fusion with the Basic, and
0033 // memory reuse with Extended. kExtended is enabled
0034 // by default
0035 enum class OptimizationLevel {
0036    kBasic = 0x0,
0037    kExtended = 0x1,
0038 };
0039 
0040 enum class WeightFileType { None, RootBinary, Text };
0041 
0042 std::underlying_type_t<Options> operator|(Options opA, Options opB);
0043 std::underlying_type_t<Options> operator|(std::underlying_type_t<Options> opA, Options opB);
0044 
0045 class RModel_Base {
0046 
0047 protected:
0048    std::string fFileName;  // file name of original model file for identification
0049    std::string fParseTime; // UTC date and time string at parsing
0050 
0051    WeightFileType fWeightFile = WeightFileType::Text;
0052 
0053    std::unordered_set<std::string> fNeededBlasRoutines;
0054 
0055    std::unordered_set<std::string> fNeededStdLib = {"vector"};
0056    std::unordered_set<std::string> fCustomOpHeaders;
0057 
0058    std::string fName = "UnnamedModel";
0059    std::string fGC; // generated code
0060    bool fUseWeightFile = true;
0061    bool fUseSession = true;
0062    bool fIsGNN = false;
0063    bool fIsGNNComponent = false;
0064 
0065 public:
0066    /**
0067        Default constructor. Needed to allow serialization of ROOT objects. See
0068        https://root.cern/manual/io_custom_classes/#restrictions-on-types-root-io-can-handle
0069    */
0070    RModel_Base() = default;
0071 
0072    RModel_Base(std::string name, std::string parsedtime);
0073 
0074    // For GNN Functions usage
0075    RModel_Base(std::string function_name) : fName(function_name) {}
0076 
0077    void AddBlasRoutines(std::vector<std::string> routines)
0078    {
0079       for (auto &routine : routines) {
0080          fNeededBlasRoutines.insert(routine);
0081       }
0082    }
0083    void AddNeededStdLib(std::string libname)
0084    {
0085       static const std::unordered_set<std::string> allowedStdLib = {"vector", "algorithm", "cmath", "memory", "span"};
0086       if (allowedStdLib.find(libname) != allowedStdLib.end()) {
0087          fNeededStdLib.insert(libname);
0088       }
0089    }
0090    void AddNeededCustomHeader(std::string filename)
0091    {
0092        fCustomOpHeaders.insert(filename);
0093    }
0094    void GenerateHeaderInfo(std::string &hgname);
0095    void PrintGenerated() { std::cout << fGC; }
0096 
0097    std::string ReturnGenerated() { return fGC; }
0098    void OutputGenerated(std::string filename = "", bool append = false);
0099    void SetFilename(std::string filename) { fName = filename; }
0100    std::string GetFilename() { return fName; }
0101    const std::string & GetName() const { return fName;}
0102 };
0103 
0104 enum class GraphType { INVALID = 0, GNN = 1, GraphIndependent = 2 };
0105 
0106 enum class FunctionType { UPDATE = 0, AGGREGATE = 1 };
0107 enum class FunctionTarget { INVALID = 0, NODES = 1, EDGES = 2, GLOBALS = 3 };
0108 enum class FunctionReducer { INVALID = 0, SUM = 1, MEAN = 2 };
0109 enum class FunctionRelation { INVALID = 0, NODES_EDGES = 1, NODES_GLOBALS = 2, EDGES_GLOBALS = 3 };
0110 
0111 class RModel_GNNBase : public RModel_Base {
0112 public:
0113    virtual void Generate() = 0;
0114    virtual ~RModel_GNNBase() = default;
0115 };
0116 
0117 } // namespace SOFIE
0118 } // namespace Experimental
0119 } // namespace TMVA
0120 
0121 #endif // TMVA_SOFIE_RMODEL_BASE