Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-09-28 07:02:24

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2022 Sylvester Joosten, Chao Peng, Whitney Armstrong
0003 
0004 /*
0005  *  Reconstruct the cluster with Center of Gravity method
0006  *  Logarithmic weighting is used for mimicing energy deposit in transverse direction
0007  *
0008  *  Author: Chao Peng (ANL), 09/27/2020
0009  */
0010 
0011 #include <algorithms/calorimetry/ClusterRecoCoG.h>
0012 
0013 #include <algorithm>
0014 #include <functional>
0015 #include <map>
0016 
0017 #include <fmt/format.h>
0018 #include <fmt/ranges.h>
0019 
0020 // Event Model related classes
0021 #include "edm4hep/utils/vector_utils.h"
0022 
0023 namespace algorithms::calorimetry {
0024 namespace {
0025 
0026   // weighting functions (with place holders for hit energy, total energy, one parameter
0027   double constWeight(double /*E*/, double /*tE*/, double /*p*/) { return 1.0; }
0028   double linearWeight(double E, double /*tE*/, double /*p*/) { return E; }
0029   double logWeight(double E, double tE, double base) {
0030     return std::max(0., base + std::log(E / tE));
0031   }
0032 
0033   const std::map<std::string, ClusterRecoCoG::WeightFunc> weightMethods{
0034       {"none", constWeight},
0035       {"linear", linearWeight},
0036       {"log", logWeight},
0037   };
0038 } // namespace
0039 
0040 void ClusterRecoCoG::init() {
0041   // select weighting method
0042   std::string ew = m_energyWeight;
0043   // make it case-insensitive
0044   std::transform(ew.begin(), ew.end(), ew.begin(), [](char s) { return std::tolower(s); });
0045   if (!weightMethods.count(ew)) {
0046     std::vector<std::string> keys;
0047     std::transform(weightMethods.begin(), weightMethods.end(), std::back_inserter(keys),
0048                    [](const auto& keyvalue) { return keyvalue.first; });
0049     raise(fmt::format("Cannot find energy weighting method {}, choose one from {}", ew,
0050                       keys));
0051   }
0052   m_weightFunc = weightMethods.at(ew);
0053   info() << fmt::format("Energy weight method set to: {}", ew) << endmsg;
0054 }
0055 
0056 void ClusterRecoCoG::process(const ClusterRecoCoG::Input& input,
0057                              const ClusterRecoCoG::Output& output) const {
0058   const auto [proto, opt_simhits] = input;
0059   auto [clusters, opt_assoc]      = output;
0060 
0061   for (const auto& pcl : *proto) {
0062     auto cl = reconstruct(pcl);
0063 
0064     if (aboveDebugThreshold()) {
0065       debug() << cl.getNhits() << " hits: " << cl.getEnergy() / dd4hep::GeV << " GeV, ("
0066               << cl.getPosition().x / dd4hep::mm << ", " << cl.getPosition().y / dd4hep::mm << ", "
0067               << cl.getPosition().z / dd4hep::mm << ")" << endmsg;
0068     }
0069     clusters->push_back(cl);
0070 
0071     // If mcHits are available, associate cluster with MCParticle
0072     // 1. find proto-cluster hit with largest energy deposition
0073     // 2. find first mchit with same CellID
0074     // 3. assign mchit's MCParticle as cluster truth
0075     if (opt_simhits && opt_assoc) {
0076 
0077       // 1. find pclhit with largest energy deposition
0078       auto pclhits = pcl.getHits();
0079       auto pclhit  = std::max_element(pclhits.begin(), pclhits.end(),
0080                                      [](const auto& pclhit1, const auto& pclhit2) {
0081                                        return pclhit1.getEnergy() < pclhit2.getEnergy();
0082                                      });
0083 
0084       // 2. find mchit with same CellID
0085       // find_if not working, https://github.com/AIDASoft/podio/pull/273
0086       // auto mchit = std::find_if(
0087       //  opt_simhits->begin(),
0088       //  opt_simhits->end(),
0089       //  [&pclhit](const auto& mchit1) {
0090       //    return mchit1.getCellID() == pclhit->getCellID();
0091       //  }
0092       //);
0093       auto mchit = opt_simhits->begin();
0094       for (; mchit != opt_simhits->end(); ++mchit) {
0095         // break loop when CellID match found
0096         if (mchit->getCellID() == pclhit->getCellID()) {
0097           break;
0098         }
0099       }
0100       if (!(mchit != opt_simhits->end())) {
0101         // error condition should not happen
0102         // break if no matching hit found for this CellID
0103         warning() << "Proto-cluster has highest energy in CellID " << pclhit->getCellID()
0104                   << ", but no mc hit with that CellID was found." << endmsg;
0105         info() << "Proto-cluster hits: " << endmsg;
0106         for (const auto& pclhit1 : pclhits) {
0107           info() << pclhit1.getCellID() << ": " << pclhit1.getEnergy() << endmsg;
0108         }
0109         info() << "MC hits: " << endmsg;
0110         for (const auto& mchit1 : *opt_simhits) {
0111           info() << mchit1.getCellID() << ": " << mchit1.getEnergy() << endmsg;
0112         }
0113         break;
0114       }
0115 
0116       // 3. find mchit's MCParticle
0117       const auto& mcp = mchit->getContributions(0).getParticle();
0118 
0119       // debug output
0120       if (aboveDebugThreshold()) {
0121         debug() << "cluster has largest energy in cellID: " << pclhit->getCellID() << endmsg;
0122         debug() << "pcl hit with highest energy " << pclhit->getEnergy() << " at index "
0123                 << pclhit->getObjectID().index << endmsg;
0124         debug() << "corresponding mc hit energy " << mchit->getEnergy() << " at index "
0125                 << mchit->getObjectID().index << endmsg;
0126         debug() << "from MCParticle index " << mcp.getObjectID().index << ", PDG " << mcp.getPDG()
0127                 << ", " << edm4hep::utils::magnitude(mcp.getMomentum()) << endmsg;
0128       }
0129 
0130       // set association
0131       edm4eic::MutableMCRecoClusterParticleAssociation clusterassoc;
0132       clusterassoc.setRecID(cl.getObjectID().index);
0133       clusterassoc.setSimID(mcp.getObjectID().index);
0134       clusterassoc.setWeight(1.0);
0135       clusterassoc.setRec(cl);
0136       // clusterassoc.setSim(mcp);
0137       opt_assoc->push_back(clusterassoc);
0138     } else {
0139       if (aboveDebugThreshold()) {
0140         debug() << "No mcHitCollection was provided, so no truth association will be performed."
0141                 << endmsg;
0142       }
0143     }
0144   }
0145 }
0146 
0147 edm4eic::MutableCluster ClusterRecoCoG::reconstruct(const edm4eic::ProtoCluster& pcl) const {
0148   edm4eic::MutableCluster cl;
0149   cl.setNhits(pcl.hits_size());
0150 
0151   // no hits
0152   if (aboveDebugThreshold()) {
0153     debug() << "hit size = " << pcl.hits_size() << endmsg;
0154   }
0155   if (pcl.hits_size() == 0) {
0156     return cl;
0157   }
0158 
0159   // calculate total energy, find the cell with the maximum energy deposit
0160   float totalE = 0.;
0161   float maxE   = 0.;
0162   // Used to optionally constrain the cluster eta to those of the contributing hits
0163   float minHitEta = std::numeric_limits<float>::max();
0164   float maxHitEta = std::numeric_limits<float>::min();
0165   auto time       = pcl.getHits()[0].getTime();
0166   auto timeError  = pcl.getHits()[0].getTimeError();
0167   for (unsigned i = 0; i < pcl.getHits().size(); ++i) {
0168     const auto& hit   = pcl.getHits()[i];
0169     const auto weight = pcl.getWeights()[i];
0170     if (aboveDebugThreshold()) {
0171       debug() << "hit energy = " << hit.getEnergy() << " hit weight: " << weight << endmsg;
0172     }
0173     auto energy = hit.getEnergy() * weight;
0174     totalE += energy;
0175     if (energy > maxE) {
0176     }
0177     const float eta = edm4hep::utils::eta(hit.getPosition());
0178     if (eta < minHitEta) {
0179       minHitEta = eta;
0180     }
0181     if (eta > maxHitEta) {
0182       maxHitEta = eta;
0183     }
0184   }
0185   cl.setEnergy(totalE / m_sampFrac);
0186   cl.setEnergyError(0.);
0187   cl.setTime(time);
0188   cl.setTimeError(timeError);
0189 
0190   // center of gravity with logarithmic weighting
0191   float tw = 0.;
0192   auto v   = cl.getPosition();
0193   for (unsigned i = 0; i < pcl.getHits().size(); ++i) {
0194     const auto& hit   = pcl.getHits()[i];
0195     const auto weight = pcl.getWeights()[i];
0196     float w           = m_weightFunc(hit.getEnergy() * weight, totalE, m_logWeightBase.value());
0197     tw += w;
0198     v = v + (hit.getPosition() * w);
0199   }
0200   if (tw == 0.) {
0201     warning() << "zero total weights encountered, you may want to adjust your weighting parameter."
0202               << endmsg;
0203   }
0204   cl.setPosition(v / tw);
0205   cl.setPositionError({}); // @TODO: Covariance matrix
0206 
0207   // Optionally constrain the cluster to the hit eta values
0208   if (m_enableEtaBounds) {
0209     const bool overflow  = (edm4hep::utils::eta(cl.getPosition()) > maxHitEta);
0210     const bool underflow = (edm4hep::utils::eta(cl.getPosition()) < minHitEta);
0211     if (overflow || underflow) {
0212       const double newEta   = overflow ? maxHitEta : minHitEta;
0213       const double newTheta = edm4hep::utils::etaToAngle(newEta);
0214       const double newR     = edm4hep::utils::magnitude(cl.getPosition());
0215       const double newPhi   = edm4hep::utils::angleAzimuthal(cl.getPosition());
0216       cl.setPosition(edm4hep::utils::sphericalToVector(newR, newTheta, newPhi));
0217       if (aboveDebugThreshold()) {
0218         debug() << "Bound cluster position to contributing hits due to "
0219                 << (overflow ? "overflow" : "underflow") << endmsg;
0220       }
0221     }
0222   }
0223 
0224   // Additional convenience variables
0225 
0226   // best estimate on the cluster direction is the cluster position
0227   // for simple 2D CoG clustering
0228   cl.setIntrinsicTheta(edm4hep::utils::anglePolar(cl.getPosition()));
0229   cl.setIntrinsicPhi(edm4hep::utils::angleAzimuthal(cl.getPosition()));
0230   // TODO errors
0231 
0232   // Calculate radius
0233   // @TODO: add skewness
0234   if (cl.getNhits() > 1) {
0235     double radius = 0;
0236     for (const auto& hit : pcl.getHits()) {
0237       const auto delta = cl.getPosition() - hit.getPosition();
0238       radius += delta * delta;
0239     }
0240     radius = sqrt((1. / (cl.getNhits() - 1.)) * radius);
0241     cl.addToShapeParameters(radius);
0242     cl.addToShapeParameters(0 /* skewness */); // skewness not yet calculated
0243   }
0244 
0245   return cl;
0246 }
0247 
0248 } // namespace algorithms::calorimetry