Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:27:42

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2020 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 #include <vector>
0011 
0012 #include <Eigen/Dense>
0013 #include <onnxruntime_cxx_api.h>
0014 
0015 namespace Acts {
0016 
0017 using NetworkBatchInput =
0018     Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
0019 
0020 // General class that sets up the ONNX runtime framework for loading an ML model
0021 // and using it for inference.
0022 class OnnxRuntimeBase {
0023  public:
0024   /// @brief Default constructor
0025   OnnxRuntimeBase() = default;
0026 
0027   /// @brief Parametrized constructor
0028   ///
0029   /// @param env the ONNX runtime environment
0030   /// @param modelPath the path to the ML model in *.onnx format
0031   OnnxRuntimeBase(Ort::Env& env, const char* modelPath);
0032 
0033   /// @brief Default destructor
0034   ~OnnxRuntimeBase() = default;
0035 
0036   /// @brief Run the ONNX inference function
0037   ///
0038   /// @param inputTensorValues The input feature values used for prediction
0039   ///
0040   /// @return The output (predicted) values
0041   std::vector<float> runONNXInference(
0042       std::vector<float>& inputTensorValues) const;
0043 
0044   /// @brief Run the ONNX inference function for a batch of input
0045   ///
0046   /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction
0047   ///
0048   /// @return The vector of output (predicted) values
0049   std::vector<std::vector<float>> runONNXInference(
0050       NetworkBatchInput& inputTensorValues) const;
0051 
0052   /// @brief Run the multi-output ONNX inference function for a batch of input
0053   ///
0054   /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction
0055   ///
0056   /// @return The vector of output (predicted) values, one for each output
0057   std::vector<std::vector<std::vector<float>>> runONNXInferenceMultiOutput(
0058       NetworkBatchInput& inputTensorValues) const;
0059 
0060  private:
0061   /// ONNX runtime session / model properties
0062   std::unique_ptr<Ort::Session> m_session;
0063   std::vector<Ort::AllocatedStringPtr> m_inputNodeNamesAllocated;
0064   std::vector<const char*> m_inputNodeNames;
0065   std::vector<int64_t> m_inputNodeDims;
0066   std::vector<Ort::AllocatedStringPtr> m_outputNodeNamesAllocated;
0067   std::vector<const char*> m_outputNodeNames;
0068   std::vector<std::vector<int64_t>> m_outputNodeDims;
0069 };
0070 
0071 }  // namespace Acts