Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:27:32

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2024 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/Definitions/Units.hpp"
0012 #include "Acts/EventData/TrackContainer.hpp"
0013 #include "Acts/Utilities/Delegate.hpp"
0014 #include "Acts/Utilities/Logger.hpp"
0015 
0016 #include <cstddef>
0017 #include <map>
0018 #include <memory>
0019 #include <string>
0020 #include <tuple>
0021 #include <vector>
0022 
0023 #include <boost/container/flat_map.hpp>
0024 #include <boost/container/flat_set.hpp>
0025 
0026 namespace Acts {
0027 
0028 /// Generic implementation of the score based ambiguity resolution.
0029 /// The alhorithm is based on the following steps:
0030 /// 1) Compute the initial state of the tracks
0031 /// 2) Compute the score of each track
0032 /// 3) Removes hits that are not good enough for each track
0033 /// 4) Remove tracks that have a score below a certain threshold or not have
0034 /// enough hits
0035 /// 5) Remove tracks that are not good enough based on cuts Contains method for
0036 /// data preparations
0037 class ScoreBasedAmbiguityResolution {
0038  public:
0039   /// @brief Detector configuration struct : contains the configuration for each detector
0040   ///
0041   /// The configuration can be saved in a json file and loaded from there.
0042   ///
0043   struct DetectorConfig {
0044     int hitsScoreWeight = 0;
0045     int holesScoreWeight = 0;
0046     int outliersScoreWeight = 0;
0047     int otherScoreWeight = 0;
0048 
0049     std::size_t minHits = 0;
0050     std::size_t maxHits = 0;
0051     std::size_t maxHoles = 0;
0052     std::size_t maxOutliers = 0;
0053     std::size_t maxSharedHits = 0;
0054 
0055     /// if true, the shared hits are considered as bad hits for this detector
0056     bool sharedHitsFlag = false;
0057 
0058     std::size_t detectorId = 0;
0059 
0060     /// a list of values from  0 to 1, the higher number of hits, higher value
0061     /// in the list is multiplied to ambuiguity score applied only if
0062     /// useAmbiguityFunction is true
0063     std::vector<double> factorHits = {1.0};
0064 
0065     /// a list of values from  0 to 1, the higher number of holes, lower value
0066     /// in the list is multiplied to ambuiguity score applied only if
0067     /// useAmbiguityFunction is true
0068     std::vector<double> factorHoles = {1.0};
0069   };
0070 
0071   /// @brief  TrackFeatures struct : contains the features that are counted for each track.
0072   ///
0073   /// The trackFeatures is used to compute the score of each track
0074   struct TrackFeatures {
0075     std::size_t nHits = 0;
0076     std::size_t nHoles = 0;
0077     std::size_t nOutliers = 0;
0078     std::size_t nSharedHits = 0;
0079   };
0080 
0081   /// @brief MeasurementInfo : contains the measurement ID and the detector ID
0082   struct MeasurementInfo {
0083     std::size_t iMeasurement = 0;
0084     std::size_t detectorId = 0;
0085     bool isOutlier = false;
0086   };
0087 
0088   /// @brief Configuration struct : contains the configuration for the ambiguity resolution.
0089   struct Config {
0090     std::map<std::size_t, std::size_t> volumeMap = {{0, 0}};
0091     std::vector<DetectorConfig> detectorConfigs;
0092     /// minimum score for any track
0093     double minScore = 0;
0094     /// minimum score for shared tracks
0095     double minScoreSharedTracks = 0;
0096     /// maximum number of shared tracks per measurement
0097     std::size_t maxSharedTracksPerMeasurement = 10;
0098     /// maximum number of shared hit per track
0099     std::size_t maxShared = 5;
0100 
0101     double pTMin = 0 * UnitConstants::GeV;
0102     double pTMax = 1e5 * UnitConstants::GeV;
0103 
0104     double phiMin = -M_PI * UnitConstants::rad;
0105     double phiMax = M_PI * UnitConstants::rad;
0106 
0107     double etaMin = -5;
0108     double etaMax = 5;
0109 
0110     // if true, the ambiguity score is computed based on a different function.
0111     bool useAmbiguityFunction = false;
0112   };
0113 
0114   /// @brief OptionalCuts struct : contains the optional cuts to be applied.
0115   ///
0116   /// The optional cuts,weights and score are used to remove tracks that are not
0117   /// good enough, based on some criteria. Users are free to add their own cuts
0118   /// with the help of this struct.
0119   template <typename track_container_t, typename traj_t,
0120             template <typename> class holder_t, bool ReadOnly>
0121   struct OptionalCuts {
0122     using OptionalFilter =
0123         std::function<bool(const Acts::TrackProxy<track_container_t, traj_t,
0124                                                   holder_t, ReadOnly>&)>;
0125 
0126     using OptionalScoreModifier = std::function<void(
0127         const Acts::TrackProxy<track_container_t, traj_t, holder_t, ReadOnly>&,
0128         double&)>;
0129     std::vector<OptionalFilter> cuts = {};
0130     std::vector<OptionalScoreModifier> weights = {};
0131 
0132     /// applied only if useAmbiguityFunction is true
0133     std::vector<OptionalScoreModifier> scores = {};
0134   };
0135 
0136   ScoreBasedAmbiguityResolution(
0137       const Config& cfg,
0138       std::unique_ptr<const Logger> logger =
0139           getDefaultLogger("ScoreBasedAmbiguityResolution", Logging::INFO))
0140       : m_cfg{cfg}, m_logger{std::move(logger)} {}
0141 
0142   /// Compute the initial state of the tracks.
0143   ///
0144   /// @param tracks is the input track container
0145   /// @param sourceLinkHash is the  source links
0146   /// @param sourceLinkEquality is the equality function for the source links
0147   /// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
0148   /// @return a vector of the initial state of the tracks
0149   template <typename track_container_t, typename traj_t,
0150             template <typename> class holder_t, typename source_link_hash_t,
0151             typename source_link_equality_t>
0152   std::vector<std::vector<MeasurementInfo>> computeInitialState(
0153       const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
0154       source_link_hash_t sourceLinkHash,
0155       source_link_equality_t sourceLinkEquality,
0156       std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors) const;
0157 
0158   /// Compute the score of each track.
0159   ///
0160   /// @param tracks is the input track container
0161   /// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
0162   /// @param optionalCuts is the user defined optional cuts to be applied.
0163   /// @return a vector of scores for each track
0164   template <typename track_container_t, typename traj_t,
0165             template <typename> class holder_t, bool ReadOnly>
0166   std::vector<double> simpleScore(
0167       const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
0168       const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
0169       const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
0170           optionalCuts = {}) const;
0171 
0172   /// Compute the score of each track based on the ambiguity function.
0173   ///
0174   /// @param tracks is the input track container
0175   /// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
0176   /// @param optionalCuts is the user defined optional cuts to be applied.
0177   /// @return a vector of scores for each track
0178   template <typename track_container_t, typename traj_t,
0179             template <typename> class holder_t, bool ReadOnly>
0180   std::vector<double> ambiguityScore(
0181       const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
0182       const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
0183       const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
0184           optionalCuts = {}) const;
0185 
0186   /// Remove hits that are not good enough for each track and removes tracks
0187   /// that have a score below a certain threshold or not enough hits.
0188   ///
0189   /// @brief Remove tracks that are not good enough based on cuts
0190   /// @param trackScore is the score of each track
0191   /// @param trackFeaturesVectors is the trackFeatures map for each track
0192   /// @param measurementsPerTrack is the list of measurements for each track
0193   /// @return a vector of IDs of the tracks we want to keep
0194   std::vector<bool> getCleanedOutTracks(
0195       const std::vector<double>& trackScore,
0196       const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
0197       const std::vector<std::vector<MeasurementInfo>>& measurementsPerTrack)
0198       const;
0199 
0200   /// Remove tracks that are bad based on cuts and weighted scores.
0201   ///
0202   /// @brief Remove tracks that are not good enough
0203   /// @param tracks is the input track container
0204   /// @param measurementsPerTrack is the list of measurements for each track
0205   /// @param trackFeaturesVectors is the map of detector id to trackFeatures for each track
0206   /// @param optionalCuts is the optional cuts to be applied
0207   /// @return a vector of IDs of the tracks we want to keep
0208   template <typename track_container_t, typename traj_t,
0209             template <typename> class holder_t, bool ReadOnly>
0210   std::vector<int> solveAmbiguity(
0211       const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
0212       const std::vector<std::vector<MeasurementInfo>>& measurementsPerTrack,
0213       const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
0214       const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
0215           optionalCuts = {}) const;
0216 
0217  private:
0218   Config m_cfg;
0219 
0220   /// Logging instance
0221   std::unique_ptr<const Logger> m_logger = nullptr;
0222 
0223   /// Private access to logging instance
0224   const Logger& logger() const;
0225 };
0226 
0227 }  // namespace Acts
0228 
0229 #include "Acts/AmbiguityResolution/ScoreBasedAmbiguityResolution.ipp"