File indexing completed on 2026-03-28 07:45:44
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/TrackFitting/detail/GsfComponentMerging.hpp"
0010
0011 #include <iostream>
0012
0013 namespace Acts {
0014
0015 std::tuple<BoundVector, BoundMatrix> detail::Gsf::mergeGaussianMixture(
0016 std::span<const GsfComponent> mixture, const Surface &surface,
0017 ComponentMergeMethod method) {
0018 return mergeGaussianMixture(
0019 mixture,
0020 [](const GsfComponent &c) {
0021 return std::tie(c.weight, c.boundPars, c.boundCov);
0022 },
0023 surface, method);
0024 }
0025
0026 GsfComponent detail::Gsf::mergeTwoComponents(const GsfComponent &a,
0027 const GsfComponent &b,
0028 const Surface &surface) {
0029 assert(a.weight >= 0.0 && b.weight >= 0.0 && "non-positive weight");
0030
0031 std::array components = {&a, &b};
0032 const auto proj = [](const GsfComponent *c) {
0033 return std::tie(c->weight, c->boundPars, c->boundCov);
0034 };
0035 auto [mergedPars, mergedCov] =
0036 angleDescriptionSwitch(surface, [&](const auto &desc) {
0037 return mergeGaussianMixtureMeanCov(components, proj, desc);
0038 });
0039
0040 GsfComponent ret = a;
0041 ret.boundPars = mergedPars;
0042 ret.boundCov = mergedCov;
0043 ret.weight = a.weight + b.weight;
0044 return ret;
0045 }
0046
0047 double detail::Gsf::computeSymmetricKlDivergence(const GsfComponent &a,
0048 const GsfComponent &b) {
0049 const double parsA = a.boundPars[eBoundQOverP];
0050 const double parsB = b.boundPars[eBoundQOverP];
0051 const double covA = a.boundCov(eBoundQOverP, eBoundQOverP);
0052 const double covB = b.boundCov(eBoundQOverP, eBoundQOverP);
0053
0054 assert(covA != 0.0);
0055 assert(std::isfinite(covA));
0056 assert(covB != 0.0);
0057 assert(std::isfinite(covB));
0058
0059 const double kl = covA * (1 / covB) + covB * (1 / covA) +
0060 (parsA - parsB) * (1 / covA + 1 / covB) * (parsA - parsB);
0061
0062 assert(kl >= 0.0 && "kl-divergence must be non-negative");
0063
0064 return kl;
0065 }
0066
0067 namespace detail::Gsf {
0068
0069 SymmetricKLDistanceMatrix::SymmetricKLDistanceMatrix(
0070 std::span<const GsfComponent> cmps)
0071 : m_distances(Array::Zero(cmps.size() * (cmps.size() - 1) / 2)),
0072 m_mask(Mask::Ones(cmps.size() * (cmps.size() - 1) / 2)),
0073 m_mapToPair(m_distances.size()),
0074 m_numberComponents(cmps.size()) {
0075 for (std::size_t i = 1; i < m_numberComponents; ++i) {
0076 const std::size_t indexConst = (i - 1) * i / 2;
0077 for (std::size_t j = 0; j < i; ++j) {
0078 m_mapToPair.at(indexConst + j) = {i, j};
0079 m_distances[indexConst + j] =
0080 computeSymmetricKlDivergence(cmps[i], cmps[j]);
0081 }
0082 }
0083 }
0084
0085 double SymmetricKLDistanceMatrix::at(std::size_t i, std::size_t j) const {
0086 return m_distances[i * (i - 1) / 2 + j];
0087 }
0088
0089 void SymmetricKLDistanceMatrix::recomputeAssociatedDistances(
0090 std::size_t n, std::span<const GsfComponent> cmps) {
0091 assert(cmps.size() == m_numberComponents && "size mismatch");
0092
0093 setAssociated(n, m_distances, [&](std::size_t i, std::size_t j) {
0094 return computeSymmetricKlDivergence(cmps[i], cmps[j]);
0095 });
0096 }
0097
0098 void SymmetricKLDistanceMatrix::maskAssociatedDistances(std::size_t n) {
0099 setAssociated(n, m_mask, [&](std::size_t, std::size_t) { return false; });
0100 }
0101
0102 std::pair<std::size_t, std::size_t> SymmetricKLDistanceMatrix::minDistancePair()
0103 const {
0104 double min = std::numeric_limits<double>::max();
0105 std::size_t idx = 0;
0106
0107 for (std::size_t i = 0; i < static_cast<std::size_t>(m_distances.size());
0108 ++i) {
0109 if (double new_min = std::min(min, m_distances[i]);
0110 m_mask[i] && new_min < min) {
0111 min = new_min;
0112 idx = i;
0113 }
0114 }
0115
0116 return m_mapToPair.at(idx);
0117 }
0118
0119 std::ostream &SymmetricKLDistanceMatrix::toStream(std::ostream &os) const {
0120 const auto prev_precision = os.precision();
0121 const int width = 8;
0122 const int prec = 2;
0123
0124 os << "\n";
0125 os << std::string(width, ' ') << " | ";
0126 for (std::size_t j = 0ul; j < m_numberComponents - 1; ++j) {
0127 os << std::setw(width) << j << " ";
0128 }
0129 os << "\n";
0130 os << std::string((width + 3) + (width + 2) * (m_numberComponents - 1), '-');
0131 os << "\n";
0132
0133 for (std::size_t i = 1ul; i < m_numberComponents; ++i) {
0134 const std::size_t indexConst = (i - 1) * i / 2;
0135 os << std::setw(width) << i << " | ";
0136 for (std::size_t j = 0ul; j < i; ++j) {
0137 os << std::setw(width) << std::setprecision(prec)
0138 << m_distances[indexConst + j] << " ";
0139 }
0140 os << "\n";
0141 }
0142 os << std::setprecision(prev_precision);
0143 return os;
0144 }
0145
0146 }
0147
0148 }