File indexing completed on 2025-07-01 07:53:39
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.hpp"
0012
0013 #include <chrono>
0014 #include <filesystem>
0015 #include <fstream>
0016
0017 #include <NvInfer.h>
0018 #include <NvInferPlugin.h>
0019 #include <NvInferRuntimeBase.h>
0020 #include <cuda_runtime.h>
0021
0022 #include "printCudaMemInfo.hpp"
0023
0024 namespace {
0025
0026 class TensorRTLogger : public nvinfer1::ILogger {
0027 std::unique_ptr<const Acts::Logger> m_logger;
0028
0029 public:
0030 TensorRTLogger(Acts::Logging::Level lvl)
0031 : m_logger(Acts::getDefaultLogger("TensorRT", lvl)) {}
0032
0033 void log(Severity severity, const char *msg) noexcept override {
0034 const auto &logger = *m_logger;
0035 switch (severity) {
0036 case Severity::kVERBOSE:
0037 ACTS_DEBUG(msg);
0038 break;
0039 case Severity::kINFO:
0040 ACTS_INFO(msg);
0041 break;
0042 case Severity::kWARNING:
0043 ACTS_WARNING(msg);
0044 break;
0045 case Severity::kERROR:
0046 ACTS_ERROR(msg);
0047 break;
0048 case Severity::kINTERNAL_ERROR:
0049 ACTS_FATAL(msg);
0050 break;
0051 }
0052 }
0053 };
0054
0055 }
0056
0057 namespace Acts {
0058
0059 TensorRTEdgeClassifier::TensorRTEdgeClassifier(
0060 const Config &cfg, std::unique_ptr<const Logger> _logger)
0061 : m_logger(std::move(_logger)),
0062 m_cfg(cfg),
0063 m_trtLogger(std::make_unique<TensorRTLogger>(m_logger->level())) {
0064 auto status = initLibNvInferPlugins(m_trtLogger.get(), "");
0065 if (!status) {
0066 throw std::runtime_error("Failed to initialize TensorRT plugins");
0067 }
0068
0069 std::size_t fsize =
0070 std::filesystem::file_size(std::filesystem::path(m_cfg.modelPath));
0071 std::vector<char> engineData(fsize);
0072
0073 ACTS_DEBUG("Load '" << m_cfg.modelPath << "' with size " << fsize);
0074
0075 std::ifstream engineFile(m_cfg.modelPath);
0076 if (!engineFile) {
0077 throw std::runtime_error("Failed to open engine file");
0078 } else if (!engineFile.read(engineData.data(), fsize)) {
0079 throw std::runtime_error("Failed to read engine file");
0080 }
0081
0082 m_runtime.reset(nvinfer1::createInferRuntime(*m_trtLogger));
0083 if (!m_runtime) {
0084 throw std::runtime_error("Failed to create TensorRT runtime");
0085 }
0086
0087 m_engine.reset(m_runtime->deserializeCudaEngine(engineData.data(), fsize));
0088 if (!m_engine) {
0089 throw std::runtime_error("Failed to deserialize CUDA engine");
0090 }
0091
0092 for (auto i = 0ul; i < m_cfg.numExecutionContexts; ++i) {
0093 ACTS_DEBUG("Create execution context " << i);
0094 m_contexts.emplace_back(m_engine->createExecutionContext());
0095 if (!m_contexts.back()) {
0096 throw std::runtime_error("Failed to create execution context");
0097 }
0098 }
0099
0100 std::size_t freeMem{}, totalMem{};
0101 ACTS_CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem));
0102 ACTS_DEBUG("Used CUDA memory after TensorRT initialization: "
0103 << (totalMem - freeMem) * 1e-9 << " / " << totalMem * 1e-9
0104 << " GB");
0105 }
0106
0107 TensorRTEdgeClassifier::~TensorRTEdgeClassifier() {}
0108
0109 PipelineTensors TensorRTEdgeClassifier::operator()(
0110 PipelineTensors tensors, const ExecutionContext &execContext) {
0111 assert(execContext.device.type == Acts::Device::Type::eCUDA);
0112
0113 decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4;
0114 t0 = std::chrono::high_resolution_clock::now();
0115
0116
0117 std::unique_ptr<nvinfer1::IExecutionContext> context;
0118 while (context == nullptr) {
0119 std::lock_guard<std::mutex> lock(m_contextMutex);
0120 if (!m_contexts.empty()) {
0121 context = std::move(m_contexts.back());
0122 m_contexts.pop_back();
0123 }
0124 }
0125 assert(context != nullptr);
0126
0127 context->setInputShape(
0128 "x", nvinfer1::Dims2{static_cast<long>(tensors.nodeFeatures.shape()[0]),
0129 static_cast<long>(tensors.nodeFeatures.shape()[1])});
0130 context->setTensorAddress("x", tensors.nodeFeatures.data());
0131
0132 context->setInputShape(
0133 "edge_index",
0134 nvinfer1::Dims2{static_cast<long>(tensors.edgeIndex.shape()[0]),
0135 static_cast<long>(tensors.edgeIndex.shape()[1])});
0136 context->setTensorAddress("edge_index", tensors.edgeIndex.data());
0137
0138 if (tensors.edgeFeatures.has_value()) {
0139 context->setInputShape(
0140 "edge_attr",
0141 nvinfer1::Dims2{static_cast<long>(tensors.edgeFeatures->shape()[0]),
0142 static_cast<long>(tensors.edgeFeatures->shape()[1])});
0143 context->setTensorAddress("edge_attr", tensors.edgeFeatures->data());
0144 }
0145
0146 auto scores =
0147 Tensor<float>::Create({tensors.edgeIndex.shape()[1], 1ul}, execContext);
0148 context->setTensorAddress("output", scores.data());
0149
0150 t2 = std::chrono::high_resolution_clock::now();
0151
0152 auto stream = execContext.stream.value();
0153 auto status = context->enqueueV3(stream);
0154 if (!status) {
0155 throw std::runtime_error("Failed to execute TensorRT model");
0156 }
0157 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0158
0159 t3 = std::chrono::high_resolution_clock::now();
0160
0161 {
0162 std::lock_guard<std::mutex> lock(m_contextMutex);
0163 m_contexts.push_back(std::move(context));
0164 }
0165
0166 sigmoid(scores, execContext.stream);
0167
0168 ACTS_VERBOSE("Size after classifier: " << scores.shape()[0]);
0169 printCudaMemInfo(logger());
0170
0171 auto [newScores, newEdgeIndex] =
0172 applyScoreCut(scores, tensors.edgeIndex, m_cfg.cut, execContext.stream);
0173 ACTS_VERBOSE("Size after score cut: " << newEdgeIndex.shape()[1]);
0174 printCudaMemInfo(logger());
0175
0176 t4 = std::chrono::high_resolution_clock::now();
0177
0178 auto milliseconds = [](const auto &a, const auto &b) {
0179 return std::chrono::duration<double, std::milli>(b - a).count();
0180 };
0181 ACTS_DEBUG("Time anycast: " << milliseconds(t0, t1));
0182 ACTS_DEBUG("Time alloc, set shape " << milliseconds(t1, t2));
0183 ACTS_DEBUG("Time inference: " << milliseconds(t2, t3));
0184 ACTS_DEBUG("Time sigmoid and cut: " << milliseconds(t3, t4));
0185
0186 return {std::move(tensors.nodeFeatures), std::move(newEdgeIndex),
0187 std::nullopt, std::move(newScores)};
0188 }
0189
0190 }