File indexing completed on 2025-10-16 08:07:02
0001
0002
0003
0004 #include <fmt/core.h>
0005 #include <onnxruntime_c_api.h>
0006 #include <onnxruntime_cxx_api.h>
0007 #include <algorithm>
0008 #include <cstddef>
0009 #include <gsl/pointers>
0010 #include <iterator>
0011 #include <sstream>
0012 #include <stdexcept>
0013
0014 #include "ONNXInference.h"
0015
0016 namespace eicrecon {
0017
0018 static std::string print_shape(const std::vector<std::int64_t>& v) {
0019 std::stringstream ss("");
0020 for (std::size_t i = 0; i < v.size() - 1; i++) {
0021 ss << v[i] << " x ";
0022 }
0023 ss << v[v.size() - 1];
0024 return ss.str();
0025 }
0026
0027 static bool check_shape_consistency(const std::vector<std::int64_t>& shape1,
0028 const std::vector<std::int64_t>& shape2) {
0029 if (shape2.size() != shape1.size()) {
0030 return false;
0031 }
0032 for (std::size_t ix = 0; ix < shape1.size(); ix++) {
0033 if ((shape1[ix] != -1) && (shape2[ix] != -1) && (shape1[ix] != shape2[ix])) {
0034 return false;
0035 }
0036 }
0037 return true;
0038 }
0039
0040 template <typename T>
0041 static Ort::Value iters_to_tensor(typename std::vector<T>::const_iterator data_begin,
0042 typename std::vector<T>::const_iterator data_end,
0043 std::vector<int64_t>::const_iterator shape_begin,
0044 std::vector<int64_t>::const_iterator shape_end) {
0045 Ort::MemoryInfo mem_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator,
0046 OrtMemType::OrtMemTypeDefault);
0047 auto tensor =
0048 Ort::Value::CreateTensor<T>(mem_info, const_cast<T*>(&*data_begin), data_end - data_begin,
0049 &*shape_begin, shape_end - shape_begin);
0050 return tensor;
0051 }
0052
0053 void ONNXInference::init() {
0054
0055 m_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, name().data());
0056 Ort::SessionOptions session_options;
0057 session_options.SetInterOpNumThreads(1);
0058 session_options.SetIntraOpNumThreads(1);
0059 try {
0060 m_session = Ort::Session(m_env, m_cfg.modelPath.c_str(), session_options);
0061 Ort::AllocatorWithDefaultOptions allocator;
0062
0063
0064 debug("Input Node Name/Shape:");
0065 for (std::size_t i = 0; i < m_session.GetInputCount(); i++) {
0066 m_input_names.emplace_back(m_session.GetInputNameAllocated(i, allocator).get());
0067 m_input_shapes.emplace_back(
0068 m_session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
0069 debug("\t{} : {}", m_input_names.at(i), print_shape(m_input_shapes.at(i)));
0070 }
0071
0072
0073 debug("Output Node Name/Shape: {}", m_session.GetOutputCount());
0074 for (std::size_t i = 0; i < m_session.GetOutputCount(); i++) {
0075 m_output_names.emplace_back(m_session.GetOutputNameAllocated(i, allocator).get());
0076
0077 if (m_session.GetOutputTypeInfo(i).GetONNXType() != ONNX_TYPE_TENSOR) {
0078 m_output_shapes.emplace_back();
0079 debug("\t{} : not a tensor", m_output_names.at(i));
0080 } else {
0081 m_output_shapes.emplace_back(
0082 m_session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
0083 debug("\t{} : {}", m_output_names.at(i), print_shape(m_output_shapes.at(i)));
0084 }
0085 }
0086
0087
0088 m_input_names_char.resize(m_input_names.size(), nullptr);
0089 std::transform(std::begin(m_input_names), std::end(m_input_names),
0090 std::begin(m_input_names_char),
0091 [&](const std::string& str) { return str.c_str(); });
0092 m_output_names_char.resize(m_output_names.size(), nullptr);
0093 std::transform(std::begin(m_output_names), std::end(m_output_names),
0094 std::begin(m_output_names_char),
0095 [&](const std::string& str) { return str.c_str(); });
0096
0097 } catch (const Ort::Exception& exception) {
0098 error("ONNX error {}", exception.what());
0099 throw;
0100 }
0101 }
0102
0103 void ONNXInference::process(const ONNXInference::Input& input,
0104 const ONNXInference::Output& output) const {
0105
0106 const auto [in_tensors] = input;
0107 auto [out_tensors] = output;
0108
0109
0110 if (in_tensors.size() != m_input_names.size()) {
0111 error("The ONNX model requires {} tensors, whereas {} were provided", m_input_names.size(),
0112 in_tensors.size());
0113 throw std::runtime_error(
0114 fmt::format("The ONNX model requires {} tensors, whereas {} were provided",
0115 m_input_names.size(), in_tensors.size()));
0116 }
0117
0118
0119 std::vector<float> input_tensor_values;
0120 std::vector<Ort::Value> input_tensors;
0121
0122 for (std::size_t ix = 0; ix < m_input_names.size(); ix++) {
0123 edm4eic::Tensor in_tensor = in_tensors[ix]->at(0);
0124 if (in_tensor.getElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
0125 input_tensors.emplace_back(
0126 iters_to_tensor<float>(in_tensor.floatData_begin(), in_tensor.floatData_end(),
0127 in_tensor.shape_begin(), in_tensor.shape_end()));
0128 } else if (in_tensor.getElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
0129 input_tensors.emplace_back(
0130 iters_to_tensor<int64_t>(in_tensor.int64Data_begin(), in_tensor.int64Data_end(),
0131 in_tensor.shape_begin(), in_tensor.shape_end()));
0132 }
0133
0134 auto input_shape = input_tensors[ix].GetTensorTypeAndShapeInfo().GetShape();
0135 std::vector<std::int64_t> input_expected_shape = m_input_shapes[ix];
0136 if (!check_shape_consistency(input_shape, input_expected_shape)) {
0137 error("Input tensor shape incorrect {} != {}", print_shape(input_shape),
0138 print_shape(input_expected_shape));
0139 throw std::runtime_error(fmt::format("Input tensor shape incorrect {} != {}",
0140 print_shape(input_shape),
0141 print_shape(input_expected_shape)));
0142 }
0143 }
0144
0145
0146 std::vector<Ort::Value> onnx_values;
0147 try {
0148 onnx_values = m_session.Run(Ort::RunOptions{nullptr}, m_input_names_char.data(),
0149 input_tensors.data(), m_input_names_char.size(),
0150 m_output_names_char.data(), m_output_names_char.size());
0151 } catch (const Ort::Exception& exception) {
0152 error("Error running model inference: {}", exception.what());
0153 throw;
0154 }
0155
0156 try {
0157 for (std::size_t ix = 0; ix < onnx_values.size(); ix++) {
0158 Ort::Value& onnx_tensor = onnx_values[ix];
0159 if (!onnx_tensor.IsTensor()) {
0160 error("The output \"{}\" is not a tensor. ONNXType {} is not yet supported. Skipping...",
0161 m_output_names_char[ix], static_cast<int>(onnx_tensor.GetTypeInfo().GetONNXType()));
0162 continue;
0163 }
0164 auto onnx_tensor_type = onnx_tensor.GetTensorTypeAndShapeInfo();
0165 edm4eic::MutableTensor out_tensor = out_tensors[ix]->create();
0166 out_tensor.setElementType(static_cast<int32_t>(onnx_tensor_type.GetElementType()));
0167 std::size_t num_values = 1;
0168 for (int64_t dim_size : onnx_tensor_type.GetShape()) {
0169 out_tensor.addToShape(dim_size);
0170 num_values *= dim_size;
0171 }
0172 if (onnx_tensor_type.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
0173 auto* data = onnx_tensor.GetTensorMutableData<float>();
0174 for (std::size_t value_ix = 0; value_ix < num_values; value_ix++) {
0175 out_tensor.addToFloatData(data[value_ix]);
0176 }
0177 } else if (onnx_tensor_type.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
0178 auto* data = onnx_tensor.GetTensorMutableData<int64_t>();
0179 for (std::size_t value_ix = 0; value_ix < num_values; value_ix++) {
0180 out_tensor.addToInt64Data(data[value_ix]);
0181 }
0182 } else {
0183 error("Unsupported ONNXTensorElementDataType {}",
0184 static_cast<int>(onnx_tensor_type.GetElementType()));
0185 }
0186 }
0187 } catch (const Ort::Exception& exception) {
0188 error("Error running model inference: {}", exception.what());
0189 throw;
0190 }
0191 }
0192
0193 }