Warning, file /acts/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <Acts/Definitions/Units.hpp>
0010 #include <ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp>
0011
0012 #include <algorithm>
0013
0014 using namespace Acts::UnitLiterals;
0015
0016 namespace ActsExamples {
0017
0018 TruthGraphBuilder::TruthGraphBuilder(Config config, Acts::Logging::Level level)
0019 : ActsExamples::IAlgorithm("TruthGraphBuilder", level),
0020 m_cfg(std::move(config)) {
0021 m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
0022 m_inputParticles.initialize(m_cfg.inputParticles);
0023 m_outputGraph.initialize(m_cfg.outputGraph);
0024
0025 m_inputMeasParticlesMap.maybeInitialize(m_cfg.inputMeasurementParticlesMap);
0026 m_inputSimhits.maybeInitialize(m_cfg.inputSimHits);
0027 m_inputMeasSimhitMap.maybeInitialize(m_cfg.inputMeasurementSimHitsMap);
0028
0029 bool a = m_inputMeasParticlesMap.isInitialized();
0030 bool b =
0031 m_inputSimhits.isInitialized() && m_inputMeasSimhitMap.isInitialized();
0032
0033
0034 if (!a != !b) {
0035 throw std::invalid_argument("Missing inputs, cannot build truth graph");
0036 }
0037 }
0038
0039 std::vector<std::int64_t> TruthGraphBuilder::buildFromMeasurements(
0040 const SimSpacePointContainer& spacepoints,
0041 const SimParticleContainer& particles,
0042 const IndexMultimap<ActsFatras::Barcode>& measPartMap) const {
0043 if (m_cfg.targetMinPT < 500_MeV) {
0044 ACTS_WARNING(
0045 "truth graph building based on distance from origin, this breaks down "
0046 "for low pT particles. Consider using a higher target pT value");
0047 }
0048
0049
0050 std::unordered_map<ActsFatras::Barcode, std::vector<std::size_t>> tracks;
0051
0052 for (auto i = 0ul; i < spacepoints.size(); ++i) {
0053 const auto measId =
0054 spacepoints[i].sourceLinks()[0].template get<IndexSourceLink>().index();
0055
0056 auto [a, b] = measPartMap.equal_range(measId);
0057 for (auto it = a; it != b; ++it) {
0058 tracks[it->second].push_back(i);
0059 }
0060 }
0061
0062
0063 std::vector<std::int64_t> graph;
0064 std::size_t notFoundParticles = 0;
0065 std::size_t moduleDuplicatesRemoved = 0;
0066
0067 for (auto& [pid, track] : tracks) {
0068 auto found = particles.find(pid);
0069 if (found == particles.end()) {
0070 ACTS_VERBOSE("Did not find " << pid << ", skip track");
0071 notFoundParticles++;
0072 continue;
0073 }
0074
0075 if (found->transverseMomentum() < m_cfg.targetMinPT ||
0076 track.size() < m_cfg.targetMinSize) {
0077 continue;
0078 }
0079
0080 const Acts::Vector3 vtx = found->fourPosition().segment<3>(0);
0081 auto radiusForOrdering = [&](std::size_t i) {
0082 const auto& sp = spacepoints[i];
0083 return std::hypot(sp.x() - vtx[0], sp.y() - vtx[1], sp.z() - vtx[2]);
0084 };
0085
0086
0087 std::ranges::sort(track, {},
0088 [&](const auto& t) { return radiusForOrdering(t); });
0089
0090 if (m_cfg.uniqueModules) {
0091 auto newEnd = std::unique(
0092 track.begin(), track.end(), [&](const auto& a, const auto& b) {
0093 auto gidA = spacepoints[a]
0094 .sourceLinks()[0]
0095 .template get<IndexSourceLink>()
0096 .geometryId();
0097 auto gidB = spacepoints[b]
0098 .sourceLinks()[0]
0099 .template get<IndexSourceLink>()
0100 .geometryId();
0101 return gidA == gidB;
0102 });
0103 moduleDuplicatesRemoved += std::distance(newEnd, track.end());
0104 track.erase(newEnd, track.end());
0105 }
0106
0107 for (auto i = 0ul; i < track.size() - 1; ++i) {
0108 graph.push_back(track[i]);
0109 graph.push_back(track[i + 1]);
0110 }
0111 }
0112
0113 ACTS_DEBUG("Did not find particles for " << notFoundParticles << " tracks");
0114 if (moduleDuplicatesRemoved > 0) {
0115 ACTS_DEBUG(
0116 "Removed " << moduleDuplicatesRemoved
0117 << " hit to ensure a unique hit per track and module");
0118 }
0119
0120 return graph;
0121 }
0122
0123 struct HitInfo {
0124 std::size_t spacePointIndex;
0125 std::int32_t hitIndex;
0126 };
0127
0128 std::vector<std::int64_t> TruthGraphBuilder::buildFromSimhits(
0129 const SimSpacePointContainer& spacepoints,
0130 const IndexMultimap<Index>& measHitMap, const SimHitContainer& simhits,
0131 const SimParticleContainer& particles) const {
0132
0133 std::unordered_map<ActsFatras::Barcode, std::vector<HitInfo>> tracks;
0134
0135 for (auto i = 0ul; i < spacepoints.size(); ++i) {
0136 const auto measId =
0137 spacepoints[i].sourceLinks()[0].template get<IndexSourceLink>().index();
0138
0139 auto [a, b] = measHitMap.equal_range(measId);
0140 for (auto it = a; it != b; ++it) {
0141 const auto& hit = *simhits.nth(it->second);
0142
0143 tracks[hit.particleId()].push_back({i, hit.index()});
0144 }
0145 }
0146
0147
0148 std::vector<std::int64_t> truthGraph;
0149
0150 for (auto& [pid, track] : tracks) {
0151
0152 std::ranges::sort(track, {}, [](const auto& t) { return t.hitIndex; });
0153
0154 auto found = particles.find(pid);
0155 if (found == particles.end()) {
0156 ACTS_WARNING("Did not find " << pid << ", skip track");
0157 continue;
0158 }
0159
0160 for (auto i = 0ul; i < track.size() - 1; ++i) {
0161 if (found->transverseMomentum() > m_cfg.targetMinPT &&
0162 track.size() >= m_cfg.targetMinSize) {
0163 truthGraph.push_back(track[i].spacePointIndex);
0164 truthGraph.push_back(track[i + 1].spacePointIndex);
0165 }
0166 }
0167 }
0168
0169 return truthGraph;
0170 }
0171
0172 ProcessCode TruthGraphBuilder::execute(
0173 const ActsExamples::AlgorithmContext& ctx) const {
0174
0175 const auto& spacepoints = m_inputSpacePoints(ctx);
0176 const auto& particles = m_inputParticles(ctx);
0177
0178 auto edges = (m_inputMeasParticlesMap.isInitialized())
0179 ? buildFromMeasurements(spacepoints, particles,
0180 m_inputMeasParticlesMap(ctx))
0181 : buildFromSimhits(spacepoints, m_inputMeasSimhitMap(ctx),
0182 m_inputSimhits(ctx), particles);
0183
0184 ACTS_DEBUG("Truth track edges: " << edges.size() / 2);
0185
0186 Graph g;
0187 g.edges = std::move(edges);
0188
0189 m_outputGraph(ctx, std::move(g));
0190
0191 return ProcessCode::SUCCESS;
0192 }
0193
0194 }