File indexing completed on 2025-01-30 09:15:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/buildEdges.hpp"
0012
0013 #include <onnxruntime_cxx_api.h>
0014 #include <torch/script.h>
0015
0016 #include "runSessionWithIoBinding.hpp"
0017
0018 namespace Acts {
0019
0020 OnnxMetricLearning::OnnxMetricLearning(const Config& cfg,
0021 std::unique_ptr<const Logger> logger)
0022 : m_logger(std::move(logger)),
0023 m_cfg(cfg),
0024 m_device(torch::Device(torch::kCPU)) {
0025 m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
0026 "ExaTrkX - metric learning");
0027
0028 Ort::SessionOptions session_options;
0029 session_options.SetIntraOpNumThreads(1);
0030 session_options.SetGraphOptimizationLevel(
0031 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0032
0033 m_model = std::make_unique<Ort::Session>(*m_env, m_cfg.modelPath.c_str(),
0034 session_options);
0035 }
0036
0037 OnnxMetricLearning::~OnnxMetricLearning() {}
0038
0039 void OnnxMetricLearning::buildEdgesWrapper(std::vector<float>& embedFeatures,
0040 std::vector<std::int64_t>& edgeList,
0041 std::int64_t numSpacepoints,
0042 const Logger& logger) const {
0043 torch::Device device(torch::kCUDA);
0044 auto options =
0045 torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
0046
0047 torch::Tensor embedTensor =
0048 torch::tensor(embedFeatures, options)
0049 .reshape({numSpacepoints, m_cfg.embeddingDim});
0050
0051 auto stackedEdges = detail::buildEdges(embedTensor, m_cfg.rVal, m_cfg.knnVal);
0052
0053 stackedEdges = stackedEdges.toType(torch::kInt64).to(torch::kCPU);
0054
0055 ACTS_VERBOSE("copy edges to std::vector");
0056 std::copy(stackedEdges.data_ptr<std::int64_t>(),
0057 stackedEdges.data_ptr<std::int64_t>() + stackedEdges.numel(),
0058 std::back_inserter(edgeList));
0059 }
0060
0061 std::tuple<std::any, std::any> OnnxMetricLearning::operator()(
0062 std::vector<float>& inputValues, std::size_t, torch::Device) {
0063 Ort::AllocatorWithDefaultOptions allocator;
0064 auto memoryInfo = Ort::MemoryInfo::CreateCpu(
0065 OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0066
0067
0068
0069
0070
0071 std::int64_t numSpacepoints = inputValues.size() / m_cfg.spacepointFeatures;
0072 std::vector<std::int64_t> eInputShape{numSpacepoints,
0073 m_cfg.spacepointFeatures};
0074
0075 std::vector<const char*> eInputNames{"sp_features"};
0076 std::vector<Ort::Value> eInputTensor;
0077 eInputTensor.push_back(Ort::Value::CreateTensor<float>(
0078 memoryInfo, inputValues.data(), inputValues.size(), eInputShape.data(),
0079 eInputShape.size()));
0080
0081 std::vector<float> eOutputData(numSpacepoints * m_cfg.embeddingDim);
0082 std::vector<const char*> eOutputNames{"embedding_output"};
0083 std::vector<std::int64_t> eOutputShape{numSpacepoints, m_cfg.embeddingDim};
0084 std::vector<Ort::Value> eOutputTensor;
0085 eOutputTensor.push_back(Ort::Value::CreateTensor<float>(
0086 memoryInfo, eOutputData.data(), eOutputData.size(), eOutputShape.data(),
0087 eOutputShape.size()));
0088 runSessionWithIoBinding(*m_model, eInputNames, eInputTensor, eOutputNames,
0089 eOutputTensor);
0090
0091 ACTS_VERBOSE("Embedding space of the first SP: ");
0092 for (std::size_t i = 0; i < 3; i++) {
0093 ACTS_VERBOSE("\t" << eOutputData[i]);
0094 }
0095
0096
0097
0098
0099 std::vector<std::int64_t> edgeList;
0100 buildEdgesWrapper(eOutputData, edgeList, numSpacepoints, logger());
0101 std::int64_t numEdges = edgeList.size() / 2;
0102 ACTS_DEBUG("Graph construction: built " << numEdges << " edges.");
0103
0104 for (std::size_t i = 0; i < 10; i++) {
0105 ACTS_VERBOSE(edgeList[i]);
0106 }
0107 for (std::size_t i = 0; i < 10; i++) {
0108 ACTS_VERBOSE(edgeList[numEdges + i]);
0109 }
0110
0111 return {std::make_shared<Ort::Value>(std::move(eInputTensor[0])), edgeList};
0112 }
0113
0114 }