File indexing completed on 2025-12-17 09:21:31
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsPlugins/Gnn/OnnxEdgeClassifier.hpp"
0010
0011 #include <boost/container/static_vector.hpp>
0012 #include <onnxruntime_cxx_api.h>
0013
0014 namespace bc = boost::container;
0015
0016 namespace {
0017
0018 template <typename T>
0019 Ort::Value toOnnx(Ort::MemoryInfo &memoryInfo, ActsPlugins::Tensor<T> &tensor,
0020 std::size_t rank = 2) {
0021 assert(rank == 1 || rank == 2);
0022 ONNXTensorElementDataType onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
0023
0024 if constexpr (std::is_same_v<T, float>) {
0025 onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
0026 } else if constexpr (std::is_same_v<T, std::int64_t>) {
0027 onnxType = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
0028 } else {
0029 throw std::runtime_error(
0030 "Cannot convert ActsPlugins::Tensor to Ort::Value (datatype)");
0031 }
0032
0033 bc::static_vector<std::int64_t, 2> shape;
0034 for (auto size : tensor.shape()) {
0035
0036 if (size > 1 || rank == 2) {
0037 shape.push_back(size);
0038 }
0039 }
0040
0041 assert(shape.size() == rank);
0042 return Ort::Value::CreateTensor(memoryInfo, tensor.data(), tensor.nbytes(),
0043 shape.data(), shape.size(), onnxType);
0044 }
0045
0046 }
0047
0048 using namespace Acts;
0049
0050 namespace ActsPlugins {
0051
0052 OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
0053 std::unique_ptr<const Logger> _logger)
0054 : m_logger(std::move(_logger)), m_cfg(cfg) {
0055 ACTS_INFO("OnnxEdgeClassifier with ORT API version " << ORT_API_VERSION);
0056
0057 OrtLoggingLevel onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0058 switch (m_logger->level()) {
0059 case Logging::VERBOSE:
0060 onnxLevel = ORT_LOGGING_LEVEL_VERBOSE;
0061 break;
0062 case Logging::DEBUG:
0063 onnxLevel = ORT_LOGGING_LEVEL_INFO;
0064 break;
0065 case Logging::INFO:
0066 onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0067 break;
0068 case Logging::WARNING:
0069 onnxLevel = ORT_LOGGING_LEVEL_WARNING;
0070 break;
0071 case Logging::ERROR:
0072 onnxLevel = ORT_LOGGING_LEVEL_ERROR;
0073 break;
0074 case Logging::FATAL:
0075 onnxLevel = ORT_LOGGING_LEVEL_FATAL;
0076 break;
0077 default:
0078 throw std::runtime_error("Invalid log level");
0079 }
0080
0081 m_env = std::make_unique<Ort::Env>(onnxLevel, "Gnn - edge classifier");
0082
0083 Ort::SessionOptions sessionOptions;
0084 sessionOptions.SetIntraOpNumThreads(1);
0085 sessionOptions.SetGraphOptimizationLevel(
0086 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0087
0088 #ifndef ACTS_GNN_CPUONLY
0089 ACTS_INFO("Try to add ONNX execution provider for CUDA");
0090 OrtCUDAProviderOptions cuda_options;
0091 cuda_options.device_id = 0;
0092 sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
0093 #endif
0094
0095 m_model = std::make_unique<Ort::Session>(*m_env, m_cfg.modelPath.c_str(),
0096 sessionOptions);
0097
0098 Ort::AllocatorWithDefaultOptions allocator;
0099
0100 if (m_model->GetInputCount() < 2 || m_model->GetInputCount() > 3) {
0101 throw std::invalid_argument("ONNX edge classifier needs 2 or 3 inputs!");
0102 }
0103
0104 for (std::size_t i = 0; i < m_model->GetInputCount(); ++i) {
0105 m_inputNames.emplace_back(
0106 m_model->GetInputNameAllocated(i, allocator).get());
0107 }
0108
0109 if (m_model->GetOutputCount() != 1) {
0110 throw std::invalid_argument(
0111 "ONNX edge classifier needs exactly one output!");
0112 }
0113
0114 m_outputName =
0115 std::string(m_model->GetOutputNameAllocated(0, allocator).get());
0116 }
0117
0118 OnnxEdgeClassifier::~OnnxEdgeClassifier() {}
0119
0120 PipelineTensors OnnxEdgeClassifier::operator()(
0121 PipelineTensors tensors, const ExecutionContext &execContext) {
0122 const char *deviceStr = "Cpu";
0123 if (execContext.device.type == Device::Type::eCUDA) {
0124 deviceStr = "Cuda";
0125 }
0126
0127 ACTS_DEBUG("Create ORT memory info (" << deviceStr << ")");
0128 Ort::MemoryInfo memoryInfo(deviceStr, OrtArenaAllocator,
0129 execContext.device.index, OrtMemTypeDefault);
0130
0131 bc::static_vector<Ort::Value, 3> inputTensors;
0132 bc::static_vector<const char *, 3> inputNames;
0133
0134
0135 inputTensors.push_back(toOnnx(memoryInfo, tensors.nodeFeatures));
0136 inputNames.push_back(m_inputNames.at(0).c_str());
0137
0138
0139 inputTensors.push_back(toOnnx(memoryInfo, tensors.edgeIndex));
0140 inputNames.push_back(m_inputNames.at(1).c_str());
0141
0142
0143 std::optional<Tensor<float>> edgeFeatures;
0144 if (m_inputNames.size() == 3 && tensors.edgeFeatures.has_value()) {
0145 inputTensors.push_back(toOnnx(memoryInfo, *tensors.edgeFeatures));
0146 inputNames.push_back(m_inputNames.at(2).c_str());
0147 }
0148
0149
0150 ACTS_DEBUG("Create score tensor");
0151 auto scores =
0152 Tensor<float>::Create({tensors.edgeIndex.shape()[1], 1ul}, execContext);
0153
0154 std::vector<Ort::Value> outputTensors;
0155 auto outputRank = m_model->GetOutputTypeInfo(0)
0156 .GetTensorTypeAndShapeInfo()
0157 .GetDimensionsCount();
0158 outputTensors.push_back(toOnnx(memoryInfo, scores, outputRank));
0159 std::vector<const char *> outputNames{m_outputName.c_str()};
0160
0161 ACTS_DEBUG("Run model");
0162 Ort::RunOptions options;
0163 m_model->Run(options, inputNames.data(), inputTensors.data(),
0164 inputTensors.size(), outputNames.data(), outputTensors.data(),
0165 outputNames.size());
0166
0167 sigmoid(scores, execContext.stream);
0168 auto [newScores, newEdgeIndex] =
0169 applyScoreCut(scores, tensors.edgeIndex, m_cfg.cut, execContext.stream);
0170
0171 ACTS_DEBUG("Finished edge classification, after cut: "
0172 << newEdgeIndex.shape()[1] << " edges.");
0173
0174 if (newEdgeIndex.shape()[1] == 0) {
0175 throw NoEdgesError{};
0176 }
0177
0178 return {std::move(tensors.nodeFeatures),
0179 std::move(newEdgeIndex),
0180 {},
0181 std::move(newScores)};
0182 }
0183
0184 }