|
||||
File indexing completed on 2025-01-18 09:27:42
0001 // This file is part of the Acts project. 0002 // 0003 // Copyright (C) 2020 CERN for the benefit of the Acts project 0004 // 0005 // This Source Code Form is subject to the terms of the Mozilla Public 0006 // License, v. 2.0. If a copy of the MPL was not distributed with this 0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/. 0008 0009 #pragma once 0010 #include <vector> 0011 0012 #include <Eigen/Dense> 0013 #include <onnxruntime_cxx_api.h> 0014 0015 namespace Acts { 0016 0017 using NetworkBatchInput = 0018 Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 0019 0020 // General class that sets up the ONNX runtime framework for loading an ML model 0021 // and using it for inference. 0022 class OnnxRuntimeBase { 0023 public: 0024 /// @brief Default constructor 0025 OnnxRuntimeBase() = default; 0026 0027 /// @brief Parametrized constructor 0028 /// 0029 /// @param env the ONNX runtime environment 0030 /// @param modelPath the path to the ML model in *.onnx format 0031 OnnxRuntimeBase(Ort::Env& env, const char* modelPath); 0032 0033 /// @brief Default destructor 0034 ~OnnxRuntimeBase() = default; 0035 0036 /// @brief Run the ONNX inference function 0037 /// 0038 /// @param inputTensorValues The input feature values used for prediction 0039 /// 0040 /// @return The output (predicted) values 0041 std::vector<float> runONNXInference( 0042 std::vector<float>& inputTensorValues) const; 0043 0044 /// @brief Run the ONNX inference function for a batch of input 0045 /// 0046 /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction 0047 /// 0048 /// @return The vector of output (predicted) values 0049 std::vector<std::vector<float>> runONNXInference( 0050 NetworkBatchInput& inputTensorValues) const; 0051 0052 /// @brief Run the multi-output ONNX inference function for a batch of input 0053 /// 0054 /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction 0055 /// 0056 /// @return The vector of output (predicted) values, one for each output 0057 std::vector<std::vector<std::vector<float>>> runONNXInferenceMultiOutput( 0058 NetworkBatchInput& inputTensorValues) const; 0059 0060 private: 0061 /// ONNX runtime session / model properties 0062 std::unique_ptr<Ort::Session> m_session; 0063 std::vector<Ort::AllocatedStringPtr> m_inputNodeNamesAllocated; 0064 std::vector<const char*> m_inputNodeNames; 0065 std::vector<int64_t> m_inputNodeDims; 0066 std::vector<Ort::AllocatedStringPtr> m_outputNodeNamesAllocated; 0067 std::vector<const char*> m_outputNodeNames; 0068 std::vector<std::vector<int64_t>> m_outputNodeDims; 0069 }; 0070 0071 } // namespace Acts
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |