Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:24:26

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsPlugins/Gnn/detail/buildEdges.hpp"
0010 
0011 #include "Acts/Utilities/Helpers.hpp"
0012 #include "Acts/Utilities/KDTree.hpp"
0013 #include "ActsPlugins/Gnn/detail/TensorVectorConversion.hpp"
0014 
0015 #include <iostream>
0016 #include <mutex>
0017 #include <vector>
0018 
0019 #include <torch/script.h>
0020 #include <torch/torch.h>
0021 
0022 #ifndef ACTS_GNN_CPUONLY
0023 #include <cuda.h>
0024 #include <cuda_runtime_api.h>
0025 #include <grid/counting_sort.h>
0026 #include <grid/find_nbrs.h>
0027 #include <grid/grid.h>
0028 #include <grid/insert_points.h>
0029 #include <grid/prefix_sum.h>
0030 #endif
0031 
0032 using namespace torch::indexing;
0033 
0034 using namespace Acts;
0035 
0036 torch::Tensor ActsPlugins::detail::postprocessEdgeTensor(torch::Tensor edges,
0037                                                          bool removeSelfLoops,
0038                                                          bool removeDuplicates,
0039                                                          bool flipDirections) {
0040   // Remove self-loops
0041   if (removeSelfLoops) {
0042     torch::Tensor selfLoopMask = edges.index({0}) != edges.index({1});
0043     edges = edges.index({Slice(), selfLoopMask});
0044   }
0045 
0046   // Remove duplicates
0047   if (removeDuplicates) {
0048     torch::Tensor mask = edges.index({0}) > edges.index({1});
0049     edges.index_put_({Slice(), mask}, edges.index({Slice(), mask}).flip(0));
0050     edges = std::get<0>(torch::unique_dim(edges, -1, false));
0051   }
0052 
0053   // Randomly flip direction
0054   if (flipDirections) {
0055     torch::Tensor random_cut_keep = torch::randint(2, {edges.size(1)});
0056     torch::Tensor random_cut_flip = 1 - random_cut_keep;
0057     torch::Tensor keep_edges =
0058         edges.index({Slice(), random_cut_keep.to(torch::kBool)});
0059     torch::Tensor flip_edges =
0060         edges.index({Slice(), random_cut_flip.to(torch::kBool)}).flip({0});
0061     edges = torch::cat({keep_edges, flip_edges}, 1);
0062   }
0063 
0064   return edges.toType(torch::kInt64);
0065 }
0066 
0067 torch::Tensor ActsPlugins::detail::buildEdgesFRNN(torch::Tensor &embedFeatures,
0068                                                   float rVal, int kVal,
0069                                                   bool flipDirections) {
0070 #ifndef ACTS_GNN_CPUONLY
0071   const auto device = embedFeatures.device();
0072 
0073   const std::int64_t numSpacepoints = embedFeatures.size(0);
0074   const int dim = embedFeatures.size(1);
0075 
0076   const int grid_params_size = 8;
0077   const int grid_delta_idx = 3;
0078   const int grid_total_idx = 7;
0079   const int grid_max_res = 128;
0080   const int grid_dim = 3;
0081 
0082   if (dim < 3) {
0083     throw std::runtime_error("DIM < 3 is not supported for now.\n");
0084   }
0085 
0086   const float radius_cell_ratio = 2.0;
0087   const int batch_size = 1;
0088   int G = -1;
0089 
0090   // Set up grid properties
0091   torch::Tensor grid_min;
0092   torch::Tensor grid_max;
0093   torch::Tensor grid_size;
0094 
0095   torch::Tensor embedTensor = embedFeatures.reshape({1, numSpacepoints, dim});
0096   torch::Tensor gridParamsCuda =
0097       torch::zeros({batch_size, grid_params_size}, device).to(torch::kFloat32);
0098   torch::Tensor r_tensor = torch::full({batch_size}, rVal, device);
0099   torch::Tensor lengths = torch::full({batch_size}, numSpacepoints, device);
0100 
0101   // build the grid
0102   for (int i = 0; i < batch_size; i++) {
0103     torch::Tensor allPoints =
0104         embedTensor.index({i, Slice(None, lengths.index({i}).item().to<long>()),
0105                            Slice(None, grid_dim)});
0106     grid_min = std::get<0>(allPoints.min(0));
0107     grid_max = std::get<0>(allPoints.max(0));
0108     gridParamsCuda.index_put_({i, Slice(None, grid_delta_idx)}, grid_min);
0109 
0110     grid_size = grid_max - grid_min;
0111 
0112     float cell_size =
0113         r_tensor.index({i}).item().to<float>() / radius_cell_ratio;
0114 
0115     if (cell_size < (grid_size.min().item().to<float>() / grid_max_res)) {
0116       cell_size = grid_size.min().item().to<float>() / grid_max_res;
0117     }
0118 
0119     gridParamsCuda.index_put_({i, grid_delta_idx}, 1 / cell_size);
0120 
0121     gridParamsCuda.index_put_({i, Slice(1 + grid_delta_idx, grid_total_idx)},
0122                               floor(grid_size / cell_size) + 1);
0123 
0124     gridParamsCuda.index_put_(
0125         {i, grid_total_idx},
0126         gridParamsCuda.index({i, Slice(1 + grid_delta_idx, grid_total_idx)})
0127             .prod());
0128 
0129     if (G < gridParamsCuda.index({i, grid_total_idx}).item().to<int>()) {
0130       G = gridParamsCuda.index({i, grid_total_idx}).item().to<int>();
0131     }
0132   }
0133 
0134   torch::Tensor pc_grid_cnt =
0135       torch::zeros({batch_size, G}, device).to(torch::kInt32);
0136   torch::Tensor pc_grid_cell =
0137       torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0138   torch::Tensor pc_grid_idx =
0139       torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0140 
0141   // put spacepoints into the grid
0142   InsertPointsCUDA(embedTensor, lengths.to(torch::kInt64), gridParamsCuda,
0143                    pc_grid_cnt, pc_grid_cell, pc_grid_idx, G);
0144 
0145   torch::Tensor pc_grid_off =
0146       torch::full({batch_size, G}, 0, device).to(torch::kInt32);
0147   torch::Tensor grid_params = gridParamsCuda.to(torch::kCPU);
0148 
0149   // for loop seems not to be necessary anymore
0150   pc_grid_off = PrefixSumCUDA(pc_grid_cnt, grid_params);
0151 
0152   torch::Tensor sorted_points =
0153       torch::zeros({batch_size, numSpacepoints, dim}, device)
0154           .to(torch::kFloat32);
0155   torch::Tensor sorted_points_idxs =
0156       torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0157 
0158   CountingSortCUDA(embedTensor, lengths.to(torch::kInt64), pc_grid_cell,
0159                    pc_grid_idx, pc_grid_off, sorted_points, sorted_points_idxs);
0160 
0161   auto [indices, distances] = FindNbrsCUDA(
0162       sorted_points, sorted_points, lengths.to(torch::kInt64),
0163       lengths.to(torch::kInt64), pc_grid_off.to(torch::kInt32),
0164       sorted_points_idxs, sorted_points_idxs,
0165       gridParamsCuda.to(torch::kFloat32), kVal, r_tensor, r_tensor * r_tensor);
0166   torch::Tensor positiveIndices = indices >= 0;
0167 
0168   torch::Tensor repeatRange = torch::arange(positiveIndices.size(1), device)
0169                                   .repeat({1, positiveIndices.size(2), 1})
0170                                   .transpose(1, 2);
0171 
0172   torch::Tensor stackedEdges = torch::stack(
0173       {repeatRange.index({positiveIndices}), indices.index({positiveIndices})});
0174 
0175   return postprocessEdgeTensor(std::move(stackedEdges), true, true,
0176                                flipDirections);
0177 #else
0178   throw std::runtime_error(
0179       "ACTS not compiled with CUDA, cannot run ActsPlugins::buildEdgesFRNN");
0180 #endif
0181 }
0182 
0183 /// This is a very unsophisticated span implementation to avoid data copies in
0184 /// the KDTree search.
0185 /// Should be replaced with std::span when possible
0186 template <typename T, std::size_t S>
0187 struct Span {
0188   T *ptr;
0189 
0190   auto size() const { return S; }
0191 
0192   using const_iterator = T const *;
0193   const_iterator cbegin() const { return ptr; }
0194   const_iterator cend() const { return ptr + S; }
0195 
0196   auto &operator[](std::size_t i) const { return ptr[i]; }
0197 };
0198 
0199 template <std::size_t Dim>
0200 float dist(const Span<float, Dim> &a, const Span<float, Dim> &b) {
0201   float s = 0.f;
0202   for (auto i = 0ul; i < Dim; ++i) {
0203     s += (a[i] - b[i]) * (a[i] - b[i]);
0204   }
0205   return std::sqrt(s);
0206 };
0207 
0208 template <std::size_t Dim>
0209 struct BuildEdgesKDTree {
0210   static torch::Tensor invoke(torch::Tensor &embedFeatures, float rVal,
0211                               int kVal) {
0212     assert(embedFeatures.size(1) == Dim);
0213     embedFeatures = embedFeatures.to(torch::kCPU);
0214 
0215     ////////////////
0216     // Build tree //
0217     ////////////////
0218     using KDTree = KDTree<Dim, int, float, Span>;
0219 
0220     typename KDTree::vector_t features;
0221     features.reserve(embedFeatures.size(0));
0222 
0223     auto dataPtr = embedFeatures.data_ptr<float>();
0224 
0225     for (int i = 0; i < embedFeatures.size(0); ++i) {
0226       features.push_back({Span<float, Dim>{dataPtr + i * Dim}, i});
0227     }
0228 
0229     KDTree tree(std::move(features));
0230 
0231     /////////////////
0232     // Search tree //
0233     /////////////////
0234     std::vector<std::int32_t> edges;
0235     edges.reserve(2 * kVal * embedFeatures.size(0));
0236 
0237     for (int iself = 0; iself < embedFeatures.size(0); ++iself) {
0238       const Span<float, Dim> self{dataPtr + iself * Dim};
0239 
0240       RangeXD<Dim, float> range;
0241       for (auto j = 0ul; j < Dim; ++j) {
0242         range[j] = Range1D<float>(self[j] - rVal, self[j] + rVal);
0243       }
0244 
0245       tree.rangeSearchMapDiscard(
0246           range, [&](const Span<float, Dim> &other, const int &iother) {
0247             if (iself != iother && dist(self, other) <= rVal) {
0248               edges.push_back(iself);
0249               edges.push_back(iother);
0250             }
0251           });
0252     }
0253 
0254     // Transpose is necessary here, clone to get ownership
0255     return ActsPlugins::detail::vectorToTensor2D(edges, 2).t().clone();
0256   }
0257 };
0258 
0259 torch::Tensor ActsPlugins::detail::buildEdgesKDTree(
0260     torch::Tensor &embedFeatures, float rVal, int kVal, bool flipDirections) {
0261   auto tensor = template_switch<BuildEdgesKDTree, 1, 12>(
0262       embedFeatures.size(1), embedFeatures, rVal, kVal);
0263 
0264   return postprocessEdgeTensor(tensor, true, true, flipDirections);
0265 }
0266 
0267 torch::Tensor ActsPlugins::detail::buildEdges(torch::Tensor &embedFeatures,
0268                                               float rVal, int kVal,
0269                                               bool flipDirections) {
0270 #ifndef ACTS_GNN_CPUONLY
0271   if (torch::cuda::is_available()) {
0272     return detail::buildEdgesFRNN(embedFeatures, rVal, kVal, flipDirections);
0273   } else {
0274     return detail::buildEdgesKDTree(embedFeatures, rVal, kVal, flipDirections);
0275   }
0276 #else
0277   return detail::buildEdgesKDTree(embedFeatures, rVal, kVal, flipDirections);
0278 #endif
0279 }