Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-19 09:06:54

0001 // -*- C++ -*-
0002 #ifndef RIVET_RivetLWTNN_HH
0003 #define RIVET_RivetLWTNN_HH
0004 
0005 #include "Rivet/Tools/RivetPaths.hh"
0006 #include "lwtnn/LightweightNeuralNetwork.hh"
0007 #include "lwtnn/LightweightGraph.hh"
0008 #include "lwtnn/Exceptions.hh"
0009 #include "lwtnn/parse_json.hh"
0010 #include <fstream>
0011 
0012 namespace Rivet {
0013   using namespace std;
0014 
0015 
0016   /// Read a LWT DNN config from the JSON path
0017   lwt::JSONConfig readLWTNNConfig(const string& jsonpath) {
0018     ifstream input;
0019     try {
0020       // Note: a failed read here may fail quietly, and cause the filestream to
0021       // go bad, making it look like the hepmc event-read has failed.
0022       input = std::ifstream(jsonpath);
0023       return lwt::parse_json(input);
0024     } catch (lwt::LightweightNNException &e) {
0025       input.close();
0026       throw IOError("Error loading LWTNN JSON config");
0027     }
0028   }
0029 
0030 
0031   /// @brief Read a LWT Graph config from the JSON path
0032   ///
0033   /// Note graph here means "not linear" rather than a GNN
0034   lwt::GraphConfig readLWTNNGraphConfig(const string& jsonpath) {
0035     ifstream input;
0036     try {
0037       // Note: a failed read here may fail quietly, and cause the filestream to
0038       // go bad, making it look like the hepmc event-read has failed.
0039       input = std::ifstream(jsonpath);
0040       return lwt::parse_json_graph(input);
0041     } catch (lwt::LightweightNNException &e) {
0042       input.close();
0043       throw IOError("Error loading LWTNN JSON config");
0044     }
0045   }
0046 
0047   /// Make a LWT DNN from the JSON config object
0048   std::unique_ptr<lwt::LightweightNeuralNetwork> mkLWTNN(const lwt::JSONConfig& jsonconfig) {
0049     try {
0050       return std::make_unique<lwt::LightweightNeuralNetwork>(jsonconfig.inputs, jsonconfig.layers, jsonconfig.outputs);
0051     } catch (lwt::LightweightNNException &e) {
0052       throw IOError("Error initialising from LWTNN JSON config");
0053     }
0054   }
0055 
0056   /// @brief Make a LWT Graph from the JSON config object
0057   ///
0058   /// Note graph here means "not linear" rather than a GNN
0059   std::unique_ptr<lwt::LightweightGraph> mkGraphLWTNN(const lwt::GraphConfig& graphconfig) {
0060     try {
0061       return std::make_unique<lwt::LightweightGraph>(graphconfig);
0062     } catch (lwt::LightweightNNException &e) {
0063       throw IOError("Error initialising from LWTNN JSON config");
0064     }
0065   }
0066 
0067 
0068   /// Make a LWT DNN direct from the JSON config path
0069   std::unique_ptr<lwt::LightweightNeuralNetwork> mkLWTNN(const string& jsonpath) {
0070     lwt::JSONConfig config = readLWTNNConfig(jsonpath);
0071     return mkLWTNN(config);
0072   }
0073 
0074   /// @brief Make a LWT graph direct from the JSON config path
0075   ///
0076   /// Note graph here means "not linear" rather than a GNN
0077   std::unique_ptr<lwt::LightweightGraph> mkGraphLWTNN(const string& jsonpath) {
0078     lwt::GraphConfig config = readLWTNNGraphConfig(jsonpath);
0079     return mkGraphLWTNN(config);
0080   }
0081 
0082 }
0083 
0084 #endif