Warning, file /include/root/TMVA/RModel_GNN.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_RMODEL_GNN
0002 #define TMVA_SOFIE_RMODEL_GNN
0003
0004 #include <ctime>
0005
0006 #include "TMVA/RModel_Base.hxx"
0007 #include "TMVA/RModel.hxx"
0008 #include "TMVA/RFunction.hxx"
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 class RFunction_Update;
0015 class RFunction_Aggregate;
0016
0017 struct GNN_Init {
0018
0019
0020
0021 GNN_Init() {}
0022
0023
0024 std::unique_ptr<RFunction_Update> edges_update_block;
0025 std::unique_ptr<RFunction_Update> nodes_update_block;
0026 std::unique_ptr<RFunction_Update> globals_update_block;
0027
0028
0029 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
0030 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
0031 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
0032
0033 std::size_t num_nodes;
0034 std::vector<std::pair<int, int>> edges;
0035
0036 std::size_t num_node_features;
0037 std::size_t num_edge_features;
0038 std::size_t num_global_features;
0039
0040 std::string filename;
0041
0042 template <typename T>
0043 void createUpdateFunction(T &updateFunction)
0044 {
0045 switch (updateFunction.GetFunctionTarget()) {
0046 case FunctionTarget::EDGES: {
0047 edges_update_block.reset(new T(updateFunction));
0048 break;
0049 }
0050 case FunctionTarget::NODES: {
0051 nodes_update_block.reset(new T(updateFunction));
0052 break;
0053 }
0054 case FunctionTarget::GLOBALS: {
0055 globals_update_block.reset(new T(updateFunction));
0056 break;
0057 }
0058 default: {
0059 throw std::runtime_error("TMVA SOFIE: Invalid Update function supplied for creating GNN function block.");
0060 }
0061 }
0062 }
0063
0064 template <typename T>
0065 void createAggregateFunction(T &aggFunction, FunctionRelation relation)
0066 {
0067 switch (relation) {
0068 case FunctionRelation::NODES_EDGES: {
0069 edge_node_agg_block.reset(new T(aggFunction));
0070 break;
0071 }
0072 case FunctionRelation::NODES_GLOBALS: {
0073 node_global_agg_block.reset(new T(aggFunction));
0074 break;
0075 }
0076 case FunctionRelation::EDGES_GLOBALS: {
0077 edge_global_agg_block.reset(new T(aggFunction));
0078 break;
0079 }
0080 default: {
0081 throw std::runtime_error("TMVA SOFIE: Invalid Aggregate function supplied for creating GNN function block.");
0082 }
0083 }
0084 }
0085 };
0086
0087 class RModel_GNN final : public RModel_GNNBase {
0088
0089 private:
0090
0091 std::unique_ptr<RFunction_Update> edges_update_block;
0092 std::unique_ptr<RFunction_Update> nodes_update_block;
0093 std::unique_ptr<RFunction_Update> globals_update_block;
0094
0095
0096 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
0097 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
0098 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
0099
0100 std::size_t num_nodes;
0101 std::size_t num_edges;
0102
0103 std::size_t num_node_features;
0104 std::size_t num_edge_features;
0105 std::size_t num_global_features;
0106
0107 public:
0108 RModel_GNN(GNN_Init &graph_input_struct);
0109
0110 void Generate() final;
0111 };
0112
0113 }
0114 }
0115 }
0116
0117 #endif