Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:23:35

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsExamples/TrackFindingGnn/TrackFindingAlgorithmGnn.hpp"
0010 
0011 #include "Acts/Definitions/Units.hpp"
0012 #include "Acts/Utilities/Helpers.hpp"
0013 #include "Acts/Utilities/Zip.hpp"
0014 #include "ActsExamples/EventData/Index.hpp"
0015 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0016 #include "ActsExamples/EventData/ProtoTrack.hpp"
0017 #include "ActsExamples/EventData/SimSpacePoint.hpp"
0018 #include "ActsExamples/Framework/WhiteBoard.hpp"
0019 #include "ActsPlugins/Gnn/GraphStoreHook.hpp"
0020 #include "ActsPlugins/Gnn/TruthGraphMetricsHook.hpp"
0021 #include "ActsPlugins/Gnn/detail/NvtxUtils.hpp"
0022 
0023 #include <algorithm>
0024 #include <chrono>
0025 #include <numeric>
0026 
0027 #include "createFeatures.hpp"
0028 
0029 #ifdef ACTS_GNN_WITH_CUDA
0030 #include <cuda_runtime_api.h>
0031 #endif
0032 
0033 using namespace Acts;
0034 using namespace ActsPlugins;
0035 using namespace Acts::UnitLiterals;
0036 
0037 namespace {
0038 
0039 struct LoopHook : public GnnHook {
0040   std::vector<GnnHook*> hooks;
0041 
0042   void operator()(const PipelineTensors& tensors,
0043                   const ExecutionContext& ctx) const override {
0044     for (auto hook : hooks) {
0045       (*hook)(tensors, ctx);
0046     }
0047   }
0048 };
0049 
0050 }  // namespace
0051 
0052 ActsExamples::TrackFindingAlgorithmGnn::TrackFindingAlgorithmGnn(
0053     Config config, Logging::Level level)
0054     : ActsExamples::IAlgorithm("TrackFindingMLBasedAlgorithm", level),
0055       m_cfg(std::move(config)),
0056       m_pipeline(m_cfg.graphConstructor, m_cfg.edgeClassifiers,
0057                  m_cfg.trackBuilder, logger().clone()) {
0058   if (m_cfg.inputSpacePoints.empty()) {
0059     throw std::invalid_argument("Missing spacepoint input collection");
0060   }
0061   if (m_cfg.outputProtoTracks.empty()) {
0062     throw std::invalid_argument("Missing protoTrack output collection");
0063   }
0064 
0065   m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
0066   m_inputClusters.maybeInitialize(m_cfg.inputClusters);
0067   m_outputProtoTracks.initialize(m_cfg.outputProtoTracks);
0068 
0069   m_inputTruthGraph.maybeInitialize(m_cfg.inputTruthGraph);
0070   m_outputGraph.maybeInitialize(m_cfg.outputGraph);
0071 
0072   // reserve space for timing
0073   m_timing.classifierTimes.resize(
0074       m_cfg.edgeClassifiers.size(),
0075       decltype(m_timing.classifierTimes)::value_type{0.f});
0076 
0077   // Check if we want cluster features but do not have them
0078   const static std::array clFeatures = {
0079       NodeFeature::eClusterLoc0, NodeFeature::eClusterLoc0,
0080       NodeFeature::eCellCount,   NodeFeature::eChargeSum,
0081       NodeFeature::eCluster1R,   NodeFeature::eCluster2R};
0082 
0083   auto wantClFeatures = std::ranges::any_of(
0084       m_cfg.nodeFeatures,
0085       [&](const auto& f) { return rangeContainsValue(clFeatures, f); });
0086 
0087   if (wantClFeatures && !m_inputClusters.isInitialized()) {
0088     throw std::invalid_argument("Cluster features requested, but not provided");
0089   }
0090 
0091   if (m_cfg.nodeFeatures.size() != m_cfg.featureScales.size()) {
0092     throw std::invalid_argument(
0093         "Number of features mismatches number of scale parameters.");
0094   }
0095 
0096 // Reset error state to prevent failure due to previous runs
0097 #ifdef ACTS_GNN_WITH_CUDA
0098   cudaGetLastError();
0099 #endif
0100 }
0101 
0102 /// Allow access to features with nice names
0103 
0104 ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmGnn::execute(
0105     const ActsExamples::AlgorithmContext& ctx) const {
0106   ACTS_NVTX_START(data_preparation);
0107 
0108   using Clock = std::chrono::high_resolution_clock;
0109   using Duration = std::chrono::duration<double, std::milli>;
0110   auto t0 = Clock::now();
0111 
0112   // Setup hooks
0113   LoopHook hook;
0114 
0115   std::unique_ptr<TruthGraphMetricsHook> truthGraphHook;
0116   if (m_inputTruthGraph.isInitialized()) {
0117     truthGraphHook = std::make_unique<TruthGraphMetricsHook>(
0118         m_inputTruthGraph(ctx).edges, this->logger().clone());
0119     hook.hooks.push_back(&*truthGraphHook);
0120   }
0121 
0122   std::unique_ptr<GraphStoreHook> graphStoreHook;
0123   if (m_outputGraph.isInitialized()) {
0124     graphStoreHook = std::make_unique<GraphStoreHook>();
0125     hook.hooks.push_back(&*graphStoreHook);
0126   }
0127 
0128   // Read input data
0129   const auto& spacepoints = m_inputSpacePoints(ctx);
0130 
0131   const ClusterContainer* clusters = nullptr;
0132   if (m_inputClusters.isInitialized()) {
0133     clusters = &m_inputClusters(ctx);
0134   }
0135 
0136   // Convert Input data to a list of size [num_measurements x
0137   // measurement_features]
0138   const std::size_t numSpacepoints = spacepoints.size();
0139   const std::size_t numFeatures = m_cfg.nodeFeatures.size();
0140   ACTS_DEBUG("Received " << numSpacepoints << " spacepoints");
0141   ACTS_DEBUG("Construct " << numFeatures << " node features");
0142 
0143   auto t01 = Clock::now();
0144 
0145   std::vector<std::uint64_t> moduleIds;
0146   moduleIds.reserve(spacepoints.size());
0147 
0148   for (auto isp = 0ul; isp < numSpacepoints; ++isp) {
0149     const auto& sp = spacepoints[isp];
0150 
0151     // For now just take the first index since does require one single index
0152     // per spacepoint
0153     // TODO does it work for the module map construction to use only the first
0154     // sp?
0155     const auto& sl1 = sp.sourceLinks().at(0).template get<IndexSourceLink>();
0156 
0157     if (m_cfg.geometryIdMap != nullptr) {
0158       moduleIds.push_back(m_cfg.geometryIdMap->right.at(sl1.geometryId()));
0159     } else {
0160       moduleIds.push_back(sl1.geometryId().value());
0161     }
0162   }
0163 
0164   auto t02 = Clock::now();
0165 
0166   // Sort the spacepoints by module ide. Required by module map
0167   std::vector<int> idxs(numSpacepoints);
0168   std::iota(idxs.begin(), idxs.end(), 0);
0169   std::ranges::sort(idxs, {}, [&](auto i) { return moduleIds[i]; });
0170 
0171   std::ranges::sort(moduleIds);
0172 
0173   SimSpacePointContainer sortedSpacepoints;
0174   sortedSpacepoints.reserve(spacepoints.size());
0175   std::ranges::transform(idxs, std::back_inserter(sortedSpacepoints),
0176                          [&](auto i) { return spacepoints[i]; });
0177 
0178   auto t03 = Clock::now();
0179 
0180   auto features = createFeatures(sortedSpacepoints, clusters,
0181                                  m_cfg.nodeFeatures, m_cfg.featureScales);
0182 
0183   auto t1 = Clock::now();
0184 
0185   auto ms = [](auto a, auto b) {
0186     return std::chrono::duration<double, std::milli>(b - a).count();
0187   };
0188   ACTS_DEBUG("Setup time:              " << ms(t0, t01));
0189   ACTS_DEBUG("ModuleId mapping & copy: " << ms(t01, t02));
0190   ACTS_DEBUG("Spacepoint sort:         " << ms(t02, t03));
0191   ACTS_DEBUG("Feature creation:        " << ms(t03, t1));
0192 
0193   // Run the pipeline
0194   ACTS_NVTX_STOP(data_preparation);
0195   GnnTiming timing;
0196 #ifdef ACTS_GNN_CPUONLY
0197   Device device = {Device::Type::eCPU, 0};
0198 #else
0199   Device device = {Device::Type::eCUDA, 0};
0200 #endif
0201   auto trackCandidates =
0202       m_pipeline.run(features, moduleIds, idxs, device, hook, &timing);
0203   ACTS_NVTX_START(post_processing);
0204 
0205   auto t2 = Clock::now();
0206 
0207   ACTS_DEBUG("Done with pipeline, received " << trackCandidates.size()
0208                                              << " candidates");
0209 
0210   // Make the prototracks
0211   std::vector<ProtoTrack> protoTracks;
0212   protoTracks.reserve(trackCandidates.size());
0213 
0214   int nShortTracks = 0;
0215 
0216   /// TODO the whole conversion back to meas idxs should be pulled out of the
0217   /// track trackBuilder
0218   for (auto& candidate : trackCandidates) {
0219     ProtoTrack onetrack;
0220     onetrack.reserve(candidate.size());
0221 
0222     for (auto i : candidate) {
0223       for (const auto& sl : spacepoints.at(i).sourceLinks()) {
0224         onetrack.push_back(sl.template get<IndexSourceLink>().index());
0225       }
0226     }
0227 
0228     if (onetrack.size() < m_cfg.minMeasurementsPerTrack) {
0229       nShortTracks++;
0230       continue;
0231     }
0232 
0233     protoTracks.push_back(std::move(onetrack));
0234   }
0235 
0236   ACTS_DEBUG("Removed " << nShortTracks << " with less then "
0237                         << m_cfg.minMeasurementsPerTrack << " hits");
0238   ACTS_DEBUG("Created " << protoTracks.size() << " proto tracks");
0239 
0240   m_outputProtoTracks(ctx, std::move(protoTracks));
0241 
0242   if (m_outputGraph.isInitialized()) {
0243     auto graph = graphStoreHook->storedGraph();
0244     std::transform(graph.first.begin(), graph.first.end(), graph.first.begin(),
0245                    [&](const auto& a) -> std::int64_t { return idxs.at(a); });
0246     m_outputGraph(ctx, {graph.first, graph.second});
0247   }
0248 
0249   auto t3 = Clock::now();
0250 
0251   {
0252     std::lock_guard<std::mutex> lock(m_mutex);
0253 
0254     m_timing.preprocessingTime(Duration(t1 - t0).count());
0255     m_timing.graphBuildingTime(timing.graphBuildingTime.count());
0256 
0257     assert(timing.classifierTimes.size() == m_timing.classifierTimes.size());
0258     for (auto [aggr, a] :
0259          zip(m_timing.classifierTimes, timing.classifierTimes)) {
0260       aggr(a.count());
0261     }
0262 
0263     m_timing.trackBuildingTime(timing.trackBuildingTime.count());
0264     m_timing.postprocessingTime(Duration(t3 - t2).count());
0265     m_timing.fullTime(Duration(t3 - t0).count());
0266   }
0267 
0268   ACTS_NVTX_STOP(post_processing);
0269   return ActsExamples::ProcessCode::SUCCESS;
0270 }
0271 
0272 ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmGnn::finalize() {
0273   namespace ba = boost::accumulators;
0274 
0275   auto print = [](const auto& t) {
0276     std::stringstream ss;
0277     ss << ba::mean(t) << " +- " << std::sqrt(ba::variance(t)) << " ";
0278     ss << "[" << ba::min(t) << ", " << ba::max(t) << "]";
0279     return ss.str();
0280   };
0281 
0282   ACTS_INFO("GNN timing info");
0283   ACTS_INFO("- preprocessing:  " << print(m_timing.preprocessingTime));
0284   ACTS_INFO("- graph building: " << print(m_timing.graphBuildingTime));
0285   // clang-format off
0286   for (const auto& t : m_timing.classifierTimes) {
0287   ACTS_INFO("- classifier:     " << print(t));
0288   }
0289   // clang-format on
0290   ACTS_INFO("- track building: " << print(m_timing.trackBuildingTime));
0291   ACTS_INFO("- postprocessing: " << print(m_timing.postprocessingTime));
0292   ACTS_INFO("- full timing:    " << print(m_timing.fullTime));
0293 
0294   return {};
0295 }