Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 09:15:13

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 
0013 #include <torch/torch.h>
0014 
0015 Acts::TorchGraphStoreHook::TorchGraphStoreHook() {
0016   m_storedGraph = std::make_unique<Graph>();
0017 }
0018 
0019 void Acts::TorchGraphStoreHook::operator()(const std::any&,
0020                                            const std::any& edges,
0021                                            const std::any& weights) const {
0022   if (not weights.has_value()) {
0023     return;
0024   }
0025 
0026   m_storedGraph->first = detail::tensor2DToVector<std::int64_t>(
0027       std::any_cast<torch::Tensor>(edges).t());
0028 
0029   auto cpuWeights = std::any_cast<torch::Tensor>(weights).to(torch::kCPU);
0030   m_storedGraph->second =
0031       std::vector<float>(cpuWeights.data_ptr<float>(),
0032                          cpuWeights.data_ptr<float>() + cpuWeights.numel());
0033 }