File indexing completed on 2025-01-18 09:11:51
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/Io/Csv/CsvSeedWriter.hpp"
0010
0011 #include "Acts/EventData/Seed.hpp"
0012 #include "Acts/EventData/TrackParameters.hpp"
0013 #include "Acts/Utilities/Helpers.hpp"
0014 #include "ActsExamples/EventData/AverageSimHits.hpp"
0015 #include "ActsExamples/EventData/Index.hpp"
0016 #include "ActsExamples/EventData/Measurement.hpp"
0017 #include "ActsExamples/EventData/SimHit.hpp"
0018 #include "ActsExamples/EventData/SimParticle.hpp"
0019 #include "ActsExamples/EventData/SimSeed.hpp"
0020 #include "ActsExamples/Utilities/EventDataTransforms.hpp"
0021 #include "ActsExamples/Utilities/Paths.hpp"
0022 #include "ActsExamples/Utilities/Range.hpp"
0023 #include "ActsExamples/Validation/TrackClassification.hpp"
0024
0025 #include <fstream>
0026 #include <ios>
0027 #include <iostream>
0028 #include <numbers>
0029 #include <stdexcept>
0030 #include <string>
0031 #include <unordered_map>
0032 #include <unordered_set>
0033
0034 using Acts::VectorHelpers::eta;
0035 using Acts::VectorHelpers::phi;
0036 using Acts::VectorHelpers::theta;
0037
0038 ActsExamples::CsvSeedWriter::CsvSeedWriter(
0039 const ActsExamples::CsvSeedWriter::Config& config,
0040 Acts::Logging::Level level)
0041 : WriterT<TrackParametersContainer>(config.inputTrackParameters,
0042 "CsvSeedWriter", level),
0043 m_cfg(config) {
0044 if (m_cfg.inputSimSeeds.empty()) {
0045 throw std::invalid_argument("Missing space points input collection");
0046 }
0047 if (m_cfg.inputSimHits.empty()) {
0048 throw std::invalid_argument("Missing simulated hits input collection");
0049 }
0050 if (m_cfg.inputMeasurementParticlesMap.empty()) {
0051 throw std::invalid_argument("Missing hit-particles map input collection");
0052 }
0053 if (m_cfg.inputMeasurementSimHitsMap.empty()) {
0054 throw std::invalid_argument(
0055 "Missing hit-simulated-hits map input collection");
0056 }
0057 if (m_cfg.fileName.empty()) {
0058 throw std::invalid_argument("Missing output filename");
0059 }
0060 if (m_cfg.outputDir.empty()) {
0061 throw std::invalid_argument("Missing output directory");
0062 }
0063
0064 m_inputSimSeeds.initialize(m_cfg.inputSimSeeds);
0065 m_inputSimHits.initialize(m_cfg.inputSimHits);
0066 m_inputMeasurementParticlesMap.initialize(m_cfg.inputMeasurementParticlesMap);
0067 m_inputMeasurementSimHitsMap.initialize(m_cfg.inputMeasurementSimHitsMap);
0068 }
0069
0070 ActsExamples::ProcessCode ActsExamples::CsvSeedWriter::writeT(
0071 const ActsExamples::AlgorithmContext& ctx,
0072 const TrackParametersContainer& trackParams) {
0073
0074 const auto& seeds = m_inputSimSeeds(ctx);
0075 const auto& simHits = m_inputSimHits(ctx);
0076 const auto& hitParticlesMap = m_inputMeasurementParticlesMap(ctx);
0077 const auto& hitSimHitsMap = m_inputMeasurementSimHitsMap(ctx);
0078
0079 std::string path =
0080 perEventFilepath(m_cfg.outputDir, m_cfg.fileName, ctx.eventNumber);
0081
0082 std::ofstream mos(path, std::ofstream::out | std::ofstream::trunc);
0083 if (!mos) {
0084 throw std::ios_base::failure("Could not open '" + path + "' to write");
0085 }
0086
0087 std::unordered_map<std::size_t, SeedInfo> infoMap;
0088 std::unordered_map<ActsFatras::Barcode, std::pair<std::size_t, float>>
0089 goodSeed;
0090
0091
0092 for (std::size_t iparams = 0; iparams < trackParams.size(); ++iparams) {
0093
0094 const auto params = trackParams[iparams].parameters();
0095
0096 float seedPhi = params[Acts::eBoundPhi];
0097 float seedEta = std::atanh(std::cos(params[Acts::eBoundTheta]));
0098
0099
0100 const auto& seed = seeds[iparams];
0101 const auto& ptrack = seedToPrototrack(seed);
0102
0103 std::vector<ParticleHitCount> particleHitCounts;
0104 identifyContributingParticles(hitParticlesMap, ptrack, particleHitCounts);
0105 bool truthMatched = false;
0106 float truthDistance = -1;
0107 auto majorityParticleId = particleHitCounts.front().particleId;
0108
0109
0110 if (particleHitCounts.size() == 1) {
0111 truthMatched = true;
0112
0113 const auto& hitIdx = ptrack.front();
0114
0115 auto indices = makeRange(hitSimHitsMap.equal_range(hitIdx));
0116
0117 Acts::Vector3 truthUnitDir = {0, 0, 0};
0118 for (auto [_, simHitIdx] : indices) {
0119 const auto& simHit = *simHits.nth(simHitIdx);
0120 if (simHit.particleId() == majorityParticleId) {
0121 truthUnitDir = simHit.direction();
0122 }
0123 }
0124
0125 float truthPhi = phi(truthUnitDir);
0126 float truthEta = std::atanh(std::cos(theta(truthUnitDir)));
0127 float dEta = std::abs(truthEta - seedEta);
0128 float dPhi =
0129 std::abs(truthPhi - seedPhi) < std::numbers::pi_v<float>
0130 ? std::abs(truthPhi - seedPhi)
0131 : std::abs(truthPhi - seedPhi) - std::numbers::pi_v<float>;
0132 truthDistance = sqrt(dPhi * dPhi + dEta * dEta);
0133
0134
0135 if (goodSeed.contains(majorityParticleId)) {
0136 if (goodSeed[majorityParticleId].second > truthDistance) {
0137 goodSeed[majorityParticleId] = std::make_pair(iparams, truthDistance);
0138 }
0139 } else {
0140 goodSeed[majorityParticleId] = std::make_pair(iparams, truthDistance);
0141 }
0142 }
0143
0144 boost::container::small_vector<Acts::Vector3, 3> globalPosition;
0145 for (auto spacePointPtr : seed.sp()) {
0146 Acts::Vector3 pos(spacePointPtr->x(), spacePointPtr->y(),
0147 spacePointPtr->z());
0148 globalPosition.push_back(pos);
0149 }
0150
0151
0152 SeedInfo toAdd;
0153 toAdd.seedID = iparams;
0154 toAdd.particleId = majorityParticleId;
0155 toAdd.seedPt = std::abs(1.0 / params[Acts::eBoundQOverP]) *
0156 std::sin(params[Acts::eBoundTheta]);
0157 toAdd.seedPhi = seedPhi;
0158 toAdd.seedEta = seedEta;
0159 toAdd.vertexZ = seed.z();
0160 toAdd.quality = seed.seedQuality();
0161 toAdd.globalPosition = globalPosition;
0162 toAdd.truthDistance = truthDistance;
0163 toAdd.seedType = truthMatched ? "duplicate" : "fake";
0164 toAdd.measurementsID = ptrack;
0165
0166 infoMap[toAdd.seedID] = toAdd;
0167 }
0168
0169 mos << "seed_id,particleId," << "pT,eta,phi," << "bX,bY,bZ," << "mX,mY,mZ,"
0170 << "tX,tY,tZ," << "good/duplicate/fake," << "vertexZ,quality,"
0171 << "Hits_ID" << '\n';
0172
0173 for (auto& [id, info] : infoMap) {
0174 if (goodSeed[info.particleId].first == id) {
0175 info.seedType = "good";
0176 }
0177
0178 mos << info.seedID << ",";
0179 mos << info.particleId << ",";
0180 mos << info.seedPt << ",";
0181 mos << info.seedEta << ",";
0182 mos << info.seedPhi << ",";
0183 for (auto& point : info.globalPosition) {
0184 mos << point.x() << ",";
0185 mos << point.y() << ",";
0186 mos << point.z() << ",";
0187 }
0188 mos << info.seedType << ",";
0189 mos << info.vertexZ << ",";
0190 mos << info.quality << ",";
0191 mos << "\"[";
0192 for (auto& ID : info.measurementsID) {
0193 mos << ID << ",";
0194 }
0195 mos << "]\"";
0196 mos << '\n';
0197 }
0198
0199 return ProcessCode::SUCCESS;
0200 }