Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-16 09:08:50

0001 #ifndef TMVA_SOFIE_RMODEL_GraphIndependent
0002 #define TMVA_SOFIE_RMODEL_GraphIndependent
0003 
0004 #include <ctime>
0005 
0006 #include "TMVA/RModel_Base.hxx"
0007 #include "TMVA/RModel.hxx"
0008 #include "TMVA/RFunction.hxx"
0009 
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013 
0014 class RFunction_Update;
0015 
0016 struct GraphIndependent_Init {
0017 
0018    // Explicitly define default constructor so cppyy doesn't attempt
0019    // aggregate initialization.
0020    GraphIndependent_Init() {}
0021 
0022    // update blocks
0023    std::unique_ptr<RFunction_Update> edges_update_block;
0024    std::unique_ptr<RFunction_Update> nodes_update_block;
0025    std::unique_ptr<RFunction_Update> globals_update_block;
0026 
0027    std::size_t num_nodes;
0028    std::vector<std::pair<int, int>> edges;
0029 
0030    int num_node_features;
0031    int num_edge_features;
0032    int num_global_features;
0033 
0034    std::string filename;
0035 
0036    template <typename T>
0037    void createUpdateFunction(T &updateFunction)
0038    {
0039       switch (updateFunction.GetFunctionTarget()) {
0040       case FunctionTarget::EDGES: {
0041          edges_update_block.reset(new T(updateFunction));
0042          break;
0043       }
0044       case FunctionTarget::NODES: {
0045          nodes_update_block.reset(new T(updateFunction));
0046          break;
0047       }
0048       case FunctionTarget::GLOBALS: {
0049          globals_update_block.reset(new T(updateFunction));
0050          break;
0051       }
0052       default: {
0053          throw std::runtime_error(
0054             "TMVA SOFIE: Invalid Update function supplied for creating GraphIndependent function block.");
0055       }
0056       }
0057    }
0058 };
0059 
0060 class RModel_GraphIndependent final : public RModel_GNNBase {
0061 
0062 private:
0063    // updation function for edges, nodes & global attributes
0064    std::unique_ptr<RFunction_Update> edges_update_block;
0065    std::unique_ptr<RFunction_Update> nodes_update_block;
0066    std::unique_ptr<RFunction_Update> globals_update_block;
0067 
0068    std::size_t num_nodes;
0069    std::size_t num_edges;
0070 
0071    std::size_t num_node_features;
0072    std::size_t num_edge_features;
0073    std::size_t num_global_features;
0074 
0075 public:
0076    RModel_GraphIndependent(GraphIndependent_Init &graph_input_struct);
0077 
0078    void Generate() final;
0079 };
0080 
0081 } // namespace SOFIE
0082 } // namespace Experimental
0083 } // namespace TMVA
0084 
0085 #endif // TMVA_SOFIE_RMODEL_GNN