File indexing completed on 2025-12-09 09:16:57
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/Digitization/ModuleClusters.hpp"
0010
0011 #include "Acts/Clusterization/Clusterization.hpp"
0012 #include "Acts/Utilities/Helpers.hpp"
0013 #include "ActsExamples/Digitization/MeasurementCreation.hpp"
0014
0015 #include <array>
0016 #include <cmath>
0017 #include <cstdlib>
0018 #include <limits>
0019 #include <map>
0020 #include <stdexcept>
0021
0022 namespace ActsExamples {
0023
0024 void ModuleClusters::add(DigitizedParameters params, simhit_t simhit) {
0025 ModuleValue mval;
0026 mval.paramIndices = std::move(params.indices);
0027 mval.paramValues = std::move(params.values);
0028 mval.paramVariances = std::move(params.variances);
0029 mval.sources = {simhit};
0030
0031 if (m_merge && !params.cluster.channels.empty()) {
0032
0033 for (const auto& cell : params.cluster.channels) {
0034 ModuleValue mval_cell = mval;
0035 mval_cell.value = cell;
0036 m_moduleValues.push_back(std::move(mval_cell));
0037 }
0038 } else {
0039
0040 mval.value = std::move(params.cluster);
0041 m_moduleValues.push_back(std::move(mval));
0042 }
0043 }
0044
0045 std::vector<std::pair<DigitizedParameters, std::set<ModuleClusters::simhit_t>>>
0046 ModuleClusters::digitizedParameters() {
0047 if (m_merge) {
0048 merge();
0049 }
0050 std::vector<std::pair<DigitizedParameters, std::set<simhit_t>>> retv;
0051 for (ModuleValue& mval : m_moduleValues) {
0052 if (std::holds_alternative<Cluster::Cell>(mval.value)) {
0053
0054
0055
0056 throw std::runtime_error("Invalid cluster!");
0057 }
0058 DigitizedParameters dpars;
0059 dpars.indices = mval.paramIndices;
0060 dpars.values = mval.paramValues;
0061 dpars.variances = mval.paramVariances;
0062 dpars.cluster = std::get<Cluster>(mval.value);
0063 retv.emplace_back(std::move(dpars), mval.sources);
0064 }
0065 return retv;
0066 }
0067
0068
0069 int getCellRow(const ModuleValue& mval) {
0070 if (std::holds_alternative<ActsExamples::Cluster::Cell>(mval.value)) {
0071 return std::get<ActsExamples::Cluster::Cell>(mval.value).bin[0];
0072 }
0073 throw std::domain_error("ModuleValue does not contain cell!");
0074 }
0075
0076 int getCellColumn(const ActsExamples::ModuleValue& mval) {
0077 if (std::holds_alternative<ActsExamples::Cluster::Cell>(mval.value)) {
0078 return std::get<ActsExamples::Cluster::Cell>(mval.value).bin[1];
0079 }
0080 throw std::domain_error("ModuleValue does not contain cell!");
0081 }
0082
0083 void clusterAddCell(std::vector<ModuleValue>& cl, const ModuleValue& ce) {
0084 cl.push_back(ce);
0085 }
0086
0087 std::vector<ModuleValue> ModuleClusters::createCellCollection() {
0088 std::map<ActsFatras::Segmentizer::Bin2D, ModuleValue> uniqueCells;
0089 for (const ModuleValue& mval : m_moduleValues) {
0090 if (!std::holds_alternative<Cluster::Cell>(mval.value)) {
0091 continue;
0092 }
0093 const auto& cell = std::get<ActsExamples::Cluster::Cell>(mval.value).bin;
0094
0095 if (const auto it = uniqueCells.find(cell); it != uniqueCells.end()) {
0096
0097 std::get<Cluster::Cell>(it->second.value).activation +=
0098 std::get<Cluster::Cell>(mval.value).activation;
0099 } else {
0100
0101 uniqueCells[cell] = mval;
0102 }
0103 }
0104 std::vector<ModuleValue> cells;
0105 cells.reserve(uniqueCells.size());
0106 for (const auto& [_, mval] : uniqueCells) {
0107 cells.push_back(mval);
0108 }
0109 return cells;
0110 }
0111
0112 void ModuleClusters::merge() {
0113 std::vector<ModuleValue> cells = createCellCollection();
0114
0115 std::vector<ModuleValue> newVals;
0116
0117 if (!cells.empty()) {
0118
0119 Acts::Ccl::ClusteringData data;
0120 std::vector<std::vector<ModuleValue>> merged;
0121 Acts::Ccl::createClusters<std::vector<ModuleValue>,
0122 std::vector<std::vector<ModuleValue>>>(
0123 data, cells, merged,
0124 Acts::Ccl::DefaultConnect<ModuleValue>(m_commonCorner));
0125
0126 for (std::vector<ModuleValue>& cellv : merged) {
0127
0128
0129
0130
0131
0132
0133 for (std::vector<ModuleValue>& remerged : mergeParameters(cellv)) {
0134 newVals.push_back(squash(remerged));
0135 }
0136 }
0137 m_moduleValues = std::move(newVals);
0138 } else {
0139
0140 for (std::vector<ModuleValue>& merged : mergeParameters(m_moduleValues)) {
0141 newVals.push_back(squash(merged));
0142 }
0143 m_moduleValues = std::move(newVals);
0144 }
0145 }
0146
0147
0148 std::vector<std::size_t> ModuleClusters::nonGeoEntries(
0149 std::vector<Acts::BoundIndices>& indices) {
0150 std::vector<std::size_t> retv;
0151 for (std::size_t i = 0; i < indices.size(); i++) {
0152 auto idx = indices.at(i);
0153 if (!rangeContainsValue(m_geoIndices, idx)) {
0154 retv.push_back(i);
0155 }
0156 }
0157 return retv;
0158 }
0159
0160
0161 std::vector<std::vector<ModuleValue>> ModuleClusters::mergeParameters(
0162 std::vector<ModuleValue> values) {
0163 std::vector<std::vector<ModuleValue>> retv;
0164
0165 std::vector<bool> used(values.size(), false);
0166 for (std::size_t i = 0; i < values.size(); i++) {
0167 if (used.at(i)) {
0168 continue;
0169 }
0170
0171 retv.emplace_back();
0172 std::vector<ModuleValue>& thisvec = retv.back();
0173
0174
0175 thisvec.push_back(std::move(values.at(i)));
0176 used.at(i) = true;
0177
0178
0179
0180
0181 for (std::size_t j = i + 1; j < values.size(); j++) {
0182
0183 if (used.at(j)) {
0184 continue;
0185 }
0186
0187
0188
0189
0190
0191 bool matched = true;
0192
0193
0194
0195
0196 for (ModuleValue& thisval : thisvec) {
0197
0198 for (auto k : nonGeoEntries(thisval.paramIndices)) {
0199 double p_i = thisval.paramValues.at(k);
0200 double p_j = values.at(j).paramValues.at(k);
0201 double v_i = thisval.paramVariances.at(k);
0202 double v_j = values.at(j).paramVariances.at(k);
0203
0204 double left = 0, right = 0;
0205 if (p_i < p_j) {
0206 left = p_i + m_nsigma * std::sqrt(v_i);
0207 right = p_j - m_nsigma * std::sqrt(v_j);
0208 } else {
0209 left = p_j + m_nsigma * std::sqrt(v_j);
0210 right = p_i - m_nsigma * std::sqrt(v_i);
0211 }
0212 if (left < right) {
0213
0214
0215 matched = false;
0216 break;
0217 }
0218 }
0219 if (matched) {
0220
0221
0222
0223 break;
0224 }
0225 }
0226 if (matched) {
0227
0228 used.at(j) = true;
0229 thisvec.push_back(std::move(values.at(j)));
0230 }
0231 }
0232 }
0233 return retv;
0234 }
0235
0236 ModuleValue ModuleClusters::squash(std::vector<ModuleValue>& values) {
0237 ModuleValue mval;
0238 double tot = 0;
0239 double tot2 = 0;
0240 std::vector<double> weights;
0241
0242
0243 for (ModuleValue& other : values) {
0244 if (std::holds_alternative<Cluster::Cell>(other.value)) {
0245 weights.push_back(std::get<Cluster::Cell>(other.value).activation);
0246 } else {
0247 weights.push_back(1);
0248 }
0249 tot += weights.back();
0250 tot2 += weights.back() * weights.back();
0251 }
0252
0253
0254 for (std::size_t i = 0; i < values.size(); i++) {
0255 ModuleValue& other = values.at(i);
0256 for (std::size_t j = 0; j < other.paramIndices.size(); j++) {
0257 auto idx = other.paramIndices.at(j);
0258 if (!rangeContainsValue(m_geoIndices, idx)) {
0259 if (!rangeContainsValue(mval.paramIndices, idx)) {
0260 mval.paramIndices.push_back(idx);
0261 }
0262 if (mval.paramValues.size() < (j + 1)) {
0263 mval.paramValues.push_back(0);
0264 mval.paramVariances.push_back(0);
0265 }
0266 double f = weights.at(i) / (tot > 0 ? tot : 1);
0267 double f2 = weights.at(i) * weights.at(i) / (tot2 > 0 ? tot2 : 1);
0268 mval.paramValues.at(j) += f * other.paramValues.at(j);
0269 mval.paramVariances.at(j) += f2 * other.paramVariances.at(j);
0270 }
0271 }
0272 }
0273
0274
0275 Cluster clus;
0276
0277 const auto& binningData = m_segmentation.binningData();
0278 Acts::Vector2 pos(0., 0.);
0279 Acts::Vector2 var(0., 0.);
0280
0281 std::size_t b0min = std::numeric_limits<std::size_t>::max();
0282 std::size_t b0max = 0;
0283 std::size_t b1min = std::numeric_limits<std::size_t>::max();
0284 std::size_t b1max = 0;
0285
0286 for (std::size_t i = 0; i < values.size(); i++) {
0287 ModuleValue& other = values.at(i);
0288 if (!std::holds_alternative<Cluster::Cell>(other.value)) {
0289 continue;
0290 }
0291
0292 Cluster::Cell ch = std::get<Cluster::Cell>(other.value);
0293 auto bin = ch.bin;
0294
0295 std::size_t b0 = bin[0];
0296 std::size_t b1 = bin[1];
0297
0298 b0min = std::min(b0min, b0);
0299 b0max = std::max(b0max, b0);
0300 b1min = std::min(b1min, b1);
0301 b1max = std::max(b1max, b1);
0302
0303 float p0 = binningData[0].center(b0);
0304 float w0 = binningData[0].width(b0);
0305 float p1 = binningData[1].center(b1);
0306 float w1 = binningData[1].width(b1);
0307
0308 pos += Acts::Vector2(weights.at(i) * p0, weights.at(i) * p1);
0309
0310
0311
0312 var += Acts::Vector2(weights.at(i) * weights.at(i) * w0 * w0 / 12,
0313 weights.at(i) * weights.at(i) * w1 * w1 / 12);
0314
0315 clus.channels.push_back(std::move(ch));
0316
0317
0318
0319 clus.sizeLoc0 = b0max - b0min + 1;
0320 clus.sizeLoc1 = b1max - b1min + 1;
0321 }
0322
0323 if (tot > 0) {
0324 pos /= tot;
0325 var /= (tot * tot);
0326 }
0327
0328 for (auto idx : m_geoIndices) {
0329 mval.paramIndices.push_back(idx);
0330 mval.paramValues.push_back(pos[idx]);
0331 mval.paramVariances.push_back(var[idx]);
0332 }
0333
0334 mval.value = std::move(clus);
0335
0336
0337 for (ModuleValue& other : values) {
0338 mval.sources.merge(other.sources);
0339 }
0340
0341 return mval;
0342 }
0343
0344 }