File indexing completed on 2025-12-16 09:24:26
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Utilities/Zip.hpp"
0010 #include "ActsPlugins/Gnn/CudaTrackBuilding.hpp"
0011 #include "ActsPlugins/Gnn/detail/ConnectedComponents.cuh"
0012 #include "ActsPlugins/Gnn/detail/CudaUtils.cuh"
0013 #include "ActsPlugins/Gnn/detail/CudaUtils.hpp"
0014 #include "ActsPlugins/Gnn/detail/JunctionRemoval.hpp"
0015
0016 using namespace Acts;
0017
0018 namespace ActsPlugins {
0019
0020 std::vector<std::vector<int>> CudaTrackBuilding::operator()(
0021 PipelineTensors tensors, std::vector<int>& spacepointIDs,
0022 const ExecutionContext& execContext) {
0023 ACTS_VERBOSE("Start CUDA track building");
0024 if (!(tensors.edgeIndex.device().isCuda() &&
0025 tensors.edgeScores.value().device().isCuda())) {
0026 throw std::runtime_error(
0027 "CudaTrackBuilding expects tensors to be on CUDA!");
0028 }
0029
0030 const auto numSpacepoints = spacepointIDs.size();
0031 auto numEdges = static_cast<std::size_t>(tensors.edgeIndex.shape().at(1));
0032
0033 if (numEdges == 0) {
0034 ACTS_DEBUG("No edges remained after edge classification");
0035 return {};
0036 }
0037
0038 auto stream = execContext.stream.value();
0039
0040 auto cudaSrcPtr = tensors.edgeIndex.data();
0041 auto cudaTgtPtr = tensors.edgeIndex.data() + numEdges;
0042
0043 auto ms = [](auto t0, auto t1) {
0044 return std::chrono::duration_cast<std::chrono::milliseconds>(t1 - t0)
0045 .count();
0046 };
0047
0048 if (m_cfg.doJunctionRemoval) {
0049 assert(tensors.edgeScores->shape().at(0) ==
0050 tensors.edgeIndex.shape().at(1));
0051 auto cudaScorePtr = tensors.edgeScores->data();
0052
0053 ACTS_DEBUG("Do junction removal...");
0054 auto t0 = std::chrono::high_resolution_clock::now();
0055 auto [cudaSrcPtrJr, numEdgesOut] = detail::junctionRemovalCuda(
0056 numEdges, numSpacepoints, cudaScorePtr, cudaSrcPtr, cudaTgtPtr, stream);
0057 auto t1 = std::chrono::high_resolution_clock::now();
0058 cudaSrcPtr = cudaSrcPtrJr;
0059 cudaTgtPtr = cudaSrcPtrJr + numEdgesOut;
0060
0061 if (numEdgesOut == 0) {
0062 ACTS_WARNING(
0063 "No edges remained after junction removal, this should not happen!");
0064 ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcPtrJr, stream));
0065 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0066 return {};
0067 }
0068
0069 ACTS_DEBUG("Removed " << numEdges - numEdgesOut
0070 << " edges in junction removal");
0071 ACTS_DEBUG("Junction removal took " << ms(t0, t1) << " ms");
0072 numEdges = numEdgesOut;
0073 }
0074
0075 int* cudaLabels{};
0076 ACTS_CUDA_CHECK(
0077 cudaMallocAsync(&cudaLabels, numSpacepoints * sizeof(int), stream));
0078
0079 auto t0 = std::chrono::high_resolution_clock::now();
0080 std::size_t numberLabels = detail::connectedComponentsCuda(
0081 numEdges, cudaSrcPtr, cudaTgtPtr, numSpacepoints, cudaLabels, stream,
0082 m_cfg.useOneBlockImplementation);
0083 auto t1 = std::chrono::high_resolution_clock::now();
0084 ACTS_DEBUG("Connected components took " << ms(t0, t1) << " ms");
0085 ACTS_VERBOSE("Found " << numberLabels << " track candidates");
0086
0087
0088 int* cudaSpacepointIDs{};
0089 ACTS_CUDA_CHECK(cudaMallocAsync(&cudaSpacepointIDs,
0090 spacepointIDs.size() * sizeof(int), stream));
0091 ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaSpacepointIDs, spacepointIDs.data(),
0092 spacepointIDs.size() * sizeof(int),
0093 cudaMemcpyHostToDevice, stream));
0094
0095
0096 int* cudaBounds{};
0097 ACTS_CUDA_CHECK(
0098 cudaMallocAsync(&cudaBounds, (numberLabels + 1) * sizeof(int), stream));
0099
0100
0101 detail::findTrackCandidateBounds(cudaLabels, cudaSpacepointIDs, cudaBounds,
0102 numSpacepoints, numberLabels, stream);
0103
0104
0105 std::vector<int> bounds(numberLabels + 1);
0106 ACTS_CUDA_CHECK(cudaMemcpyAsync(bounds.data(), cudaBounds,
0107 (numberLabels + 1) * sizeof(int),
0108 cudaMemcpyDeviceToHost, stream));
0109
0110
0111 ACTS_CUDA_CHECK(cudaMemcpyAsync(spacepointIDs.data(), cudaSpacepointIDs,
0112 spacepointIDs.size() * sizeof(int),
0113 cudaMemcpyDeviceToHost, stream));
0114
0115
0116 ACTS_CUDA_CHECK(cudaFreeAsync(cudaLabels, stream));
0117 ACTS_CUDA_CHECK(cudaFreeAsync(cudaSpacepointIDs, stream));
0118 ACTS_CUDA_CHECK(cudaFreeAsync(cudaBounds, stream));
0119 if (m_cfg.doJunctionRemoval) {
0120 ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcPtr, stream));
0121 }
0122
0123 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0124 ACTS_CUDA_CHECK(cudaGetLastError());
0125
0126 ACTS_DEBUG("Bounds size: " << bounds.size());
0127 ACTS_DEBUG("Bounds: " << bounds.at(0) << ", " << bounds.at(1) << ", "
0128 << bounds.at(2) << ", ..., "
0129 << bounds.at(numberLabels - 2) << ", "
0130 << bounds.at(numberLabels - 1) << ", "
0131 << bounds.at(numberLabels));
0132
0133 std::vector<std::vector<int>> trackCandidates;
0134 trackCandidates.reserve(numberLabels);
0135 for (std::size_t label = 0ul; label < numberLabels; ++label) {
0136 int start = bounds.at(label);
0137 int end = bounds.at(label + 1);
0138
0139 assert(start >= 0);
0140 assert(end <= static_cast<int>(numSpacepoints));
0141 assert(start <= end);
0142
0143 if (end - start < m_cfg.minCandidateSize) {
0144 continue;
0145 }
0146
0147 trackCandidates.emplace_back(spacepointIDs.begin() + start,
0148 spacepointIDs.begin() + end);
0149 }
0150
0151 return trackCandidates;
0152 }
0153
0154 }