File indexing completed on 2025-04-19 09:06:54
0001
0002 #ifndef RIVET_RivetONNXrt_HH
0003 #define RIVET_RivetONNXrt_HH
0004
0005 #include <iostream>
0006 #include <functional>
0007 #include <numeric>
0008
0009 #include "Rivet/Tools/RivetPaths.hh"
0010 #include "Rivet/Tools/Utils.hh"
0011 #include "onnxruntime/onnxruntime_cxx_api.h"
0012
0013 namespace Rivet {
0014
0015
0016
0017
0018
0019
0020
0021 class RivetONNXrt {
0022
0023 public:
0024
0025
0026 RivetONNXrt() = delete;
0027
0028
0029 RivetONNXrt(const string& filename, const string& runname = "RivetONNXrt") {
0030
0031
0032 _env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, runname.c_str());
0033
0034
0035 Ort::SessionOptions sessionopts;
0036 _session = std::make_unique<Ort::Session> (*_env, filename.c_str(), sessionopts);
0037
0038
0039 getNetworkInfo();
0040
0041 MSG_DEBUG(*this);
0042 }
0043
0044
0045 vector<vector<float>> compute(vector<vector<float>>& inputs) const {
0046
0047
0048 if (inputs.size() != _inDims.size()) {
0049 throw("Expected " + to_string(_inDims.size())
0050 + " input nodes, received " + to_string(inputs.size()));
0051 }
0052
0053
0054 vector<Ort::Value> ort_input;
0055 ort_input.reserve(_inDims.size());
0056 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
0057 for (size_t i=0; i < _inDims.size(); ++i) {
0058
0059
0060 if (inputs[i].size() != _inDimsFlat[i]) {
0061 throw("Expected flattened input node dimension " + to_string(_inDimsFlat[i])
0062 + ", received " + to_string(inputs[i].size()));
0063 }
0064
0065 ort_input.emplace_back(Ort::Value::CreateTensor<float>(memory_info,
0066 inputs[i].data(), inputs[i].size(),
0067 _inDims[i].data(), _inDims[i].size()));
0068 }
0069
0070
0071 auto ort_output = _session->Run(Ort::RunOptions{nullptr}, _inNames.data(),
0072 ort_input.data(), ort_input.size(),
0073 _outNames.data(), _outNames.size());
0074
0075
0076 vector<vector<float>> outputs; outputs.resize(_outDims.size());
0077 for (size_t i = 0; i < _outDims.size(); ++i) {
0078 float* floatarr = ort_output[i].GetTensorMutableData<float>();
0079 outputs[i].assign(floatarr, floatarr + _outDimsFlat[i]);
0080 }
0081 return outputs;
0082 }
0083
0084
0085 vector<float> compute(const vector<float>& inputs) const {
0086 if (_inDims.size() != 1 || _outDims.size() != 1) {
0087 throw("This method assumes a single input/output node!");
0088 }
0089 vector<vector<float>> wrapped_inputs = { inputs };
0090 vector<vector<float>> outputs = compute(wrapped_inputs);
0091 return outputs[0];
0092 }
0093
0094
0095 bool hasKey(const std::string& key) const {
0096 Ort::AllocatorWithDefaultOptions allocator;
0097 return (bool)_metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
0098 }
0099
0100
0101
0102 template <typename T,
0103 typename std::enable_if_t<!is_iterable_v<T> | is_cstring_v<T> >>
0104 T retrieve(const std::string& key) const {
0105 Ort::AllocatorWithDefaultOptions allocator;
0106 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
0107 if (!res) {
0108 throw("Key '"+key+"' not found in network metadata!");
0109 }
0110
0111
0112
0113 return lexical_cast<T>(res.get());
0114 }
0115
0116
0117 std::string retrieve(const std::string& key) const {
0118 Ort::AllocatorWithDefaultOptions allocator;
0119 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
0120 if (!res) {
0121 throw("Key '"+key+"' not found in network metadata!");
0122 }
0123 return res.get();
0124 }
0125
0126
0127 template <typename T>
0128 vector<T> retrieve(const std::string & key) const {
0129 const vector<string> stringvec = split(retrieve(key), ",");
0130 vector<T> returnvec = {};
0131 for (const string & s : stringvec){
0132 returnvec.push_back(lexical_cast<T>(s));
0133 }
0134 return returnvec;
0135 }
0136
0137
0138 template <typename T>
0139 vector<T> retrieve(const std::string & key, const vector<T> & defaultreturn) const {
0140 try {
0141 return retrieve<T>(key);
0142 } catch (...) {
0143 return defaultreturn;
0144 }
0145 }
0146
0147 std::string retrieve(const std::string& key, const std::string& defaultreturn) const {
0148 try {
0149 return retrieve(key);
0150 } catch (...) {
0151 return defaultreturn;
0152 }
0153 }
0154
0155
0156
0157 template <typename T,
0158 typename std::enable_if_t<!is_iterable_v<T> | is_cstring_v<T> >>
0159 T retrieve(const std::string& key, const T& defaultreturn) const {
0160 try {
0161 return retrieve<T>(key);
0162 } catch (...) {
0163 return defaultreturn;
0164 }
0165 }
0166
0167
0168 friend std::ostream& operator <<(std::ostream& os, const RivetONNXrt& rort){
0169 os << "RivetONNXrt Network Summary: \n";
0170 for (size_t i=0; i < rort._inNames.size(); ++i) {
0171 os << "- Input node " << i << " name: " << rort._inNames[i];
0172 os << ", dimensions: (";
0173 for (size_t j=0; j < rort._inDims[i].size(); ++j){
0174 if (j) os << ", ";
0175 os << rort._inDims[i][j];
0176 }
0177 os << "), type (as ONNX enums): " << rort._inTypes[i] << "\n";
0178 }
0179 for (size_t i=0; i < rort._outNames.size(); ++i) {
0180 os << "- Output node " << i << " name: " << rort._outNames[i];
0181 os << ", dimensions: (";
0182 for (size_t j=0; j < rort._outDims[i].size(); ++j){
0183 if (j) os << ", ";
0184 os << rort._outDims[i][j];
0185 }
0186 os << "), type (as ONNX enums): (" << rort._outTypes[i] << "\n";
0187 }
0188 return os;
0189 }
0190
0191
0192 Log& getLog() const {
0193 string logname = "Rivet.RivetONNXrt";
0194 return Log::getLog(logname);
0195 }
0196
0197
0198 private:
0199
0200 void getNetworkInfo() {
0201
0202 Ort::AllocatorWithDefaultOptions allocator;
0203
0204
0205 _metadata = std::make_unique<Ort::ModelMetadata>(_session->GetModelMetadata());
0206
0207
0208 const size_t num_input_nodes = _session->GetInputCount();
0209 _inDimsFlat.reserve(num_input_nodes);
0210 _inTypes.reserve(num_input_nodes);
0211 _inDims.reserve(num_input_nodes);
0212 _inNames.reserve(num_input_nodes);
0213 _inNamesPtr.reserve(num_input_nodes);
0214 for (size_t i = 0; i < num_input_nodes; ++i) {
0215
0216 auto input_name = _session->GetInputNameAllocated(i, allocator);
0217 _inNames.push_back(input_name.get());
0218 _inNamesPtr.push_back(std::move(input_name));
0219
0220
0221 auto in_type_info = _session->GetInputTypeInfo(i);
0222 auto in_tensor_info = in_type_info.GetTensorTypeAndShapeInfo();
0223 _inTypes.push_back(in_tensor_info.GetElementType());
0224 _inDims.push_back(in_tensor_info.GetShape());
0225 }
0226
0227
0228 for (auto& dims : _inDims) {
0229 int64_t n = 1;
0230 for (auto& dim : dims) {
0231 if (dim < 0) dim = abs(dim);
0232 n *= dim;
0233 }
0234 _inDimsFlat.push_back(n);
0235 }
0236
0237
0238 const size_t num_output_nodes = _session->GetOutputCount();
0239 _outDimsFlat.reserve(num_output_nodes);
0240 _outTypes.reserve(num_output_nodes);
0241 _outDims.reserve(num_output_nodes);
0242 _outNames.reserve(num_output_nodes);
0243 _outNamesPtr.reserve(num_output_nodes);
0244 for (size_t i = 0; i < num_output_nodes; ++i) {
0245
0246 auto output_name = _session->GetOutputNameAllocated(i, allocator);
0247 _outNames.push_back(output_name.get());
0248 _outNamesPtr.push_back(std::move(output_name));
0249
0250
0251 auto out_type_info = _session->GetOutputTypeInfo(i);
0252 auto out_tensor_info = out_type_info.GetTensorTypeAndShapeInfo();
0253 _outTypes.push_back(out_tensor_info.GetElementType());
0254 _outDims.push_back(out_tensor_info.GetShape());
0255 }
0256
0257
0258 for (auto& dims : _outDims) {
0259 int64_t n = 1;
0260 for (auto& dim : dims) {
0261 if (dim < 0) dim = abs(dim);
0262 n *= dim;
0263 }
0264 _outDimsFlat.push_back(n);
0265 }
0266 }
0267
0268 private:
0269
0270
0271 std::unique_ptr<Ort::Env> _env;
0272
0273
0274 std::unique_ptr<Ort::Session> _session;
0275
0276
0277 std::unique_ptr<Ort::ModelMetadata> _metadata;
0278
0279
0280
0281
0282 vector<vector<int64_t>> _inDims, _outDims;
0283
0284
0285 vector<int64_t> _inDimsFlat, _outDimsFlat;
0286
0287
0288 vector<ONNXTensorElementDataType> _inTypes, _outTypes;
0289
0290
0291 vector<Ort::AllocatedStringPtr> _inNamesPtr, _outNamesPtr;
0292
0293
0294 vector<const char*> _inNames, _outNames;
0295 };
0296
0297
0298
0299
0300
0301 inline string getONNXFilePath(const string& filename) {
0302
0303 const string path1 = findAnalysisDataFile(filename);
0304 if (!path1.empty()) return path1;
0305 throw Rivet::Error("Couldn't find an ONNX data file for '" + filename + "' " +
0306 "in the path " + toString(getRivetDataPath()));
0307 }
0308
0309
0310
0311
0312
0313
0314
0315 inline unique_ptr<RivetONNXrt> getONNX(const string& analysisname, const string& suffix = ".onnx"){
0316 return make_unique<RivetONNXrt>(getONNXFilePath(analysisname+suffix));
0317 }
0318
0319
0320 }
0321
0322 #endif