Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-02 09:37:02

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include <vector>
0012 
0013 #include <Eigen/Dense>
0014 #include <onnxruntime/onnxruntime_cxx_api.h>
0015 
0016 namespace Acts {
0017 
0018 using NetworkBatchInput =
0019     Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
0020 
0021 // General class that sets up the ONNX runtime framework for loading an ML model
0022 // and using it for inference.
0023 class OnnxRuntimeBase {
0024  public:
0025   /// @brief Default constructor
0026   OnnxRuntimeBase() = default;
0027 
0028   /// @brief Parametrized constructor
0029   ///
0030   /// @param env the ONNX runtime environment
0031   /// @param modelPath the path to the ML model in *.onnx format
0032   OnnxRuntimeBase(Ort::Env& env, const char* modelPath);
0033 
0034   /// @brief Default destructor
0035   ~OnnxRuntimeBase() = default;
0036 
0037   /// @brief Run the ONNX inference function
0038   ///
0039   /// @param inputTensorValues The input feature values used for prediction
0040   ///
0041   /// @return The output (predicted) values
0042   std::vector<float> runONNXInference(
0043       std::vector<float>& inputTensorValues) const;
0044 
0045   /// @brief Run the ONNX inference function for a batch of input
0046   ///
0047   /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction
0048   ///
0049   /// @return The vector of output (predicted) values
0050   std::vector<std::vector<float>> runONNXInference(
0051       NetworkBatchInput& inputTensorValues) const;
0052 
0053   /// @brief Run the multi-output ONNX inference function for a batch of input
0054   ///
0055   /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction
0056   ///
0057   /// @return The vector of output (predicted) values, one for each output
0058   std::vector<std::vector<std::vector<float>>> runONNXInferenceMultiOutput(
0059       NetworkBatchInput& inputTensorValues) const;
0060 
0061  private:
0062   /// ONNX runtime session / model properties
0063   std::unique_ptr<Ort::Session> m_session;
0064   std::vector<Ort::AllocatedStringPtr> m_inputNodeNamesAllocated;
0065   std::vector<const char*> m_inputNodeNames;
0066   std::vector<std::int64_t> m_inputNodeDims;
0067   std::vector<Ort::AllocatedStringPtr> m_outputNodeNamesAllocated;
0068   std::vector<const char*> m_outputNodeNames;
0069   std::vector<std::vector<std::int64_t>> m_outputNodeDims;
0070 };
0071 
0072 }  // namespace Acts