File indexing completed on 2025-01-18 09:11:06
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Definitions/Algebra.hpp"
0012 #include "Acts/Definitions/TrackParametrization.hpp"
0013 #include "Acts/Surfaces/CylinderSurface.hpp"
0014 #include "Acts/TrackFitting/GsfOptions.hpp"
0015 #include "Acts/Utilities/detail/periodic.hpp"
0016
0017 #include <cmath>
0018 #include <numbers>
0019 #include <optional>
0020 #include <tuple>
0021
0022 namespace Acts::detail {
0023
0024
0025 template <BoundIndices Idx>
0026 struct CyclicAngle {
0027 constexpr static BoundIndices idx = Idx;
0028 constexpr static double constant = 1.0;
0029 };
0030
0031 template <BoundIndices Idx>
0032 struct CyclicRadiusAngle {
0033 constexpr static BoundIndices idx = Idx;
0034 double constant = 1.0;
0035 };
0036
0037
0038 template <Surface::SurfaceType type_t>
0039 struct AngleDescription {
0040 using Desc = std::tuple<CyclicAngle<eBoundPhi>>;
0041 };
0042
0043 template <>
0044 struct AngleDescription<Surface::Disc> {
0045 using Desc = std::tuple<CyclicAngle<eBoundLoc1>, CyclicAngle<eBoundPhi>>;
0046 };
0047
0048 template <>
0049 struct AngleDescription<Surface::Cylinder> {
0050 using Desc =
0051 std::tuple<CyclicRadiusAngle<eBoundLoc0>, CyclicAngle<eBoundPhi>>;
0052 };
0053
0054
0055
0056 template <typename Callable>
0057 auto angleDescriptionSwitch(const Surface &surface, Callable &&callable) {
0058 switch (surface.type()) {
0059 case Surface::Cylinder: {
0060 auto desc = AngleDescription<Surface::Cylinder>::Desc{};
0061 const auto &bounds =
0062 static_cast<const CylinderSurface &>(surface).bounds();
0063 std::get<0>(desc).constant = bounds.get(CylinderBounds::eR);
0064 return callable(desc);
0065 }
0066 case Surface::Disc: {
0067 auto desc = AngleDescription<Surface::Disc>::Desc{};
0068 return callable(desc);
0069 }
0070 default: {
0071 auto desc = AngleDescription<Surface::Plane>::Desc{};
0072 return callable(desc);
0073 }
0074 }
0075 }
0076
0077 template <int D, typename components_t, typename projector_t,
0078 typename angle_desc_t>
0079 auto gaussianMixtureCov(const components_t components,
0080 const ActsVector<D> &mean, double sumOfWeights,
0081 projector_t &&projector,
0082 const angle_desc_t &angleDesc) {
0083 ActsSquareMatrix<D> cov = ActsSquareMatrix<D>::Zero();
0084
0085 for (const auto &cmp : components) {
0086 const auto &[weight_l, pars_l, cov_l] = projector(cmp);
0087
0088 cov += weight_l * cov_l;
0089
0090 ActsVector<D> diff = pars_l - mean;
0091
0092
0093 auto handleCyclicCov = [&l = pars_l, &m = mean, &diff = diff](auto desc) {
0094 diff[desc.idx] = difference_periodic(l[desc.idx] / desc.constant,
0095 m[desc.idx] / desc.constant,
0096 2 * std::numbers::pi) *
0097 desc.constant;
0098 };
0099
0100 std::apply([&](auto... dsc) { (handleCyclicCov(dsc), ...); }, angleDesc);
0101
0102 cov += weight_l * diff * diff.transpose();
0103 }
0104
0105 cov /= sumOfWeights;
0106 return cov;
0107 }
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126 template <typename components_t, typename projector_t = std::identity,
0127 typename angle_desc_t = AngleDescription<Surface::Plane>::Desc>
0128 auto gaussianMixtureMeanCov(const components_t components,
0129 projector_t &&projector = projector_t{},
0130 const angle_desc_t &angleDesc = angle_desc_t{}) {
0131
0132 const auto &[beginWeight, beginPars, beginCov] =
0133 projector(components.front());
0134
0135
0136 using ParsType = std::decay_t<decltype(beginPars)>;
0137 using CovType = std::decay_t<decltype(beginCov)>;
0138 using WeightType = std::decay_t<decltype(beginWeight)>;
0139
0140 constexpr int D = ParsType::RowsAtCompileTime;
0141 EIGEN_STATIC_ASSERT_VECTOR_ONLY(ParsType);
0142 EIGEN_STATIC_ASSERT_MATRIX_SPECIFIC_SIZE(CovType, D, D);
0143 static_assert(std::is_floating_point_v<WeightType>);
0144
0145
0146
0147 #if defined(__GNUC__) && __GNUC__ < 9 && !defined(__clang__)
0148
0149 #else
0150 std::apply(
0151 [&](auto... d) { static_assert((std::less<int>{}(d.idx, D) && ...)); },
0152 angleDesc);
0153 #endif
0154
0155
0156 using RetType = std::tuple<ActsVector<D>, ActsSquareMatrix<D>>;
0157
0158
0159 if (components.size() == 1) {
0160 return RetType{beginPars / beginWeight, beginCov / beginWeight};
0161 }
0162
0163
0164 ActsVector<D> mean = ActsVector<D>::Zero();
0165 WeightType sumOfWeights{0.0};
0166
0167 for (const auto &cmp : components) {
0168 const auto &[weight_l, pars_l, cov_l] = projector(cmp);
0169
0170 sumOfWeights += weight_l;
0171 mean += weight_l * pars_l;
0172
0173
0174 auto handleCyclicMean = [&ref = beginPars, &pars = pars_l,
0175 &weight = weight_l, &mean = mean](auto desc) {
0176 const auto delta = (ref[desc.idx] - pars[desc.idx]) / desc.constant;
0177
0178 if (delta > std::numbers::pi) {
0179 mean[desc.idx] += 2. * std::numbers::pi * weight * desc.constant;
0180 } else if (delta < -std::numbers::pi) {
0181 mean[desc.idx] -= 2. * std::numbers::pi * weight * desc.constant;
0182 }
0183 };
0184
0185 std::apply([&](auto... dsc) { (handleCyclicMean(dsc), ...); }, angleDesc);
0186 }
0187
0188 mean /= sumOfWeights;
0189
0190 auto wrap = [&](auto desc) {
0191 mean[desc.idx] = wrap_periodic(mean[desc.idx] / desc.constant,
0192 -std::numbers::pi, 2 * std::numbers::pi) *
0193 desc.constant;
0194 };
0195
0196 std::apply([&](auto... dsc) { (wrap(dsc), ...); }, angleDesc);
0197
0198
0199 const auto cov =
0200 gaussianMixtureCov(components, mean, sumOfWeights, projector, angleDesc);
0201
0202
0203 return RetType{mean, cov};
0204 }
0205
0206
0207
0208
0209
0210
0211
0212
0213
0214
0215 template <typename mixture_t, typename projector_t = std::identity>
0216 auto mergeGaussianMixture(const mixture_t &mixture, const Surface &surface,
0217 ComponentMergeMethod method,
0218 projector_t &&projector = projector_t{}) {
0219 using R = std::tuple<Acts::BoundVector, Acts::BoundSquareMatrix>;
0220 const auto [mean, cov] =
0221 detail::angleDescriptionSwitch(surface, [&](const auto &desc) {
0222 return detail::gaussianMixtureMeanCov(mixture, projector, desc);
0223 });
0224
0225 if (method == ComponentMergeMethod::eMean) {
0226 return R{mean, cov};
0227 } else {
0228 const auto maxWeightIt = std::max_element(
0229 mixture.begin(), mixture.end(), [&](const auto &a, const auto &b) {
0230 return std::get<0>(projector(a)) < std::get<0>(projector(b));
0231 });
0232 const BoundVector meanMaxWeight = std::get<1>(projector(*maxWeightIt));
0233
0234 return R{meanMaxWeight, cov};
0235 }
0236 }
0237
0238 }