File indexing completed on 2025-04-19 09:06:54
0001
0002 #ifndef RIVET_RivetLWTNN_HH
0003 #define RIVET_RivetLWTNN_HH
0004
0005 #include "Rivet/Tools/RivetPaths.hh"
0006 #include "lwtnn/LightweightNeuralNetwork.hh"
0007 #include "lwtnn/LightweightGraph.hh"
0008 #include "lwtnn/Exceptions.hh"
0009 #include "lwtnn/parse_json.hh"
0010 #include <fstream>
0011
0012 namespace Rivet {
0013 using namespace std;
0014
0015
0016
0017 lwt::JSONConfig readLWTNNConfig(const string& jsonpath) {
0018 ifstream input;
0019 try {
0020
0021
0022 input = std::ifstream(jsonpath);
0023 return lwt::parse_json(input);
0024 } catch (lwt::LightweightNNException &e) {
0025 input.close();
0026 throw IOError("Error loading LWTNN JSON config");
0027 }
0028 }
0029
0030
0031
0032
0033
0034 lwt::GraphConfig readLWTNNGraphConfig(const string& jsonpath) {
0035 ifstream input;
0036 try {
0037
0038
0039 input = std::ifstream(jsonpath);
0040 return lwt::parse_json_graph(input);
0041 } catch (lwt::LightweightNNException &e) {
0042 input.close();
0043 throw IOError("Error loading LWTNN JSON config");
0044 }
0045 }
0046
0047
0048 std::unique_ptr<lwt::LightweightNeuralNetwork> mkLWTNN(const lwt::JSONConfig& jsonconfig) {
0049 try {
0050 return std::make_unique<lwt::LightweightNeuralNetwork>(jsonconfig.inputs, jsonconfig.layers, jsonconfig.outputs);
0051 } catch (lwt::LightweightNNException &e) {
0052 throw IOError("Error initialising from LWTNN JSON config");
0053 }
0054 }
0055
0056
0057
0058
0059 std::unique_ptr<lwt::LightweightGraph> mkGraphLWTNN(const lwt::GraphConfig& graphconfig) {
0060 try {
0061 return std::make_unique<lwt::LightweightGraph>(graphconfig);
0062 } catch (lwt::LightweightNNException &e) {
0063 throw IOError("Error initialising from LWTNN JSON config");
0064 }
0065 }
0066
0067
0068
0069 std::unique_ptr<lwt::LightweightNeuralNetwork> mkLWTNN(const string& jsonpath) {
0070 lwt::JSONConfig config = readLWTNNConfig(jsonpath);
0071 return mkLWTNN(config);
0072 }
0073
0074
0075
0076
0077 std::unique_ptr<lwt::LightweightGraph> mkGraphLWTNN(const string& jsonpath) {
0078 lwt::GraphConfig config = readLWTNNGraphConfig(jsonpath);
0079 return mkGraphLWTNN(config);
0080 }
0081
0082 }
0083
0084 #endif