File indexing completed on 2025-01-18 10:11:04
0001 #ifndef TMVA_SOFIE_RMODEL_BASE
0002 #define TMVA_SOFIE_RMODEL_BASE
0003
0004 #include <type_traits>
0005 #include <unordered_set>
0006 #include <vector>
0007 #include <unordered_map>
0008 #include <memory>
0009 #include <ctime>
0010 #include <set>
0011 #include <iomanip>
0012 #include <fstream>
0013 #include <sstream>
0014 #include "TMVA/SOFIE_common.hxx"
0015 #include "TMVA/ROperator.hxx"
0016 #include "TBuffer.h"
0017
0018 namespace TMVA {
0019 namespace Experimental {
0020 namespace SOFIE {
0021
0022 enum class Options {
0023 kDefault = 0x0,
0024 kNoSession = 0x1,
0025 kNoWeightFile = 0x2,
0026 kRootBinaryWeightFile = 0x4,
0027 kGNN = 0x8,
0028 kGNNComponent = 0x10,
0029 };
0030
0031 enum class WeightFileType { None, RootBinary, Text };
0032
0033 std::underlying_type_t<Options> operator|(Options opA, Options opB);
0034 std::underlying_type_t<Options> operator|(std::underlying_type_t<Options> opA, Options opB);
0035
0036 class RModel_Base {
0037
0038 protected:
0039 std::string fFileName;
0040 std::string fParseTime;
0041
0042 WeightFileType fWeightFile = WeightFileType::Text;
0043
0044 std::unordered_set<std::string> fNeededBlasRoutines;
0045
0046 const std::unordered_set<std::string> fAllowedStdLib = {"vector", "algorithm", "cmath"};
0047 std::unordered_set<std::string> fNeededStdLib = {"vector"};
0048 std::unordered_set<std::string> fCustomOpHeaders;
0049
0050 std::string fName = "UnnamedModel";
0051 std::string fGC;
0052 bool fUseWeightFile = true;
0053 bool fUseSession = true;
0054 bool fIsGNN = false;
0055 bool fIsGNNComponent = false;
0056
0057 public:
0058
0059
0060
0061
0062 RModel_Base() = default;
0063
0064 RModel_Base(std::string name, std::string parsedtime);
0065
0066
0067 RModel_Base(std::string function_name) : fName(function_name) {}
0068
0069 void AddBlasRoutines(std::vector<std::string> routines)
0070 {
0071 for (auto &routine : routines) {
0072 fNeededBlasRoutines.insert(routine);
0073 }
0074 }
0075 void AddNeededStdLib(std::string libname)
0076 {
0077 if (fAllowedStdLib.find(libname) != fAllowedStdLib.end()) {
0078 fNeededStdLib.insert(libname);
0079 }
0080 }
0081 void AddNeededCustomHeader(std::string filename) { fCustomOpHeaders.insert(filename); }
0082 void GenerateHeaderInfo(std::string &hgname);
0083 void PrintGenerated() { std::cout << fGC; }
0084
0085 std::string ReturnGenerated() { return fGC; }
0086 void OutputGenerated(std::string filename = "", bool append = false);
0087 void SetFilename(std::string filename) { fName = filename; }
0088 std::string GetFilename() { return fName; }
0089 };
0090
0091 enum class GraphType { INVALID = 0, GNN = 1, GraphIndependent = 2 };
0092
0093 enum class FunctionType { UPDATE = 0, AGGREGATE = 1 };
0094 enum class FunctionTarget { INVALID = 0, NODES = 1, EDGES = 2, GLOBALS = 3 };
0095 enum class FunctionReducer { INVALID = 0, SUM = 1, MEAN = 2 };
0096 enum class FunctionRelation { INVALID = 0, NODES_EDGES = 1, NODES_GLOBALS = 2, EDGES_GLOBALS = 3 };
0097
0098 class RModel_GNNBase : public RModel_Base {
0099 public:
0100
0101
0102
0103
0104 RModel_GNNBase() = default;
0105 virtual void Generate() = 0;
0106 virtual ~RModel_GNNBase() = default;
0107 };
0108
0109 }
0110 }
0111 }
0112
0113 #endif