File indexing completed on 2025-01-30 09:15:12
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/detail/buildEdges.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 #include "Acts/Utilities/Helpers.hpp"
0013 #include "Acts/Utilities/KDTree.hpp"
0014
0015 #include <iostream>
0016 #include <mutex>
0017 #include <vector>
0018
0019 #include <torch/script.h>
0020 #include <torch/torch.h>
0021
0022 #ifndef ACTS_EXATRKX_CPUONLY
0023 #include <cuda.h>
0024 #include <cuda_runtime_api.h>
0025 #include <grid/counting_sort.h>
0026 #include <grid/find_nbrs.h>
0027 #include <grid/grid.h>
0028 #include <grid/insert_points.h>
0029 #include <grid/prefix_sum.h>
0030 #endif
0031
0032 using namespace torch::indexing;
0033
0034 torch::Tensor Acts::detail::postprocessEdgeTensor(torch::Tensor edges,
0035 bool removeSelfLoops,
0036 bool removeDuplicates,
0037 bool flipDirections) {
0038
0039 if (removeSelfLoops) {
0040 torch::Tensor selfLoopMask = edges.index({0}) != edges.index({1});
0041 edges = edges.index({Slice(), selfLoopMask});
0042 }
0043
0044
0045 if (removeDuplicates) {
0046 torch::Tensor mask = edges.index({0}) > edges.index({1});
0047 edges.index_put_({Slice(), mask}, edges.index({Slice(), mask}).flip(0));
0048 edges = std::get<0>(torch::unique_dim(edges, -1, false));
0049 }
0050
0051
0052 if (flipDirections) {
0053 torch::Tensor random_cut_keep = torch::randint(2, {edges.size(1)});
0054 torch::Tensor random_cut_flip = 1 - random_cut_keep;
0055 torch::Tensor keep_edges =
0056 edges.index({Slice(), random_cut_keep.to(torch::kBool)});
0057 torch::Tensor flip_edges =
0058 edges.index({Slice(), random_cut_flip.to(torch::kBool)}).flip({0});
0059 edges = torch::cat({keep_edges, flip_edges}, 1);
0060 }
0061
0062 return edges.toType(torch::kInt64);
0063 }
0064
0065 torch::Tensor Acts::detail::buildEdgesFRNN(torch::Tensor &embedFeatures,
0066 float rVal, int kVal,
0067 bool flipDirections) {
0068 #ifndef ACTS_EXATRKX_CPUONLY
0069 const auto device = embedFeatures.device();
0070
0071 const std::int64_t numSpacepoints = embedFeatures.size(0);
0072 const int dim = embedFeatures.size(1);
0073
0074 const int grid_params_size = 8;
0075 const int grid_delta_idx = 3;
0076 const int grid_total_idx = 7;
0077 const int grid_max_res = 128;
0078 const int grid_dim = 3;
0079
0080 if (dim < 3) {
0081 throw std::runtime_error("DIM < 3 is not supported for now.\n");
0082 }
0083
0084 const float radius_cell_ratio = 2.0;
0085 const int batch_size = 1;
0086 int G = -1;
0087
0088
0089 torch::Tensor grid_min;
0090 torch::Tensor grid_max;
0091 torch::Tensor grid_size;
0092
0093 torch::Tensor embedTensor = embedFeatures.reshape({1, numSpacepoints, dim});
0094 torch::Tensor gridParamsCuda =
0095 torch::zeros({batch_size, grid_params_size}, device).to(torch::kFloat32);
0096 torch::Tensor r_tensor = torch::full({batch_size}, rVal, device);
0097 torch::Tensor lengths = torch::full({batch_size}, numSpacepoints, device);
0098
0099
0100 for (int i = 0; i < batch_size; i++) {
0101 torch::Tensor allPoints =
0102 embedTensor.index({i, Slice(None, lengths.index({i}).item().to<long>()),
0103 Slice(None, grid_dim)});
0104 grid_min = std::get<0>(allPoints.min(0));
0105 grid_max = std::get<0>(allPoints.max(0));
0106 gridParamsCuda.index_put_({i, Slice(None, grid_delta_idx)}, grid_min);
0107
0108 grid_size = grid_max - grid_min;
0109
0110 float cell_size =
0111 r_tensor.index({i}).item().to<float>() / radius_cell_ratio;
0112
0113 if (cell_size < (grid_size.min().item().to<float>() / grid_max_res)) {
0114 cell_size = grid_size.min().item().to<float>() / grid_max_res;
0115 }
0116
0117 gridParamsCuda.index_put_({i, grid_delta_idx}, 1 / cell_size);
0118
0119 gridParamsCuda.index_put_({i, Slice(1 + grid_delta_idx, grid_total_idx)},
0120 floor(grid_size / cell_size) + 1);
0121
0122 gridParamsCuda.index_put_(
0123 {i, grid_total_idx},
0124 gridParamsCuda.index({i, Slice(1 + grid_delta_idx, grid_total_idx)})
0125 .prod());
0126
0127 if (G < gridParamsCuda.index({i, grid_total_idx}).item().to<int>()) {
0128 G = gridParamsCuda.index({i, grid_total_idx}).item().to<int>();
0129 }
0130 }
0131
0132 torch::Tensor pc_grid_cnt =
0133 torch::zeros({batch_size, G}, device).to(torch::kInt32);
0134 torch::Tensor pc_grid_cell =
0135 torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0136 torch::Tensor pc_grid_idx =
0137 torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0138
0139
0140 InsertPointsCUDA(embedTensor, lengths.to(torch::kInt64), gridParamsCuda,
0141 pc_grid_cnt, pc_grid_cell, pc_grid_idx, G);
0142
0143 torch::Tensor pc_grid_off =
0144 torch::full({batch_size, G}, 0, device).to(torch::kInt32);
0145 torch::Tensor grid_params = gridParamsCuda.to(torch::kCPU);
0146
0147
0148 pc_grid_off = PrefixSumCUDA(pc_grid_cnt, grid_params);
0149
0150 torch::Tensor sorted_points =
0151 torch::zeros({batch_size, numSpacepoints, dim}, device)
0152 .to(torch::kFloat32);
0153 torch::Tensor sorted_points_idxs =
0154 torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
0155
0156 CountingSortCUDA(embedTensor, lengths.to(torch::kInt64), pc_grid_cell,
0157 pc_grid_idx, pc_grid_off, sorted_points, sorted_points_idxs);
0158
0159 auto [indices, distances] = FindNbrsCUDA(
0160 sorted_points, sorted_points, lengths.to(torch::kInt64),
0161 lengths.to(torch::kInt64), pc_grid_off.to(torch::kInt32),
0162 sorted_points_idxs, sorted_points_idxs,
0163 gridParamsCuda.to(torch::kFloat32), kVal, r_tensor, r_tensor * r_tensor);
0164 torch::Tensor positiveIndices = indices >= 0;
0165
0166 torch::Tensor repeatRange = torch::arange(positiveIndices.size(1), device)
0167 .repeat({1, positiveIndices.size(2), 1})
0168 .transpose(1, 2);
0169
0170 torch::Tensor stackedEdges = torch::stack(
0171 {repeatRange.index({positiveIndices}), indices.index({positiveIndices})});
0172
0173 return postprocessEdgeTensor(std::move(stackedEdges), true, true,
0174 flipDirections);
0175 #else
0176 throw std::runtime_error(
0177 "ACTS not compiled with CUDA, cannot run Acts::buildEdgesFRNN");
0178 #endif
0179 }
0180
0181
0182
0183
0184 template <typename T, std::size_t S>
0185 struct Span {
0186 T *ptr;
0187
0188 auto size() const { return S; }
0189
0190 using const_iterator = T const *;
0191 const_iterator cbegin() const { return ptr; }
0192 const_iterator cend() const { return ptr + S; }
0193
0194 auto &operator[](std::size_t i) const { return ptr[i]; }
0195 };
0196
0197 template <std::size_t Dim>
0198 float dist(const Span<float, Dim> &a, const Span<float, Dim> &b) {
0199 float s = 0.f;
0200 for (auto i = 0ul; i < Dim; ++i) {
0201 s += (a[i] - b[i]) * (a[i] - b[i]);
0202 }
0203 return std::sqrt(s);
0204 };
0205
0206 template <std::size_t Dim>
0207 struct BuildEdgesKDTree {
0208 static torch::Tensor invoke(torch::Tensor &embedFeatures, float rVal,
0209 int kVal) {
0210 assert(embedFeatures.size(1) == Dim);
0211 embedFeatures = embedFeatures.to(torch::kCPU);
0212
0213
0214
0215
0216 using KDTree = Acts::KDTree<Dim, int, float, Span>;
0217
0218 typename KDTree::vector_t features;
0219 features.reserve(embedFeatures.size(0));
0220
0221 auto dataPtr = embedFeatures.data_ptr<float>();
0222
0223 for (int i = 0; i < embedFeatures.size(0); ++i) {
0224 features.push_back({Span<float, Dim>{dataPtr + i * Dim}, i});
0225 }
0226
0227 KDTree tree(std::move(features));
0228
0229
0230
0231
0232 std::vector<std::int32_t> edges;
0233 edges.reserve(2 * kVal * embedFeatures.size(0));
0234
0235 for (int iself = 0; iself < embedFeatures.size(0); ++iself) {
0236 const Span<float, Dim> self{dataPtr + iself * Dim};
0237
0238 Acts::RangeXD<Dim, float> range;
0239 for (auto j = 0ul; j < Dim; ++j) {
0240 range[j] = Acts::Range1D<float>(self[j] - rVal, self[j] + rVal);
0241 }
0242
0243 tree.rangeSearchMapDiscard(
0244 range, [&](const Span<float, Dim> &other, const int &iother) {
0245 if (iself != iother && dist(self, other) <= rVal) {
0246 edges.push_back(iself);
0247 edges.push_back(iother);
0248 }
0249 });
0250 }
0251
0252
0253 return Acts::detail::vectorToTensor2D(edges, 2).t().clone();
0254 }
0255 };
0256
0257 torch::Tensor Acts::detail::buildEdgesKDTree(torch::Tensor &embedFeatures,
0258 float rVal, int kVal,
0259 bool flipDirections) {
0260 auto tensor = Acts::template_switch<BuildEdgesKDTree, 1, 12>(
0261 embedFeatures.size(1), embedFeatures, rVal, kVal);
0262
0263 return postprocessEdgeTensor(tensor, true, true, flipDirections);
0264 }
0265
0266 torch::Tensor Acts::detail::buildEdges(torch::Tensor &embedFeatures, float rVal,
0267 int kVal, bool flipDirections) {
0268 #ifndef ACTS_EXATRKX_CPUONLY
0269 if (torch::cuda::is_available()) {
0270 return detail::buildEdgesFRNN(embedFeatures, rVal, kVal, flipDirections);
0271 } else {
0272 return detail::buildEdgesKDTree(embedFeatures, rVal, kVal, flipDirections);
0273 }
0274 #else
0275 return detail::buildEdgesKDTree(embedFeatures, rVal, kVal, flipDirections);
0276 #endif
0277 }