File indexing completed on 2025-12-17 09:21:31
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsPlugins/Gnn/TorchEdgeClassifier.hpp"
0010
0011 #include "ActsPlugins/Gnn/detail/TensorVectorConversion.hpp"
0012 #include "ActsPlugins/Gnn/detail/Utils.hpp"
0013
0014 #include <chrono>
0015
0016 #ifndef ACTS_GNN_CPUONLY
0017 #include <c10/cuda/CUDAGuard.h>
0018 #endif
0019
0020 #include <torch/script.h>
0021 #include <torch/torch.h>
0022
0023 #include "printCudaMemInfo.hpp"
0024
0025 using namespace torch::indexing;
0026
0027 using namespace Acts;
0028
0029 namespace ActsPlugins {
0030
0031 TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg,
0032 std::unique_ptr<const Logger> _logger)
0033 : m_logger(std::move(_logger)), m_cfg(cfg) {
0034 c10::InferenceMode guard(true);
0035 torch::Device device = torch::kCPU;
0036
0037 if (!torch::cuda::is_available()) {
0038 ACTS_DEBUG("Running on CPU...");
0039 } else {
0040 if (cfg.deviceID >= 0 && cfg.deviceID < torch::cuda::device_count()) {
0041 ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used.");
0042 device = torch::Device(torch::kCUDA, cfg.deviceID);
0043 } else {
0044 ACTS_WARNING("GPU device " << cfg.deviceID
0045 << " not available, falling back to CPU.");
0046 }
0047 }
0048
0049 ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
0050 << TORCH_VERSION_MINOR << "."
0051 << TORCH_VERSION_PATCH);
0052 #ifndef ACTS_GNN_CPUONLY
0053 if (!torch::cuda::is_available()) {
0054 ACTS_INFO("CUDA not available, falling back to CPU");
0055 }
0056 #endif
0057
0058 try {
0059 m_model = std::make_unique<torch::jit::Module>();
0060 *m_model = torch::jit::load(m_cfg.modelPath, device);
0061 m_model->eval();
0062 } catch (const c10::Error& e) {
0063 throw std::invalid_argument("Failed to load models: " + e.msg());
0064 }
0065 }
0066
0067 TorchEdgeClassifier::~TorchEdgeClassifier() {}
0068
0069 PipelineTensors TorchEdgeClassifier::operator()(
0070 PipelineTensors tensors, const ExecutionContext& execContext) {
0071 const auto device =
0072 execContext.device.type == Device::Type::eCUDA
0073 ? torch::Device(torch::kCUDA, execContext.device.index)
0074 : torch::kCPU;
0075 decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4;
0076 t0 = std::chrono::high_resolution_clock::now();
0077 ACTS_DEBUG("Start edge classification, use " << device);
0078
0079 if (tensors.edgeIndex.size() == 0) {
0080 throw NoEdgesError{};
0081 }
0082
0083 c10::InferenceMode guard(true);
0084
0085
0086 #ifdef ACTS_GNN_CPUONLY
0087 assert(device == torch::Device(torch::kCPU));
0088 #else
0089 std::optional<c10::cuda::CUDAGuard> device_guard;
0090 if (device.is_cuda()) {
0091 device_guard.emplace(device.index());
0092 }
0093 #endif
0094
0095 auto nodeFeatures = detail::actsToNonOwningTorchTensor(tensors.nodeFeatures);
0096 ACTS_DEBUG("nodeFeatures: " << detail::TensorDetails{nodeFeatures});
0097
0098 auto edgeIndex = detail::actsToNonOwningTorchTensor(tensors.edgeIndex);
0099 ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex});
0100
0101 std::optional<torch::Tensor> edgeFeatures;
0102 if (tensors.edgeFeatures.has_value()) {
0103 edgeFeatures = detail::actsToNonOwningTorchTensor(*tensors.edgeFeatures);
0104 ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{*edgeFeatures});
0105 }
0106
0107 torch::Tensor output;
0108
0109
0110 {
0111 auto edgeIndexTmp = m_cfg.undirected
0112 ? torch::cat({edgeIndex, edgeIndex.flip(0)}, 1)
0113 : edgeIndex;
0114
0115 std::vector<torch::jit::IValue> inputTensors(2);
0116 auto selectedFeaturesTensor =
0117 at::tensor(at::ArrayRef<int>(m_cfg.selectedFeatures));
0118 at::Tensor selectedNodeFeatures =
0119 !m_cfg.selectedFeatures.empty()
0120 ? nodeFeatures.index({Slice{}, selectedFeaturesTensor}).clone()
0121 : nodeFeatures;
0122
0123 ACTS_DEBUG("selected nodeFeatures: "
0124 << detail::TensorDetails{selectedNodeFeatures});
0125 inputTensors[0] = selectedNodeFeatures;
0126
0127 if (edgeFeatures && m_cfg.useEdgeFeatures) {
0128 inputTensors.push_back(*edgeFeatures);
0129 }
0130
0131 t1 = std::chrono::high_resolution_clock::now();
0132
0133 if (m_cfg.nChunks > 1) {
0134 std::vector<at::Tensor> results;
0135 results.reserve(m_cfg.nChunks);
0136
0137 auto chunks = at::chunk(edgeIndexTmp, m_cfg.nChunks, 1);
0138 for (auto& chunk : chunks) {
0139 ACTS_VERBOSE("Process chunk with shape" << chunk.sizes());
0140 inputTensors[1] = chunk;
0141
0142 results.push_back(m_model->forward(inputTensors).toTensor());
0143 results.back().squeeze_();
0144 }
0145
0146 output = torch::cat(results);
0147 } else {
0148 inputTensors[1] = edgeIndexTmp;
0149 output = m_model->forward(inputTensors).toTensor().to(torch::kFloat32);
0150 output.squeeze_();
0151 }
0152 t2 = std::chrono::high_resolution_clock::now();
0153 }
0154
0155 ACTS_VERBOSE("Slice of classified output before sigmoid:\n"
0156 << output.slice(0, 0, 9));
0157
0158 output.sigmoid_();
0159
0160 if (m_cfg.undirected) {
0161 auto newSize = output.size(0) / 2;
0162 output = output.index({Slice(None, newSize)});
0163 }
0164
0165 ACTS_VERBOSE("Size after classifier: " << output.size(0));
0166 ACTS_VERBOSE("Slice of classified output:\n"
0167 << output.slice(0, 0, 9));
0168 printCudaMemInfo(logger());
0169
0170 torch::Tensor mask = output > m_cfg.cut;
0171 torch::Tensor edgesAfterCut = edgeIndex.index({Slice(), mask});
0172 edgesAfterCut = edgesAfterCut.to(torch::kInt64);
0173
0174 if (edgesAfterCut.numel() == 0) {
0175 throw NoEdgesError{};
0176 }
0177
0178 ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
0179 printCudaMemInfo(logger());
0180 t3 = std::chrono::high_resolution_clock::now();
0181
0182 auto milliseconds = [](const auto& a, const auto& b) {
0183 return std::chrono::duration<double, std::milli>(b - a).count();
0184 };
0185 ACTS_DEBUG("Time preparation: " << milliseconds(t0, t1));
0186 ACTS_DEBUG("Time inference: " << milliseconds(t1, t2));
0187 ACTS_DEBUG("Time postprocessing: " << milliseconds(t2, t3));
0188
0189
0190
0191 return {std::move(tensors.nodeFeatures),
0192 detail::torchToActsTensor<std::int64_t>(edgesAfterCut, execContext),
0193 {},
0194 detail::torchToActsTensor<float>(output.masked_select(mask),
0195 execContext)};
0196 }
0197
0198 }