Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:15:30

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2022 - 2024 Wouter Deconinck, Tooba Ali, Dmitry Kalinkin
0003 
0004 #include <edm4eic/EDM4eicVersion.h>
0005 
0006 #if EDM4EIC_VERSION_MAJOR >= 8
0007 #include <algorithm>
0008 #include <cstddef>
0009 #include <fmt/core.h>
0010 #include <gsl/pointers>
0011 #include <iterator>
0012 #include <onnxruntime_c_api.h>
0013 #include <onnxruntime_cxx_api.h>
0014 #include <ostream>
0015 #include <stdexcept>
0016 
0017 #include "ONNXInference.h"
0018 
0019 namespace eicrecon {
0020 
0021   static std::string print_shape(const std::vector<std::int64_t>& v) {
0022     std::stringstream ss("");
0023     for (std::size_t i = 0; i < v.size() - 1; i++) ss << v[i] << " x ";
0024     ss << v[v.size() - 1];
0025     return ss.str();
0026   }
0027 
0028   static bool check_shape_consistency(const std::vector<std::int64_t>& shape1, const std::vector<std::int64_t>& shape2) {
0029     if (shape2.size() != shape1.size()) {
0030       return false;
0031     }
0032     for (size_t ix = 0; ix < shape1.size(); ix++) {
0033       if ((shape1[ix] != -1) && (shape2[ix] != -1) && (shape1[ix] != shape2[ix])) {
0034         return false;
0035       }
0036     }
0037     return true;
0038   }
0039 
0040   template <typename T>
0041   static Ort::Value iters_to_tensor(
0042     typename std::vector<T>::const_iterator data_begin,
0043     typename std::vector<T>::const_iterator data_end,
0044     std::vector<int64_t>::const_iterator shape_begin,
0045     std::vector<int64_t>::const_iterator shape_end
0046   ) {
0047     Ort::MemoryInfo mem_info =
0048         Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0049     auto tensor = Ort::Value::CreateTensor<T>(mem_info, const_cast<T*>(&*data_begin), data_end - data_begin, &*shape_begin, shape_end - shape_begin);
0050     return tensor;
0051   }
0052 
0053   void ONNXInference::init() {
0054     // onnxruntime setup
0055     m_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, name().data());
0056     Ort::SessionOptions session_options;
0057     session_options.SetInterOpNumThreads(1);
0058     session_options.SetIntraOpNumThreads(1);
0059     try {
0060       m_session = Ort::Session(m_env, m_cfg.modelPath.c_str(), session_options);
0061       Ort::AllocatorWithDefaultOptions allocator;
0062 
0063       // print name/shape of inputs
0064       debug("Input Node Name/Shape:");
0065       for (std::size_t i = 0; i < m_session.GetInputCount(); i++) {
0066         m_input_names.emplace_back(m_session.GetInputNameAllocated(i, allocator).get());
0067         m_input_shapes.emplace_back(m_session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
0068         debug("\t{} : {}", m_input_names.at(i), print_shape(m_input_shapes.at(i)));
0069       }
0070 
0071       // print name/shape of outputs
0072       debug("Output Node Name/Shape: {}", m_session.GetOutputCount());
0073       for (std::size_t i = 0; i < m_session.GetOutputCount(); i++) {
0074         m_output_names.emplace_back(m_session.GetOutputNameAllocated(i, allocator).get());
0075 
0076         if (m_session.GetOutputTypeInfo(i).GetONNXType() != ONNX_TYPE_TENSOR) {
0077           m_output_shapes.emplace_back();
0078           debug("\t{} : not a tensor", m_output_names.at(i));
0079         } else {
0080           m_output_shapes.emplace_back(m_session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
0081           debug("\t{} : {}", m_output_names.at(i), print_shape(m_output_shapes.at(i)));
0082         }
0083       }
0084 
0085       // convert names to char*
0086       m_input_names_char.resize(m_input_names.size(), nullptr);
0087       std::transform(std::begin(m_input_names), std::end(m_input_names), std::begin(m_input_names_char),
0088                      [&](const std::string& str) { return str.c_str(); });
0089       m_output_names_char.resize(m_output_names.size(), nullptr);
0090       std::transform(std::begin(m_output_names), std::end(m_output_names), std::begin(m_output_names_char),
0091                      [&](const std::string& str) { return str.c_str(); });
0092 
0093     } catch(const Ort::Exception& exception) {
0094       error("ONNX error {}", exception.what());
0095       throw;
0096     }
0097   }
0098 
0099   void ONNXInference::process(
0100       const ONNXInference::Input& input,
0101       const ONNXInference::Output& output) const {
0102 
0103     const auto [in_tensors] = input;
0104     auto [out_tensors] = output;
0105 
0106     // Require valid inputs
0107     if (in_tensors.size() != m_input_names.size()) {
0108       error("The ONNX model requires {} tensors, whereas {} were provided", m_input_names.size(), in_tensors.size());
0109       throw std::runtime_error(fmt::format("The ONNX model requires {} tensors, whereas {} were provided", m_input_names.size(), in_tensors.size()));
0110     }
0111 
0112     // Prepare input tensor
0113     std::vector<float> input_tensor_values;
0114     std::vector<Ort::Value> input_tensors;
0115 
0116     for (int ix = 0; ix < m_input_names.size(); ix++) {
0117       edm4eic::Tensor in_tensor = in_tensors[ix]->at(0);
0118       if (in_tensor.getElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
0119         input_tensors.emplace_back(iters_to_tensor<float>(
0120           in_tensor.floatData_begin(),
0121           in_tensor.floatData_end(),
0122           in_tensor.shape_begin(),
0123           in_tensor.shape_end()
0124           ));
0125       } else if (in_tensor.getElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
0126         input_tensors.emplace_back(iters_to_tensor<int64_t>(
0127           in_tensor.int64Data_begin(),
0128           in_tensor.int64Data_end(),
0129           in_tensor.shape_begin(),
0130           in_tensor.shape_end()
0131           ));
0132       }
0133 
0134       auto input_shape = input_tensors[ix].GetTensorTypeAndShapeInfo().GetShape();
0135       std::vector<std::int64_t> input_expected_shape = m_input_shapes[ix];
0136       if (!check_shape_consistency(input_shape, input_expected_shape)) {
0137         error("Input tensor shape incorrect {} != {}", print_shape(input_shape), print_shape(input_expected_shape));
0138         throw std::runtime_error(fmt::format("Input tensor shape incorrect {} != {}", print_shape(input_shape), print_shape(input_expected_shape)));
0139       }
0140     }
0141 
0142     // Attempt inference
0143     std::vector<Ort::Value> onnx_values;
0144     try {
0145       onnx_values = m_session.Run(Ort::RunOptions{nullptr}, m_input_names_char.data(), input_tensors.data(),
0146                                   m_input_names_char.size(), m_output_names_char.data(), m_output_names_char.size());
0147     } catch (const Ort::Exception& exception) {
0148       error("Error running model inference: {}", exception.what());
0149       throw;
0150     }
0151 
0152     try {
0153       for (size_t ix = 0; ix < onnx_values.size(); ix++) {
0154         Ort::Value &onnx_tensor = onnx_values[ix];
0155         if (!onnx_tensor.IsTensor()) {
0156           error("The output \"{}\" is not a tensor. ONNXType {} is not yet supported. Skipping...",
0157                 m_output_names_char[ix],
0158                 static_cast<int>(onnx_tensor.GetTypeInfo().GetONNXType()));
0159           continue;
0160         }
0161         auto onnx_tensor_type = onnx_tensor.GetTensorTypeAndShapeInfo();
0162         edm4eic::MutableTensor out_tensor = out_tensors[ix]->create();
0163         out_tensor.setElementType(static_cast<int32_t>(onnx_tensor_type.GetElementType()));
0164         size_t num_values = 1;
0165         for (int64_t dim_size : onnx_tensor_type.GetShape()) {
0166           out_tensor.addToShape(dim_size);
0167           num_values *= dim_size;
0168         }
0169         if (onnx_tensor_type.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
0170           auto *data = onnx_tensor.GetTensorMutableData<float>();
0171           for (size_t value_ix = 0; value_ix < num_values; value_ix++) {
0172             out_tensor.addToFloatData(data[value_ix]);
0173           }
0174         } else if (onnx_tensor_type.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
0175           auto *data = onnx_tensor.GetTensorMutableData<int64_t>();
0176           for (size_t value_ix = 0; value_ix < num_values; value_ix++) {
0177             out_tensor.addToInt64Data(data[value_ix]);
0178           }
0179         } else {
0180           error("Unsupported ONNXTensorElementDataType {}", static_cast<int>(onnx_tensor_type.GetElementType()));
0181         }
0182       }
0183     } catch (const Ort::Exception& exception) {
0184       error("Error running model inference: {}", exception.what());
0185       throw;
0186     }
0187   }
0188 
0189 } // namespace eicrecon
0190 #endif