Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-11 07:49:41

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/AmbiguityResolution/AmbiguityNetworkConcept.hpp"
0012 #include "Acts/Definitions/Units.hpp"
0013 #include "Acts/EventData/TrackContainer.hpp"
0014 #include "Acts/Utilities/Delegate.hpp"
0015 #include "Acts/Utilities/Logger.hpp"
0016 
0017 #include <cstddef>
0018 #include <map>
0019 #include <memory>
0020 #include <string>
0021 #include <tuple>
0022 #include <vector>
0023 
0024 namespace Acts {
0025 
0026 /// Generic implementation of the machine learning ambiguity resolution
0027 /// Contains method for data preparations
0028 template <AmbiguityNetworkConcept AmbiguityNetwork>
0029 class AmbiguityResolutionML {
0030  public:
0031   struct Config {
0032     /// Path to the model file for the duplicate neural network
0033     std::string inputDuplicateNN = "";
0034     /// Minimum number of measurement to form a track.
0035     std::size_t nMeasurementsMin = 7;
0036   };
0037   /// Construct the ambiguity resolution algorithm.
0038   ///
0039   /// @param cfg is the algorithm configuration
0040   /// @param logger is the logging instance
0041   explicit AmbiguityResolutionML(const Config& cfg,
0042                                  std::unique_ptr<const Logger> logger =
0043                                      getDefaultLogger("AmbiguityResolutionML",
0044                                                       Logging::INFO))
0045       : m_cfg{cfg},
0046         m_duplicateClassifier(m_cfg.inputDuplicateNN.c_str()),
0047         m_logger{std::move(logger)} {}
0048 
0049   /// Associate the hits to the tracks
0050   ///
0051   /// This algorithm performs the mapping of hits ID to track ID. Our final goal
0052   /// is too loop over all the tracks (and their associated hits) by order of
0053   /// decreasing number hits for this we use a multimap where the key is the
0054   /// number of hits as this will automatically perform the sorting.
0055   ///
0056   /// @param tracks is the input track container
0057   /// @param sourceLinkHash is the hash function for the source link, will be used to associate to tracks
0058   /// @param sourceLinkEquality is the equality function for the source link used used to associated hits to tracks
0059   /// @return an ordered list containing pairs of track ID and associated measurement ID
0060   template <TrackContainerFrontend track_container_t,
0061             typename source_link_hash_t, typename source_link_equality_t>
0062   std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0063   mapTrackHits(const track_container_t& tracks,
0064                const source_link_hash_t& sourceLinkHash,
0065                const source_link_equality_t& sourceLinkEquality) const {
0066     // A map to store (and generate) the measurement index for each source link
0067     auto measurementIndexMap =
0068         std::unordered_map<SourceLink, std::size_t, source_link_hash_t,
0069                            source_link_equality_t>(0, sourceLinkHash,
0070                                                    sourceLinkEquality);
0071 
0072     // A map to store the track Id and their associated measurements ID, a
0073     // multimap is used to automatically sort the tracks by the number of
0074     // measurements
0075     std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0076         trackMap;
0077     std::size_t trackIndex = 0;
0078     std::vector<std::size_t> measurements;
0079     // Loop over all the trajectories in the events
0080     for (const auto& track : tracks) {
0081       // Kick out tracks that do not fulfill our initial requirements
0082       if (track.nMeasurements() < m_cfg.nMeasurementsMin) {
0083         continue;
0084       }
0085       measurements.clear();
0086       for (auto ts : track.trackStatesReversed()) {
0087         if (ts.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0088           SourceLink sourceLink = ts.getUncalibratedSourceLink();
0089           // assign a new measurement index if the source link was not seen yet
0090           auto emplace = measurementIndexMap.try_emplace(
0091               sourceLink, measurementIndexMap.size());
0092           measurements.push_back(emplace.first->second);
0093         }
0094       }
0095       trackMap.emplace(track.nMeasurements(),
0096                        std::make_pair(trackIndex, measurements));
0097       ++trackIndex;
0098     }
0099     return trackMap;
0100   }
0101 
0102   /// Select the track associated with each cluster
0103   ///
0104   /// In this algorithm the call the neural network to score the tracks and then
0105   /// select the track with the highest score in each cluster
0106   ///
0107   /// @param clusters is a map of clusters, each cluster correspond to a vector of track ID
0108   /// @param tracks is the input track container
0109   /// @return a vector of trackID corresponding tho the good tracks
0110   template <TrackContainerFrontend track_container_t>
0111   std::vector<std::size_t> solveAmbiguity(
0112       std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0113       const track_container_t& tracks) const {
0114     std::vector<std::vector<float>> outputTensor =
0115         m_duplicateClassifier.inferScores(clusters, tracks);
0116     std::vector<std::size_t> goodTracks =
0117         m_duplicateClassifier.trackSelection(clusters, outputTensor);
0118 
0119     return goodTracks;
0120   }
0121 
0122  private:
0123   // Configuration
0124   Config m_cfg;
0125 
0126   // The neural network for duplicate classification, the network
0127   // implementation is chosen with the AmbiguityNetwork template parameter
0128   AmbiguityNetwork m_duplicateClassifier;
0129 
0130   /// Logging instance
0131   std::unique_ptr<const Logger> m_logger = nullptr;
0132 
0133   /// Private access to logging instance
0134   const Logger& logger() const { return *m_logger; }
0135 };
0136 
0137 }  // namespace Acts