Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-11 07:50:57

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
0010 
0011 #include "Acts/Utilities/Helpers.hpp"
0012 
0013 #ifdef ACTS_EXATRKX_WITH_CUDA
0014 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.hpp"
0015 
0016 namespace {
0017 struct CudaStreamGuard {
0018   cudaStream_t stream{};
0019   CudaStreamGuard() { ACTS_CUDA_CHECK(cudaStreamCreate(&stream)); }
0020   ~CudaStreamGuard() {
0021     ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0022     ACTS_CUDA_CHECK(cudaStreamDestroy(stream));
0023   }
0024 };
0025 }  // namespace
0026 #endif
0027 
0028 namespace Acts {
0029 
0030 ExaTrkXPipeline::ExaTrkXPipeline(
0031     std::shared_ptr<GraphConstructionBase> graphConstructor,
0032     std::vector<std::shared_ptr<EdgeClassificationBase>> edgeClassifiers,
0033     std::shared_ptr<TrackBuildingBase> trackBuilder,
0034     std::unique_ptr<const Acts::Logger> logger)
0035     : m_logger(std::move(logger)),
0036       m_graphConstructor(std::move(graphConstructor)),
0037       m_edgeClassifiers(std::move(edgeClassifiers)),
0038       m_trackBuilder(std::move(trackBuilder)) {
0039   if (!m_graphConstructor) {
0040     throw std::invalid_argument("Missing graph construction module");
0041   }
0042   if (!m_trackBuilder) {
0043     throw std::invalid_argument("Missing track building module");
0044   }
0045   if (m_edgeClassifiers.empty() ||
0046       rangeContainsValue(m_edgeClassifiers, nullptr)) {
0047     throw std::invalid_argument("Missing graph construction module");
0048   }
0049 }
0050 
0051 std::vector<std::vector<int>> ExaTrkXPipeline::run(
0052     std::vector<float> &features, const std::vector<std::uint64_t> &moduleIds,
0053     std::vector<int> &spacepointIDs, Acts::Device device,
0054     const ExaTrkXHook &hook, ExaTrkXTiming *timing) const {
0055   ExecutionContext ctx;
0056   ctx.device = device;
0057 #ifdef ACTS_EXATRKX_WITH_CUDA
0058   std::optional<CudaStreamGuard> streamGuard;
0059   if (ctx.device.type == Acts::Device::Type::eCUDA) {
0060     streamGuard.emplace();
0061     ctx.stream = streamGuard->stream;
0062   }
0063 #endif
0064 
0065   try {
0066     auto t0 = std::chrono::high_resolution_clock::now();
0067     auto tensors =
0068         (*m_graphConstructor)(features, spacepointIDs.size(), moduleIds, ctx);
0069     auto t1 = std::chrono::high_resolution_clock::now();
0070 
0071     if (timing != nullptr) {
0072       timing->graphBuildingTime = t1 - t0;
0073     }
0074 
0075     hook(tensors, ctx);
0076 
0077     if (timing != nullptr) {
0078       timing->classifierTimes.clear();
0079     }
0080 
0081     for (const auto &edgeClassifier : m_edgeClassifiers) {
0082       t0 = std::chrono::high_resolution_clock::now();
0083       tensors = (*edgeClassifier)(std::move(tensors), ctx);
0084       t1 = std::chrono::high_resolution_clock::now();
0085 
0086       if (timing != nullptr) {
0087         timing->classifierTimes.push_back(t1 - t0);
0088       }
0089 
0090       hook(tensors, ctx);
0091     }
0092 
0093     t0 = std::chrono::high_resolution_clock::now();
0094     auto res = (*m_trackBuilder)(std::move(tensors), spacepointIDs, ctx);
0095     t1 = std::chrono::high_resolution_clock::now();
0096 
0097     if (timing != nullptr) {
0098       timing->trackBuildingTime = t1 - t0;
0099     }
0100 
0101     return res;
0102   } catch (Acts::NoEdgesError &) {
0103     ACTS_DEBUG("No edges left in GNN pipeline, return 0 track candidates");
0104     if (timing != nullptr) {
0105       while (timing->classifierTimes.size() < m_edgeClassifiers.size()) {
0106         timing->classifierTimes.push_back({});
0107       }
0108     }
0109     return {};
0110   }
0111 }
0112 
0113 }  // namespace Acts