File indexing completed on 2025-12-16 09:23:27
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/AmbiguityResolution/ScoreBasedAmbiguityResolutionAlgorithm.hpp"
0010
0011 #include "Acts/AmbiguityResolution/ScoreBasedAmbiguityResolution.hpp"
0012 #include "Acts/EventData/MultiTrajectoryHelpers.hpp"
0013 #include "Acts/Utilities/Logger.hpp"
0014 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0015 #include "ActsExamples/EventData/Measurement.hpp"
0016 #include "ActsExamples/Framework/ProcessCode.hpp"
0017 #include "ActsExamples/Framework/WhiteBoard.hpp"
0018 #include "ActsPlugins/Json/AmbiguityConfigJsonConverter.hpp"
0019
0020 #include <fstream>
0021
0022 namespace {
0023
0024 Acts::ScoreBasedAmbiguityResolution::Config transformConfig(
0025 const ActsExamples::ScoreBasedAmbiguityResolutionAlgorithm::Config& cfg,
0026 const std::string& configFile) {
0027 Acts::ScoreBasedAmbiguityResolution::Config result;
0028
0029 Acts::ConfigPair configPair;
0030 nlohmann::json json_file;
0031 std::ifstream file(configFile);
0032 if (!file.is_open()) {
0033 std::cerr << "Error opening file: " << configFile << std::endl;
0034 return {};
0035 }
0036 file >> json_file;
0037 file.close();
0038
0039 Acts::from_json(json_file, configPair);
0040
0041 result.volumeMap = configPair.first;
0042 result.detectorConfigs = configPair.second;
0043 result.minScore = cfg.minScore;
0044 result.minScoreSharedTracks = cfg.minScoreSharedTracks;
0045 result.maxSharedTracksPerMeasurement = cfg.maxSharedTracksPerMeasurement;
0046 result.maxShared = cfg.maxShared;
0047 result.minUnshared = cfg.minUnshared;
0048 result.useAmbiguityScoring = cfg.useAmbiguityScoring;
0049 return result;
0050 }
0051
0052 std::size_t sourceLinkHash(const Acts::SourceLink& a) {
0053 return static_cast<std::size_t>(
0054 a.get<ActsExamples::IndexSourceLink>().index());
0055 }
0056
0057 bool sourceLinkEquality(const Acts::SourceLink& a, const Acts::SourceLink& b) {
0058 return a.get<ActsExamples::IndexSourceLink>().index() ==
0059 b.get<ActsExamples::IndexSourceLink>().index();
0060 }
0061
0062 bool doubleHolesFilter(const Acts::TrackProxy<Acts::ConstVectorTrackContainer,
0063 Acts::ConstVectorMultiTrajectory,
0064 std::shared_ptr, true>& track) {
0065 bool doubleFlag = false;
0066 int counter = 0;
0067 for (const auto& ts : track.trackStatesReversed()) {
0068 auto iTypeFlags = ts.typeFlags();
0069 if (!iTypeFlags.test(Acts::TrackStateFlag::HoleFlag)) {
0070 doubleFlag = false;
0071 }
0072
0073 if (iTypeFlags.test(Acts::TrackStateFlag::HoleFlag)) {
0074 if (doubleFlag) {
0075 counter++;
0076 doubleFlag = false;
0077 } else {
0078 doubleFlag = true;
0079 };
0080 }
0081 }
0082 if (counter > 1) {
0083 return true;
0084 } else {
0085 return false;
0086 }
0087 }
0088
0089 }
0090
0091 ActsExamples::ScoreBasedAmbiguityResolutionAlgorithm::
0092 ScoreBasedAmbiguityResolutionAlgorithm(
0093 ActsExamples::ScoreBasedAmbiguityResolutionAlgorithm::Config cfg,
0094 Acts::Logging::Level lvl)
0095 : ActsExamples::IAlgorithm("ScoreBasedAmbiguityResolutionAlgorithm", lvl),
0096 m_cfg(std::move(cfg)),
0097 m_ambi(transformConfig(cfg, m_cfg.configFile), logger().clone()) {
0098 if (m_cfg.inputTracks.empty()) {
0099 throw std::invalid_argument("Missing trajectories input collection");
0100 }
0101 if (m_cfg.outputTracks.empty()) {
0102 throw std::invalid_argument("Missing trajectories output collection");
0103 }
0104 m_inputTracks.initialize(m_cfg.inputTracks);
0105 m_outputTracks.initialize(m_cfg.outputTracks);
0106 }
0107
0108 ActsExamples::ProcessCode
0109 ActsExamples::ScoreBasedAmbiguityResolutionAlgorithm::execute(
0110 const AlgorithmContext& ctx) const {
0111 const auto& tracks = m_inputTracks(ctx);
0112 ACTS_VERBOSE("Number of input tracks: " << tracks.size());
0113
0114 Acts::ScoreBasedAmbiguityResolution::Optionals<ConstTrackProxy> optionals;
0115 optionals.cuts.push_back(doubleHolesFilter);
0116 std::vector<int> goodTracks = m_ambi.solveAmbiguity(
0117 tracks, &sourceLinkHash, &sourceLinkEquality, optionals);
0118
0119 TrackContainer solvedTracks{std::make_shared<Acts::VectorTrackContainer>(),
0120 std::make_shared<Acts::VectorMultiTrajectory>()};
0121 solvedTracks.ensureDynamicColumns(tracks);
0122 for (auto iTrack : goodTracks) {
0123 auto destProxy = solvedTracks.makeTrack();
0124 auto srcProxy = tracks.getTrack(iTrack);
0125 destProxy.copyFromWithoutStates(srcProxy);
0126 destProxy.tipIndex() = srcProxy.tipIndex();
0127 }
0128
0129 ActsExamples::ConstTrackContainer outputTracks{
0130 std::make_shared<Acts::ConstVectorTrackContainer>(
0131 std::move(solvedTracks.container())),
0132 tracks.trackStateContainerHolder()};
0133
0134 m_outputTracks(ctx, std::move(outputTracks));
0135 return ActsExamples::ProcessCode::SUCCESS;
0136 }