File indexing completed on 2025-07-14 08:11:49
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 #include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"
0013
0014 #include <chrono>
0015
0016 #ifndef ACTS_EXATRKX_CPUONLY
0017 #include <c10/cuda/CUDAGuard.h>
0018 #endif
0019
0020 #include <torch/script.h>
0021 #include <torch/torch.h>
0022
0023 #include "printCudaMemInfo.hpp"
0024
0025 using namespace torch::indexing;
0026
0027 namespace Acts {
0028
0029 TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg,
0030 std::unique_ptr<const Logger> _logger)
0031 : m_logger(std::move(_logger)), m_cfg(cfg) {
0032 c10::InferenceMode guard(true);
0033 torch::Device device = torch::kCPU;
0034
0035 if (!torch::cuda::is_available()) {
0036 ACTS_DEBUG("Running on CPU...");
0037 } else {
0038 if (cfg.deviceID >= 0 &&
0039 static_cast<std::size_t>(cfg.deviceID) < torch::cuda::device_count()) {
0040 ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used.");
0041 device = torch::Device(torch::kCUDA, cfg.deviceID);
0042 } else {
0043 ACTS_WARNING("GPU device " << cfg.deviceID
0044 << " not available, falling back to CPU.");
0045 }
0046 }
0047
0048 ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
0049 << TORCH_VERSION_MINOR << "."
0050 << TORCH_VERSION_PATCH);
0051 #ifndef ACTS_EXATRKX_CPUONLY
0052 if (!torch::cuda::is_available()) {
0053 ACTS_INFO("CUDA not available, falling back to CPU");
0054 }
0055 #endif
0056
0057 try {
0058 m_model = std::make_unique<torch::jit::Module>();
0059 *m_model = torch::jit::load(m_cfg.modelPath, device);
0060 m_model->eval();
0061 } catch (const c10::Error& e) {
0062 throw std::invalid_argument("Failed to load models: " + e.msg());
0063 }
0064 }
0065
0066 TorchEdgeClassifier::~TorchEdgeClassifier() {}
0067
0068 PipelineTensors TorchEdgeClassifier::operator()(
0069 PipelineTensors tensors, const ExecutionContext& execContext) {
0070 const auto device =
0071 execContext.device.type == Acts::Device::Type::eCUDA
0072 ? torch::Device(torch::kCUDA, execContext.device.index)
0073 : torch::kCPU;
0074 decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4;
0075 t0 = std::chrono::high_resolution_clock::now();
0076 ACTS_DEBUG("Start edge classification, use " << device);
0077
0078 if (tensors.edgeIndex.size() == 0) {
0079 throw NoEdgesError{};
0080 }
0081
0082 c10::InferenceMode guard(true);
0083
0084
0085 #ifdef ACTS_EXATRKX_CPUONLY
0086 assert(device == torch::Device(torch::kCPU));
0087 #else
0088 std::optional<c10::cuda::CUDAGuard> device_guard;
0089 if (device.is_cuda()) {
0090 device_guard.emplace(device.index());
0091 }
0092 #endif
0093
0094 auto nodeFeatures = detail::actsToNonOwningTorchTensor(tensors.nodeFeatures);
0095 ACTS_DEBUG("nodeFeatures: " << detail::TensorDetails{nodeFeatures});
0096
0097 auto edgeIndex = detail::actsToNonOwningTorchTensor(tensors.edgeIndex);
0098 ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex});
0099
0100 std::optional<torch::Tensor> edgeFeatures;
0101 if (tensors.edgeFeatures.has_value()) {
0102 edgeFeatures = detail::actsToNonOwningTorchTensor(*tensors.edgeFeatures);
0103 ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{*edgeFeatures});
0104 }
0105
0106 torch::Tensor output;
0107
0108
0109 {
0110 auto edgeIndexTmp = m_cfg.undirected
0111 ? torch::cat({edgeIndex, edgeIndex.flip(0)}, 1)
0112 : edgeIndex;
0113
0114 std::vector<torch::jit::IValue> inputTensors(2);
0115 auto selectedFeaturesTensor =
0116 at::tensor(at::ArrayRef<int>(m_cfg.selectedFeatures));
0117 at::Tensor selectedNodeFeatures =
0118 !m_cfg.selectedFeatures.empty()
0119 ? nodeFeatures.index({Slice{}, selectedFeaturesTensor}).clone()
0120 : nodeFeatures;
0121
0122 ACTS_DEBUG("selected nodeFeatures: "
0123 << detail::TensorDetails{selectedNodeFeatures});
0124 inputTensors[0] = selectedNodeFeatures;
0125
0126 if (edgeFeatures && m_cfg.useEdgeFeatures) {
0127 inputTensors.push_back(*edgeFeatures);
0128 }
0129
0130 t1 = std::chrono::high_resolution_clock::now();
0131
0132 if (m_cfg.nChunks > 1) {
0133 std::vector<at::Tensor> results;
0134 results.reserve(m_cfg.nChunks);
0135
0136 auto chunks = at::chunk(edgeIndexTmp, m_cfg.nChunks, 1);
0137 for (auto& chunk : chunks) {
0138 ACTS_VERBOSE("Process chunk with shape" << chunk.sizes());
0139 inputTensors[1] = chunk;
0140
0141 results.push_back(m_model->forward(inputTensors).toTensor());
0142 results.back().squeeze_();
0143 }
0144
0145 output = torch::cat(results);
0146 } else {
0147 inputTensors[1] = edgeIndexTmp;
0148 output = m_model->forward(inputTensors).toTensor().to(torch::kFloat32);
0149 output.squeeze_();
0150 }
0151 t2 = std::chrono::high_resolution_clock::now();
0152 }
0153
0154 ACTS_VERBOSE("Slice of classified output before sigmoid:\n"
0155 << output.slice(0, 0, 9));
0156
0157 output.sigmoid_();
0158
0159 if (m_cfg.undirected) {
0160 auto newSize = output.size(0) / 2;
0161 output = output.index({Slice(None, newSize)});
0162 }
0163
0164 ACTS_VERBOSE("Size after classifier: " << output.size(0));
0165 ACTS_VERBOSE("Slice of classified output:\n"
0166 << output.slice(0, 0, 9));
0167 printCudaMemInfo(logger());
0168
0169 torch::Tensor mask = output > m_cfg.cut;
0170 torch::Tensor edgesAfterCut = edgeIndex.index({Slice(), mask});
0171 edgesAfterCut = edgesAfterCut.to(torch::kInt64);
0172
0173 if (edgesAfterCut.numel() == 0) {
0174 throw NoEdgesError{};
0175 }
0176
0177 ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
0178 printCudaMemInfo(logger());
0179 t3 = std::chrono::high_resolution_clock::now();
0180
0181 auto milliseconds = [](const auto& a, const auto& b) {
0182 return std::chrono::duration<double, std::milli>(b - a).count();
0183 };
0184 ACTS_DEBUG("Time preparation: " << milliseconds(t0, t1));
0185 ACTS_DEBUG("Time inference: " << milliseconds(t1, t2));
0186 ACTS_DEBUG("Time postprocessing: " << milliseconds(t2, t3));
0187
0188
0189
0190 return {std::move(tensors.nodeFeatures),
0191 detail::torchToActsTensor<std::int64_t>(edgesAfterCut, execContext),
0192 {},
0193 detail::torchToActsTensor<float>(output.masked_select(mask),
0194 execContext)};
0195 }
0196
0197 }