File indexing completed on 2025-01-18 10:11:04
0001 #ifndef TMVA_SOFIE_RMODEL_GraphIndependent
0002 #define TMVA_SOFIE_RMODEL_GraphIndependent
0003
0004 #include <ctime>
0005
0006 #include "TMVA/RModel_Base.hxx"
0007 #include "TMVA/RModel.hxx"
0008 #include "TMVA/RFunction.hxx"
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 class RFunction_Update;
0015
0016 struct GraphIndependent_Init {
0017
0018 std::unique_ptr<RFunction_Update> edges_update_block;
0019 std::unique_ptr<RFunction_Update> nodes_update_block;
0020 std::unique_ptr<RFunction_Update> globals_update_block;
0021
0022 std::size_t num_nodes;
0023 std::vector<std::pair<int, int>> edges;
0024
0025 int num_node_features;
0026 int num_edge_features;
0027 int num_global_features;
0028
0029 std::string filename;
0030
0031 template <typename T>
0032 void createUpdateFunction(T &updateFunction)
0033 {
0034 switch (updateFunction.GetFunctionTarget()) {
0035 case FunctionTarget::EDGES: {
0036 edges_update_block.reset(new T(updateFunction));
0037 break;
0038 }
0039 case FunctionTarget::NODES: {
0040 nodes_update_block.reset(new T(updateFunction));
0041 break;
0042 }
0043 case FunctionTarget::GLOBALS: {
0044 globals_update_block.reset(new T(updateFunction));
0045 break;
0046 }
0047 default: {
0048 throw std::runtime_error(
0049 "TMVA SOFIE: Invalid Update function supplied for creating GraphIndependent function block.");
0050 }
0051 }
0052 }
0053
0054 ~GraphIndependent_Init()
0055 {
0056 edges_update_block.reset();
0057 nodes_update_block.reset();
0058 globals_update_block.reset();
0059 }
0060 };
0061
0062 class RModel_GraphIndependent final : public RModel_GNNBase {
0063
0064 private:
0065
0066 std::unique_ptr<RFunction_Update> edges_update_block;
0067 std::unique_ptr<RFunction_Update> nodes_update_block;
0068 std::unique_ptr<RFunction_Update> globals_update_block;
0069
0070 std::size_t num_nodes;
0071 std::size_t num_edges;
0072
0073 std::size_t num_node_features;
0074 std::size_t num_edge_features;
0075 std::size_t num_global_features;
0076
0077 public:
0078
0079
0080
0081
0082 RModel_GraphIndependent() = default;
0083 RModel_GraphIndependent(GraphIndependent_Init &graph_input_struct);
0084
0085
0086 RModel_GraphIndependent(RModel_GraphIndependent &&other);
0087 RModel_GraphIndependent &operator=(RModel_GraphIndependent &&other);
0088 RModel_GraphIndependent(const RModel_GraphIndependent &other) = delete;
0089 RModel_GraphIndependent &operator=(const RModel_GraphIndependent &other) = delete;
0090 ~RModel_GraphIndependent() final = default;
0091
0092 void Generate() final;
0093 };
0094
0095 }
0096 }
0097 }
0098
0099 #endif