File indexing completed on 2026-06-04 07:59:26
0001
0002
0003
0004 #include <algorithms/service.h>
0005 #include <edm4eic/EDM4eicVersion.h>
0006 #include <edm4eic/MCRecoParticleAssociationCollection.h>
0007 #include <edm4eic/ReconstructedParticleCollection.h>
0008 #include <edm4hep/MCParticleCollection.h>
0009 #include <edm4hep/Vector3f.h>
0010 #include <edm4hep/utils/vector_utils.h>
0011 #include <podio/detail/Link.h>
0012 #include <podio/detail/LinkCollectionImpl.h>
0013 #include <cmath>
0014 #include <exception>
0015 #include <gsl/pointers>
0016 #include <memory>
0017 #include <random>
0018 #include <stdexcept>
0019 #include <vector>
0020
0021 #include "algorithms/pid_lut/PIDLookup.h"
0022 #include "algorithms/pid_lut/PIDLookupConfig.h"
0023 #include "services/pid_lut/PIDLookupTableSvc.h"
0024
0025 namespace eicrecon {
0026
0027 void PIDLookup::init() {
0028
0029 try {
0030 m_system = m_detector->constant<int32_t>(m_cfg.system);
0031 } catch (const std::exception& e) {
0032 error("Failed to get {} from the detector: {}", m_cfg.system, e.what());
0033 throw std::runtime_error("Failed to get requested ID from the detector");
0034 }
0035
0036 auto& serviceSvc = algorithms::ServiceSvc::instance();
0037 auto* lut_svc = serviceSvc.service<PIDLookupTableSvc>("PIDLookupTableSvc");
0038
0039 m_lut = lut_svc->load(m_cfg.filename,
0040 {
0041 .pdg_values = m_cfg.pdg_values,
0042 .charge_values = m_cfg.charge_values,
0043 .momentum_edges = m_cfg.momentum_edges,
0044 .polar_edges = m_cfg.polar_edges,
0045 .azimuthal_binning = m_cfg.azimuthal_binning,
0046 .azimuthal_bin_centers_in_lut = m_cfg.azimuthal_bin_centers_in_lut,
0047 .momentum_bin_centers_in_lut = m_cfg.momentum_bin_centers_in_lut,
0048 .polar_bin_centers_in_lut = m_cfg.polar_bin_centers_in_lut,
0049 .use_radians = m_cfg.use_radians,
0050 .missing_electron_prob = m_cfg.missing_electron_prob,
0051 });
0052 if (m_lut == nullptr) {
0053 throw std::runtime_error("LUT not available");
0054 }
0055 }
0056
0057 void PIDLookup::process(const Input& input, const Output& output) const {
0058 const auto [headers, recoparts_in, partassocs_in] = input;
0059 #if EDM4EIC_BUILD_VERSION >= EDM4EIC_VERSION(8, 7, 0)
0060 auto [recoparts_out, partlinks_out, partassocs_out, partids_out] = output;
0061 #else
0062 auto [recoparts_out, partassocs_out, partids_out] = output;
0063 #endif
0064
0065
0066 auto seed = m_uid.getUniqueID(*headers, name());
0067 std::default_random_engine generator(seed);
0068 std::uniform_real_distribution<double> uniform;
0069
0070 for (const auto& recopart_without_pid : *recoparts_in) {
0071 auto recopart = recopart_without_pid.clone();
0072
0073
0074 auto best_assoc = edm4eic::MCRecoParticleAssociation::makeEmpty();
0075 for (auto assoc_in : *partassocs_in) {
0076 if (assoc_in.getRec() == recopart_without_pid) {
0077 if ((not best_assoc.isAvailable()) || (best_assoc.getWeight() < assoc_in.getWeight())) {
0078 best_assoc = assoc_in;
0079 }
0080 #if EDM4EIC_BUILD_VERSION >= EDM4EIC_VERSION(8, 7, 0)
0081 auto link_out = partlinks_out->create();
0082 link_out.setFrom(recopart);
0083 link_out.setTo(assoc_in.getSim());
0084 link_out.setWeight(assoc_in.getWeight());
0085 #endif
0086 auto assoc_out = assoc_in.clone();
0087 assoc_out.setRec(recopart);
0088 partassocs_out->push_back(assoc_out);
0089 }
0090 }
0091 if (not best_assoc.isAvailable()) {
0092 recoparts_out->push_back(recopart);
0093 continue;
0094 }
0095
0096 edm4hep::MCParticle mcpart = best_assoc.getSim();
0097
0098 int true_pdg = mcpart.getPDG();
0099 int true_charge = mcpart.getCharge();
0100 int charge = recopart.getCharge();
0101 double momentum = edm4hep::utils::magnitude(recopart.getMomentum());
0102
0103 double theta = edm4hep::utils::anglePolar(recopart.getMomentum()) / M_PI * 180.;
0104 double phi = edm4hep::utils::angleAzimuthal(recopart.getMomentum()) / M_PI * 180.;
0105
0106 trace("lookup for true_pdg={}, true_charge={}, momentum={:.2f} GeV, polar={:.2f}, "
0107 "aziumthal={:.2f}",
0108 true_pdg, true_charge, momentum, theta, phi);
0109 const auto* entry = m_lut->Lookup(true_pdg, true_charge, momentum, theta, phi);
0110
0111 int identified_pdg = 0;
0112
0113 if ((entry != nullptr) && ((entry->prob_electron != 0.) || (entry->prob_pion != 0.) ||
0114 (entry->prob_kaon != 0.) || (entry->prob_proton != 0.))) {
0115 double random_unit_interval = uniform(generator);
0116
0117 trace("entry with e:pi:K:P={}:{}:{}:{}", entry->prob_electron, entry->prob_pion,
0118 entry->prob_kaon, entry->prob_proton);
0119
0120 recopart.addToParticleIDs(
0121 partids_out->create(m_system,
0122 std::copysign(11, -charge),
0123 0,
0124 static_cast<float>(entry->prob_electron)
0125 ));
0126 recopart.addToParticleIDs(
0127 partids_out->create(m_system,
0128 std::copysign(211, charge),
0129 0,
0130 static_cast<float>(entry->prob_pion)
0131 ));
0132 recopart.addToParticleIDs(
0133 partids_out->create(m_system,
0134 std::copysign(321, charge),
0135 0,
0136 static_cast<float>(entry->prob_kaon)
0137 ));
0138 recopart.addToParticleIDs(
0139 partids_out->create(m_system,
0140 std::copysign(2212, charge),
0141 0,
0142 static_cast<float>(entry->prob_proton)
0143 ));
0144
0145 if (random_unit_interval < entry->prob_electron) {
0146 identified_pdg = 11;
0147 recopart.setParticleIDUsed((*partids_out)[partids_out->size() - 4]);
0148 } else if (random_unit_interval < (entry->prob_electron + entry->prob_pion)) {
0149 identified_pdg = 211;
0150 recopart.setParticleIDUsed((*partids_out)[partids_out->size() - 3]);
0151 } else if (random_unit_interval <
0152 (entry->prob_electron + entry->prob_pion + entry->prob_kaon)) {
0153 identified_pdg = 321;
0154 recopart.setParticleIDUsed((*partids_out)[partids_out->size() - 2]);
0155 } else if (random_unit_interval < (entry->prob_electron + entry->prob_pion +
0156 entry->prob_kaon + entry->prob_proton)) {
0157 identified_pdg = 2212;
0158 recopart.setParticleIDUsed((*partids_out)[partids_out->size() - 1]);
0159 }
0160 }
0161
0162 if (identified_pdg != 0) {
0163 recopart.setPDG(std::copysign(identified_pdg, (identified_pdg == 11) ? -charge : charge));
0164 recopart.setMass(m_particleSvc.particle(identified_pdg).mass);
0165 recopart.setEnergy(std::hypot(momentum, m_particleSvc.particle(identified_pdg).mass));
0166 }
0167
0168 if (identified_pdg != 0) {
0169 trace("randomized PDG is {}", recopart.getPDG());
0170 }
0171
0172 recoparts_out->push_back(recopart);
0173 }
0174 }
0175
0176 }