File indexing completed on 2025-12-16 09:24:26
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsPlugins/Gnn/detail/buildEdges.hpp"
0010
0011 #include "Acts/Utilities/Helpers.hpp"
0012 #include "Acts/Utilities/KDTree.hpp"
0013 #include "ActsPlugins/Gnn/detail/TensorVectorConversion.hpp"
0014
0015 #include <iostream>
0016 #include <mutex>
0017 #include <vector>
0018
0019 #include <torch/script.h>
0020 #include <torch/torch.h>
0021
0022 #ifndef ACTS_GNN_CPUONLY
0023 #include <cuda.h>
0024 #include <cuda_runtime_api.h>
0025 #include <grid/counting_sort.h>
0026 #include <grid/find_nbrs.h>
0027 #include <grid/grid.h>
0028 #include <grid/insert_points.h>
0029 #include <grid/prefix_sum.h>
0030 #endif
0031
0032 using namespace torch::indexing;
0033
0034 using namespace Acts;
0035
0036 torch::Tensor ActsPlugins::detail::postprocessEdgeTensor(torch::Tensor edges,
0037 bool removeSelfLoops,
0038 bool removeDuplicates,
0039 bool flipDirections) {
0040
0041 if (removeSelfLoops) {
0042 torch::Tensor selfLoopMask = edges.index({0}) != edges.index({1});
0043 edges = edges.index({Slice(), selfLoopMask});
0044 }
0045
0046
0047 if (removeDuplicates) {
0048 torch::Tensor mask = edges.index({0}) > edges.index({1});
0049 edges.index_put_({Slice(), mask}, edges.index({Slice(), mask}).flip(0));
0050 edges = std::get<0>(torch::unique_dim(edges, -1, false));
0051 }
0052
0053
0054 if (flipDirections) {
0055 torch::Tensor random_cut_keep = torch::randint(2, {edges.size(1)});
0056 torch::Tensor random_cut_flip = 1 - random_cut_keep;
0057 torch::Tensor keep_edges =
0058 edges.index({Slice(), random_cut_keep.to(torch::kBool)});
0059 torch::Tensor flip_edges =
0060 edges.index({Slice(), random_cut_flip.to(torch::kBool)}).flip({0});
0061 edges = torch::cat({keep_edges, flip_edges}, 1);
0062 }
0063
0064 return edges.toType(torch::kInt64);
0065 }
0066
0067 torch::Tensor ActsPlugins::detail::buildEdgesFRNN(torch::Tensor &embedFeatures,
0068 float rVal, int kVal,
0069 bool flipDirections) {
0070 #ifndef ACTS_GNN_CPUONLY
0071 const auto device = embedFeatures.device();
0072
0073 const std::int64_t numSpacepoints = embedFeatures.size(0);
0074 const int dim = embedFeatures.size(1);
0075
0076 const int grid_params_size = 8;
0077 const int grid_delta_idx = 3;
0078 const int grid_total_idx = 7;
0079 const int grid_max_res = 128;
0080 const int grid_dim = 3;
0081
0082 if (dim < 3) {
0083 throw std::runtime_error("DIM < 3 is not supported for now.\n");
0084 }
0085
0086 const float radius_cell_ratio = 2.0;
0087 const int batch_size = 1;
0088 int G = -1;
0089
0090
0091 torch::Tensor grid_min;
0092 torch::Tensor grid_max;
0093 torch::Tensor grid_size;
0094
0095 torch::Tensor embedTensor = embedFeatures.reshape({1, numSpacepoints, dim});
0096 torch::Tensor gridParamsCuda =
0097 torch::zeros({batch_size, grid_params_size}, device).to(torch::kFloat32);
0098 torch::Tensor r_tensor = torch::full({batch_size}, rVal, device);
0099 torch::Tensor lengths = torch::full({batch_size}, numSpacepoints, device);
0100
0101
0102 for (int i = 0; i < batch_size; i++) {
0103 torch::Tensor allPoints =
0104 embedTensor.index({i, Slice(None, lengths.index({i}).item().to<long>()),
0105 Slice(None, grid_dim)});
0106 grid_min = std::get<0>(allPoints.min(0));
0107 grid_max = std::get<0>(allPoints.max(0));
0108 gridParamsCuda.index_put_({i, Slice(None, grid_delta_idx)}, grid_min);
0109
0110 grid_size = grid_max - grid_min;
0111
0112 float cell_size =
0113 r_tensor.index({i}).item().to<float>() / radius_cell_ratio;
0114
0115 if (cell_size < (grid_size.min().item().to<float>() / grid_max_res)) {
0116 cell_size = grid_size.min().item().to<float>() / grid_max_res;
0117 }
0118
0119 gridParamsCuda.index_put_({i, grid_delta_idx}, 1 / cell_size);
0120
0121 gridParamsCuda.index_put_({i, Slice(1 + grid_delta_idx, grid_total_idx)},
0122 floor(grid_size / cell_size) + 1);
0123
0124 gridParamsCuda.index_put_(
0125 {i, grid_total_idx},
0126 gridParamsCuda.index({i, Slice(1 + grid_delta_idx, grid_total_idx)})
0127 .prod());
0128
0129 if (G < gridParamsCuda.index({i, grid_total_idx}).item().to<int>()) {
0130 G = gridParamsCuda.index({i, grid_total_idx}).item().to<int>();
0131 }
0132 }
0133
0134 torch::Tensor pc_grid_cnt =
0135 torch::zeros({batch_size, G}, device).to(torch::kInt32);
0136 torch::Tensor pc_grid_cell =
0137 torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0138 torch::Tensor pc_grid_idx =
0139 torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0140
0141
0142 InsertPointsCUDA(embedTensor, lengths.to(torch::kInt64), gridParamsCuda,
0143 pc_grid_cnt, pc_grid_cell, pc_grid_idx, G);
0144
0145 torch::Tensor pc_grid_off =
0146 torch::full({batch_size, G}, 0, device).to(torch::kInt32);
0147 torch::Tensor grid_params = gridParamsCuda.to(torch::kCPU);
0148
0149
0150 pc_grid_off = PrefixSumCUDA(pc_grid_cnt, grid_params);
0151
0152 torch::Tensor sorted_points =
0153 torch::zeros({batch_size, numSpacepoints, dim}, device)
0154 .to(torch::kFloat32);
0155 torch::Tensor sorted_points_idxs =
0156 torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0157
0158 CountingSortCUDA(embedTensor, lengths.to(torch::kInt64), pc_grid_cell,
0159 pc_grid_idx, pc_grid_off, sorted_points, sorted_points_idxs);
0160
0161 auto [indices, distances] = FindNbrsCUDA(
0162 sorted_points, sorted_points, lengths.to(torch::kInt64),
0163 lengths.to(torch::kInt64), pc_grid_off.to(torch::kInt32),
0164 sorted_points_idxs, sorted_points_idxs,
0165 gridParamsCuda.to(torch::kFloat32), kVal, r_tensor, r_tensor * r_tensor);
0166 torch::Tensor positiveIndices = indices >= 0;
0167
0168 torch::Tensor repeatRange = torch::arange(positiveIndices.size(1), device)
0169 .repeat({1, positiveIndices.size(2), 1})
0170 .transpose(1, 2);
0171
0172 torch::Tensor stackedEdges = torch::stack(
0173 {repeatRange.index({positiveIndices}), indices.index({positiveIndices})});
0174
0175 return postprocessEdgeTensor(std::move(stackedEdges), true, true,
0176 flipDirections);
0177 #else
0178 throw std::runtime_error(
0179 "ACTS not compiled with CUDA, cannot run ActsPlugins::buildEdgesFRNN");
0180 #endif
0181 }
0182
0183
0184
0185
0186 template <typename T, std::size_t S>
0187 struct Span {
0188 T *ptr;
0189
0190 auto size() const { return S; }
0191
0192 using const_iterator = T const *;
0193 const_iterator cbegin() const { return ptr; }
0194 const_iterator cend() const { return ptr + S; }
0195
0196 auto &operator[](std::size_t i) const { return ptr[i]; }
0197 };
0198
0199 template <std::size_t Dim>
0200 float dist(const Span<float, Dim> &a, const Span<float, Dim> &b) {
0201 float s = 0.f;
0202 for (auto i = 0ul; i < Dim; ++i) {
0203 s += (a[i] - b[i]) * (a[i] - b[i]);
0204 }
0205 return std::sqrt(s);
0206 };
0207
0208 template <std::size_t Dim>
0209 struct BuildEdgesKDTree {
0210 static torch::Tensor invoke(torch::Tensor &embedFeatures, float rVal,
0211 int kVal) {
0212 assert(embedFeatures.size(1) == Dim);
0213 embedFeatures = embedFeatures.to(torch::kCPU);
0214
0215
0216
0217
0218 using KDTree = KDTree<Dim, int, float, Span>;
0219
0220 typename KDTree::vector_t features;
0221 features.reserve(embedFeatures.size(0));
0222
0223 auto dataPtr = embedFeatures.data_ptr<float>();
0224
0225 for (int i = 0; i < embedFeatures.size(0); ++i) {
0226 features.push_back({Span<float, Dim>{dataPtr + i * Dim}, i});
0227 }
0228
0229 KDTree tree(std::move(features));
0230
0231
0232
0233
0234 std::vector<std::int32_t> edges;
0235 edges.reserve(2 * kVal * embedFeatures.size(0));
0236
0237 for (int iself = 0; iself < embedFeatures.size(0); ++iself) {
0238 const Span<float, Dim> self{dataPtr + iself * Dim};
0239
0240 RangeXD<Dim, float> range;
0241 for (auto j = 0ul; j < Dim; ++j) {
0242 range[j] = Range1D<float>(self[j] - rVal, self[j] + rVal);
0243 }
0244
0245 tree.rangeSearchMapDiscard(
0246 range, [&](const Span<float, Dim> &other, const int &iother) {
0247 if (iself != iother && dist(self, other) <= rVal) {
0248 edges.push_back(iself);
0249 edges.push_back(iother);
0250 }
0251 });
0252 }
0253
0254
0255 return ActsPlugins::detail::vectorToTensor2D(edges, 2).t().clone();
0256 }
0257 };
0258
0259 torch::Tensor ActsPlugins::detail::buildEdgesKDTree(
0260 torch::Tensor &embedFeatures, float rVal, int kVal, bool flipDirections) {
0261 auto tensor = template_switch<BuildEdgesKDTree, 1, 12>(
0262 embedFeatures.size(1), embedFeatures, rVal, kVal);
0263
0264 return postprocessEdgeTensor(tensor, true, true, flipDirections);
0265 }
0266
0267 torch::Tensor ActsPlugins::detail::buildEdges(torch::Tensor &embedFeatures,
0268 float rVal, int kVal,
0269 bool flipDirections) {
0270 #ifndef ACTS_GNN_CPUONLY
0271 if (torch::cuda::is_available()) {
0272 return detail::buildEdgesFRNN(embedFeatures, rVal, kVal, flipDirections);
0273 } else {
0274 return detail::buildEdgesKDTree(embedFeatures, rVal, kVal, flipDirections);
0275 }
0276 #else
0277 return detail::buildEdgesKDTree(embedFeatures, rVal, kVal, flipDirections);
0278 #endif
0279 }