Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-14 08:11:49

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 #include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"
0013 
0014 #include <chrono>
0015 
0016 #ifndef ACTS_EXATRKX_CPUONLY
0017 #include <c10/cuda/CUDAGuard.h>
0018 #endif
0019 
0020 #include <torch/script.h>
0021 #include <torch/torch.h>
0022 
0023 #include "printCudaMemInfo.hpp"
0024 
0025 using namespace torch::indexing;
0026 
0027 namespace Acts {
0028 
0029 TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg,
0030                                          std::unique_ptr<const Logger> _logger)
0031     : m_logger(std::move(_logger)), m_cfg(cfg) {
0032   c10::InferenceMode guard(true);
0033   torch::Device device = torch::kCPU;
0034 
0035   if (!torch::cuda::is_available()) {
0036     ACTS_DEBUG("Running on CPU...");
0037   } else {
0038     if (cfg.deviceID >= 0 &&
0039         static_cast<std::size_t>(cfg.deviceID) < torch::cuda::device_count()) {
0040       ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used.");
0041       device = torch::Device(torch::kCUDA, cfg.deviceID);
0042     } else {
0043       ACTS_WARNING("GPU device " << cfg.deviceID
0044                                  << " not available, falling back to CPU.");
0045     }
0046   }
0047 
0048   ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
0049                                     << TORCH_VERSION_MINOR << "."
0050                                     << TORCH_VERSION_PATCH);
0051 #ifndef ACTS_EXATRKX_CPUONLY
0052   if (!torch::cuda::is_available()) {
0053     ACTS_INFO("CUDA not available, falling back to CPU");
0054   }
0055 #endif
0056 
0057   try {
0058     m_model = std::make_unique<torch::jit::Module>();
0059     *m_model = torch::jit::load(m_cfg.modelPath, device);
0060     m_model->eval();
0061   } catch (const c10::Error& e) {
0062     throw std::invalid_argument("Failed to load models: " + e.msg());
0063   }
0064 }
0065 
0066 TorchEdgeClassifier::~TorchEdgeClassifier() {}
0067 
0068 PipelineTensors TorchEdgeClassifier::operator()(
0069     PipelineTensors tensors, const ExecutionContext& execContext) {
0070   const auto device =
0071       execContext.device.type == Acts::Device::Type::eCUDA
0072           ? torch::Device(torch::kCUDA, execContext.device.index)
0073           : torch::kCPU;
0074   decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4;
0075   t0 = std::chrono::high_resolution_clock::now();
0076   ACTS_DEBUG("Start edge classification, use " << device);
0077 
0078   if (tensors.edgeIndex.size() == 0) {
0079     throw NoEdgesError{};
0080   }
0081 
0082   c10::InferenceMode guard(true);
0083 
0084   // add a protection to avoid calling for kCPU
0085 #ifdef ACTS_EXATRKX_CPUONLY
0086   assert(device == torch::Device(torch::kCPU));
0087 #else
0088   std::optional<c10::cuda::CUDAGuard> device_guard;
0089   if (device.is_cuda()) {
0090     device_guard.emplace(device.index());
0091   }
0092 #endif
0093 
0094   auto nodeFeatures = detail::actsToNonOwningTorchTensor(tensors.nodeFeatures);
0095   ACTS_DEBUG("nodeFeatures: " << detail::TensorDetails{nodeFeatures});
0096 
0097   auto edgeIndex = detail::actsToNonOwningTorchTensor(tensors.edgeIndex);
0098   ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex});
0099 
0100   std::optional<torch::Tensor> edgeFeatures;
0101   if (tensors.edgeFeatures.has_value()) {
0102     edgeFeatures = detail::actsToNonOwningTorchTensor(*tensors.edgeFeatures);
0103     ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{*edgeFeatures});
0104   }
0105 
0106   torch::Tensor output;
0107 
0108   // Scope this to keep inference objects separate
0109   {
0110     auto edgeIndexTmp = m_cfg.undirected
0111                             ? torch::cat({edgeIndex, edgeIndex.flip(0)}, 1)
0112                             : edgeIndex;
0113 
0114     std::vector<torch::jit::IValue> inputTensors(2);
0115     auto selectedFeaturesTensor =
0116         at::tensor(at::ArrayRef<int>(m_cfg.selectedFeatures));
0117     at::Tensor selectedNodeFeatures =
0118         !m_cfg.selectedFeatures.empty()
0119             ? nodeFeatures.index({Slice{}, selectedFeaturesTensor}).clone()
0120             : nodeFeatures;
0121 
0122     ACTS_DEBUG("selected nodeFeatures: "
0123                << detail::TensorDetails{selectedNodeFeatures});
0124     inputTensors[0] = selectedNodeFeatures;
0125 
0126     if (edgeFeatures && m_cfg.useEdgeFeatures) {
0127       inputTensors.push_back(*edgeFeatures);
0128     }
0129 
0130     t1 = std::chrono::high_resolution_clock::now();
0131 
0132     if (m_cfg.nChunks > 1) {
0133       std::vector<at::Tensor> results;
0134       results.reserve(m_cfg.nChunks);
0135 
0136       auto chunks = at::chunk(edgeIndexTmp, m_cfg.nChunks, 1);
0137       for (auto& chunk : chunks) {
0138         ACTS_VERBOSE("Process chunk with shape" << chunk.sizes());
0139         inputTensors[1] = chunk;
0140 
0141         results.push_back(m_model->forward(inputTensors).toTensor());
0142         results.back().squeeze_();
0143       }
0144 
0145       output = torch::cat(results);
0146     } else {
0147       inputTensors[1] = edgeIndexTmp;
0148       output = m_model->forward(inputTensors).toTensor().to(torch::kFloat32);
0149       output.squeeze_();
0150     }
0151     t2 = std::chrono::high_resolution_clock::now();
0152   }
0153 
0154   ACTS_VERBOSE("Slice of classified output before sigmoid:\n"
0155                << output.slice(/*dim=*/0, /*start=*/0, /*end=*/9));
0156 
0157   output.sigmoid_();
0158 
0159   if (m_cfg.undirected) {
0160     auto newSize = output.size(0) / 2;
0161     output = output.index({Slice(None, newSize)});
0162   }
0163 
0164   ACTS_VERBOSE("Size after classifier: " << output.size(0));
0165   ACTS_VERBOSE("Slice of classified output:\n"
0166                << output.slice(/*dim=*/0, /*start=*/0, /*end=*/9));
0167   printCudaMemInfo(logger());
0168 
0169   torch::Tensor mask = output > m_cfg.cut;
0170   torch::Tensor edgesAfterCut = edgeIndex.index({Slice(), mask});
0171   edgesAfterCut = edgesAfterCut.to(torch::kInt64);
0172 
0173   if (edgesAfterCut.numel() == 0) {
0174     throw NoEdgesError{};
0175   }
0176 
0177   ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
0178   printCudaMemInfo(logger());
0179   t3 = std::chrono::high_resolution_clock::now();
0180 
0181   auto milliseconds = [](const auto& a, const auto& b) {
0182     return std::chrono::duration<double, std::milli>(b - a).count();
0183   };
0184   ACTS_DEBUG("Time preparation:    " << milliseconds(t0, t1));
0185   ACTS_DEBUG("Time inference:      " << milliseconds(t1, t2));
0186   ACTS_DEBUG("Time postprocessing: " << milliseconds(t2, t3));
0187 
0188   // Don't propagate edge features right now since they are not needed by any
0189   // track building algorithm
0190   return {std::move(tensors.nodeFeatures),
0191           detail::torchToActsTensor<std::int64_t>(edgesAfterCut, execContext),
0192           {},
0193           detail::torchToActsTensor<float>(output.masked_select(mask),
0194                                            execContext)};
0195 }
0196 
0197 }  // namespace Acts