File indexing completed on 2026-04-27 07:26:17
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/AmbiguityResolution/AmbiguityNetworkConcept.hpp"
0012 #include "Acts/Utilities/Logger.hpp"
0013
0014 #include <cstddef>
0015 #include <map>
0016 #include <memory>
0017 #include <string>
0018 #include <vector>
0019
0020 namespace Acts {
0021
0022
0023
0024 template <AmbiguityNetworkConcept AmbiguityNetwork>
0025 class AmbiguityResolutionML {
0026 public:
0027
0028 struct Config {
0029
0030 std::string inputDuplicateNN = "";
0031
0032 std::size_t nMeasurementsMin = 7;
0033 };
0034
0035
0036
0037
0038 explicit AmbiguityResolutionML(const Config& cfg,
0039 std::unique_ptr<const Logger> logger =
0040 getDefaultLogger("AmbiguityResolutionML",
0041 Logging::INFO))
0042 : m_cfg{cfg},
0043 m_duplicateClassifier(m_cfg.inputDuplicateNN.c_str()),
0044 m_logger{std::move(logger)} {}
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057 template <TrackContainerFrontend track_container_t,
0058 typename source_link_hash_t, typename source_link_equality_t>
0059 std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0060 mapTrackHits(const track_container_t& tracks,
0061 const source_link_hash_t& sourceLinkHash,
0062 const source_link_equality_t& sourceLinkEquality) const {
0063
0064 auto measurementIndexMap =
0065 std::unordered_map<SourceLink, std::size_t, source_link_hash_t,
0066 source_link_equality_t>(0, sourceLinkHash,
0067 sourceLinkEquality);
0068
0069
0070
0071
0072 std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0073 trackMap;
0074 std::size_t trackIndex = 0;
0075 std::vector<std::size_t> measurements;
0076
0077 for (const auto& track : tracks) {
0078
0079 if (track.nMeasurements() < m_cfg.nMeasurementsMin) {
0080 continue;
0081 }
0082 measurements.clear();
0083 for (auto ts : track.trackStatesReversed()) {
0084 if (ts.typeFlags().isMeasurement()) {
0085 SourceLink sourceLink = ts.getUncalibratedSourceLink();
0086
0087 auto emplace = measurementIndexMap.try_emplace(
0088 sourceLink, measurementIndexMap.size());
0089 measurements.push_back(emplace.first->second);
0090 }
0091 }
0092 trackMap.emplace(track.nMeasurements(),
0093 std::make_pair(trackIndex, measurements));
0094 ++trackIndex;
0095 }
0096 return trackMap;
0097 }
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107 template <TrackContainerFrontend track_container_t>
0108 std::vector<std::size_t> solveAmbiguity(
0109 std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0110 const track_container_t& tracks) const {
0111 std::vector<std::vector<float>> outputTensor =
0112 m_duplicateClassifier.inferScores(clusters, tracks);
0113 std::vector<std::size_t> goodTracks =
0114 m_duplicateClassifier.trackSelection(clusters, outputTensor);
0115
0116 return goodTracks;
0117 }
0118
0119 private:
0120
0121 Config m_cfg;
0122
0123
0124
0125 AmbiguityNetwork m_duplicateClassifier;
0126
0127
0128 std::unique_ptr<const Logger> m_logger = nullptr;
0129
0130
0131 const Logger& logger() const { return *m_logger; }
0132 };
0133
0134 }