|
|
|||
File indexing completed on 2025-12-02 09:37:02
0001 // This file is part of the ACTS project. 0002 // 0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project 0004 // 0005 // This Source Code Form is subject to the terms of the Mozilla Public 0006 // License, v. 2.0. If a copy of the MPL was not distributed with this 0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/. 0008 0009 #pragma once 0010 0011 #include <vector> 0012 0013 #include <Eigen/Dense> 0014 #include <onnxruntime/onnxruntime_cxx_api.h> 0015 0016 namespace Acts { 0017 0018 using NetworkBatchInput = 0019 Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 0020 0021 // General class that sets up the ONNX runtime framework for loading an ML model 0022 // and using it for inference. 0023 class OnnxRuntimeBase { 0024 public: 0025 /// @brief Default constructor 0026 OnnxRuntimeBase() = default; 0027 0028 /// @brief Parametrized constructor 0029 /// 0030 /// @param env the ONNX runtime environment 0031 /// @param modelPath the path to the ML model in *.onnx format 0032 OnnxRuntimeBase(Ort::Env& env, const char* modelPath); 0033 0034 /// @brief Default destructor 0035 ~OnnxRuntimeBase() = default; 0036 0037 /// @brief Run the ONNX inference function 0038 /// 0039 /// @param inputTensorValues The input feature values used for prediction 0040 /// 0041 /// @return The output (predicted) values 0042 std::vector<float> runONNXInference( 0043 std::vector<float>& inputTensorValues) const; 0044 0045 /// @brief Run the ONNX inference function for a batch of input 0046 /// 0047 /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction 0048 /// 0049 /// @return The vector of output (predicted) values 0050 std::vector<std::vector<float>> runONNXInference( 0051 NetworkBatchInput& inputTensorValues) const; 0052 0053 /// @brief Run the multi-output ONNX inference function for a batch of input 0054 /// 0055 /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction 0056 /// 0057 /// @return The vector of output (predicted) values, one for each output 0058 std::vector<std::vector<std::vector<float>>> runONNXInferenceMultiOutput( 0059 NetworkBatchInput& inputTensorValues) const; 0060 0061 private: 0062 /// ONNX runtime session / model properties 0063 std::unique_ptr<Ort::Session> m_session; 0064 std::vector<Ort::AllocatedStringPtr> m_inputNodeNamesAllocated; 0065 std::vector<const char*> m_inputNodeNames; 0066 std::vector<std::int64_t> m_inputNodeDims; 0067 std::vector<Ort::AllocatedStringPtr> m_outputNodeNamesAllocated; 0068 std::vector<const char*> m_outputNodeNames; 0069 std::vector<std::vector<std::int64_t>> m_outputNodeDims; 0070 }; 0071 0072 } // namespace Acts
| [ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
|
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |
|