File indexing completed on 2025-07-11 07:50:57
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
0010
0011 #include "Acts/Utilities/Helpers.hpp"
0012
0013 #ifdef ACTS_EXATRKX_WITH_CUDA
0014 #include "Acts/Plugins/ExaTrkX/detail/CudaUtils.hpp"
0015
0016 namespace {
0017 struct CudaStreamGuard {
0018 cudaStream_t stream{};
0019 CudaStreamGuard() { ACTS_CUDA_CHECK(cudaStreamCreate(&stream)); }
0020 ~CudaStreamGuard() {
0021 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0022 ACTS_CUDA_CHECK(cudaStreamDestroy(stream));
0023 }
0024 };
0025 }
0026 #endif
0027
0028 namespace Acts {
0029
0030 ExaTrkXPipeline::ExaTrkXPipeline(
0031 std::shared_ptr<GraphConstructionBase> graphConstructor,
0032 std::vector<std::shared_ptr<EdgeClassificationBase>> edgeClassifiers,
0033 std::shared_ptr<TrackBuildingBase> trackBuilder,
0034 std::unique_ptr<const Acts::Logger> logger)
0035 : m_logger(std::move(logger)),
0036 m_graphConstructor(std::move(graphConstructor)),
0037 m_edgeClassifiers(std::move(edgeClassifiers)),
0038 m_trackBuilder(std::move(trackBuilder)) {
0039 if (!m_graphConstructor) {
0040 throw std::invalid_argument("Missing graph construction module");
0041 }
0042 if (!m_trackBuilder) {
0043 throw std::invalid_argument("Missing track building module");
0044 }
0045 if (m_edgeClassifiers.empty() ||
0046 rangeContainsValue(m_edgeClassifiers, nullptr)) {
0047 throw std::invalid_argument("Missing graph construction module");
0048 }
0049 }
0050
0051 std::vector<std::vector<int>> ExaTrkXPipeline::run(
0052 std::vector<float> &features, const std::vector<std::uint64_t> &moduleIds,
0053 std::vector<int> &spacepointIDs, Acts::Device device,
0054 const ExaTrkXHook &hook, ExaTrkXTiming *timing) const {
0055 ExecutionContext ctx;
0056 ctx.device = device;
0057 #ifdef ACTS_EXATRKX_WITH_CUDA
0058 std::optional<CudaStreamGuard> streamGuard;
0059 if (ctx.device.type == Acts::Device::Type::eCUDA) {
0060 streamGuard.emplace();
0061 ctx.stream = streamGuard->stream;
0062 }
0063 #endif
0064
0065 try {
0066 auto t0 = std::chrono::high_resolution_clock::now();
0067 auto tensors =
0068 (*m_graphConstructor)(features, spacepointIDs.size(), moduleIds, ctx);
0069 auto t1 = std::chrono::high_resolution_clock::now();
0070
0071 if (timing != nullptr) {
0072 timing->graphBuildingTime = t1 - t0;
0073 }
0074
0075 hook(tensors, ctx);
0076
0077 if (timing != nullptr) {
0078 timing->classifierTimes.clear();
0079 }
0080
0081 for (const auto &edgeClassifier : m_edgeClassifiers) {
0082 t0 = std::chrono::high_resolution_clock::now();
0083 tensors = (*edgeClassifier)(std::move(tensors), ctx);
0084 t1 = std::chrono::high_resolution_clock::now();
0085
0086 if (timing != nullptr) {
0087 timing->classifierTimes.push_back(t1 - t0);
0088 }
0089
0090 hook(tensors, ctx);
0091 }
0092
0093 t0 = std::chrono::high_resolution_clock::now();
0094 auto res = (*m_trackBuilder)(std::move(tensors), spacepointIDs, ctx);
0095 t1 = std::chrono::high_resolution_clock::now();
0096
0097 if (timing != nullptr) {
0098 timing->trackBuildingTime = t1 - t0;
0099 }
0100
0101 return res;
0102 } catch (Acts::NoEdgesError &) {
0103 ACTS_DEBUG("No edges left in GNN pipeline, return 0 track candidates");
0104 if (timing != nullptr) {
0105 while (timing->classifierTimes.size() < m_edgeClassifiers.size()) {
0106 timing->classifierTimes.push_back({});
0107 }
0108 }
0109 return {};
0110 }
0111 }
0112
0113 }