File indexing completed on 2025-07-02 07:51:54
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/CudaTrackBuilding.hpp"
0010 #include "Acts/Plugins/ExaTrkX/detail/ConnectedComponents.cuh"
0011 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.cuh"
0012 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.hpp"
0013 #include "Acts/Plugins/ExaTrkX/detail/JunctionRemoval.hpp"
0014 #include "Acts/Utilities/Zip.hpp"
0015
0016 namespace Acts {
0017
0018 std::vector<std::vector<int>> CudaTrackBuilding::operator()(
0019 PipelineTensors tensors, std::vector<int>& spacepointIDs,
0020 const ExecutionContext& execContext) {
0021 ACTS_VERBOSE("Start CUDA track building");
0022 if (!(tensors.edgeIndex.device().isCuda() &&
0023 tensors.edgeScores.value().device().isCuda())) {
0024 throw std::runtime_error(
0025 "CudaTrackBuilding expects tensors to be on CUDA!");
0026 }
0027
0028 const auto numSpacepoints = spacepointIDs.size();
0029 auto numEdges = static_cast<std::size_t>(tensors.edgeIndex.shape().at(1));
0030
0031 if (numEdges == 0) {
0032 ACTS_DEBUG("No edges remained after edge classification");
0033 return {};
0034 }
0035
0036 auto stream = execContext.stream.value();
0037
0038 auto cudaSrcPtr = tensors.edgeIndex.data();
0039 auto cudaTgtPtr = tensors.edgeIndex.data() + numEdges;
0040
0041 auto ms = [](auto t0, auto t1) {
0042 return std::chrono::duration_cast<std::chrono::milliseconds>(t1 - t0)
0043 .count();
0044 };
0045
0046 if (m_cfg.doJunctionRemoval) {
0047 assert(tensors.edgeScores->shape().at(0) ==
0048 tensors.edgeIndex.shape().at(1));
0049 auto cudaScorePtr = tensors.edgeScores->data();
0050
0051 ACTS_DEBUG("Do junction removal...");
0052 auto t0 = std::chrono::high_resolution_clock::now();
0053 auto [cudaSrcPtrJr, numEdgesOut] = detail::junctionRemovalCuda(
0054 numEdges, numSpacepoints, cudaScorePtr, cudaSrcPtr, cudaTgtPtr, stream);
0055 auto t1 = std::chrono::high_resolution_clock::now();
0056 cudaSrcPtr = cudaSrcPtrJr;
0057 cudaTgtPtr = cudaSrcPtrJr + numEdgesOut;
0058
0059 if (numEdgesOut == 0) {
0060 ACTS_WARNING(
0061 "No edges remained after junction removal, this should not happen!");
0062 ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcPtrJr, stream));
0063 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0064 return {};
0065 }
0066
0067 ACTS_DEBUG("Removed " << numEdges - numEdgesOut
0068 << " edges in junction removal");
0069 ACTS_DEBUG("Junction removal took " << ms(t0, t1) << " ms");
0070 numEdges = numEdgesOut;
0071 }
0072
0073 int* cudaLabels{};
0074 ACTS_CUDA_CHECK(
0075 cudaMallocAsync(&cudaLabels, numSpacepoints * sizeof(int), stream));
0076
0077 auto t0 = std::chrono::high_resolution_clock::now();
0078 std::size_t numberLabels = detail::connectedComponentsCuda(
0079 numEdges, cudaSrcPtr, cudaTgtPtr, numSpacepoints, cudaLabels, stream,
0080 m_cfg.useOneBlockImplementation);
0081 auto t1 = std::chrono::high_resolution_clock::now();
0082 ACTS_DEBUG("Connected components took " << ms(t0, t1) << " ms");
0083
0084
0085 numberLabels += 1;
0086
0087 std::vector<int> trackLabels(numSpacepoints);
0088 ACTS_CUDA_CHECK(cudaMemcpyAsync(trackLabels.data(), cudaLabels,
0089 numSpacepoints * sizeof(int),
0090 cudaMemcpyDeviceToHost, stream));
0091
0092
0093 ACTS_CUDA_CHECK(cudaFreeAsync(cudaLabels, stream));
0094 if (m_cfg.doJunctionRemoval) {
0095 ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcPtr, stream));
0096 }
0097
0098 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0099 ACTS_CUDA_CHECK(cudaGetLastError());
0100
0101 ACTS_VERBOSE("Found " << numberLabels << " track candidates");
0102
0103 std::vector<std::vector<int>> trackCandidates(numberLabels);
0104
0105 for (const auto [label, id] : Acts::zip(trackLabels, spacepointIDs)) {
0106 trackCandidates[label].push_back(id);
0107 }
0108
0109 return trackCandidates;
0110 }
0111
0112 }