Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:24:26

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsPlugins/Gnn/GnnPipeline.hpp"
0010 
0011 #include "Acts/Utilities/Helpers.hpp"
0012 #include "ActsPlugins/Gnn/detail/NvtxUtils.hpp"
0013 
0014 #ifdef ACTS_GNN_WITH_CUDA
0015 #include "ActsPlugins/Gnn/detail/CudaUtils.hpp"
0016 
0017 namespace {
0018 struct CudaStreamGuard {
0019   cudaStream_t stream{};
0020   CudaStreamGuard() { ACTS_CUDA_CHECK(cudaStreamCreate(&stream)); }
0021   ~CudaStreamGuard() {
0022     ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0023     ACTS_CUDA_CHECK(cudaStreamDestroy(stream));
0024   }
0025 };
0026 }  // namespace
0027 #endif
0028 
0029 using namespace Acts;
0030 
0031 namespace ActsPlugins {
0032 
0033 GnnPipeline::GnnPipeline(
0034     std::shared_ptr<GraphConstructionBase> graphConstructor,
0035     std::vector<std::shared_ptr<EdgeClassificationBase>> edgeClassifiers,
0036     std::shared_ptr<TrackBuildingBase> trackBuilder,
0037     std::unique_ptr<const Logger> logger)
0038     : m_logger(std::move(logger)),
0039       m_graphConstructor(std::move(graphConstructor)),
0040       m_edgeClassifiers(std::move(edgeClassifiers)),
0041       m_trackBuilder(std::move(trackBuilder)) {
0042   if (!m_graphConstructor) {
0043     throw std::invalid_argument("Missing graph construction module");
0044   }
0045   if (!m_trackBuilder) {
0046     throw std::invalid_argument("Missing track building module");
0047   }
0048   if (m_edgeClassifiers.empty() ||
0049       rangeContainsValue(m_edgeClassifiers, nullptr)) {
0050     throw std::invalid_argument("Missing graph construction module");
0051   }
0052 }
0053 
0054 std::vector<std::vector<int>> GnnPipeline::run(
0055     std::vector<float> &features, const std::vector<std::uint64_t> &moduleIds,
0056     std::vector<int> &spacepointIDs, Device device, const GnnHook &hook,
0057     GnnTiming *timing) const {
0058   ExecutionContext ctx;
0059   ctx.device = device;
0060 #ifdef ACTS_GNN_WITH_CUDA
0061   std::optional<CudaStreamGuard> streamGuard;
0062   if (ctx.device.type == Device::Type::eCUDA) {
0063     streamGuard.emplace();
0064     ctx.stream = streamGuard->stream;
0065   }
0066 #endif
0067 
0068   try {
0069     auto t0 = std::chrono::high_resolution_clock::now();
0070     ACTS_NVTX_START(graph_construction);
0071     auto tensors =
0072         (*m_graphConstructor)(features, spacepointIDs.size(), moduleIds, ctx);
0073     ACTS_NVTX_STOP(graph_construction);
0074     auto t1 = std::chrono::high_resolution_clock::now();
0075 
0076     if (timing != nullptr) {
0077       timing->graphBuildingTime = t1 - t0;
0078     }
0079 
0080     hook(tensors, ctx);
0081 
0082     if (timing != nullptr) {
0083       timing->classifierTimes.clear();
0084     }
0085 
0086     for (const auto &edgeClassifier : m_edgeClassifiers) {
0087       t0 = std::chrono::high_resolution_clock::now();
0088       ACTS_NVTX_START(edge_classifier);
0089       tensors = (*edgeClassifier)(std::move(tensors), ctx);
0090       ACTS_NVTX_STOP(edge_classifier);
0091       t1 = std::chrono::high_resolution_clock::now();
0092 
0093       if (timing != nullptr) {
0094         timing->classifierTimes.push_back(t1 - t0);
0095       }
0096 
0097       hook(tensors, ctx);
0098     }
0099 
0100     t0 = std::chrono::high_resolution_clock::now();
0101     ACTS_NVTX_START(track_building);
0102     auto res = (*m_trackBuilder)(std::move(tensors), spacepointIDs, ctx);
0103     ACTS_NVTX_STOP(track_building);
0104     t1 = std::chrono::high_resolution_clock::now();
0105 
0106     if (timing != nullptr) {
0107       timing->trackBuildingTime = t1 - t0;
0108     }
0109 
0110     return res;
0111   } catch (NoEdgesError &) {
0112     ACTS_DEBUG("No edges left in GNN pipeline, return 0 track candidates");
0113     if (timing != nullptr) {
0114       while (timing->classifierTimes.size() < m_edgeClassifiers.size()) {
0115         timing->classifierTimes.push_back({});
0116       }
0117     }
0118     return {};
0119   }
0120 }
0121 
0122 }  // namespace ActsPlugins