File indexing completed on 2025-01-18 09:11:41
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/TrackFindingML/AmbiguityResolutionMLAlgorithm.hpp"
0010
0011 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0012 #include "ActsExamples/EventData/Measurement.hpp"
0013 #include "ActsExamples/Framework/ProcessCode.hpp"
0014
0015 #include <iterator>
0016 #include <map>
0017
0018 static std::size_t sourceLinkHash(const Acts::SourceLink& a) {
0019 return static_cast<std::size_t>(
0020 a.get<ActsExamples::IndexSourceLink>().index());
0021 }
0022
0023 static bool sourceLinkEquality(const Acts::SourceLink& a,
0024 const Acts::SourceLink& b) {
0025 return a.get<ActsExamples::IndexSourceLink>().index() ==
0026 b.get<ActsExamples::IndexSourceLink>().index();
0027 }
0028
0029 ActsExamples::AmbiguityResolutionMLAlgorithm::AmbiguityResolutionMLAlgorithm(
0030 ActsExamples::AmbiguityResolutionMLAlgorithm::Config cfg,
0031 Acts::Logging::Level lvl)
0032 : ActsExamples::IAlgorithm("AmbiguityResolutionMLAlgorithm", lvl),
0033 m_cfg(std::move(cfg)),
0034 m_ambiML(m_cfg.toAmbiguityResolutionMLConfig(), logger().clone()) {
0035 if (m_cfg.inputTracks.empty()) {
0036 throw std::invalid_argument("Missing trajectories input collection");
0037 }
0038 if (m_cfg.outputTracks.empty()) {
0039 throw std::invalid_argument("Missing trajectories output collection");
0040 }
0041 m_inputTracks.initialize(m_cfg.inputTracks);
0042 m_outputTracks.initialize(m_cfg.outputTracks);
0043 }
0044
0045 ActsExamples::ProcessCode ActsExamples::AmbiguityResolutionMLAlgorithm::execute(
0046 const AlgorithmContext& ctx) const {
0047
0048 const auto& tracks = m_inputTracks(ctx);
0049
0050
0051 std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0052 trackMap =
0053 m_ambiML.mapTrackHits(tracks, &sourceLinkHash, &sourceLinkEquality);
0054
0055 auto cluster = Acts::detail::clusterDuplicateTracks(trackMap);
0056
0057 std::vector<std::size_t> goodTracks =
0058 m_ambiML.solveAmbiguity(cluster, tracks);
0059
0060 TrackContainer solvedTracks{std::make_shared<Acts::VectorTrackContainer>(),
0061 std::make_shared<Acts::VectorMultiTrajectory>()};
0062 solvedTracks.ensureDynamicColumns(tracks);
0063 for (auto iTrack : goodTracks) {
0064 auto destProxy = solvedTracks.makeTrack();
0065 auto srcProxy = tracks.getTrack(iTrack);
0066 destProxy.copyFrom(srcProxy, false);
0067 destProxy.tipIndex() = srcProxy.tipIndex();
0068 }
0069
0070 ActsExamples::ConstTrackContainer outputTracks{
0071 std::make_shared<Acts::ConstVectorTrackContainer>(
0072 std::move(solvedTracks.container())),
0073 tracks.trackStateContainerHolder()};
0074
0075 m_outputTracks(ctx, std::move(outputTracks));
0076
0077 return ActsExamples::ProcessCode::SUCCESS;
0078 }