Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-19 09:06:54

0001 // -*- C++ -*-
0002 #ifndef RIVET_RivetONNXrt_HH
0003 #define RIVET_RivetONNXrt_HH
0004 
0005 #include <iostream>
0006 #include <functional>
0007 #include <numeric>
0008 
0009 #include "Rivet/Tools/RivetPaths.hh"
0010 #include "Rivet/Tools/Utils.hh"
0011 #include "onnxruntime/onnxruntime_cxx_api.h"
0012 
0013 namespace Rivet {
0014 
0015 
0016   /// @brief Simple interface class to take care of basic ONNX networks
0017   ///
0018   /// See analyses/examples/EXAMPLE_ONNX.cc for how to use this.
0019   ///
0020   /// @note A node is not a neuron but a single tensor of arbitrary dimension size
0021   class RivetONNXrt {
0022 
0023   public:
0024 
0025     // Suppress default constructor
0026     RivetONNXrt() = delete;
0027 
0028     /// Constructor
0029     RivetONNXrt(const string& filename, const string& runname = "RivetONNXrt") {
0030 
0031       // Set some ORT variables that need to be kept in memory
0032       _env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, runname.c_str());
0033 
0034       // Load the model
0035       Ort::SessionOptions sessionopts;
0036       _session = std::make_unique<Ort::Session> (*_env, filename.c_str(), sessionopts);
0037 
0038       // Store network hyperparameters (input/output shape, etc.)
0039       getNetworkInfo();
0040 
0041       MSG_DEBUG(*this);
0042     }
0043 
0044     /// Given a multi-node input vector, populate and return the multi-node output vector
0045     vector<vector<float>> compute(vector<vector<float>>& inputs) const {
0046 
0047       /// Check that number of input nodes matches what the model expects
0048       if (inputs.size() != _inDims.size()) {
0049         throw("Expected " + to_string(_inDims.size())
0050               + " input nodes, received " + to_string(inputs.size()));
0051       }
0052 
0053       // Create input tensor objects from input data
0054       vector<Ort::Value> ort_input;
0055       ort_input.reserve(_inDims.size());
0056       auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
0057       for (size_t i=0; i < _inDims.size(); ++i) {
0058 
0059         // Check that input data matches expected input node dimension
0060         if (inputs[i].size() != _inDimsFlat[i]) {
0061           throw("Expected flattened input node dimension " + to_string(_inDimsFlat[i])
0062                  + ", received " + to_string(inputs[i].size()));
0063         }
0064 
0065         ort_input.emplace_back(Ort::Value::CreateTensor<float>(memory_info,
0066                                                                inputs[i].data(), inputs[i].size(),
0067                                                                _inDims[i].data(), _inDims[i].size()));
0068       }
0069 
0070       // retrieve output tensors
0071       auto ort_output = _session->Run(Ort::RunOptions{nullptr}, _inNames.data(),
0072                                       ort_input.data(), ort_input.size(),
0073                                       _outNames.data(), _outNames.size());
0074 
0075       // construct flattened values and return
0076       vector<vector<float>> outputs; outputs.resize(_outDims.size());
0077       for (size_t i = 0; i < _outDims.size(); ++i) {
0078         float* floatarr = ort_output[i].GetTensorMutableData<float>();
0079         outputs[i].assign(floatarr, floatarr + _outDimsFlat[i]);
0080       }
0081       return outputs;
0082     }
0083 
0084     /// Given a single-node input vector, populate and return the single-node output vector
0085     vector<float> compute(const vector<float>& inputs) const {
0086       if (_inDims.size() != 1 || _outDims.size() != 1) {
0087         throw("This method assumes a single input/output node!");
0088       }
0089       vector<vector<float>> wrapped_inputs = { inputs };
0090       vector<vector<float>> outputs = compute(wrapped_inputs);
0091       return outputs[0];
0092     }
0093 
0094     /// Method to check if @a key exists in network metatdata
0095     bool hasKey(const std::string& key) const {
0096       Ort::AllocatorWithDefaultOptions allocator;
0097       return (bool)_metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
0098     }
0099 
0100     /// Method to retrieve value associated with @a key
0101     /// from network metadata and return value as type T
0102     template <typename T,
0103       typename std::enable_if_t<!is_iterable_v<T> | is_cstring_v<T> >>
0104     T retrieve(const std::string& key) const {
0105       Ort::AllocatorWithDefaultOptions allocator;
0106       Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
0107       if (!res) {
0108         throw("Key '"+key+"' not found in network metadata!");
0109       }
0110       /*if constexpr (std::is_same<T, std::string>::value) {
0111         return res.get();
0112       }*/
0113       return lexical_cast<T>(res.get());
0114     }
0115 
0116     /// Template specialisation of retrieve for std::string
0117     std::string retrieve(const std::string& key) const {
0118       Ort::AllocatorWithDefaultOptions allocator;
0119       Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
0120       if (!res) {
0121         throw("Key '"+key+"' not found in network metadata!");
0122       }
0123       return res.get();
0124     }
0125 
0126     /// Overload of retrieve for vector<T>
0127     template <typename T>
0128     vector<T> retrieve(const std::string & key) const {
0129       const vector<string> stringvec = split(retrieve(key), ",");
0130       vector<T> returnvec = {};
0131       for (const string & s : stringvec){
0132         returnvec.push_back(lexical_cast<T>(s));
0133       }
0134       return returnvec;
0135     }
0136 
0137     /// Overload of retrieve for vector<T>, with a default return
0138     template <typename T>
0139     vector<T> retrieve(const std::string & key, const vector<T> & defaultreturn) const {
0140       try {
0141         return retrieve<T>(key);
0142       } catch (...) {
0143         return defaultreturn;
0144       }
0145     }
0146 
0147     std::string retrieve(const std::string& key, const std::string& defaultreturn) const {
0148       try {
0149         return retrieve(key);
0150       } catch (...) {
0151         return defaultreturn;
0152       }
0153     }
0154 
0155     /// Variation of retrieve method that falls back
0156     /// to @a defaultreturn if @a key cannot be found
0157     template <typename T,
0158       typename std::enable_if_t<!is_iterable_v<T> | is_cstring_v<T> >>
0159     T retrieve(const std::string& key, const T& defaultreturn) const {
0160       try {
0161         return retrieve<T>(key);
0162       } catch (...) {
0163         return defaultreturn;
0164       }
0165     }
0166 
0167     /// Printing function for debugging.
0168     friend std::ostream& operator <<(std::ostream& os, const RivetONNXrt& rort){
0169       os << "RivetONNXrt Network Summary: \n";
0170       for (size_t i=0; i < rort._inNames.size(); ++i) {
0171         os << "- Input node " << i << " name: " << rort._inNames[i];
0172         os << ", dimensions: (";
0173         for (size_t j=0; j < rort._inDims[i].size(); ++j){
0174           if (j)  os << ", ";
0175           os << rort._inDims[i][j];
0176         }
0177         os << "), type (as ONNX enums): " << rort._inTypes[i] << "\n";
0178       }
0179       for (size_t i=0; i < rort._outNames.size(); ++i) {
0180         os << "- Output node " << i << " name: " << rort._outNames[i];
0181         os << ", dimensions: (";
0182         for (size_t j=0; j < rort._outDims[i].size(); ++j){
0183           if (j)  os << ", ";
0184           os << rort._outDims[i][j];
0185         }
0186         os << "), type (as ONNX enums): (" << rort._outTypes[i] << "\n";
0187       }
0188       return os;
0189     }
0190 
0191     /// Logger
0192     Log& getLog() const {
0193       string logname = "Rivet.RivetONNXrt";
0194       return Log::getLog(logname);
0195     }
0196 
0197 
0198   private:
0199 
0200     void getNetworkInfo() {
0201 
0202       Ort::AllocatorWithDefaultOptions allocator;
0203 
0204       // Retrieve network metadat
0205       _metadata = std::make_unique<Ort::ModelMetadata>(_session->GetModelMetadata());
0206 
0207       // find out how many input nodes the model expects
0208       const size_t num_input_nodes = _session->GetInputCount();
0209       _inDimsFlat.reserve(num_input_nodes);
0210       _inTypes.reserve(num_input_nodes);
0211       _inDims.reserve(num_input_nodes);
0212       _inNames.reserve(num_input_nodes);
0213       _inNamesPtr.reserve(num_input_nodes);
0214       for (size_t i = 0; i < num_input_nodes; ++i) {
0215         // retrieve input node name
0216         auto input_name = _session->GetInputNameAllocated(i, allocator);
0217         _inNames.push_back(input_name.get());
0218         _inNamesPtr.push_back(std::move(input_name));
0219 
0220         // retrieve input node type
0221         auto in_type_info = _session->GetInputTypeInfo(i);
0222         auto in_tensor_info = in_type_info.GetTensorTypeAndShapeInfo();
0223         _inTypes.push_back(in_tensor_info.GetElementType());
0224         _inDims.push_back(in_tensor_info.GetShape());
0225       }
0226 
0227       // Fix negative shape values - appears to be an artefact of batch size issues.
0228       for (auto& dims : _inDims) {
0229         int64_t n = 1;
0230         for (auto& dim : dims) {
0231           if (dim < 0)  dim = abs(dim);
0232           n *= dim;
0233         }
0234         _inDimsFlat.push_back(n);
0235       }
0236 
0237       // find out how many output nodes the model expects
0238       const size_t num_output_nodes = _session->GetOutputCount();
0239       _outDimsFlat.reserve(num_output_nodes);
0240       _outTypes.reserve(num_output_nodes);
0241       _outDims.reserve(num_output_nodes);
0242       _outNames.reserve(num_output_nodes);
0243       _outNamesPtr.reserve(num_output_nodes);
0244       for (size_t i = 0; i < num_output_nodes; ++i) {
0245         // retrieve output node name
0246         auto output_name = _session->GetOutputNameAllocated(i, allocator);
0247         _outNames.push_back(output_name.get());
0248         _outNamesPtr.push_back(std::move(output_name));
0249 
0250         // retrieve input node type
0251         auto out_type_info = _session->GetOutputTypeInfo(i);
0252         auto out_tensor_info = out_type_info.GetTensorTypeAndShapeInfo();
0253         _outTypes.push_back(out_tensor_info.GetElementType());
0254         _outDims.push_back(out_tensor_info.GetShape());
0255       }
0256 
0257       // Fix negative shape values - appears to be an artefact of batch size issues.
0258       for (auto& dims : _outDims) {
0259         int64_t n = 1;
0260         for (auto& dim : dims) {
0261           if (dim < 0)  dim = abs(dim);
0262           n *= dim;
0263         }
0264         _outDimsFlat.push_back(n);
0265       }
0266     }
0267 
0268   private:
0269 
0270     /// ONNXrt environment for this session
0271     std::unique_ptr<Ort::Env> _env;
0272 
0273     /// ONNXrt session holiding the network
0274     std::unique_ptr<Ort::Session> _session;
0275 
0276     /// Network metadata
0277     std::unique_ptr<Ort::ModelMetadata> _metadata;
0278 
0279     /// Input/output node dimensions
0280     ///
0281     /// @note Each could be a multidimensional tensor
0282     vector<vector<int64_t>> _inDims, _outDims;
0283 
0284     /// Equivalent length for flattened input/ouput node structure
0285     vector<int64_t> _inDimsFlat, _outDimsFlat;
0286 
0287     /// Types of input/output nodes (as ONNX enums)
0288     vector<ONNXTensorElementDataType> _inTypes, _outTypes;
0289 
0290     /// Pointers to the ONNXrt inout/output node names
0291     vector<Ort::AllocatedStringPtr> _inNamesPtr, _outNamesPtr;
0292 
0293     /// C-style arrays of the input/output node names
0294     vector<const char*> _inNames, _outNames;
0295   };
0296 
0297 
0298   /// @brief Useful function for getting ONNX file paths
0299   ///
0300   /// Based on getDatafilePath from RivetYODA.cc
0301   inline string getONNXFilePath(const string& filename) {
0302     /// Try to find an ONNX file matching this analysis name
0303     const string path1 = findAnalysisDataFile(filename);
0304     if (!path1.empty()) return path1;
0305     throw Rivet::Error("Couldn't find an ONNX data file for '" + filename + "' " +
0306                        "in the path " + toString(getRivetDataPath()));
0307   }
0308 
0309 
0310   /// Function to get a RivetONNXrt object from an analysis name
0311   /// Use suffix to help disambiguate if an analysis requires 
0312   /// multiple networks.
0313   /// @todo: If ONNX is ever fully integrated into rivet, move
0314   /// to analysis class.
0315   inline unique_ptr<RivetONNXrt> getONNX(const string& analysisname, const string& suffix = ".onnx"){
0316     return make_unique<RivetONNXrt>(getONNXFilePath(analysisname+suffix));
0317   }
0318 
0319 
0320 }
0321 
0322 #endif