File indexing completed on 2025-05-13 07:58:41
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.cuh"
0010 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.hpp"
0011 #include "Acts/Plugins/ExaTrkX/detail/JunctionRemoval.hpp"
0012
0013 #include <thrust/count.h>
0014 #include <thrust/execution_policy.h>
0015 #include <thrust/scan.h>
0016 #include <thrust/transform_scan.h>
0017
0018 namespace Acts::detail {
0019
0020 __global__ void findNumInOutEdge(std::size_t nEdges,
0021 const std::int64_t *srcNodes,
0022 const std::int64_t *dstNodes, int *numInEdges,
0023 int *numOutEdges) {
0024 const std::size_t i = blockIdx.x * blockDim.x + threadIdx.x;
0025 if (i >= nEdges) {
0026 return;
0027 }
0028
0029 auto srcNode = srcNodes[i];
0030 auto dstNode = dstNodes[i];
0031
0032 atomicAdd(&numInEdges[dstNode], 1);
0033 atomicAdd(&numOutEdges[srcNode], 1);
0034 }
0035
0036 __global__ void fillJunctionEdges(std::size_t nEdges,
0037 const std::int64_t *edgeNodes,
0038 const int *numEdgesPrefixSum,
0039 int *junctionEdges, int *junctionEdgeOffset) {
0040 const std::size_t i = blockIdx.x * blockDim.x + threadIdx.x;
0041 if (i >= nEdges) {
0042 return;
0043 }
0044
0045 int node = edgeNodes[i];
0046 int base = numEdgesPrefixSum[node];
0047 int numEdgesNode = numEdgesPrefixSum[node + 1] - base;
0048
0049
0050 assert(numEdgesNode != 1 && "node is not a junction");
0051
0052 if (numEdgesNode != 0) {
0053 int offset = atomicAdd(&junctionEdgeOffset[node], 1);
0054 assert(offset < numEdgesNode && "inconsistent offset with number of edges");
0055 junctionEdges[base + offset] = i;
0056 }
0057 }
0058
0059 __global__ void fillEdgeMask(std::size_t nNodes, const float *scores,
0060 const int *numEdgesPrefixSum,
0061 const int *junctionEdges,
0062 bool *edgesToRemoveMask) {
0063 const std::size_t i = blockIdx.x * blockDim.x + threadIdx.x;
0064 if (i >= nNodes) {
0065 return;
0066 }
0067
0068
0069 int base = numEdgesPrefixSum[i];
0070 int numEdgesNode = numEdgesPrefixSum[i + 1] - base;
0071
0072
0073 float maxScore = 0.0f;
0074 int edgeIdMaxScore = -1;
0075 for (int j = base; j < base + numEdgesNode; ++j) {
0076 int edgeId = junctionEdges[j];
0077 float score = scores[edgeId];
0078 if (score > maxScore) {
0079 maxScore = score;
0080 edgeIdMaxScore = edgeId;
0081 }
0082 }
0083
0084
0085 for (int j = base; j < base + numEdgesNode; ++j) {
0086 int edgeId = junctionEdges[j];
0087 if (edgeId != edgeIdMaxScore) {
0088 edgesToRemoveMask[edgeId] = true;
0089 }
0090 }
0091 }
0092
0093 struct LogicalNotPredicate {
0094 bool __device__ operator()(bool b) { return !b; }
0095 };
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107 struct AccumulateJunctionEdges {
0108 int __device__ operator()(int a, int b) const {
0109 a = a < 2 ? 0 : a;
0110 b = b < 2 ? 0 : b;
0111 return a + b;
0112 }
0113 };
0114
0115 std::pair<std::int64_t *, std::size_t> junctionRemovalCuda(
0116 std::size_t nEdges, std::size_t nNodes, const float *scores,
0117 const std::int64_t *srcNodes, const std::int64_t *dstNodes,
0118 cudaStream_t stream) {
0119
0120 int *numInEdges{}, *numOutEdges{};
0121 ACTS_CUDA_CHECK(
0122 cudaMallocAsync(&numInEdges, (nNodes + 1) * sizeof(int), stream));
0123 ACTS_CUDA_CHECK(
0124 cudaMallocAsync(&numOutEdges, (nNodes + 1) * sizeof(int), stream));
0125
0126
0127 ACTS_CUDA_CHECK(
0128 cudaMemsetAsync(numInEdges, 0, (nNodes + 1) * sizeof(int), stream));
0129 ACTS_CUDA_CHECK(
0130 cudaMemsetAsync(numOutEdges, 0, (nNodes + 1) * sizeof(int), stream));
0131
0132
0133 const dim3 blockSize = 512;
0134 const dim3 gridSizeEdges = (nEdges + blockSize.x - 1) / blockSize.x;
0135 findNumInOutEdge<<<gridSizeEdges, blockSize, 0, stream>>>(
0136 nEdges, srcNodes, dstNodes, numInEdges, numOutEdges);
0137 ACTS_CUDA_CHECK(cudaGetLastError());
0138
0139
0140
0141 thrust::exclusive_scan(thrust::device.on(stream), numInEdges,
0142 numInEdges + nNodes + 1, numInEdges, 0,
0143 AccumulateJunctionEdges{});
0144 thrust::exclusive_scan(thrust::device.on(stream), numOutEdges,
0145 numOutEdges + nNodes + 1, numOutEdges, 0,
0146 AccumulateJunctionEdges{});
0147
0148
0149 int numJunctionInEdges{}, numJunctionOutEdges{};
0150 ACTS_CUDA_CHECK(cudaMemcpyAsync(&numJunctionInEdges, &numInEdges[nNodes],
0151 sizeof(int), cudaMemcpyDeviceToHost, stream));
0152 ACTS_CUDA_CHECK(cudaMemcpyAsync(&numJunctionOutEdges, &numOutEdges[nNodes],
0153 sizeof(int), cudaMemcpyDeviceToHost, stream));
0154 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0155
0156
0157 int *junctionInEdges{}, *junctionOutEdges{};
0158 ACTS_CUDA_CHECK(cudaMallocAsync(&junctionInEdges,
0159 numJunctionInEdges * sizeof(int), stream));
0160 ACTS_CUDA_CHECK(cudaMallocAsync(&junctionOutEdges,
0161 numJunctionOutEdges * sizeof(int), stream));
0162
0163
0164
0165 int *junctionInEdgeOffset{}, *junctionOutEdgeOffset{};
0166 ACTS_CUDA_CHECK(
0167 cudaMallocAsync(&junctionInEdgeOffset, nNodes * sizeof(int), stream));
0168 ACTS_CUDA_CHECK(
0169 cudaMallocAsync(&junctionOutEdgeOffset, nNodes * sizeof(int), stream));
0170 ACTS_CUDA_CHECK(
0171 cudaMemsetAsync(junctionInEdgeOffset, 0, nNodes * sizeof(int), stream));
0172 ACTS_CUDA_CHECK(
0173 cudaMemsetAsync(junctionOutEdgeOffset, 0, nNodes * sizeof(int), stream));
0174
0175
0176 fillJunctionEdges<<<gridSizeEdges, blockSize, 0, stream>>>(
0177 nEdges, srcNodes, numOutEdges, junctionOutEdges, junctionOutEdgeOffset);
0178 ACTS_CUDA_CHECK(cudaGetLastError());
0179 fillJunctionEdges<<<gridSizeEdges, blockSize, 0, stream>>>(
0180 nEdges, dstNodes, numInEdges, junctionInEdges, junctionInEdgeOffset);
0181 ACTS_CUDA_CHECK(cudaGetLastError());
0182
0183
0184 bool *edgesToRemoveMask{};
0185 ACTS_CUDA_CHECK(
0186 cudaMallocAsync(&edgesToRemoveMask, nEdges * sizeof(bool), stream));
0187 ACTS_CUDA_CHECK(
0188 cudaMemsetAsync(edgesToRemoveMask, 0, nEdges * sizeof(bool), stream));
0189
0190
0191 const dim3 gridSizeNodes = (nNodes + blockSize.x - 1) / blockSize.x;
0192 fillEdgeMask<<<gridSizeNodes, blockSize, 0, stream>>>(
0193 nNodes, scores, numInEdges, junctionInEdges, edgesToRemoveMask);
0194 ACTS_CUDA_CHECK(cudaGetLastError());
0195 fillEdgeMask<<<gridSizeNodes, blockSize, 0, stream>>>(
0196 nNodes, scores, numOutEdges, junctionOutEdges, edgesToRemoveMask);
0197 ACTS_CUDA_CHECK(cudaGetLastError());
0198
0199
0200 ACTS_CUDA_CHECK(cudaFreeAsync(numInEdges, stream));
0201 ACTS_CUDA_CHECK(cudaFreeAsync(numOutEdges, stream));
0202 ACTS_CUDA_CHECK(cudaFreeAsync(junctionInEdges, stream));
0203 ACTS_CUDA_CHECK(cudaFreeAsync(junctionOutEdges, stream));
0204 ACTS_CUDA_CHECK(cudaFreeAsync(junctionInEdgeOffset, stream));
0205 ACTS_CUDA_CHECK(cudaFreeAsync(junctionOutEdgeOffset, stream));
0206
0207
0208 int nEdgesToRemove =
0209 thrust::count(thrust::device.on(stream), edgesToRemoveMask,
0210 edgesToRemoveMask + nEdges, true);
0211 int nEdgesAfter = nEdges - nEdgesToRemove;
0212
0213 std::int64_t *newSrcNodes{};
0214 ACTS_CUDA_CHECK(cudaMallocAsync(
0215 &newSrcNodes, 2 * nEdgesAfter * sizeof(std::int64_t), stream));
0216 std::int64_t *newDstNodes = newSrcNodes + nEdgesAfter;
0217
0218
0219 thrust::copy_if(thrust::device.on(stream), srcNodes, srcNodes + nEdges,
0220 edgesToRemoveMask, newSrcNodes, LogicalNotPredicate{});
0221 thrust::copy_if(thrust::device.on(stream), dstNodes, dstNodes + nEdges,
0222 edgesToRemoveMask, newDstNodes, LogicalNotPredicate{});
0223
0224
0225 ACTS_CUDA_CHECK(cudaFreeAsync(edgesToRemoveMask, stream));
0226
0227
0228 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0229
0230 return std::make_pair(newSrcNodes, static_cast<std::size_t>(nEdgesAfter));
0231 }
0232
0233 }