Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-10-13 08:23:39

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2025 Chun Yuen Tsang
0003 
0004 #include "LGADHitClustering.h"
0005 
0006 #include <Acts/Definitions/Algebra.hpp>
0007 #include <Acts/Definitions/Units.hpp>
0008 #include <Acts/Geometry/GeometryIdentifier.hpp>
0009 #include <Acts/Surfaces/Surface.hpp>
0010 #include <DD4hep/Handle.h>
0011 #include <DD4hep/Readout.h>
0012 #include <DD4hep/VolumeManager.h>
0013 #include <DD4hep/detail/SegmentationsInterna.h>
0014 #include <DDSegmentation/MultiSegmentation.h>
0015 #include <DDSegmentation/Segmentation.h>
0016 #include <Evaluator/DD4hepUnits.h>
0017 #include <Math/GenVector/Cartesian3D.h>
0018 #include <Math/GenVector/DisplacementVector3D.h>
0019 #include <ROOT/RVec.hxx>
0020 #include <algorithms/geo.h>
0021 #include <edm4eic/Cov3f.h>
0022 #include <edm4eic/CovDiag3f.h>
0023 #include <edm4hep/Vector2f.h>
0024 #include <fmt/core.h>
0025 #include <stddef.h>
0026 #include <Eigen/Core>
0027 #include <cmath>
0028 #include <gsl/pointers>
0029 #include <limits>
0030 #include <set>
0031 #include <stdexcept>
0032 #include <unordered_map>
0033 #include <utility>
0034 #include <vector>
0035 
0036 #include "ActsGeometryProvider.h"
0037 #include "algorithms/interfaces/ActsSvc.h"
0038 #include "algorithms/tracking/LGADHitClusteringConfig.h"
0039 
0040 namespace eicrecon {
0041 
0042 void LGADHitClustering::init() {
0043 
0044   m_converter    = algorithms::GeoSvc::instance().cellIDPositionConverter();
0045   m_detector     = algorithms::GeoSvc::instance().detector();
0046   m_seg          = m_detector->readout(m_cfg.readout).segmentation();
0047   auto type      = m_seg.type();
0048   m_decoder      = m_seg.decoder();
0049   m_acts_context = algorithms::ActsSvc::instance().acts_geometry_provider();
0050 }
0051 
0052 void LGADHitClustering::_calcCluster(const Output& output,
0053                                      const std::vector<edm4eic::TrackerHit>& hits) const {
0054   if (hits.size() == 0)
0055     return;
0056   constexpr double mm_acts = Acts::UnitConstants::mm;
0057   using dd4hep::mm;
0058 
0059   auto [clusters] = output;
0060   auto cluster    = clusters->create();
0061   // Right now the clustering algorithm is either:
0062   // 1. simple average over all hits in a sensors
0063   // 2. Cell position with max ADC value in a cluster
0064   // Switch between option 1 and 2 with m_cfg.useAve
0065   // Will be problematic near the edges, but it's just an illustration
0066   float ave_x = 0, ave_y = 0;
0067   float sigma2_x = 0, sigma2_y = 0;
0068   double tot_charge = 0;
0069   // find cellID for the cell with maximum ADC value within a sensor
0070   dd4hep::rec::CellID cellID;
0071   auto max_charge    = std::numeric_limits<float>::min();
0072   auto earliest_time = std::numeric_limits<float>::max();
0073   float time_err;
0074   float max_charge_x;
0075   float max_charge_y;
0076   float max_charge_sigma2_x;
0077   float max_charge_sigma2_y;
0078 
0079   ROOT::VecOps::RVec<double> weights;
0080 
0081   for (size_t id = 0; id < hits.size(); ++id) {
0082     const auto& hit = hits[id];
0083     if (hit.getTime() < earliest_time) {
0084       earliest_time = hit.getTime();
0085       time_err      = hit.getTimeError();
0086     }
0087     // weigh all hits by ADC value
0088     auto pos = m_seg->position(hit.getCellID());
0089     if (hit.getEdep() < 0)
0090       error("Edep for hit at cellID{} is negative. Please check the accuracy of your energy "
0091             "calibration. ",
0092             hit.getCellID());
0093     const auto Edep = hit.getEdep();
0094     ave_x += Edep * pos.x();
0095     ave_y += Edep * pos.y();
0096     sigma2_x += Edep * Edep * hit.getPositionError().xx * mm_acts * mm_acts;
0097     sigma2_y += Edep * Edep * hit.getPositionError().yy * mm_acts * mm_acts;
0098 
0099     tot_charge += Edep;
0100     if (Edep > max_charge) {
0101       max_charge          = Edep;
0102       cellID              = hit.getCellID();
0103       max_charge_x        = pos.x();
0104       max_charge_y        = pos.y();
0105       max_charge_sigma2_x = hit.getPositionError().xx * mm_acts * mm_acts;
0106       max_charge_sigma2_y = hit.getPositionError().yy * mm_acts * mm_acts;
0107     }
0108     cluster.addToHits(hit);
0109     weights.push_back(Edep);
0110   }
0111 
0112   if (m_cfg.useAve) {
0113     weights /= tot_charge;
0114     ave_x /= tot_charge;
0115     ave_y /= tot_charge;
0116     sigma2_x /= tot_charge * tot_charge;
0117     sigma2_y /= tot_charge * tot_charge;
0118   } else {
0119     ave_x    = max_charge_x;
0120     ave_y    = max_charge_y;
0121     sigma2_x = max_charge_sigma2_x;
0122     sigma2_y = max_charge_sigma2_y;
0123   }
0124 
0125   // covariance copied from TrackerMeasurementFromHits.cc
0126   Acts::SquareMatrix2 cov = Acts::SquareMatrix2::Zero();
0127   cov(0, 0)               = sigma2_x;
0128   cov(1, 1)               = sigma2_y;
0129   cov(0, 1) = cov(1, 0) = 0.0;
0130 
0131   for (const auto& w : weights)
0132     cluster.addToWeights(w);
0133 
0134   edm4eic::Cov3f covariance;
0135   edm4hep::Vector2f locPos{static_cast<float>(ave_x / mm), static_cast<float>(ave_y / mm)};
0136 
0137   const auto* context    = m_converter->findContext(cellID);
0138   auto volID             = context->identifier;
0139   const auto& surfaceMap = m_acts_context->surfaceMap();
0140   const auto is          = surfaceMap.find(volID);
0141   if (is == surfaceMap.end())
0142     error("vol_id ({})  not found in m_surfaces.", volID);
0143 
0144   const Acts::Surface* surface = is->second;
0145 
0146   cluster.setSurface(surface->geometryId().value());
0147   cluster.setLoc(locPos);
0148   cluster.setTime(earliest_time);
0149   cluster.setCovariance(
0150       {cov(0, 0), cov(1, 1), time_err * time_err, cov(0, 1)}); // Covariance on location and time
0151 }
0152 
0153 void LGADHitClustering::process(const LGADHitClustering::Input& input,
0154                                 const LGADHitClustering::Output& output) const {
0155   const auto [calibrated_hits] = input;
0156 
0157   // use unordered map to efficiently search for hits by CellID
0158   // store the index of hits instead of the hit itself
0159   // UnionFind can only group integer objects, not edm4eic::TrackerHit
0160   std::unordered_map<dd4hep::rec::CellID, std::vector<int>> hitIDsByCells;
0161 
0162   for (size_t hitID = 0; hitID < calibrated_hits->size(); ++hitID) {
0163     hitIDsByCells[calibrated_hits->at(hitID).getCellID()].push_back(hitID);
0164   }
0165 
0166   // merge neighbors by union find
0167   UnionFind uf(static_cast<int>(calibrated_hits->size()));
0168   for (auto [cellID, hitIDs] : hitIDsByCells) {
0169     // code copied from SiliconChargeSharing for neighbor finding
0170     const auto* element = &m_converter->findContext(cellID)->element; // volume context
0171     auto [segmentationIt, segmentationInserted] =
0172         m_segmentation_map.try_emplace(element, getLocalSegmentation(cellID));
0173 
0174     std::set<dd4hep::rec::CellID> cellNeighbors;
0175     segmentationIt->second->neighbours(cellID, cellNeighbors);
0176     // find if there are hits in neighboring cells
0177     for (const auto& neighborCandidates : cellNeighbors) {
0178       auto it = hitIDsByCells.find(neighborCandidates);
0179       if (it != hitIDsByCells.end()) {
0180         for (const auto& hitID1 : hitIDs)
0181           for (const auto& hitID2 : it->second) {
0182             const auto& hit1 = calibrated_hits->at(hitID1);
0183             const auto& hit2 = calibrated_hits->at(hitID2);
0184             // only consider hits with time difference < deltaT as the same cluster
0185             if (std::fabs(hit1.getTime() - hit2.getTime()) < m_cfg.deltaT)
0186               uf.merge(hitID1, hitID2);
0187           }
0188       }
0189     }
0190   }
0191 
0192   // group hits by cluster parent index according to union find algorithm
0193   std::unordered_map<int, std::vector<edm4eic::TrackerHit>> clusters;
0194   for (size_t hitID = 0; hitID < calibrated_hits->size(); ++hitID)
0195     clusters[uf.find(hitID)].push_back(calibrated_hits->at(hitID));
0196 
0197   // calculated weighted averages
0198   for (auto& [_, cluster] : clusters) {
0199     this->_calcCluster(output, cluster);
0200   }
0201 }
0202 
0203 // copied from SiliconChargeSharing
0204 // Get the segmentation relevant to a cellID
0205 const dd4hep::DDSegmentation::CartesianGridXY*
0206 LGADHitClustering::getLocalSegmentation(const dd4hep::rec::CellID& cellID) const {
0207   // Get the segmentation type
0208   auto segmentation_type                                   = m_seg.type();
0209   const dd4hep::DDSegmentation::Segmentation* segmentation = m_seg.segmentation();
0210   // Check if the segmentation is a multi-segmentation
0211   while (segmentation_type == "MultiSegmentation") {
0212     const auto* multi_segmentation =
0213         dynamic_cast<const dd4hep::DDSegmentation::MultiSegmentation*>(segmentation);
0214     segmentation      = &multi_segmentation->subsegmentation(cellID);
0215     segmentation_type = segmentation->type();
0216   }
0217 
0218   // Try to cast the segmentation to CartesianGridXY
0219   const auto* cartesianGrid =
0220       dynamic_cast<const dd4hep::DDSegmentation::CartesianGridXY*>(segmentation);
0221   if (cartesianGrid == nullptr) {
0222     throw std::runtime_error("Segmentation is not of type CartesianGridXY");
0223   }
0224 
0225   return cartesianGrid;
0226 }
0227 
0228 LGADHitClustering::UnionFind::UnionFind(int n) : mParent(n, 0), mRank(n, 0) {
0229   for (int i = 0; i < n; ++i)
0230     mParent[i] = i;
0231 }
0232 
0233 int LGADHitClustering::UnionFind::find(int id) {
0234   if (mParent[id] == id)
0235     return id;
0236   return mParent[id] = find(mParent[id]); // path compression
0237 }
0238 
0239 void LGADHitClustering::UnionFind::merge(int id1, int id2) {
0240   auto root1 = find(id1);
0241   auto root2 = find(id2);
0242 
0243   if (root1 != root2) {
0244     if (mRank[root1] > mRank[root2])
0245       mParent[root2] = root1;
0246     else if (mRank[root1] < mRank[root2])
0247       mParent[root1] = root2;
0248     else {
0249       mParent[root1] = root2;
0250       mRank[root2]++;
0251     }
0252   }
0253 }
0254 
0255 } // namespace eicrecon