File indexing completed on 2025-10-22 07:52:25
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/TrackFindingML/AmbiguityResolutionMLAlgorithm.hpp"
0010
0011 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0012 #include "ActsExamples/EventData/Measurement.hpp"
0013 #include "ActsExamples/Framework/ProcessCode.hpp"
0014
0015 #include <iterator>
0016 #include <map>
0017
0018 using namespace Acts;
0019 using namespace ActsPlugins;
0020
0021 static std::size_t sourceLinkHash(const SourceLink& a) {
0022 return static_cast<std::size_t>(
0023 a.get<ActsExamples::IndexSourceLink>().index());
0024 }
0025
0026 static bool sourceLinkEquality(const SourceLink& a, const SourceLink& b) {
0027 return a.get<ActsExamples::IndexSourceLink>().index() ==
0028 b.get<ActsExamples::IndexSourceLink>().index();
0029 }
0030
0031 ActsExamples::AmbiguityResolutionMLAlgorithm::AmbiguityResolutionMLAlgorithm(
0032 ActsExamples::AmbiguityResolutionMLAlgorithm::Config cfg,
0033 Logging::Level lvl)
0034 : ActsExamples::IAlgorithm("AmbiguityResolutionMLAlgorithm", lvl),
0035 m_cfg(std::move(cfg)),
0036 m_ambiML(m_cfg.toAmbiguityResolutionMLConfig(), logger().clone()) {
0037 if (m_cfg.inputTracks.empty()) {
0038 throw std::invalid_argument("Missing trajectories input collection");
0039 }
0040 if (m_cfg.outputTracks.empty()) {
0041 throw std::invalid_argument("Missing trajectories output collection");
0042 }
0043 m_inputTracks.initialize(m_cfg.inputTracks);
0044 m_outputTracks.initialize(m_cfg.outputTracks);
0045 }
0046
0047 ActsExamples::ProcessCode ActsExamples::AmbiguityResolutionMLAlgorithm::execute(
0048 const AlgorithmContext& ctx) const {
0049
0050 const auto& tracks = m_inputTracks(ctx);
0051
0052
0053 std::multimap<int, std::pair<std::size_t, std::vector<std::size_t>>>
0054 trackMap =
0055 m_ambiML.mapTrackHits(tracks, &sourceLinkHash, &sourceLinkEquality);
0056
0057 auto cluster = Acts::detail::clusterDuplicateTracks(trackMap);
0058
0059 std::vector<std::size_t> goodTracks =
0060 m_ambiML.solveAmbiguity(cluster, tracks);
0061
0062 TrackContainer solvedTracks{std::make_shared<VectorTrackContainer>(),
0063 std::make_shared<VectorMultiTrajectory>()};
0064 solvedTracks.ensureDynamicColumns(tracks);
0065 for (auto iTrack : goodTracks) {
0066 auto destProxy = solvedTracks.makeTrack();
0067 auto srcProxy = tracks.getTrack(iTrack);
0068 destProxy.copyFromWithoutStates(srcProxy);
0069 destProxy.tipIndex() = srcProxy.tipIndex();
0070 }
0071
0072 ActsExamples::ConstTrackContainer outputTracks{
0073 std::make_shared<ConstVectorTrackContainer>(
0074 std::move(solvedTracks.container())),
0075 tracks.trackStateContainerHolder()};
0076
0077 m_outputTracks(ctx, std::move(outputTracks));
0078
0079 return ActsExamples::ProcessCode::SUCCESS;
0080 }