File indexing completed on 2025-01-30 09:15:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 #include "Acts/Plugins/ExaTrkX/detail/buildEdges.hpp"
0013
0014 #ifndef ACTS_EXATRKX_CPUONLY
0015 #include <c10/cuda/CUDAGuard.h>
0016 #endif
0017
0018 #include <numbers>
0019
0020 #include <torch/script.h>
0021 #include <torch/torch.h>
0022
0023 #include "printCudaMemInfo.hpp"
0024
0025 using namespace torch::indexing;
0026
0027 namespace Acts {
0028
0029 TorchMetricLearning::TorchMetricLearning(const Config &cfg,
0030 std::unique_ptr<const Logger> _logger)
0031 : m_logger(std::move(_logger)),
0032 m_cfg(cfg),
0033 m_device(torch::Device(torch::kCPU)) {
0034 c10::InferenceMode guard(true);
0035 m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
0036
0037 if (m_deviceType == torch::kCPU) {
0038 ACTS_DEBUG("Running on CPU...");
0039 } else {
0040 if (cfg.deviceID >= 0 &&
0041 static_cast<std::size_t>(cfg.deviceID) < torch::cuda::device_count()) {
0042 ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used.");
0043 m_device = torch::Device(torch::kCUDA, cfg.deviceID);
0044 } else {
0045 ACTS_WARNING("GPU device " << cfg.deviceID
0046 << " not available, falling back to CPU.");
0047 }
0048 }
0049
0050 ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
0051 << TORCH_VERSION_MINOR << "."
0052 << TORCH_VERSION_PATCH);
0053 #ifndef ACTS_EXATRKX_CPUONLY
0054 if (not torch::cuda::is_available()) {
0055 ACTS_INFO("CUDA not available, falling back to CPU");
0056 }
0057 #endif
0058
0059 try {
0060 m_model = std::make_unique<torch::jit::Module>();
0061 *m_model = torch::jit::load(m_cfg.modelPath, m_device);
0062 m_model->eval();
0063 } catch (const c10::Error &e) {
0064 throw std::invalid_argument("Failed to load models: " + e.msg());
0065 }
0066 }
0067
0068 TorchMetricLearning::~TorchMetricLearning() {}
0069
0070 std::tuple<std::any, std::any, std::any> TorchMetricLearning::operator()(
0071 std::vector<float> &inputValues, std::size_t numNodes,
0072 const std::vector<std::uint64_t> & ,
0073 const ExecutionContext &execContext) {
0074 const auto &device = execContext.device;
0075 ACTS_DEBUG("Start graph construction");
0076 c10::InferenceMode guard(true);
0077
0078
0079 #ifdef ACTS_EXATRKX_CPUONLY
0080 assert(device == torch::Device(torch::kCPU));
0081 #else
0082 std::optional<c10::cuda::CUDAGuard> device_guard;
0083 std::optional<c10::cuda::CUDAStreamGuard> streamGuard;
0084 if (device.is_cuda()) {
0085 device_guard.emplace(device.index());
0086 streamGuard.emplace(execContext.stream.value());
0087 }
0088 #endif
0089
0090 const std::int64_t numAllFeatures = inputValues.size() / numNodes;
0091
0092
0093 ACTS_VERBOSE("First spacepoint information: " << [&]() {
0094 std::stringstream ss;
0095 for (int i = 0; i < numAllFeatures; ++i) {
0096 ss << inputValues[i] << " ";
0097 }
0098 return ss.str();
0099 }());
0100 printCudaMemInfo(logger());
0101
0102 auto inputTensor = detail::vectorToTensor2D(inputValues, numAllFeatures);
0103
0104
0105
0106 if (inputTensor.options().device() == device) {
0107 inputTensor = inputTensor.clone();
0108 } else {
0109 inputTensor = inputTensor.to(device);
0110 }
0111
0112
0113
0114
0115
0116
0117 auto model = m_model->clone();
0118 model.to(device);
0119
0120 std::vector<torch::jit::IValue> inputTensors;
0121 auto selectedFeaturesTensor =
0122 at::tensor(at::ArrayRef<int>(m_cfg.selectedFeatures));
0123 inputTensors.push_back(
0124 !m_cfg.selectedFeatures.empty()
0125 ? inputTensor.index({Slice{}, selectedFeaturesTensor})
0126 : std::move(inputTensor));
0127
0128 ACTS_DEBUG("embedding input tensor shape "
0129 << inputTensors[0].toTensor().size(0) << ", "
0130 << inputTensors[0].toTensor().size(1));
0131
0132 auto output = model.forward(inputTensors).toTensor();
0133
0134 ACTS_VERBOSE("Embedding space of the first SP:\n"
0135 << output.slice(0, 0, 1));
0136 printCudaMemInfo(logger());
0137
0138
0139
0140
0141
0142 auto edgeList = detail::buildEdges(output, m_cfg.rVal, m_cfg.knnVal,
0143 m_cfg.shuffleDirections);
0144
0145 ACTS_VERBOSE("Shape of built edges: (" << edgeList.size(0) << ", "
0146 << edgeList.size(1));
0147 ACTS_VERBOSE("Slice of edgelist:\n" << edgeList.slice(1, 0, 5));
0148 printCudaMemInfo(logger());
0149
0150
0151 std::any edgeFeatures;
0152 return {std::move(inputTensors[0]).toTensor(), std::move(edgeList),
0153 std::move(edgeFeatures)};
0154 }
0155 }