Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:24:26

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Utilities/Zip.hpp"
0010 #include "ActsPlugins/Gnn/CudaTrackBuilding.hpp"
0011 #include "ActsPlugins/Gnn/detail/ConnectedComponents.cuh"
0012 #include "ActsPlugins/Gnn/detail/CudaUtils.cuh"
0013 #include "ActsPlugins/Gnn/detail/CudaUtils.hpp"
0014 #include "ActsPlugins/Gnn/detail/JunctionRemoval.hpp"
0015 
0016 using namespace Acts;
0017 
0018 namespace ActsPlugins {
0019 
0020 std::vector<std::vector<int>> CudaTrackBuilding::operator()(
0021     PipelineTensors tensors, std::vector<int>& spacepointIDs,
0022     const ExecutionContext& execContext) {
0023   ACTS_VERBOSE("Start CUDA track building");
0024   if (!(tensors.edgeIndex.device().isCuda() &&
0025         tensors.edgeScores.value().device().isCuda())) {
0026     throw std::runtime_error(
0027         "CudaTrackBuilding expects tensors to be on CUDA!");
0028   }
0029 
0030   const auto numSpacepoints = spacepointIDs.size();
0031   auto numEdges = static_cast<std::size_t>(tensors.edgeIndex.shape().at(1));
0032 
0033   if (numEdges == 0) {
0034     ACTS_DEBUG("No edges remained after edge classification");
0035     return {};
0036   }
0037 
0038   auto stream = execContext.stream.value();
0039 
0040   auto cudaSrcPtr = tensors.edgeIndex.data();
0041   auto cudaTgtPtr = tensors.edgeIndex.data() + numEdges;
0042 
0043   auto ms = [](auto t0, auto t1) {
0044     return std::chrono::duration_cast<std::chrono::milliseconds>(t1 - t0)
0045         .count();
0046   };
0047 
0048   if (m_cfg.doJunctionRemoval) {
0049     assert(tensors.edgeScores->shape().at(0) ==
0050            tensors.edgeIndex.shape().at(1));
0051     auto cudaScorePtr = tensors.edgeScores->data();
0052 
0053     ACTS_DEBUG("Do junction removal...");
0054     auto t0 = std::chrono::high_resolution_clock::now();
0055     auto [cudaSrcPtrJr, numEdgesOut] = detail::junctionRemovalCuda(
0056         numEdges, numSpacepoints, cudaScorePtr, cudaSrcPtr, cudaTgtPtr, stream);
0057     auto t1 = std::chrono::high_resolution_clock::now();
0058     cudaSrcPtr = cudaSrcPtrJr;
0059     cudaTgtPtr = cudaSrcPtrJr + numEdgesOut;
0060 
0061     if (numEdgesOut == 0) {
0062       ACTS_WARNING(
0063           "No edges remained after junction removal, this should not happen!");
0064       ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcPtrJr, stream));
0065       ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0066       return {};
0067     }
0068 
0069     ACTS_DEBUG("Removed " << numEdges - numEdgesOut
0070                           << " edges in junction removal");
0071     ACTS_DEBUG("Junction removal took " << ms(t0, t1) << " ms");
0072     numEdges = numEdgesOut;
0073   }
0074 
0075   int* cudaLabels{};
0076   ACTS_CUDA_CHECK(
0077       cudaMallocAsync(&cudaLabels, numSpacepoints * sizeof(int), stream));
0078 
0079   auto t0 = std::chrono::high_resolution_clock::now();
0080   std::size_t numberLabels = detail::connectedComponentsCuda(
0081       numEdges, cudaSrcPtr, cudaTgtPtr, numSpacepoints, cudaLabels, stream,
0082       m_cfg.useOneBlockImplementation);
0083   auto t1 = std::chrono::high_resolution_clock::now();
0084   ACTS_DEBUG("Connected components took " << ms(t0, t1) << " ms");
0085   ACTS_VERBOSE("Found " << numberLabels << " track candidates");
0086 
0087   // Postprocess labels
0088   int* cudaSpacepointIDs{};
0089   ACTS_CUDA_CHECK(cudaMallocAsync(&cudaSpacepointIDs,
0090                                   spacepointIDs.size() * sizeof(int), stream));
0091   ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaSpacepointIDs, spacepointIDs.data(),
0092                                   spacepointIDs.size() * sizeof(int),
0093                                   cudaMemcpyHostToDevice, stream));
0094 
0095   // Allocate space for the bounds
0096   int* cudaBounds{};
0097   ACTS_CUDA_CHECK(
0098       cudaMallocAsync(&cudaBounds, (numberLabels + 1) * sizeof(int), stream));
0099 
0100   // Compute the bounds of the track candidates
0101   detail::findTrackCandidateBounds(cudaLabels, cudaSpacepointIDs, cudaBounds,
0102                                    numSpacepoints, numberLabels, stream);
0103 
0104   // Copy the bounds to the host
0105   std::vector<int> bounds(numberLabels + 1);
0106   ACTS_CUDA_CHECK(cudaMemcpyAsync(bounds.data(), cudaBounds,
0107                                   (numberLabels + 1) * sizeof(int),
0108                                   cudaMemcpyDeviceToHost, stream));
0109 
0110   // Copy the sorted spacepoint IDs to the host
0111   ACTS_CUDA_CHECK(cudaMemcpyAsync(spacepointIDs.data(), cudaSpacepointIDs,
0112                                   spacepointIDs.size() * sizeof(int),
0113                                   cudaMemcpyDeviceToHost, stream));
0114 
0115   // Free Memory
0116   ACTS_CUDA_CHECK(cudaFreeAsync(cudaLabels, stream));
0117   ACTS_CUDA_CHECK(cudaFreeAsync(cudaSpacepointIDs, stream));
0118   ACTS_CUDA_CHECK(cudaFreeAsync(cudaBounds, stream));
0119   if (m_cfg.doJunctionRemoval) {
0120     ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcPtr, stream));
0121   }
0122 
0123   ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0124   ACTS_CUDA_CHECK(cudaGetLastError());
0125 
0126   ACTS_DEBUG("Bounds size: " << bounds.size());
0127   ACTS_DEBUG("Bounds: " << bounds.at(0) << ", " << bounds.at(1) << ", "
0128                         << bounds.at(2) << ", ..., "
0129                         << bounds.at(numberLabels - 2) << ", "
0130                         << bounds.at(numberLabels - 1) << ", "
0131                         << bounds.at(numberLabels));
0132 
0133   std::vector<std::vector<int>> trackCandidates;
0134   trackCandidates.reserve(numberLabels);
0135   for (std::size_t label = 0ul; label < numberLabels; ++label) {
0136     int start = bounds.at(label);
0137     int end = bounds.at(label + 1);
0138 
0139     assert(start >= 0);
0140     assert(end <= static_cast<int>(numSpacepoints));
0141     assert(start <= end);
0142 
0143     if (end - start < m_cfg.minCandidateSize) {
0144       continue;
0145     }
0146 
0147     trackCandidates.emplace_back(spacepointIDs.begin() + start,
0148                                  spacepointIDs.begin() + end);
0149   }
0150 
0151   return trackCandidates;
0152 }
0153 
0154 }  // namespace ActsPlugins