File indexing completed on 2026-05-21 08:00:08
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Definitions/Algebra.hpp"
0012 #include "Acts/Definitions/TrackParametrization.hpp"
0013 #include "Acts/Surfaces/CylinderSurface.hpp"
0014 #include "Acts/TrackFitting/GsfOptions.hpp"
0015 #include "Acts/Utilities/detail/periodic.hpp"
0016
0017 #include <numbers>
0018 #include <tuple>
0019
0020 namespace Acts::detail::Gsf {
0021
0022
0023 template <BoundIndices Idx>
0024 struct CyclicAngle {
0025 constexpr static BoundIndices idx = Idx;
0026 constexpr static double constant = 1.0;
0027 };
0028
0029 template <BoundIndices Idx>
0030 struct CyclicRadiusAngle {
0031 constexpr static BoundIndices idx = Idx;
0032 double constant = 1.0;
0033 };
0034
0035
0036 template <Surface::SurfaceType type_t>
0037 struct AngleDescription {
0038 using Desc = std::tuple<CyclicAngle<eBoundPhi>>;
0039 };
0040
0041 template <>
0042 struct AngleDescription<Surface::Disc> {
0043 using Desc = std::tuple<CyclicAngle<eBoundLoc1>, CyclicAngle<eBoundPhi>>;
0044 };
0045
0046 template <>
0047 struct AngleDescription<Surface::Cylinder> {
0048 using Desc =
0049 std::tuple<CyclicRadiusAngle<eBoundLoc0>, CyclicAngle<eBoundPhi>>;
0050 };
0051
0052
0053
0054 template <typename Callable>
0055 auto angleDescriptionSwitch(const Surface &surface, Callable &&callable) {
0056 switch (surface.type()) {
0057 case Surface::Cylinder: {
0058 auto desc = AngleDescription<Surface::Cylinder>::Desc{};
0059 const auto &bounds =
0060 static_cast<const CylinderSurface &>(surface).bounds();
0061 std::get<0>(desc).constant = bounds.get(CylinderBounds::eR);
0062 return callable(desc);
0063 }
0064 case Surface::Disc: {
0065 auto desc = AngleDescription<Surface::Disc>::Desc{};
0066 return callable(desc);
0067 }
0068 default: {
0069 auto desc = AngleDescription<Surface::Plane>::Desc{};
0070 return callable(desc);
0071 }
0072 }
0073 }
0074
0075 template <int D, typename components_t, typename projector_t,
0076 typename angle_desc_t>
0077 auto gaussianMixtureCov(const components_t components,
0078 const ActsVector<D> &mean, double sumOfWeights,
0079 projector_t &&projector,
0080 const angle_desc_t &angleDesc) {
0081 ActsSquareMatrix<D> cov = ActsSquareMatrix<D>::Zero();
0082
0083 for (const auto &cmp : components) {
0084 const auto &[weight_l, pars_l, cov_l] = projector(cmp);
0085
0086 cov += weight_l * cov_l;
0087
0088 ActsVector<D> diff = pars_l - mean;
0089
0090
0091 auto handleCyclicCov = [&l = pars_l, &m = mean, &diff = diff](auto desc) {
0092 diff[desc.idx] = difference_periodic(l[desc.idx] / desc.constant,
0093 m[desc.idx] / desc.constant,
0094 2 * std::numbers::pi) *
0095 desc.constant;
0096 };
0097
0098 std::apply([&](auto... dsc) { (handleCyclicCov(dsc), ...); }, angleDesc);
0099
0100 cov += weight_l * diff * diff.transpose();
0101 }
0102
0103 cov /= sumOfWeights;
0104 return cov;
0105 }
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124 template <typename components_t, typename projector_t = std::identity,
0125 typename angle_desc_t = AngleDescription<Surface::Plane>::Desc>
0126 auto gaussianMixtureMeanCov(const components_t components,
0127 projector_t &&projector = projector_t{},
0128 const angle_desc_t &angleDesc = angle_desc_t{}) {
0129
0130 const auto &[beginWeight, beginPars, beginCov] =
0131 projector(components.front());
0132
0133
0134 using ParsType = std::decay_t<decltype(beginPars)>;
0135 using CovType = std::decay_t<decltype(beginCov)>;
0136 using WeightType = std::decay_t<decltype(beginWeight)>;
0137
0138 constexpr int D = ParsType::RowsAtCompileTime;
0139 EIGEN_STATIC_ASSERT_VECTOR_ONLY(ParsType);
0140 EIGEN_STATIC_ASSERT_MATRIX_SPECIFIC_SIZE(CovType, D, D);
0141 static_assert(std::is_floating_point_v<WeightType>);
0142
0143
0144
0145 #if defined(__GNUC__) && __GNUC__ < 9 && !defined(__clang__)
0146
0147 #else
0148 std::apply(
0149 [&](auto... d) { static_assert((std::less<int>{}(d.idx, D) && ...)); },
0150 angleDesc);
0151 #endif
0152
0153
0154 using RetType = std::tuple<ActsVector<D>, ActsSquareMatrix<D>>;
0155
0156
0157 if (components.size() == 1) {
0158 return RetType{beginPars / beginWeight, beginCov / beginWeight};
0159 }
0160
0161
0162
0163
0164
0165 using CVec = Eigen::Matrix<std::complex<double>, D, 1>;
0166 CVec cMean = CVec::Zero();
0167 WeightType sumOfWeights{0.0};
0168
0169 for (const auto &cmp : components) {
0170 const auto &[weight_l, pars_l, cov_l] = projector(cmp);
0171
0172 CVec pars_l_c = pars_l;
0173
0174 auto setPolar = [&](auto desc) {
0175 pars_l_c[desc.idx] = std::polar(1.0, pars_l[desc.idx] / desc.constant);
0176 };
0177 std::apply([&](auto... dsc) { (setPolar(dsc), ...); }, angleDesc);
0178
0179 sumOfWeights += weight_l;
0180 cMean += weight_l * pars_l_c;
0181 }
0182
0183 cMean /= sumOfWeights;
0184
0185 ActsVector<D> mean = cMean.real();
0186
0187 auto getArg = [&](auto desc) {
0188 mean[desc.idx] = desc.constant * std::arg(cMean[desc.idx]);
0189 };
0190 std::apply([&](auto... dsc) { (getArg(dsc), ...); }, angleDesc);
0191
0192
0193 const auto cov =
0194 gaussianMixtureCov(components, mean, sumOfWeights, projector, angleDesc);
0195
0196
0197 return RetType{mean, cov};
0198 }
0199
0200
0201
0202
0203
0204
0205
0206
0207
0208
0209 template <typename mixture_t, typename projector_t = std::identity>
0210 auto mergeGaussianMixture(const mixture_t &mixture, const Surface &surface,
0211 ComponentMergeMethod method,
0212 projector_t &&projector = projector_t{}) {
0213 using R = std::tuple<Acts::BoundVector, Acts::BoundMatrix>;
0214 const auto [mean, cov] =
0215 angleDescriptionSwitch(surface, [&](const auto &desc) {
0216 return gaussianMixtureMeanCov(mixture, projector, desc);
0217 });
0218
0219 if (method == ComponentMergeMethod::eMean) {
0220 return R{mean, cov};
0221 } else {
0222 const auto maxWeightIt = std::max_element(
0223 mixture.begin(), mixture.end(), [&](const auto &a, const auto &b) {
0224 return std::get<0>(projector(a)) < std::get<0>(projector(b));
0225 });
0226 const BoundVector meanMaxWeight = std::get<1>(projector(*maxWeightIt));
0227
0228 return R{meanMaxWeight, cov};
0229 }
0230 }
0231
0232 }