Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:10:44

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/AmbiguityResolution/AmbiguityNetworkConcept.hpp"
0012 #include "Acts/Definitions/Units.hpp"
0013 #include "Acts/EventData/TrackContainer.hpp"
0014 #include "Acts/Utilities/Delegate.hpp"
0015 #include "Acts/Utilities/Logger.hpp"
0016 
0017 #include <cstddef>
0018 #include <map>
0019 #include <memory>
0020 #include <string>
0021 #include <tuple>
0022 #include <vector>
0023 
0024 namespace Acts {
0025 
0026 /// Generic implementation of the machine learning ambiguity resolution
0027 /// Contains method for data preparations
0028 template <AmbiguityNetworkConcept AmbiguityNetwork>
0029 class AmbiguityResolutionML {
0030  public:
0031   struct Config {
0032     /// Path to the model file for the duplicate neural network
0033     std::string inputDuplicateNN = "";
0034     /// Minimum number of measurement to form a track.
0035     std::size_t nMeasurementsMin = 7;
0036   };
0037   /// Construct the ambiguity resolution algorithm.
0038   ///
0039   /// @param cfg is the algorithm configuration
0040   /// @param logger is the logging instance
0041   AmbiguityResolutionML(const Config& cfg,
0042                         std::unique_ptr<const Logger> logger = getDefaultLogger(
0043                             "AmbiguityResolutionML", Logging::INFO))
0044       : m_cfg{cfg},
0045         m_duplicateClassifier(m_cfg.inputDuplicateNN.c_str()),
0046         m_logger{std::move(logger)} {}
0047 
0048   /// Associate the hits to the tracks
0049   ///
0050   /// This algorithm performs the mapping of hits ID to track ID. Our final goal
0051   /// is too loop over all the tracks (and their associated hits) by order of
0052   /// decreasing number hits for this we use a multimap where the key is the
0053   /// number of hits as this will automatically perform the sorting.
0054   ///
0055   /// @param tracks is the input track container
0056   /// @param sourceLinkHash is the hash function for the source link, will be used to associate to tracks
0057   /// @param sourceLinkEquality is the equality function for the source link used used to associated hits to tracks
0058   /// @return an ordered list containing pairs of track ID and associated measurement ID
0059   template <TrackContainerFrontend track_container_t,
0060             typename source_link_hash_t, typename source_link_equality_t>
0061   std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0062   mapTrackHits(const track_container_t& tracks,
0063                const source_link_hash_t& sourceLinkHash,
0064                const source_link_equality_t& sourceLinkEquality) const {
0065     // A map to store (and generate) the measurement index for each source link
0066     auto measurementIndexMap =
0067         std::unordered_map<SourceLink, std::size_t, source_link_hash_t,
0068                            source_link_equality_t>(0, sourceLinkHash,
0069                                                    sourceLinkEquality);
0070 
0071     // A map to store the track Id and their associated measurements ID, a
0072     // multimap is used to automatically sort the tracks by the number of
0073     // measurements
0074     std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0075         trackMap;
0076     std::size_t trackIndex = 0;
0077     std::vector<std::size_t> measurements;
0078     // Loop over all the trajectories in the events
0079     for (const auto& track : tracks) {
0080       // Kick out tracks that do not fulfill our initial requirements
0081       if (track.nMeasurements() < m_cfg.nMeasurementsMin) {
0082         continue;
0083       }
0084       measurements.clear();
0085       for (auto ts : track.trackStatesReversed()) {
0086         if (ts.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0087           SourceLink sourceLink = ts.getUncalibratedSourceLink();
0088           // assign a new measurement index if the source link was not seen yet
0089           auto emplace = measurementIndexMap.try_emplace(
0090               sourceLink, measurementIndexMap.size());
0091           measurements.push_back(emplace.first->second);
0092         }
0093       }
0094       trackMap.emplace(track.nMeasurements(),
0095                        std::make_pair(trackIndex, measurements));
0096       ++trackIndex;
0097     }
0098     return trackMap;
0099   }
0100 
0101   /// Select the track associated with each cluster
0102   ///
0103   /// In this algorithm the call the neural network to score the tracks and then
0104   /// select the track with the highest score in each cluster
0105   ///
0106   /// @param clusters is a map of clusters, each cluster correspond to a vector of track ID
0107   /// @param tracks is the input track container
0108   /// @return a vector of trackID corresponding tho the good tracks
0109   template <TrackContainerFrontend track_container_t>
0110   std::vector<std::size_t> solveAmbiguity(
0111       std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0112       const track_container_t& tracks) const {
0113     std::vector<std::vector<float>> outputTensor =
0114         m_duplicateClassifier.inferScores(clusters, tracks);
0115     std::vector<std::size_t> goodTracks =
0116         m_duplicateClassifier.trackSelection(clusters, outputTensor);
0117 
0118     return goodTracks;
0119   }
0120 
0121  private:
0122   // Configuration
0123   Config m_cfg;
0124 
0125   // The neural network for duplicate classification, the network
0126   // implementation is chosen with the AmbiguityNetwork template parameter
0127   AmbiguityNetwork m_duplicateClassifier;
0128 
0129   /// Logging instance
0130   std::unique_ptr<const Logger> m_logger = nullptr;
0131 
0132   /// Private access to logging instance
0133   const Logger& logger() const { return *m_logger; }
0134 };
0135 
0136 }  // namespace Acts