File indexing completed on 2025-01-30 09:15:12
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
0010
0011 #include "Acts/Utilities/Helpers.hpp"
0012
0013 #include <algorithm>
0014
0015 namespace Acts {
0016
0017 ExaTrkXPipeline::ExaTrkXPipeline(
0018 std::shared_ptr<GraphConstructionBase> graphConstructor,
0019 std::vector<std::shared_ptr<EdgeClassificationBase>> edgeClassifiers,
0020 std::shared_ptr<TrackBuildingBase> trackBuilder,
0021 std::unique_ptr<const Acts::Logger> logger)
0022 : m_logger(std::move(logger)),
0023 m_graphConstructor(graphConstructor),
0024 m_edgeClassifiers(edgeClassifiers),
0025 m_trackBuilder(trackBuilder) {
0026 if (!m_graphConstructor) {
0027 throw std::invalid_argument("Missing graph construction module");
0028 }
0029 if (!m_trackBuilder) {
0030 throw std::invalid_argument("Missing track building module");
0031 }
0032 if (m_edgeClassifiers.empty() ||
0033 rangeContainsValue(m_edgeClassifiers, nullptr)) {
0034 throw std::invalid_argument("Missing graph construction module");
0035 }
0036 }
0037
0038 std::vector<std::vector<int>> ExaTrkXPipeline::run(
0039 std::vector<float> &features, const std::vector<std::uint64_t> &moduleIds,
0040 std::vector<int> &spacepointIDs, const ExaTrkXHook &hook,
0041 ExaTrkXTiming *timing) const {
0042 ExecutionContext ctx;
0043 ctx.device = m_graphConstructor->device();
0044 #ifndef ACTS_EXATRKX_CPUONLY
0045 if (ctx.device.type() == torch::kCUDA) {
0046 ctx.stream = c10::cuda::getStreamFromPool(ctx.device.index());
0047 }
0048 #endif
0049
0050 try {
0051 auto t0 = std::chrono::high_resolution_clock::now();
0052 auto [nodeFeatures, edgeIndex, edgeFeatures] =
0053 (*m_graphConstructor)(features, spacepointIDs.size(), moduleIds, ctx);
0054 auto t1 = std::chrono::high_resolution_clock::now();
0055
0056 if (timing != nullptr) {
0057 timing->graphBuildingTime = t1 - t0;
0058 }
0059
0060 hook(nodeFeatures, edgeIndex, {});
0061
0062 std::any edgeScores;
0063 timing->classifierTimes.clear();
0064
0065 for (auto edgeClassifier : m_edgeClassifiers) {
0066 t0 = std::chrono::high_resolution_clock::now();
0067 auto [newNodeFeatures, newEdgeIndex, newEdgeFeatures, newEdgeScores] =
0068 (*edgeClassifier)(std::move(nodeFeatures), std::move(edgeIndex),
0069 std::move(edgeFeatures), ctx);
0070 t1 = std::chrono::high_resolution_clock::now();
0071
0072 if (timing != nullptr) {
0073 timing->classifierTimes.push_back(t1 - t0);
0074 }
0075
0076 nodeFeatures = std::move(newNodeFeatures);
0077 edgeFeatures = std::move(newEdgeFeatures);
0078 edgeIndex = std::move(newEdgeIndex);
0079 edgeScores = std::move(newEdgeScores);
0080
0081 hook(nodeFeatures, edgeIndex, edgeScores);
0082 }
0083
0084 t0 = std::chrono::high_resolution_clock::now();
0085 auto res = (*m_trackBuilder)(std::move(nodeFeatures), std::move(edgeIndex),
0086 std::move(edgeScores), spacepointIDs, ctx);
0087 t1 = std::chrono::high_resolution_clock::now();
0088
0089 if (timing != nullptr) {
0090 timing->trackBuildingTime = t1 - t0;
0091 }
0092
0093 return res;
0094 } catch (Acts::NoEdgesError &) {
0095 ACTS_WARNING("No egdges left in GNN pipeline, return 0 track candidates");
0096 if (timing != nullptr) {
0097 while (timing->classifierTimes.size() < m_edgeClassifiers.size()) {
0098 timing->classifierTimes.push_back({});
0099 }
0100 }
0101 return {};
0102 }
0103 }
0104
0105 }