Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-08-28 08:12:14

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsExamples/TrackFindingGnn/TrackFindingAlgorithmGnn.hpp"
0010 
0011 #include "Acts/Definitions/Units.hpp"
0012 #include "Acts/Plugins/Gnn/GraphStoreHook.hpp"
0013 #include "Acts/Plugins/Gnn/TruthGraphMetricsHook.hpp"
0014 #include "Acts/Utilities/Helpers.hpp"
0015 #include "Acts/Utilities/Zip.hpp"
0016 #include "ActsExamples/EventData/Index.hpp"
0017 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0018 #include "ActsExamples/EventData/ProtoTrack.hpp"
0019 #include "ActsExamples/EventData/SimSpacePoint.hpp"
0020 #include "ActsExamples/Framework/WhiteBoard.hpp"
0021 
0022 #include <algorithm>
0023 #include <chrono>
0024 #include <numeric>
0025 
0026 #include "createFeatures.hpp"
0027 
0028 using namespace ActsExamples;
0029 using namespace Acts::UnitLiterals;
0030 
0031 namespace {
0032 
0033 struct LoopHook : public Acts::GnnHook {
0034   std::vector<Acts::GnnHook*> hooks;
0035 
0036   void operator()(const Acts::PipelineTensors& tensors,
0037                   const Acts::ExecutionContext& ctx) const override {
0038     for (auto hook : hooks) {
0039       (*hook)(tensors, ctx);
0040     }
0041   }
0042 };
0043 
0044 }  // namespace
0045 
0046 ActsExamples::TrackFindingAlgorithmGnn::TrackFindingAlgorithmGnn(
0047     Config config, Acts::Logging::Level level)
0048     : ActsExamples::IAlgorithm("TrackFindingMLBasedAlgorithm", level),
0049       m_cfg(std::move(config)),
0050       m_pipeline(m_cfg.graphConstructor, m_cfg.edgeClassifiers,
0051                  m_cfg.trackBuilder, logger().clone()) {
0052   if (m_cfg.inputSpacePoints.empty()) {
0053     throw std::invalid_argument("Missing spacepoint input collection");
0054   }
0055   if (m_cfg.outputProtoTracks.empty()) {
0056     throw std::invalid_argument("Missing protoTrack output collection");
0057   }
0058 
0059   m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
0060   m_inputClusters.maybeInitialize(m_cfg.inputClusters);
0061   m_outputProtoTracks.initialize(m_cfg.outputProtoTracks);
0062 
0063   m_inputTruthGraph.maybeInitialize(m_cfg.inputTruthGraph);
0064   m_outputGraph.maybeInitialize(m_cfg.outputGraph);
0065 
0066   // reserve space for timing
0067   m_timing.classifierTimes.resize(
0068       m_cfg.edgeClassifiers.size(),
0069       decltype(m_timing.classifierTimes)::value_type{0.f});
0070 
0071   // Check if we want cluster features but do not have them
0072   const static std::array clFeatures = {
0073       NodeFeature::eClusterLoc0, NodeFeature::eClusterLoc0,
0074       NodeFeature::eCellCount,   NodeFeature::eChargeSum,
0075       NodeFeature::eCluster1R,   NodeFeature::eCluster2R};
0076 
0077   auto wantClFeatures = std::ranges::any_of(
0078       m_cfg.nodeFeatures,
0079       [&](const auto& f) { return Acts::rangeContainsValue(clFeatures, f); });
0080 
0081   if (wantClFeatures && !m_inputClusters.isInitialized()) {
0082     throw std::invalid_argument("Cluster features requested, but not provided");
0083   }
0084 
0085   if (m_cfg.nodeFeatures.size() != m_cfg.featureScales.size()) {
0086     throw std::invalid_argument(
0087         "Number of features mismatches number of scale parameters.");
0088   }
0089 }
0090 
0091 /// Allow access to features with nice names
0092 
0093 ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmGnn::execute(
0094     const ActsExamples::AlgorithmContext& ctx) const {
0095   using Clock = std::chrono::high_resolution_clock;
0096   using Duration = std::chrono::duration<double, std::milli>;
0097   auto t0 = Clock::now();
0098 
0099   // Setup hooks
0100   LoopHook hook;
0101 
0102   std::unique_ptr<Acts::TruthGraphMetricsHook> truthGraphHook;
0103   if (m_inputTruthGraph.isInitialized()) {
0104     truthGraphHook = std::make_unique<Acts::TruthGraphMetricsHook>(
0105         m_inputTruthGraph(ctx).edges, this->logger().clone());
0106     hook.hooks.push_back(&*truthGraphHook);
0107   }
0108 
0109   std::unique_ptr<Acts::GraphStoreHook> graphStoreHook;
0110   if (m_outputGraph.isInitialized()) {
0111     graphStoreHook = std::make_unique<Acts::GraphStoreHook>();
0112     hook.hooks.push_back(&*graphStoreHook);
0113   }
0114 
0115   // Read input data
0116   const auto& spacepoints = m_inputSpacePoints(ctx);
0117 
0118   const ClusterContainer* clusters = nullptr;
0119   if (m_inputClusters.isInitialized()) {
0120     clusters = &m_inputClusters(ctx);
0121   }
0122 
0123   // Convert Input data to a list of size [num_measurements x
0124   // measurement_features]
0125   const std::size_t numSpacepoints = spacepoints.size();
0126   const std::size_t numFeatures = m_cfg.nodeFeatures.size();
0127   ACTS_DEBUG("Received " << numSpacepoints << " spacepoints");
0128   ACTS_DEBUG("Construct " << numFeatures << " node features");
0129 
0130   auto t01 = Clock::now();
0131 
0132   std::vector<std::uint64_t> moduleIds;
0133   moduleIds.reserve(spacepoints.size());
0134 
0135   for (auto isp = 0ul; isp < numSpacepoints; ++isp) {
0136     const auto& sp = spacepoints[isp];
0137 
0138     // For now just take the first index since does require one single index
0139     // per spacepoint
0140     // TODO does it work for the module map construction to use only the first
0141     // sp?
0142     const auto& sl1 = sp.sourceLinks().at(0).template get<IndexSourceLink>();
0143 
0144     if (m_cfg.geometryIdMap != nullptr) {
0145       moduleIds.push_back(m_cfg.geometryIdMap->right.at(sl1.geometryId()));
0146     } else {
0147       moduleIds.push_back(sl1.geometryId().value());
0148     }
0149   }
0150 
0151   auto t02 = Clock::now();
0152 
0153   // Sort the spacepoints by module ide. Required by module map
0154   std::vector<int> idxs(numSpacepoints);
0155   std::iota(idxs.begin(), idxs.end(), 0);
0156   std::ranges::sort(idxs, {}, [&](auto i) { return moduleIds[i]; });
0157 
0158   std::ranges::sort(moduleIds);
0159 
0160   SimSpacePointContainer sortedSpacepoints;
0161   sortedSpacepoints.reserve(spacepoints.size());
0162   std::ranges::transform(idxs, std::back_inserter(sortedSpacepoints),
0163                          [&](auto i) { return spacepoints[i]; });
0164 
0165   auto t03 = Clock::now();
0166 
0167   auto features = createFeatures(sortedSpacepoints, clusters,
0168                                  m_cfg.nodeFeatures, m_cfg.featureScales);
0169 
0170   auto t1 = Clock::now();
0171 
0172   auto ms = [](auto a, auto b) {
0173     return std::chrono::duration<double, std::milli>(b - a).count();
0174   };
0175   ACTS_DEBUG("Setup time:              " << ms(t0, t01));
0176   ACTS_DEBUG("ModuleId mapping & copy: " << ms(t01, t02));
0177   ACTS_DEBUG("Spacepoint sort:         " << ms(t02, t03));
0178   ACTS_DEBUG("Feature creation:        " << ms(t03, t1));
0179 
0180   // Run the pipeline
0181   Acts::GnnTiming timing;
0182 #ifdef ACTS_GNN_CPUONLY
0183   Acts::Device device = {Acts::Device::Type::eCPU, 0};
0184 #else
0185   Acts::Device device = {Acts::Device::Type::eCUDA, 0};
0186 #endif
0187   auto trackCandidates =
0188       m_pipeline.run(features, moduleIds, idxs, device, hook, &timing);
0189 
0190   auto t2 = Clock::now();
0191 
0192   ACTS_DEBUG("Done with pipeline, received " << trackCandidates.size()
0193                                              << " candidates");
0194 
0195   // Make the prototracks
0196   std::vector<ProtoTrack> protoTracks;
0197   protoTracks.reserve(trackCandidates.size());
0198 
0199   int nShortTracks = 0;
0200 
0201   /// TODO the whole conversion back to meas idxs should be pulled out of the
0202   /// track trackBuilder
0203   for (auto& candidate : trackCandidates) {
0204     ProtoTrack onetrack;
0205     onetrack.reserve(candidate.size());
0206 
0207     for (auto i : candidate) {
0208       for (const auto& sl : spacepoints.at(i).sourceLinks()) {
0209         onetrack.push_back(sl.template get<IndexSourceLink>().index());
0210       }
0211     }
0212 
0213     if (onetrack.size() < m_cfg.minMeasurementsPerTrack) {
0214       nShortTracks++;
0215       continue;
0216     }
0217 
0218     protoTracks.push_back(std::move(onetrack));
0219   }
0220 
0221   ACTS_DEBUG("Removed " << nShortTracks << " with less then "
0222                         << m_cfg.minMeasurementsPerTrack << " hits");
0223   ACTS_DEBUG("Created " << protoTracks.size() << " proto tracks");
0224 
0225   m_outputProtoTracks(ctx, std::move(protoTracks));
0226 
0227   if (m_outputGraph.isInitialized()) {
0228     auto graph = graphStoreHook->storedGraph();
0229     std::transform(graph.first.begin(), graph.first.end(), graph.first.begin(),
0230                    [&](const auto& a) -> std::int64_t { return idxs.at(a); });
0231     m_outputGraph(ctx, {graph.first, graph.second});
0232   }
0233 
0234   auto t3 = Clock::now();
0235 
0236   {
0237     std::lock_guard<std::mutex> lock(m_mutex);
0238 
0239     m_timing.preprocessingTime(Duration(t1 - t0).count());
0240     m_timing.graphBuildingTime(timing.graphBuildingTime.count());
0241 
0242     assert(timing.classifierTimes.size() == m_timing.classifierTimes.size());
0243     for (auto [aggr, a] :
0244          Acts::zip(m_timing.classifierTimes, timing.classifierTimes)) {
0245       aggr(a.count());
0246     }
0247 
0248     m_timing.trackBuildingTime(timing.trackBuildingTime.count());
0249     m_timing.postprocessingTime(Duration(t3 - t2).count());
0250     m_timing.fullTime(Duration(t3 - t0).count());
0251   }
0252 
0253   return ActsExamples::ProcessCode::SUCCESS;
0254 }
0255 
0256 ActsExamples::ProcessCode TrackFindingAlgorithmGnn::finalize() {
0257   namespace ba = boost::accumulators;
0258 
0259   auto print = [](const auto& t) {
0260     std::stringstream ss;
0261     ss << ba::mean(t) << " +- " << std::sqrt(ba::variance(t)) << " ";
0262     ss << "[" << ba::min(t) << ", " << ba::max(t) << "]";
0263     return ss.str();
0264   };
0265 
0266   ACTS_INFO("GNN timing info");
0267   ACTS_INFO("- preprocessing:  " << print(m_timing.preprocessingTime));
0268   ACTS_INFO("- graph building: " << print(m_timing.graphBuildingTime));
0269   // clang-format off
0270   for (const auto& t : m_timing.classifierTimes) {
0271   ACTS_INFO("- classifier:     " << print(t));
0272   }
0273   // clang-format on
0274   ACTS_INFO("- track building: " << print(m_timing.trackBuildingTime));
0275   ACTS_INFO("- postprocessing: " << print(m_timing.postprocessingTime));
0276   ACTS_INFO("- full timing:    " << print(m_timing.fullTime));
0277 
0278   return {};
0279 }