File indexing completed on 2025-01-18 09:10:44
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/AmbiguityResolution/ScoreBasedAmbiguityResolution.hpp"
0012 #include "Acts/Definitions/Units.hpp"
0013 #include "Acts/EventData/TrackContainerFrontendConcept.hpp"
0014 #include "Acts/Utilities/VectorHelpers.hpp"
0015
0016 #include <unordered_map>
0017
0018 namespace Acts {
0019
0020 inline const Logger& ScoreBasedAmbiguityResolution::logger() const {
0021 return *m_logger;
0022 }
0023
0024 template <TrackContainerFrontend track_container_t>
0025 std::vector<std::vector<ScoreBasedAmbiguityResolution::TrackFeatures>>
0026 ScoreBasedAmbiguityResolution::computeInitialState(
0027 const track_container_t& tracks) const {
0028 ACTS_VERBOSE("Starting to compute initial state");
0029 std::vector<std::vector<TrackFeatures>> trackFeaturesVectors;
0030 trackFeaturesVectors.reserve(tracks.size());
0031
0032 for (const auto& track : tracks) {
0033 int numberOfDetectors = m_cfg.detectorConfigs.size();
0034
0035 std::vector<TrackFeatures> trackFeaturesVector(numberOfDetectors);
0036
0037 for (const auto& ts : track.trackStatesReversed()) {
0038 if (!ts.hasReferenceSurface()) {
0039 ACTS_DEBUG("Track state has no reference surface");
0040 continue;
0041 }
0042 auto iVolume = ts.referenceSurface().geometryId().volume();
0043 auto volume_it = m_cfg.volumeMap.find(iVolume);
0044 if (volume_it == m_cfg.volumeMap.end()) {
0045 ACTS_ERROR("Volume " << iVolume << "not found in the volume map");
0046 continue;
0047 }
0048 auto detectorId = volume_it->second;
0049
0050 if (ts.typeFlags().test(Acts::TrackStateFlag::HoleFlag)) {
0051 ACTS_VERBOSE("Track state type is HoleFlag");
0052 trackFeaturesVector[detectorId].nHoles++;
0053 } else if (ts.typeFlags().test(Acts::TrackStateFlag::OutlierFlag)) {
0054 ACTS_VERBOSE("Track state type is OutlierFlag");
0055 trackFeaturesVector[detectorId].nOutliers++;
0056
0057 } else if (ts.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0058 ACTS_VERBOSE("Track state type is MeasurementFlag");
0059
0060 if (ts.typeFlags().test(Acts::TrackStateFlag::SharedHitFlag)) {
0061 trackFeaturesVector[detectorId].nSharedHits++;
0062 }
0063 trackFeaturesVector[detectorId].nHits++;
0064 }
0065 }
0066 trackFeaturesVectors.push_back(std::move(trackFeaturesVector));
0067 }
0068
0069 return trackFeaturesVectors;
0070 }
0071
0072 template <TrackContainerFrontend track_container_t>
0073 std::vector<double> Acts::ScoreBasedAmbiguityResolution::simpleScore(
0074 const track_container_t& tracks,
0075 const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
0076 const OptionalCuts<typename track_container_t::ConstTrackProxy>&
0077 optionalCuts) const {
0078 std::vector<double> trackScore;
0079 trackScore.reserve(tracks.size());
0080
0081 int iTrack = 0;
0082
0083 ACTS_VERBOSE("Number of detectors: " << m_cfg.detectorConfigs.size());
0084
0085 ACTS_INFO("Starting to score tracks");
0086
0087
0088 for (const auto& track : tracks) {
0089
0090 const auto& trackFeaturesVector = trackFeaturesVectors[iTrack];
0091 double score = 1;
0092 auto pT = Acts::VectorHelpers::perp(track.momentum());
0093 auto eta = Acts::VectorHelpers::eta(track.momentum());
0094 auto phi = Acts::VectorHelpers::phi(track.momentum());
0095
0096 if (pT < m_cfg.pTMin || pT > m_cfg.pTMax) {
0097 score = 0;
0098 iTrack++;
0099 trackScore.push_back(score);
0100 ACTS_DEBUG("Track: " << iTrack
0101 << " has score = 0, due to pT cuts --- pT = " << pT);
0102 continue;
0103 }
0104
0105
0106 if (phi > m_cfg.phiMax || phi < m_cfg.phiMin) {
0107 score = 0;
0108 iTrack++;
0109 trackScore.push_back(score);
0110 ACTS_DEBUG("Track: " << iTrack
0111 << " has score = 0, due to phi cuts --- phi = "
0112 << phi);
0113 continue;
0114 }
0115
0116
0117 if (eta > m_cfg.etaMax || eta < m_cfg.etaMin) {
0118 score = 0;
0119 iTrack++;
0120 trackScore.push_back(score);
0121 ACTS_DEBUG("Track: " << iTrack
0122 << " has score = 0, due to eta cuts --- eta = "
0123 << eta);
0124 continue;
0125 }
0126
0127
0128 for (const auto& cutFunction : optionalCuts.cuts) {
0129 if (cutFunction(track)) {
0130 score = 0;
0131 ACTS_DEBUG("Track: " << iTrack
0132 << " has score = 0, due to optional cuts.");
0133 break;
0134 }
0135 }
0136
0137 if (score == 0) {
0138 iTrack++;
0139 trackScore.push_back(score);
0140 ACTS_DEBUG("Track: " << iTrack << " score : " << score);
0141 continue;
0142 }
0143
0144 for (std::size_t detectorId = 0; detectorId < m_cfg.detectorConfigs.size();
0145 detectorId++) {
0146 const auto& detector = m_cfg.detectorConfigs.at(detectorId);
0147
0148 const auto& trackFeatures = trackFeaturesVector[detectorId];
0149
0150 ACTS_DEBUG("---> Found summary information");
0151 ACTS_DEBUG("---> Detector ID: " << detectorId);
0152 ACTS_DEBUG("---> Number of hits: " << trackFeatures.nHits);
0153 ACTS_DEBUG("---> Number of holes: " << trackFeatures.nHoles);
0154 ACTS_DEBUG("---> Number of outliers: " << trackFeatures.nOutliers);
0155
0156 if ((trackFeatures.nHits < detector.minHits) ||
0157 (trackFeatures.nHoles > detector.maxHoles) ||
0158 (trackFeatures.nOutliers > detector.maxOutliers)) {
0159 score = 0;
0160 ACTS_DEBUG("Track: " << iTrack
0161 << " has score = 0, due to detector cuts");
0162 break;
0163 }
0164 }
0165
0166 if (score == 0) {
0167 iTrack++;
0168 trackScore.push_back(score);
0169 ACTS_DEBUG("Track: " << iTrack << " score : " << score);
0170 continue;
0171 }
0172
0173
0174
0175
0176
0177 ACTS_VERBOSE("Using Simple Scoring function");
0178
0179 score = 100;
0180
0181
0182
0183 for (std::size_t detectorId = 0; detectorId < m_cfg.detectorConfigs.size();
0184 detectorId++) {
0185 const auto& detector = m_cfg.detectorConfigs.at(detectorId);
0186 const auto& trackFeatures = trackFeaturesVector[detectorId];
0187
0188 score += trackFeatures.nHits * detector.hitsScoreWeight;
0189 score += trackFeatures.nHoles * detector.holesScoreWeight;
0190 score += trackFeatures.nOutliers * detector.outliersScoreWeight;
0191 score += trackFeatures.nSharedHits * detector.otherScoreWeight;
0192 }
0193
0194
0195 for (const auto& weightFunction : optionalCuts.weights) {
0196 weightFunction(track, score);
0197 }
0198
0199
0200 if (track.chi2() > 0 && track.nDoF() > 0) {
0201 double p = 1. / std::log10(10. + 10. * track.chi2() / track.nDoF());
0202 if (p > 0) {
0203 score += p;
0204 } else {
0205 score -= 50;
0206 }
0207 }
0208
0209 iTrack++;
0210
0211
0212 trackScore.push_back(score);
0213 ACTS_VERBOSE("Track: " << iTrack << " score: " << score);
0214
0215 }
0216
0217 return trackScore;
0218 }
0219
0220 template <TrackContainerFrontend track_container_t>
0221 std::vector<double> Acts::ScoreBasedAmbiguityResolution::ambiguityScore(
0222 const track_container_t& tracks,
0223 const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
0224 const OptionalCuts<typename track_container_t::ConstTrackProxy>&
0225 optionalCuts) const {
0226 std::vector<double> trackScore;
0227 trackScore.reserve(tracks.size());
0228
0229 ACTS_VERBOSE("Using Ambiguity Scoring function");
0230
0231 int iTrack = 0;
0232
0233 ACTS_VERBOSE("Number of detectors: " << m_cfg.detectorConfigs.size());
0234
0235 ACTS_INFO("Starting to score tracks");
0236
0237
0238 for (const auto& track : tracks) {
0239
0240 const auto& trackFeaturesVector = trackFeaturesVectors[iTrack];
0241 double score = 1;
0242 auto pT = Acts::VectorHelpers::perp(track.momentum());
0243 auto eta = Acts::VectorHelpers::eta(track.momentum());
0244 auto phi = Acts::VectorHelpers::phi(track.momentum());
0245
0246 if (pT < m_cfg.pTMin || pT > m_cfg.pTMax) {
0247 score = 0;
0248 iTrack++;
0249 trackScore.push_back(score);
0250 ACTS_DEBUG("Track: " << iTrack
0251 << " has score = 0, due to pT cuts --- pT = " << pT);
0252 continue;
0253 }
0254
0255
0256 if (phi > m_cfg.phiMax || phi < m_cfg.phiMin) {
0257 score = 0;
0258 iTrack++;
0259 trackScore.push_back(score);
0260 ACTS_DEBUG("Track: " << iTrack
0261 << " has score = 0, due to phi cuts --- phi = "
0262 << phi);
0263 continue;
0264 }
0265
0266
0267 if (eta > m_cfg.etaMax || eta < m_cfg.etaMin) {
0268 score = 0;
0269 iTrack++;
0270 trackScore.push_back(score);
0271 ACTS_DEBUG("Track: " << iTrack
0272 << " has score = 0, due to eta cuts --- eta = "
0273 << eta);
0274 continue;
0275 }
0276
0277
0278 for (const auto& cutFunction : optionalCuts.cuts) {
0279 if (cutFunction(track)) {
0280 score = 0;
0281 ACTS_DEBUG("Track: " << iTrack
0282 << " has score = 0, due to optional cuts.");
0283 break;
0284 }
0285 }
0286
0287 if (score == 0) {
0288 iTrack++;
0289 trackScore.push_back(score);
0290 ACTS_DEBUG("Track: " << iTrack << " score : " << score);
0291 continue;
0292 }
0293
0294 for (std::size_t detectorId = 0; detectorId < m_cfg.detectorConfigs.size();
0295 detectorId++) {
0296 const auto& detector = m_cfg.detectorConfigs.at(detectorId);
0297
0298 const auto& trackFeatures = trackFeaturesVector[detectorId];
0299
0300 ACTS_DEBUG("---> Found summary information");
0301 ACTS_DEBUG("---> Detector ID: " << detectorId);
0302 ACTS_DEBUG("---> Number of hits: " << trackFeatures.nHits);
0303 ACTS_DEBUG("---> Number of holes: " << trackFeatures.nHoles);
0304 ACTS_DEBUG("---> Number of outliers: " << trackFeatures.nOutliers);
0305
0306 if ((trackFeatures.nHits < detector.minHits) ||
0307 (trackFeatures.nHoles > detector.maxHoles) ||
0308 (trackFeatures.nOutliers > detector.maxOutliers)) {
0309 score = 0;
0310 ACTS_DEBUG("Track: " << iTrack
0311 << " has score = 0, due to detector cuts");
0312 break;
0313 }
0314 }
0315
0316 if (score == 0) {
0317 iTrack++;
0318 trackScore.push_back(score);
0319 ACTS_DEBUG("Track: " << iTrack << " score : " << score);
0320 continue;
0321 }
0322
0323
0324 score = std::log10(pT / UnitConstants::MeV) - 1.;
0325
0326 ACTS_DEBUG("Modifier for pT = " << pT << " GeV is : " << score
0327 << " New score now: " << score);
0328
0329 for (std::size_t detectorId = 0; detectorId < m_cfg.detectorConfigs.size();
0330 detectorId++) {
0331 const auto& detector = m_cfg.detectorConfigs.at(detectorId);
0332
0333 const auto& trackFeatures = trackFeaturesVector[detectorId];
0334
0335
0336
0337 std::size_t nHits = trackFeatures.nHits;
0338 if (nHits > detector.maxHits) {
0339 score = score * (detector.maxHits - nHits + 1);
0340 nHits = detector.maxHits;
0341 }
0342 score = score * detector.factorHits[nHits];
0343 ACTS_DEBUG("Modifier for " << nHits
0344 << " hits: " << detector.factorHits[nHits]
0345 << " New score now: " << score);
0346
0347
0348
0349 std::size_t iHoles = trackFeatures.nHoles;
0350 if (iHoles > detector.maxHoles) {
0351 score /= (iHoles - detector.maxHoles + 1);
0352 iHoles = detector.maxHoles;
0353 }
0354 score = score * detector.factorHoles[iHoles];
0355 ACTS_DEBUG("Modifier for " << iHoles
0356 << " holes: " << detector.factorHoles[iHoles]
0357 << " New score now: " << score);
0358 }
0359
0360 for (const auto& scoreFunction : optionalCuts.scores) {
0361 scoreFunction(track, score);
0362 }
0363
0364 if (track.chi2() > 0 && track.nDoF() > 0) {
0365 double chi2 = track.chi2();
0366 int indf = track.nDoF();
0367 double fac = 1. / std::log10(10. + 10. * chi2 / indf);
0368 score = score * fac;
0369 ACTS_DEBUG("Modifier for chi2 = " << chi2 << " and NDF = " << indf
0370 << " is : " << fac
0371 << " New score now: " << score);
0372 }
0373 iTrack++;
0374
0375
0376 trackScore.push_back(score);
0377 ACTS_VERBOSE("Track: " << iTrack << " score: " << score);
0378
0379 }
0380
0381 return trackScore;
0382 }
0383
0384 template <TrackContainerFrontend track_container_t, typename source_link_hash_t,
0385 typename source_link_equality_t>
0386 std::vector<int> Acts::ScoreBasedAmbiguityResolution::solveAmbiguity(
0387 const track_container_t& tracks, source_link_hash_t sourceLinkHash,
0388 source_link_equality_t sourceLinkEquality,
0389 const OptionalCuts<typename track_container_t::ConstTrackProxy>&
0390 optionalCuts) const {
0391 ACTS_INFO("Number of tracks before Ambiguty Resolution: " << tracks.size());
0392
0393
0394
0395 const std::vector<std::vector<TrackFeatures>> trackFeaturesVectors =
0396 computeInitialState<track_container_t>(tracks);
0397
0398 std::vector<double> trackScore;
0399 trackScore.reserve(tracks.size());
0400 if (m_cfg.useAmbiguityFunction) {
0401 trackScore = ambiguityScore(tracks, trackFeaturesVectors, optionalCuts);
0402 } else {
0403 trackScore = simpleScore(tracks, trackFeaturesVectors, optionalCuts);
0404 }
0405
0406 auto MeasurementIndexMap =
0407 std::unordered_map<SourceLink, std::size_t, source_link_hash_t,
0408 source_link_equality_t>(0, sourceLinkHash,
0409 sourceLinkEquality);
0410
0411 std::vector<std::vector<std::size_t>> measurementsPerTrackVector;
0412 std::map<std::size_t, std::size_t> nTracksPerMeasurement;
0413
0414
0415
0416
0417
0418 for (const auto& track : tracks) {
0419 std::vector<std::size_t> measurementsPerTrack;
0420 for (const auto& ts : track.trackStatesReversed()) {
0421 if (!ts.typeFlags().test(Acts::TrackStateFlag::OutlierFlag) &&
0422 !ts.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0423 continue;
0424 }
0425 Acts::SourceLink sourceLink = ts.getUncalibratedSourceLink();
0426
0427 auto emplace = MeasurementIndexMap.try_emplace(
0428 sourceLink, MeasurementIndexMap.size());
0429 std::size_t iMeasurement = emplace.first->second;
0430 measurementsPerTrack.push_back(iMeasurement);
0431 if (nTracksPerMeasurement.find(iMeasurement) ==
0432 nTracksPerMeasurement.end()) {
0433 nTracksPerMeasurement[iMeasurement] = 0;
0434 }
0435 nTracksPerMeasurement[iMeasurement]++;
0436 }
0437 measurementsPerTrackVector.push_back(std::move(measurementsPerTrack));
0438 }
0439
0440 std::vector<int> goodTracks;
0441 int cleanTrackIndex = 0;
0442
0443 auto optionalHitSelections = optionalCuts.hitSelections;
0444
0445
0446
0447
0448
0449
0450 for (std::size_t iTrack = 0; const auto& track : tracks) {
0451
0452 auto trackFeaturesVector = trackFeaturesVectors.at(iTrack);
0453 bool trkCouldBeAccepted = true;
0454 for (std::size_t detectorId = 0; detectorId < m_cfg.detectorConfigs.size();
0455 detectorId++) {
0456 auto detector = m_cfg.detectorConfigs.at(detectorId);
0457 if (trackFeaturesVector[detectorId].nSharedHits >
0458 detector.maxSharedHits) {
0459 trkCouldBeAccepted = false;
0460 break;
0461 }
0462 }
0463 if (!trkCouldBeAccepted) {
0464 iTrack++;
0465 continue;
0466 }
0467
0468 trkCouldBeAccepted = getCleanedOutTracks(
0469 track, trackScore[iTrack], measurementsPerTrackVector[iTrack],
0470 nTracksPerMeasurement, optionalHitSelections);
0471 if (trkCouldBeAccepted) {
0472 cleanTrackIndex++;
0473 if (trackScore[iTrack] >= m_cfg.minScore) {
0474 goodTracks.push_back(track.index());
0475 }
0476 }
0477 iTrack++;
0478 }
0479 ACTS_INFO("Number of clean tracks: " << cleanTrackIndex);
0480 ACTS_VERBOSE("Min score: " << m_cfg.minScore);
0481 ACTS_INFO("Number of Good tracks: " << goodTracks.size());
0482 return goodTracks;
0483 }
0484
0485 template <TrackProxyConcept track_proxy_t>
0486 bool Acts::ScoreBasedAmbiguityResolution::getCleanedOutTracks(
0487 const track_proxy_t& track, const double& trackScore,
0488 const std::vector<std::size_t>& measurementsPerTrack,
0489 const std::map<std::size_t, std::size_t>& nTracksPerMeasurement,
0490 const std::vector<
0491 std::function<void(const track_proxy_t&,
0492 const typename track_proxy_t::ConstTrackStateProxy&,
0493 TrackStateTypes&)>>& optionalHitSelections) const {
0494
0495
0496 std::vector<TrackStateTypes> trackStateTypes;
0497
0498
0499 for (std::size_t index = 0; const auto& ts : track.trackStatesReversed()) {
0500 if (ts.typeFlags().test(Acts::TrackStateFlag::OutlierFlag) ||
0501 ts.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0502 std::size_t iMeasurement = measurementsPerTrack[index];
0503 auto it = nTracksPerMeasurement.find(iMeasurement);
0504 if (it == nTracksPerMeasurement.end()) {
0505 trackStateTypes.push_back(TrackStateTypes::OtherTrackStateType);
0506 index++;
0507 continue;
0508 }
0509
0510 std::size_t nTracksShared = it->second;
0511 auto isoutliner = ts.typeFlags().test(Acts::TrackStateFlag::OutlierFlag);
0512
0513 if (isoutliner) {
0514 ACTS_VERBOSE("Measurement is outlier on a fitter track, copy it over");
0515 trackStateTypes.push_back(TrackStateTypes::Outlier);
0516 continue;
0517 }
0518 if (nTracksShared == 1) {
0519 ACTS_VERBOSE("Measurement is not shared, copy it over");
0520
0521 trackStateTypes.push_back(TrackStateTypes::UnsharedHit);
0522 continue;
0523 } else if (nTracksShared > 1) {
0524 ACTS_VERBOSE("Measurement is shared, copy it over");
0525 trackStateTypes.push_back(TrackStateTypes::SharedHit);
0526 continue;
0527 }
0528 }
0529 }
0530 std::vector<std::size_t> newMeasurementsPerTrack;
0531 std::size_t measurement = 0;
0532 std::size_t nshared = 0;
0533
0534
0535
0536
0537 for (std::size_t index = 0; auto ts : track.trackStatesReversed()) {
0538 if (ts.typeFlags().test(Acts::TrackStateFlag::OutlierFlag) ||
0539 ts.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0540 if (!ts.hasReferenceSurface()) {
0541 ACTS_DEBUG("Track state has no reference surface");
0542 continue;
0543 }
0544 measurement = measurementsPerTrack[index];
0545
0546 auto it = nTracksPerMeasurement.find(measurement);
0547 if (it == nTracksPerMeasurement.end()) {
0548 index++;
0549 continue;
0550 }
0551 auto nTracksShared = it->second;
0552
0553
0554
0555
0556 for (const auto& hitSelection : optionalHitSelections) {
0557 hitSelection(track, ts, trackStateTypes[index]);
0558 }
0559
0560 if (trackStateTypes[index] == TrackStateTypes::RejectedHit) {
0561 ACTS_DEBUG("Dropping rejected hit");
0562 } else if (trackStateTypes[index] != TrackStateTypes::SharedHit) {
0563 ACTS_DEBUG("Good TSOS, copy hit");
0564 newMeasurementsPerTrack.push_back(measurement);
0565
0566
0567
0568 } else if (nshared >= m_cfg.maxShared) {
0569 ACTS_DEBUG("Too many shared hit, drop it");
0570 }
0571
0572
0573 else {
0574 ACTS_DEBUG("Try to recover shared hit ");
0575 if (nTracksShared <= m_cfg.maxSharedTracksPerMeasurement &&
0576 trackScore > m_cfg.minScoreSharedTracks) {
0577 ACTS_DEBUG("Accepted hit shared with " << nTracksShared << " tracks");
0578 newMeasurementsPerTrack.push_back(measurement);
0579 nshared++;
0580 } else {
0581 ACTS_DEBUG("Rejected hit shared with " << nTracksShared << " tracks");
0582 }
0583 }
0584 index++;
0585 }
0586 }
0587
0588 if (newMeasurementsPerTrack.size() < 3) {
0589 return false;
0590 } else {
0591 return true;
0592 }
0593 }
0594
0595 }