Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-01 07:53:39

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.hpp"
0012 
0013 #include <chrono>
0014 #include <filesystem>
0015 #include <fstream>
0016 
0017 #include <NvInfer.h>
0018 #include <NvInferPlugin.h>
0019 #include <NvInferRuntimeBase.h>
0020 #include <cuda_runtime.h>
0021 
0022 #include "printCudaMemInfo.hpp"
0023 
0024 namespace {
0025 
0026 class TensorRTLogger : public nvinfer1::ILogger {
0027   std::unique_ptr<const Acts::Logger> m_logger;
0028 
0029  public:
0030   TensorRTLogger(Acts::Logging::Level lvl)
0031       : m_logger(Acts::getDefaultLogger("TensorRT", lvl)) {}
0032 
0033   void log(Severity severity, const char *msg) noexcept override {
0034     const auto &logger = *m_logger;
0035     switch (severity) {
0036       case Severity::kVERBOSE:
0037         ACTS_DEBUG(msg);
0038         break;
0039       case Severity::kINFO:
0040         ACTS_INFO(msg);
0041         break;
0042       case Severity::kWARNING:
0043         ACTS_WARNING(msg);
0044         break;
0045       case Severity::kERROR:
0046         ACTS_ERROR(msg);
0047         break;
0048       case Severity::kINTERNAL_ERROR:
0049         ACTS_FATAL(msg);
0050         break;
0051     }
0052   }
0053 };
0054 
0055 }  // namespace
0056 
0057 namespace Acts {
0058 
0059 TensorRTEdgeClassifier::TensorRTEdgeClassifier(
0060     const Config &cfg, std::unique_ptr<const Logger> _logger)
0061     : m_logger(std::move(_logger)),
0062       m_cfg(cfg),
0063       m_trtLogger(std::make_unique<TensorRTLogger>(m_logger->level())) {
0064   auto status = initLibNvInferPlugins(m_trtLogger.get(), "");
0065   if (!status) {
0066     throw std::runtime_error("Failed to initialize TensorRT plugins");
0067   }
0068 
0069   std::size_t fsize =
0070       std::filesystem::file_size(std::filesystem::path(m_cfg.modelPath));
0071   std::vector<char> engineData(fsize);
0072 
0073   ACTS_DEBUG("Load '" << m_cfg.modelPath << "' with size " << fsize);
0074 
0075   std::ifstream engineFile(m_cfg.modelPath);
0076   if (!engineFile) {
0077     throw std::runtime_error("Failed to open engine file");
0078   } else if (!engineFile.read(engineData.data(), fsize)) {
0079     throw std::runtime_error("Failed to read engine file");
0080   }
0081 
0082   m_runtime.reset(nvinfer1::createInferRuntime(*m_trtLogger));
0083   if (!m_runtime) {
0084     throw std::runtime_error("Failed to create TensorRT runtime");
0085   }
0086 
0087   m_engine.reset(m_runtime->deserializeCudaEngine(engineData.data(), fsize));
0088   if (!m_engine) {
0089     throw std::runtime_error("Failed to deserialize CUDA engine");
0090   }
0091 
0092   for (auto i = 0ul; i < m_cfg.numExecutionContexts; ++i) {
0093     ACTS_DEBUG("Create execution context " << i);
0094     m_contexts.emplace_back(m_engine->createExecutionContext());
0095     if (!m_contexts.back()) {
0096       throw std::runtime_error("Failed to create execution context");
0097     }
0098   }
0099 
0100   std::size_t freeMem{}, totalMem{};
0101   ACTS_CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem));
0102   ACTS_DEBUG("Used CUDA memory after TensorRT initialization: "
0103              << (totalMem - freeMem) * 1e-9 << " / " << totalMem * 1e-9
0104              << " GB");
0105 }
0106 
0107 TensorRTEdgeClassifier::~TensorRTEdgeClassifier() {}
0108 
0109 PipelineTensors TensorRTEdgeClassifier::operator()(
0110     PipelineTensors tensors, const ExecutionContext &execContext) {
0111   assert(execContext.device.type == Acts::Device::Type::eCUDA);
0112 
0113   decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4;
0114   t0 = std::chrono::high_resolution_clock::now();
0115 
0116   // get a context from the list of contexts
0117   std::unique_ptr<nvinfer1::IExecutionContext> context;
0118   while (context == nullptr) {
0119     std::lock_guard<std::mutex> lock(m_contextMutex);
0120     if (!m_contexts.empty()) {
0121       context = std::move(m_contexts.back());
0122       m_contexts.pop_back();
0123     }
0124   }
0125   assert(context != nullptr);
0126 
0127   context->setInputShape(
0128       "x", nvinfer1::Dims2{static_cast<long>(tensors.nodeFeatures.shape()[0]),
0129                            static_cast<long>(tensors.nodeFeatures.shape()[1])});
0130   context->setTensorAddress("x", tensors.nodeFeatures.data());
0131 
0132   context->setInputShape(
0133       "edge_index",
0134       nvinfer1::Dims2{static_cast<long>(tensors.edgeIndex.shape()[0]),
0135                       static_cast<long>(tensors.edgeIndex.shape()[1])});
0136   context->setTensorAddress("edge_index", tensors.edgeIndex.data());
0137 
0138   if (tensors.edgeFeatures.has_value()) {
0139     context->setInputShape(
0140         "edge_attr",
0141         nvinfer1::Dims2{static_cast<long>(tensors.edgeFeatures->shape()[0]),
0142                         static_cast<long>(tensors.edgeFeatures->shape()[1])});
0143     context->setTensorAddress("edge_attr", tensors.edgeFeatures->data());
0144   }
0145 
0146   auto scores =
0147       Tensor<float>::Create({tensors.edgeIndex.shape()[1], 1ul}, execContext);
0148   context->setTensorAddress("output", scores.data());
0149 
0150   t2 = std::chrono::high_resolution_clock::now();
0151 
0152   auto stream = execContext.stream.value();
0153   auto status = context->enqueueV3(stream);
0154   if (!status) {
0155     throw std::runtime_error("Failed to execute TensorRT model");
0156   }
0157   ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0158 
0159   t3 = std::chrono::high_resolution_clock::now();
0160 
0161   {
0162     std::lock_guard<std::mutex> lock(m_contextMutex);
0163     m_contexts.push_back(std::move(context));
0164   }
0165 
0166   sigmoid(scores, execContext.stream);
0167 
0168   ACTS_VERBOSE("Size after classifier: " << scores.shape()[0]);
0169   printCudaMemInfo(logger());
0170 
0171   auto [newScores, newEdgeIndex] =
0172       applyScoreCut(scores, tensors.edgeIndex, m_cfg.cut, execContext.stream);
0173   ACTS_VERBOSE("Size after score cut: " << newEdgeIndex.shape()[1]);
0174   printCudaMemInfo(logger());
0175 
0176   t4 = std::chrono::high_resolution_clock::now();
0177 
0178   auto milliseconds = [](const auto &a, const auto &b) {
0179     return std::chrono::duration<double, std::milli>(b - a).count();
0180   };
0181   ACTS_DEBUG("Time anycast:  " << milliseconds(t0, t1));
0182   ACTS_DEBUG("Time alloc, set shape " << milliseconds(t1, t2));
0183   ACTS_DEBUG("Time inference:       " << milliseconds(t2, t3));
0184   ACTS_DEBUG("Time sigmoid and cut: " << milliseconds(t3, t4));
0185 
0186   return {std::move(tensors.nodeFeatures), std::move(newEdgeIndex),
0187           std::nullopt, std::move(newScores)};
0188 }
0189 
0190 }  // namespace Acts