Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 09:15:12

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/detail/buildEdges.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 #include "Acts/Utilities/Helpers.hpp"
0013 #include "Acts/Utilities/KDTree.hpp"
0014 
0015 #include <iostream>
0016 #include <mutex>
0017 #include <vector>
0018 
0019 #include <torch/script.h>
0020 #include <torch/torch.h>
0021 
0022 #ifndef ACTS_EXATRKX_CPUONLY
0023 #include <cuda.h>
0024 #include <cuda_runtime_api.h>
0025 #include <grid/counting_sort.h>
0026 #include <grid/find_nbrs.h>
0027 #include <grid/grid.h>
0028 #include <grid/insert_points.h>
0029 #include <grid/prefix_sum.h>
0030 #endif
0031 
0032 using namespace torch::indexing;
0033 
0034 torch::Tensor Acts::detail::postprocessEdgeTensor(torch::Tensor edges,
0035                                                   bool removeSelfLoops,
0036                                                   bool removeDuplicates,
0037                                                   bool flipDirections) {
0038   // Remove self-loops
0039   if (removeSelfLoops) {
0040     torch::Tensor selfLoopMask = edges.index({0}) != edges.index({1});
0041     edges = edges.index({Slice(), selfLoopMask});
0042   }
0043 
0044   // Remove duplicates
0045   if (removeDuplicates) {
0046     torch::Tensor mask = edges.index({0}) > edges.index({1});
0047     edges.index_put_({Slice(), mask}, edges.index({Slice(), mask}).flip(0));
0048     edges = std::get<0>(torch::unique_dim(edges, -1, false));
0049   }
0050 
0051   // Randomly flip direction
0052   if (flipDirections) {
0053     torch::Tensor random_cut_keep = torch::randint(2, {edges.size(1)});
0054     torch::Tensor random_cut_flip = 1 - random_cut_keep;
0055     torch::Tensor keep_edges =
0056         edges.index({Slice(), random_cut_keep.to(torch::kBool)});
0057     torch::Tensor flip_edges =
0058         edges.index({Slice(), random_cut_flip.to(torch::kBool)}).flip({0});
0059     edges = torch::cat({keep_edges, flip_edges}, 1);
0060   }
0061 
0062   return edges.toType(torch::kInt64);
0063 }
0064 
0065 torch::Tensor Acts::detail::buildEdgesFRNN(torch::Tensor &embedFeatures,
0066                                            float rVal, int kVal,
0067                                            bool flipDirections) {
0068 #ifndef ACTS_EXATRKX_CPUONLY
0069   const auto device = embedFeatures.device();
0070 
0071   const std::int64_t numSpacepoints = embedFeatures.size(0);
0072   const int dim = embedFeatures.size(1);
0073 
0074   const int grid_params_size = 8;
0075   const int grid_delta_idx = 3;
0076   const int grid_total_idx = 7;
0077   const int grid_max_res = 128;
0078   const int grid_dim = 3;
0079 
0080   if (dim < 3) {
0081     throw std::runtime_error("DIM < 3 is not supported for now.\n");
0082   }
0083 
0084   const float radius_cell_ratio = 2.0;
0085   const int batch_size = 1;
0086   int G = -1;
0087 
0088   // Set up grid properties
0089   torch::Tensor grid_min;
0090   torch::Tensor grid_max;
0091   torch::Tensor grid_size;
0092 
0093   torch::Tensor embedTensor = embedFeatures.reshape({1, numSpacepoints, dim});
0094   torch::Tensor gridParamsCuda =
0095       torch::zeros({batch_size, grid_params_size}, device).to(torch::kFloat32);
0096   torch::Tensor r_tensor = torch::full({batch_size}, rVal, device);
0097   torch::Tensor lengths = torch::full({batch_size}, numSpacepoints, device);
0098 
0099   // build the grid
0100   for (int i = 0; i < batch_size; i++) {
0101     torch::Tensor allPoints =
0102         embedTensor.index({i, Slice(None, lengths.index({i}).item().to<long>()),
0103                            Slice(None, grid_dim)});
0104     grid_min = std::get<0>(allPoints.min(0));
0105     grid_max = std::get<0>(allPoints.max(0));
0106     gridParamsCuda.index_put_({i, Slice(None, grid_delta_idx)}, grid_min);
0107 
0108     grid_size = grid_max - grid_min;
0109 
0110     float cell_size =
0111         r_tensor.index({i}).item().to<float>() / radius_cell_ratio;
0112 
0113     if (cell_size < (grid_size.min().item().to<float>() / grid_max_res)) {
0114       cell_size = grid_size.min().item().to<float>() / grid_max_res;
0115     }
0116 
0117     gridParamsCuda.index_put_({i, grid_delta_idx}, 1 / cell_size);
0118 
0119     gridParamsCuda.index_put_({i, Slice(1 + grid_delta_idx, grid_total_idx)},
0120                               floor(grid_size / cell_size) + 1);
0121 
0122     gridParamsCuda.index_put_(
0123         {i, grid_total_idx},
0124         gridParamsCuda.index({i, Slice(1 + grid_delta_idx, grid_total_idx)})
0125             .prod());
0126 
0127     if (G < gridParamsCuda.index({i, grid_total_idx}).item().to<int>()) {
0128       G = gridParamsCuda.index({i, grid_total_idx}).item().to<int>();
0129     }
0130   }
0131 
0132   torch::Tensor pc_grid_cnt =
0133       torch::zeros({batch_size, G}, device).to(torch::kInt32);
0134   torch::Tensor pc_grid_cell =
0135       torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0136   torch::Tensor pc_grid_idx =
0137       torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0138 
0139   // put spacepoints into the grid
0140   InsertPointsCUDA(embedTensor, lengths.to(torch::kInt64), gridParamsCuda,
0141                    pc_grid_cnt, pc_grid_cell, pc_grid_idx, G);
0142 
0143   torch::Tensor pc_grid_off =
0144       torch::full({batch_size, G}, 0, device).to(torch::kInt32);
0145   torch::Tensor grid_params = gridParamsCuda.to(torch::kCPU);
0146 
0147   // for loop seems not to be necessary anymore
0148   pc_grid_off = PrefixSumCUDA(pc_grid_cnt, grid_params);
0149 
0150   torch::Tensor sorted_points =
0151       torch::zeros({batch_size, numSpacepoints, dim}, device)
0152           .to(torch::kFloat32);
0153   torch::Tensor sorted_points_idxs =
0154       torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0155 
0156   CountingSortCUDA(embedTensor, lengths.to(torch::kInt64), pc_grid_cell,
0157                    pc_grid_idx, pc_grid_off, sorted_points, sorted_points_idxs);
0158 
0159   auto [indices, distances] = FindNbrsCUDA(
0160       sorted_points, sorted_points, lengths.to(torch::kInt64),
0161       lengths.to(torch::kInt64), pc_grid_off.to(torch::kInt32),
0162       sorted_points_idxs, sorted_points_idxs,
0163       gridParamsCuda.to(torch::kFloat32), kVal, r_tensor, r_tensor * r_tensor);
0164   torch::Tensor positiveIndices = indices >= 0;
0165 
0166   torch::Tensor repeatRange = torch::arange(positiveIndices.size(1), device)
0167                                   .repeat({1, positiveIndices.size(2), 1})
0168                                   .transpose(1, 2);
0169 
0170   torch::Tensor stackedEdges = torch::stack(
0171       {repeatRange.index({positiveIndices}), indices.index({positiveIndices})});
0172 
0173   return postprocessEdgeTensor(std::move(stackedEdges), true, true,
0174                                flipDirections);
0175 #else
0176   throw std::runtime_error(
0177       "ACTS not compiled with CUDA, cannot run Acts::buildEdgesFRNN");
0178 #endif
0179 }
0180 
0181 /// This is a very unsophisticated span implementation to avoid data copies in
0182 /// the KDTree search.
0183 /// Should be replaced with std::span when possible
0184 template <typename T, std::size_t S>
0185 struct Span {
0186   T *ptr;
0187 
0188   auto size() const { return S; }
0189 
0190   using const_iterator = T const *;
0191   const_iterator cbegin() const { return ptr; }
0192   const_iterator cend() const { return ptr + S; }
0193 
0194   auto &operator[](std::size_t i) const { return ptr[i]; }
0195 };
0196 
0197 template <std::size_t Dim>
0198 float dist(const Span<float, Dim> &a, const Span<float, Dim> &b) {
0199   float s = 0.f;
0200   for (auto i = 0ul; i < Dim; ++i) {
0201     s += (a[i] - b[i]) * (a[i] - b[i]);
0202   }
0203   return std::sqrt(s);
0204 };
0205 
0206 template <std::size_t Dim>
0207 struct BuildEdgesKDTree {
0208   static torch::Tensor invoke(torch::Tensor &embedFeatures, float rVal,
0209                               int kVal) {
0210     assert(embedFeatures.size(1) == Dim);
0211     embedFeatures = embedFeatures.to(torch::kCPU);
0212 
0213     ////////////////
0214     // Build tree //
0215     ////////////////
0216     using KDTree = Acts::KDTree<Dim, int, float, Span>;
0217 
0218     typename KDTree::vector_t features;
0219     features.reserve(embedFeatures.size(0));
0220 
0221     auto dataPtr = embedFeatures.data_ptr<float>();
0222 
0223     for (int i = 0; i < embedFeatures.size(0); ++i) {
0224       features.push_back({Span<float, Dim>{dataPtr + i * Dim}, i});
0225     }
0226 
0227     KDTree tree(std::move(features));
0228 
0229     /////////////////
0230     // Search tree //
0231     /////////////////
0232     std::vector<std::int32_t> edges;
0233     edges.reserve(2 * kVal * embedFeatures.size(0));
0234 
0235     for (int iself = 0; iself < embedFeatures.size(0); ++iself) {
0236       const Span<float, Dim> self{dataPtr + iself * Dim};
0237 
0238       Acts::RangeXD<Dim, float> range;
0239       for (auto j = 0ul; j < Dim; ++j) {
0240         range[j] = Acts::Range1D<float>(self[j] - rVal, self[j] + rVal);
0241       }
0242 
0243       tree.rangeSearchMapDiscard(
0244           range, [&](const Span<float, Dim> &other, const int &iother) {
0245             if (iself != iother && dist(self, other) <= rVal) {
0246               edges.push_back(iself);
0247               edges.push_back(iother);
0248             }
0249           });
0250     }
0251 
0252     // Transpose is necessary here, clone to get ownership
0253     return Acts::detail::vectorToTensor2D(edges, 2).t().clone();
0254   }
0255 };
0256 
0257 torch::Tensor Acts::detail::buildEdgesKDTree(torch::Tensor &embedFeatures,
0258                                              float rVal, int kVal,
0259                                              bool flipDirections) {
0260   auto tensor = Acts::template_switch<BuildEdgesKDTree, 1, 12>(
0261       embedFeatures.size(1), embedFeatures, rVal, kVal);
0262 
0263   return postprocessEdgeTensor(tensor, true, true, flipDirections);
0264 }
0265 
0266 torch::Tensor Acts::detail::buildEdges(torch::Tensor &embedFeatures, float rVal,
0267                                        int kVal, bool flipDirections) {
0268 #ifndef ACTS_EXATRKX_CPUONLY
0269   if (torch::cuda::is_available()) {
0270     return detail::buildEdgesFRNN(embedFeatures, rVal, kVal, flipDirections);
0271   } else {
0272     return detail::buildEdgesKDTree(embedFeatures, rVal, kVal, flipDirections);
0273   }
0274 #else
0275   return detail::buildEdgesKDTree(embedFeatures, rVal, kVal, flipDirections);
0276 #endif
0277 }