Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-06-25 07:49:03

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsPlugins/Gnn/Stages.hpp"
0010 
0011 #include <algorithm>
0012 #include <stdexcept>
0013 
0014 namespace ActsPlugins {
0015 
0016 namespace {
0017 
0018 PipelineTensors cpuRemoveUnusedNodes(PipelineTensors &&tensors,
0019                                      std::vector<int> &spacePointIds,
0020                                      const ExecutionContext & /*execCtx*/) {
0021   const auto nNodes = tensors.nodeFeatures.shape()[0];
0022   const auto nEdges = tensors.edgeIndex.shape()[1];
0023   auto *edgeData = tensors.edgeIndex.data();
0024 
0025   std::vector<std::int64_t> sortedIdx(edgeData, edgeData + 2 * nEdges);
0026   std::ranges::sort(sortedIdx);
0027   auto dups = std::ranges::unique(sortedIdx);
0028   sortedIdx.erase(dups.begin(), dups.end());
0029   const auto nUsed = sortedIdx.size();
0030 
0031   ExecutionContext cpuCtx{Device::Cpu(), {}};
0032   auto mask = Tensor<bool>::Create({nNodes, 1}, cpuCtx);
0033   std::fill(mask.data(), mask.data() + nNodes, false);
0034   std::vector<std::int64_t> oldToNew(nNodes, 0);
0035   for (std::size_t newIdx = 0; newIdx < nUsed; ++newIdx) {
0036     const auto oldIdx = static_cast<std::size_t>(sortedIdx[newIdx]);
0037     mask.data()[oldIdx] = true;
0038     oldToNew[oldIdx] = static_cast<std::int64_t>(newIdx);
0039   }
0040 
0041   auto newNodeFeatures = selectRows(tensors.nodeFeatures, mask, cpuCtx);
0042 
0043   for (std::size_t i = 0; i < 2 * nEdges; ++i) {
0044     edgeData[i] = oldToNew[static_cast<std::size_t>(edgeData[i])];
0045   }
0046 
0047   std::vector<int> remapped;
0048   remapped.reserve(nUsed);
0049   for (const auto oldIdx : sortedIdx) {
0050     remapped.push_back(spacePointIds[static_cast<std::size_t>(oldIdx)]);
0051   }
0052   spacePointIds = std::move(remapped);
0053 
0054   return {std::move(newNodeFeatures), std::move(tensors.edgeIndex),
0055           std::move(tensors.edgeFeatures), std::move(tensors.edgeScores)};
0056 }
0057 
0058 }  // namespace
0059 
0060 #ifdef ACTS_GNN_WITH_CUDA
0061 namespace detail {
0062 PipelineTensors cudaRemoveUnusedNodes(PipelineTensors &&tensors,
0063                                       std::vector<int> &spacePointIds,
0064                                       const ExecutionContext &execCtx);
0065 }  // namespace detail
0066 #endif
0067 
0068 PipelineTensors removeUnusedNodes(PipelineTensors &&tensors,
0069                                   std::vector<int> &spacePointIds,
0070                                   const ExecutionContext &execCtx) {
0071   if (tensors.edgeIndex.shape()[1] == 0) {
0072     throw NoEdgesError{};
0073   }
0074 
0075   if (tensors.nodeFeatures.device().isCuda()) {
0076 #ifdef ACTS_GNN_WITH_CUDA
0077     return detail::cudaRemoveUnusedNodes(std::move(tensors), spacePointIds,
0078                                          execCtx);
0079 #else
0080     throw std::runtime_error(
0081         "Cannot removeUnusedNodes on CUDA tensor, library not compiled with "
0082         "CUDA");
0083 #endif
0084   }
0085 
0086   return cpuRemoveUnusedNodes(std::move(tensors), spacePointIds, execCtx);
0087 }
0088 
0089 }  // namespace ActsPlugins