Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:28:03

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2025 Chun Yuen Tsang
0003 
0004 #include "LGADHitClustering.h"
0005 
0006 #include <Acts/Definitions/Algebra.hpp>
0007 #include <Acts/Definitions/Units.hpp>
0008 #include <Acts/Geometry/GeometryIdentifier.hpp>
0009 #include <Acts/Surfaces/Surface.hpp>
0010 #include <DD4hep/Handle.h>
0011 #include <DD4hep/Readout.h>
0012 #include <DD4hep/VolumeManager.h>
0013 #include <DD4hep/detail/SegmentationsInterna.h>
0014 #include <DDSegmentation/MultiSegmentation.h>
0015 #include <DDSegmentation/Segmentation.h>
0016 #include <Evaluator/DD4hepUnits.h>
0017 #include <Math/GenVector/Cartesian3D.h>
0018 #include <Math/GenVector/DisplacementVector3D.h>
0019 #include <ROOT/RVec.hxx>
0020 #include <algorithms/geo.h>
0021 #include <edm4eic/Cov3f.h>
0022 #include <edm4eic/CovDiag3f.h>
0023 #include <edm4hep/Vector2f.h>
0024 #include <fmt/core.h>
0025 #include <stddef.h>
0026 #include <Eigen/Core>
0027 #include <cmath>
0028 #include <gsl/pointers>
0029 #include <limits>
0030 #include <set>
0031 #include <stdexcept>
0032 #include <unordered_map>
0033 #include <utility>
0034 #include <vector>
0035 
0036 #include "ActsGeometryProvider.h"
0037 #include "algorithms/interfaces/ActsSvc.h"
0038 #include "algorithms/tracking/LGADHitClusteringConfig.h"
0039 
0040 namespace eicrecon {
0041 
0042 void LGADHitClustering::init() {
0043 
0044   m_converter    = algorithms::GeoSvc::instance().cellIDPositionConverter();
0045   m_detector     = algorithms::GeoSvc::instance().detector();
0046   m_seg          = m_detector->readout(m_cfg.readout).segmentation();
0047   auto type      = m_seg.type();
0048   m_decoder      = m_seg.decoder();
0049   m_acts_context = algorithms::ActsSvc::instance().acts_geometry_provider();
0050 }
0051 
0052 void LGADHitClustering::_calcCluster(const Output& output,
0053                                      const std::vector<edm4eic::TrackerHit>& hits) const {
0054   if (hits.empty()) {
0055     return;
0056   }
0057   constexpr double mm_acts = Acts::UnitConstants::mm;
0058   using dd4hep::mm;
0059 
0060   auto [clusters] = output;
0061   auto cluster    = clusters->create();
0062   // Right now the clustering algorithm is either:
0063   // 1. simple average over all hits in a sensors
0064   // 2. Cell position with max ADC value in a cluster
0065   // Switch between option 1 and 2 with m_cfg.useAve
0066   // Will be problematic near the edges, but it's just an illustration
0067   float ave_x = 0, ave_y = 0;
0068   float sigma2_x = 0, sigma2_y = 0;
0069   double tot_charge = 0;
0070   // find cellID for the cell with maximum ADC value within a sensor
0071   dd4hep::rec::CellID cellID;
0072   auto max_charge    = std::numeric_limits<float>::min();
0073   auto earliest_time = std::numeric_limits<float>::max();
0074   float time_err;
0075   float max_charge_x;
0076   float max_charge_y;
0077   float max_charge_sigma2_x;
0078   float max_charge_sigma2_y;
0079 
0080   ROOT::VecOps::RVec<double> weights;
0081 
0082   for (size_t id = 0; id < hits.size(); ++id) {
0083     const auto& hit = hits[id];
0084     if (hit.getTime() < earliest_time) {
0085       earliest_time = hit.getTime();
0086       time_err      = hit.getTimeError();
0087     }
0088     // weigh all hits by ADC value
0089     auto pos = m_seg->position(hit.getCellID());
0090     if (hit.getEdep() < 0) {
0091       error("Edep for hit at cellID{} is negative. Please check the accuracy of your energy "
0092             "calibration. ",
0093             hit.getCellID());
0094     }
0095     const auto Edep = hit.getEdep();
0096     ave_x += Edep * pos.x();
0097     ave_y += Edep * pos.y();
0098     sigma2_x += Edep * Edep * hit.getPositionError().xx * mm_acts * mm_acts;
0099     sigma2_y += Edep * Edep * hit.getPositionError().yy * mm_acts * mm_acts;
0100 
0101     tot_charge += Edep;
0102     if (Edep > max_charge) {
0103       max_charge          = Edep;
0104       cellID              = hit.getCellID();
0105       max_charge_x        = pos.x();
0106       max_charge_y        = pos.y();
0107       max_charge_sigma2_x = hit.getPositionError().xx * mm_acts * mm_acts;
0108       max_charge_sigma2_y = hit.getPositionError().yy * mm_acts * mm_acts;
0109     }
0110     cluster.addToHits(hit);
0111     weights.push_back(Edep);
0112   }
0113 
0114   if (m_cfg.useAve) {
0115     weights /= tot_charge;
0116     ave_x /= tot_charge;
0117     ave_y /= tot_charge;
0118     sigma2_x /= tot_charge * tot_charge;
0119     sigma2_y /= tot_charge * tot_charge;
0120   } else {
0121     ave_x    = max_charge_x;
0122     ave_y    = max_charge_y;
0123     sigma2_x = max_charge_sigma2_x;
0124     sigma2_y = max_charge_sigma2_y;
0125   }
0126 
0127   // covariance copied from TrackerMeasurementFromHits.cc
0128   Acts::SquareMatrix2 cov = Acts::SquareMatrix2::Zero();
0129   cov(0, 0)               = sigma2_x;
0130   cov(1, 1)               = sigma2_y;
0131   cov(0, 1) = cov(1, 0) = 0.0;
0132 
0133   for (const auto& w : weights) {
0134     cluster.addToWeights(w);
0135   }
0136 
0137   edm4eic::Cov3f covariance;
0138   edm4hep::Vector2f locPos{static_cast<float>(ave_x / mm), static_cast<float>(ave_y / mm)};
0139 
0140   const auto* context    = m_converter->findContext(cellID);
0141   auto volID             = context->identifier;
0142   const auto& surfaceMap = m_acts_context->surfaceMap();
0143   const auto is          = surfaceMap.find(volID);
0144   if (is == surfaceMap.end()) {
0145     error("vol_id ({})  not found in m_surfaces.", volID);
0146   }
0147 
0148   const Acts::Surface* surface = is->second;
0149 
0150   cluster.setSurface(surface->geometryId().value());
0151   cluster.setLoc(locPos);
0152   cluster.setTime(earliest_time);
0153   cluster.setCovariance(
0154       {cov(0, 0), cov(1, 1), time_err * time_err, cov(0, 1)}); // Covariance on location and time
0155 }
0156 
0157 void LGADHitClustering::process(const LGADHitClustering::Input& input,
0158                                 const LGADHitClustering::Output& output) const {
0159   const auto [calibrated_hits] = input;
0160 
0161   // use unordered map to efficiently search for hits by CellID
0162   // store the index of hits instead of the hit itself
0163   // UnionFind can only group integer objects, not edm4eic::TrackerHit
0164   std::unordered_map<dd4hep::rec::CellID, std::vector<int>> hitIDsByCells;
0165 
0166   for (size_t hitID = 0; hitID < calibrated_hits->size(); ++hitID) {
0167     hitIDsByCells[calibrated_hits->at(hitID).getCellID()].push_back(hitID);
0168   }
0169 
0170   // merge neighbors by union find
0171   UnionFind uf(static_cast<int>(calibrated_hits->size()));
0172   for (auto [cellID, hitIDs] : hitIDsByCells) {
0173     // code copied from SiliconChargeSharing for neighbor finding
0174     const auto* element = &m_converter->findContext(cellID)->element; // volume context
0175     auto [segmentationIt, segmentationInserted] =
0176         m_segmentation_map.try_emplace(element, getLocalSegmentation(cellID));
0177 
0178     std::set<dd4hep::rec::CellID> cellNeighbors;
0179     segmentationIt->second->neighbours(cellID, cellNeighbors);
0180     // find if there are hits in neighboring cells
0181     for (const auto& neighborCandidates : cellNeighbors) {
0182       auto it = hitIDsByCells.find(neighborCandidates);
0183       if (it != hitIDsByCells.end()) {
0184         for (const auto& hitID1 : hitIDs) {
0185           for (const auto& hitID2 : it->second) {
0186             const auto& hit1 = calibrated_hits->at(hitID1);
0187             const auto& hit2 = calibrated_hits->at(hitID2);
0188             // only consider hits with time difference < deltaT as the same cluster
0189             if (std::fabs(hit1.getTime() - hit2.getTime()) < m_cfg.deltaT) {
0190               uf.merge(hitID1, hitID2);
0191             }
0192           }
0193         }
0194       }
0195     }
0196   }
0197 
0198   // group hits by cluster parent index according to union find algorithm
0199   std::unordered_map<int, std::vector<edm4eic::TrackerHit>> clusters;
0200   for (size_t hitID = 0; hitID < calibrated_hits->size(); ++hitID) {
0201     clusters[uf.find(hitID)].push_back(calibrated_hits->at(hitID));
0202   }
0203 
0204   // calculated weighted averages
0205   for (auto& [_, cluster] : clusters) {
0206     this->_calcCluster(output, cluster);
0207   }
0208 }
0209 
0210 // copied from SiliconChargeSharing
0211 // Get the segmentation relevant to a cellID
0212 const dd4hep::DDSegmentation::CartesianGridXY*
0213 LGADHitClustering::getLocalSegmentation(const dd4hep::rec::CellID& cellID) const {
0214   // Get the segmentation type
0215   auto segmentation_type                                   = m_seg.type();
0216   const dd4hep::DDSegmentation::Segmentation* segmentation = m_seg.segmentation();
0217   // Check if the segmentation is a multi-segmentation
0218   while (segmentation_type == "MultiSegmentation") {
0219     const auto* multi_segmentation =
0220         dynamic_cast<const dd4hep::DDSegmentation::MultiSegmentation*>(segmentation);
0221     segmentation      = &multi_segmentation->subsegmentation(cellID);
0222     segmentation_type = segmentation->type();
0223   }
0224 
0225   // Try to cast the segmentation to CartesianGridXY
0226   const auto* cartesianGrid =
0227       dynamic_cast<const dd4hep::DDSegmentation::CartesianGridXY*>(segmentation);
0228   if (cartesianGrid == nullptr) {
0229     throw std::runtime_error("Segmentation is not of type CartesianGridXY");
0230   }
0231 
0232   return cartesianGrid;
0233 }
0234 
0235 LGADHitClustering::UnionFind::UnionFind(int n) : mParent(n, 0), mRank(n, 0) {
0236   for (int i = 0; i < n; ++i) {
0237     mParent[i] = i;
0238   }
0239 }
0240 
0241 int LGADHitClustering::UnionFind::find(int id) {
0242   if (mParent[id] == id) {
0243     return id;
0244   }
0245   return mParent[id] = find(mParent[id]); // path compression
0246 }
0247 
0248 void LGADHitClustering::UnionFind::merge(int id1, int id2) {
0249   auto root1 = find(id1);
0250   auto root2 = find(id2);
0251 
0252   if (root1 != root2) {
0253     if (mRank[root1] > mRank[root2]) {
0254       mParent[root2] = root1;
0255     } else if (mRank[root1] < mRank[root2]) {
0256       mParent[root1] = root2;
0257     } else {
0258       mParent[root1] = root2;
0259       mRank[root2]++;
0260     }
0261   }
0262 }
0263 
0264 } // namespace eicrecon