Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-02 07:51:54

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/CudaTrackBuilding.hpp"
0010 #include "Acts/Plugins/ExaTrkX/detail/ConnectedComponents.cuh"
0011 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.cuh"
0012 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.hpp"
0013 #include "Acts/Plugins/ExaTrkX/detail/JunctionRemoval.hpp"
0014 #include "Acts/Utilities/Zip.hpp"
0015 
0016 namespace Acts {
0017 
0018 std::vector<std::vector<int>> CudaTrackBuilding::operator()(
0019     PipelineTensors tensors, std::vector<int>& spacepointIDs,
0020     const ExecutionContext& execContext) {
0021   ACTS_VERBOSE("Start CUDA track building");
0022   if (!(tensors.edgeIndex.device().isCuda() &&
0023         tensors.edgeScores.value().device().isCuda())) {
0024     throw std::runtime_error(
0025         "CudaTrackBuilding expects tensors to be on CUDA!");
0026   }
0027 
0028   const auto numSpacepoints = spacepointIDs.size();
0029   auto numEdges = static_cast<std::size_t>(tensors.edgeIndex.shape().at(1));
0030 
0031   if (numEdges == 0) {
0032     ACTS_DEBUG("No edges remained after edge classification");
0033     return {};
0034   }
0035 
0036   auto stream = execContext.stream.value();
0037 
0038   auto cudaSrcPtr = tensors.edgeIndex.data();
0039   auto cudaTgtPtr = tensors.edgeIndex.data() + numEdges;
0040 
0041   auto ms = [](auto t0, auto t1) {
0042     return std::chrono::duration_cast<std::chrono::milliseconds>(t1 - t0)
0043         .count();
0044   };
0045 
0046   if (m_cfg.doJunctionRemoval) {
0047     assert(tensors.edgeScores->shape().at(0) ==
0048            tensors.edgeIndex.shape().at(1));
0049     auto cudaScorePtr = tensors.edgeScores->data();
0050 
0051     ACTS_DEBUG("Do junction removal...");
0052     auto t0 = std::chrono::high_resolution_clock::now();
0053     auto [cudaSrcPtrJr, numEdgesOut] = detail::junctionRemovalCuda(
0054         numEdges, numSpacepoints, cudaScorePtr, cudaSrcPtr, cudaTgtPtr, stream);
0055     auto t1 = std::chrono::high_resolution_clock::now();
0056     cudaSrcPtr = cudaSrcPtrJr;
0057     cudaTgtPtr = cudaSrcPtrJr + numEdgesOut;
0058 
0059     if (numEdgesOut == 0) {
0060       ACTS_WARNING(
0061           "No edges remained after junction removal, this should not happen!");
0062       ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcPtrJr, stream));
0063       ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0064       return {};
0065     }
0066 
0067     ACTS_DEBUG("Removed " << numEdges - numEdgesOut
0068                           << " edges in junction removal");
0069     ACTS_DEBUG("Junction removal took " << ms(t0, t1) << " ms");
0070     numEdges = numEdgesOut;
0071   }
0072 
0073   int* cudaLabels{};
0074   ACTS_CUDA_CHECK(
0075       cudaMallocAsync(&cudaLabels, numSpacepoints * sizeof(int), stream));
0076 
0077   auto t0 = std::chrono::high_resolution_clock::now();
0078   std::size_t numberLabels = detail::connectedComponentsCuda(
0079       numEdges, cudaSrcPtr, cudaTgtPtr, numSpacepoints, cudaLabels, stream,
0080       m_cfg.useOneBlockImplementation);
0081   auto t1 = std::chrono::high_resolution_clock::now();
0082   ACTS_DEBUG("Connected components took " << ms(t0, t1) << " ms");
0083 
0084   // TODO not sure why there is an issue that is not detected in the unit tests
0085   numberLabels += 1;
0086 
0087   std::vector<int> trackLabels(numSpacepoints);
0088   ACTS_CUDA_CHECK(cudaMemcpyAsync(trackLabels.data(), cudaLabels,
0089                                   numSpacepoints * sizeof(int),
0090                                   cudaMemcpyDeviceToHost, stream));
0091 
0092   // Free Memory
0093   ACTS_CUDA_CHECK(cudaFreeAsync(cudaLabels, stream));
0094   if (m_cfg.doJunctionRemoval) {
0095     ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcPtr, stream));
0096   }
0097 
0098   ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0099   ACTS_CUDA_CHECK(cudaGetLastError());
0100 
0101   ACTS_VERBOSE("Found " << numberLabels << " track candidates");
0102 
0103   std::vector<std::vector<int>> trackCandidates(numberLabels);
0104 
0105   for (const auto [label, id] : Acts::zip(trackLabels, spacepointIDs)) {
0106     trackCandidates[label].push_back(id);
0107   }
0108 
0109   return trackCandidates;
0110 }
0111 
0112 }  // namespace Acts