File indexing completed on 2026-05-20 07:34:07
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Clusterization/Clusterization.hpp"
0012
0013 #include <algorithm>
0014 #include <array>
0015 #include <ranges>
0016 #include <vector>
0017
0018 namespace Acts::Ccl {
0019
0020 template <typename Cluster>
0021 void reserve(Cluster& cl, std::size_t n) {
0022 if constexpr (Acts::Ccl::CanReserve<Cluster>) {
0023 clusterReserve(cl, n);
0024 }
0025 }
0026
0027 template <typename Cell, std::size_t GridDim>
0028 struct Compare {
0029 static_assert(GridDim != 1 && GridDim != 2,
0030 "Only grid dimensions of 1 or 2 are supported");
0031 };
0032
0033
0034
0035 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0036 struct Compare<Cell, 1> {
0037 bool operator()(const Cell& c0, const Cell& c1) const {
0038 int col0 = getCellColumn(c0);
0039 int col1 = getCellColumn(c1);
0040 return col0 < col1;
0041 }
0042 };
0043
0044
0045 template <typename Cell>
0046 requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0047 Acts::Ccl::HasRetrievableRowInfo<Cell>)
0048 struct Compare<Cell, 2> {
0049 bool operator()(const Cell& c0, const Cell& c1) const {
0050 int row0 = getCellRow(c0);
0051 int row1 = getCellRow(c1);
0052 int col0 = getCellColumn(c0);
0053 int col1 = getCellColumn(c1);
0054 return (col0 == col1) ? row0 < row1 : col0 < col1;
0055 }
0056 };
0057
0058 template <std::size_t BufSize>
0059 struct ConnectionsBase {
0060 std::size_t nconn{0};
0061 std::array<Label, BufSize> buf;
0062 ConnectionsBase() { std::ranges::fill(buf, NO_LABEL); }
0063 };
0064
0065 template <std::size_t GridDim>
0066 class Connections {};
0067
0068
0069 template <>
0070 struct Connections<1> : public ConnectionsBase<1> {
0071 using ConnectionsBase::ConnectionsBase;
0072 };
0073
0074
0075 template <>
0076 struct Connections<2> : public ConnectionsBase<4> {
0077 using ConnectionsBase::ConnectionsBase;
0078 };
0079
0080
0081 template <typename Cell, typename Connect, std::size_t GridDim>
0082 Connections<GridDim> getConnections(std::size_t idx, std::vector<Cell>& cells,
0083 std::vector<Label>& labels,
0084 Connect&& connect) {
0085 Connections<GridDim> seen;
0086
0087 for (std::size_t i = 0; i < idx; ++i) {
0088 std::size_t idx2 = idx - i - 1;
0089 ConnectResult cr = connect(cells[idx], cells[idx2]);
0090
0091 if (cr == ConnectResult::eDuplicate) {
0092 throw std::invalid_argument(
0093 "Clusterization: input contains duplicate cells");
0094 }
0095 if (cr == ConnectResult::eNoConnStop) {
0096 break;
0097 }
0098 if (cr == ConnectResult::eNoConn) {
0099 continue;
0100 }
0101 if (cr == ConnectResult::eConn) {
0102 seen.buf[seen.nconn] = labels[idx2];
0103 seen.nconn += 1;
0104 if (seen.nconn == seen.buf.size()) {
0105 break;
0106 }
0107 }
0108 }
0109
0110 return seen;
0111 }
0112
0113 template <typename CellCollection, typename ClusterCollection>
0114 requires(Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
0115 typename ClusterCollection::value_type>)
0116 void mergeClusters(Acts::Ccl::ClusteringData& data, const CellCollection& cells,
0117 ClusterCollection& outv) {
0118 using Cluster = typename ClusterCollection::value_type;
0119
0120
0121 std::size_t previousSize = outv.size();
0122 outv.resize(previousSize + data.nClusters.size());
0123 for (std::size_t i = 0; i < data.nClusters.size(); ++i) {
0124 Acts::Ccl::reserve(outv[previousSize + i], data.nClusters[i]);
0125 }
0126
0127
0128
0129
0130 for (std::size_t i = 0; i < cells.size(); ++i) {
0131 Label label = data.labels[i] - 1;
0132 Cluster& cl = outv[previousSize + label];
0133 clusterAddCell(cl, cells[i]);
0134 }
0135
0136
0137
0138 std::size_t invalidClusters = 0ul;
0139 for (std::size_t i = 0; i < data.nClusters.size(); ++i) {
0140 std::size_t idx = data.nClusters.size() - i - 1;
0141 if (data.nClusters[idx] != 0) {
0142 continue;
0143 }
0144
0145
0146
0147 std::swap(outv[previousSize + idx],
0148 outv[outv.size() - invalidClusters - 1]);
0149 ++invalidClusters;
0150 }
0151 outv.resize(outv.size() - invalidClusters);
0152 }
0153
0154 template <typename Cell>
0155 requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0156 Acts::Ccl::HasRetrievableRowInfo<Cell>)
0157 ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
0158 const Cell& iter) const {
0159 int deltaRow = getCellRow(iter) - getCellRow(ref);
0160 int deltaCol = getCellColumn(iter) - getCellColumn(ref);
0161 assert((deltaCol < 0 || (deltaCol == 0 && deltaRow <= 0)) &&
0162 "Not iterating backwards");
0163
0164 switch (deltaCol) {
0165 case 0:
0166 if (deltaRow == 0) {
0167 return ConnectResult::eDuplicate;
0168 } else if (deltaRow == -1) {
0169 return ConnectResult::eConn;
0170 } else {
0171 return ConnectResult::eNoConn;
0172 }
0173 case -1:
0174 if (deltaRow > static_cast<int>(conn8)) {
0175 return ConnectResult::eNoConn;
0176 } else if (deltaRow < -static_cast<int>(conn8)) {
0177 return ConnectResult::eNoConnStop;
0178 } else {
0179 return ConnectResult::eConn;
0180 }
0181 default:
0182 return ConnectResult::eNoConnStop;
0183 }
0184 }
0185
0186 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0187 ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
0188 const Cell& iter) const {
0189 int deltaCol = getCellColumn(iter) - getCellColumn(ref);
0190 assert((deltaCol <= 0) && "Not iterating backwards");
0191
0192 switch (deltaCol) {
0193 case 0:
0194 return ConnectResult::eDuplicate;
0195 case -1:
0196 return ConnectResult::eConn;
0197 default:
0198 return ConnectResult::eNoConnStop;
0199 }
0200 }
0201
0202 template <std::size_t GridDim>
0203 void recordEquivalences(const Connections<GridDim> seen, DisjointSets& ds) {
0204
0205
0206 if (seen.nconn > 0 && seen.buf[0] == NO_LABEL) {
0207 throw std::logic_error("seen.nconn > 0 but seen.buf[0] == NO_LABEL");
0208 }
0209 for (std::size_t i = 1; i < seen.nconn; i++) {
0210
0211
0212
0213 if (seen.buf[i] == NO_LABEL) {
0214 throw std::logic_error("i < seen.nconn but see.buf[i] == NO_LABEL");
0215 }
0216
0217 if (seen.buf[0] != seen.buf[i]) {
0218 ds.unionSet(seen.buf[0], seen.buf[i]);
0219 }
0220 }
0221 }
0222
0223 template <typename CellCollection, std::size_t GridDim, typename Connect>
0224 void labelClusters(Acts::Ccl::ClusteringData& data, CellCollection& cells,
0225 Connect&& connect) {
0226 using Cell = typename CellCollection::value_type;
0227
0228 data.labels.resize(cells.size(), NO_LABEL);
0229
0230 std::ranges::sort(cells, Acts::Ccl::Compare<Cell, GridDim>());
0231
0232
0233 for (std::size_t nCell(0ul); nCell < cells.size(); ++nCell) {
0234 const Acts::Ccl::Connections<GridDim> seen =
0235 Acts::Ccl::getConnections<Cell, Connect, GridDim>(
0236 nCell, cells, data.labels, std::forward<Connect>(connect));
0237
0238 if (seen.nconn == 0) {
0239
0240 data.labels[nCell] = data.ds.makeSet();
0241 } else {
0242 recordEquivalences(seen, data.ds);
0243
0244 data.labels[nCell] = seen.buf[0];
0245 }
0246 }
0247
0248
0249 int maxNClusters = 0;
0250 for (Label& lbl : data.labels) {
0251 lbl = data.ds.findSet(lbl);
0252 maxNClusters = std::max(maxNClusters, lbl);
0253 }
0254
0255
0256
0257 data.nClusters.resize(maxNClusters, 0);
0258 for (const Label label : data.labels) {
0259 ++data.nClusters[label - 1];
0260 }
0261 }
0262
0263 template <typename CellCollection, typename ClusterCollection,
0264 std::size_t GridDim, typename Connect>
0265 requires(GridDim == 1 || GridDim == 2)
0266 void createClusters(Acts::Ccl::ClusteringData& data, CellCollection& cells,
0267 ClusterCollection& clusters, Connect&& connect) {
0268 if (cells.empty()) {
0269 return;
0270 }
0271 data.clear();
0272
0273 Acts::Ccl::labelClusters<CellCollection, GridDim, Connect>(
0274 data, cells, std::forward<Connect>(connect));
0275 Acts::Ccl::mergeClusters<CellCollection, ClusterCollection>(data, cells,
0276 clusters);
0277 }
0278
0279 }