File indexing completed on 2025-07-11 07:50:57
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
0010
0011 #include <boost/container/static_vector.hpp>
0012 #include <onnxruntime_cxx_api.h>
0013
0014 namespace bc = boost::container;
0015
0016 namespace {
0017
0018 template <typename T>
0019 Ort::Value toOnnx(Ort::MemoryInfo &memoryInfo, Acts::Tensor<T> &tensor,
0020 std::size_t rank = 2) {
0021 assert(rank == 1 || rank == 2);
0022 ONNXTensorElementDataType onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
0023
0024 if constexpr (std::is_same_v<T, float>) {
0025 onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
0026 } else if constexpr (std::is_same_v<T, std::int64_t>) {
0027 onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
0028 } else {
0029 throw std::runtime_error(
0030 "Cannot convert Acts::Tensor to Ort::Value (datatype)");
0031 }
0032
0033 bc::static_vector<std::int64_t, 2> shape;
0034 for (auto size : tensor.shape()) {
0035
0036 if (size > 1 || rank == 2) {
0037 shape.push_back(size);
0038 }
0039 }
0040
0041 assert(shape.size() == rank);
0042 return Ort::Value::CreateTensor(memoryInfo, tensor.data(), tensor.nbytes(),
0043 shape.data(), shape.size(), onnxType);
0044 }
0045
0046 }
0047
0048 namespace Acts {
0049
0050 OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
0051 std::unique_ptr<const Logger> _logger)
0052 : m_logger(std::move(_logger)), m_cfg(cfg) {
0053 ACTS_INFO("OnnxEdgeClassifier with ORT API version " << ORT_API_VERSION);
0054
0055 OrtLoggingLevel onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0056 switch (m_logger->level()) {
0057 case Acts::Logging::VERBOSE:
0058 onnxLevel = ORT_LOGGING_LEVEL_VERBOSE;
0059 break;
0060 case Acts::Logging::DEBUG:
0061 onnxLevel = ORT_LOGGING_LEVEL_INFO;
0062 break;
0063 case Acts::Logging::INFO:
0064 onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0065 break;
0066 case Acts::Logging::WARNING:
0067 onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0068 break;
0069 case Acts::Logging::ERROR:
0070 onnxLevel = ORT_LOGGING_LEVEL_ERROR;
0071 break;
0072 case Acts::Logging::FATAL:
0073 onnxLevel = ORT_LOGGING_LEVEL_FATAL;
0074 break;
0075 default:
0076 throw std::runtime_error("Invalid log level");
0077 }
0078
0079 m_env = std::make_unique<Ort::Env>(onnxLevel, "ExaTrkX - edge classifier");
0080
0081 Ort::SessionOptions sessionOptions;
0082 sessionOptions.SetIntraOpNumThreads(1);
0083 sessionOptions.SetGraphOptimizationLevel(
0084 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0085
0086 #ifndef ACTS_EXATRKX_CPUONLY
0087 ACTS_INFO("Try to add ONNX execution provider for CUDA");
0088 OrtCUDAProviderOptions cuda_options;
0089 cuda_options.device_id = 0;
0090 sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
0091 #endif
0092
0093 m_model = std::make_unique<Ort::Session>(*m_env, m_cfg.modelPath.c_str(),
0094 sessionOptions);
0095
0096 Ort::AllocatorWithDefaultOptions allocator;
0097
0098 if (m_model->GetInputCount() < 2 || m_model->GetInputCount() > 3) {
0099 throw std::invalid_argument("ONNX edge classifier needs 2 or 3 inputs!");
0100 }
0101
0102 for (std::size_t i = 0; i < m_model->GetInputCount(); ++i) {
0103 m_inputNames.emplace_back(
0104 m_model->GetInputNameAllocated(i, allocator).get());
0105 }
0106
0107 if (m_model->GetOutputCount() != 1) {
0108 throw std::invalid_argument(
0109 "ONNX edge classifier needs exactly one output!");
0110 }
0111
0112 m_outputName =
0113 std::string(m_model->GetOutputNameAllocated(0, allocator).get());
0114 }
0115
0116 OnnxEdgeClassifier::~OnnxEdgeClassifier() {}
0117
0118 PipelineTensors OnnxEdgeClassifier::operator()(
0119 PipelineTensors tensors, const ExecutionContext &execContext) {
0120 const char *deviceStr = "Cpu";
0121 if (execContext.device.type == Acts::Device::Type::eCUDA) {
0122 deviceStr = "Cuda";
0123 }
0124
0125 ACTS_DEBUG("Create ORT memory info (" << deviceStr << ")");
0126 Ort::MemoryInfo memoryInfo(deviceStr, OrtArenaAllocator,
0127 execContext.device.index, OrtMemTypeDefault);
0128
0129 bc::static_vector<Ort::Value, 3> inputTensors;
0130 bc::static_vector<const char *, 3> inputNames;
0131
0132
0133 inputTensors.push_back(toOnnx(memoryInfo, tensors.nodeFeatures));
0134 inputNames.push_back(m_inputNames.at(0).c_str());
0135
0136
0137 inputTensors.push_back(toOnnx(memoryInfo, tensors.edgeIndex));
0138 inputNames.push_back(m_inputNames.at(1).c_str());
0139
0140
0141 std::optional<Acts::Tensor<float>> edgeFeatures;
0142 if (m_inputNames.size() == 3 && tensors.edgeFeatures.has_value()) {
0143 inputTensors.push_back(toOnnx(memoryInfo, *tensors.edgeFeatures));
0144 inputNames.push_back(m_inputNames.at(2).c_str());
0145 }
0146
0147
0148 ACTS_DEBUG("Create score tensor");
0149 auto scores = Acts::Tensor<float>::Create({tensors.edgeIndex.shape()[1], 1ul},
0150 execContext);
0151
0152 std::vector<Ort::Value> outputTensors;
0153 auto outputRank = m_model->GetOutputTypeInfo(0)
0154 .GetTensorTypeAndShapeInfo()
0155 .GetDimensionsCount();
0156 outputTensors.push_back(toOnnx(memoryInfo, scores, outputRank));
0157 std::vector<const char *> outputNames{m_outputName.c_str()};
0158
0159 ACTS_DEBUG("Run model");
0160 Ort::RunOptions options;
0161 m_model->Run(options, inputNames.data(), inputTensors.data(),
0162 inputTensors.size(), outputNames.data(), outputTensors.data(),
0163 outputNames.size());
0164
0165 sigmoid(scores, execContext.stream);
0166 auto [newScores, newEdgeIndex] =
0167 applyScoreCut(scores, tensors.edgeIndex, m_cfg.cut, execContext.stream);
0168
0169 ACTS_DEBUG("Finished edge classification, after cut: "
0170 << newEdgeIndex.shape()[1] << " edges.");
0171
0172 if (newEdgeIndex.shape()[1] == 0) {
0173 throw Acts::NoEdgesError{};
0174 }
0175
0176 return {std::move(tensors.nodeFeatures),
0177 std::move(newEdgeIndex),
0178 {},
0179 std::move(newScores)};
0180 }
0181
0182 }