Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 09:15:13

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 #include "Acts/Plugins/ExaTrkX/detail/buildEdges.hpp"
0013 
0014 #ifndef ACTS_EXATRKX_CPUONLY
0015 #include <c10/cuda/CUDAGuard.h>
0016 #endif
0017 
0018 #include <numbers>
0019 
0020 #include <torch/script.h>
0021 #include <torch/torch.h>
0022 
0023 #include "printCudaMemInfo.hpp"
0024 
0025 using namespace torch::indexing;
0026 
0027 namespace Acts {
0028 
0029 TorchMetricLearning::TorchMetricLearning(const Config &cfg,
0030                                          std::unique_ptr<const Logger> _logger)
0031     : m_logger(std::move(_logger)),
0032       m_cfg(cfg),
0033       m_device(torch::Device(torch::kCPU)) {
0034   c10::InferenceMode guard(true);
0035   m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
0036 
0037   if (m_deviceType == torch::kCPU) {
0038     ACTS_DEBUG("Running on CPU...");
0039   } else {
0040     if (cfg.deviceID >= 0 &&
0041         static_cast<std::size_t>(cfg.deviceID) < torch::cuda::device_count()) {
0042       ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used.");
0043       m_device = torch::Device(torch::kCUDA, cfg.deviceID);
0044     } else {
0045       ACTS_WARNING("GPU device " << cfg.deviceID
0046                                  << " not available, falling back to CPU.");
0047     }
0048   }
0049 
0050   ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
0051                                     << TORCH_VERSION_MINOR << "."
0052                                     << TORCH_VERSION_PATCH);
0053 #ifndef ACTS_EXATRKX_CPUONLY
0054   if (not torch::cuda::is_available()) {
0055     ACTS_INFO("CUDA not available, falling back to CPU");
0056   }
0057 #endif
0058 
0059   try {
0060     m_model = std::make_unique<torch::jit::Module>();
0061     *m_model = torch::jit::load(m_cfg.modelPath, m_device);
0062     m_model->eval();
0063   } catch (const c10::Error &e) {
0064     throw std::invalid_argument("Failed to load models: " + e.msg());
0065   }
0066 }
0067 
0068 TorchMetricLearning::~TorchMetricLearning() {}
0069 
0070 std::tuple<std::any, std::any, std::any> TorchMetricLearning::operator()(
0071     std::vector<float> &inputValues, std::size_t numNodes,
0072     const std::vector<std::uint64_t> & /*moduleIds*/,
0073     const ExecutionContext &execContext) {
0074   const auto &device = execContext.device;
0075   ACTS_DEBUG("Start graph construction");
0076   c10::InferenceMode guard(true);
0077 
0078   // add a protection to avoid calling for kCPU
0079 #ifdef ACTS_EXATRKX_CPUONLY
0080   assert(device == torch::Device(torch::kCPU));
0081 #else
0082   std::optional<c10::cuda::CUDAGuard> device_guard;
0083   std::optional<c10::cuda::CUDAStreamGuard> streamGuard;
0084   if (device.is_cuda()) {
0085     device_guard.emplace(device.index());
0086     streamGuard.emplace(execContext.stream.value());
0087   }
0088 #endif
0089 
0090   const std::int64_t numAllFeatures = inputValues.size() / numNodes;
0091 
0092   // printout the r,phi,z of the first spacepoint
0093   ACTS_VERBOSE("First spacepoint information: " << [&]() {
0094     std::stringstream ss;
0095     for (int i = 0; i < numAllFeatures; ++i) {
0096       ss << inputValues[i] << "  ";
0097     }
0098     return ss.str();
0099   }());
0100   printCudaMemInfo(logger());
0101 
0102   auto inputTensor = detail::vectorToTensor2D(inputValues, numAllFeatures);
0103 
0104   // If we are on CPU, clone to get ownership (is this necessary?), else bring
0105   // to device.
0106   if (inputTensor.options().device() == device) {
0107     inputTensor = inputTensor.clone();
0108   } else {
0109     inputTensor = inputTensor.to(device);
0110   }
0111 
0112   // **********
0113   // Embedding
0114   // **********
0115 
0116   // Clone models (solve memory leak? members can be const...)
0117   auto model = m_model->clone();
0118   model.to(device);
0119 
0120   std::vector<torch::jit::IValue> inputTensors;
0121   auto selectedFeaturesTensor =
0122       at::tensor(at::ArrayRef<int>(m_cfg.selectedFeatures));
0123   inputTensors.push_back(
0124       !m_cfg.selectedFeatures.empty()
0125           ? inputTensor.index({Slice{}, selectedFeaturesTensor})
0126           : std::move(inputTensor));
0127 
0128   ACTS_DEBUG("embedding input tensor shape "
0129              << inputTensors[0].toTensor().size(0) << ", "
0130              << inputTensors[0].toTensor().size(1));
0131 
0132   auto output = model.forward(inputTensors).toTensor();
0133 
0134   ACTS_VERBOSE("Embedding space of the first SP:\n"
0135                << output.slice(/*dim=*/0, /*start=*/0, /*end=*/1));
0136   printCudaMemInfo(logger());
0137 
0138   // ****************
0139   // Building Edges
0140   // ****************
0141 
0142   auto edgeList = detail::buildEdges(output, m_cfg.rVal, m_cfg.knnVal,
0143                                      m_cfg.shuffleDirections);
0144 
0145   ACTS_VERBOSE("Shape of built edges: (" << edgeList.size(0) << ", "
0146                                          << edgeList.size(1));
0147   ACTS_VERBOSE("Slice of edgelist:\n" << edgeList.slice(1, 0, 5));
0148   printCudaMemInfo(logger());
0149 
0150   // TODO add real edge features for this workflow later
0151   std::any edgeFeatures;
0152   return {std::move(inputTensors[0]).toTensor(), std::move(edgeList),
0153           std::move(edgeFeatures)};
0154 }
0155 }  // namespace Acts