Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /acts/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include <Acts/Definitions/Units.hpp>
0010 #include <ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp>
0011 
0012 #include <algorithm>
0013 
0014 using namespace Acts::UnitLiterals;
0015 
0016 namespace ActsExamples {
0017 
0018 TruthGraphBuilder::TruthGraphBuilder(Config config, Acts::Logging::Level level)
0019     : ActsExamples::IAlgorithm("TruthGraphBuilder", level),
0020       m_cfg(std::move(config)) {
0021   m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
0022   m_inputParticles.initialize(m_cfg.inputParticles);
0023   m_outputGraph.initialize(m_cfg.outputGraph);
0024 
0025   m_inputMeasParticlesMap.maybeInitialize(m_cfg.inputMeasurementParticlesMap);
0026   m_inputSimhits.maybeInitialize(m_cfg.inputSimHits);
0027   m_inputMeasSimhitMap.maybeInitialize(m_cfg.inputMeasurementSimHitsMap);
0028 
0029   bool a = m_inputMeasParticlesMap.isInitialized();
0030   bool b =
0031       m_inputSimhits.isInitialized() && m_inputMeasSimhitMap.isInitialized();
0032 
0033   // Logical XOR operation
0034   if (!a != !b) {
0035     throw std::invalid_argument("Missing inputs, cannot build truth graph");
0036   }
0037 }
0038 
0039 std::vector<std::int64_t> TruthGraphBuilder::buildFromMeasurements(
0040     const SimSpacePointContainer& spacepoints,
0041     const SimParticleContainer& particles,
0042     const IndexMultimap<ActsFatras::Barcode>& measPartMap) const {
0043   if (m_cfg.targetMinPT < 500_MeV) {
0044     ACTS_WARNING(
0045         "truth graph building based on distance from origin, this breaks down "
0046         "for low pT particles. Consider using a higher target pT value");
0047   }
0048 
0049   // Associate tracks to graph, collect momentum
0050   std::unordered_map<ActsFatras::Barcode, std::vector<std::size_t>> tracks;
0051 
0052   for (auto i = 0ul; i < spacepoints.size(); ++i) {
0053     const auto measId =
0054         spacepoints[i].sourceLinks()[0].template get<IndexSourceLink>().index();
0055 
0056     auto [a, b] = measPartMap.equal_range(measId);
0057     for (auto it = a; it != b; ++it) {
0058       tracks[it->second].push_back(i);
0059     }
0060   }
0061 
0062   // Collect edges for truth graph and target graph
0063   std::vector<std::int64_t> graph;
0064   std::size_t notFoundParticles = 0;
0065   std::size_t moduleDuplicatesRemoved = 0;
0066 
0067   for (auto& [pid, track] : tracks) {
0068     auto found = particles.find(pid);
0069     if (found == particles.end()) {
0070       ACTS_VERBOSE("Did not find " << pid << ", skip track");
0071       notFoundParticles++;
0072       continue;
0073     }
0074 
0075     if (found->transverseMomentum() < m_cfg.targetMinPT ||
0076         track.size() < m_cfg.targetMinSize) {
0077       continue;
0078     }
0079 
0080     const Acts::Vector3 vtx = found->fourPosition().segment<3>(0);
0081     auto radiusForOrdering = [&](std::size_t i) {
0082       const auto& sp = spacepoints[i];
0083       return std::hypot(sp.x() - vtx[0], sp.y() - vtx[1], sp.z() - vtx[2]);
0084     };
0085 
0086     // Sort by radius (this breaks down if the particle has to low momentum)
0087     std::ranges::sort(track, {},
0088                       [&](const auto& t) { return radiusForOrdering(t); });
0089 
0090     if (m_cfg.uniqueModules) {
0091       auto newEnd = std::unique(
0092           track.begin(), track.end(), [&](const auto& a, const auto& b) {
0093             auto gidA = spacepoints[a]
0094                             .sourceLinks()[0]
0095                             .template get<IndexSourceLink>()
0096                             .geometryId();
0097             auto gidB = spacepoints[b]
0098                             .sourceLinks()[0]
0099                             .template get<IndexSourceLink>()
0100                             .geometryId();
0101             return gidA == gidB;
0102           });
0103       moduleDuplicatesRemoved += std::distance(newEnd, track.end());
0104       track.erase(newEnd, track.end());
0105     }
0106 
0107     for (auto i = 0ul; i < track.size() - 1; ++i) {
0108       graph.push_back(track[i]);
0109       graph.push_back(track[i + 1]);
0110     }
0111   }
0112 
0113   ACTS_DEBUG("Did not find particles for " << notFoundParticles << " tracks");
0114   if (moduleDuplicatesRemoved > 0) {
0115     ACTS_DEBUG(
0116         "Removed " << moduleDuplicatesRemoved
0117                    << " hit to ensure a unique hit per track and module");
0118   }
0119 
0120   return graph;
0121 }
0122 
0123 struct HitInfo {
0124   std::size_t spacePointIndex;
0125   std::int32_t hitIndex;
0126 };
0127 
0128 std::vector<std::int64_t> TruthGraphBuilder::buildFromSimhits(
0129     const SimSpacePointContainer& spacepoints,
0130     const IndexMultimap<Index>& measHitMap, const SimHitContainer& simhits,
0131     const SimParticleContainer& particles) const {
0132   // Associate tracks to graph, collect momentum
0133   std::unordered_map<ActsFatras::Barcode, std::vector<HitInfo>> tracks;
0134 
0135   for (auto i = 0ul; i < spacepoints.size(); ++i) {
0136     const auto measId =
0137         spacepoints[i].sourceLinks()[0].template get<IndexSourceLink>().index();
0138 
0139     auto [a, b] = measHitMap.equal_range(measId);
0140     for (auto it = a; it != b; ++it) {
0141       const auto& hit = *simhits.nth(it->second);
0142 
0143       tracks[hit.particleId()].push_back({i, hit.index()});
0144     }
0145   }
0146 
0147   // Collect edges for truth graph and target graph
0148   std::vector<std::int64_t> truthGraph;
0149 
0150   for (auto& [pid, track] : tracks) {
0151     // Sort by hit index, so the edges are connected correctly
0152     std::ranges::sort(track, {}, [](const auto& t) { return t.hitIndex; });
0153 
0154     auto found = particles.find(pid);
0155     if (found == particles.end()) {
0156       ACTS_WARNING("Did not find " << pid << ", skip track");
0157       continue;
0158     }
0159 
0160     for (auto i = 0ul; i < track.size() - 1; ++i) {
0161       if (found->transverseMomentum() > m_cfg.targetMinPT &&
0162           track.size() >= m_cfg.targetMinSize) {
0163         truthGraph.push_back(track[i].spacePointIndex);
0164         truthGraph.push_back(track[i + 1].spacePointIndex);
0165       }
0166     }
0167   }
0168 
0169   return truthGraph;
0170 }
0171 
0172 ProcessCode TruthGraphBuilder::execute(
0173     const ActsExamples::AlgorithmContext& ctx) const {
0174   // Read input data
0175   const auto& spacepoints = m_inputSpacePoints(ctx);
0176   const auto& particles = m_inputParticles(ctx);
0177 
0178   auto edges = (m_inputMeasParticlesMap.isInitialized())
0179                    ? buildFromMeasurements(spacepoints, particles,
0180                                            m_inputMeasParticlesMap(ctx))
0181                    : buildFromSimhits(spacepoints, m_inputMeasSimhitMap(ctx),
0182                                       m_inputSimhits(ctx), particles);
0183 
0184   ACTS_DEBUG("Truth track edges: " << edges.size() / 2);
0185 
0186   Graph g;
0187   g.edges = std::move(edges);
0188 
0189   m_outputGraph(ctx, std::move(g));
0190 
0191   return ProcessCode::SUCCESS;
0192 }
0193 
0194 }  // namespace ActsExamples