Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/RModel_GNN.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_RMODEL_GNN
0002 #define TMVA_SOFIE_RMODEL_GNN
0003 
0004 #include <ctime>
0005 
0006 #include "TMVA/RModel_Base.hxx"
0007 #include "TMVA/RModel.hxx"
0008 #include "TMVA/RFunction.hxx"
0009 
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013 
0014 class RFunction_Update;
0015 class RFunction_Aggregate;
0016 
0017 struct GNN_Init {
0018 
0019    // Explicitly define default constructor so cppyy doesn't attempt
0020    // aggregate initialization.
0021    GNN_Init() {}
0022 
0023    // update blocks
0024    std::unique_ptr<RFunction_Update> edges_update_block;
0025    std::unique_ptr<RFunction_Update> nodes_update_block;
0026    std::unique_ptr<RFunction_Update> globals_update_block;
0027 
0028    // aggregation blocks
0029    std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
0030    std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
0031    std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
0032 
0033    std::size_t num_nodes;
0034    std::vector<std::pair<int, int>> edges;
0035 
0036    std::size_t num_node_features;
0037    std::size_t num_edge_features;
0038    std::size_t num_global_features;
0039 
0040    std::string filename;
0041 
0042    template <typename T>
0043    void createUpdateFunction(T &updateFunction)
0044    {
0045       switch (updateFunction.GetFunctionTarget()) {
0046       case FunctionTarget::EDGES: {
0047          edges_update_block.reset(new T(updateFunction));
0048          break;
0049       }
0050       case FunctionTarget::NODES: {
0051          nodes_update_block.reset(new T(updateFunction));
0052          break;
0053       }
0054       case FunctionTarget::GLOBALS: {
0055          globals_update_block.reset(new T(updateFunction));
0056          break;
0057       }
0058       default: {
0059          throw std::runtime_error("TMVA SOFIE: Invalid Update function supplied for creating GNN function block.");
0060       }
0061       }
0062    }
0063 
0064    template <typename T>
0065    void createAggregateFunction(T &aggFunction, FunctionRelation relation)
0066    {
0067       switch (relation) {
0068       case FunctionRelation::NODES_EDGES: {
0069          edge_node_agg_block.reset(new T(aggFunction));
0070          break;
0071       }
0072       case FunctionRelation::NODES_GLOBALS: {
0073          node_global_agg_block.reset(new T(aggFunction));
0074          break;
0075       }
0076       case FunctionRelation::EDGES_GLOBALS: {
0077          edge_global_agg_block.reset(new T(aggFunction));
0078          break;
0079       }
0080       default: {
0081          throw std::runtime_error("TMVA SOFIE: Invalid Aggregate function supplied for creating GNN function block.");
0082       }
0083       }
0084    }
0085 };
0086 
0087 class RModel_GNN final : public RModel_GNNBase {
0088 
0089 private:
0090    // update function for edges, nodes & global attributes
0091    std::unique_ptr<RFunction_Update> edges_update_block;
0092    std::unique_ptr<RFunction_Update> nodes_update_block;
0093    std::unique_ptr<RFunction_Update> globals_update_block;
0094 
0095    // aggregation function for edges, nodes & global attributes
0096    std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
0097    std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
0098    std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
0099 
0100    std::size_t num_nodes; // maximum number of nodes
0101    std::size_t num_edges; // maximum number of edges
0102 
0103    std::size_t num_node_features;
0104    std::size_t num_edge_features;
0105    std::size_t num_global_features;
0106 
0107 public:
0108    RModel_GNN(GNN_Init &graph_input_struct);
0109 
0110    void Generate() final;
0111 };
0112 
0113 } // namespace SOFIE
0114 } // namespace Experimental
0115 } // namespace TMVA
0116 
0117 #endif // TMVA_SOFIE_RMODEL_GNN