File indexing completed on 2025-07-11 07:49:41
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Clusterization/Clusterization.hpp"
0012
0013 #include <algorithm>
0014 #include <array>
0015 #include <ranges>
0016 #include <vector>
0017
0018 #include <boost/pending/disjoint_sets.hpp>
0019
0020 namespace Acts::Ccl::internal {
0021
0022 template <typename Cluster>
0023 void reserve(Cluster& , std::size_t ) {}
0024
0025 template <Acts::Ccl::CanReserve Cluster>
0026 void reserve(Cluster& cl, std::size_t n) {
0027 clusterReserve(cl, n);
0028 }
0029
0030 template <typename Cell, std::size_t GridDim>
0031 struct Compare {
0032 static_assert(GridDim != 1 && GridDim != 2,
0033 "Only grid dimensions of 1 or 2 are supported");
0034 };
0035
0036
0037
0038 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0039 struct Compare<Cell, 1> {
0040 bool operator()(const Cell& c0, const Cell& c1) const {
0041 int col0 = getCellColumn(c0);
0042 int col1 = getCellColumn(c1);
0043 return col0 < col1;
0044 }
0045 };
0046
0047
0048 template <typename Cell>
0049 requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0050 Acts::Ccl::HasRetrievableRowInfo<Cell>)
0051 struct Compare<Cell, 2> {
0052 bool operator()(const Cell& c0, const Cell& c1) const {
0053 int row0 = getCellRow(c0);
0054 int row1 = getCellRow(c1);
0055 int col0 = getCellColumn(c0);
0056 int col1 = getCellColumn(c1);
0057 return (col0 == col1) ? row0 < row1 : col0 < col1;
0058 }
0059 };
0060
0061
0062
0063
0064 class DisjointSets {
0065 public:
0066 explicit DisjointSets(std::size_t initial_size = 128)
0067 : m_size(initial_size),
0068 m_rank(m_size),
0069 m_parent(m_size),
0070 m_ds(&m_rank[0], &m_parent[0]) {}
0071
0072 Label makeSet() {
0073
0074
0075 while (m_globalId >= m_size) {
0076 m_size *= 2;
0077 m_rank.resize(m_size);
0078 m_parent.resize(m_size);
0079 m_ds = boost::disjoint_sets<std::size_t*, std::size_t*>(&m_rank[0],
0080 &m_parent[0]);
0081 }
0082 m_ds.make_set(m_globalId);
0083 return static_cast<Label>(m_globalId++);
0084 }
0085
0086 void unionSet(std::size_t x, std::size_t y) { m_ds.union_set(x, y); }
0087 Label findSet(std::size_t x) { return static_cast<Label>(m_ds.find_set(x)); }
0088
0089 private:
0090 std::size_t m_globalId = 1;
0091 std::size_t m_size;
0092 std::vector<std::size_t> m_rank;
0093 std::vector<std::size_t> m_parent;
0094 boost::disjoint_sets<std::size_t*, std::size_t*> m_ds;
0095 };
0096
0097 template <std::size_t BufSize>
0098 struct ConnectionsBase {
0099 std::size_t nconn{0};
0100 std::array<Label, BufSize> buf;
0101 ConnectionsBase() { std::ranges::fill(buf, NO_LABEL); }
0102 };
0103
0104 template <std::size_t GridDim>
0105 class Connections {};
0106
0107
0108 template <>
0109 struct Connections<1> : public ConnectionsBase<1> {
0110 using ConnectionsBase::ConnectionsBase;
0111 };
0112
0113
0114 template <>
0115 struct Connections<2> : public ConnectionsBase<4> {
0116 using ConnectionsBase::ConnectionsBase;
0117 };
0118
0119
0120 template <typename Cell, typename Connect, std::size_t GridDim>
0121 Connections<GridDim> getConnections(std::size_t idx, std::vector<Cell>& cells,
0122 std::vector<Label>& labels,
0123 Connect&& connect) {
0124 Connections<GridDim> seen;
0125
0126 for (std::size_t i = 0; i < idx; ++i) {
0127 std::size_t idx2 = idx - i - 1;
0128 ConnectResult cr = connect(cells[idx], cells[idx2]);
0129
0130 if (cr == ConnectResult::eNoConnStop) {
0131 break;
0132 }
0133 if (cr == ConnectResult::eNoConn) {
0134 continue;
0135 }
0136 if (cr == ConnectResult::eConn) {
0137 seen.buf[seen.nconn] = labels[idx2];
0138 seen.nconn += 1;
0139 if (seen.nconn == seen.buf.size()) {
0140 break;
0141 }
0142 }
0143 }
0144
0145 return seen;
0146 }
0147
0148 template <typename CellCollection, typename ClusterCollection>
0149 requires(Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
0150 typename ClusterCollection::value_type>)
0151 ClusterCollection mergeClustersImpl(CellCollection& cells,
0152 const std::vector<Label>& cellLabels,
0153 const std::vector<std::size_t>& nClusters) {
0154 using Cluster = typename ClusterCollection::value_type;
0155
0156
0157 ClusterCollection clusters(nClusters.size());
0158 for (std::size_t i = 0; i < clusters.size(); ++i) {
0159 reserve(clusters[i], nClusters[i]);
0160 }
0161
0162
0163 for (std::size_t i = 0; i < cells.size(); ++i) {
0164 Label label = cellLabels[i] - 1;
0165 Cluster& cl = clusters[label];
0166 clusterAddCell(cl, cells[i]);
0167 }
0168
0169
0170
0171 std::size_t invalidClusters = 0ul;
0172 for (std::size_t i = 0; i < clusters.size(); ++i) {
0173 std::size_t idx = clusters.size() - i - 1;
0174 if (nClusters[idx] != 0) {
0175 continue;
0176 }
0177
0178
0179
0180 std::swap(clusters[idx], clusters[clusters.size() - invalidClusters - 1]);
0181 ++invalidClusters;
0182 }
0183 clusters.resize(clusters.size() - invalidClusters);
0184
0185 return clusters;
0186 }
0187
0188 }
0189
0190 namespace Acts::Ccl {
0191
0192 template <typename Cell>
0193 requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0194 Acts::Ccl::HasRetrievableRowInfo<Cell>)
0195 ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
0196 const Cell& iter) const {
0197 int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
0198 int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0199
0200
0201 if (deltaCol > 1) {
0202 return ConnectResult::eNoConnStop;
0203 }
0204
0205
0206 if (deltaRow > 1) {
0207 return ConnectResult::eNoConn;
0208 }
0209
0210
0211 if ((deltaRow + deltaCol) <= (conn8 ? 2 : 1)) {
0212 return ConnectResult::eConn;
0213 }
0214 return ConnectResult::eNoConn;
0215 }
0216
0217 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0218 ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
0219 const Cell& iter) const {
0220 int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0221 return deltaCol == 1 ? ConnectResult::eConn : ConnectResult::eNoConnStop;
0222 }
0223
0224 template <std::size_t GridDim>
0225 void recordEquivalences(const internal::Connections<GridDim> seen,
0226 internal::DisjointSets& ds) {
0227
0228
0229 if (seen.nconn > 0 && seen.buf[0] == NO_LABEL) {
0230 throw std::logic_error("seen.nconn > 0 but seen.buf[0] == NO_LABEL");
0231 }
0232 for (std::size_t i = 1; i < seen.nconn; i++) {
0233
0234
0235
0236 if (seen.buf[i] == NO_LABEL) {
0237 throw std::logic_error("i < seen.nconn but see.buf[i] == NO_LABEL");
0238 }
0239
0240 if (seen.buf[0] != seen.buf[i]) {
0241 ds.unionSet(seen.buf[0], seen.buf[i]);
0242 }
0243 }
0244 }
0245
0246 template <typename CellCollection, std::size_t GridDim, typename Connect>
0247 std::vector<std::size_t> labelClusters(CellCollection& cells,
0248 std::vector<Label>& cellLabels,
0249 Connect&& connect) {
0250 using Cell = typename CellCollection::value_type;
0251
0252 internal::DisjointSets ds{};
0253
0254
0255 std::ranges::sort(cells, internal::Compare<Cell, GridDim>());
0256
0257
0258 for (std::size_t nCell(0ul); nCell < cells.size(); ++nCell) {
0259 const internal::Connections<GridDim> seen =
0260 internal::getConnections<Cell, Connect, GridDim>(
0261 nCell, cells, cellLabels, std::forward<Connect>(connect));
0262
0263 if (seen.nconn == 0) {
0264
0265 cellLabels[nCell] = ds.makeSet();
0266 } else {
0267 recordEquivalences(seen, ds);
0268
0269 cellLabels[nCell] = seen.buf[0];
0270 }
0271 }
0272
0273
0274 int maxNClusters = 0;
0275 for (Label& lbl : cellLabels) {
0276 lbl = ds.findSet(lbl);
0277 maxNClusters = std::max(maxNClusters, lbl);
0278 }
0279
0280
0281
0282 std::vector<std::size_t> nClusters(maxNClusters, 0);
0283 for (const Label label : cellLabels) {
0284 ++nClusters[label - 1];
0285 }
0286
0287 return nClusters;
0288 }
0289
0290 template <typename CellCollection, typename ClusterCollection,
0291 std::size_t GridDim = 2>
0292 requires(GridDim == 1 || GridDim == 2)
0293 ClusterCollection mergeClusters(CellCollection& cells,
0294 const std::vector<Label>& cellLabels,
0295 const std::vector<std::size_t>& nClusters) {
0296 return internal::mergeClustersImpl<CellCollection, ClusterCollection>(
0297 cells, cellLabels, nClusters);
0298 }
0299
0300 template <typename CellCollection, typename ClusterCollection,
0301 std::size_t GridDim, typename Connect>
0302 ClusterCollection createClusters(CellCollection& cells, Connect&& connect) {
0303 if (cells.empty()) {
0304 return {};
0305 }
0306 std::vector<Label> cellLabels(cells.size(), NO_LABEL);
0307 std::vector<std::size_t> nClusters =
0308 labelClusters<CellCollection, GridDim, Connect>(
0309 cells, cellLabels, std::forward<Connect>(connect));
0310 return mergeClusters<CellCollection, ClusterCollection, GridDim>(
0311 cells, cellLabels, nClusters);
0312 }
0313
0314 }