File indexing completed on 2025-11-04 09:21:24
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Clusterization/Clusterization.hpp"
0012
0013 #include <algorithm>
0014 #include <array>
0015 #include <ranges>
0016 #include <vector>
0017
0018 namespace Acts::Ccl {
0019
0020 template <typename Cluster>
0021 void reserve(Cluster& , std::size_t ) {}
0022
0023 template <Acts::Ccl::CanReserve Cluster>
0024 void reserve(Cluster& cl, std::size_t n) {
0025 clusterReserve(cl, n);
0026 }
0027
0028 template <typename Cell, std::size_t GridDim>
0029 struct Compare {
0030 static_assert(GridDim != 1 && GridDim != 2,
0031 "Only grid dimensions of 1 or 2 are supported");
0032 };
0033
0034
0035
0036 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0037 struct Compare<Cell, 1> {
0038 bool operator()(const Cell& c0, const Cell& c1) const {
0039 int col0 = getCellColumn(c0);
0040 int col1 = getCellColumn(c1);
0041 return col0 < col1;
0042 }
0043 };
0044
0045
0046 template <typename Cell>
0047 requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0048 Acts::Ccl::HasRetrievableRowInfo<Cell>)
0049 struct Compare<Cell, 2> {
0050 bool operator()(const Cell& c0, const Cell& c1) const {
0051 int row0 = getCellRow(c0);
0052 int row1 = getCellRow(c1);
0053 int col0 = getCellColumn(c0);
0054 int col1 = getCellColumn(c1);
0055 return (col0 == col1) ? row0 < row1 : col0 < col1;
0056 }
0057 };
0058
0059 template <std::size_t BufSize>
0060 struct ConnectionsBase {
0061 std::size_t nconn{0};
0062 std::array<Label, BufSize> buf;
0063 ConnectionsBase() { std::ranges::fill(buf, NO_LABEL); }
0064 };
0065
0066 template <std::size_t GridDim>
0067 class Connections {};
0068
0069
0070 template <>
0071 struct Connections<1> : public ConnectionsBase<1> {
0072 using ConnectionsBase::ConnectionsBase;
0073 };
0074
0075
0076 template <>
0077 struct Connections<2> : public ConnectionsBase<4> {
0078 using ConnectionsBase::ConnectionsBase;
0079 };
0080
0081
0082 template <typename Cell, typename Connect, std::size_t GridDim>
0083 Connections<GridDim> getConnections(std::size_t idx, std::vector<Cell>& cells,
0084 std::vector<Label>& labels,
0085 Connect&& connect) {
0086 Connections<GridDim> seen;
0087
0088 for (std::size_t i = 0; i < idx; ++i) {
0089 std::size_t idx2 = idx - i - 1;
0090 ConnectResult cr = connect(cells[idx], cells[idx2]);
0091
0092 if (cr == ConnectResult::eDuplicate) {
0093 throw std::invalid_argument(
0094 "Clusterization: input contains duplicate cells");
0095 }
0096 if (cr == ConnectResult::eNoConnStop) {
0097 break;
0098 }
0099 if (cr == ConnectResult::eNoConn) {
0100 continue;
0101 }
0102 if (cr == ConnectResult::eConn) {
0103 seen.buf[seen.nconn] = labels[idx2];
0104 seen.nconn += 1;
0105 if (seen.nconn == seen.buf.size()) {
0106 break;
0107 }
0108 }
0109 }
0110
0111 return seen;
0112 }
0113
0114 template <typename CellCollection, typename ClusterCollection>
0115 requires(Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
0116 typename ClusterCollection::value_type>)
0117 void mergeClusters(Acts::Ccl::ClusteringData& data, const CellCollection& cells,
0118 ClusterCollection& outv) {
0119 using Cluster = typename ClusterCollection::value_type;
0120
0121
0122 std::size_t previousSize = outv.size();
0123 outv.resize(previousSize + data.nClusters.size());
0124 for (std::size_t i = 0; i < data.nClusters.size(); ++i) {
0125 Acts::Ccl::reserve(outv[previousSize + i], data.nClusters[i]);
0126 }
0127
0128
0129
0130
0131 for (std::size_t i = 0; i < cells.size(); ++i) {
0132 Label label = data.labels[i] - 1;
0133 Cluster& cl = outv[previousSize + label];
0134 clusterAddCell(cl, cells[i]);
0135 }
0136
0137
0138
0139 std::size_t invalidClusters = 0ul;
0140 for (std::size_t i = 0; i < data.nClusters.size(); ++i) {
0141 std::size_t idx = data.nClusters.size() - i - 1;
0142 if (data.nClusters[idx] != 0) {
0143 continue;
0144 }
0145
0146
0147
0148 std::swap(outv[previousSize + idx],
0149 outv[outv.size() - invalidClusters - 1]);
0150 ++invalidClusters;
0151 }
0152 outv.resize(outv.size() - invalidClusters);
0153 }
0154
0155 template <typename Cell>
0156 requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0157 Acts::Ccl::HasRetrievableRowInfo<Cell>)
0158 ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
0159 const Cell& iter) const {
0160 int deltaRow = getCellRow(iter) - getCellRow(ref);
0161 int deltaCol = getCellColumn(iter) - getCellColumn(ref);
0162 assert((deltaCol < 0 || (deltaCol == 0 && deltaRow <= 0)) &&
0163 "Not iterating backwards");
0164
0165 switch (deltaCol) {
0166 case 0:
0167 if (deltaRow == 0) {
0168 return ConnectResult::eDuplicate;
0169 } else if (deltaRow == -1) {
0170 return ConnectResult::eConn;
0171 } else {
0172 return ConnectResult::eNoConn;
0173 }
0174 case -1:
0175 if (deltaRow > static_cast<int>(conn8)) {
0176 return ConnectResult::eNoConn;
0177 } else if (deltaRow < -static_cast<int>(conn8)) {
0178 return ConnectResult::eNoConnStop;
0179 } else {
0180 return ConnectResult::eConn;
0181 }
0182 default:
0183 return ConnectResult::eNoConnStop;
0184 }
0185 }
0186
0187 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0188 ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
0189 const Cell& iter) const {
0190 int deltaCol = getCellColumn(iter) - getCellColumn(ref);
0191 assert((deltaCol <= 0) && "Not iterating backwards");
0192
0193 switch (deltaCol) {
0194 case 0:
0195 return ConnectResult::eDuplicate;
0196 case -1:
0197 return ConnectResult::eConn;
0198 default:
0199 return ConnectResult::eNoConnStop;
0200 }
0201 }
0202
0203 template <std::size_t GridDim>
0204 void recordEquivalences(const Connections<GridDim> seen, DisjointSets& ds) {
0205
0206
0207 if (seen.nconn > 0 && seen.buf[0] == NO_LABEL) {
0208 throw std::logic_error("seen.nconn > 0 but seen.buf[0] == NO_LABEL");
0209 }
0210 for (std::size_t i = 1; i < seen.nconn; i++) {
0211
0212
0213
0214 if (seen.buf[i] == NO_LABEL) {
0215 throw std::logic_error("i < seen.nconn but see.buf[i] == NO_LABEL");
0216 }
0217
0218 if (seen.buf[0] != seen.buf[i]) {
0219 ds.unionSet(seen.buf[0], seen.buf[i]);
0220 }
0221 }
0222 }
0223
0224 template <typename CellCollection, std::size_t GridDim, typename Connect>
0225 void labelClusters(Acts::Ccl::ClusteringData& data, CellCollection& cells,
0226 Connect&& connect) {
0227 using Cell = typename CellCollection::value_type;
0228
0229 data.labels.resize(cells.size(), NO_LABEL);
0230
0231 std::ranges::sort(cells, Acts::Ccl::Compare<Cell, GridDim>());
0232
0233
0234 for (std::size_t nCell(0ul); nCell < cells.size(); ++nCell) {
0235 const Acts::Ccl::Connections<GridDim> seen =
0236 Acts::Ccl::getConnections<Cell, Connect, GridDim>(
0237 nCell, cells, data.labels, std::forward<Connect>(connect));
0238
0239 if (seen.nconn == 0) {
0240
0241 data.labels[nCell] = data.ds.makeSet();
0242 } else {
0243 recordEquivalences(seen, data.ds);
0244
0245 data.labels[nCell] = seen.buf[0];
0246 }
0247 }
0248
0249
0250 int maxNClusters = 0;
0251 for (Label& lbl : data.labels) {
0252 lbl = data.ds.findSet(lbl);
0253 maxNClusters = std::max(maxNClusters, lbl);
0254 }
0255
0256
0257
0258 data.nClusters.resize(maxNClusters, 0);
0259 for (const Label label : data.labels) {
0260 ++data.nClusters[label - 1];
0261 }
0262 }
0263
0264 template <typename CellCollection, typename ClusterCollection,
0265 std::size_t GridDim, typename Connect>
0266 ClusterCollection createClusters(CellCollection& cells, Connect&& connect) {
0267 ClusterCollection clusters;
0268 Acts::Ccl::ClusteringData data;
0269 Acts::Ccl::createClusters<CellCollection, ClusterCollection, GridDim,
0270 Connect>(data, cells, clusters,
0271 std::forward<Connect>(connect));
0272 return clusters;
0273 }
0274
0275 template <typename CellCollection, typename ClusterCollection,
0276 std::size_t GridDim, typename Connect>
0277 requires(GridDim == 1 || GridDim == 2)
0278 void createClusters(Acts::Ccl::ClusteringData& data, CellCollection& cells,
0279 ClusterCollection& clusters, Connect&& connect) {
0280 if (cells.empty()) {
0281 return;
0282 }
0283 data.clear();
0284
0285 Acts::Ccl::labelClusters<CellCollection, GridDim, Connect>(
0286 data, cells, std::forward<Connect>(connect));
0287 Acts::Ccl::mergeClusters<CellCollection, ClusterCollection>(data, cells,
0288 clusters);
0289 }
0290
0291 }