File indexing completed on 2025-07-11 07:49:41
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/AmbiguityResolution/AmbiguityNetworkConcept.hpp"
0012 #include "Acts/Definitions/Units.hpp"
0013 #include "Acts/EventData/TrackContainer.hpp"
0014 #include "Acts/Utilities/Delegate.hpp"
0015 #include "Acts/Utilities/Logger.hpp"
0016
0017 #include <cstddef>
0018 #include <map>
0019 #include <memory>
0020 #include <string>
0021 #include <tuple>
0022 #include <vector>
0023
0024 namespace Acts {
0025
0026
0027
0028 template <AmbiguityNetworkConcept AmbiguityNetwork>
0029 class AmbiguityResolutionML {
0030 public:
0031 struct Config {
0032
0033 std::string inputDuplicateNN = "";
0034
0035 std::size_t nMeasurementsMin = 7;
0036 };
0037
0038
0039
0040
0041 explicit AmbiguityResolutionML(const Config& cfg,
0042 std::unique_ptr<const Logger> logger =
0043 getDefaultLogger("AmbiguityResolutionML",
0044 Logging::INFO))
0045 : m_cfg{cfg},
0046 m_duplicateClassifier(m_cfg.inputDuplicateNN.c_str()),
0047 m_logger{std::move(logger)} {}
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060 template <TrackContainerFrontend track_container_t,
0061 typename source_link_hash_t, typename source_link_equality_t>
0062 std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0063 mapTrackHits(const track_container_t& tracks,
0064 const source_link_hash_t& sourceLinkHash,
0065 const source_link_equality_t& sourceLinkEquality) const {
0066
0067 auto measurementIndexMap =
0068 std::unordered_map<SourceLink, std::size_t, source_link_hash_t,
0069 source_link_equality_t>(0, sourceLinkHash,
0070 sourceLinkEquality);
0071
0072
0073
0074
0075 std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0076 trackMap;
0077 std::size_t trackIndex = 0;
0078 std::vector<std::size_t> measurements;
0079
0080 for (const auto& track : tracks) {
0081
0082 if (track.nMeasurements() < m_cfg.nMeasurementsMin) {
0083 continue;
0084 }
0085 measurements.clear();
0086 for (auto ts : track.trackStatesReversed()) {
0087 if (ts.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0088 SourceLink sourceLink = ts.getUncalibratedSourceLink();
0089
0090 auto emplace = measurementIndexMap.try_emplace(
0091 sourceLink, measurementIndexMap.size());
0092 measurements.push_back(emplace.first->second);
0093 }
0094 }
0095 trackMap.emplace(track.nMeasurements(),
0096 std::make_pair(trackIndex, measurements));
0097 ++trackIndex;
0098 }
0099 return trackMap;
0100 }
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110 template <TrackContainerFrontend track_container_t>
0111 std::vector<std::size_t> solveAmbiguity(
0112 std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0113 const track_container_t& tracks) const {
0114 std::vector<std::vector<float>> outputTensor =
0115 m_duplicateClassifier.inferScores(clusters, tracks);
0116 std::vector<std::size_t> goodTracks =
0117 m_duplicateClassifier.trackSelection(clusters, outputTensor);
0118
0119 return goodTracks;
0120 }
0121
0122 private:
0123
0124 Config m_cfg;
0125
0126
0127
0128 AmbiguityNetwork m_duplicateClassifier;
0129
0130
0131 std::unique_ptr<const Logger> m_logger = nullptr;
0132
0133
0134 const Logger& logger() const { return *m_logger; }
0135 };
0136
0137 }