Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:10:45

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include <algorithm>
0010 #include <array>
0011 #include <vector>
0012 
0013 #include <boost/pending/disjoint_sets.hpp>
0014 
0015 namespace Acts::Ccl::internal {
0016 
0017 template <typename Cell, std::size_t GridDim>
0018 struct Compare {
0019   static_assert(GridDim != 1 && GridDim != 2,
0020                 "Only grid dimensions of 1 or 2 are supported");
0021 };
0022 
0023 // Comparator function object for cells, column-wise ordering
0024 // Specialization for 1-D grids
0025 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0026 struct Compare<Cell, 1> {
0027   bool operator()(const Cell& c0, const Cell& c1) const {
0028     int col0 = getCellColumn(c0);
0029     int col1 = getCellColumn(c1);
0030     return col0 < col1;
0031   }
0032 };
0033 
0034 // Specialization for 2-D grid
0035 template <typename Cell>
0036   requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0037            Acts::Ccl::HasRetrievableRowInfo<Cell>)
0038 struct Compare<Cell, 2> {
0039   bool operator()(const Cell& c0, const Cell& c1) const {
0040     int row0 = getCellRow(c0);
0041     int row1 = getCellRow(c1);
0042     int col0 = getCellColumn(c0);
0043     int col1 = getCellColumn(c1);
0044     return (col0 == col1) ? row0 < row1 : col0 < col1;
0045   }
0046 };
0047 
0048 // Simple wrapper around boost::disjoint_sets. In theory, could use
0049 // boost::vector_property_map and use boost::disjoint_sets without
0050 // wrapping, but it's way slower
0051 class DisjointSets {
0052  public:
0053   explicit DisjointSets(std::size_t initial_size = 128)
0054       : m_size(initial_size),
0055         m_rank(m_size),
0056         m_parent(m_size),
0057         m_ds(&m_rank[0], &m_parent[0]) {}
0058 
0059   Label makeSet() {
0060     // Empirically, m_size = 128 seems to be good default. If we
0061     // exceed this, take a performance hit and do the right thing.
0062     while (m_globalId >= m_size) {
0063       m_size *= 2;
0064       m_rank.resize(m_size);
0065       m_parent.resize(m_size);
0066       m_ds = boost::disjoint_sets<std::size_t*, std::size_t*>(&m_rank[0],
0067                                                               &m_parent[0]);
0068     }
0069     m_ds.make_set(m_globalId);
0070     return static_cast<Label>(m_globalId++);
0071   }
0072 
0073   void unionSet(std::size_t x, std::size_t y) { m_ds.union_set(x, y); }
0074   Label findSet(std::size_t x) { return static_cast<Label>(m_ds.find_set(x)); }
0075 
0076  private:
0077   std::size_t m_globalId = 1;
0078   std::size_t m_size;
0079   std::vector<std::size_t> m_rank;
0080   std::vector<std::size_t> m_parent;
0081   boost::disjoint_sets<std::size_t*, std::size_t*> m_ds;
0082 };
0083 
0084 template <std::size_t BufSize>
0085 struct ConnectionsBase {
0086   std::size_t nconn{0};
0087   std::array<Label, BufSize> buf;
0088   ConnectionsBase() { std::fill(buf.begin(), buf.end(), NO_LABEL); }
0089 };
0090 
0091 template <std::size_t GridDim>
0092 class Connections {};
0093 
0094 // On 1-D grid, cells have 1 backward neighbor
0095 template <>
0096 struct Connections<1> : public ConnectionsBase<1> {
0097   using ConnectionsBase::ConnectionsBase;
0098 };
0099 
0100 // On a 2-D grid, cells have 4 backward neighbors
0101 template <>
0102 struct Connections<2> : public ConnectionsBase<4> {
0103   using ConnectionsBase::ConnectionsBase;
0104 };
0105 
0106 // Cell collection logic
0107 template <typename Cell, typename Connect, std::size_t GridDim>
0108 Connections<GridDim> getConnections(typename std::vector<Cell>::iterator it,
0109                                     std::vector<Cell>& set, Connect connect) {
0110   Connections<GridDim> seen;
0111   typename std::vector<Cell>::iterator it_2{it};
0112 
0113   while (it_2 != set.begin()) {
0114     it_2 = std::prev(it_2);
0115 
0116     ConnectResult cr = connect(*it, *it_2);
0117     if (cr == ConnectResult::eNoConnStop) {
0118       break;
0119     }
0120     if (cr == ConnectResult::eNoConn) {
0121       continue;
0122     }
0123     if (cr == ConnectResult::eConn) {
0124       seen.buf[seen.nconn] = getCellLabel(*it_2);
0125       seen.nconn += 1;
0126       if (seen.nconn == seen.buf.size()) {
0127         break;
0128       }
0129     }
0130   }
0131   return seen;
0132 }
0133 
0134 template <typename CellCollection, typename ClusterCollection>
0135   requires(
0136       Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type> &&
0137       Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
0138                                typename ClusterCollection::value_type>)
0139 ClusterCollection mergeClustersImpl(CellCollection& cells) {
0140   using Cluster = typename ClusterCollection::value_type;
0141 
0142   if (cells.empty()) {
0143     return {};
0144   }
0145 
0146   // Accumulate clusters into the output collection
0147   ClusterCollection outv;
0148   Cluster cl;
0149   int lbl = getCellLabel(cells.front());
0150   for (auto& cell : cells) {
0151     if (getCellLabel(cell) != lbl) {
0152       // New cluster, save previous one
0153       outv.push_back(std::move(cl));
0154       cl = Cluster();
0155       lbl = getCellLabel(cell);
0156     }
0157     clusterAddCell(cl, cell);
0158   }
0159   // Get the last cluster as well
0160   outv.push_back(std::move(cl));
0161 
0162   return outv;
0163 }
0164 
0165 }  // namespace Acts::Ccl::internal
0166 
0167 namespace Acts::Ccl {
0168 
0169 template <typename Cell>
0170   requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0171            Acts::Ccl::HasRetrievableRowInfo<Cell>)
0172 ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
0173                                           const Cell& iter) const {
0174   int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
0175   int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0176   // Iteration is column-wise, so if too far in column, can
0177   // safely stop
0178   if (deltaCol > 1) {
0179     return ConnectResult::eNoConnStop;
0180   }
0181   // For same reason, if too far in row we know the pixel is not
0182   // connected, but need to keep iterating
0183   if (deltaRow > 1) {
0184     return ConnectResult::eNoConn;
0185   }
0186   // Decide whether or not cluster is connected based on 4- or
0187   // 8-connectivity
0188   if ((deltaRow + deltaCol) <= (conn8 ? 2 : 1)) {
0189     return ConnectResult::eConn;
0190   }
0191   return ConnectResult::eNoConn;
0192 }
0193 
0194 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0195 ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
0196                                           const Cell& iter) const {
0197   int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0198   return deltaCol == 1 ? ConnectResult::eConn : ConnectResult::eNoConnStop;
0199 }
0200 
0201 template <std::size_t GridDim>
0202 void recordEquivalences(const internal::Connections<GridDim> seen,
0203                         internal::DisjointSets& ds) {
0204   // Sanity check: first element should always have
0205   // label if nconn > 0
0206   if (seen.nconn > 0 && seen.buf[0] == NO_LABEL) {
0207     throw std::logic_error("seen.nconn > 0 but seen.buf[0] == NO_LABEL");
0208   }
0209   for (std::size_t i = 1; i < seen.nconn; i++) {
0210     // Sanity check: since connection lookup is always backward
0211     // while iteration is forward, all connected cells found here
0212     // should have a label
0213     if (seen.buf[i] == NO_LABEL) {
0214       throw std::logic_error("i < seen.nconn but see.buf[i] == NO_LABEL");
0215     }
0216     // Only record equivalence if needed
0217     if (seen.buf[0] != seen.buf[i]) {
0218       ds.unionSet(seen.buf[0], seen.buf[i]);
0219     }
0220   }
0221 }
0222 
0223 template <typename CellCollection, std::size_t GridDim, typename Connect>
0224   requires(
0225       Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type>)
0226 void labelClusters(CellCollection& cells, Connect connect) {
0227   using Cell = typename CellCollection::value_type;
0228 
0229   internal::DisjointSets ds{};
0230 
0231   // Sort cells by position to enable in-order scan
0232   std::ranges::sort(cells, internal::Compare<Cell, GridDim>());
0233 
0234   // First pass: Allocate labels and record equivalences
0235   for (auto it = std::ranges::begin(cells); it != std::ranges::end(cells);
0236        ++it) {
0237     const internal::Connections<GridDim> seen =
0238         internal::getConnections<Cell, Connect, GridDim>(it, cells, connect);
0239     if (seen.nconn == 0) {
0240       // Allocate new label
0241       getCellLabel(*it) = ds.makeSet();
0242     } else {
0243       recordEquivalences(seen, ds);
0244       // Set label for current cell
0245       getCellLabel(*it) = seen.buf[0];
0246     }
0247   }
0248 
0249   // Second pass: Merge labels based on recorded equivalences
0250   for (auto& cell : cells) {
0251     Label& lbl = getCellLabel(cell);
0252     lbl = ds.findSet(lbl);
0253   }
0254 }
0255 
0256 template <typename CellCollection, typename ClusterCollection,
0257           std::size_t GridDim = 2>
0258   requires(GridDim == 1 || GridDim == 2) &&
0259           Acts::Ccl::HasRetrievableLabelInfo<
0260               typename CellCollection::value_type>
0261 ClusterCollection mergeClusters(CellCollection& cells) {
0262   using Cell = typename CellCollection::value_type;
0263   if constexpr (GridDim > 1) {
0264     // Sort the cells by their cluster label, only needed if more than
0265     // one spatial dimension
0266     std::ranges::sort(cells, {}, [](Cell& c) { return getCellLabel(c); });
0267   }
0268 
0269   return internal::mergeClustersImpl<CellCollection, ClusterCollection>(cells);
0270 }
0271 
0272 template <typename CellCollection, typename ClusterCollection,
0273           std::size_t GridDim, typename Connect>
0274 ClusterCollection createClusters(CellCollection& cells, Connect connect) {
0275   labelClusters<CellCollection, GridDim, Connect>(cells, connect);
0276   return mergeClusters<CellCollection, ClusterCollection, GridDim>(cells);
0277 }
0278 
0279 }  // namespace Acts::Ccl