Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 09:15:13

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.cuh"
0012 #include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"
0013 
0014 #include <chrono>
0015 #include <filesystem>
0016 #include <fstream>
0017 
0018 #include <NvInfer.h>
0019 #include <NvInferPlugin.h>
0020 #include <NvInferRuntimeBase.h>
0021 #include <c10/cuda/CUDAGuard.h>
0022 #include <cuda_runtime.h>
0023 
0024 #include "printCudaMemInfo.hpp"
0025 
0026 using namespace torch::indexing;
0027 
0028 namespace {
0029 
0030 class TensorRTLogger : public nvinfer1::ILogger {
0031   std::unique_ptr<const Acts::Logger> m_logger;
0032 
0033  public:
0034   TensorRTLogger(Acts::Logging::Level lvl)
0035       : m_logger(Acts::getDefaultLogger("TensorRT", lvl)) {}
0036 
0037   void log(Severity severity, const char *msg) noexcept override {
0038     const auto &logger = *m_logger;
0039     switch (severity) {
0040       case Severity::kVERBOSE:
0041         ACTS_DEBUG(msg);
0042         break;
0043       case Severity::kINFO:
0044         ACTS_INFO(msg);
0045         break;
0046       case Severity::kWARNING:
0047         ACTS_WARNING(msg);
0048         break;
0049       case Severity::kERROR:
0050         ACTS_ERROR(msg);
0051         break;
0052       case Severity::kINTERNAL_ERROR:
0053         ACTS_FATAL(msg);
0054         break;
0055     }
0056   }
0057 };
0058 
0059 }  // namespace
0060 
0061 namespace Acts {
0062 
0063 TensorRTEdgeClassifier::TensorRTEdgeClassifier(
0064     const Config &cfg, std::unique_ptr<const Logger> _logger)
0065     : m_logger(std::move(_logger)),
0066       m_cfg(cfg),
0067       m_trtLogger(std::make_unique<TensorRTLogger>(m_logger->level())) {
0068   auto status = initLibNvInferPlugins(m_trtLogger.get(), "");
0069   if (!status) {
0070     throw std::runtime_error("Failed to initialize TensorRT plugins");
0071   }
0072 
0073   std::size_t fsize =
0074       std::filesystem::file_size(std::filesystem::path(m_cfg.modelPath));
0075   std::vector<char> engineData(fsize);
0076 
0077   ACTS_DEBUG("Load '" << m_cfg.modelPath << "' with size " << fsize);
0078 
0079   std::ifstream engineFile(m_cfg.modelPath);
0080   if (!engineFile) {
0081     throw std::runtime_error("Failed to open engine file");
0082   } else if (!engineFile.read(engineData.data(), fsize)) {
0083     throw std::runtime_error("Failed to read engine file");
0084   }
0085 
0086   m_runtime.reset(nvinfer1::createInferRuntime(*m_trtLogger));
0087   if (!m_runtime) {
0088     throw std::runtime_error("Failed to create TensorRT runtime");
0089   }
0090 
0091   m_engine.reset(m_runtime->deserializeCudaEngine(engineData.data(), fsize));
0092   if (!m_engine) {
0093     throw std::runtime_error("Failed to deserialize CUDA engine");
0094   }
0095 
0096   for (auto i = 0ul; i < m_cfg.numExecutionContexts; ++i) {
0097     ACTS_DEBUG("Create execution context " << i);
0098     m_contexts.emplace_back(m_engine->createExecutionContext());
0099     if (!m_contexts.back()) {
0100       throw std::runtime_error("Failed to create execution context");
0101     }
0102   }
0103 
0104   std::size_t freeMem{}, totalMem{};
0105   ACTS_CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem));
0106   ACTS_DEBUG("Used CUDA memory after TensorRT initialization: "
0107              << (totalMem - freeMem) * 1e-9 << " / " << totalMem * 1e-9
0108              << " GB");
0109 }
0110 
0111 TensorRTEdgeClassifier::~TensorRTEdgeClassifier() {}
0112 
0113 std::tuple<std::any, std::any, std::any, std::any>
0114 TensorRTEdgeClassifier::operator()(std::any inNodeFeatures,
0115                                    std::any inEdgeIndex,
0116                                    std::any inEdgeFeatures,
0117                                    const ExecutionContext &execContext) {
0118   assert(execContext.device.is_cuda());
0119   decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4;
0120   t0 = std::chrono::high_resolution_clock::now();
0121 
0122   c10::cuda::CUDAStreamGuard(execContext.stream.value());
0123 
0124   auto nodeFeatures =
0125       std::any_cast<torch::Tensor>(inNodeFeatures).to(execContext.device);
0126 
0127   auto edgeIndex =
0128       std::any_cast<torch::Tensor>(inEdgeIndex).to(execContext.device);
0129   ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex});
0130 
0131   auto edgeFeatures =
0132       std::any_cast<torch::Tensor>(inEdgeFeatures).to(execContext.device);
0133   ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{edgeFeatures});
0134 
0135   t1 = std::chrono::high_resolution_clock::now();
0136 
0137   // get a context from the list of contexts
0138   std::unique_ptr<nvinfer1::IExecutionContext> context;
0139   while (context == nullptr) {
0140     std::lock_guard<std::mutex> lock(m_contextMutex);
0141     if (!m_contexts.empty()) {
0142       context = std::move(m_contexts.back());
0143       m_contexts.pop_back();
0144     }
0145   }
0146   assert(context != nullptr);
0147 
0148   context->setInputShape(
0149       "x", nvinfer1::Dims2{nodeFeatures.size(0), nodeFeatures.size(1)});
0150   context->setTensorAddress("x", nodeFeatures.data_ptr());
0151 
0152   context->setInputShape("edge_index",
0153                          nvinfer1::Dims2{edgeIndex.size(0), edgeIndex.size(1)});
0154   context->setTensorAddress("edge_index", edgeIndex.data_ptr());
0155 
0156   context->setInputShape(
0157       "edge_attr", nvinfer1::Dims2{edgeFeatures.size(0), edgeFeatures.size(1)});
0158   context->setTensorAddress("edge_attr", edgeFeatures.data_ptr());
0159 
0160   auto scores = torch::empty(
0161       edgeIndex.size(1),
0162       torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
0163   context->setTensorAddress("output", scores.data_ptr());
0164 
0165   t2 = std::chrono::high_resolution_clock::now();
0166 
0167   auto stream = execContext.stream.value().stream();
0168   auto status = context->enqueueV3(stream);
0169   if (!status) {
0170     throw std::runtime_error("Failed to execute TensorRT model");
0171   }
0172   ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0173 
0174   t3 = std::chrono::high_resolution_clock::now();
0175 
0176   {
0177     std::lock_guard<std::mutex> lock(m_contextMutex);
0178     m_contexts.push_back(std::move(context));
0179   }
0180 
0181   scores.sigmoid_();
0182 
0183   ACTS_VERBOSE("Size after classifier: " << scores.size(0));
0184   ACTS_VERBOSE("Slice of classified output:\n"
0185                << scores.slice(/*dim=*/0, /*start=*/0, /*end=*/9));
0186   printCudaMemInfo(logger());
0187 
0188   torch::Tensor mask = scores > m_cfg.cut;
0189   torch::Tensor edgesAfterCut = edgeIndex.index({Slice(), mask});
0190 
0191   scores = scores.masked_select(mask);
0192   ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
0193   printCudaMemInfo(logger());
0194 
0195   t4 = std::chrono::high_resolution_clock::now();
0196 
0197   auto milliseconds = [](const auto &a, const auto &b) {
0198     return std::chrono::duration<double, std::milli>(b - a).count();
0199   };
0200   ACTS_DEBUG("Time anycast:  " << milliseconds(t0, t1));
0201   ACTS_DEBUG("Time alloc, set shape " << milliseconds(t1, t2));
0202   ACTS_DEBUG("Time inference:       " << milliseconds(t2, t3));
0203   ACTS_DEBUG("Time sigmoid and cut: " << milliseconds(t3, t4));
0204 
0205   return {nodeFeatures, edgesAfterCut, edgeFeatures, scores};
0206 }
0207 
0208 }  // namespace Acts