Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 09:15:12

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/CudaTrackBuilding.hpp"
0010 #include "Acts/Plugins/ExaTrkX/detail/ConnectedComponents.cuh"
0011 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.cuh"
0012 #include "Acts/Utilities/Zip.hpp"
0013 
0014 #include <c10/cuda/CUDAGuard.h>
0015 #include <c10/cuda/CUDAStream.h>
0016 #include <torch/torch.h>
0017 
0018 namespace Acts {
0019 
0020 std::vector<std::vector<int>> CudaTrackBuilding::operator()(
0021     std::any /*nodes*/, std::any edges, std::any weights,
0022     std::vector<int>& spacepointIDs, const ExecutionContext& execContext) {
0023   ACTS_VERBOSE("Start CUDA track building");
0024   c10::cuda::CUDAStreamGuard guard(execContext.stream.value());
0025 
0026   const auto edgeTensor = std::any_cast<torch::Tensor>(edges).to(torch::kCUDA);
0027   assert(edgeTensor.size(0) == 2);
0028 
0029   const auto numSpacepoints = spacepointIDs.size();
0030   const auto numEdges = static_cast<std::size_t>(edgeTensor.size(1));
0031 
0032   if (numEdges == 0) {
0033     ACTS_WARNING("No edges remained after edge classification");
0034     return {};
0035   }
0036 
0037   auto stream = execContext.stream->stream();
0038 
0039   auto cudaSrcPtr = edgeTensor.data_ptr<std::int64_t>();
0040   auto cudaTgtPtr = edgeTensor.data_ptr<std::int64_t>() + numEdges;
0041 
0042   int* cudaLabels;
0043   ACTS_CUDA_CHECK(
0044       cudaMallocAsync(&cudaLabels, numSpacepoints * sizeof(int), stream));
0045 
0046   std::size_t numberLabels = detail::connectedComponentsCuda(
0047       numEdges, cudaSrcPtr, cudaTgtPtr, numSpacepoints, cudaLabels, stream);
0048 
0049   // TODO not sure why there is an issue that is not detected in the unit tests
0050   numberLabels += 1;
0051 
0052   std::vector<int> trackLabels(numSpacepoints);
0053   ACTS_CUDA_CHECK(cudaMemcpyAsync(trackLabels.data(), cudaLabels,
0054                                   numSpacepoints * sizeof(int),
0055                                   cudaMemcpyDeviceToHost, stream));
0056   ACTS_CUDA_CHECK(cudaFreeAsync(cudaLabels, stream));
0057   ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0058   ACTS_CUDA_CHECK(cudaGetLastError());
0059 
0060   ACTS_VERBOSE("Found " << numberLabels << " track candidates");
0061 
0062   std::vector<std::vector<int>> trackCandidates(numberLabels);
0063 
0064   for (const auto [label, id] : Acts::zip(trackLabels, spacepointIDs)) {
0065     trackCandidates[label].push_back(id);
0066   }
0067 
0068   return trackCandidates;
0069 }
0070 
0071 }  // namespace Acts