File indexing completed on 2025-10-31 08:17:38
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsPlugins/Onnx/OnnxRuntimeBase.hpp"
0010
0011 #include <cassert>
0012 #include <stdexcept>
0013
0014
0015 ActsPlugins::OnnxRuntimeBase::OnnxRuntimeBase(Ort::Env& env,
0016 const char* modelPath) {
0017
0018 Ort::SessionOptions sessionOptions;
0019
0020 sessionOptions.SetGraphOptimizationLevel(
0021 GraphOptimizationLevel::ORT_ENABLE_BASIC);
0022
0023 m_session = std::make_unique<Ort::Session>(env, modelPath, sessionOptions);
0024
0025 Ort::AllocatorWithDefaultOptions allocator;
0026
0027
0028 std::size_t numInputNodes = m_session->GetInputCount();
0029
0030 for (std::size_t i = 0; i < numInputNodes; i++) {
0031 m_inputNodeNamesAllocated.push_back(
0032 m_session->GetInputNameAllocated(i, allocator));
0033 m_inputNodeNames.push_back(m_inputNodeNamesAllocated.back().get());
0034
0035
0036
0037 Ort::TypeInfo inputTypeInfo = m_session->GetInputTypeInfo(i);
0038 auto tensorInfo = inputTypeInfo.GetTensorTypeAndShapeInfo();
0039 m_inputNodeDims = tensorInfo.GetShape();
0040 }
0041
0042
0043 std::size_t numOutputNodes = m_session->GetOutputCount();
0044
0045 for (std::size_t i = 0; i < numOutputNodes; i++) {
0046 m_outputNodeNamesAllocated.push_back(
0047 m_session->GetOutputNameAllocated(i, allocator));
0048 m_outputNodeNames.push_back(m_outputNodeNamesAllocated.back().get());
0049
0050
0051 Ort::TypeInfo outputTypeInfo = m_session->GetOutputTypeInfo(i);
0052 auto tensorInfo = outputTypeInfo.GetTensorTypeAndShapeInfo();
0053 m_outputNodeDims.push_back(tensorInfo.GetShape());
0054 }
0055 }
0056
0057
0058 std::vector<float> ActsPlugins::OnnxRuntimeBase::runONNXInference(
0059 std::vector<float>& inputTensorValues) const {
0060 NetworkBatchInput vectorInput(1, inputTensorValues.size());
0061 for (std::size_t i = 0; i < inputTensorValues.size(); i++) {
0062 vectorInput(0, i) = inputTensorValues[i];
0063 }
0064 auto vectorOutput = runONNXInference(vectorInput);
0065 return vectorOutput[0];
0066 }
0067
0068
0069
0070 std::vector<std::vector<float>> ActsPlugins::OnnxRuntimeBase::runONNXInference(
0071 NetworkBatchInput& inputTensorValues) const {
0072 return runONNXInferenceMultiOutput(inputTensorValues).front();
0073 }
0074
0075
0076 std::vector<std::vector<std::vector<float>>>
0077 ActsPlugins::OnnxRuntimeBase::runONNXInferenceMultiOutput(
0078 NetworkBatchInput& inputTensorValues) const {
0079 int batchSize = inputTensorValues.rows();
0080 std::vector<std::int64_t> inputNodeDims = m_inputNodeDims;
0081 std::vector<std::vector<std::int64_t>> outputNodeDims = m_outputNodeDims;
0082
0083
0084
0085 if (inputNodeDims[0] == -1) {
0086 inputNodeDims[0] = batchSize;
0087 }
0088
0089 bool outputDimsMatch = true;
0090 for (std::vector<std::int64_t>& nodeDim : outputNodeDims) {
0091 if (nodeDim[0] == -1) {
0092 nodeDim[0] = batchSize;
0093 }
0094 outputDimsMatch &= batchSize == 1 || nodeDim[0] == batchSize;
0095 }
0096
0097 if (batchSize != 1 && (inputNodeDims[0] != batchSize || !outputDimsMatch)) {
0098 throw std::runtime_error(
0099 "runONNXInference: batch size doesn't match the input or output node "
0100 "size");
0101 }
0102
0103
0104 Ort::MemoryInfo memoryInfo =
0105 Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
0106 Ort::Value inputTensor = Ort::Value::CreateTensor<float>(
0107 memoryInfo, inputTensorValues.data(), inputTensorValues.size(),
0108 inputNodeDims.data(), inputNodeDims.size());
0109
0110 if (!inputTensor.IsTensor()) {
0111 throw std::runtime_error(
0112 "runONNXInference: conversion of input to Tensor failed. ");
0113 }
0114
0115 Ort::RunOptions run_options;
0116 std::vector<Ort::Value> outputTensors =
0117 m_session->Run(run_options, m_inputNodeNames.data(), &inputTensor,
0118 m_inputNodeNames.size(), m_outputNodeNames.data(),
0119 m_outputNodeNames.size());
0120
0121
0122
0123 if (!outputTensors[0].IsTensor() ||
0124 (outputTensors.size() != m_outputNodeNames.size())) {
0125 throw std::runtime_error(
0126 "runONNXInference: calculation of output failed. ");
0127 }
0128
0129 std::vector<std::vector<std::vector<float>>> multiOutput;
0130
0131 for (std::size_t i_out = 0; i_out < outputTensors.size(); i_out++) {
0132
0133 float* outputTensor = outputTensors.at(i_out).GetTensorMutableData<float>();
0134
0135 std::vector<std::vector<float>> outputTensorValues(
0136 batchSize, std::vector<float>(outputNodeDims.at(i_out)[1], -1));
0137 for (int i = 0; i < outputNodeDims.at(i_out)[0]; i++) {
0138 for (int j = 0; j < ((outputNodeDims.at(i_out).size() > 1)
0139 ? outputNodeDims.at(i_out)[1]
0140 : 1);
0141 j++) {
0142 outputTensorValues[i][j] =
0143 outputTensor[i * outputNodeDims.at(i_out)[1] + j];
0144 }
0145 }
0146 multiOutput.push_back(std::move(outputTensorValues));
0147 }
0148 return multiOutput;
0149 }