Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 09:15:13

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"
0012 
0013 #include <onnxruntime_cxx_api.h>
0014 #include <torch/script.h>
0015 
0016 using namespace torch::indexing;
0017 
0018 namespace Acts {
0019 
0020 OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
0021                                        std::unique_ptr<const Logger> logger)
0022     : m_logger(std::move(logger)),
0023       m_cfg(cfg),
0024       m_device(torch::Device(torch::kCPU)) {
0025   m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
0026                                      "ExaTrkX - edge classifier");
0027 
0028   Ort::SessionOptions session_options;
0029   session_options.SetIntraOpNumThreads(1);
0030   // session_options.SetGraphOptimizationLevel(
0031   //     GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0032 
0033   OrtCUDAProviderOptions cuda_options;
0034   cuda_options.device_id = 0;
0035   // session_options.AppendExecutionProvider_CUDA(cuda_options);
0036 
0037   m_model = std::make_unique<Ort::Session>(*m_env, m_cfg.modelPath.c_str(),
0038                                            session_options);
0039 
0040   Ort::AllocatorWithDefaultOptions allocator;
0041 
0042   for (std::size_t i = 0; i < m_model->GetInputCount(); ++i) {
0043     m_inputNames.emplace_back(
0044         m_model->GetInputNameAllocated(i, allocator).get());
0045   }
0046   m_outputName =
0047       std::string(m_model->GetOutputNameAllocated(0, allocator).get());
0048 }
0049 
0050 OnnxEdgeClassifier::~OnnxEdgeClassifier() {}
0051 
0052 template <typename T>
0053 auto torchToOnnx(Ort::MemoryInfo &memInfo, at::Tensor &tensor) {
0054   std::vector<std::int64_t> shape{tensor.size(0), tensor.size(1)};
0055   return Ort::Value::CreateTensor<T>(memInfo, tensor.data_ptr<T>(),
0056                                      tensor.numel(), shape.data(),
0057                                      shape.size());
0058 }
0059 
0060 std::ostream &operator<<(std::ostream &os, Ort::Value &v) {
0061   if (!v.IsTensor()) {
0062     os << "no tensor";
0063     return os;
0064   }
0065 
0066   auto shape = v.GetTensorTypeAndShapeInfo().GetShape();
0067 
0068   auto printVal = [&]<typename T>() {
0069     for (int i = 0; i < shape.at(0); ++i) {
0070       for (int j = 0; j < shape.at(1); ++j) {
0071         os << v.At<T>({i, j}) << " ";
0072       }
0073       os << "\n";
0074     }
0075   };
0076 
0077   auto type = v.GetTensorTypeAndShapeInfo().GetElementType();
0078   if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
0079     os << "[float tensor]\n";
0080     printVal.operator()<float>();
0081   } else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
0082     os << "[int64 tensor]\n";
0083     printVal.operator()<std::int64_t>();
0084   } else {
0085     os << "not implemented datatype";
0086   }
0087 
0088   return os;
0089 }
0090 
0091 std::tuple<std::any, std::any, std::any, std::any>
0092 OnnxEdgeClassifier::operator()(std::any inputNodes, std::any inputEdges,
0093                                std::any inEdgeFeatures,
0094                                const ExecutionContext & /*unused*/) {
0095   auto torchDevice = torch::kCPU;
0096   Ort::MemoryInfo memoryInfo("Cpu", OrtArenaAllocator, /*device_id*/ 0,
0097                              OrtMemTypeDefault);
0098 
0099   Ort::Allocator allocator(*m_model, memoryInfo);
0100 
0101   auto nodeTensor =
0102       std::any_cast<torch::Tensor>(inputNodes).to(torchDevice).clone();
0103   auto edgeList = std::any_cast<torch::Tensor>(inputEdges).to(torchDevice);
0104   const int numEdges = edgeList.size(1);
0105 
0106   std::vector<const char *> inputNames{m_inputNames.at(0).c_str(),
0107                                        m_inputNames.at(1).c_str()};
0108 
0109   // TODO move this contiguous to graph construction
0110   auto edgeListClone = edgeList.clone().contiguous();
0111   ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeListClone});
0112   auto nodeTensorClone = nodeTensor.clone();
0113   ACTS_DEBUG("nodes: " << detail::TensorDetails{nodeTensorClone});
0114   std::vector<Ort::Value> inputTensors;
0115   inputTensors.push_back(torchToOnnx<float>(memoryInfo, nodeTensorClone));
0116   inputTensors.push_back(torchToOnnx<std::int64_t>(memoryInfo, edgeListClone));
0117 
0118   std::optional<at::Tensor> edgeAttrTensor;
0119   if (inEdgeFeatures.has_value()) {
0120     inputNames.push_back(m_inputNames.at(2).c_str());
0121     edgeAttrTensor =
0122         std::any_cast<torch::Tensor>(inEdgeFeatures).to(torchDevice).clone();
0123     inputTensors.push_back(torchToOnnx<float>(memoryInfo, *edgeAttrTensor));
0124   }
0125 
0126   std::vector<const char *> outputNames{m_outputName.c_str()};
0127 
0128   auto outputTensor =
0129       m_model->Run({}, inputNames.data(), inputTensors.data(),
0130                    inputTensors.size(), outputNames.data(), outputNames.size());
0131 
0132   float *rawOutData = nullptr;
0133   if (outputTensor.at(0).GetTensorTypeAndShapeInfo().GetElementType() ==
0134       ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
0135     rawOutData = outputTensor.at(0).GetTensorMutableData<float>();
0136   } else {
0137     throw std::runtime_error("Invalid output datatype");
0138   }
0139 
0140   ACTS_DEBUG("Get scores for " << numEdges << " edges.");
0141   auto scores =
0142       torch::from_blob(
0143           rawOutData, {numEdges},
0144           torch::TensorOptions().device(torchDevice).dtype(torch::kFloat32))
0145           .clone();
0146 
0147   ACTS_VERBOSE("Slice of classified output before sigmoid:\n"
0148                << scores.slice(/*dim=*/0, /*start=*/0, /*end=*/9));
0149 
0150   scores.sigmoid_();
0151 
0152   ACTS_DEBUG("scores: " << detail::TensorDetails{scores});
0153   ACTS_VERBOSE("Slice of classified output:\n"
0154                << scores.slice(/*dim=*/0, /*start=*/0, /*end=*/9));
0155 
0156   torch::Tensor filterMask = scores > m_cfg.cut;
0157   torch::Tensor edgesAfterCut = edgeList.index({Slice(), filterMask});
0158 
0159   ACTS_DEBUG("Finished edge classification, after cut: "
0160              << edgesAfterCut.size(1) << " edges.");
0161 
0162   if (edgesAfterCut.size(1) == 0) {
0163     throw Acts::NoEdgesError{};
0164   }
0165 
0166   return {std::move(nodeTensor), edgesAfterCut.clone(),
0167           std::move(inEdgeFeatures), std::move(scores)};
0168 }
0169 
0170 }  // namespace Acts