File indexing completed on 2025-01-18 09:10:45
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <algorithm>
0010 #include <array>
0011 #include <vector>
0012
0013 #include <boost/pending/disjoint_sets.hpp>
0014
0015 namespace Acts::Ccl::internal {
0016
0017 template <typename Cell, std::size_t GridDim>
0018 struct Compare {
0019 static_assert(GridDim != 1 && GridDim != 2,
0020 "Only grid dimensions of 1 or 2 are supported");
0021 };
0022
0023
0024
0025 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0026 struct Compare<Cell, 1> {
0027 bool operator()(const Cell& c0, const Cell& c1) const {
0028 int col0 = getCellColumn(c0);
0029 int col1 = getCellColumn(c1);
0030 return col0 < col1;
0031 }
0032 };
0033
0034
0035 template <typename Cell>
0036 requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0037 Acts::Ccl::HasRetrievableRowInfo<Cell>)
0038 struct Compare<Cell, 2> {
0039 bool operator()(const Cell& c0, const Cell& c1) const {
0040 int row0 = getCellRow(c0);
0041 int row1 = getCellRow(c1);
0042 int col0 = getCellColumn(c0);
0043 int col1 = getCellColumn(c1);
0044 return (col0 == col1) ? row0 < row1 : col0 < col1;
0045 }
0046 };
0047
0048
0049
0050
0051 class DisjointSets {
0052 public:
0053 explicit DisjointSets(std::size_t initial_size = 128)
0054 : m_size(initial_size),
0055 m_rank(m_size),
0056 m_parent(m_size),
0057 m_ds(&m_rank[0], &m_parent[0]) {}
0058
0059 Label makeSet() {
0060
0061
0062 while (m_globalId >= m_size) {
0063 m_size *= 2;
0064 m_rank.resize(m_size);
0065 m_parent.resize(m_size);
0066 m_ds = boost::disjoint_sets<std::size_t*, std::size_t*>(&m_rank[0],
0067 &m_parent[0]);
0068 }
0069 m_ds.make_set(m_globalId);
0070 return static_cast<Label>(m_globalId++);
0071 }
0072
0073 void unionSet(std::size_t x, std::size_t y) { m_ds.union_set(x, y); }
0074 Label findSet(std::size_t x) { return static_cast<Label>(m_ds.find_set(x)); }
0075
0076 private:
0077 std::size_t m_globalId = 1;
0078 std::size_t m_size;
0079 std::vector<std::size_t> m_rank;
0080 std::vector<std::size_t> m_parent;
0081 boost::disjoint_sets<std::size_t*, std::size_t*> m_ds;
0082 };
0083
0084 template <std::size_t BufSize>
0085 struct ConnectionsBase {
0086 std::size_t nconn{0};
0087 std::array<Label, BufSize> buf;
0088 ConnectionsBase() { std::fill(buf.begin(), buf.end(), NO_LABEL); }
0089 };
0090
0091 template <std::size_t GridDim>
0092 class Connections {};
0093
0094
0095 template <>
0096 struct Connections<1> : public ConnectionsBase<1> {
0097 using ConnectionsBase::ConnectionsBase;
0098 };
0099
0100
0101 template <>
0102 struct Connections<2> : public ConnectionsBase<4> {
0103 using ConnectionsBase::ConnectionsBase;
0104 };
0105
0106
0107 template <typename Cell, typename Connect, std::size_t GridDim>
0108 Connections<GridDim> getConnections(typename std::vector<Cell>::iterator it,
0109 std::vector<Cell>& set, Connect connect) {
0110 Connections<GridDim> seen;
0111 typename std::vector<Cell>::iterator it_2{it};
0112
0113 while (it_2 != set.begin()) {
0114 it_2 = std::prev(it_2);
0115
0116 ConnectResult cr = connect(*it, *it_2);
0117 if (cr == ConnectResult::eNoConnStop) {
0118 break;
0119 }
0120 if (cr == ConnectResult::eNoConn) {
0121 continue;
0122 }
0123 if (cr == ConnectResult::eConn) {
0124 seen.buf[seen.nconn] = getCellLabel(*it_2);
0125 seen.nconn += 1;
0126 if (seen.nconn == seen.buf.size()) {
0127 break;
0128 }
0129 }
0130 }
0131 return seen;
0132 }
0133
0134 template <typename CellCollection, typename ClusterCollection>
0135 requires(
0136 Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type> &&
0137 Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
0138 typename ClusterCollection::value_type>)
0139 ClusterCollection mergeClustersImpl(CellCollection& cells) {
0140 using Cluster = typename ClusterCollection::value_type;
0141
0142 if (cells.empty()) {
0143 return {};
0144 }
0145
0146
0147 ClusterCollection outv;
0148 Cluster cl;
0149 int lbl = getCellLabel(cells.front());
0150 for (auto& cell : cells) {
0151 if (getCellLabel(cell) != lbl) {
0152
0153 outv.push_back(std::move(cl));
0154 cl = Cluster();
0155 lbl = getCellLabel(cell);
0156 }
0157 clusterAddCell(cl, cell);
0158 }
0159
0160 outv.push_back(std::move(cl));
0161
0162 return outv;
0163 }
0164
0165 }
0166
0167 namespace Acts::Ccl {
0168
0169 template <typename Cell>
0170 requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0171 Acts::Ccl::HasRetrievableRowInfo<Cell>)
0172 ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
0173 const Cell& iter) const {
0174 int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
0175 int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0176
0177
0178 if (deltaCol > 1) {
0179 return ConnectResult::eNoConnStop;
0180 }
0181
0182
0183 if (deltaRow > 1) {
0184 return ConnectResult::eNoConn;
0185 }
0186
0187
0188 if ((deltaRow + deltaCol) <= (conn8 ? 2 : 1)) {
0189 return ConnectResult::eConn;
0190 }
0191 return ConnectResult::eNoConn;
0192 }
0193
0194 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0195 ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
0196 const Cell& iter) const {
0197 int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0198 return deltaCol == 1 ? ConnectResult::eConn : ConnectResult::eNoConnStop;
0199 }
0200
0201 template <std::size_t GridDim>
0202 void recordEquivalences(const internal::Connections<GridDim> seen,
0203 internal::DisjointSets& ds) {
0204
0205
0206 if (seen.nconn > 0 && seen.buf[0] == NO_LABEL) {
0207 throw std::logic_error("seen.nconn > 0 but seen.buf[0] == NO_LABEL");
0208 }
0209 for (std::size_t i = 1; i < seen.nconn; i++) {
0210
0211
0212
0213 if (seen.buf[i] == NO_LABEL) {
0214 throw std::logic_error("i < seen.nconn but see.buf[i] == NO_LABEL");
0215 }
0216
0217 if (seen.buf[0] != seen.buf[i]) {
0218 ds.unionSet(seen.buf[0], seen.buf[i]);
0219 }
0220 }
0221 }
0222
0223 template <typename CellCollection, std::size_t GridDim, typename Connect>
0224 requires(
0225 Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type>)
0226 void labelClusters(CellCollection& cells, Connect connect) {
0227 using Cell = typename CellCollection::value_type;
0228
0229 internal::DisjointSets ds{};
0230
0231
0232 std::ranges::sort(cells, internal::Compare<Cell, GridDim>());
0233
0234
0235 for (auto it = std::ranges::begin(cells); it != std::ranges::end(cells);
0236 ++it) {
0237 const internal::Connections<GridDim> seen =
0238 internal::getConnections<Cell, Connect, GridDim>(it, cells, connect);
0239 if (seen.nconn == 0) {
0240
0241 getCellLabel(*it) = ds.makeSet();
0242 } else {
0243 recordEquivalences(seen, ds);
0244
0245 getCellLabel(*it) = seen.buf[0];
0246 }
0247 }
0248
0249
0250 for (auto& cell : cells) {
0251 Label& lbl = getCellLabel(cell);
0252 lbl = ds.findSet(lbl);
0253 }
0254 }
0255
0256 template <typename CellCollection, typename ClusterCollection,
0257 std::size_t GridDim = 2>
0258 requires(GridDim == 1 || GridDim == 2) &&
0259 Acts::Ccl::HasRetrievableLabelInfo<
0260 typename CellCollection::value_type>
0261 ClusterCollection mergeClusters(CellCollection& cells) {
0262 using Cell = typename CellCollection::value_type;
0263 if constexpr (GridDim > 1) {
0264
0265
0266 std::ranges::sort(cells, {}, [](Cell& c) { return getCellLabel(c); });
0267 }
0268
0269 return internal::mergeClustersImpl<CellCollection, ClusterCollection>(cells);
0270 }
0271
0272 template <typename CellCollection, typename ClusterCollection,
0273 std::size_t GridDim, typename Connect>
0274 ClusterCollection createClusters(CellCollection& cells, Connect connect) {
0275 labelClusters<CellCollection, GridDim, Connect>(cells, connect);
0276 return mergeClusters<CellCollection, ClusterCollection, GridDim>(cells);
0277 }
0278
0279 }