File indexing completed on 2025-01-30 09:15:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.cuh"
0012 #include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"
0013
0014 #include <chrono>
0015 #include <filesystem>
0016 #include <fstream>
0017
0018 #include <NvInfer.h>
0019 #include <NvInferPlugin.h>
0020 #include <NvInferRuntimeBase.h>
0021 #include <c10/cuda/CUDAGuard.h>
0022 #include <cuda_runtime.h>
0023
0024 #include "printCudaMemInfo.hpp"
0025
0026 using namespace torch::indexing;
0027
0028 namespace {
0029
0030 class TensorRTLogger : public nvinfer1::ILogger {
0031 std::unique_ptr<const Acts::Logger> m_logger;
0032
0033 public:
0034 TensorRTLogger(Acts::Logging::Level lvl)
0035 : m_logger(Acts::getDefaultLogger("TensorRT", lvl)) {}
0036
0037 void log(Severity severity, const char *msg) noexcept override {
0038 const auto &logger = *m_logger;
0039 switch (severity) {
0040 case Severity::kVERBOSE:
0041 ACTS_DEBUG(msg);
0042 break;
0043 case Severity::kINFO:
0044 ACTS_INFO(msg);
0045 break;
0046 case Severity::kWARNING:
0047 ACTS_WARNING(msg);
0048 break;
0049 case Severity::kERROR:
0050 ACTS_ERROR(msg);
0051 break;
0052 case Severity::kINTERNAL_ERROR:
0053 ACTS_FATAL(msg);
0054 break;
0055 }
0056 }
0057 };
0058
0059 }
0060
0061 namespace Acts {
0062
0063 TensorRTEdgeClassifier::TensorRTEdgeClassifier(
0064 const Config &cfg, std::unique_ptr<const Logger> _logger)
0065 : m_logger(std::move(_logger)),
0066 m_cfg(cfg),
0067 m_trtLogger(std::make_unique<TensorRTLogger>(m_logger->level())) {
0068 auto status = initLibNvInferPlugins(m_trtLogger.get(), "");
0069 if (!status) {
0070 throw std::runtime_error("Failed to initialize TensorRT plugins");
0071 }
0072
0073 std::size_t fsize =
0074 std::filesystem::file_size(std::filesystem::path(m_cfg.modelPath));
0075 std::vector<char> engineData(fsize);
0076
0077 ACTS_DEBUG("Load '" << m_cfg.modelPath << "' with size " << fsize);
0078
0079 std::ifstream engineFile(m_cfg.modelPath);
0080 if (!engineFile) {
0081 throw std::runtime_error("Failed to open engine file");
0082 } else if (!engineFile.read(engineData.data(), fsize)) {
0083 throw std::runtime_error("Failed to read engine file");
0084 }
0085
0086 m_runtime.reset(nvinfer1::createInferRuntime(*m_trtLogger));
0087 if (!m_runtime) {
0088 throw std::runtime_error("Failed to create TensorRT runtime");
0089 }
0090
0091 m_engine.reset(m_runtime->deserializeCudaEngine(engineData.data(), fsize));
0092 if (!m_engine) {
0093 throw std::runtime_error("Failed to deserialize CUDA engine");
0094 }
0095
0096 for (auto i = 0ul; i < m_cfg.numExecutionContexts; ++i) {
0097 ACTS_DEBUG("Create execution context " << i);
0098 m_contexts.emplace_back(m_engine->createExecutionContext());
0099 if (!m_contexts.back()) {
0100 throw std::runtime_error("Failed to create execution context");
0101 }
0102 }
0103
0104 std::size_t freeMem{}, totalMem{};
0105 ACTS_CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem));
0106 ACTS_DEBUG("Used CUDA memory after TensorRT initialization: "
0107 << (totalMem - freeMem) * 1e-9 << " / " << totalMem * 1e-9
0108 << " GB");
0109 }
0110
0111 TensorRTEdgeClassifier::~TensorRTEdgeClassifier() {}
0112
0113 std::tuple<std::any, std::any, std::any, std::any>
0114 TensorRTEdgeClassifier::operator()(std::any inNodeFeatures,
0115 std::any inEdgeIndex,
0116 std::any inEdgeFeatures,
0117 const ExecutionContext &execContext) {
0118 assert(execContext.device.is_cuda());
0119 decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4;
0120 t0 = std::chrono::high_resolution_clock::now();
0121
0122 c10::cuda::CUDAStreamGuard(execContext.stream.value());
0123
0124 auto nodeFeatures =
0125 std::any_cast<torch::Tensor>(inNodeFeatures).to(execContext.device);
0126
0127 auto edgeIndex =
0128 std::any_cast<torch::Tensor>(inEdgeIndex).to(execContext.device);
0129 ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex});
0130
0131 auto edgeFeatures =
0132 std::any_cast<torch::Tensor>(inEdgeFeatures).to(execContext.device);
0133 ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{edgeFeatures});
0134
0135 t1 = std::chrono::high_resolution_clock::now();
0136
0137
0138 std::unique_ptr<nvinfer1::IExecutionContext> context;
0139 while (context == nullptr) {
0140 std::lock_guard<std::mutex> lock(m_contextMutex);
0141 if (!m_contexts.empty()) {
0142 context = std::move(m_contexts.back());
0143 m_contexts.pop_back();
0144 }
0145 }
0146 assert(context != nullptr);
0147
0148 context->setInputShape(
0149 "x", nvinfer1::Dims2{nodeFeatures.size(0), nodeFeatures.size(1)});
0150 context->setTensorAddress("x", nodeFeatures.data_ptr());
0151
0152 context->setInputShape("edge_index",
0153 nvinfer1::Dims2{edgeIndex.size(0), edgeIndex.size(1)});
0154 context->setTensorAddress("edge_index", edgeIndex.data_ptr());
0155
0156 context->setInputShape(
0157 "edge_attr", nvinfer1::Dims2{edgeFeatures.size(0), edgeFeatures.size(1)});
0158 context->setTensorAddress("edge_attr", edgeFeatures.data_ptr());
0159
0160 auto scores = torch::empty(
0161 edgeIndex.size(1),
0162 torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
0163 context->setTensorAddress("output", scores.data_ptr());
0164
0165 t2 = std::chrono::high_resolution_clock::now();
0166
0167 auto stream = execContext.stream.value().stream();
0168 auto status = context->enqueueV3(stream);
0169 if (!status) {
0170 throw std::runtime_error("Failed to execute TensorRT model");
0171 }
0172 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0173
0174 t3 = std::chrono::high_resolution_clock::now();
0175
0176 {
0177 std::lock_guard<std::mutex> lock(m_contextMutex);
0178 m_contexts.push_back(std::move(context));
0179 }
0180
0181 scores.sigmoid_();
0182
0183 ACTS_VERBOSE("Size after classifier: " << scores.size(0));
0184 ACTS_VERBOSE("Slice of classified output:\n"
0185 << scores.slice(0, 0, 9));
0186 printCudaMemInfo(logger());
0187
0188 torch::Tensor mask = scores > m_cfg.cut;
0189 torch::Tensor edgesAfterCut = edgeIndex.index({Slice(), mask});
0190
0191 scores = scores.masked_select(mask);
0192 ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
0193 printCudaMemInfo(logger());
0194
0195 t4 = std::chrono::high_resolution_clock::now();
0196
0197 auto milliseconds = [](const auto &a, const auto &b) {
0198 return std::chrono::duration<double, std::milli>(b - a).count();
0199 };
0200 ACTS_DEBUG("Time anycast: " << milliseconds(t0, t1));
0201 ACTS_DEBUG("Time alloc, set shape " << milliseconds(t1, t2));
0202 ACTS_DEBUG("Time inference: " << milliseconds(t2, t3));
0203 ACTS_DEBUG("Time sigmoid and cut: " << milliseconds(t3, t4));
0204
0205 return {nodeFeatures, edgesAfterCut, edgeFeatures, scores};
0206 }
0207
0208 }