Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-11 07:50:57

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
0010 
0011 #include <boost/container/static_vector.hpp>
0012 #include <onnxruntime_cxx_api.h>
0013 
0014 namespace bc = boost::container;
0015 
0016 namespace {
0017 
0018 template <typename T>
0019 Ort::Value toOnnx(Ort::MemoryInfo &memoryInfo, Acts::Tensor<T> &tensor,
0020                   std::size_t rank = 2) {
0021   assert(rank == 1 || rank == 2);
0022   ONNXTensorElementDataType onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
0023 
0024   if constexpr (std::is_same_v<T, float>) {
0025     onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
0026   } else if constexpr (std::is_same_v<T, std::int64_t>) {
0027     onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
0028   } else {
0029     throw std::runtime_error(
0030         "Cannot convert Acts::Tensor to Ort::Value (datatype)");
0031   }
0032 
0033   bc::static_vector<std::int64_t, 2> shape;
0034   for (auto size : tensor.shape()) {
0035     // If rank is 1 and we encounter a dimension with size 1, then we skip it
0036     if (size > 1 || rank == 2) {
0037       shape.push_back(size);
0038     }
0039   }
0040 
0041   assert(shape.size() == rank);
0042   return Ort::Value::CreateTensor(memoryInfo, tensor.data(), tensor.nbytes(),
0043                                   shape.data(), shape.size(), onnxType);
0044 }
0045 
0046 }  // namespace
0047 
0048 namespace Acts {
0049 
0050 OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
0051                                        std::unique_ptr<const Logger> _logger)
0052     : m_logger(std::move(_logger)), m_cfg(cfg) {
0053   ACTS_INFO("OnnxEdgeClassifier with ORT API version " << ORT_API_VERSION);
0054 
0055   OrtLoggingLevel onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0056   switch (m_logger->level()) {
0057     case Acts::Logging::VERBOSE:
0058       onnxLevel = ORT_LOGGING_LEVEL_VERBOSE;
0059       break;
0060     case Acts::Logging::DEBUG:
0061       onnxLevel = ORT_LOGGING_LEVEL_INFO;
0062       break;
0063     case Acts::Logging::INFO:
0064       onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0065       break;
0066     case Acts::Logging::WARNING:
0067       onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0068       break;
0069     case Acts::Logging::ERROR:
0070       onnxLevel = ORT_LOGGING_LEVEL_ERROR;
0071       break;
0072     case Acts::Logging::FATAL:
0073       onnxLevel = ORT_LOGGING_LEVEL_FATAL;
0074       break;
0075     default:
0076       throw std::runtime_error("Invalid log level");
0077   }
0078 
0079   m_env = std::make_unique<Ort::Env>(onnxLevel, "ExaTrkX - edge classifier");
0080 
0081   Ort::SessionOptions sessionOptions;
0082   sessionOptions.SetIntraOpNumThreads(1);
0083   sessionOptions.SetGraphOptimizationLevel(
0084       GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0085 
0086 #ifndef ACTS_EXATRKX_CPUONLY
0087   ACTS_INFO("Try to add ONNX execution provider for CUDA");
0088   OrtCUDAProviderOptions cuda_options;
0089   cuda_options.device_id = 0;
0090   sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
0091 #endif
0092 
0093   m_model = std::make_unique<Ort::Session>(*m_env, m_cfg.modelPath.c_str(),
0094                                            sessionOptions);
0095 
0096   Ort::AllocatorWithDefaultOptions allocator;
0097 
0098   if (m_model->GetInputCount() < 2 || m_model->GetInputCount() > 3) {
0099     throw std::invalid_argument("ONNX edge classifier needs 2 or 3 inputs!");
0100   }
0101 
0102   for (std::size_t i = 0; i < m_model->GetInputCount(); ++i) {
0103     m_inputNames.emplace_back(
0104         m_model->GetInputNameAllocated(i, allocator).get());
0105   }
0106 
0107   if (m_model->GetOutputCount() != 1) {
0108     throw std::invalid_argument(
0109         "ONNX edge classifier needs exactly one output!");
0110   }
0111 
0112   m_outputName =
0113       std::string(m_model->GetOutputNameAllocated(0, allocator).get());
0114 }
0115 
0116 OnnxEdgeClassifier::~OnnxEdgeClassifier() {}
0117 
0118 PipelineTensors OnnxEdgeClassifier::operator()(
0119     PipelineTensors tensors, const ExecutionContext &execContext) {
0120   const char *deviceStr = "Cpu";
0121   if (execContext.device.type == Acts::Device::Type::eCUDA) {
0122     deviceStr = "Cuda";
0123   }
0124 
0125   ACTS_DEBUG("Create ORT memory info (" << deviceStr << ")");
0126   Ort::MemoryInfo memoryInfo(deviceStr, OrtArenaAllocator,
0127                              execContext.device.index, OrtMemTypeDefault);
0128 
0129   bc::static_vector<Ort::Value, 3> inputTensors;
0130   bc::static_vector<const char *, 3> inputNames;
0131 
0132   // Node tensor
0133   inputTensors.push_back(toOnnx(memoryInfo, tensors.nodeFeatures));
0134   inputNames.push_back(m_inputNames.at(0).c_str());
0135 
0136   // Edge tensor
0137   inputTensors.push_back(toOnnx(memoryInfo, tensors.edgeIndex));
0138   inputNames.push_back(m_inputNames.at(1).c_str());
0139 
0140   // Edge feature tensor
0141   std::optional<Acts::Tensor<float>> edgeFeatures;
0142   if (m_inputNames.size() == 3 && tensors.edgeFeatures.has_value()) {
0143     inputTensors.push_back(toOnnx(memoryInfo, *tensors.edgeFeatures));
0144     inputNames.push_back(m_inputNames.at(2).c_str());
0145   }
0146 
0147   // Output score tensor
0148   ACTS_DEBUG("Create score tensor");
0149   auto scores = Acts::Tensor<float>::Create({tensors.edgeIndex.shape()[1], 1ul},
0150                                             execContext);
0151 
0152   std::vector<Ort::Value> outputTensors;
0153   auto outputRank = m_model->GetOutputTypeInfo(0)
0154                         .GetTensorTypeAndShapeInfo()
0155                         .GetDimensionsCount();
0156   outputTensors.push_back(toOnnx(memoryInfo, scores, outputRank));
0157   std::vector<const char *> outputNames{m_outputName.c_str()};
0158 
0159   ACTS_DEBUG("Run model");
0160   Ort::RunOptions options;
0161   m_model->Run(options, inputNames.data(), inputTensors.data(),
0162                inputTensors.size(), outputNames.data(), outputTensors.data(),
0163                outputNames.size());
0164 
0165   sigmoid(scores, execContext.stream);
0166   auto [newScores, newEdgeIndex] =
0167       applyScoreCut(scores, tensors.edgeIndex, m_cfg.cut, execContext.stream);
0168 
0169   ACTS_DEBUG("Finished edge classification, after cut: "
0170              << newEdgeIndex.shape()[1] << " edges.");
0171 
0172   if (newEdgeIndex.shape()[1] == 0) {
0173     throw Acts::NoEdgesError{};
0174   }
0175 
0176   return {std::move(tensors.nodeFeatures),
0177           std::move(newEdgeIndex),
0178           {},
0179           std::move(newScores)};
0180 }
0181 
0182 }  // namespace Acts