Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:23:36

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsExamples/TrackFindingGnn/TruthGraphBuilder.hpp"
0010 
0011 #include "Acts/Definitions/Units.hpp"
0012 
0013 #include <algorithm>
0014 
0015 using namespace Acts;
0016 using namespace Acts::UnitLiterals;
0017 
0018 namespace ActsExamples {
0019 
0020 TruthGraphBuilder::TruthGraphBuilder(Config config, Logging::Level level)
0021     : ActsExamples::IAlgorithm("TruthGraphBuilder", level),
0022       m_cfg(std::move(config)) {
0023   m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
0024   m_inputParticles.initialize(m_cfg.inputParticles);
0025   m_outputGraph.initialize(m_cfg.outputGraph);
0026 
0027   m_inputMeasParticlesMap.maybeInitialize(m_cfg.inputMeasurementParticlesMap);
0028   m_inputSimhits.maybeInitialize(m_cfg.inputSimHits);
0029   m_inputMeasSimhitMap.maybeInitialize(m_cfg.inputMeasurementSimHitsMap);
0030 
0031   bool a = m_inputMeasParticlesMap.isInitialized();
0032   bool b =
0033       m_inputSimhits.isInitialized() && m_inputMeasSimhitMap.isInitialized();
0034 
0035   // Logical XOR operation
0036   if (!a != !b) {
0037     throw std::invalid_argument("Missing inputs, cannot build truth graph");
0038   }
0039 }
0040 
0041 std::vector<std::int64_t> TruthGraphBuilder::buildFromMeasurements(
0042     const SimSpacePointContainer& spacepoints,
0043     const SimParticleContainer& particles,
0044     const IndexMultimap<ActsFatras::Barcode>& measPartMap) const {
0045   if (m_cfg.targetMinPT < 500_MeV) {
0046     ACTS_WARNING(
0047         "truth graph building based on distance from origin, this breaks down "
0048         "for low pT particles. Consider using a higher target pT value");
0049   }
0050 
0051   // Associate tracks to graph, collect momentum
0052   std::unordered_map<ActsFatras::Barcode, std::vector<std::size_t>> tracks;
0053 
0054   for (auto i = 0ul; i < spacepoints.size(); ++i) {
0055     const auto measId =
0056         spacepoints[i].sourceLinks()[0].template get<IndexSourceLink>().index();
0057 
0058     auto [a, b] = measPartMap.equal_range(measId);
0059     for (auto it = a; it != b; ++it) {
0060       tracks[it->second].push_back(i);
0061     }
0062   }
0063 
0064   // Collect edges for truth graph and target graph
0065   std::vector<std::int64_t> graph;
0066   std::size_t notFoundParticles = 0;
0067   std::size_t moduleDuplicatesRemoved = 0;
0068 
0069   for (auto& [pid, track] : tracks) {
0070     auto found = particles.find(pid);
0071     if (found == particles.end()) {
0072       ACTS_VERBOSE("Did not find " << pid << ", skip track");
0073       notFoundParticles++;
0074       continue;
0075     }
0076 
0077     if (found->transverseMomentum() < m_cfg.targetMinPT ||
0078         track.size() < m_cfg.targetMinSize) {
0079       continue;
0080     }
0081 
0082     const Vector3 vtx = found->fourPosition().segment<3>(0);
0083     auto radiusForOrdering = [&](std::size_t i) {
0084       const auto& sp = spacepoints[i];
0085       return std::hypot(sp.x() - vtx[0], sp.y() - vtx[1], sp.z() - vtx[2]);
0086     };
0087 
0088     // Sort by radius (this breaks down if the particle has to low momentum)
0089     std::ranges::sort(track, {},
0090                       [&](const auto& t) { return radiusForOrdering(t); });
0091 
0092     if (m_cfg.uniqueModules) {
0093       auto newEnd = std::unique(
0094           track.begin(), track.end(), [&](const auto& a, const auto& b) {
0095             auto gidA = spacepoints[a]
0096                             .sourceLinks()[0]
0097                             .template get<IndexSourceLink>()
0098                             .geometryId();
0099             auto gidB = spacepoints[b]
0100                             .sourceLinks()[0]
0101                             .template get<IndexSourceLink>()
0102                             .geometryId();
0103             return gidA == gidB;
0104           });
0105       moduleDuplicatesRemoved += std::distance(newEnd, track.end());
0106       track.erase(newEnd, track.end());
0107     }
0108 
0109     for (auto i = 0ul; i < track.size() - 1; ++i) {
0110       graph.push_back(track[i]);
0111       graph.push_back(track[i + 1]);
0112     }
0113   }
0114 
0115   ACTS_DEBUG("Did not find particles for " << notFoundParticles << " tracks");
0116   if (moduleDuplicatesRemoved > 0) {
0117     ACTS_DEBUG(
0118         "Removed " << moduleDuplicatesRemoved
0119                    << " hit to ensure a unique hit per track and module");
0120   }
0121 
0122   return graph;
0123 }
0124 
0125 struct HitInfo {
0126   std::size_t spacePointIndex;
0127   std::int32_t hitIndex;
0128 };
0129 
0130 std::vector<std::int64_t> TruthGraphBuilder::buildFromSimhits(
0131     const SimSpacePointContainer& spacepoints,
0132     const IndexMultimap<Index>& measHitMap, const SimHitContainer& simhits,
0133     const SimParticleContainer& particles) const {
0134   // Associate tracks to graph, collect momentum
0135   std::unordered_map<ActsFatras::Barcode, std::vector<HitInfo>> tracks;
0136 
0137   for (auto i = 0ul; i < spacepoints.size(); ++i) {
0138     const auto measId =
0139         spacepoints[i].sourceLinks()[0].template get<IndexSourceLink>().index();
0140 
0141     auto [a, b] = measHitMap.equal_range(measId);
0142     for (auto it = a; it != b; ++it) {
0143       const auto& hit = *simhits.nth(it->second);
0144 
0145       tracks[hit.particleId()].push_back({i, hit.index()});
0146     }
0147   }
0148 
0149   // Collect edges for truth graph and target graph
0150   std::vector<std::int64_t> truthGraph;
0151 
0152   for (auto& [pid, track] : tracks) {
0153     // Sort by hit index, so the edges are connected correctly
0154     std::ranges::sort(track, {}, [](const auto& t) { return t.hitIndex; });
0155 
0156     auto found = particles.find(pid);
0157     if (found == particles.end()) {
0158       ACTS_WARNING("Did not find " << pid << ", skip track");
0159       continue;
0160     }
0161 
0162     for (auto i = 0ul; i < track.size() - 1; ++i) {
0163       if (found->transverseMomentum() > m_cfg.targetMinPT &&
0164           track.size() >= m_cfg.targetMinSize) {
0165         truthGraph.push_back(track[i].spacePointIndex);
0166         truthGraph.push_back(track[i + 1].spacePointIndex);
0167       }
0168     }
0169   }
0170 
0171   return truthGraph;
0172 }
0173 
0174 ProcessCode TruthGraphBuilder::execute(
0175     const ActsExamples::AlgorithmContext& ctx) const {
0176   // Read input data
0177   const auto& spacepoints = m_inputSpacePoints(ctx);
0178   const auto& particles = m_inputParticles(ctx);
0179 
0180   auto edges = (m_inputMeasParticlesMap.isInitialized())
0181                    ? buildFromMeasurements(spacepoints, particles,
0182                                            m_inputMeasParticlesMap(ctx))
0183                    : buildFromSimhits(spacepoints, m_inputMeasSimhitMap(ctx),
0184                                       m_inputSimhits(ctx), particles);
0185 
0186   ACTS_DEBUG("Truth track edges: " << edges.size() / 2);
0187 
0188   Graph g;
0189   g.edges = std::move(edges);
0190 
0191   m_outputGraph(ctx, std::move(g));
0192 
0193   return ProcessCode::SUCCESS;
0194 }
0195 
0196 }  // namespace ActsExamples