File indexing completed on 2025-01-30 09:15:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"
0012
0013 #include <chrono>
0014
0015 #ifndef ACTS_EXATRKX_CPUONLY
0016 #include <c10/cuda/CUDAGuard.h>
0017 #endif
0018
0019 #include <torch/script.h>
0020 #include <torch/torch.h>
0021
0022 #include "printCudaMemInfo.hpp"
0023
0024 using namespace torch::indexing;
0025
0026 namespace Acts {
0027
0028 TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg,
0029 std::unique_ptr<const Logger> _logger)
0030 : m_logger(std::move(_logger)),
0031 m_cfg(cfg),
0032 m_device(torch::Device(torch::kCPU)) {
0033 c10::InferenceMode guard(true);
0034 m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
0035 if (m_deviceType == torch::kCPU) {
0036 ACTS_DEBUG("Running on CPU...");
0037 } else {
0038 if (cfg.deviceID >= 0 &&
0039 static_cast<std::size_t>(cfg.deviceID) < torch::cuda::device_count()) {
0040 ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used.");
0041 m_device = torch::Device(torch::kCUDA, cfg.deviceID);
0042 } else {
0043 ACTS_WARNING("GPU device " << cfg.deviceID
0044 << " not available, falling back to CPU.");
0045 }
0046 }
0047
0048 ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
0049 << TORCH_VERSION_MINOR << "."
0050 << TORCH_VERSION_PATCH);
0051 #ifndef ACTS_EXATRKX_CPUONLY
0052 if (not torch::cuda::is_available()) {
0053 ACTS_INFO("CUDA not available, falling back to CPU");
0054 }
0055 #endif
0056
0057 try {
0058 m_model = std::make_unique<torch::jit::Module>();
0059 *m_model = torch::jit::load(m_cfg.modelPath.c_str(), m_device);
0060 m_model->eval();
0061 } catch (const c10::Error& e) {
0062 throw std::invalid_argument("Failed to load models: " + e.msg());
0063 }
0064 }
0065
0066 TorchEdgeClassifier::~TorchEdgeClassifier() {}
0067
0068 std::tuple<std::any, std::any, std::any, std::any>
0069 TorchEdgeClassifier::operator()(std::any inNodeFeatures, std::any inEdgeIndex,
0070 std::any inEdgeFeatures,
0071 const ExecutionContext& execContext) {
0072 const auto& device = execContext.device;
0073 decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4;
0074 t0 = std::chrono::high_resolution_clock::now();
0075 ACTS_DEBUG("Start edge classification, use " << device);
0076 c10::InferenceMode guard(true);
0077
0078
0079 #ifdef ACTS_EXATRKX_CPUONLY
0080 assert(device == torch::Device(torch::kCPU));
0081 #else
0082 std::optional<c10::cuda::CUDAGuard> device_guard;
0083 std::optional<c10::cuda::CUDAStreamGuard> streamGuard;
0084 if (device.is_cuda()) {
0085 device_guard.emplace(device.index());
0086 streamGuard.emplace(execContext.stream.value());
0087 }
0088 #endif
0089
0090 auto nodeFeatures = std::any_cast<torch::Tensor>(inNodeFeatures).to(device);
0091 auto edgeIndex = std::any_cast<torch::Tensor>(inEdgeIndex).to(device);
0092
0093 if (edgeIndex.numel() == 0) {
0094 throw NoEdgesError{};
0095 }
0096
0097 ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex});
0098
0099 std::optional<torch::Tensor> edgeFeatures;
0100 if (inEdgeFeatures.has_value()) {
0101 edgeFeatures = std::any_cast<torch::Tensor>(inEdgeFeatures).to(device);
0102 ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{*edgeFeatures});
0103 }
0104 t1 = std::chrono::high_resolution_clock::now();
0105
0106 torch::Tensor output;
0107
0108
0109 {
0110 auto edgeIndexTmp = m_cfg.undirected
0111 ? torch::cat({edgeIndex, edgeIndex.flip(0)}, 1)
0112 : edgeIndex;
0113
0114 std::vector<torch::jit::IValue> inputTensors(2);
0115 auto selectedFeaturesTensor =
0116 at::tensor(at::ArrayRef<int>(m_cfg.selectedFeatures));
0117 at::Tensor selectedNodeFeatures =
0118 !m_cfg.selectedFeatures.empty()
0119 ? nodeFeatures.index({Slice{}, selectedFeaturesTensor}).clone()
0120 : nodeFeatures;
0121
0122 ACTS_DEBUG("selected nodeFeatures: "
0123 << detail::TensorDetails{selectedNodeFeatures});
0124 inputTensors[0] = selectedNodeFeatures;
0125
0126 if (edgeFeatures && m_cfg.useEdgeFeatures) {
0127 inputTensors.push_back(*edgeFeatures);
0128 }
0129
0130 if (m_cfg.nChunks > 1) {
0131 std::vector<at::Tensor> results;
0132 results.reserve(m_cfg.nChunks);
0133
0134 auto chunks = at::chunk(edgeIndexTmp, m_cfg.nChunks, 1);
0135 for (auto& chunk : chunks) {
0136 ACTS_VERBOSE("Process chunk with shape" << chunk.sizes());
0137 inputTensors[1] = chunk;
0138
0139 results.push_back(m_model->forward(inputTensors).toTensor());
0140 results.back().squeeze_();
0141 }
0142
0143 output = torch::cat(results);
0144 } else {
0145 inputTensors[1] = edgeIndexTmp;
0146
0147 t2 = std::chrono::high_resolution_clock::now();
0148 output = m_model->forward(inputTensors).toTensor().to(torch::kFloat32);
0149 t3 = std::chrono::high_resolution_clock::now();
0150 output.squeeze_();
0151 }
0152 }
0153
0154 ACTS_VERBOSE("Slice of classified output before sigmoid:\n"
0155 << output.slice(0, 0, 9));
0156
0157 output.sigmoid_();
0158
0159 if (m_cfg.undirected) {
0160 auto newSize = output.size(0) / 2;
0161 output = output.index({Slice(None, newSize)});
0162 }
0163
0164 ACTS_VERBOSE("Size after classifier: " << output.size(0));
0165 ACTS_VERBOSE("Slice of classified output:\n"
0166 << output.slice(0, 0, 9));
0167 printCudaMemInfo(logger());
0168
0169 torch::Tensor mask = output > m_cfg.cut;
0170 torch::Tensor edgesAfterCut = edgeIndex.index({Slice(), mask});
0171 edgesAfterCut = edgesAfterCut.to(torch::kInt64);
0172
0173 ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
0174 printCudaMemInfo(logger());
0175 t4 = std::chrono::high_resolution_clock::now();
0176
0177 auto milliseconds = [](const auto& a, const auto& b) {
0178 return std::chrono::duration<double, std::milli>(b - a).count();
0179 };
0180 ACTS_DEBUG("Time anycast, device guard: " << milliseconds(t0, t1));
0181 ACTS_DEBUG("Time jit::IValue creation: " << milliseconds(t1, t2));
0182 ACTS_DEBUG("Time model forward: " << milliseconds(t2, t3));
0183 ACTS_DEBUG("Time sigmoid and cut: " << milliseconds(t3, t4));
0184
0185 return {std::move(nodeFeatures), std::move(edgesAfterCut),
0186 std::move(inEdgeFeatures), output.masked_select(mask)};
0187 }
0188
0189 }