File indexing completed on 2025-01-30 09:15:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012
0013 #include <torch/torch.h>
0014
0015 Acts::TorchGraphStoreHook::TorchGraphStoreHook() {
0016 m_storedGraph = std::make_unique<Graph>();
0017 }
0018
0019 void Acts::TorchGraphStoreHook::operator()(const std::any&,
0020 const std::any& edges,
0021 const std::any& weights) const {
0022 if (not weights.has_value()) {
0023 return;
0024 }
0025
0026 m_storedGraph->first = detail::tensor2DToVector<std::int64_t>(
0027 std::any_cast<torch::Tensor>(edges).t());
0028
0029 auto cpuWeights = std::any_cast<torch::Tensor>(weights).to(torch::kCPU);
0030 m_storedGraph->second =
0031 std::vector<float>(cpuWeights.data_ptr<float>(),
0032 cpuWeights.data_ptr<float>() + cpuWeights.numel());
0033 }