Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:04

0001 #ifndef TMVA_SOFIE_RMODEL_BASE
0002 #define TMVA_SOFIE_RMODEL_BASE
0003 
0004 #include <type_traits>
0005 #include <unordered_set>
0006 #include <vector>
0007 #include <unordered_map>
0008 #include <memory>
0009 #include <ctime>
0010 #include <set>
0011 #include <iomanip>
0012 #include <fstream>
0013 #include <sstream>
0014 #include "TMVA/SOFIE_common.hxx"
0015 #include "TMVA/ROperator.hxx"
0016 #include "TBuffer.h"
0017 
0018 namespace TMVA {
0019 namespace Experimental {
0020 namespace SOFIE {
0021 
0022 enum class Options {
0023    kDefault = 0x0,
0024    kNoSession = 0x1,
0025    kNoWeightFile = 0x2,
0026    kRootBinaryWeightFile = 0x4,
0027    kGNN = 0x8,
0028    kGNNComponent = 0x10,
0029 };
0030 
0031 enum class WeightFileType { None, RootBinary, Text };
0032 
0033 std::underlying_type_t<Options> operator|(Options opA, Options opB);
0034 std::underlying_type_t<Options> operator|(std::underlying_type_t<Options> opA, Options opB);
0035 
0036 class RModel_Base {
0037 
0038 protected:
0039    std::string fFileName;  // file name of original model file for identification
0040    std::string fParseTime; // UTC date and time string at parsing
0041 
0042    WeightFileType fWeightFile = WeightFileType::Text;
0043 
0044    std::unordered_set<std::string> fNeededBlasRoutines;
0045 
0046    const std::unordered_set<std::string> fAllowedStdLib = {"vector", "algorithm", "cmath"};
0047    std::unordered_set<std::string> fNeededStdLib = {"vector"};
0048    std::unordered_set<std::string> fCustomOpHeaders;
0049 
0050    std::string fName = "UnnamedModel";
0051    std::string fGC; // generated code
0052    bool fUseWeightFile = true;
0053    bool fUseSession = true;
0054    bool fIsGNN = false;
0055    bool fIsGNNComponent = false;
0056 
0057 public:
0058    /**
0059        Default constructor. Needed to allow serialization of ROOT objects. See
0060        https://root.cern/manual/io_custom_classes/#restrictions-on-types-root-io-can-handle
0061    */
0062    RModel_Base() = default;
0063 
0064    RModel_Base(std::string name, std::string parsedtime);
0065 
0066    // For GNN Functions usage
0067    RModel_Base(std::string function_name) : fName(function_name) {}
0068 
0069    void AddBlasRoutines(std::vector<std::string> routines)
0070    {
0071       for (auto &routine : routines) {
0072          fNeededBlasRoutines.insert(routine);
0073       }
0074    }
0075    void AddNeededStdLib(std::string libname)
0076    {
0077       if (fAllowedStdLib.find(libname) != fAllowedStdLib.end()) {
0078          fNeededStdLib.insert(libname);
0079       }
0080    }
0081    void AddNeededCustomHeader(std::string filename) { fCustomOpHeaders.insert(filename); }
0082    void GenerateHeaderInfo(std::string &hgname);
0083    void PrintGenerated() { std::cout << fGC; }
0084 
0085    std::string ReturnGenerated() { return fGC; }
0086    void OutputGenerated(std::string filename = "", bool append = false);
0087    void SetFilename(std::string filename) { fName = filename; }
0088    std::string GetFilename() { return fName; }
0089 };
0090 
0091 enum class GraphType { INVALID = 0, GNN = 1, GraphIndependent = 2 };
0092 
0093 enum class FunctionType { UPDATE = 0, AGGREGATE = 1 };
0094 enum class FunctionTarget { INVALID = 0, NODES = 1, EDGES = 2, GLOBALS = 3 };
0095 enum class FunctionReducer { INVALID = 0, SUM = 1, MEAN = 2 };
0096 enum class FunctionRelation { INVALID = 0, NODES_EDGES = 1, NODES_GLOBALS = 2, EDGES_GLOBALS = 3 };
0097 
0098 class RModel_GNNBase : public RModel_Base {
0099 public:
0100    /**
0101        Default constructor. Needed to allow serialization of ROOT objects. See
0102        https://root.cern/manual/io_custom_classes/#restrictions-on-types-root-io-can-handle
0103    */
0104    RModel_GNNBase() = default;
0105    virtual void Generate() = 0;
0106    virtual ~RModel_GNNBase() = default;
0107 };
0108 
0109 } // namespace SOFIE
0110 } // namespace Experimental
0111 } // namespace TMVA
0112 
0113 #endif // TMVA_SOFIE_RMODEL_BASE