File indexing completed on 2025-09-17 08:01:21
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Clusterization/Clusterization.hpp"
0012
0013 #include <algorithm>
0014 #include <array>
0015 #include <ranges>
0016 #include <vector>
0017
0018 namespace Acts::Ccl {
0019
0020 template <typename Cluster>
0021 void reserve(Cluster& , std::size_t ) {}
0022
0023 template <Acts::Ccl::CanReserve Cluster>
0024 void reserve(Cluster& cl, std::size_t n) {
0025 clusterReserve(cl, n);
0026 }
0027
0028 template <typename Cell, std::size_t GridDim>
0029 struct Compare {
0030 static_assert(GridDim != 1 && GridDim != 2,
0031 "Only grid dimensions of 1 or 2 are supported");
0032 };
0033
0034
0035
0036 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0037 struct Compare<Cell, 1> {
0038 bool operator()(const Cell& c0, const Cell& c1) const {
0039 int col0 = getCellColumn(c0);
0040 int col1 = getCellColumn(c1);
0041 return col0 < col1;
0042 }
0043 };
0044
0045
0046 template <typename Cell>
0047 requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0048 Acts::Ccl::HasRetrievableRowInfo<Cell>)
0049 struct Compare<Cell, 2> {
0050 bool operator()(const Cell& c0, const Cell& c1) const {
0051 int row0 = getCellRow(c0);
0052 int row1 = getCellRow(c1);
0053 int col0 = getCellColumn(c0);
0054 int col1 = getCellColumn(c1);
0055 return (col0 == col1) ? row0 < row1 : col0 < col1;
0056 }
0057 };
0058
0059 template <std::size_t BufSize>
0060 struct ConnectionsBase {
0061 std::size_t nconn{0};
0062 std::array<Label, BufSize> buf;
0063 ConnectionsBase() { std::ranges::fill(buf, NO_LABEL); }
0064 };
0065
0066 template <std::size_t GridDim>
0067 class Connections {};
0068
0069
0070 template <>
0071 struct Connections<1> : public ConnectionsBase<1> {
0072 using ConnectionsBase::ConnectionsBase;
0073 };
0074
0075
0076 template <>
0077 struct Connections<2> : public ConnectionsBase<4> {
0078 using ConnectionsBase::ConnectionsBase;
0079 };
0080
0081
0082 template <typename Cell, typename Connect, std::size_t GridDim>
0083 Connections<GridDim> getConnections(std::size_t idx, std::vector<Cell>& cells,
0084 std::vector<Label>& labels,
0085 Connect&& connect) {
0086 Connections<GridDim> seen;
0087
0088 for (std::size_t i = 0; i < idx; ++i) {
0089 std::size_t idx2 = idx - i - 1;
0090 ConnectResult cr = connect(cells[idx], cells[idx2]);
0091
0092 if (cr == ConnectResult::eNoConnStop) {
0093 break;
0094 }
0095 if (cr == ConnectResult::eNoConn) {
0096 continue;
0097 }
0098 if (cr == ConnectResult::eConn) {
0099 seen.buf[seen.nconn] = labels[idx2];
0100 seen.nconn += 1;
0101 if (seen.nconn == seen.buf.size()) {
0102 break;
0103 }
0104 }
0105 }
0106
0107 return seen;
0108 }
0109
0110 template <typename CellCollection, typename ClusterCollection>
0111 requires(Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
0112 typename ClusterCollection::value_type>)
0113 void mergeClusters(Acts::Ccl::ClusteringData& data, const CellCollection& cells,
0114 ClusterCollection& outv) {
0115 using Cluster = typename ClusterCollection::value_type;
0116
0117
0118 std::size_t previousSize = outv.size();
0119 outv.resize(previousSize + data.nClusters.size());
0120 for (std::size_t i = 0; i < data.nClusters.size(); ++i) {
0121 Acts::Ccl::reserve(outv[previousSize + i], data.nClusters[i]);
0122 }
0123
0124
0125
0126
0127 for (std::size_t i = 0; i < cells.size(); ++i) {
0128 Label label = data.labels[i] - 1;
0129 Cluster& cl = outv[previousSize + label];
0130 clusterAddCell(cl, cells[i]);
0131 }
0132
0133
0134
0135 std::size_t invalidClusters = 0ul;
0136 for (std::size_t i = 0; i < data.nClusters.size(); ++i) {
0137 std::size_t idx = data.nClusters.size() - i - 1;
0138 if (data.nClusters[idx] != 0) {
0139 continue;
0140 }
0141
0142
0143
0144 std::swap(outv[previousSize + idx],
0145 outv[outv.size() - invalidClusters - 1]);
0146 ++invalidClusters;
0147 }
0148 outv.resize(outv.size() - invalidClusters);
0149 }
0150
0151 template <typename Cell>
0152 requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0153 Acts::Ccl::HasRetrievableRowInfo<Cell>)
0154 ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
0155 const Cell& iter) const {
0156 int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
0157 int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0158
0159
0160 if (deltaCol > 1) {
0161 return ConnectResult::eNoConnStop;
0162 }
0163
0164
0165 if (deltaRow > 1) {
0166 return ConnectResult::eNoConn;
0167 }
0168
0169
0170 if ((deltaRow + deltaCol) <= (conn8 ? 2 : 1)) {
0171 return ConnectResult::eConn;
0172 }
0173 return ConnectResult::eNoConn;
0174 }
0175
0176 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0177 ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
0178 const Cell& iter) const {
0179 int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0180 return deltaCol == 1 ? ConnectResult::eConn : ConnectResult::eNoConnStop;
0181 }
0182
0183 template <std::size_t GridDim>
0184 void recordEquivalences(const Connections<GridDim> seen, DisjointSets& ds) {
0185
0186
0187 if (seen.nconn > 0 && seen.buf[0] == NO_LABEL) {
0188 throw std::logic_error("seen.nconn > 0 but seen.buf[0] == NO_LABEL");
0189 }
0190 for (std::size_t i = 1; i < seen.nconn; i++) {
0191
0192
0193
0194 if (seen.buf[i] == NO_LABEL) {
0195 throw std::logic_error("i < seen.nconn but see.buf[i] == NO_LABEL");
0196 }
0197
0198 if (seen.buf[0] != seen.buf[i]) {
0199 ds.unionSet(seen.buf[0], seen.buf[i]);
0200 }
0201 }
0202 }
0203
0204 template <typename CellCollection, std::size_t GridDim, typename Connect>
0205 void labelClusters(Acts::Ccl::ClusteringData& data, CellCollection& cells,
0206 Connect&& connect) {
0207 using Cell = typename CellCollection::value_type;
0208
0209 data.labels.resize(cells.size(), NO_LABEL);
0210
0211 std::ranges::sort(cells, Acts::Ccl::Compare<Cell, GridDim>());
0212
0213
0214 for (std::size_t nCell(0ul); nCell < cells.size(); ++nCell) {
0215 const Acts::Ccl::Connections<GridDim> seen =
0216 Acts::Ccl::getConnections<Cell, Connect, GridDim>(
0217 nCell, cells, data.labels, std::forward<Connect>(connect));
0218
0219 if (seen.nconn == 0) {
0220
0221 data.labels[nCell] = data.ds.makeSet();
0222 } else {
0223 recordEquivalences(seen, data.ds);
0224
0225 data.labels[nCell] = seen.buf[0];
0226 }
0227 }
0228
0229
0230 int maxNClusters = 0;
0231 for (Label& lbl : data.labels) {
0232 lbl = data.ds.findSet(lbl);
0233 maxNClusters = std::max(maxNClusters, lbl);
0234 }
0235
0236
0237
0238 data.nClusters.resize(maxNClusters, 0);
0239 for (const Label label : data.labels) {
0240 ++data.nClusters[label - 1];
0241 }
0242 }
0243
0244 template <typename CellCollection, typename ClusterCollection,
0245 std::size_t GridDim, typename Connect>
0246 ClusterCollection createClusters(CellCollection& cells, Connect&& connect) {
0247 ClusterCollection clusters;
0248 Acts::Ccl::ClusteringData data;
0249 Acts::Ccl::createClusters<CellCollection, ClusterCollection, GridDim,
0250 Connect>(data, cells, clusters,
0251 std::forward<Connect>(connect));
0252 return clusters;
0253 }
0254
0255 template <typename CellCollection, typename ClusterCollection,
0256 std::size_t GridDim, typename Connect>
0257 requires(GridDim == 1 || GridDim == 2)
0258 void createClusters(Acts::Ccl::ClusteringData& data, CellCollection& cells,
0259 ClusterCollection& clusters, Connect&& connect) {
0260 if (cells.empty()) {
0261 return;
0262 }
0263 data.clear();
0264
0265 Acts::Ccl::labelClusters<CellCollection, GridDim, Connect>(
0266 data, cells, std::forward<Connect>(connect));
0267 Acts::Ccl::mergeClusters<CellCollection, ClusterCollection>(data, cells,
0268 clusters);
0269 }
0270
0271 }