File indexing completed on 2026-04-05 07:45:35
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/Digitization/ModuleClusters.hpp"
0010
0011 #include "Acts/Clusterization/Clusterization.hpp"
0012 #include "Acts/Utilities/Helpers.hpp"
0013 #include "ActsExamples/Digitization/MeasurementCreation.hpp"
0014 #include "ActsExamples/EventData/SimHit.hpp"
0015
0016 #include <array>
0017 #include <cmath>
0018 #include <cstdlib>
0019 #include <limits>
0020 #include <map>
0021 #include <stdexcept>
0022
0023 namespace ActsExamples {
0024
0025 void ModuleClusters::add(DigitizedParameters params, SimHitIndex simhit) {
0026 ModuleValue mval;
0027 mval.paramIndices = std::move(params.indices);
0028 mval.paramValues = std::move(params.values);
0029 mval.paramVariances = std::move(params.variances);
0030 mval.sources = {simhit};
0031
0032 if (m_merge && !params.cluster.channels.empty()) {
0033
0034 for (const auto& cell : params.cluster.channels) {
0035 ModuleValue mval_cell = mval;
0036 mval_cell.value = cell;
0037 m_moduleValues.push_back(std::move(mval_cell));
0038 }
0039 } else {
0040
0041 mval.value = std::move(params.cluster);
0042 m_moduleValues.push_back(std::move(mval));
0043 }
0044 }
0045
0046 std::vector<std::pair<DigitizedParameters, std::set<SimHitIndex>>>
0047 ModuleClusters::digitizedParameters() {
0048 if (m_merge) {
0049 merge();
0050 }
0051 std::vector<std::pair<DigitizedParameters, std::set<SimHitIndex>>> retv;
0052 for (ModuleValue& mval : m_moduleValues) {
0053 if (std::holds_alternative<Cluster::Cell>(mval.value)) {
0054
0055
0056
0057 throw std::runtime_error("Invalid cluster!");
0058 }
0059 DigitizedParameters dpars;
0060 dpars.indices = mval.paramIndices;
0061 dpars.values = mval.paramValues;
0062 dpars.variances = mval.paramVariances;
0063 dpars.cluster = std::get<Cluster>(mval.value);
0064 retv.emplace_back(std::move(dpars), mval.sources);
0065 }
0066 return retv;
0067 }
0068
0069
0070 int getCellRow(const ModuleValue& mval) {
0071 if (std::holds_alternative<Cluster::Cell>(mval.value)) {
0072 return std::get<Cluster::Cell>(mval.value).bin[0];
0073 }
0074 throw std::domain_error("ModuleValue does not contain cell!");
0075 }
0076
0077 int getCellColumn(const ModuleValue& mval) {
0078 if (std::holds_alternative<Cluster::Cell>(mval.value)) {
0079 return std::get<Cluster::Cell>(mval.value).bin[1];
0080 }
0081 throw std::domain_error("ModuleValue does not contain cell!");
0082 }
0083
0084 void clusterAddCell(std::vector<ModuleValue>& cl, const ModuleValue& ce) {
0085 cl.push_back(ce);
0086 }
0087
0088 std::vector<ModuleValue> ModuleClusters::createCellCollection() {
0089 std::map<ActsFatras::Segmentizer::Bin2D, ModuleValue> uniqueCells;
0090 for (const ModuleValue& mval : m_moduleValues) {
0091 if (!std::holds_alternative<Cluster::Cell>(mval.value)) {
0092 continue;
0093 }
0094 const auto& cell = std::get<Cluster::Cell>(mval.value).bin;
0095
0096 if (const auto it = uniqueCells.find(cell); it != uniqueCells.end()) {
0097
0098 std::get<Cluster::Cell>(it->second.value).activation +=
0099 std::get<Cluster::Cell>(mval.value).activation;
0100 } else {
0101
0102 uniqueCells[cell] = mval;
0103 }
0104 }
0105 std::vector<ModuleValue> cells;
0106 cells.reserve(uniqueCells.size());
0107 for (const auto& [_, mval] : uniqueCells) {
0108 cells.push_back(mval);
0109 }
0110 return cells;
0111 }
0112
0113 void ModuleClusters::merge() {
0114 std::vector<ModuleValue> cells = createCellCollection();
0115
0116 std::vector<ModuleValue> newVals;
0117
0118 if (!cells.empty()) {
0119
0120 Acts::Ccl::ClusteringData data;
0121 std::vector<std::vector<ModuleValue>> merged;
0122 Acts::Ccl::createClusters<std::vector<ModuleValue>,
0123 std::vector<std::vector<ModuleValue>>>(
0124 data, cells, merged,
0125 Acts::Ccl::DefaultConnect<ModuleValue>(m_commonCorner));
0126
0127 for (std::vector<ModuleValue>& cellv : merged) {
0128
0129
0130
0131
0132
0133
0134 for (std::vector<ModuleValue>& remerged : mergeParameters(cellv)) {
0135 newVals.push_back(squash(remerged));
0136 }
0137 }
0138 m_moduleValues = std::move(newVals);
0139 } else {
0140
0141 for (std::vector<ModuleValue>& merged : mergeParameters(m_moduleValues)) {
0142 newVals.push_back(squash(merged));
0143 }
0144 m_moduleValues = std::move(newVals);
0145 }
0146 }
0147
0148
0149 std::vector<std::size_t> ModuleClusters::nonGeoEntries(
0150 std::vector<Acts::BoundIndices>& indices) {
0151 std::vector<std::size_t> retv;
0152 for (std::size_t i = 0; i < indices.size(); i++) {
0153 auto idx = indices.at(i);
0154 if (!rangeContainsValue(m_geoIndices, idx)) {
0155 retv.push_back(i);
0156 }
0157 }
0158 return retv;
0159 }
0160
0161
0162 std::vector<std::vector<ModuleValue>> ModuleClusters::mergeParameters(
0163 std::vector<ModuleValue> values) {
0164 std::vector<std::vector<ModuleValue>> retv;
0165
0166 std::vector<bool> used(values.size(), false);
0167 for (std::size_t i = 0; i < values.size(); i++) {
0168 if (used.at(i)) {
0169 continue;
0170 }
0171
0172 retv.emplace_back();
0173 std::vector<ModuleValue>& thisvec = retv.back();
0174
0175
0176 thisvec.push_back(std::move(values.at(i)));
0177 used.at(i) = true;
0178
0179
0180
0181
0182 for (std::size_t j = i + 1; j < values.size(); j++) {
0183
0184 if (used.at(j)) {
0185 continue;
0186 }
0187
0188
0189
0190
0191
0192 bool matched = true;
0193
0194
0195
0196
0197 for (ModuleValue& thisval : thisvec) {
0198
0199 for (auto k : nonGeoEntries(thisval.paramIndices)) {
0200 double p_i = thisval.paramValues.at(k);
0201 double p_j = values.at(j).paramValues.at(k);
0202 double v_i = thisval.paramVariances.at(k);
0203 double v_j = values.at(j).paramVariances.at(k);
0204
0205 double left = 0, right = 0;
0206 if (p_i < p_j) {
0207 left = p_i + m_nsigma * std::sqrt(v_i);
0208 right = p_j - m_nsigma * std::sqrt(v_j);
0209 } else {
0210 left = p_j + m_nsigma * std::sqrt(v_j);
0211 right = p_i - m_nsigma * std::sqrt(v_i);
0212 }
0213 if (left < right) {
0214
0215
0216 matched = false;
0217 break;
0218 }
0219 }
0220 if (matched) {
0221
0222
0223
0224 break;
0225 }
0226 }
0227 if (matched) {
0228
0229 used.at(j) = true;
0230 thisvec.push_back(std::move(values.at(j)));
0231 }
0232 }
0233 }
0234 return retv;
0235 }
0236
0237 ModuleValue ModuleClusters::squash(std::vector<ModuleValue>& values) {
0238 ModuleValue mval;
0239 double tot = 0;
0240 double tot2 = 0;
0241 std::vector<double> weights;
0242
0243
0244 for (ModuleValue& other : values) {
0245 if (std::holds_alternative<Cluster::Cell>(other.value)) {
0246 weights.push_back(std::get<Cluster::Cell>(other.value).activation);
0247 } else {
0248 weights.push_back(1);
0249 }
0250 tot += weights.back();
0251 tot2 += weights.back() * weights.back();
0252 }
0253
0254
0255 for (std::size_t i = 0; i < values.size(); i++) {
0256 ModuleValue& other = values.at(i);
0257 for (std::size_t j = 0; j < other.paramIndices.size(); j++) {
0258 auto idx = other.paramIndices.at(j);
0259 if (!rangeContainsValue(m_geoIndices, idx)) {
0260 if (!rangeContainsValue(mval.paramIndices, idx)) {
0261 mval.paramIndices.push_back(idx);
0262 }
0263 if (mval.paramValues.size() < (j + 1)) {
0264 mval.paramValues.push_back(0);
0265 mval.paramVariances.push_back(0);
0266 }
0267 double f = weights.at(i) / (tot > 0 ? tot : 1);
0268 double f2 = weights.at(i) * weights.at(i) / (tot2 > 0 ? tot2 : 1);
0269 mval.paramValues.at(j) += f * other.paramValues.at(j);
0270 mval.paramVariances.at(j) += f2 * other.paramVariances.at(j);
0271 }
0272 }
0273 }
0274
0275
0276 Cluster clus;
0277
0278 const auto& binningData = m_segmentation.binningData();
0279 Acts::Vector2 pos(0., 0.);
0280 Acts::Vector2 var(0., 0.);
0281
0282 std::size_t b0min = std::numeric_limits<std::size_t>::max();
0283 std::size_t b0max = 0;
0284 std::size_t b1min = std::numeric_limits<std::size_t>::max();
0285 std::size_t b1max = 0;
0286
0287 for (std::size_t i = 0; i < values.size(); i++) {
0288 ModuleValue& other = values.at(i);
0289 if (!std::holds_alternative<Cluster::Cell>(other.value)) {
0290 continue;
0291 }
0292
0293 Cluster::Cell ch = std::get<Cluster::Cell>(other.value);
0294 auto bin = ch.bin;
0295
0296 std::size_t b0 = bin[0];
0297 std::size_t b1 = bin[1];
0298
0299 b0min = std::min(b0min, b0);
0300 b0max = std::max(b0max, b0);
0301 b1min = std::min(b1min, b1);
0302 b1max = std::max(b1max, b1);
0303
0304 float p0 = binningData[0].center(b0);
0305 float w0 = binningData[0].width(b0);
0306 float p1 = binningData[1].center(b1);
0307 float w1 = binningData[1].width(b1);
0308
0309 pos += Acts::Vector2(weights.at(i) * p0, weights.at(i) * p1);
0310
0311
0312
0313 var += Acts::Vector2(weights.at(i) * weights.at(i) * w0 * w0 / 12,
0314 weights.at(i) * weights.at(i) * w1 * w1 / 12);
0315
0316 clus.channels.push_back(std::move(ch));
0317
0318
0319
0320 clus.sizeLoc0 = b0max - b0min + 1;
0321 clus.sizeLoc1 = b1max - b1min + 1;
0322 }
0323
0324 if (tot > 0) {
0325 pos /= tot;
0326 var /= (tot * tot);
0327 }
0328
0329 for (auto idx : m_geoIndices) {
0330 mval.paramIndices.push_back(idx);
0331 mval.paramValues.push_back(pos[idx]);
0332 mval.paramVariances.push_back(var[idx]);
0333 }
0334
0335 mval.value = std::move(clus);
0336
0337
0338 for (ModuleValue& other : values) {
0339 mval.sources.merge(other.sources);
0340 }
0341
0342 return mval;
0343 }
0344
0345 }