Warning, file /include/root/TMVA/RModel_Base.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_RMODEL_BASE
0002 #define TMVA_SOFIE_RMODEL_BASE
0003
0004 #include <type_traits>
0005 #include <unordered_set>
0006 #include <vector>
0007 #include <unordered_map>
0008 #include <memory>
0009 #include <ctime>
0010 #include <set>
0011 #include <iomanip>
0012 #include <fstream>
0013 #include <sstream>
0014 #include "TMVA/SOFIE_common.hxx"
0015 #include "TMVA/ROperator.hxx"
0016 #include "TBuffer.h"
0017
0018 namespace TMVA {
0019 namespace Experimental {
0020 namespace SOFIE {
0021
0022 enum class Options {
0023 kDefault = 0x0,
0024 kNoSession = 0x1,
0025 kNoWeightFile = 0x2,
0026 kRootBinaryWeightFile = 0x4,
0027 kGNN = 0x8,
0028 kGNNComponent = 0x10,
0029 };
0030
0031
0032
0033
0034
0035 enum class OptimizationLevel {
0036 kBasic = 0x0,
0037 kExtended = 0x1,
0038 };
0039
0040 enum class WeightFileType { None, RootBinary, Text };
0041
0042 std::underlying_type_t<Options> operator|(Options opA, Options opB);
0043 std::underlying_type_t<Options> operator|(std::underlying_type_t<Options> opA, Options opB);
0044
0045 class RModel_Base {
0046
0047 protected:
0048 std::string fFileName;
0049 std::string fParseTime;
0050
0051 WeightFileType fWeightFile = WeightFileType::Text;
0052
0053 std::unordered_set<std::string> fNeededBlasRoutines;
0054
0055 std::unordered_set<std::string> fNeededStdLib = {"vector"};
0056 std::unordered_set<std::string> fCustomOpHeaders;
0057
0058 std::string fName = "UnnamedModel";
0059 std::string fGC;
0060 bool fUseWeightFile = true;
0061 bool fUseSession = true;
0062 bool fIsGNN = false;
0063 bool fIsGNNComponent = false;
0064
0065 public:
0066
0067
0068
0069
0070 RModel_Base() = default;
0071
0072 RModel_Base(std::string name, std::string parsedtime);
0073
0074
0075 RModel_Base(std::string function_name) : fName(function_name) {}
0076
0077 void AddBlasRoutines(std::vector<std::string> routines)
0078 {
0079 for (auto &routine : routines) {
0080 fNeededBlasRoutines.insert(routine);
0081 }
0082 }
0083 void AddNeededStdLib(std::string libname)
0084 {
0085 static const std::unordered_set<std::string> allowedStdLib = {"vector", "algorithm", "cmath", "memory", "span"};
0086 if (allowedStdLib.find(libname) != allowedStdLib.end()) {
0087 fNeededStdLib.insert(libname);
0088 }
0089 }
0090 void AddNeededCustomHeader(std::string filename)
0091 {
0092 fCustomOpHeaders.insert(filename);
0093 }
0094 void GenerateHeaderInfo(std::string &hgname);
0095 void PrintGenerated() { std::cout << fGC; }
0096
0097 std::string ReturnGenerated() { return fGC; }
0098 void OutputGenerated(std::string filename = "", bool append = false);
0099 void SetFilename(std::string filename) { fName = filename; }
0100 std::string GetFilename() { return fName; }
0101 const std::string & GetName() const { return fName;}
0102 };
0103
0104 enum class GraphType { INVALID = 0, GNN = 1, GraphIndependent = 2 };
0105
0106 enum class FunctionType { UPDATE = 0, AGGREGATE = 1 };
0107 enum class FunctionTarget { INVALID = 0, NODES = 1, EDGES = 2, GLOBALS = 3 };
0108 enum class FunctionReducer { INVALID = 0, SUM = 1, MEAN = 2 };
0109 enum class FunctionRelation { INVALID = 0, NODES_EDGES = 1, NODES_GLOBALS = 2, EDGES_GLOBALS = 3 };
0110
0111 class RModel_GNNBase : public RModel_Base {
0112 public:
0113 virtual void Generate() = 0;
0114 virtual ~RModel_GNNBase() = default;
0115 };
0116
0117 }
0118 }
0119 }
0120
0121 #endif