Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 09:15:13

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"
0012 
0013 #include <chrono>
0014 
0015 #ifndef ACTS_EXATRKX_CPUONLY
0016 #include <c10/cuda/CUDAGuard.h>
0017 #endif
0018 
0019 #include <torch/script.h>
0020 #include <torch/torch.h>
0021 
0022 #include "printCudaMemInfo.hpp"
0023 
0024 using namespace torch::indexing;
0025 
0026 namespace Acts {
0027 
0028 TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg,
0029                                          std::unique_ptr<const Logger> _logger)
0030     : m_logger(std::move(_logger)),
0031       m_cfg(cfg),
0032       m_device(torch::Device(torch::kCPU)) {
0033   c10::InferenceMode guard(true);
0034   m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
0035   if (m_deviceType == torch::kCPU) {
0036     ACTS_DEBUG("Running on CPU...");
0037   } else {
0038     if (cfg.deviceID >= 0 &&
0039         static_cast<std::size_t>(cfg.deviceID) < torch::cuda::device_count()) {
0040       ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used.");
0041       m_device = torch::Device(torch::kCUDA, cfg.deviceID);
0042     } else {
0043       ACTS_WARNING("GPU device " << cfg.deviceID
0044                                  << " not available, falling back to CPU.");
0045     }
0046   }
0047 
0048   ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
0049                                     << TORCH_VERSION_MINOR << "."
0050                                     << TORCH_VERSION_PATCH);
0051 #ifndef ACTS_EXATRKX_CPUONLY
0052   if (not torch::cuda::is_available()) {
0053     ACTS_INFO("CUDA not available, falling back to CPU");
0054   }
0055 #endif
0056 
0057   try {
0058     m_model = std::make_unique<torch::jit::Module>();
0059     *m_model = torch::jit::load(m_cfg.modelPath.c_str(), m_device);
0060     m_model->eval();
0061   } catch (const c10::Error& e) {
0062     throw std::invalid_argument("Failed to load models: " + e.msg());
0063   }
0064 }
0065 
0066 TorchEdgeClassifier::~TorchEdgeClassifier() {}
0067 
0068 std::tuple<std::any, std::any, std::any, std::any>
0069 TorchEdgeClassifier::operator()(std::any inNodeFeatures, std::any inEdgeIndex,
0070                                 std::any inEdgeFeatures,
0071                                 const ExecutionContext& execContext) {
0072   const auto& device = execContext.device;
0073   decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4;
0074   t0 = std::chrono::high_resolution_clock::now();
0075   ACTS_DEBUG("Start edge classification, use " << device);
0076   c10::InferenceMode guard(true);
0077 
0078   // add a protection to avoid calling for kCPU
0079 #ifdef ACTS_EXATRKX_CPUONLY
0080   assert(device == torch::Device(torch::kCPU));
0081 #else
0082   std::optional<c10::cuda::CUDAGuard> device_guard;
0083   std::optional<c10::cuda::CUDAStreamGuard> streamGuard;
0084   if (device.is_cuda()) {
0085     device_guard.emplace(device.index());
0086     streamGuard.emplace(execContext.stream.value());
0087   }
0088 #endif
0089 
0090   auto nodeFeatures = std::any_cast<torch::Tensor>(inNodeFeatures).to(device);
0091   auto edgeIndex = std::any_cast<torch::Tensor>(inEdgeIndex).to(device);
0092 
0093   if (edgeIndex.numel() == 0) {
0094     throw NoEdgesError{};
0095   }
0096 
0097   ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex});
0098 
0099   std::optional<torch::Tensor> edgeFeatures;
0100   if (inEdgeFeatures.has_value()) {
0101     edgeFeatures = std::any_cast<torch::Tensor>(inEdgeFeatures).to(device);
0102     ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{*edgeFeatures});
0103   }
0104   t1 = std::chrono::high_resolution_clock::now();
0105 
0106   torch::Tensor output;
0107 
0108   // Scope this to keep inference objects separate
0109   {
0110     auto edgeIndexTmp = m_cfg.undirected
0111                             ? torch::cat({edgeIndex, edgeIndex.flip(0)}, 1)
0112                             : edgeIndex;
0113 
0114     std::vector<torch::jit::IValue> inputTensors(2);
0115     auto selectedFeaturesTensor =
0116         at::tensor(at::ArrayRef<int>(m_cfg.selectedFeatures));
0117     at::Tensor selectedNodeFeatures =
0118         !m_cfg.selectedFeatures.empty()
0119             ? nodeFeatures.index({Slice{}, selectedFeaturesTensor}).clone()
0120             : nodeFeatures;
0121 
0122     ACTS_DEBUG("selected nodeFeatures: "
0123                << detail::TensorDetails{selectedNodeFeatures});
0124     inputTensors[0] = selectedNodeFeatures;
0125 
0126     if (edgeFeatures && m_cfg.useEdgeFeatures) {
0127       inputTensors.push_back(*edgeFeatures);
0128     }
0129 
0130     if (m_cfg.nChunks > 1) {
0131       std::vector<at::Tensor> results;
0132       results.reserve(m_cfg.nChunks);
0133 
0134       auto chunks = at::chunk(edgeIndexTmp, m_cfg.nChunks, 1);
0135       for (auto& chunk : chunks) {
0136         ACTS_VERBOSE("Process chunk with shape" << chunk.sizes());
0137         inputTensors[1] = chunk;
0138 
0139         results.push_back(m_model->forward(inputTensors).toTensor());
0140         results.back().squeeze_();
0141       }
0142 
0143       output = torch::cat(results);
0144     } else {
0145       inputTensors[1] = edgeIndexTmp;
0146 
0147       t2 = std::chrono::high_resolution_clock::now();
0148       output = m_model->forward(inputTensors).toTensor().to(torch::kFloat32);
0149       t3 = std::chrono::high_resolution_clock::now();
0150       output.squeeze_();
0151     }
0152   }
0153 
0154   ACTS_VERBOSE("Slice of classified output before sigmoid:\n"
0155                << output.slice(/*dim=*/0, /*start=*/0, /*end=*/9));
0156 
0157   output.sigmoid_();
0158 
0159   if (m_cfg.undirected) {
0160     auto newSize = output.size(0) / 2;
0161     output = output.index({Slice(None, newSize)});
0162   }
0163 
0164   ACTS_VERBOSE("Size after classifier: " << output.size(0));
0165   ACTS_VERBOSE("Slice of classified output:\n"
0166                << output.slice(/*dim=*/0, /*start=*/0, /*end=*/9));
0167   printCudaMemInfo(logger());
0168 
0169   torch::Tensor mask = output > m_cfg.cut;
0170   torch::Tensor edgesAfterCut = edgeIndex.index({Slice(), mask});
0171   edgesAfterCut = edgesAfterCut.to(torch::kInt64);
0172 
0173   ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
0174   printCudaMemInfo(logger());
0175   t4 = std::chrono::high_resolution_clock::now();
0176 
0177   auto milliseconds = [](const auto& a, const auto& b) {
0178     return std::chrono::duration<double, std::milli>(b - a).count();
0179   };
0180   ACTS_DEBUG("Time anycast, device guard:  " << milliseconds(t0, t1));
0181   ACTS_DEBUG("Time jit::IValue creation:   " << milliseconds(t1, t2));
0182   ACTS_DEBUG("Time model forward:          " << milliseconds(t2, t3));
0183   ACTS_DEBUG("Time sigmoid and cut:        " << milliseconds(t3, t4));
0184 
0185   return {std::move(nodeFeatures), std::move(edgesAfterCut),
0186           std::move(inEdgeFeatures), output.masked_select(mask)};
0187 }
0188 
0189 }  // namespace Acts