File indexing completed on 2025-01-19 09:23:36
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Definitions/Algebra.hpp"
0012 #include "Acts/Definitions/TrackParametrization.hpp"
0013 #include "Acts/EventData/MultiComponentTrackParameters.hpp"
0014 #include "Acts/EventData/MultiTrajectory.hpp"
0015 #include "Acts/EventData/TrackParameters.hpp"
0016 #include "Acts/Utilities/Logger.hpp"
0017
0018 #include <array>
0019 #include <cassert>
0020 #include <cmath>
0021 #include <cstddef>
0022 #include <iomanip>
0023 #include <map>
0024 #include <numeric>
0025 #include <ostream>
0026 #include <tuple>
0027 #include <vector>
0028
0029 namespace Acts {
0030
0031
0032 constexpr static double s_normalizationTolerance = 1.e-4;
0033
0034 namespace detail {
0035
0036 template <typename component_range_t, typename projector_t>
0037 bool weightsAreNormalized(const component_range_t &cmps,
0038 const projector_t &proj,
0039 double tol = s_normalizationTolerance) {
0040 double sumOfWeights = 0.0;
0041
0042 for (auto it = cmps.begin(); it != cmps.end(); ++it) {
0043 sumOfWeights += proj(*it);
0044 }
0045
0046 return std::abs(sumOfWeights - 1.0) < tol;
0047 }
0048
0049 template <typename component_range_t, typename projector_t>
0050 void normalizeWeights(component_range_t &cmps, const projector_t &proj) {
0051 double sumOfWeights = 0.0;
0052
0053
0054
0055 for (auto it = cmps.begin(); it != cmps.end(); ++it) {
0056 decltype(auto) cmp = *it;
0057 assert(std::isfinite(proj(cmp)) && "weight not finite in normalization");
0058 sumOfWeights += proj(cmp);
0059 }
0060
0061 assert(sumOfWeights > 0 && "sum of weights is not > 0");
0062
0063 for (auto it = cmps.begin(); it != cmps.end(); ++it) {
0064 decltype(auto) cmp = *it;
0065 proj(cmp) /= sumOfWeights;
0066 }
0067 }
0068
0069
0070
0071
0072
0073 template <typename propagator_state_t, typename stepper_t, typename navigator_t>
0074 class ScopedGsfInfoPrinterAndChecker {
0075 const propagator_state_t &m_state;
0076 const stepper_t &m_stepper;
0077 const navigator_t &m_navigator;
0078 double m_p_initial;
0079 const Logger &m_logger;
0080
0081 const Logger &logger() const { return m_logger; }
0082
0083 void print_component_stats() const {
0084 std::size_t i = 0;
0085 for (auto cmp : m_stepper.constComponentIterable(m_state.stepping)) {
0086 auto getVector = [&](auto idx) {
0087 return cmp.pars().template segment<3>(idx).transpose();
0088 };
0089 ACTS_VERBOSE(" #" << i++ << " pos: " << getVector(eFreePos0) << ", dir: "
0090 << getVector(eFreeDir0) << ", weight: " << cmp.weight()
0091 << ", status: " << cmp.status()
0092 << ", qop: " << cmp.pars()[eFreeQOverP]
0093 << ", det(cov): " << cmp.cov().determinant());
0094 }
0095 }
0096
0097 void checks(bool onStart) const {
0098 const auto cmps = m_stepper.constComponentIterable(m_state.stepping);
0099 [[maybe_unused]] const bool allFinite =
0100 std::all_of(cmps.begin(), cmps.end(),
0101 [](auto cmp) { return std::isfinite(cmp.weight()); });
0102 [[maybe_unused]] const bool allNormalized = detail::weightsAreNormalized(
0103 cmps, [](const auto &cmp) { return cmp.weight(); });
0104 [[maybe_unused]] const bool zeroComponents =
0105 m_stepper.numberComponents(m_state.stepping) == 0;
0106
0107 if (onStart) {
0108 assert(!zeroComponents && "no cmps at the start");
0109 assert(allFinite && "weights not finite at the start");
0110 assert(allNormalized && "not normalized at the start");
0111 } else {
0112 assert(!zeroComponents && "no cmps at the end");
0113 assert(allFinite && "weights not finite at the end");
0114 assert(allNormalized && "not normalized at the end");
0115 }
0116 }
0117
0118 public:
0119 ScopedGsfInfoPrinterAndChecker(const propagator_state_t &state,
0120 const stepper_t &stepper,
0121 const navigator_t &navigator,
0122 const Logger &logger)
0123 : m_state(state),
0124 m_stepper(stepper),
0125 m_navigator(navigator),
0126 m_p_initial(stepper.absoluteMomentum(state.stepping)),
0127 m_logger{logger} {
0128
0129 checks(true);
0130 ACTS_VERBOSE("Gsf step "
0131 << state.stepping.steps << " at mean position "
0132 << stepper.position(state.stepping).transpose()
0133 << " with direction "
0134 << stepper.direction(state.stepping).transpose()
0135 << " and momentum " << stepper.absoluteMomentum(state.stepping)
0136 << " and charge " << stepper.charge(state.stepping));
0137 ACTS_VERBOSE("Propagation is in " << state.options.direction << " mode");
0138 print_component_stats();
0139 }
0140
0141 ~ScopedGsfInfoPrinterAndChecker() {
0142 if (m_navigator.currentSurface(m_state.navigation)) {
0143 const auto p_final = m_stepper.absoluteMomentum(m_state.stepping);
0144 ACTS_VERBOSE("Component status at end of step:");
0145 print_component_stats();
0146 ACTS_VERBOSE("Delta Momentum = " << std::setprecision(5)
0147 << p_final - m_p_initial);
0148 }
0149 checks(false);
0150 }
0151 };
0152
0153 ActsScalar calculateDeterminant(
0154 const double *fullCalibratedCovariance,
0155 TrackStateTraits<MultiTrajectoryTraits::MeasurementSizeMax,
0156 true>::Covariance predictedCovariance,
0157 TrackStateTraits<MultiTrajectoryTraits::MeasurementSizeMax, true>::Projector
0158 projector,
0159 unsigned int calibratedSize);
0160
0161
0162
0163
0164
0165 template <typename traj_t>
0166 void computePosteriorWeights(
0167 const traj_t &mt, const std::vector<MultiTrajectoryTraits::IndexType> &tips,
0168 std::map<MultiTrajectoryTraits::IndexType, double> &weights) {
0169
0170
0171
0172
0173 const auto minChi2 =
0174 mt.getTrackState(*std::min_element(tips.begin(), tips.end(),
0175 [&](const auto &a, const auto &b) {
0176 return mt.getTrackState(a).chi2() <
0177 mt.getTrackState(b).chi2();
0178 }))
0179 .chi2();
0180
0181
0182 for (auto tip : tips) {
0183 const auto state = mt.getTrackState(tip);
0184 const double chi2 = state.chi2() - minChi2;
0185 const double detR = calculateDeterminant(
0186
0187
0188
0189 state
0190 .template calibratedCovariance<
0191 MultiTrajectoryTraits::MeasurementSizeMax>()
0192 .data(),
0193 state.predictedCovariance(), state.projector(), state.calibratedSize());
0194
0195 const auto factor = std::sqrt(1. / detR) * safeExp(-0.5 * chi2);
0196
0197
0198 if (std::isfinite(factor)) {
0199 weights.at(tip) *= factor;
0200 }
0201 }
0202 }
0203
0204
0205
0206 enum class StatesType { ePredicted, eFiltered, eSmoothed };
0207
0208 inline std::ostream &operator<<(std::ostream &os, StatesType type) {
0209 constexpr static std::array names = {"predicted", "filtered", "smoothed"};
0210 os << names[static_cast<int>(type)];
0211 return os;
0212 }
0213
0214
0215
0216
0217 template <StatesType type, typename traj_t>
0218 struct MultiTrajectoryProjector {
0219 const traj_t &mt;
0220 const std::map<MultiTrajectoryTraits::IndexType, double> &weights;
0221
0222 auto operator()(MultiTrajectoryTraits::IndexType idx) const {
0223 const auto proxy = mt.getTrackState(idx);
0224 switch (type) {
0225 case StatesType::ePredicted:
0226 return std::make_tuple(weights.at(idx), proxy.predicted(),
0227 proxy.predictedCovariance());
0228 case StatesType::eFiltered:
0229 return std::make_tuple(weights.at(idx), proxy.filtered(),
0230 proxy.filteredCovariance());
0231 case StatesType::eSmoothed:
0232 return std::make_tuple(weights.at(idx), proxy.smoothed(),
0233 proxy.smoothedCovariance());
0234 default:
0235 throw std::invalid_argument(
0236 "Incorrect StatesType, should be ePredicted"
0237 ", eFiltered, or eSmoothed.");
0238 }
0239 }
0240 };
0241
0242
0243
0244
0245 template <typename T>
0246 class Updatable {
0247 T m_tmp{};
0248 T m_val{};
0249
0250 public:
0251 Updatable() : m_tmp(0), m_val(0) {}
0252
0253 T &tmp() { return m_tmp; }
0254 void update() { m_val = m_tmp; }
0255
0256 const T &val() const { return m_val; }
0257 };
0258
0259 }
0260 }