Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:04

0001 #ifndef TMVA_SOFIE_RMODEL_GraphIndependent
0002 #define TMVA_SOFIE_RMODEL_GraphIndependent
0003 
0004 #include <ctime>
0005 
0006 #include "TMVA/RModel_Base.hxx"
0007 #include "TMVA/RModel.hxx"
0008 #include "TMVA/RFunction.hxx"
0009 
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013 
0014 class RFunction_Update;
0015 
0016 struct GraphIndependent_Init {
0017    // update blocks
0018    std::unique_ptr<RFunction_Update> edges_update_block;
0019    std::unique_ptr<RFunction_Update> nodes_update_block;
0020    std::unique_ptr<RFunction_Update> globals_update_block;
0021 
0022    std::size_t num_nodes;
0023    std::vector<std::pair<int, int>> edges;
0024 
0025    int num_node_features;
0026    int num_edge_features;
0027    int num_global_features;
0028 
0029    std::string filename;
0030 
0031    template <typename T>
0032    void createUpdateFunction(T &updateFunction)
0033    {
0034       switch (updateFunction.GetFunctionTarget()) {
0035       case FunctionTarget::EDGES: {
0036          edges_update_block.reset(new T(updateFunction));
0037          break;
0038       }
0039       case FunctionTarget::NODES: {
0040          nodes_update_block.reset(new T(updateFunction));
0041          break;
0042       }
0043       case FunctionTarget::GLOBALS: {
0044          globals_update_block.reset(new T(updateFunction));
0045          break;
0046       }
0047       default: {
0048          throw std::runtime_error(
0049             "TMVA SOFIE: Invalid Update function supplied for creating GraphIndependent function block.");
0050       }
0051       }
0052    }
0053 
0054    ~GraphIndependent_Init()
0055    {
0056       edges_update_block.reset();
0057       nodes_update_block.reset();
0058       globals_update_block.reset();
0059    }
0060 };
0061 
0062 class RModel_GraphIndependent final : public RModel_GNNBase {
0063 
0064 private:
0065    // updation function for edges, nodes & global attributes
0066    std::unique_ptr<RFunction_Update> edges_update_block;
0067    std::unique_ptr<RFunction_Update> nodes_update_block;
0068    std::unique_ptr<RFunction_Update> globals_update_block;
0069 
0070    std::size_t num_nodes;
0071    std::size_t num_edges;
0072 
0073    std::size_t num_node_features;
0074    std::size_t num_edge_features;
0075    std::size_t num_global_features;
0076 
0077 public:
0078    /**
0079        Default constructor. Needed to allow serialization of ROOT objects. See
0080        https://root.cern/manual/io_custom_classes/#restrictions-on-types-root-io-can-handle
0081    */
0082    RModel_GraphIndependent() = default;
0083    RModel_GraphIndependent(GraphIndependent_Init &graph_input_struct);
0084 
0085    // Rule of five: explicitly define move semantics, disallow copy
0086    RModel_GraphIndependent(RModel_GraphIndependent &&other);
0087    RModel_GraphIndependent &operator=(RModel_GraphIndependent &&other);
0088    RModel_GraphIndependent(const RModel_GraphIndependent &other) = delete;
0089    RModel_GraphIndependent &operator=(const RModel_GraphIndependent &other) = delete;
0090    ~RModel_GraphIndependent() final = default;
0091 
0092    void Generate() final;
0093 };
0094 
0095 } // namespace SOFIE
0096 } // namespace Experimental
0097 } // namespace TMVA
0098 
0099 #endif // TMVA_SOFIE_RMODEL_GNN