File indexing completed on 2025-01-30 09:15:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"
0012
0013 #include <onnxruntime_cxx_api.h>
0014 #include <torch/script.h>
0015
0016 using namespace torch::indexing;
0017
0018 namespace Acts {
0019
0020 OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
0021 std::unique_ptr<const Logger> logger)
0022 : m_logger(std::move(logger)),
0023 m_cfg(cfg),
0024 m_device(torch::Device(torch::kCPU)) {
0025 m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
0026 "ExaTrkX - edge classifier");
0027
0028 Ort::SessionOptions session_options;
0029 session_options.SetIntraOpNumThreads(1);
0030
0031
0032
0033 OrtCUDAProviderOptions cuda_options;
0034 cuda_options.device_id = 0;
0035
0036
0037 m_model = std::make_unique<Ort::Session>(*m_env, m_cfg.modelPath.c_str(),
0038 session_options);
0039
0040 Ort::AllocatorWithDefaultOptions allocator;
0041
0042 for (std::size_t i = 0; i < m_model->GetInputCount(); ++i) {
0043 m_inputNames.emplace_back(
0044 m_model->GetInputNameAllocated(i, allocator).get());
0045 }
0046 m_outputName =
0047 std::string(m_model->GetOutputNameAllocated(0, allocator).get());
0048 }
0049
0050 OnnxEdgeClassifier::~OnnxEdgeClassifier() {}
0051
0052 template <typename T>
0053 auto torchToOnnx(Ort::MemoryInfo &memInfo, at::Tensor &tensor) {
0054 std::vector<std::int64_t> shape{tensor.size(0), tensor.size(1)};
0055 return Ort::Value::CreateTensor<T>(memInfo, tensor.data_ptr<T>(),
0056 tensor.numel(), shape.data(),
0057 shape.size());
0058 }
0059
0060 std::ostream &operator<<(std::ostream &os, Ort::Value &v) {
0061 if (!v.IsTensor()) {
0062 os << "no tensor";
0063 return os;
0064 }
0065
0066 auto shape = v.GetTensorTypeAndShapeInfo().GetShape();
0067
0068 auto printVal = [&]<typename T>() {
0069 for (int i = 0; i < shape.at(0); ++i) {
0070 for (int j = 0; j < shape.at(1); ++j) {
0071 os << v.At<T>({i, j}) << " ";
0072 }
0073 os << "\n";
0074 }
0075 };
0076
0077 auto type = v.GetTensorTypeAndShapeInfo().GetElementType();
0078 if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
0079 os << "[float tensor]\n";
0080 printVal.operator()<float>();
0081 } else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
0082 os << "[int64 tensor]\n";
0083 printVal.operator()<std::int64_t>();
0084 } else {
0085 os << "not implemented datatype";
0086 }
0087
0088 return os;
0089 }
0090
0091 std::tuple<std::any, std::any, std::any, std::any>
0092 OnnxEdgeClassifier::operator()(std::any inputNodes, std::any inputEdges,
0093 std::any inEdgeFeatures,
0094 const ExecutionContext & ) {
0095 auto torchDevice = torch::kCPU;
0096 Ort::MemoryInfo memoryInfo("Cpu", OrtArenaAllocator, 0,
0097 OrtMemTypeDefault);
0098
0099 Ort::Allocator allocator(*m_model, memoryInfo);
0100
0101 auto nodeTensor =
0102 std::any_cast<torch::Tensor>(inputNodes).to(torchDevice).clone();
0103 auto edgeList = std::any_cast<torch::Tensor>(inputEdges).to(torchDevice);
0104 const int numEdges = edgeList.size(1);
0105
0106 std::vector<const char *> inputNames{m_inputNames.at(0).c_str(),
0107 m_inputNames.at(1).c_str()};
0108
0109
0110 auto edgeListClone = edgeList.clone().contiguous();
0111 ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeListClone});
0112 auto nodeTensorClone = nodeTensor.clone();
0113 ACTS_DEBUG("nodes: " << detail::TensorDetails{nodeTensorClone});
0114 std::vector<Ort::Value> inputTensors;
0115 inputTensors.push_back(torchToOnnx<float>(memoryInfo, nodeTensorClone));
0116 inputTensors.push_back(torchToOnnx<std::int64_t>(memoryInfo, edgeListClone));
0117
0118 std::optional<at::Tensor> edgeAttrTensor;
0119 if (inEdgeFeatures.has_value()) {
0120 inputNames.push_back(m_inputNames.at(2).c_str());
0121 edgeAttrTensor =
0122 std::any_cast<torch::Tensor>(inEdgeFeatures).to(torchDevice).clone();
0123 inputTensors.push_back(torchToOnnx<float>(memoryInfo, *edgeAttrTensor));
0124 }
0125
0126 std::vector<const char *> outputNames{m_outputName.c_str()};
0127
0128 auto outputTensor =
0129 m_model->Run({}, inputNames.data(), inputTensors.data(),
0130 inputTensors.size(), outputNames.data(), outputNames.size());
0131
0132 float *rawOutData = nullptr;
0133 if (outputTensor.at(0).GetTensorTypeAndShapeInfo().GetElementType() ==
0134 ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
0135 rawOutData = outputTensor.at(0).GetTensorMutableData<float>();
0136 } else {
0137 throw std::runtime_error("Invalid output datatype");
0138 }
0139
0140 ACTS_DEBUG("Get scores for " << numEdges << " edges.");
0141 auto scores =
0142 torch::from_blob(
0143 rawOutData, {numEdges},
0144 torch::TensorOptions().device(torchDevice).dtype(torch::kFloat32))
0145 .clone();
0146
0147 ACTS_VERBOSE("Slice of classified output before sigmoid:\n"
0148 << scores.slice(0, 0, 9));
0149
0150 scores.sigmoid_();
0151
0152 ACTS_DEBUG("scores: " << detail::TensorDetails{scores});
0153 ACTS_VERBOSE("Slice of classified output:\n"
0154 << scores.slice(0, 0, 9));
0155
0156 torch::Tensor filterMask = scores > m_cfg.cut;
0157 torch::Tensor edgesAfterCut = edgeList.index({Slice(), filterMask});
0158
0159 ACTS_DEBUG("Finished edge classification, after cut: "
0160 << edgesAfterCut.size(1) << " edges.");
0161
0162 if (edgesAfterCut.size(1) == 0) {
0163 throw Acts::NoEdgesError{};
0164 }
0165
0166 return {std::move(nodeTensor), edgesAfterCut.clone(),
0167 std::move(inEdgeFeatures), std::move(scores)};
0168 }
0169
0170 }