Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-15 08:16:13

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2022 - 2024 Wouter Deconinck, Tooba Ali, Dmitry Kalinkin
0003 
0004 #include <edm4eic/EDM4eicVersion.h>
0005 
0006 #if EDM4EIC_VERSION_MAJOR >= 8
0007 #include <algorithm>
0008 #include <cstddef>
0009 #include <fmt/core.h>
0010 #include <gsl/pointers>
0011 #include <iterator>
0012 #include <onnxruntime_c_api.h>
0013 #include <onnxruntime_cxx_api.h>
0014 #include <ostream>
0015 #include <stdexcept>
0016 
0017 #include "ONNXInference.h"
0018 
0019 namespace eicrecon {
0020 
0021 static std::string print_shape(const std::vector<std::int64_t>& v) {
0022   std::stringstream ss("");
0023   for (std::size_t i = 0; i < v.size() - 1; i++) {
0024     ss << v[i] << " x ";
0025   }
0026   ss << v[v.size() - 1];
0027   return ss.str();
0028 }
0029 
0030 static bool check_shape_consistency(const std::vector<std::int64_t>& shape1,
0031                                     const std::vector<std::int64_t>& shape2) {
0032   if (shape2.size() != shape1.size()) {
0033     return false;
0034   }
0035   for (std::size_t ix = 0; ix < shape1.size(); ix++) {
0036     if ((shape1[ix] != -1) && (shape2[ix] != -1) && (shape1[ix] != shape2[ix])) {
0037       return false;
0038     }
0039   }
0040   return true;
0041 }
0042 
0043 template <typename T>
0044 static Ort::Value iters_to_tensor(typename std::vector<T>::const_iterator data_begin,
0045                                   typename std::vector<T>::const_iterator data_end,
0046                                   std::vector<int64_t>::const_iterator shape_begin,
0047                                   std::vector<int64_t>::const_iterator shape_end) {
0048   Ort::MemoryInfo mem_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator,
0049                                                         OrtMemType::OrtMemTypeDefault);
0050   auto tensor =
0051       Ort::Value::CreateTensor<T>(mem_info, const_cast<T*>(&*data_begin), data_end - data_begin,
0052                                   &*shape_begin, shape_end - shape_begin);
0053   return tensor;
0054 }
0055 
0056 void ONNXInference::init() {
0057   // onnxruntime setup
0058   m_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, name().data());
0059   Ort::SessionOptions session_options;
0060   session_options.SetInterOpNumThreads(1);
0061   session_options.SetIntraOpNumThreads(1);
0062   try {
0063     m_session = Ort::Session(m_env, m_cfg.modelPath.c_str(), session_options);
0064     Ort::AllocatorWithDefaultOptions allocator;
0065 
0066     // print name/shape of inputs
0067     debug("Input Node Name/Shape:");
0068     for (std::size_t i = 0; i < m_session.GetInputCount(); i++) {
0069       m_input_names.emplace_back(m_session.GetInputNameAllocated(i, allocator).get());
0070       m_input_shapes.emplace_back(
0071           m_session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
0072       debug("\t{} : {}", m_input_names.at(i), print_shape(m_input_shapes.at(i)));
0073     }
0074 
0075     // print name/shape of outputs
0076     debug("Output Node Name/Shape: {}", m_session.GetOutputCount());
0077     for (std::size_t i = 0; i < m_session.GetOutputCount(); i++) {
0078       m_output_names.emplace_back(m_session.GetOutputNameAllocated(i, allocator).get());
0079 
0080       if (m_session.GetOutputTypeInfo(i).GetONNXType() != ONNX_TYPE_TENSOR) {
0081         m_output_shapes.emplace_back();
0082         debug("\t{} : not a tensor", m_output_names.at(i));
0083       } else {
0084         m_output_shapes.emplace_back(
0085             m_session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
0086         debug("\t{} : {}", m_output_names.at(i), print_shape(m_output_shapes.at(i)));
0087       }
0088     }
0089 
0090     // convert names to char*
0091     m_input_names_char.resize(m_input_names.size(), nullptr);
0092     std::transform(std::begin(m_input_names), std::end(m_input_names),
0093                    std::begin(m_input_names_char),
0094                    [&](const std::string& str) { return str.c_str(); });
0095     m_output_names_char.resize(m_output_names.size(), nullptr);
0096     std::transform(std::begin(m_output_names), std::end(m_output_names),
0097                    std::begin(m_output_names_char),
0098                    [&](const std::string& str) { return str.c_str(); });
0099 
0100   } catch (const Ort::Exception& exception) {
0101     error("ONNX error {}", exception.what());
0102     throw;
0103   }
0104 }
0105 
0106 void ONNXInference::process(const ONNXInference::Input& input,
0107                             const ONNXInference::Output& output) const {
0108 
0109   const auto [in_tensors] = input;
0110   auto [out_tensors]      = output;
0111 
0112   // Require valid inputs
0113   if (in_tensors.size() != m_input_names.size()) {
0114     error("The ONNX model requires {} tensors, whereas {} were provided", m_input_names.size(),
0115           in_tensors.size());
0116     throw std::runtime_error(
0117         fmt::format("The ONNX model requires {} tensors, whereas {} were provided",
0118                     m_input_names.size(), in_tensors.size()));
0119   }
0120 
0121   // Prepare input tensor
0122   std::vector<float> input_tensor_values;
0123   std::vector<Ort::Value> input_tensors;
0124 
0125   for (std::size_t ix = 0; ix < m_input_names.size(); ix++) {
0126     edm4eic::Tensor in_tensor = in_tensors[ix]->at(0);
0127     if (in_tensor.getElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
0128       input_tensors.emplace_back(
0129           iters_to_tensor<float>(in_tensor.floatData_begin(), in_tensor.floatData_end(),
0130                                  in_tensor.shape_begin(), in_tensor.shape_end()));
0131     } else if (in_tensor.getElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
0132       input_tensors.emplace_back(
0133           iters_to_tensor<int64_t>(in_tensor.int64Data_begin(), in_tensor.int64Data_end(),
0134                                    in_tensor.shape_begin(), in_tensor.shape_end()));
0135     }
0136 
0137     auto input_shape = input_tensors[ix].GetTensorTypeAndShapeInfo().GetShape();
0138     std::vector<std::int64_t> input_expected_shape = m_input_shapes[ix];
0139     if (!check_shape_consistency(input_shape, input_expected_shape)) {
0140       error("Input tensor shape incorrect {} != {}", print_shape(input_shape),
0141             print_shape(input_expected_shape));
0142       throw std::runtime_error(fmt::format("Input tensor shape incorrect {} != {}",
0143                                            print_shape(input_shape),
0144                                            print_shape(input_expected_shape)));
0145     }
0146   }
0147 
0148   // Attempt inference
0149   std::vector<Ort::Value> onnx_values;
0150   try {
0151     onnx_values = m_session.Run(Ort::RunOptions{nullptr}, m_input_names_char.data(),
0152                                 input_tensors.data(), m_input_names_char.size(),
0153                                 m_output_names_char.data(), m_output_names_char.size());
0154   } catch (const Ort::Exception& exception) {
0155     error("Error running model inference: {}", exception.what());
0156     throw;
0157   }
0158 
0159   try {
0160     for (std::size_t ix = 0; ix < onnx_values.size(); ix++) {
0161       Ort::Value& onnx_tensor = onnx_values[ix];
0162       if (!onnx_tensor.IsTensor()) {
0163         error("The output \"{}\" is not a tensor. ONNXType {} is not yet supported. Skipping...",
0164               m_output_names_char[ix], static_cast<int>(onnx_tensor.GetTypeInfo().GetONNXType()));
0165         continue;
0166       }
0167       auto onnx_tensor_type             = onnx_tensor.GetTensorTypeAndShapeInfo();
0168       edm4eic::MutableTensor out_tensor = out_tensors[ix]->create();
0169       out_tensor.setElementType(static_cast<int32_t>(onnx_tensor_type.GetElementType()));
0170       std::size_t num_values = 1;
0171       for (int64_t dim_size : onnx_tensor_type.GetShape()) {
0172         out_tensor.addToShape(dim_size);
0173         num_values *= dim_size;
0174       }
0175       if (onnx_tensor_type.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
0176         auto* data = onnx_tensor.GetTensorMutableData<float>();
0177         for (std::size_t value_ix = 0; value_ix < num_values; value_ix++) {
0178           out_tensor.addToFloatData(data[value_ix]);
0179         }
0180       } else if (onnx_tensor_type.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
0181         auto* data = onnx_tensor.GetTensorMutableData<int64_t>();
0182         for (std::size_t value_ix = 0; value_ix < num_values; value_ix++) {
0183           out_tensor.addToInt64Data(data[value_ix]);
0184         }
0185       } else {
0186         error("Unsupported ONNXTensorElementDataType {}",
0187               static_cast<int>(onnx_tensor_type.GetElementType()));
0188       }
0189     }
0190   } catch (const Ort::Exception& exception) {
0191     error("Error running model inference: {}", exception.what());
0192     throw;
0193   }
0194 }
0195 
0196 } // namespace eicrecon
0197 #endif