File indexing completed on 2024-06-18 07:05:42
0001
0002
0003
0004 #include <fmt/core.h>
0005 #include <onnxruntime_c_api.h>
0006 #include <onnxruntime_cxx_api.h>
0007 #include <algorithm>
0008 #include <cstddef>
0009 #include <exception>
0010 #include <gsl/pointers>
0011 #include <iterator>
0012 #include <ostream>
0013
0014 #include "InclusiveKinematicsML.h"
0015
0016 namespace eicrecon {
0017
0018 static std::string print_shape(const std::vector<std::int64_t>& v) {
0019 std::stringstream ss("");
0020 for (std::size_t i = 0; i < v.size() - 1; i++) ss << v[i] << "x";
0021 ss << v[v.size() - 1];
0022 return ss.str();
0023 }
0024
0025 template <typename T>
0026 Ort::Value vec_to_tensor(std::vector<T>& data, const std::vector<std::int64_t>& shape) {
0027 Ort::MemoryInfo mem_info =
0028 Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0029 auto tensor = Ort::Value::CreateTensor<T>(mem_info, data.data(), data.size(), shape.data(), shape.size());
0030 return tensor;
0031 }
0032
0033 void InclusiveKinematicsML::init() {
0034
0035 Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "inclusive-kinematics-ml");
0036 Ort::SessionOptions session_options;
0037 try {
0038 m_session = Ort::Session(env, m_cfg.modelPath.c_str(), session_options);
0039
0040
0041 Ort::AllocatorWithDefaultOptions allocator;
0042 debug("Input Node Name/Shape:");
0043 for (std::size_t i = 0; i < m_session.GetInputCount(); i++) {
0044 m_input_names.emplace_back(m_session.GetInputNameAllocated(i, allocator).get());
0045 m_input_shapes.emplace_back(m_session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
0046 debug("\t{} : {}", m_input_names.at(i), print_shape(m_input_shapes.at(i)));
0047 }
0048
0049
0050 debug("Output Node Name/Shape:");
0051 for (std::size_t i = 0; i < m_session.GetOutputCount(); i++) {
0052 m_output_names.emplace_back(m_session.GetOutputNameAllocated(i, allocator).get());
0053 m_output_shapes.emplace_back(m_session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
0054 debug("\t{} : {}", m_output_names.at(i), print_shape(m_output_shapes.at(i)));
0055 }
0056
0057
0058 m_input_names_char.resize(m_input_names.size(), nullptr);
0059 std::transform(std::begin(m_input_names), std::end(m_input_names), std::begin(m_input_names_char),
0060 [&](const std::string& str) { return str.c_str(); });
0061 m_output_names_char.resize(m_output_names.size(), nullptr);
0062 std::transform(std::begin(m_output_names), std::end(m_output_names), std::begin(m_output_names_char),
0063 [&](const std::string& str) { return str.c_str(); });
0064
0065 } catch(std::exception& e) {
0066 error(e.what());
0067 }
0068 }
0069
0070 void InclusiveKinematicsML::process(
0071 const InclusiveKinematicsML::Input& input,
0072 const InclusiveKinematicsML::Output& output) const {
0073
0074 const auto [electron, da] = input;
0075 auto [ml] = output;
0076
0077
0078 if (electron->size() == 0 || da->size() == 0) {
0079 debug("skipping because input collections have no entries");
0080 return;
0081 }
0082
0083
0084 if (m_input_names.size() != 1 || m_output_names.size() != 1) {
0085 debug("skipping because model has incorrect input and output size");
0086 return;
0087 }
0088
0089
0090 std::vector<float> input_tensor_values;
0091 std::vector<Ort::Value> input_tensors;
0092 for (std::size_t i = 0; i < electron->size(); i++) {
0093 input_tensor_values.push_back(electron->at(i).getX());
0094 }
0095 input_tensors.emplace_back(vec_to_tensor<float>(input_tensor_values, m_input_shapes.front()));
0096
0097
0098 if (! input_tensors[0].IsTensor() || input_tensors[0].GetTensorTypeAndShapeInfo().GetShape() != m_input_shapes.front()) {
0099 debug("skipping because input tensor shape incorrect");
0100 return;
0101 }
0102
0103
0104 try {
0105 auto output_tensors = m_session.Run(Ort::RunOptions{nullptr}, m_input_names_char.data(), input_tensors.data(),
0106 m_input_names_char.size(), m_output_names_char.data(), m_output_names_char.size());
0107
0108
0109 if (!output_tensors[0].IsTensor() || output_tensors.size() != m_output_names.size()) {
0110 debug("skipping because output tensor shape incorrect");
0111 return;
0112 }
0113
0114
0115 float* output_tensor_data = output_tensors[0].GetTensorMutableData<float>();
0116 auto x = output_tensor_data[0];
0117 auto kin = ml->create();
0118 kin.setX(x);
0119
0120 } catch (const Ort::Exception& exception) {
0121 error("error running model inference: {}", exception.what());
0122 }
0123 }
0124
0125 }