Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-11-04 09:21:24

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/Clusterization/Clusterization.hpp"
0012 
0013 #include <algorithm>
0014 #include <array>
0015 #include <ranges>
0016 #include <vector>
0017 
0018 namespace Acts::Ccl {
0019 
0020 template <typename Cluster>
0021 void reserve(Cluster& /*cl*/, std::size_t /*n*/) {}
0022 
0023 template <Acts::Ccl::CanReserve Cluster>
0024 void reserve(Cluster& cl, std::size_t n) {
0025   clusterReserve(cl, n);
0026 }
0027 
0028 template <typename Cell, std::size_t GridDim>
0029 struct Compare {
0030   static_assert(GridDim != 1 && GridDim != 2,
0031                 "Only grid dimensions of 1 or 2 are supported");
0032 };
0033 
0034 // Comparator function object for cells, column-wise ordering
0035 // Specialization for 1-D grids
0036 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0037 struct Compare<Cell, 1> {
0038   bool operator()(const Cell& c0, const Cell& c1) const {
0039     int col0 = getCellColumn(c0);
0040     int col1 = getCellColumn(c1);
0041     return col0 < col1;
0042   }
0043 };
0044 
0045 // Specialization for 2-D grid
0046 template <typename Cell>
0047   requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0048            Acts::Ccl::HasRetrievableRowInfo<Cell>)
0049 struct Compare<Cell, 2> {
0050   bool operator()(const Cell& c0, const Cell& c1) const {
0051     int row0 = getCellRow(c0);
0052     int row1 = getCellRow(c1);
0053     int col0 = getCellColumn(c0);
0054     int col1 = getCellColumn(c1);
0055     return (col0 == col1) ? row0 < row1 : col0 < col1;
0056   }
0057 };
0058 
0059 template <std::size_t BufSize>
0060 struct ConnectionsBase {
0061   std::size_t nconn{0};
0062   std::array<Label, BufSize> buf;
0063   ConnectionsBase() { std::ranges::fill(buf, NO_LABEL); }
0064 };
0065 
0066 template <std::size_t GridDim>
0067 class Connections {};
0068 
0069 // On 1-D grid, cells have 1 backward neighbor
0070 template <>
0071 struct Connections<1> : public ConnectionsBase<1> {
0072   using ConnectionsBase::ConnectionsBase;
0073 };
0074 
0075 // On a 2-D grid, cells have 4 backward neighbors
0076 template <>
0077 struct Connections<2> : public ConnectionsBase<4> {
0078   using ConnectionsBase::ConnectionsBase;
0079 };
0080 
0081 // Cell collection logic
0082 template <typename Cell, typename Connect, std::size_t GridDim>
0083 Connections<GridDim> getConnections(std::size_t idx, std::vector<Cell>& cells,
0084                                     std::vector<Label>& labels,
0085                                     Connect&& connect) {
0086   Connections<GridDim> seen;
0087 
0088   for (std::size_t i = 0; i < idx; ++i) {
0089     std::size_t idx2 = idx - i - 1;
0090     ConnectResult cr = connect(cells[idx], cells[idx2]);
0091 
0092     if (cr == ConnectResult::eDuplicate) {
0093       throw std::invalid_argument(
0094           "Clusterization: input contains duplicate cells");
0095     }
0096     if (cr == ConnectResult::eNoConnStop) {
0097       break;
0098     }
0099     if (cr == ConnectResult::eNoConn) {
0100       continue;
0101     }
0102     if (cr == ConnectResult::eConn) {
0103       seen.buf[seen.nconn] = labels[idx2];
0104       seen.nconn += 1;
0105       if (seen.nconn == seen.buf.size()) {
0106         break;
0107       }
0108     }
0109   }
0110 
0111   return seen;
0112 }
0113 
0114 template <typename CellCollection, typename ClusterCollection>
0115   requires(Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
0116                                     typename ClusterCollection::value_type>)
0117 void mergeClusters(Acts::Ccl::ClusteringData& data, const CellCollection& cells,
0118                    ClusterCollection& outv) {
0119   using Cluster = typename ClusterCollection::value_type;
0120 
0121   // Accumulate clusters into the output collection
0122   std::size_t previousSize = outv.size();
0123   outv.resize(previousSize + data.nClusters.size());
0124   for (std::size_t i = 0; i < data.nClusters.size(); ++i) {
0125     Acts::Ccl::reserve(outv[previousSize + i], data.nClusters[i]);
0126   }
0127 
0128   // Fill clusters with cells
0129   // We are not using enumerate, since that is less optimal than
0130   // this loop
0131   for (std::size_t i = 0; i < cells.size(); ++i) {
0132     Label label = data.labels[i] - 1;
0133     Cluster& cl = outv[previousSize + label];
0134     clusterAddCell(cl, cells[i]);
0135   }
0136 
0137   // Due to previous merging, we may have now clusters with
0138   // no cells. We need to remove them
0139   std::size_t invalidClusters = 0ul;
0140   for (std::size_t i = 0; i < data.nClusters.size(); ++i) {
0141     std::size_t idx = data.nClusters.size() - i - 1;
0142     if (data.nClusters[idx] != 0) {
0143       continue;
0144     }
0145     // we have an invalid cluster.
0146     // move them all to the back so that we can remove
0147     // them later
0148     std::swap(outv[previousSize + idx],
0149               outv[outv.size() - invalidClusters - 1]);
0150     ++invalidClusters;
0151   }
0152   outv.resize(outv.size() - invalidClusters);
0153 }
0154 
0155 template <typename Cell>
0156   requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
0157            Acts::Ccl::HasRetrievableRowInfo<Cell>)
0158 ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
0159                                           const Cell& iter) const {
0160   int deltaRow = getCellRow(iter) - getCellRow(ref);
0161   int deltaCol = getCellColumn(iter) - getCellColumn(ref);
0162   assert((deltaCol < 0 || (deltaCol == 0 && deltaRow <= 0)) &&
0163          "Not iterating backwards");
0164 
0165   switch (deltaCol) {
0166     case 0:
0167       if (deltaRow == 0) {
0168         return ConnectResult::eDuplicate;
0169       } else if (deltaRow == -1) {
0170         return ConnectResult::eConn;
0171       } else {
0172         return ConnectResult::eNoConn;
0173       }
0174     case -1:
0175       if (deltaRow > static_cast<int>(conn8)) {
0176         return ConnectResult::eNoConn;
0177       } else if (deltaRow < -static_cast<int>(conn8)) {
0178         return ConnectResult::eNoConnStop;
0179       } else {
0180         return ConnectResult::eConn;
0181       }
0182     default:
0183       return ConnectResult::eNoConnStop;
0184   }
0185 }
0186 
0187 template <Acts::Ccl::HasRetrievableColumnInfo Cell>
0188 ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
0189                                           const Cell& iter) const {
0190   int deltaCol = getCellColumn(iter) - getCellColumn(ref);
0191   assert((deltaCol <= 0) && "Not iterating backwards");
0192 
0193   switch (deltaCol) {
0194     case 0:
0195       return ConnectResult::eDuplicate;
0196     case -1:
0197       return ConnectResult::eConn;
0198     default:
0199       return ConnectResult::eNoConnStop;
0200   }
0201 }
0202 
0203 template <std::size_t GridDim>
0204 void recordEquivalences(const Connections<GridDim> seen, DisjointSets& ds) {
0205   // Sanity check: first element should always have
0206   // label if nconn > 0
0207   if (seen.nconn > 0 && seen.buf[0] == NO_LABEL) {
0208     throw std::logic_error("seen.nconn > 0 but seen.buf[0] == NO_LABEL");
0209   }
0210   for (std::size_t i = 1; i < seen.nconn; i++) {
0211     // Sanity check: since connection lookup is always backward
0212     // while iteration is forward, all connected cells found here
0213     // should have a label
0214     if (seen.buf[i] == NO_LABEL) {
0215       throw std::logic_error("i < seen.nconn but see.buf[i] == NO_LABEL");
0216     }
0217     // Only record equivalence if needed
0218     if (seen.buf[0] != seen.buf[i]) {
0219       ds.unionSet(seen.buf[0], seen.buf[i]);
0220     }
0221   }
0222 }
0223 
0224 template <typename CellCollection, std::size_t GridDim, typename Connect>
0225 void labelClusters(Acts::Ccl::ClusteringData& data, CellCollection& cells,
0226                    Connect&& connect) {
0227   using Cell = typename CellCollection::value_type;
0228 
0229   data.labels.resize(cells.size(), NO_LABEL);
0230   // Sort cells by position to enable in-order scan
0231   std::ranges::sort(cells, Acts::Ccl::Compare<Cell, GridDim>());
0232 
0233   // First pass: Allocate labels and record equivalences
0234   for (std::size_t nCell(0ul); nCell < cells.size(); ++nCell) {
0235     const Acts::Ccl::Connections<GridDim> seen =
0236         Acts::Ccl::getConnections<Cell, Connect, GridDim>(
0237             nCell, cells, data.labels, std::forward<Connect>(connect));
0238 
0239     if (seen.nconn == 0) {
0240       // Allocate new label
0241       data.labels[nCell] = data.ds.makeSet();
0242     } else {
0243       recordEquivalences(seen, data.ds);
0244       // Set label for current cell
0245       data.labels[nCell] = seen.buf[0];
0246     }
0247   }  // loop on cells
0248 
0249   // Second pass: Merge labels based on recorded equivalences
0250   int maxNClusters = 0;
0251   for (Label& lbl : data.labels) {
0252     lbl = data.ds.findSet(lbl);
0253     maxNClusters = std::max(maxNClusters, lbl);
0254   }
0255 
0256   // Third pass: Keep count of how many cells go in each
0257   // to-be-created clusters
0258   data.nClusters.resize(maxNClusters, 0);
0259   for (const Label label : data.labels) {
0260     ++data.nClusters[label - 1];
0261   }
0262 }
0263 
0264 template <typename CellCollection, typename ClusterCollection,
0265           std::size_t GridDim, typename Connect>
0266 ClusterCollection createClusters(CellCollection& cells, Connect&& connect) {
0267   ClusterCollection clusters;
0268   Acts::Ccl::ClusteringData data;
0269   Acts::Ccl::createClusters<CellCollection, ClusterCollection, GridDim,
0270                             Connect>(data, cells, clusters,
0271                                      std::forward<Connect>(connect));
0272   return clusters;
0273 }
0274 
0275 template <typename CellCollection, typename ClusterCollection,
0276           std::size_t GridDim, typename Connect>
0277   requires(GridDim == 1 || GridDim == 2)
0278 void createClusters(Acts::Ccl::ClusteringData& data, CellCollection& cells,
0279                     ClusterCollection& clusters, Connect&& connect) {
0280   if (cells.empty()) {
0281     return;
0282   }
0283   data.clear();
0284 
0285   Acts::Ccl::labelClusters<CellCollection, GridDim, Connect>(
0286       data, cells, std::forward<Connect>(connect));
0287   Acts::Ccl::mergeClusters<CellCollection, ClusterCollection>(data, cells,
0288                                                               clusters);
0289 }
0290 
0291 }  // namespace Acts::Ccl