File indexing completed on 2025-01-18 09:11:35
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/Digitization/ModuleClusters.hpp"
0010
0011 #include "Acts/Clusterization/Clusterization.hpp"
0012 #include "Acts/Utilities/Helpers.hpp"
0013 #include "ActsExamples/Digitization/MeasurementCreation.hpp"
0014 #include "ActsFatras/Digitization/Channelizer.hpp"
0015
0016 #include <array>
0017 #include <cmath>
0018 #include <cstdint>
0019 #include <cstdlib>
0020 #include <limits>
0021 #include <memory>
0022 #include <stdexcept>
0023 #include <type_traits>
0024
0025 namespace ActsExamples {
0026
0027 void ModuleClusters::add(DigitizedParameters params, simhit_t simhit) {
0028 ModuleValue mval;
0029 mval.paramIndices = std::move(params.indices);
0030 mval.paramValues = std::move(params.values);
0031 mval.paramVariances = std::move(params.variances);
0032 mval.sources = {simhit};
0033
0034 if (m_merge && !params.cluster.channels.empty()) {
0035
0036 for (const auto& cell : params.cluster.channels) {
0037 ModuleValue mval_cell = mval;
0038 mval_cell.value = cell;
0039 m_moduleValues.push_back(std::move(mval_cell));
0040 }
0041 } else {
0042
0043 mval.value = std::move(params.cluster);
0044 m_moduleValues.push_back(std::move(mval));
0045 }
0046 }
0047
0048 std::vector<std::pair<DigitizedParameters, std::set<ModuleClusters::simhit_t>>>
0049 ModuleClusters::digitizedParameters() {
0050 if (m_merge) {
0051 merge();
0052 }
0053 std::vector<std::pair<DigitizedParameters, std::set<simhit_t>>> retv;
0054 for (ModuleValue& mval : m_moduleValues) {
0055 if (std::holds_alternative<Cluster::Cell>(mval.value)) {
0056
0057
0058
0059 throw std::runtime_error("Invalid cluster!");
0060 }
0061 DigitizedParameters dpars;
0062 dpars.indices = mval.paramIndices;
0063 dpars.values = mval.paramValues;
0064 dpars.variances = mval.paramVariances;
0065 dpars.cluster = std::get<Cluster>(mval.value);
0066 retv.emplace_back(std::move(dpars), mval.sources);
0067 }
0068 return retv;
0069 }
0070
0071
0072 int getCellRow(const ModuleValue& mval) {
0073 if (std::holds_alternative<ActsExamples::Cluster::Cell>(mval.value)) {
0074 return std::get<ActsExamples::Cluster::Cell>(mval.value).bin[0];
0075 }
0076 throw std::domain_error("ModuleValue does not contain cell!");
0077 }
0078
0079 int getCellColumn(const ActsExamples::ModuleValue& mval) {
0080 if (std::holds_alternative<ActsExamples::Cluster::Cell>(mval.value)) {
0081 return std::get<ActsExamples::Cluster::Cell>(mval.value).bin[1];
0082 }
0083 throw std::domain_error("ModuleValue does not contain cell!");
0084 }
0085
0086 int& getCellLabel(ActsExamples::ModuleValue& mval) {
0087 return mval.label;
0088 }
0089
0090 void clusterAddCell(std::vector<ModuleValue>& cl, const ModuleValue& ce) {
0091 cl.push_back(ce);
0092 }
0093
0094 std::vector<ModuleValue> ModuleClusters::createCellCollection() {
0095 std::vector<ModuleValue> cells;
0096 for (ModuleValue& mval : m_moduleValues) {
0097 if (std::holds_alternative<Cluster::Cell>(mval.value)) {
0098 cells.push_back(mval);
0099 }
0100 }
0101 return cells;
0102 }
0103
0104 void ModuleClusters::merge() {
0105 std::vector<ModuleValue> cells = createCellCollection();
0106
0107 std::vector<ModuleValue> newVals;
0108
0109 if (!cells.empty()) {
0110
0111 std::vector<std::vector<ModuleValue>> merged =
0112 Acts::Ccl::createClusters<std::vector<ModuleValue>,
0113 std::vector<std::vector<ModuleValue>>>(
0114 cells, Acts::Ccl::DefaultConnect<ModuleValue>(m_commonCorner));
0115
0116 for (std::vector<ModuleValue>& cellv : merged) {
0117
0118
0119
0120
0121
0122
0123 for (std::vector<ModuleValue>& remerged : mergeParameters(cellv)) {
0124 newVals.push_back(squash(remerged));
0125 }
0126 }
0127 m_moduleValues = std::move(newVals);
0128 } else {
0129
0130 for (std::vector<ModuleValue>& merged : mergeParameters(m_moduleValues)) {
0131 newVals.push_back(squash(merged));
0132 }
0133 m_moduleValues = std::move(newVals);
0134 }
0135 }
0136
0137
0138 std::vector<std::size_t> ModuleClusters::nonGeoEntries(
0139 std::vector<Acts::BoundIndices>& indices) {
0140 std::vector<std::size_t> retv;
0141 for (std::size_t i = 0; i < indices.size(); i++) {
0142 auto idx = indices.at(i);
0143 if (!rangeContainsValue(m_geoIndices, idx)) {
0144 retv.push_back(i);
0145 }
0146 }
0147 return retv;
0148 }
0149
0150
0151 std::vector<std::vector<ModuleValue>> ModuleClusters::mergeParameters(
0152 std::vector<ModuleValue> values) {
0153 std::vector<std::vector<ModuleValue>> retv;
0154
0155 std::vector<bool> used(values.size(), false);
0156 for (std::size_t i = 0; i < values.size(); i++) {
0157 if (used.at(i)) {
0158 continue;
0159 }
0160
0161 retv.emplace_back();
0162 std::vector<ModuleValue>& thisvec = retv.back();
0163
0164
0165 thisvec.push_back(std::move(values.at(i)));
0166 used.at(i) = true;
0167
0168
0169
0170
0171 for (std::size_t j = i + 1; j < values.size(); j++) {
0172
0173 if (used.at(j)) {
0174 continue;
0175 }
0176
0177
0178
0179
0180
0181 bool matched = true;
0182
0183
0184
0185
0186 for (ModuleValue& thisval : thisvec) {
0187
0188 for (auto k : nonGeoEntries(thisval.paramIndices)) {
0189 double p_i = thisval.paramValues.at(k);
0190 double p_j = values.at(j).paramValues.at(k);
0191 double v_i = thisval.paramVariances.at(k);
0192 double v_j = values.at(j).paramVariances.at(k);
0193
0194 double left = 0, right = 0;
0195 if (p_i < p_j) {
0196 left = p_i + m_nsigma * std::sqrt(v_i);
0197 right = p_j - m_nsigma * std::sqrt(v_j);
0198 } else {
0199 left = p_j + m_nsigma * std::sqrt(v_j);
0200 right = p_i - m_nsigma * std::sqrt(v_i);
0201 }
0202 if (left < right) {
0203
0204
0205 matched = false;
0206 break;
0207 }
0208 }
0209 if (matched) {
0210
0211
0212
0213 break;
0214 }
0215 }
0216 if (matched) {
0217
0218 used.at(j) = true;
0219 thisvec.push_back(std::move(values.at(j)));
0220 }
0221 }
0222 }
0223 return retv;
0224 }
0225
0226 ModuleValue ModuleClusters::squash(std::vector<ModuleValue>& values) {
0227 ModuleValue mval;
0228 double tot = 0;
0229 double tot2 = 0;
0230 std::vector<double> weights;
0231
0232
0233 for (ModuleValue& other : values) {
0234 if (std::holds_alternative<Cluster::Cell>(other.value)) {
0235 weights.push_back(std::get<Cluster::Cell>(other.value).activation);
0236 } else {
0237 weights.push_back(1);
0238 }
0239 tot += weights.back();
0240 tot2 += weights.back() * weights.back();
0241 }
0242
0243
0244 for (std::size_t i = 0; i < values.size(); i++) {
0245 ModuleValue& other = values.at(i);
0246 for (std::size_t j = 0; j < other.paramIndices.size(); j++) {
0247 auto idx = other.paramIndices.at(j);
0248 if (!rangeContainsValue(m_geoIndices, idx)) {
0249 if (!rangeContainsValue(mval.paramIndices, idx)) {
0250 mval.paramIndices.push_back(idx);
0251 }
0252 if (mval.paramValues.size() < (j + 1)) {
0253 mval.paramValues.push_back(0);
0254 mval.paramVariances.push_back(0);
0255 }
0256 double f = weights.at(i) / (tot > 0 ? tot : 1);
0257 double f2 = weights.at(i) * weights.at(i) / (tot2 > 0 ? tot2 : 1);
0258 mval.paramValues.at(j) += f * other.paramValues.at(j);
0259 mval.paramVariances.at(j) += f2 * other.paramVariances.at(j);
0260 }
0261 }
0262 }
0263
0264
0265 Cluster clus;
0266
0267 const auto& binningData = m_segmentation.binningData();
0268 Acts::Vector2 pos(0., 0.);
0269 Acts::Vector2 var(0., 0.);
0270
0271 std::size_t b0min = std::numeric_limits<std::size_t>::max();
0272 std::size_t b0max = 0;
0273 std::size_t b1min = std::numeric_limits<std::size_t>::max();
0274 std::size_t b1max = 0;
0275
0276 for (std::size_t i = 0; i < values.size(); i++) {
0277 ModuleValue& other = values.at(i);
0278 if (!std::holds_alternative<Cluster::Cell>(other.value)) {
0279 continue;
0280 }
0281
0282 Cluster::Cell ch = std::get<Cluster::Cell>(other.value);
0283 auto bin = ch.bin;
0284
0285 std::size_t b0 = bin[0];
0286 std::size_t b1 = bin[1];
0287
0288 b0min = std::min(b0min, b0);
0289 b0max = std::max(b0max, b0);
0290 b1min = std::min(b1min, b1);
0291 b1max = std::max(b1max, b1);
0292
0293 float p0 = binningData[0].center(b0);
0294 float w0 = binningData[0].width(b0);
0295 float p1 = binningData[1].center(b1);
0296 float w1 = binningData[1].width(b1);
0297
0298 pos += Acts::Vector2(weights.at(i) * p0, weights.at(i) * p1);
0299
0300
0301
0302 var += Acts::Vector2(weights.at(i) * weights.at(i) * w0 * w0 / 12,
0303 weights.at(i) * weights.at(i) * w1 * w1 / 12);
0304
0305 clus.channels.push_back(std::move(ch));
0306
0307
0308
0309 clus.sizeLoc0 = b0max - b0min + 1;
0310 clus.sizeLoc1 = b1max - b1min + 1;
0311 }
0312
0313 if (tot > 0) {
0314 pos /= tot;
0315 var /= (tot * tot);
0316 }
0317
0318 for (auto idx : m_geoIndices) {
0319 mval.paramIndices.push_back(idx);
0320 mval.paramValues.push_back(pos[idx]);
0321 mval.paramVariances.push_back(var[idx]);
0322 }
0323
0324 mval.value = std::move(clus);
0325
0326
0327 for (ModuleValue& other : values) {
0328 mval.sources.merge(other.sources);
0329 }
0330
0331 return mval;
0332 }
0333
0334 }