Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-06 08:01:13

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsExamples/Io/Csv/CsvTrackWriter.hpp"
0010 
0011 #include "Acts/EventData/ProxyAccessor.hpp"
0012 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0013 #include "ActsExamples/EventData/Track.hpp"
0014 #include "ActsExamples/Framework/AlgorithmContext.hpp"
0015 #include "ActsExamples/Utilities/Paths.hpp"
0016 #include "ActsExamples/Validation/TrackClassification.hpp"
0017 
0018 #include <algorithm>
0019 #include <fstream>
0020 #include <iomanip>
0021 #include <map>
0022 #include <stdexcept>
0023 #include <string>
0024 #include <tuple>
0025 #include <unordered_map>
0026 #include <unordered_set>
0027 #include <utility>
0028 
0029 namespace ActsExamples {
0030 
0031 CsvTrackWriter::CsvTrackWriter(const Config& config, Acts::Logging::Level level)
0032     : WriterT<ConstTrackContainer>(config.inputTracks, "CsvTrackWriter", level),
0033       m_cfg(config) {
0034   if (m_cfg.inputTracks.empty()) {
0035     throw std::invalid_argument("Missing input tracks collection");
0036   }
0037 
0038   m_inputMeasurementParticlesMap.initialize(m_cfg.inputMeasurementParticlesMap);
0039 }
0040 
0041 ProcessCode CsvTrackWriter::writeT(const AlgorithmContext& context,
0042                                    const ConstTrackContainer& tracks) {
0043   // open per-event file
0044   std::string path =
0045       perEventFilepath(m_cfg.outputDir, m_cfg.fileName, context.eventNumber);
0046   std::ofstream mos(path, std::ofstream::out | std::ofstream::trunc);
0047   if (!mos) {
0048     throw std::ios_base::failure("Could not open '" + path + "' to write");
0049   }
0050 
0051   const auto& hitParticlesMap = m_inputMeasurementParticlesMap(context);
0052 
0053   std::unordered_map<Acts::TrackIndexType, TrackInfo> infoMap;
0054 
0055   // Counter of truth-matched reco tracks
0056   using RecoTrackInfo = std::pair<TrackInfo, std::size_t>;
0057   std::map<SimBarcode, std::vector<RecoTrackInfo>> matched;
0058 
0059   for (const auto& track : tracks) {
0060     // Reco track selection
0061     //@TODO: add interface for applying others cuts on reco tracks:
0062     // -> pT, d0, z0, detector-specific hits/holes number cut
0063     if (track.nMeasurements() < m_cfg.nMeasurementsMin) {
0064       continue;
0065     }
0066 
0067     // Check if the reco track has fitted track parameters
0068     if (!track.hasReferenceSurface()) {
0069       ACTS_WARNING(
0070           "No fitted track parameters for trajectory with entry index = "
0071           << track.tipIndex());
0072       continue;
0073     }
0074 
0075     // Get the majority truth particle to this track
0076     std::vector<ParticleHitCount> particleHitCount;
0077     identifyContributingParticles(hitParticlesMap, track, particleHitCount);
0078     if (m_cfg.onlyTruthMatched && particleHitCount.empty()) {
0079       ACTS_WARNING(
0080           "No truth particle associated with this trajectory with entry "
0081           "index = "
0082           << track.tipIndex());
0083       continue;
0084     }
0085 
0086     // Requirement on the pT of the track
0087     auto params = track.createParametersAtReference();
0088     const auto momentum = params.momentum();
0089     const auto pT = Acts::VectorHelpers::perp(momentum);
0090     if (pT < m_cfg.ptMin) {
0091       continue;
0092     }
0093     std::size_t nMajorityHits = 0;
0094     SimBarcode majorityParticleId;
0095     if (!particleHitCount.empty()) {
0096       // Get the majority particle counts
0097       majorityParticleId = particleHitCount.front().particleId;
0098       // n Majority hits
0099       nMajorityHits = particleHitCount.front().hitCount;
0100     }
0101 
0102     static const Acts::ConstProxyAccessor<unsigned int> seedNumber(
0103         "trackGroup");
0104 
0105     // track info
0106     TrackInfo toAdd;
0107     toAdd.trackId = track.index();
0108     if (tracks.hasColumn(Acts::hashString("trackGroup"))) {
0109       toAdd.seedID = seedNumber(track);
0110     } else {
0111       toAdd.seedID = 0;
0112     }
0113     toAdd.particleId = majorityParticleId;
0114     toAdd.nStates = track.nTrackStates();
0115     toAdd.nMajorityHits = nMajorityHits;
0116     toAdd.nMeasurements = track.nMeasurements();
0117     toAdd.nOutliers = track.nOutliers();
0118     toAdd.nHoles = track.nHoles();
0119     toAdd.nSharedHits = track.nSharedHits();
0120     toAdd.chi2Sum = track.chi2();
0121     toAdd.NDF = track.nDoF();
0122     toAdd.truthMatchProb = toAdd.nMajorityHits * 1. / track.nMeasurements();
0123     toAdd.fittedParameters = params;
0124     toAdd.trackType = "unknown";
0125 
0126     for (const auto& state : track.trackStatesReversed()) {
0127       if (state.typeFlags().hasMeasurement()) {
0128         auto sl =
0129             state.getUncalibratedSourceLink().template get<IndexSourceLink>();
0130         auto hitIndex = sl.index();
0131         toAdd.measurementsID.insert(toAdd.measurementsID.begin(), hitIndex);
0132       }
0133     }
0134 
0135     // Check if the trajectory is matched with truth.
0136     if (toAdd.truthMatchProb >= m_cfg.truthMatchProbMin) {
0137       matched[toAdd.particleId].push_back({toAdd, toAdd.trackId});
0138     } else {
0139       toAdd.trackType = "fake";
0140     }
0141 
0142     infoMap[toAdd.trackId] = toAdd;
0143   }
0144 
0145   // Find duplicates
0146   std::unordered_set<std::size_t> listGoodTracks;
0147   for (auto& [particleId, matchedTracks] : matched) {
0148     std::ranges::sort(matchedTracks, [](const auto& lhs, const auto& rhs) {
0149       const auto& t1 = lhs.first;
0150       const auto& t2 = rhs.first;
0151       // nMajorityHits are sorted descending, others ascending
0152       return std::tie(t2.nMajorityHits, t1.nOutliers, t1.chi2Sum) <
0153              std::tie(t1.nMajorityHits, t2.nOutliers, t2.chi2Sum);
0154     });
0155 
0156     listGoodTracks.insert(matchedTracks.front().first.trackId);
0157   }
0158 
0159   // write csv header
0160   mos << "track_id,seed_id,particleId,"
0161       << "nStates,nMajorityHits,nMeasurements,nOutliers,nHoles,nSharedHits,"
0162       << "chi2,ndf,chi2/ndf,"
0163       << "pT,eta,phi,"
0164       << "truthMatchProbability,"
0165       << "good/duplicate/fake,"
0166       << "Measurements_ID";
0167 
0168   mos << '\n';
0169   mos << std::setprecision(m_cfg.outputPrecision);
0170 
0171   // good/duplicate/fake = 0/1/2
0172   for (auto& [id, trajState] : infoMap) {
0173     if (listGoodTracks.contains(id)) {
0174       trajState.trackType = "good";
0175     } else if (trajState.trackType != "fake") {
0176       trajState.trackType = "duplicate";
0177     }
0178 
0179     const auto& params = *trajState.fittedParameters;
0180     const auto momentum = params.momentum();
0181 
0182     // write the track info
0183     mos << trajState.trackId << ",";
0184     mos << trajState.seedID << ",";
0185     mos << trajState.particleId << ",";
0186     mos << trajState.nStates << ",";
0187     mos << trajState.nMajorityHits << ",";
0188     mos << trajState.nMeasurements << ",";
0189     mos << trajState.nOutliers << ",";
0190     mos << trajState.nHoles << ",";
0191     mos << trajState.nSharedHits << ",";
0192     mos << trajState.chi2Sum << ",";
0193     mos << trajState.NDF << ",";
0194     mos << trajState.chi2Sum * 1.0 / trajState.NDF << ",";
0195     mos << Acts::VectorHelpers::perp(momentum) << ",";
0196     mos << Acts::VectorHelpers::eta(momentum) << ",";
0197     mos << Acts::VectorHelpers::phi(momentum) << ",";
0198     mos << trajState.truthMatchProb << ",";
0199     mos << trajState.trackType << ",";
0200     mos << "\"[";
0201     for (auto& ID : trajState.measurementsID) {
0202       mos << ID << ",";
0203     }
0204     mos << "]\"";
0205     mos << '\n';
0206   }
0207 
0208   return ProcessCode::SUCCESS;
0209 }
0210 
0211 }  // namespace ActsExamples