File indexing completed on 2025-09-16 09:08:50
0001 #ifndef TMVA_SOFIE_RMODEL_GraphIndependent
0002 #define TMVA_SOFIE_RMODEL_GraphIndependent
0003
0004 #include <ctime>
0005
0006 #include "TMVA/RModel_Base.hxx"
0007 #include "TMVA/RModel.hxx"
0008 #include "TMVA/RFunction.hxx"
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 class RFunction_Update;
0015
0016 struct GraphIndependent_Init {
0017
0018
0019
0020 GraphIndependent_Init() {}
0021
0022
0023 std::unique_ptr<RFunction_Update> edges_update_block;
0024 std::unique_ptr<RFunction_Update> nodes_update_block;
0025 std::unique_ptr<RFunction_Update> globals_update_block;
0026
0027 std::size_t num_nodes;
0028 std::vector<std::pair<int, int>> edges;
0029
0030 int num_node_features;
0031 int num_edge_features;
0032 int num_global_features;
0033
0034 std::string filename;
0035
0036 template <typename T>
0037 void createUpdateFunction(T &updateFunction)
0038 {
0039 switch (updateFunction.GetFunctionTarget()) {
0040 case FunctionTarget::EDGES: {
0041 edges_update_block.reset(new T(updateFunction));
0042 break;
0043 }
0044 case FunctionTarget::NODES: {
0045 nodes_update_block.reset(new T(updateFunction));
0046 break;
0047 }
0048 case FunctionTarget::GLOBALS: {
0049 globals_update_block.reset(new T(updateFunction));
0050 break;
0051 }
0052 default: {
0053 throw std::runtime_error(
0054 "TMVA SOFIE: Invalid Update function supplied for creating GraphIndependent function block.");
0055 }
0056 }
0057 }
0058 };
0059
0060 class RModel_GraphIndependent final : public RModel_GNNBase {
0061
0062 private:
0063
0064 std::unique_ptr<RFunction_Update> edges_update_block;
0065 std::unique_ptr<RFunction_Update> nodes_update_block;
0066 std::unique_ptr<RFunction_Update> globals_update_block;
0067
0068 std::size_t num_nodes;
0069 std::size_t num_edges;
0070
0071 std::size_t num_node_features;
0072 std::size_t num_edge_features;
0073 std::size_t num_global_features;
0074
0075 public:
0076 RModel_GraphIndependent(GraphIndependent_Init &graph_input_struct);
0077
0078 void Generate() final;
0079 };
0080
0081 }
0082 }
0083 }
0084
0085 #endif