Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-27 07:26:17

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/AmbiguityResolution/AmbiguityNetworkConcept.hpp"
0012 #include "Acts/Utilities/Logger.hpp"
0013 
0014 #include <cstddef>
0015 #include <map>
0016 #include <memory>
0017 #include <string>
0018 #include <vector>
0019 
0020 namespace Acts {
0021 
0022 /// Generic implementation of the machine learning ambiguity resolution
0023 /// Contains method for data preparations
0024 template <AmbiguityNetworkConcept AmbiguityNetwork>
0025 class AmbiguityResolutionML {
0026  public:
0027   /// @brief Configuration for the ambiguity resolution algorithm.
0028   struct Config {
0029     /// Path to the model file for the duplicate neural network
0030     std::string inputDuplicateNN = "";
0031     /// Minimum number of measurement to form a track.
0032     std::size_t nMeasurementsMin = 7;
0033   };
0034   /// Construct the ambiguity resolution algorithm.
0035   ///
0036   /// @param cfg is the algorithm configuration
0037   /// @param logger is the logging instance
0038   explicit AmbiguityResolutionML(const Config& cfg,
0039                                  std::unique_ptr<const Logger> logger =
0040                                      getDefaultLogger("AmbiguityResolutionML",
0041                                                       Logging::INFO))
0042       : m_cfg{cfg},
0043         m_duplicateClassifier(m_cfg.inputDuplicateNN.c_str()),
0044         m_logger{std::move(logger)} {}
0045 
0046   /// Associate the hits to the tracks
0047   ///
0048   /// This algorithm performs the mapping of hits ID to track ID. Our final goal
0049   /// is too loop over all the tracks (and their associated hits) by order of
0050   /// decreasing number hits for this we use a multimap where the key is the
0051   /// number of hits as this will automatically perform the sorting.
0052   ///
0053   /// @param tracks is the input track container
0054   /// @param sourceLinkHash is the hash function for the source link, will be used to associate to tracks
0055   /// @param sourceLinkEquality is the equality function for the source link used used to associated hits to tracks
0056   /// @return an ordered list containing pairs of track ID and associated measurement ID
0057   template <TrackContainerFrontend track_container_t,
0058             typename source_link_hash_t, typename source_link_equality_t>
0059   std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0060   mapTrackHits(const track_container_t& tracks,
0061                const source_link_hash_t& sourceLinkHash,
0062                const source_link_equality_t& sourceLinkEquality) const {
0063     // A map to store (and generate) the measurement index for each source link
0064     auto measurementIndexMap =
0065         std::unordered_map<SourceLink, std::size_t, source_link_hash_t,
0066                            source_link_equality_t>(0, sourceLinkHash,
0067                                                    sourceLinkEquality);
0068 
0069     // A map to store the track Id and their associated measurements ID, a
0070     // multimap is used to automatically sort the tracks by the number of
0071     // measurements
0072     std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0073         trackMap;
0074     std::size_t trackIndex = 0;
0075     std::vector<std::size_t> measurements;
0076     // Loop over all the trajectories in the events
0077     for (const auto& track : tracks) {
0078       // Kick out tracks that do not fulfill our initial requirements
0079       if (track.nMeasurements() < m_cfg.nMeasurementsMin) {
0080         continue;
0081       }
0082       measurements.clear();
0083       for (auto ts : track.trackStatesReversed()) {
0084         if (ts.typeFlags().isMeasurement()) {
0085           SourceLink sourceLink = ts.getUncalibratedSourceLink();
0086           // assign a new measurement index if the source link was not seen yet
0087           auto emplace = measurementIndexMap.try_emplace(
0088               sourceLink, measurementIndexMap.size());
0089           measurements.push_back(emplace.first->second);
0090         }
0091       }
0092       trackMap.emplace(track.nMeasurements(),
0093                        std::make_pair(trackIndex, measurements));
0094       ++trackIndex;
0095     }
0096     return trackMap;
0097   }
0098 
0099   /// Select the track associated with each cluster
0100   ///
0101   /// In this algorithm the call the neural network to score the tracks and then
0102   /// select the track with the highest score in each cluster
0103   ///
0104   /// @param clusters is a map of clusters, each cluster correspond to a vector of track ID
0105   /// @param tracks is the input track container
0106   /// @return a vector of trackID corresponding tho the good tracks
0107   template <TrackContainerFrontend track_container_t>
0108   std::vector<std::size_t> solveAmbiguity(
0109       std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0110       const track_container_t& tracks) const {
0111     std::vector<std::vector<float>> outputTensor =
0112         m_duplicateClassifier.inferScores(clusters, tracks);
0113     std::vector<std::size_t> goodTracks =
0114         m_duplicateClassifier.trackSelection(clusters, outputTensor);
0115 
0116     return goodTracks;
0117   }
0118 
0119  private:
0120   // Configuration
0121   Config m_cfg;
0122 
0123   // The neural network for duplicate classification, the network
0124   // implementation is chosen with the AmbiguityNetwork template parameter
0125   AmbiguityNetwork m_duplicateClassifier;
0126 
0127   /// Logging instance
0128   std::unique_ptr<const Logger> m_logger = nullptr;
0129 
0130   /// Private access to logging instance
0131   const Logger& logger() const { return *m_logger; }
0132 };
0133 
0134 }  // namespace Acts