File indexing completed on 2025-01-30 09:15:12
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/CudaTrackBuilding.hpp"
0010 #include "Acts/Plugins/ExaTrkX/detail/ConnectedComponents.cuh"
0011 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.cuh"
0012 #include "Acts/Utilities/Zip.hpp"
0013
0014 #include <c10/cuda/CUDAGuard.h>
0015 #include <c10/cuda/CUDAStream.h>
0016 #include <torch/torch.h>
0017
0018 namespace Acts {
0019
0020 std::vector<std::vector<int>> CudaTrackBuilding::operator()(
0021 std::any , std::any edges, std::any weights,
0022 std::vector<int>& spacepointIDs, const ExecutionContext& execContext) {
0023 ACTS_VERBOSE("Start CUDA track building");
0024 c10::cuda::CUDAStreamGuard guard(execContext.stream.value());
0025
0026 const auto edgeTensor = std::any_cast<torch::Tensor>(edges).to(torch::kCUDA);
0027 assert(edgeTensor.size(0) == 2);
0028
0029 const auto numSpacepoints = spacepointIDs.size();
0030 const auto numEdges = static_cast<std::size_t>(edgeTensor.size(1));
0031
0032 if (numEdges == 0) {
0033 ACTS_WARNING("No edges remained after edge classification");
0034 return {};
0035 }
0036
0037 auto stream = execContext.stream->stream();
0038
0039 auto cudaSrcPtr = edgeTensor.data_ptr<std::int64_t>();
0040 auto cudaTgtPtr = edgeTensor.data_ptr<std::int64_t>() + numEdges;
0041
0042 int* cudaLabels;
0043 ACTS_CUDA_CHECK(
0044 cudaMallocAsync(&cudaLabels, numSpacepoints * sizeof(int), stream));
0045
0046 std::size_t numberLabels = detail::connectedComponentsCuda(
0047 numEdges, cudaSrcPtr, cudaTgtPtr, numSpacepoints, cudaLabels, stream);
0048
0049
0050 numberLabels += 1;
0051
0052 std::vector<int> trackLabels(numSpacepoints);
0053 ACTS_CUDA_CHECK(cudaMemcpyAsync(trackLabels.data(), cudaLabels,
0054 numSpacepoints * sizeof(int),
0055 cudaMemcpyDeviceToHost, stream));
0056 ACTS_CUDA_CHECK(cudaFreeAsync(cudaLabels, stream));
0057 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0058 ACTS_CUDA_CHECK(cudaGetLastError());
0059
0060 ACTS_VERBOSE("Found " << numberLabels << " track candidates");
0061
0062 std::vector<std::vector<int>> trackCandidates(numberLabels);
0063
0064 for (const auto [label, id] : Acts::zip(trackLabels, spacepointIDs)) {
0065 trackCandidates[label].push_back(id);
0066 }
0067
0068 return trackCandidates;
0069 }
0070
0071 }