File indexing completed on 2025-01-18 10:11:04
0001 #ifndef TMVA_SOFIE_RMODEL_GNN
0002 #define TMVA_SOFIE_RMODEL_GNN
0003
0004 #include <ctime>
0005
0006 #include "TMVA/RModel_Base.hxx"
0007 #include "TMVA/RModel.hxx"
0008 #include "TMVA/RFunction.hxx"
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 class RFunction_Update;
0015 class RFunction_Aggregate;
0016
0017 struct GNN_Init {
0018
0019 std::unique_ptr<RFunction_Update> edges_update_block;
0020 std::unique_ptr<RFunction_Update> nodes_update_block;
0021 std::unique_ptr<RFunction_Update> globals_update_block;
0022
0023
0024 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
0025 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
0026 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
0027
0028 std::size_t num_nodes;
0029 std::vector<std::pair<int, int>> edges;
0030
0031 std::size_t num_node_features;
0032 std::size_t num_edge_features;
0033 std::size_t num_global_features;
0034
0035 std::string filename;
0036
0037 ~GNN_Init()
0038 {
0039 edges_update_block.reset();
0040 nodes_update_block.reset();
0041 globals_update_block.reset();
0042
0043 edge_node_agg_block.reset();
0044 edge_global_agg_block.reset();
0045 node_global_agg_block.reset();
0046 }
0047
0048 template <typename T>
0049 void createUpdateFunction(T &updateFunction)
0050 {
0051 switch (updateFunction.GetFunctionTarget()) {
0052 case FunctionTarget::EDGES: {
0053 edges_update_block.reset(new T(updateFunction));
0054 break;
0055 }
0056 case FunctionTarget::NODES: {
0057 nodes_update_block.reset(new T(updateFunction));
0058 break;
0059 }
0060 case FunctionTarget::GLOBALS: {
0061 globals_update_block.reset(new T(updateFunction));
0062 break;
0063 }
0064 default: {
0065 throw std::runtime_error("TMVA SOFIE: Invalid Update function supplied for creating GNN function block.");
0066 }
0067 }
0068 }
0069
0070 template <typename T>
0071 void createAggregateFunction(T &aggFunction, FunctionRelation relation)
0072 {
0073 switch (relation) {
0074 case FunctionRelation::NODES_EDGES: {
0075 edge_node_agg_block.reset(new T(aggFunction));
0076 break;
0077 }
0078 case FunctionRelation::NODES_GLOBALS: {
0079 node_global_agg_block.reset(new T(aggFunction));
0080 break;
0081 }
0082 case FunctionRelation::EDGES_GLOBALS: {
0083 edge_global_agg_block.reset(new T(aggFunction));
0084 break;
0085 }
0086 default: {
0087 throw std::runtime_error("TMVA SOFIE: Invalid Aggregate function supplied for creating GNN function block.");
0088 }
0089 }
0090 }
0091 };
0092
0093 class RModel_GNN final : public RModel_GNNBase {
0094
0095 private:
0096
0097 std::unique_ptr<RFunction_Update> edges_update_block;
0098 std::unique_ptr<RFunction_Update> nodes_update_block;
0099 std::unique_ptr<RFunction_Update> globals_update_block;
0100
0101
0102 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
0103 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
0104 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
0105
0106 std::size_t num_nodes;
0107 std::size_t num_edges;
0108
0109 std::size_t num_node_features;
0110 std::size_t num_edge_features;
0111 std::size_t num_global_features;
0112
0113 public:
0114
0115
0116
0117
0118 RModel_GNN() = default;
0119 RModel_GNN(GNN_Init &graph_input_struct);
0120
0121
0122 RModel_GNN(RModel_GNN &&other);
0123 RModel_GNN &operator=(RModel_GNN &&other);
0124 RModel_GNN(const RModel_GNN &other) = delete;
0125 RModel_GNN &operator=(const RModel_GNN &other) = delete;
0126 ~RModel_GNN() final = default;
0127
0128 void Generate() final;
0129 };
0130
0131 }
0132 }
0133 }
0134
0135 #endif