Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 09:15:12

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
0010 
0011 #include "Acts/Utilities/Helpers.hpp"
0012 
0013 #include <algorithm>
0014 
0015 namespace Acts {
0016 
0017 ExaTrkXPipeline::ExaTrkXPipeline(
0018     std::shared_ptr<GraphConstructionBase> graphConstructor,
0019     std::vector<std::shared_ptr<EdgeClassificationBase>> edgeClassifiers,
0020     std::shared_ptr<TrackBuildingBase> trackBuilder,
0021     std::unique_ptr<const Acts::Logger> logger)
0022     : m_logger(std::move(logger)),
0023       m_graphConstructor(graphConstructor),
0024       m_edgeClassifiers(edgeClassifiers),
0025       m_trackBuilder(trackBuilder) {
0026   if (!m_graphConstructor) {
0027     throw std::invalid_argument("Missing graph construction module");
0028   }
0029   if (!m_trackBuilder) {
0030     throw std::invalid_argument("Missing track building module");
0031   }
0032   if (m_edgeClassifiers.empty() ||
0033       rangeContainsValue(m_edgeClassifiers, nullptr)) {
0034     throw std::invalid_argument("Missing graph construction module");
0035   }
0036 }
0037 
0038 std::vector<std::vector<int>> ExaTrkXPipeline::run(
0039     std::vector<float> &features, const std::vector<std::uint64_t> &moduleIds,
0040     std::vector<int> &spacepointIDs, const ExaTrkXHook &hook,
0041     ExaTrkXTiming *timing) const {
0042   ExecutionContext ctx;
0043   ctx.device = m_graphConstructor->device();
0044 #ifndef ACTS_EXATRKX_CPUONLY
0045   if (ctx.device.type() == torch::kCUDA) {
0046     ctx.stream = c10::cuda::getStreamFromPool(ctx.device.index());
0047   }
0048 #endif
0049 
0050   try {
0051     auto t0 = std::chrono::high_resolution_clock::now();
0052     auto [nodeFeatures, edgeIndex, edgeFeatures] =
0053         (*m_graphConstructor)(features, spacepointIDs.size(), moduleIds, ctx);
0054     auto t1 = std::chrono::high_resolution_clock::now();
0055 
0056     if (timing != nullptr) {
0057       timing->graphBuildingTime = t1 - t0;
0058     }
0059 
0060     hook(nodeFeatures, edgeIndex, {});
0061 
0062     std::any edgeScores;
0063     timing->classifierTimes.clear();
0064 
0065     for (auto edgeClassifier : m_edgeClassifiers) {
0066       t0 = std::chrono::high_resolution_clock::now();
0067       auto [newNodeFeatures, newEdgeIndex, newEdgeFeatures, newEdgeScores] =
0068           (*edgeClassifier)(std::move(nodeFeatures), std::move(edgeIndex),
0069                             std::move(edgeFeatures), ctx);
0070       t1 = std::chrono::high_resolution_clock::now();
0071 
0072       if (timing != nullptr) {
0073         timing->classifierTimes.push_back(t1 - t0);
0074       }
0075 
0076       nodeFeatures = std::move(newNodeFeatures);
0077       edgeFeatures = std::move(newEdgeFeatures);
0078       edgeIndex = std::move(newEdgeIndex);
0079       edgeScores = std::move(newEdgeScores);
0080 
0081       hook(nodeFeatures, edgeIndex, edgeScores);
0082     }
0083 
0084     t0 = std::chrono::high_resolution_clock::now();
0085     auto res = (*m_trackBuilder)(std::move(nodeFeatures), std::move(edgeIndex),
0086                                  std::move(edgeScores), spacepointIDs, ctx);
0087     t1 = std::chrono::high_resolution_clock::now();
0088 
0089     if (timing != nullptr) {
0090       timing->trackBuildingTime = t1 - t0;
0091     }
0092 
0093     return res;
0094   } catch (Acts::NoEdgesError &) {
0095     ACTS_WARNING("No egdges left in GNN pipeline, return 0 track candidates");
0096     if (timing != nullptr) {
0097       while (timing->classifierTimes.size() < m_edgeClassifiers.size()) {
0098         timing->classifierTimes.push_back({});
0099       }
0100     }
0101     return {};
0102   }
0103 }
0104 
0105 }  // namespace Acts