Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-17 09:21:31

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsPlugins/Gnn/OnnxEdgeClassifier.hpp"
0010 
0011 #include <boost/container/static_vector.hpp>
0012 #include <onnxruntime_cxx_api.h>
0013 
0014 namespace bc = boost::container;
0015 
0016 namespace {
0017 
0018 template <typename T>
0019 Ort::Value toOnnx(Ort::MemoryInfo &memoryInfo, ActsPlugins::Tensor<T> &tensor,
0020                   std::size_t rank = 2) {
0021   assert(rank == 1 || rank == 2);
0022   ONNXTensorElementDataType onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
0023 
0024   if constexpr (std::is_same_v<T, float>) {
0025     onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
0026   } else if constexpr (std::is_same_v<T, std::int64_t>) {
0027     onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
0028   } else {
0029     throw std::runtime_error(
0030         "Cannot convert ActsPlugins::Tensor to Ort::Value (datatype)");
0031   }
0032 
0033   bc::static_vector<std::int64_t, 2> shape;
0034   for (auto size : tensor.shape()) {
0035     // If rank is 1 and we encounter a dimension with size 1, then we skip it
0036     if (size > 1 || rank == 2) {
0037       shape.push_back(size);
0038     }
0039   }
0040 
0041   assert(shape.size() == rank);
0042   return Ort::Value::CreateTensor(memoryInfo, tensor.data(), tensor.nbytes(),
0043                                   shape.data(), shape.size(), onnxType);
0044 }
0045 
0046 }  // namespace
0047 
0048 using namespace Acts;
0049 
0050 namespace ActsPlugins {
0051 
0052 OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
0053                                        std::unique_ptr<const Logger> _logger)
0054     : m_logger(std::move(_logger)), m_cfg(cfg) {
0055   ACTS_INFO("OnnxEdgeClassifier with ORT API version " << ORT_API_VERSION);
0056 
0057   OrtLoggingLevel onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0058   switch (m_logger->level()) {
0059     case Logging::VERBOSE:
0060       onnxLevel = ORT_LOGGING_LEVEL_VERBOSE;
0061       break;
0062     case Logging::DEBUG:
0063       onnxLevel = ORT_LOGGING_LEVEL_INFO;
0064       break;
0065     case Logging::INFO:
0066       onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0067       break;
0068     case Logging::WARNING:
0069       onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0070       break;
0071     case Logging::ERROR:
0072       onnxLevel = ORT_LOGGING_LEVEL_ERROR;
0073       break;
0074     case Logging::FATAL:
0075       onnxLevel = ORT_LOGGING_LEVEL_FATAL;
0076       break;
0077     default:
0078       throw std::runtime_error("Invalid log level");
0079   }
0080 
0081   m_env = std::make_unique<Ort::Env>(onnxLevel, "Gnn - edge classifier");
0082 
0083   Ort::SessionOptions sessionOptions;
0084   sessionOptions.SetIntraOpNumThreads(1);
0085   sessionOptions.SetGraphOptimizationLevel(
0086       GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0087 
0088 #ifndef ACTS_GNN_CPUONLY
0089   ACTS_INFO("Try to add ONNX execution provider for CUDA");
0090   OrtCUDAProviderOptions cuda_options;
0091   cuda_options.device_id = 0;
0092   sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
0093 #endif
0094 
0095   m_model = std::make_unique<Ort::Session>(*m_env, m_cfg.modelPath.c_str(),
0096                                            sessionOptions);
0097 
0098   Ort::AllocatorWithDefaultOptions allocator;
0099 
0100   if (m_model->GetInputCount() < 2 || m_model->GetInputCount() > 3) {
0101     throw std::invalid_argument("ONNX edge classifier needs 2 or 3 inputs!");
0102   }
0103 
0104   for (std::size_t i = 0; i < m_model->GetInputCount(); ++i) {
0105     m_inputNames.emplace_back(
0106         m_model->GetInputNameAllocated(i, allocator).get());
0107   }
0108 
0109   if (m_model->GetOutputCount() != 1) {
0110     throw std::invalid_argument(
0111         "ONNX edge classifier needs exactly one output!");
0112   }
0113 
0114   m_outputName =
0115       std::string(m_model->GetOutputNameAllocated(0, allocator).get());
0116 }
0117 
0118 OnnxEdgeClassifier::~OnnxEdgeClassifier() {}
0119 
0120 PipelineTensors OnnxEdgeClassifier::operator()(
0121     PipelineTensors tensors, const ExecutionContext &execContext) {
0122   const char *deviceStr = "Cpu";
0123   if (execContext.device.type == Device::Type::eCUDA) {
0124     deviceStr = "Cuda";
0125   }
0126 
0127   ACTS_DEBUG("Create ORT memory info (" << deviceStr << ")");
0128   Ort::MemoryInfo memoryInfo(deviceStr, OrtArenaAllocator,
0129                              execContext.device.index, OrtMemTypeDefault);
0130 
0131   bc::static_vector<Ort::Value, 3> inputTensors;
0132   bc::static_vector<const char *, 3> inputNames;
0133 
0134   // Node tensor
0135   inputTensors.push_back(toOnnx(memoryInfo, tensors.nodeFeatures));
0136   inputNames.push_back(m_inputNames.at(0).c_str());
0137 
0138   // Edge tensor
0139   inputTensors.push_back(toOnnx(memoryInfo, tensors.edgeIndex));
0140   inputNames.push_back(m_inputNames.at(1).c_str());
0141 
0142   // Edge feature tensor
0143   std::optional<Tensor<float>> edgeFeatures;
0144   if (m_inputNames.size() == 3 && tensors.edgeFeatures.has_value()) {
0145     inputTensors.push_back(toOnnx(memoryInfo, *tensors.edgeFeatures));
0146     inputNames.push_back(m_inputNames.at(2).c_str());
0147   }
0148 
0149   // Output score tensor
0150   ACTS_DEBUG("Create score tensor");
0151   auto scores =
0152       Tensor<float>::Create({tensors.edgeIndex.shape()[1], 1ul}, execContext);
0153 
0154   std::vector<Ort::Value> outputTensors;
0155   auto outputRank = m_model->GetOutputTypeInfo(0)
0156                         .GetTensorTypeAndShapeInfo()
0157                         .GetDimensionsCount();
0158   outputTensors.push_back(toOnnx(memoryInfo, scores, outputRank));
0159   std::vector<const char *> outputNames{m_outputName.c_str()};
0160 
0161   ACTS_DEBUG("Run model");
0162   Ort::RunOptions options;
0163   m_model->Run(options, inputNames.data(), inputTensors.data(),
0164                inputTensors.size(), outputNames.data(), outputTensors.data(),
0165                outputNames.size());
0166 
0167   sigmoid(scores, execContext.stream);
0168   auto [newScores, newEdgeIndex] =
0169       applyScoreCut(scores, tensors.edgeIndex, m_cfg.cut, execContext.stream);
0170 
0171   ACTS_DEBUG("Finished edge classification, after cut: "
0172              << newEdgeIndex.shape()[1] << " edges.");
0173 
0174   if (newEdgeIndex.shape()[1] == 0) {
0175     throw NoEdgesError{};
0176   }
0177 
0178   return {std::move(tensors.nodeFeatures),
0179           std::move(newEdgeIndex),
0180           {},
0181           std::move(newScores)};
0182 }
0183 
0184 }  // namespace ActsPlugins