Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-03 07:55:53

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2022, 2023 Wouter Deconinck, Tooba Ali
0003 
0004 #include <fmt/core.h>
0005 #include <onnxruntime_c_api.h>
0006 #include <onnxruntime_cxx_api.h>
0007 #include <algorithm>
0008 #include <cstddef>
0009 #include <exception>
0010 #include <gsl/pointers>
0011 #include <iterator>
0012 #include <ostream>
0013 
0014 #include "InclusiveKinematicsML.h"
0015 
0016 namespace eicrecon {
0017 
0018 static std::string print_shape(const std::vector<std::int64_t>& v) {
0019   std::stringstream ss("");
0020   for (std::size_t i = 0; i < v.size() - 1; i++) {
0021     ss << v[i] << "x";
0022   }
0023   ss << v[v.size() - 1];
0024   return ss.str();
0025 }
0026 
0027 template <typename T>
0028 Ort::Value vec_to_tensor(std::vector<T>& data, const std::vector<std::int64_t>& shape) {
0029   Ort::MemoryInfo mem_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator,
0030                                                         OrtMemType::OrtMemTypeDefault);
0031   auto tensor =
0032       Ort::Value::CreateTensor<T>(mem_info, data.data(), data.size(), shape.data(), shape.size());
0033   return tensor;
0034 }
0035 
0036 void InclusiveKinematicsML::init() {
0037   // onnxruntime setup
0038   m_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "inclusive-kinematics-ml");
0039   Ort::SessionOptions session_options;
0040   session_options.SetInterOpNumThreads(1);
0041   session_options.SetIntraOpNumThreads(1);
0042   try {
0043     m_session = Ort::Session(m_env, m_cfg.modelPath.c_str(), session_options);
0044 
0045     // print name/shape of inputs
0046     Ort::AllocatorWithDefaultOptions allocator;
0047     debug("Input Node Name/Shape:");
0048     for (std::size_t i = 0; i < m_session.GetInputCount(); i++) {
0049       m_input_names.emplace_back(m_session.GetInputNameAllocated(i, allocator).get());
0050       if (m_session.GetInputTypeInfo(i).GetONNXType() == ONNX_TYPE_TENSOR) {
0051         m_input_shapes.emplace_back(
0052             m_session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
0053         debug("\t{} : {}", m_input_names.at(i), print_shape(m_input_shapes.at(i)));
0054       } else {
0055         m_input_shapes.emplace_back();
0056         debug("\t{} : not a tensor", m_input_names.at(i));
0057       }
0058     }
0059 
0060     // print name/shape of outputs
0061     debug("Output Node Name/Shape:");
0062     for (std::size_t i = 0; i < m_session.GetOutputCount(); i++) {
0063       m_output_names.emplace_back(m_session.GetOutputNameAllocated(i, allocator).get());
0064       if (m_session.GetOutputTypeInfo(i).GetONNXType() == ONNX_TYPE_TENSOR) {
0065         m_output_shapes.emplace_back(
0066             m_session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
0067         debug("\t{} : {}", m_output_names.at(i), print_shape(m_output_shapes.at(i)));
0068       } else {
0069         m_output_shapes.emplace_back();
0070         debug("\t{} : not a tensor", m_output_names.at(i));
0071       }
0072     }
0073 
0074     // convert names to char*
0075     m_input_names_char.resize(m_input_names.size(), nullptr);
0076     std::transform(std::begin(m_input_names), std::end(m_input_names),
0077                    std::begin(m_input_names_char),
0078                    [&](const std::string& str) { return str.c_str(); });
0079     m_output_names_char.resize(m_output_names.size(), nullptr);
0080     std::transform(std::begin(m_output_names), std::end(m_output_names),
0081                    std::begin(m_output_names_char),
0082                    [&](const std::string& str) { return str.c_str(); });
0083 
0084   } catch (std::exception& e) {
0085     error(e.what());
0086   }
0087 }
0088 
0089 void InclusiveKinematicsML::process(const InclusiveKinematicsML::Input& input,
0090                                     const InclusiveKinematicsML::Output& output) const {
0091 
0092   const auto [electron, da] = input;
0093   auto [ml]                 = output;
0094 
0095   // Require valid inputs
0096   if (electron->empty() || da->empty()) {
0097     debug("skipping because input collections have no entries");
0098     return;
0099   }
0100 
0101   // Assume model has 1 input nodes and 1 output node.
0102   if (m_input_names.size() != 1 || m_output_names.size() != 1) {
0103     debug("skipping because model has incorrect input and output size");
0104     return;
0105   }
0106 
0107   // Prepare input tensor
0108   std::vector<float> input_tensor_values;
0109   std::vector<Ort::Value> input_tensors;
0110   for (auto&& i : *electron) {
0111     input_tensor_values.push_back(i.getX());
0112   }
0113   input_tensors.emplace_back(vec_to_tensor<float>(input_tensor_values, m_input_shapes.front()));
0114 
0115   // Double-check the dimensions of the input tensor
0116   if (!input_tensors[0].IsTensor() ||
0117       input_tensors[0].GetTensorTypeAndShapeInfo().GetShape() != m_input_shapes.front()) {
0118     debug("skipping because input tensor shape incorrect");
0119     return;
0120   }
0121 
0122   // Attempt inference
0123   try {
0124     auto output_tensors = m_session.Run(Ort::RunOptions{nullptr}, m_input_names_char.data(),
0125                                         input_tensors.data(), m_input_names_char.size(),
0126                                         m_output_names_char.data(), m_output_names_char.size());
0127 
0128     // Double-check the dimensions of the output tensors
0129     if (!output_tensors[0].IsTensor() || output_tensors.size() != m_output_names.size()) {
0130       debug("skipping because output tensor shape incorrect");
0131       return;
0132     }
0133 
0134     // Convert output tensor
0135     auto* output_tensor_data = output_tensors[0].GetTensorMutableData<float>();
0136     auto x   = output_tensor_data[0]; // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic)
0137     auto kin = ml->create();
0138     kin.setX(x);
0139 
0140   } catch (const Ort::Exception& exception) {
0141     error("error running model inference: {}", exception.what());
0142   }
0143 }
0144 
0145 } // namespace eicrecon