File indexing completed on 2025-12-16 09:24:26
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsPlugins/Gnn/GnnPipeline.hpp"
0010
0011 #include "Acts/Utilities/Helpers.hpp"
0012 #include "ActsPlugins/Gnn/detail/NvtxUtils.hpp"
0013
0014 #ifdef ACTS_GNN_WITH_CUDA
0015 #include "ActsPlugins/Gnn/detail/CudaUtils.hpp"
0016
0017 namespace {
0018 struct CudaStreamGuard {
0019 cudaStream_t stream{};
0020 CudaStreamGuard() { ACTS_CUDA_CHECK(cudaStreamCreate(&stream)); }
0021 ~CudaStreamGuard() {
0022 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0023 ACTS_CUDA_CHECK(cudaStreamDestroy(stream));
0024 }
0025 };
0026 }
0027 #endif
0028
0029 using namespace Acts;
0030
0031 namespace ActsPlugins {
0032
0033 GnnPipeline::GnnPipeline(
0034 std::shared_ptr<GraphConstructionBase> graphConstructor,
0035 std::vector<std::shared_ptr<EdgeClassificationBase>> edgeClassifiers,
0036 std::shared_ptr<TrackBuildingBase> trackBuilder,
0037 std::unique_ptr<const Logger> logger)
0038 : m_logger(std::move(logger)),
0039 m_graphConstructor(std::move(graphConstructor)),
0040 m_edgeClassifiers(std::move(edgeClassifiers)),
0041 m_trackBuilder(std::move(trackBuilder)) {
0042 if (!m_graphConstructor) {
0043 throw std::invalid_argument("Missing graph construction module");
0044 }
0045 if (!m_trackBuilder) {
0046 throw std::invalid_argument("Missing track building module");
0047 }
0048 if (m_edgeClassifiers.empty() ||
0049 rangeContainsValue(m_edgeClassifiers, nullptr)) {
0050 throw std::invalid_argument("Missing graph construction module");
0051 }
0052 }
0053
0054 std::vector<std::vector<int>> GnnPipeline::run(
0055 std::vector<float> &features, const std::vector<std::uint64_t> &moduleIds,
0056 std::vector<int> &spacepointIDs, Device device, const GnnHook &hook,
0057 GnnTiming *timing) const {
0058 ExecutionContext ctx;
0059 ctx.device = device;
0060 #ifdef ACTS_GNN_WITH_CUDA
0061 std::optional<CudaStreamGuard> streamGuard;
0062 if (ctx.device.type == Device::Type::eCUDA) {
0063 streamGuard.emplace();
0064 ctx.stream = streamGuard->stream;
0065 }
0066 #endif
0067
0068 try {
0069 auto t0 = std::chrono::high_resolution_clock::now();
0070 ACTS_NVTX_START(graph_construction);
0071 auto tensors =
0072 (*m_graphConstructor)(features, spacepointIDs.size(), moduleIds, ctx);
0073 ACTS_NVTX_STOP(graph_construction);
0074 auto t1 = std::chrono::high_resolution_clock::now();
0075
0076 if (timing != nullptr) {
0077 timing->graphBuildingTime = t1 - t0;
0078 }
0079
0080 hook(tensors, ctx);
0081
0082 if (timing != nullptr) {
0083 timing->classifierTimes.clear();
0084 }
0085
0086 for (const auto &edgeClassifier : m_edgeClassifiers) {
0087 t0 = std::chrono::high_resolution_clock::now();
0088 ACTS_NVTX_START(edge_classifier);
0089 tensors = (*edgeClassifier)(std::move(tensors), ctx);
0090 ACTS_NVTX_STOP(edge_classifier);
0091 t1 = std::chrono::high_resolution_clock::now();
0092
0093 if (timing != nullptr) {
0094 timing->classifierTimes.push_back(t1 - t0);
0095 }
0096
0097 hook(tensors, ctx);
0098 }
0099
0100 t0 = std::chrono::high_resolution_clock::now();
0101 ACTS_NVTX_START(track_building);
0102 auto res = (*m_trackBuilder)(std::move(tensors), spacepointIDs, ctx);
0103 ACTS_NVTX_STOP(track_building);
0104 t1 = std::chrono::high_resolution_clock::now();
0105
0106 if (timing != nullptr) {
0107 timing->trackBuildingTime = t1 - t0;
0108 }
0109
0110 return res;
0111 } catch (NoEdgesError &) {
0112 ACTS_DEBUG("No edges left in GNN pipeline, return 0 track candidates");
0113 if (timing != nullptr) {
0114 while (timing->classifierTimes.size() < m_edgeClassifiers.size()) {
0115 timing->classifierTimes.push_back({});
0116 }
0117 }
0118 return {};
0119 }
0120 }
0121
0122 }