Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:04

0001 #ifndef TMVA_SOFIE_RMODEL_GNN
0002 #define TMVA_SOFIE_RMODEL_GNN
0003 
0004 #include <ctime>
0005 
0006 #include "TMVA/RModel_Base.hxx"
0007 #include "TMVA/RModel.hxx"
0008 #include "TMVA/RFunction.hxx"
0009 
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013 
0014 class RFunction_Update;
0015 class RFunction_Aggregate;
0016 
0017 struct GNN_Init {
0018    // update blocks
0019    std::unique_ptr<RFunction_Update> edges_update_block;
0020    std::unique_ptr<RFunction_Update> nodes_update_block;
0021    std::unique_ptr<RFunction_Update> globals_update_block;
0022 
0023    // aggregation blocks
0024    std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
0025    std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
0026    std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
0027 
0028    std::size_t num_nodes;
0029    std::vector<std::pair<int, int>> edges;
0030 
0031    std::size_t num_node_features;
0032    std::size_t num_edge_features;
0033    std::size_t num_global_features;
0034 
0035    std::string filename;
0036 
0037    ~GNN_Init()
0038    {
0039       edges_update_block.reset();
0040       nodes_update_block.reset();
0041       globals_update_block.reset();
0042 
0043       edge_node_agg_block.reset();
0044       edge_global_agg_block.reset();
0045       node_global_agg_block.reset();
0046    }
0047 
0048    template <typename T>
0049    void createUpdateFunction(T &updateFunction)
0050    {
0051       switch (updateFunction.GetFunctionTarget()) {
0052       case FunctionTarget::EDGES: {
0053          edges_update_block.reset(new T(updateFunction));
0054          break;
0055       }
0056       case FunctionTarget::NODES: {
0057          nodes_update_block.reset(new T(updateFunction));
0058          break;
0059       }
0060       case FunctionTarget::GLOBALS: {
0061          globals_update_block.reset(new T(updateFunction));
0062          break;
0063       }
0064       default: {
0065          throw std::runtime_error("TMVA SOFIE: Invalid Update function supplied for creating GNN function block.");
0066       }
0067       }
0068    }
0069 
0070    template <typename T>
0071    void createAggregateFunction(T &aggFunction, FunctionRelation relation)
0072    {
0073       switch (relation) {
0074       case FunctionRelation::NODES_EDGES: {
0075          edge_node_agg_block.reset(new T(aggFunction));
0076          break;
0077       }
0078       case FunctionRelation::NODES_GLOBALS: {
0079          node_global_agg_block.reset(new T(aggFunction));
0080          break;
0081       }
0082       case FunctionRelation::EDGES_GLOBALS: {
0083          edge_global_agg_block.reset(new T(aggFunction));
0084          break;
0085       }
0086       default: {
0087          throw std::runtime_error("TMVA SOFIE: Invalid Aggregate function supplied for creating GNN function block.");
0088       }
0089       }
0090    }
0091 };
0092 
0093 class RModel_GNN final : public RModel_GNNBase {
0094 
0095 private:
0096    // update function for edges, nodes & global attributes
0097    std::unique_ptr<RFunction_Update> edges_update_block;
0098    std::unique_ptr<RFunction_Update> nodes_update_block;
0099    std::unique_ptr<RFunction_Update> globals_update_block;
0100 
0101    // aggregation function for edges, nodes & global attributes
0102    std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
0103    std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
0104    std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
0105 
0106    std::size_t num_nodes; // maximum number of nodes
0107    std::size_t num_edges; // maximum number of edges
0108 
0109    std::size_t num_node_features;
0110    std::size_t num_edge_features;
0111    std::size_t num_global_features;
0112 
0113 public:
0114    /**
0115        Default constructor. Needed to allow serialization of ROOT objects. See
0116        https://root.cern/manual/io_custom_classes/#restrictions-on-types-root-io-can-handle
0117    */
0118    RModel_GNN() = default;
0119    RModel_GNN(GNN_Init &graph_input_struct);
0120 
0121    // Rule of five: explicitly define move semantics, disallow copy
0122    RModel_GNN(RModel_GNN &&other);
0123    RModel_GNN &operator=(RModel_GNN &&other);
0124    RModel_GNN(const RModel_GNN &other) = delete;
0125    RModel_GNN &operator=(const RModel_GNN &other) = delete;
0126    ~RModel_GNN() final = default;
0127 
0128    void Generate() final;
0129 };
0130 
0131 } // namespace SOFIE
0132 } // namespace Experimental
0133 } // namespace TMVA
0134 
0135 #endif // TMVA_SOFIE_RMODEL_GNN