File indexing completed on 2026-06-25 07:49:03
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsPlugins/Gnn/Stages.hpp"
0010
0011 #include <algorithm>
0012 #include <stdexcept>
0013
0014 namespace ActsPlugins {
0015
0016 namespace {
0017
0018 PipelineTensors cpuRemoveUnusedNodes(PipelineTensors &&tensors,
0019 std::vector<int> &spacePointIds,
0020 const ExecutionContext & ) {
0021 const auto nNodes = tensors.nodeFeatures.shape()[0];
0022 const auto nEdges = tensors.edgeIndex.shape()[1];
0023 auto *edgeData = tensors.edgeIndex.data();
0024
0025 std::vector<std::int64_t> sortedIdx(edgeData, edgeData + 2 * nEdges);
0026 std::ranges::sort(sortedIdx);
0027 auto dups = std::ranges::unique(sortedIdx);
0028 sortedIdx.erase(dups.begin(), dups.end());
0029 const auto nUsed = sortedIdx.size();
0030
0031 ExecutionContext cpuCtx{Device::Cpu(), {}};
0032 auto mask = Tensor<bool>::Create({nNodes, 1}, cpuCtx);
0033 std::fill(mask.data(), mask.data() + nNodes, false);
0034 std::vector<std::int64_t> oldToNew(nNodes, 0);
0035 for (std::size_t newIdx = 0; newIdx < nUsed; ++newIdx) {
0036 const auto oldIdx = static_cast<std::size_t>(sortedIdx[newIdx]);
0037 mask.data()[oldIdx] = true;
0038 oldToNew[oldIdx] = static_cast<std::int64_t>(newIdx);
0039 }
0040
0041 auto newNodeFeatures = selectRows(tensors.nodeFeatures, mask, cpuCtx);
0042
0043 for (std::size_t i = 0; i < 2 * nEdges; ++i) {
0044 edgeData[i] = oldToNew[static_cast<std::size_t>(edgeData[i])];
0045 }
0046
0047 std::vector<int> remapped;
0048 remapped.reserve(nUsed);
0049 for (const auto oldIdx : sortedIdx) {
0050 remapped.push_back(spacePointIds[static_cast<std::size_t>(oldIdx)]);
0051 }
0052 spacePointIds = std::move(remapped);
0053
0054 return {std::move(newNodeFeatures), std::move(tensors.edgeIndex),
0055 std::move(tensors.edgeFeatures), std::move(tensors.edgeScores)};
0056 }
0057
0058 }
0059
0060 #ifdef ACTS_GNN_WITH_CUDA
0061 namespace detail {
0062 PipelineTensors cudaRemoveUnusedNodes(PipelineTensors &&tensors,
0063 std::vector<int> &spacePointIds,
0064 const ExecutionContext &execCtx);
0065 }
0066 #endif
0067
0068 PipelineTensors removeUnusedNodes(PipelineTensors &&tensors,
0069 std::vector<int> &spacePointIds,
0070 const ExecutionContext &execCtx) {
0071 if (tensors.edgeIndex.shape()[1] == 0) {
0072 throw NoEdgesError{};
0073 }
0074
0075 if (tensors.nodeFeatures.device().isCuda()) {
0076 #ifdef ACTS_GNN_WITH_CUDA
0077 return detail::cudaRemoveUnusedNodes(std::move(tensors), spacePointIds,
0078 execCtx);
0079 #else
0080 throw std::runtime_error(
0081 "Cannot removeUnusedNodes on CUDA tensor, library not compiled with "
0082 "CUDA");
0083 #endif
0084 }
0085
0086 return cpuRemoveUnusedNodes(std::move(tensors), spacePointIds, execCtx);
0087 }
0088
0089 }