Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-11 09:40:23

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include <vector>
0012 
0013 #include <Eigen/Dense>
0014 #include <onnxruntime_cxx_api.h>
0015 
0016 namespace ActsPlugins {
0017 
0018 /// Type alias for network batch input data structure
0019 using NetworkBatchInput =
0020     Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
0021 
0022 // General class that sets up the ONNX runtime framework for loading an ML model
0023 // and using it for inference.
0024 class OnnxRuntimeBase {
0025  public:
0026   /// @brief Default constructor
0027   OnnxRuntimeBase() = default;
0028 
0029   /// @brief Parametrized constructor
0030   ///
0031   /// @param env the ONNX runtime environment
0032   /// @param modelPath the path to the ML model in *.onnx format
0033   OnnxRuntimeBase(Ort::Env& env, const char* modelPath);
0034 
0035   /// @brief Default destructor
0036   ~OnnxRuntimeBase() = default;
0037 
0038   /// @brief Run the ONNX inference function
0039   ///
0040   /// @param inputTensorValues The input feature values used for prediction
0041   ///
0042   /// @return The output (predicted) values
0043   std::vector<float> runONNXInference(
0044       std::vector<float>& inputTensorValues) const;
0045 
0046   /// @brief Run the ONNX inference function for a batch of input
0047   ///
0048   /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction
0049   ///
0050   /// @return The vector of output (predicted) values
0051   std::vector<std::vector<float>> runONNXInference(
0052       NetworkBatchInput& inputTensorValues) const;
0053 
0054   /// @brief Run the multi-output ONNX inference function for a batch of input
0055   ///
0056   /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction
0057   ///
0058   /// @return The vector of output (predicted) values, one for each output
0059   std::vector<std::vector<std::vector<float>>> runONNXInferenceMultiOutput(
0060       NetworkBatchInput& inputTensorValues) const;
0061 
0062  private:
0063   /// ONNX runtime session / model properties
0064   std::unique_ptr<Ort::Session> m_session;
0065   std::vector<Ort::AllocatedStringPtr> m_inputNodeNamesAllocated;
0066   std::vector<const char*> m_inputNodeNames;
0067   std::vector<std::int64_t> m_inputNodeDims;
0068   std::vector<Ort::AllocatedStringPtr> m_outputNodeNamesAllocated;
0069   std::vector<const char*> m_outputNodeNames;
0070   std::vector<std::vector<std::int64_t>> m_outputNodeDims;
0071 };
0072 
0073 }  // namespace ActsPlugins