File indexing completed on 2025-12-16 09:23:36
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/TrackFindingGnn/TruthGraphBuilder.hpp"
0010
0011 #include "Acts/Definitions/Units.hpp"
0012
0013 #include <algorithm>
0014
0015 using namespace Acts;
0016 using namespace Acts::UnitLiterals;
0017
0018 namespace ActsExamples {
0019
0020 TruthGraphBuilder::TruthGraphBuilder(Config config, Logging::Level level)
0021 : ActsExamples::IAlgorithm("TruthGraphBuilder", level),
0022 m_cfg(std::move(config)) {
0023 m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
0024 m_inputParticles.initialize(m_cfg.inputParticles);
0025 m_outputGraph.initialize(m_cfg.outputGraph);
0026
0027 m_inputMeasParticlesMap.maybeInitialize(m_cfg.inputMeasurementParticlesMap);
0028 m_inputSimhits.maybeInitialize(m_cfg.inputSimHits);
0029 m_inputMeasSimhitMap.maybeInitialize(m_cfg.inputMeasurementSimHitsMap);
0030
0031 bool a = m_inputMeasParticlesMap.isInitialized();
0032 bool b =
0033 m_inputSimhits.isInitialized() && m_inputMeasSimhitMap.isInitialized();
0034
0035
0036 if (!a != !b) {
0037 throw std::invalid_argument("Missing inputs, cannot build truth graph");
0038 }
0039 }
0040
0041 std::vector<std::int64_t> TruthGraphBuilder::buildFromMeasurements(
0042 const SimSpacePointContainer& spacepoints,
0043 const SimParticleContainer& particles,
0044 const IndexMultimap<ActsFatras::Barcode>& measPartMap) const {
0045 if (m_cfg.targetMinPT < 500_MeV) {
0046 ACTS_WARNING(
0047 "truth graph building based on distance from origin, this breaks down "
0048 "for low pT particles. Consider using a higher target pT value");
0049 }
0050
0051
0052 std::unordered_map<ActsFatras::Barcode, std::vector<std::size_t>> tracks;
0053
0054 for (auto i = 0ul; i < spacepoints.size(); ++i) {
0055 const auto measId =
0056 spacepoints[i].sourceLinks()[0].template get<IndexSourceLink>().index();
0057
0058 auto [a, b] = measPartMap.equal_range(measId);
0059 for (auto it = a; it != b; ++it) {
0060 tracks[it->second].push_back(i);
0061 }
0062 }
0063
0064
0065 std::vector<std::int64_t> graph;
0066 std::size_t notFoundParticles = 0;
0067 std::size_t moduleDuplicatesRemoved = 0;
0068
0069 for (auto& [pid, track] : tracks) {
0070 auto found = particles.find(pid);
0071 if (found == particles.end()) {
0072 ACTS_VERBOSE("Did not find " << pid << ", skip track");
0073 notFoundParticles++;
0074 continue;
0075 }
0076
0077 if (found->transverseMomentum() < m_cfg.targetMinPT ||
0078 track.size() < m_cfg.targetMinSize) {
0079 continue;
0080 }
0081
0082 const Vector3 vtx = found->fourPosition().segment<3>(0);
0083 auto radiusForOrdering = [&](std::size_t i) {
0084 const auto& sp = spacepoints[i];
0085 return std::hypot(sp.x() - vtx[0], sp.y() - vtx[1], sp.z() - vtx[2]);
0086 };
0087
0088
0089 std::ranges::sort(track, {},
0090 [&](const auto& t) { return radiusForOrdering(t); });
0091
0092 if (m_cfg.uniqueModules) {
0093 auto newEnd = std::unique(
0094 track.begin(), track.end(), [&](const auto& a, const auto& b) {
0095 auto gidA = spacepoints[a]
0096 .sourceLinks()[0]
0097 .template get<IndexSourceLink>()
0098 .geometryId();
0099 auto gidB = spacepoints[b]
0100 .sourceLinks()[0]
0101 .template get<IndexSourceLink>()
0102 .geometryId();
0103 return gidA == gidB;
0104 });
0105 moduleDuplicatesRemoved += std::distance(newEnd, track.end());
0106 track.erase(newEnd, track.end());
0107 }
0108
0109 for (auto i = 0ul; i < track.size() - 1; ++i) {
0110 graph.push_back(track[i]);
0111 graph.push_back(track[i + 1]);
0112 }
0113 }
0114
0115 ACTS_DEBUG("Did not find particles for " << notFoundParticles << " tracks");
0116 if (moduleDuplicatesRemoved > 0) {
0117 ACTS_DEBUG(
0118 "Removed " << moduleDuplicatesRemoved
0119 << " hit to ensure a unique hit per track and module");
0120 }
0121
0122 return graph;
0123 }
0124
0125 struct HitInfo {
0126 std::size_t spacePointIndex;
0127 std::int32_t hitIndex;
0128 };
0129
0130 std::vector<std::int64_t> TruthGraphBuilder::buildFromSimhits(
0131 const SimSpacePointContainer& spacepoints,
0132 const IndexMultimap<Index>& measHitMap, const SimHitContainer& simhits,
0133 const SimParticleContainer& particles) const {
0134
0135 std::unordered_map<ActsFatras::Barcode, std::vector<HitInfo>> tracks;
0136
0137 for (auto i = 0ul; i < spacepoints.size(); ++i) {
0138 const auto measId =
0139 spacepoints[i].sourceLinks()[0].template get<IndexSourceLink>().index();
0140
0141 auto [a, b] = measHitMap.equal_range(measId);
0142 for (auto it = a; it != b; ++it) {
0143 const auto& hit = *simhits.nth(it->second);
0144
0145 tracks[hit.particleId()].push_back({i, hit.index()});
0146 }
0147 }
0148
0149
0150 std::vector<std::int64_t> truthGraph;
0151
0152 for (auto& [pid, track] : tracks) {
0153
0154 std::ranges::sort(track, {}, [](const auto& t) { return t.hitIndex; });
0155
0156 auto found = particles.find(pid);
0157 if (found == particles.end()) {
0158 ACTS_WARNING("Did not find " << pid << ", skip track");
0159 continue;
0160 }
0161
0162 for (auto i = 0ul; i < track.size() - 1; ++i) {
0163 if (found->transverseMomentum() > m_cfg.targetMinPT &&
0164 track.size() >= m_cfg.targetMinSize) {
0165 truthGraph.push_back(track[i].spacePointIndex);
0166 truthGraph.push_back(track[i + 1].spacePointIndex);
0167 }
0168 }
0169 }
0170
0171 return truthGraph;
0172 }
0173
0174 ProcessCode TruthGraphBuilder::execute(
0175 const ActsExamples::AlgorithmContext& ctx) const {
0176
0177 const auto& spacepoints = m_inputSpacePoints(ctx);
0178 const auto& particles = m_inputParticles(ctx);
0179
0180 auto edges = (m_inputMeasParticlesMap.isInitialized())
0181 ? buildFromMeasurements(spacepoints, particles,
0182 m_inputMeasParticlesMap(ctx))
0183 : buildFromSimhits(spacepoints, m_inputMeasSimhitMap(ctx),
0184 m_inputSimhits(ctx), particles);
0185
0186 ACTS_DEBUG("Truth track edges: " << edges.size() / 2);
0187
0188 Graph g;
0189 g.edges = std::move(edges);
0190
0191 m_outputGraph(ctx, std::move(g));
0192
0193 return ProcessCode::SUCCESS;
0194 }
0195
0196 }