File indexing completed on 2025-07-14 08:10:26
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/AmbiguityResolution/ScoreBasedAmbiguityResolution.hpp"
0012
0013 #include "Acts/Definitions/Units.hpp"
0014 #include "Acts/EventData/TrackContainerFrontendConcept.hpp"
0015 #include "Acts/Utilities/VectorHelpers.hpp"
0016
0017 #include <unordered_map>
0018
0019 namespace Acts {
0020
0021 inline const Logger& ScoreBasedAmbiguityResolution::logger() const {
0022 return *m_logger;
0023 }
0024
0025 template <TrackContainerFrontend track_container_t>
0026 std::vector<std::vector<ScoreBasedAmbiguityResolution::TrackFeatures>>
0027 ScoreBasedAmbiguityResolution::computeInitialState(
0028 const track_container_t& tracks) const {
0029 ACTS_VERBOSE("Starting to compute initial state");
0030 std::vector<std::vector<TrackFeatures>> trackFeaturesVectors;
0031 trackFeaturesVectors.reserve(tracks.size());
0032
0033 for (const auto& track : tracks) {
0034 int numberOfDetectors = m_cfg.detectorConfigs.size();
0035
0036 std::vector<TrackFeatures> trackFeaturesVector(numberOfDetectors);
0037
0038 for (const auto& ts : track.trackStatesReversed()) {
0039 if (!ts.hasReferenceSurface()) {
0040 ACTS_DEBUG("Track state has no reference surface");
0041 continue;
0042 }
0043 auto iVolume = ts.referenceSurface().geometryId().volume();
0044 auto volume_it = m_cfg.volumeMap.find(iVolume);
0045 if (volume_it == m_cfg.volumeMap.end()) {
0046 ACTS_ERROR("Volume " << iVolume << "not found in the volume map");
0047 continue;
0048 }
0049 auto detectorId = volume_it->second;
0050
0051 if (ts.typeFlags().test(Acts::TrackStateFlag::HoleFlag)) {
0052 ACTS_VERBOSE("Track state type is HoleFlag");
0053 trackFeaturesVector[detectorId].nHoles++;
0054 } else if (ts.typeFlags().test(Acts::TrackStateFlag::OutlierFlag)) {
0055 ACTS_VERBOSE("Track state type is OutlierFlag");
0056 trackFeaturesVector[detectorId].nOutliers++;
0057
0058 } else if (ts.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0059 ACTS_VERBOSE("Track state type is MeasurementFlag");
0060
0061 if (ts.typeFlags().test(Acts::TrackStateFlag::SharedHitFlag)) {
0062 trackFeaturesVector[detectorId].nSharedHits++;
0063 }
0064 trackFeaturesVector[detectorId].nHits++;
0065 }
0066 }
0067 trackFeaturesVectors.push_back(std::move(trackFeaturesVector));
0068 }
0069
0070 return trackFeaturesVectors;
0071 }
0072
0073 template <TrackContainerFrontend track_container_t>
0074 std::vector<double> Acts::ScoreBasedAmbiguityResolution::simpleScore(
0075 const track_container_t& tracks,
0076 const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
0077 const Optionals<typename track_container_t::ConstTrackProxy>& optionals)
0078 const {
0079 std::vector<double> trackScore;
0080 trackScore.reserve(tracks.size());
0081
0082 int iTrack = 0;
0083
0084 ACTS_VERBOSE("Number of detectors: " << m_cfg.detectorConfigs.size());
0085
0086 ACTS_INFO("Starting to score tracks");
0087
0088
0089 for (const auto& track : tracks) {
0090
0091 const auto& trackFeaturesVector = trackFeaturesVectors[iTrack];
0092 double score = 1;
0093 auto eta = Acts::VectorHelpers::eta(track.momentum());
0094
0095
0096 for (const auto& cutFunction : optionals.cuts) {
0097 if (cutFunction(track)) {
0098 score = 0;
0099 ACTS_DEBUG("Track: " << iTrack
0100 << " has score = 0, due to optional cuts.");
0101 break;
0102 }
0103 }
0104
0105 if (score == 0) {
0106 iTrack++;
0107 trackScore.push_back(score);
0108 ACTS_DEBUG("Track: " << iTrack << " score : " << score);
0109 continue;
0110 }
0111
0112
0113 for (std::size_t detectorId = 0; detectorId < m_cfg.detectorConfigs.size();
0114 detectorId++) {
0115 const auto& detector = m_cfg.detectorConfigs.at(detectorId);
0116
0117 const auto& trackFeatures = trackFeaturesVector[detectorId];
0118
0119 ACTS_VERBOSE("---> Found summary information");
0120 ACTS_VERBOSE("---> Detector ID: " << detectorId);
0121 ACTS_VERBOSE("---> Number of hits: " << trackFeatures.nHits);
0122 ACTS_VERBOSE("---> Number of holes: " << trackFeatures.nHoles);
0123 ACTS_VERBOSE("---> Number of outliers: " << trackFeatures.nOutliers);
0124
0125
0126 if (etaBasedCuts(detector, trackFeatures, eta)) {
0127 score = 0;
0128 ACTS_DEBUG("Track: " << iTrack
0129 << " has score = 0, due to detector cuts");
0130 break;
0131 }
0132 }
0133
0134 if (score == 0) {
0135 iTrack++;
0136 trackScore.push_back(score);
0137 ACTS_DEBUG("Track: " << iTrack << " score : " << score);
0138 continue;
0139 }
0140
0141
0142
0143 ACTS_VERBOSE("Using Simple Scoring function");
0144
0145 score = 100;
0146
0147
0148
0149 for (std::size_t detectorId = 0; detectorId < m_cfg.detectorConfigs.size();
0150 detectorId++) {
0151 const auto& detector = m_cfg.detectorConfigs.at(detectorId);
0152 const auto& trackFeatures = trackFeaturesVector[detectorId];
0153
0154 score += trackFeatures.nHits * detector.hitsScoreWeight;
0155 score += trackFeatures.nHoles * detector.holesScoreWeight;
0156 score += trackFeatures.nOutliers * detector.outliersScoreWeight;
0157 score += trackFeatures.nSharedHits * detector.otherScoreWeight;
0158 }
0159
0160
0161 for (const auto& weightFunction : optionals.weights) {
0162 weightFunction(track, score);
0163 }
0164
0165
0166 if (track.chi2() > 0 && track.nDoF() > 0) {
0167 double p = 1. / std::log10(10. + 10. * track.chi2() / track.nDoF());
0168 if (p > 0) {
0169 score += p;
0170 } else {
0171 score -= 50;
0172 }
0173 }
0174
0175 iTrack++;
0176
0177
0178 trackScore.push_back(score);
0179 ACTS_VERBOSE("Track: " << iTrack << " score: " << score);
0180
0181 }
0182
0183 return trackScore;
0184 }
0185
0186 template <TrackContainerFrontend track_container_t>
0187 std::vector<double> Acts::ScoreBasedAmbiguityResolution::ambiguityScore(
0188 const track_container_t& tracks,
0189 const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
0190 const Optionals<typename track_container_t::ConstTrackProxy>& optionals)
0191 const {
0192 std::vector<double> trackScore;
0193 trackScore.reserve(tracks.size());
0194
0195 ACTS_VERBOSE("Using Ambiguity Scoring function");
0196
0197 int iTrack = 0;
0198
0199 ACTS_VERBOSE("Number of detectors: " << m_cfg.detectorConfigs.size());
0200
0201 ACTS_INFO("Starting to score tracks");
0202
0203
0204 for (const auto& track : tracks) {
0205
0206 const auto& trackFeaturesVector = trackFeaturesVectors[iTrack];
0207 double score = 1;
0208 auto pT = Acts::VectorHelpers::perp(track.momentum());
0209 auto eta = Acts::VectorHelpers::eta(track.momentum());
0210
0211
0212 for (const auto& cutFunction : optionals.cuts) {
0213 if (cutFunction(track)) {
0214 score = 0;
0215 ACTS_DEBUG("Track: " << iTrack
0216 << " has score = 0, due to optional cuts.");
0217 break;
0218 }
0219 }
0220
0221 if (score == 0) {
0222 iTrack++;
0223 trackScore.push_back(score);
0224 ACTS_DEBUG("Track: " << iTrack << " score : " << score);
0225 continue;
0226 }
0227
0228
0229 for (std::size_t detectorId = 0; detectorId < m_cfg.detectorConfigs.size();
0230 detectorId++) {
0231 const auto& detector = m_cfg.detectorConfigs.at(detectorId);
0232
0233 const auto& trackFeatures = trackFeaturesVector[detectorId];
0234
0235 ACTS_VERBOSE("---> Found summary information");
0236 ACTS_VERBOSE("---> Detector ID: " << detectorId);
0237 ACTS_VERBOSE("---> Number of hits: " << trackFeatures.nHits);
0238 ACTS_VERBOSE("---> Number of holes: " << trackFeatures.nHoles);
0239 ACTS_VERBOSE("---> Number of outliers: " << trackFeatures.nOutliers);
0240
0241
0242 if (etaBasedCuts(detector, trackFeatures, eta)) {
0243 score = 0;
0244 ACTS_DEBUG("Track: " << iTrack
0245 << " has score = 0, due to detector cuts");
0246 break;
0247 }
0248 }
0249
0250 if (score == 0) {
0251 iTrack++;
0252 trackScore.push_back(score);
0253 ACTS_DEBUG("Track: " << iTrack << " score : " << score);
0254 continue;
0255 }
0256
0257
0258
0259
0260 score = std::log10(pT / UnitConstants::MeV) - 1.;
0261
0262 ACTS_DEBUG("Modifier for pT = " << pT << " GeV is : " << score
0263 << " New score now: " << score);
0264
0265 for (std::size_t detectorId = 0; detectorId < m_cfg.detectorConfigs.size();
0266 detectorId++) {
0267 const auto& detector = m_cfg.detectorConfigs.at(detectorId);
0268
0269 const auto& trackFeatures = trackFeaturesVector[detectorId];
0270
0271
0272
0273 std::size_t nHits = trackFeatures.nHits;
0274 if (nHits > detector.maxHits) {
0275 score = score * (nHits - detector.maxHits + 1);
0276 nHits = detector.maxHits;
0277 }
0278 score = score * detector.factorHits[nHits];
0279 ACTS_DEBUG("Modifier for " << nHits
0280 << " hits: " << detector.factorHits[nHits]
0281 << " New score now: " << score);
0282
0283
0284
0285 std::size_t iHoles = trackFeatures.nHoles;
0286 if (iHoles > detector.maxHoles) {
0287 score /= (iHoles - detector.maxHoles + 1);
0288 iHoles = detector.maxHoles;
0289 }
0290 score = score * detector.factorHoles[iHoles];
0291 ACTS_DEBUG("Modifier for " << iHoles
0292 << " holes: " << detector.factorHoles[iHoles]
0293 << " New score now: " << score);
0294 }
0295
0296 for (const auto& scoreFunction : optionals.scores) {
0297 scoreFunction(track, score);
0298 }
0299
0300 if (track.chi2() > 0 && track.nDoF() > 0) {
0301 double chi2 = track.chi2();
0302 int indf = track.nDoF();
0303 double fac = 1. / std::log10(10. + 10. * chi2 / indf);
0304 score = score * fac;
0305 ACTS_DEBUG("Modifier for chi2 = " << chi2 << " and NDF = " << indf
0306 << " is : " << fac
0307 << " New score now: " << score);
0308 }
0309
0310 iTrack++;
0311
0312
0313 trackScore.push_back(score);
0314 ACTS_VERBOSE("Track: " << iTrack << " score: " << score);
0315
0316 }
0317
0318 return trackScore;
0319 }
0320
0321 template <TrackContainerFrontend track_container_t, typename source_link_hash_t,
0322 typename source_link_equality_t>
0323 std::vector<int> Acts::ScoreBasedAmbiguityResolution::solveAmbiguity(
0324 const track_container_t& tracks, source_link_hash_t sourceLinkHash,
0325 source_link_equality_t sourceLinkEquality,
0326 const Optionals<typename track_container_t::ConstTrackProxy>& optionals)
0327 const {
0328 ACTS_INFO("Number of tracks before Ambiguty Resolution: " << tracks.size());
0329
0330
0331
0332 const std::vector<std::vector<TrackFeatures>> trackFeaturesVectors =
0333 computeInitialState<track_container_t>(tracks);
0334
0335 std::vector<double> trackScore;
0336 trackScore.reserve(tracks.size());
0337 if (m_cfg.useAmbiguityScoring) {
0338 trackScore = ambiguityScore(tracks, trackFeaturesVectors, optionals);
0339 } else {
0340 trackScore = simpleScore(tracks, trackFeaturesVectors, optionals);
0341 }
0342
0343 auto MeasurementIndexMap =
0344 std::unordered_map<SourceLink, std::size_t, source_link_hash_t,
0345 source_link_equality_t>(0, sourceLinkHash,
0346 sourceLinkEquality);
0347
0348 std::vector<std::vector<std::size_t>> measurementsPerTrackVector;
0349 std::map<std::size_t, std::size_t> nTracksPerMeasurement;
0350
0351
0352
0353
0354
0355 for (const auto& track : tracks) {
0356 std::vector<std::size_t> measurementsPerTrack;
0357 for (const auto& ts : track.trackStatesReversed()) {
0358 if (!ts.typeFlags().test(Acts::TrackStateFlag::OutlierFlag) &&
0359 !ts.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0360 continue;
0361 }
0362 Acts::SourceLink sourceLink = ts.getUncalibratedSourceLink();
0363
0364 auto emplace = MeasurementIndexMap.try_emplace(
0365 sourceLink, MeasurementIndexMap.size());
0366 std::size_t iMeasurement = emplace.first->second;
0367 measurementsPerTrack.push_back(iMeasurement);
0368 if (nTracksPerMeasurement.find(iMeasurement) ==
0369 nTracksPerMeasurement.end()) {
0370 nTracksPerMeasurement[iMeasurement] = 0;
0371 }
0372 nTracksPerMeasurement[iMeasurement]++;
0373 }
0374 measurementsPerTrackVector.push_back(std::move(measurementsPerTrack));
0375 }
0376
0377 std::vector<int> goodTracks;
0378 int cleanTrackIndex = 0;
0379
0380 auto optionalHitSelections = optionals.hitSelections;
0381
0382
0383
0384
0385 for (std::size_t iTrack = 0; const auto& track : tracks) {
0386
0387 if (getCleanedOutTracks(track, trackScore[iTrack],
0388 measurementsPerTrackVector[iTrack],
0389 nTracksPerMeasurement, optionalHitSelections)) {
0390 cleanTrackIndex++;
0391 if (trackScore[iTrack] > m_cfg.minScore) {
0392 goodTracks.push_back(track.index());
0393 }
0394 }
0395 iTrack++;
0396 }
0397 ACTS_INFO("Number of clean tracks: " << cleanTrackIndex);
0398 ACTS_VERBOSE("Min score: " << m_cfg.minScore);
0399 ACTS_INFO("Number of Good tracks: " << goodTracks.size());
0400 return goodTracks;
0401 }
0402
0403 template <TrackProxyConcept track_proxy_t>
0404 bool Acts::ScoreBasedAmbiguityResolution::getCleanedOutTracks(
0405 const track_proxy_t& track, const double& trackScore,
0406 const std::vector<std::size_t>& measurementsPerTrack,
0407 const std::map<std::size_t, std::size_t>& nTracksPerMeasurement,
0408 const std::vector<
0409 std::function<void(const track_proxy_t&,
0410 const typename track_proxy_t::ConstTrackStateProxy&,
0411 TrackStateTypes&)>>& optionalHitSelections) const {
0412
0413
0414 std::vector<TrackStateTypes> trackStateTypes;
0415
0416
0417 for (std::size_t index = 0; const auto& ts : track.trackStatesReversed()) {
0418 if (ts.typeFlags().test(Acts::TrackStateFlag::OutlierFlag) ||
0419 ts.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0420 std::size_t iMeasurement = measurementsPerTrack[index];
0421 auto it = nTracksPerMeasurement.find(iMeasurement);
0422 if (it == nTracksPerMeasurement.end()) {
0423 trackStateTypes.push_back(TrackStateTypes::OtherTrackStateType);
0424 index++;
0425 continue;
0426 }
0427
0428 std::size_t nTracksShared = it->second;
0429 auto isoutliner = ts.typeFlags().test(Acts::TrackStateFlag::OutlierFlag);
0430
0431 if (isoutliner) {
0432 ACTS_VERBOSE("Measurement is outlier on a fitter track, copy it over");
0433 trackStateTypes.push_back(TrackStateTypes::Outlier);
0434 continue;
0435 }
0436 if (nTracksShared == 1) {
0437 ACTS_VERBOSE("Measurement is not shared, copy it over");
0438
0439 trackStateTypes.push_back(TrackStateTypes::UnsharedHit);
0440 continue;
0441 } else if (nTracksShared > 1) {
0442 ACTS_VERBOSE("Measurement is shared, copy it over");
0443 trackStateTypes.push_back(TrackStateTypes::SharedHit);
0444 continue;
0445 }
0446 }
0447 }
0448 std::vector<std::size_t> newMeasurementsPerTrack;
0449 std::size_t measurement = 0;
0450 std::size_t nshared = 0;
0451
0452
0453
0454
0455 for (std::size_t index = 0; auto ts : track.trackStatesReversed()) {
0456 if (ts.typeFlags().test(Acts::TrackStateFlag::OutlierFlag) ||
0457 ts.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0458 if (!ts.hasReferenceSurface()) {
0459 ACTS_DEBUG("Track state has no reference surface");
0460 continue;
0461 }
0462
0463 std::size_t ivolume = ts.referenceSurface().geometryId().volume();
0464 auto volume_it = m_cfg.volumeMap.find(ivolume);
0465 if (volume_it == m_cfg.volumeMap.end()) {
0466 ACTS_ERROR("Volume " << ivolume << " not found in the volume map");
0467 continue;
0468 }
0469
0470 std::size_t detectorID = volume_it->second;
0471
0472 const auto& detector = m_cfg.detectorConfigs.at(detectorID);
0473
0474 measurement = measurementsPerTrack[index];
0475
0476 auto it = nTracksPerMeasurement.find(measurement);
0477 if (it == nTracksPerMeasurement.end()) {
0478 index++;
0479 continue;
0480 }
0481 auto nTracksShared = it->second;
0482
0483
0484
0485
0486 for (const auto& hitSelection : optionalHitSelections) {
0487 hitSelection(track, ts, trackStateTypes[index]);
0488 }
0489
0490 if (trackStateTypes[index] == TrackStateTypes::RejectedHit) {
0491 ACTS_DEBUG("Dropping rejected hit");
0492 } else if (trackStateTypes[index] != TrackStateTypes::SharedHit) {
0493 ACTS_DEBUG("Good TSOS, copy hit");
0494 newMeasurementsPerTrack.push_back(measurement);
0495
0496
0497
0498 } else if (nshared >= m_cfg.maxShared) {
0499 ACTS_DEBUG("Too many shared hit, drop it");
0500 }
0501
0502
0503 else {
0504 ACTS_DEBUG("Try to recover shared hit ");
0505 if (nTracksShared <= m_cfg.maxSharedTracksPerMeasurement &&
0506 trackScore > m_cfg.minScoreSharedTracks &&
0507 !detector.sharedHitsFlag) {
0508 ACTS_DEBUG("Accepted hit shared with " << nTracksShared << " tracks");
0509 newMeasurementsPerTrack.push_back(measurement);
0510 nshared++;
0511 } else {
0512 ACTS_DEBUG("Rejected hit shared with " << nTracksShared << " tracks");
0513 }
0514 }
0515 index++;
0516 }
0517 }
0518
0519 if (newMeasurementsPerTrack.size() < m_cfg.minUnshared) {
0520 return false;
0521 } else {
0522 return true;
0523 }
0524 }
0525
0526 }