|
|
|||
File indexing completed on 2025-12-11 09:40:23
0001 // This file is part of the ACTS project. 0002 // 0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project 0004 // 0005 // This Source Code Form is subject to the terms of the Mozilla Public 0006 // License, v. 2.0. If a copy of the MPL was not distributed with this 0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/. 0008 0009 #pragma once 0010 0011 #include <vector> 0012 0013 #include <Eigen/Dense> 0014 #include <onnxruntime_cxx_api.h> 0015 0016 namespace ActsPlugins { 0017 0018 /// Type alias for network batch input data structure 0019 using NetworkBatchInput = 0020 Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 0021 0022 // General class that sets up the ONNX runtime framework for loading an ML model 0023 // and using it for inference. 0024 class OnnxRuntimeBase { 0025 public: 0026 /// @brief Default constructor 0027 OnnxRuntimeBase() = default; 0028 0029 /// @brief Parametrized constructor 0030 /// 0031 /// @param env the ONNX runtime environment 0032 /// @param modelPath the path to the ML model in *.onnx format 0033 OnnxRuntimeBase(Ort::Env& env, const char* modelPath); 0034 0035 /// @brief Default destructor 0036 ~OnnxRuntimeBase() = default; 0037 0038 /// @brief Run the ONNX inference function 0039 /// 0040 /// @param inputTensorValues The input feature values used for prediction 0041 /// 0042 /// @return The output (predicted) values 0043 std::vector<float> runONNXInference( 0044 std::vector<float>& inputTensorValues) const; 0045 0046 /// @brief Run the ONNX inference function for a batch of input 0047 /// 0048 /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction 0049 /// 0050 /// @return The vector of output (predicted) values 0051 std::vector<std::vector<float>> runONNXInference( 0052 NetworkBatchInput& inputTensorValues) const; 0053 0054 /// @brief Run the multi-output ONNX inference function for a batch of input 0055 /// 0056 /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction 0057 /// 0058 /// @return The vector of output (predicted) values, one for each output 0059 std::vector<std::vector<std::vector<float>>> runONNXInferenceMultiOutput( 0060 NetworkBatchInput& inputTensorValues) const; 0061 0062 private: 0063 /// ONNX runtime session / model properties 0064 std::unique_ptr<Ort::Session> m_session; 0065 std::vector<Ort::AllocatedStringPtr> m_inputNodeNamesAllocated; 0066 std::vector<const char*> m_inputNodeNames; 0067 std::vector<std::int64_t> m_inputNodeDims; 0068 std::vector<Ort::AllocatedStringPtr> m_outputNodeNamesAllocated; 0069 std::vector<const char*> m_outputNodeNames; 0070 std::vector<std::vector<std::int64_t>> m_outputNodeDims; 0071 }; 0072 0073 } // namespace ActsPlugins
| [ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
|
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |
|