Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-11 07:49:41

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/Clusterization/Clusterization.hpp"
0012 
0013 #include <algorithm>
0014 #include <array>
0015 #include <ranges>
0016 #include <vector>
0017 
0018 #include <boost/pending/disjoint_sets.hpp>
0019 
0020 namespace Acts::Ccl::internal {
0021 
0022 template <typename Cluster>
0023 void reserve(Cluster& /*cl*/, std::size_t /*n*/) {}
0024 
0025 template <Acts::Ccl::CanReserve Cluster>
0026 void reserve(Cluster& cl, std::size_t n) {
0027   clusterReserve(cl, n);
0028 }
0029 
0030 template <typename Cell, std::size_t GridDim>
0031 struct Compare {
0032   static_assert(GridDim != 1 && GridDim != 2,
0033                 "Only grid dimensions of 1 or 2 are supported");
0034 };
0035 
0036 // Comparator function object for cells, column-wise ordering
0037 // Specialization for 1-D grids
0038 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0039 struct Compare<Cell, 1> {
0040   bool operator()(const Cell& c0, const Cell& c1) const {
0041     int col0 = getCellColumn(c0);
0042     int col1 = getCellColumn(c1);
0043     return col0 < col1;
0044   }
0045 };
0046 
0047 // Specialization for 2-D grid
0048 template <typename Cell>
0049   requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0050            Acts::Ccl::HasRetrievableRowInfo<Cell>)
0051 struct Compare<Cell, 2> {
0052   bool operator()(const Cell& c0, const Cell& c1) const {
0053     int row0 = getCellRow(c0);
0054     int row1 = getCellRow(c1);
0055     int col0 = getCellColumn(c0);
0056     int col1 = getCellColumn(c1);
0057     return (col0 == col1) ? row0 < row1 : col0 < col1;
0058   }
0059 };
0060 
0061 // Simple wrapper around boost::disjoint_sets. In theory, could use
0062 // boost::vector_property_map and use boost::disjoint_sets without
0063 // wrapping, but it's way slower
0064 class DisjointSets {
0065  public:
0066   explicit DisjointSets(std::size_t initial_size = 128)
0067       : m_size(initial_size),
0068         m_rank(m_size),
0069         m_parent(m_size),
0070         m_ds(&m_rank[0], &m_parent[0]) {}
0071 
0072   Label makeSet() {
0073     // Empirically, m_size = 128 seems to be good default. If we
0074     // exceed this, take a performance hit and do the right thing.
0075     while (m_globalId >= m_size) {
0076       m_size *= 2;
0077       m_rank.resize(m_size);
0078       m_parent.resize(m_size);
0079       m_ds = boost::disjoint_sets<std::size_t*, std::size_t*>(&m_rank[0],
0080                                                               &m_parent[0]);
0081     }
0082     m_ds.make_set(m_globalId);
0083     return static_cast<Label>(m_globalId++);
0084   }
0085 
0086   void unionSet(std::size_t x, std::size_t y) { m_ds.union_set(x, y); }
0087   Label findSet(std::size_t x) { return static_cast<Label>(m_ds.find_set(x)); }
0088 
0089  private:
0090   std::size_t m_globalId = 1;
0091   std::size_t m_size;
0092   std::vector<std::size_t> m_rank;
0093   std::vector<std::size_t> m_parent;
0094   boost::disjoint_sets<std::size_t*, std::size_t*> m_ds;
0095 };
0096 
0097 template <std::size_t BufSize>
0098 struct ConnectionsBase {
0099   std::size_t nconn{0};
0100   std::array<Label, BufSize> buf;
0101   ConnectionsBase() { std::ranges::fill(buf, NO_LABEL); }
0102 };
0103 
0104 template <std::size_t GridDim>
0105 class Connections {};
0106 
0107 // On 1-D grid, cells have 1 backward neighbor
0108 template <>
0109 struct Connections<1> : public ConnectionsBase<1> {
0110   using ConnectionsBase::ConnectionsBase;
0111 };
0112 
0113 // On a 2-D grid, cells have 4 backward neighbors
0114 template <>
0115 struct Connections<2> : public ConnectionsBase<4> {
0116   using ConnectionsBase::ConnectionsBase;
0117 };
0118 
0119 // Cell collection logic
0120 template <typename Cell, typename Connect, std::size_t GridDim>
0121 Connections<GridDim> getConnections(std::size_t idx, std::vector<Cell>& cells,
0122                                     std::vector<Label>& labels,
0123                                     Connect&& connect) {
0124   Connections<GridDim> seen;
0125 
0126   for (std::size_t i = 0; i < idx; ++i) {
0127     std::size_t idx2 = idx - i - 1;
0128     ConnectResult cr = connect(cells[idx], cells[idx2]);
0129 
0130     if (cr == ConnectResult::eNoConnStop) {
0131       break;
0132     }
0133     if (cr == ConnectResult::eNoConn) {
0134       continue;
0135     }
0136     if (cr == ConnectResult::eConn) {
0137       seen.buf[seen.nconn] = labels[idx2];
0138       seen.nconn += 1;
0139       if (seen.nconn == seen.buf.size()) {
0140         break;
0141       }
0142     }
0143   }
0144 
0145   return seen;
0146 }
0147 
0148 template <typename CellCollection, typename ClusterCollection>
0149   requires(Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
0150                                     typename ClusterCollection::value_type>)
0151 ClusterCollection mergeClustersImpl(CellCollection& cells,
0152                                     const std::vector<Label>& cellLabels,
0153                                     const std::vector<std::size_t>& nClusters) {
0154   using Cluster = typename ClusterCollection::value_type;
0155 
0156   // Accumulate clusters into the output collection
0157   ClusterCollection clusters(nClusters.size());
0158   for (std::size_t i = 0; i < clusters.size(); ++i) {
0159     reserve(clusters[i], nClusters[i]);
0160   }
0161 
0162   // Fill clusters with cells
0163   for (std::size_t i = 0; i < cells.size(); ++i) {
0164     Label label = cellLabels[i] - 1;
0165     Cluster& cl = clusters[label];
0166     clusterAddCell(cl, cells[i]);
0167   }
0168 
0169   // Due to previous merging, we may have now clusters with
0170   // no cells. We need to remove them
0171   std::size_t invalidClusters = 0ul;
0172   for (std::size_t i = 0; i < clusters.size(); ++i) {
0173     std::size_t idx = clusters.size() - i - 1;
0174     if (nClusters[idx] != 0) {
0175       continue;
0176     }
0177     // we have an invalid cluster.
0178     // move them all to the back so that we can remove
0179     // them later
0180     std::swap(clusters[idx], clusters[clusters.size() - invalidClusters - 1]);
0181     ++invalidClusters;
0182   }
0183   clusters.resize(clusters.size() - invalidClusters);
0184 
0185   return clusters;
0186 }
0187 
0188 }  // namespace Acts::Ccl::internal
0189 
0190 namespace Acts::Ccl {
0191 
0192 template <typename Cell>
0193   requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0194            Acts::Ccl::HasRetrievableRowInfo<Cell>)
0195 ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
0196                                           const Cell& iter) const {
0197   int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
0198   int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0199   // Iteration is column-wise, so if too far in column, can
0200   // safely stop
0201   if (deltaCol > 1) {
0202     return ConnectResult::eNoConnStop;
0203   }
0204   // For same reason, if too far in row we know the pixel is not
0205   // connected, but need to keep iterating
0206   if (deltaRow > 1) {
0207     return ConnectResult::eNoConn;
0208   }
0209   // Decide whether or not cluster is connected based on 4- or
0210   // 8-connectivity
0211   if ((deltaRow + deltaCol) <= (conn8 ? 2 : 1)) {
0212     return ConnectResult::eConn;
0213   }
0214   return ConnectResult::eNoConn;
0215 }
0216 
0217 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0218 ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
0219                                           const Cell& iter) const {
0220   int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0221   return deltaCol == 1 ? ConnectResult::eConn : ConnectResult::eNoConnStop;
0222 }
0223 
0224 template <std::size_t GridDim>
0225 void recordEquivalences(const internal::Connections<GridDim> seen,
0226                         internal::DisjointSets& ds) {
0227   // Sanity check: first element should always have
0228   // label if nconn > 0
0229   if (seen.nconn > 0 && seen.buf[0] == NO_LABEL) {
0230     throw std::logic_error("seen.nconn > 0 but seen.buf[0] == NO_LABEL");
0231   }
0232   for (std::size_t i = 1; i < seen.nconn; i++) {
0233     // Sanity check: since connection lookup is always backward
0234     // while iteration is forward, all connected cells found here
0235     // should have a label
0236     if (seen.buf[i] == NO_LABEL) {
0237       throw std::logic_error("i < seen.nconn but see.buf[i] == NO_LABEL");
0238     }
0239     // Only record equivalence if needed
0240     if (seen.buf[0] != seen.buf[i]) {
0241       ds.unionSet(seen.buf[0], seen.buf[i]);
0242     }
0243   }
0244 }
0245 
0246 template <typename CellCollection, std::size_t GridDim, typename Connect>
0247 std::vector<std::size_t> labelClusters(CellCollection& cells,
0248                                        std::vector<Label>& cellLabels,
0249                                        Connect&& connect) {
0250   using Cell = typename CellCollection::value_type;
0251 
0252   internal::DisjointSets ds{};
0253 
0254   // Sort cells by position to enable in-order scan
0255   std::ranges::sort(cells, internal::Compare<Cell, GridDim>());
0256 
0257   // First pass: Allocate labels and record equivalences
0258   for (std::size_t nCell(0ul); nCell < cells.size(); ++nCell) {
0259     const internal::Connections<GridDim> seen =
0260         internal::getConnections<Cell, Connect, GridDim>(
0261             nCell, cells, cellLabels, std::forward<Connect>(connect));
0262 
0263     if (seen.nconn == 0) {
0264       // Allocate new label
0265       cellLabels[nCell] = ds.makeSet();
0266     } else {
0267       recordEquivalences(seen, ds);
0268       // Set label for current cell
0269       cellLabels[nCell] = seen.buf[0];
0270     }
0271   }  // loop on cells
0272 
0273   // Second pass: Merge labels based on recorded equivalences
0274   int maxNClusters = 0;
0275   for (Label& lbl : cellLabels) {
0276     lbl = ds.findSet(lbl);
0277     maxNClusters = std::max(maxNClusters, lbl);
0278   }
0279 
0280   // Third pass: Keep count of how many cells go in each
0281   // to-be-created clusters
0282   std::vector<std::size_t> nClusters(maxNClusters, 0);
0283   for (const Label label : cellLabels) {
0284     ++nClusters[label - 1];
0285   }
0286 
0287   return nClusters;
0288 }
0289 
0290 template <typename CellCollection, typename ClusterCollection,
0291           std::size_t GridDim = 2>
0292   requires(GridDim == 1 || GridDim == 2)
0293 ClusterCollection mergeClusters(CellCollection& cells,
0294                                 const std::vector<Label>& cellLabels,
0295                                 const std::vector<std::size_t>& nClusters) {
0296   return internal::mergeClustersImpl<CellCollection, ClusterCollection>(
0297       cells, cellLabels, nClusters);
0298 }
0299 
0300 template <typename CellCollection, typename ClusterCollection,
0301           std::size_t GridDim, typename Connect>
0302 ClusterCollection createClusters(CellCollection& cells, Connect&& connect) {
0303   if (cells.empty()) {
0304     return {};
0305   }
0306   std::vector<Label> cellLabels(cells.size(), NO_LABEL);
0307   std::vector<std::size_t> nClusters =
0308       labelClusters<CellCollection, GridDim, Connect>(
0309           cells, cellLabels, std::forward<Connect>(connect));
0310   return mergeClusters<CellCollection, ClusterCollection, GridDim>(
0311       cells, cellLabels, nClusters);
0312 }
0313 
0314 }  // namespace Acts::Ccl