Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-20 09:21:35

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/Definitions/TrackParametrization.hpp"
0012 #include "Acts/EventData/MultiTrajectory.hpp"
0013 #include "Acts/EventData/MultiTrajectoryHelpers.hpp"
0014 #include "Acts/Propagator/detail/PointwiseMaterialInteraction.hpp"
0015 #include "Acts/Surfaces/Surface.hpp"
0016 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0017 #include "Acts/TrackFitting/GsfOptions.hpp"
0018 #include "Acts/TrackFitting/detail/GsfComponentMerging.hpp"
0019 #include "Acts/TrackFitting/detail/GsfUtils.hpp"
0020 #include "Acts/TrackFitting/detail/KalmanUpdateHelpers.hpp"
0021 #include "Acts/Utilities/Helpers.hpp"
0022 #include "Acts/Utilities/Zip.hpp"
0023 
0024 #include <map>
0025 
0026 namespace Acts::detail {
0027 
0028 template <typename traj_t>
0029 struct GsfResult {
0030   /// The multi-trajectory which stores the graph of components
0031   traj_t* fittedStates{nullptr};
0032 
0033   /// The current top index of the MultiTrajectory
0034   MultiTrajectoryTraits::IndexType currentTip = MultiTrajectoryTraits::kInvalid;
0035 
0036   /// The last tip referring to a measurement state in the MultiTrajectory
0037   MultiTrajectoryTraits::IndexType lastMeasurementTip =
0038       MultiTrajectoryTraits::kInvalid;
0039 
0040   /// The last multi-component measurement state. Used to initialize the
0041   /// backward pass.
0042   std::vector<std::tuple<double, BoundVector, BoundMatrix>>
0043       lastMeasurementComponents;
0044 
0045   /// The last measurement surface. Used to initialize the backward pass.
0046   const Acts::Surface* lastMeasurementSurface = nullptr;
0047 
0048   /// Some counting
0049   std::size_t measurementStates = 0;
0050   std::size_t measurementHoles = 0;
0051   std::size_t processedStates = 0;
0052 
0053   std::vector<const Surface*> visitedSurfaces;
0054   std::vector<const Surface*> surfacesVisitedBwdAgain;
0055 
0056   /// Statistics about material encounterings
0057   Updatable<std::size_t> nInvalidBetheHeitler;
0058   Updatable<double> maxPathXOverX0;
0059   Updatable<double> sumPathXOverX0;
0060 
0061   // Propagate potential errors to the outside
0062   Result<void> result{Result<void>::success()};
0063 
0064   // Internal: bethe heitler approximation component cache
0065   std::vector<BetheHeitlerApprox::Component> betheHeitlerCache;
0066 
0067   // Internal: component cache to avoid reallocation
0068   std::vector<GsfComponent> componentCache;
0069 };
0070 
0071 /// The actor carrying out the GSF algorithm
0072 template <typename traj_t>
0073 struct GsfActor {
0074   /// Enforce default construction
0075   GsfActor() = default;
0076 
0077   using ComponentCache = GsfComponent;
0078 
0079   /// Broadcast the result_type
0080   using result_type = GsfResult<traj_t>;
0081 
0082   // Actor configuration
0083   struct Config {
0084     /// Maximum number of components which the GSF should handle
0085     std::size_t maxComponents = 16;
0086 
0087     /// Input measurements
0088     const std::map<GeometryIdentifier, SourceLink>* inputMeasurements = nullptr;
0089 
0090     /// Bethe Heitler Approximator pointer. The fitter holds the approximator
0091     /// instance TODO if we somehow could initialize a reference here...
0092     const BetheHeitlerApprox* bethe_heitler_approx = nullptr;
0093 
0094     /// Whether to consider multiple scattering.
0095     bool multipleScattering = true;
0096 
0097     /// When to discard components
0098     double weightCutoff = 1.0e-4;
0099 
0100     /// When this option is enabled, material information on all surfaces is
0101     /// ignored. This disables the component convolution as well as the handling
0102     /// of energy. This may be useful for debugging.
0103     bool disableAllMaterialHandling = false;
0104 
0105     /// Whether to abort immediately when an error occurs
0106     bool abortOnError = false;
0107 
0108     /// We can stop the propagation if we reach this number of measurement
0109     /// states
0110     std::optional<std::size_t> numberMeasurements;
0111 
0112     /// The extensions
0113     GsfExtensions<traj_t> extensions;
0114 
0115     /// Whether we are in the reverse pass or not. This is more reliable than
0116     /// checking the navigation direction, because in principle the fitter can
0117     /// be started backwards in the first pass
0118     bool inReversePass = false;
0119 
0120     /// How to reduce the states that are stored in the multi trajectory
0121     ComponentMergeMethod mergeMethod = ComponentMergeMethod::eMaxWeight;
0122 
0123     const Logger* logger{nullptr};
0124 
0125     /// Calibration context for the fit
0126     const CalibrationContext* calibrationContext{nullptr};
0127 
0128   } m_cfg;
0129 
0130   const Logger& logger() const { return *m_cfg.logger; }
0131 
0132   struct TemporaryStates {
0133     traj_t traj;
0134     std::vector<MultiTrajectoryTraits::IndexType> tips;
0135     std::map<MultiTrajectoryTraits::IndexType, double> weights;
0136   };
0137 
0138   using FiltProjector = MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;
0139 
0140   /// @brief GSF actor operation
0141   ///
0142   /// @tparam propagator_state_t is the type of Propagator state
0143   /// @tparam stepper_t Type of the stepper
0144   /// @tparam navigator_t Type of the navigator
0145   ///
0146   /// @param state is the mutable propagator state object
0147   /// @param stepper The stepper in use
0148   /// @param result is the mutable result state object
0149   template <typename propagator_state_t, typename stepper_t,
0150             typename navigator_t>
0151   void act(propagator_state_t& state, const stepper_t& stepper,
0152            const navigator_t& navigator, result_type& result,
0153            const Logger& /*logger*/) const {
0154     assert(result.fittedStates && "No MultiTrajectory set");
0155 
0156     // Return is we found an error earlier
0157     if (!result.result.ok()) {
0158       ACTS_WARNING("result.result not ok, return!");
0159       return;
0160     }
0161 
0162     // Set error or abort utility
0163     auto setErrorOrAbort = [&](auto error) {
0164       if (m_cfg.abortOnError) {
0165         std::abort();
0166       } else {
0167         result.result = error;
0168       }
0169     };
0170 
0171     // Prints some VERBOSE things and performs some asserts. Can be removed
0172     // without change of behaviour
0173     const detail::ScopedGsfInfoPrinterAndChecker printer(state, stepper,
0174                                                          navigator, logger());
0175 
0176     // We only need to do something if we are on a surface
0177     if (!navigator.currentSurface(state.navigation)) {
0178       return;
0179     }
0180 
0181     const auto& surface = *navigator.currentSurface(state.navigation);
0182     ACTS_VERBOSE("Step is at surface " << surface.geometryId());
0183 
0184     // All components must be normalized at the beginning here, otherwise the
0185     // stepper misbehaves
0186     [[maybe_unused]] auto stepperComponents =
0187         stepper.constComponentIterable(state.stepping);
0188     assert(detail::weightsAreNormalized(
0189         stepperComponents, [](const auto& cmp) { return cmp.weight(); }));
0190 
0191     // All components must have status "on surface". It is however possible,
0192     // that currentSurface is nullptr and all components are "on surface" (e.g.,
0193     // for surfaces excluded from the navigation)
0194     using Status [[maybe_unused]] = IntersectionStatus;
0195     assert(std::all_of(
0196         stepperComponents.begin(), stepperComponents.end(),
0197         [](const auto& cmp) { return cmp.status() == Status::onSurface; }));
0198 
0199     // Early return if we already were on this surface TODO why is this
0200     // necessary
0201     const bool visited = rangeContainsValue(result.visitedSurfaces, &surface);
0202 
0203     if (visited) {
0204       ACTS_VERBOSE("Already visited surface, return");
0205       return;
0206     }
0207 
0208     result.visitedSurfaces.push_back(&surface);
0209 
0210     // Check what we have on this surface
0211     const auto foundSourceLink =
0212         m_cfg.inputMeasurements->find(surface.geometryId());
0213     const bool haveMaterial =
0214         navigator.currentSurface(state.navigation)->surfaceMaterial() &&
0215         !m_cfg.disableAllMaterialHandling;
0216     const bool haveMeasurement =
0217         foundSourceLink != m_cfg.inputMeasurements->end();
0218 
0219     ACTS_VERBOSE(std::boolalpha << "haveMaterial " << haveMaterial
0220                                 << ", haveMeasurement: " << haveMeasurement);
0221 
0222     ////////////////////////
0223     // The Core Algorithm
0224     ////////////////////////
0225 
0226     // Early return if nothing happens
0227     if (!haveMaterial && !haveMeasurement) {
0228       // No hole before first measurement
0229       if (result.processedStates > 0 && surface.associatedDetectorElement()) {
0230         TemporaryStates tmpStates;
0231         noMeasurementUpdate(state, stepper, navigator, result, tmpStates, true);
0232       }
0233       return;
0234     }
0235 
0236     // Update the counters. Note that this should be done before potential
0237     // material interactions, because if this is our last measurement this would
0238     // not influence the fit anymore.
0239     if (haveMeasurement) {
0240       result.maxPathXOverX0.update();
0241       result.sumPathXOverX0.update();
0242       result.nInvalidBetheHeitler.update();
0243     }
0244 
0245     for (auto cmp : stepper.componentIterable(state.stepping)) {
0246       cmp.singleStepper(stepper).transportCovarianceToBound(cmp.state(),
0247                                                             surface);
0248     }
0249 
0250     if (haveMaterial) {
0251       if (haveMeasurement) {
0252         applyMultipleScattering(state, stepper, navigator,
0253                                 MaterialUpdateStage::PreUpdate);
0254       } else {
0255         applyMultipleScattering(state, stepper, navigator,
0256                                 MaterialUpdateStage::FullUpdate);
0257       }
0258     }
0259 
0260     // We do not need the component cache here, we can just update our stepper
0261     // state with the filtered components.
0262     // NOTE because of early return before we know that we have a measurement
0263     if (!haveMaterial) {
0264       TemporaryStates tmpStates;
0265 
0266       auto res = kalmanUpdate(state, stepper, navigator, result, tmpStates,
0267                               foundSourceLink->second);
0268 
0269       if (!res.ok()) {
0270         setErrorOrAbort(res.error());
0271         return;
0272       }
0273 
0274       updateStepper(state, stepper, tmpStates);
0275     }
0276     // We have material, we thus need a component cache since we will
0277     // convolute the components and later reduce them again before updating
0278     // the stepper
0279     else {
0280       TemporaryStates tmpStates;
0281       Result<void> res;
0282 
0283       if (haveMeasurement) {
0284         res = kalmanUpdate(state, stepper, navigator, result, tmpStates,
0285                            foundSourceLink->second);
0286       } else {
0287         res = noMeasurementUpdate(state, stepper, navigator, result, tmpStates,
0288                                   false);
0289       }
0290 
0291       if (!res.ok()) {
0292         setErrorOrAbort(res.error());
0293         return;
0294       }
0295 
0296       // Reuse memory over all calls to the Actor in a single propagation
0297       std::vector<ComponentCache>& componentCache = result.componentCache;
0298       componentCache.clear();
0299 
0300       convoluteComponents(state, stepper, navigator, tmpStates, componentCache,
0301                           result);
0302 
0303       if (componentCache.empty()) {
0304         ACTS_WARNING(
0305             "No components left after applying energy loss. "
0306             "Is the weight cutoff "
0307             << m_cfg.weightCutoff << " too high?");
0308         ACTS_WARNING("Return to propagator without applying energy loss");
0309         return;
0310       }
0311 
0312       // reduce component number
0313       const auto finalCmpNumber = std::min(
0314           static_cast<std::size_t>(stepper.maxComponents), m_cfg.maxComponents);
0315       m_cfg.extensions.mixtureReducer(componentCache, finalCmpNumber, surface);
0316 
0317       removeLowWeightComponents(componentCache);
0318 
0319       updateStepper(state, stepper, navigator, componentCache);
0320     }
0321 
0322     // If we have only done preUpdate before, now do postUpdate
0323     if (haveMaterial && haveMeasurement) {
0324       applyMultipleScattering(state, stepper, navigator,
0325                               MaterialUpdateStage::PostUpdate);
0326     }
0327   }
0328 
0329   template <typename propagator_state_t, typename stepper_t,
0330             typename navigator_t>
0331   bool checkAbort(propagator_state_t& /*state*/, const stepper_t& /*stepper*/,
0332                   const navigator_t& /*navigator*/, const result_type& result,
0333                   const Logger& /*logger*/) const {
0334     if (m_cfg.numberMeasurements &&
0335         result.measurementStates == m_cfg.numberMeasurements) {
0336       ACTS_VERBOSE("Stop navigation because all measurements are found");
0337       return true;
0338     }
0339 
0340     return false;
0341   }
0342 
0343   template <typename propagator_state_t, typename stepper_t,
0344             typename navigator_t>
0345   void convoluteComponents(propagator_state_t& state, const stepper_t& stepper,
0346                            const navigator_t& navigator,
0347                            const TemporaryStates& tmpStates,
0348                            std::vector<ComponentCache>& componentCache,
0349                            result_type& result) const {
0350     auto cmps = stepper.componentIterable(state.stepping);
0351     double pathXOverX0 = 0.0;
0352     for (auto [idx, cmp] : zip(tmpStates.tips, cmps)) {
0353       auto proxy = tmpStates.traj.getTrackState(idx);
0354 
0355       BoundTrackParameters bound(proxy.referenceSurface().getSharedPtr(),
0356                                  proxy.filtered(), proxy.filteredCovariance(),
0357                                  stepper.particleHypothesis(state.stepping));
0358 
0359       pathXOverX0 +=
0360           applyBetheHeitler(state, navigator, bound, tmpStates.weights.at(idx),
0361                             componentCache, result);
0362     }
0363 
0364     // Store average material seen by the components
0365     // Should not be too broadly distributed
0366     result.sumPathXOverX0.tmp() += pathXOverX0 / tmpStates.tips.size();
0367   }
0368 
0369   template <typename propagator_state_t, typename navigator_t>
0370   double applyBetheHeitler(const propagator_state_t& state,
0371                            const navigator_t& navigator,
0372                            const BoundTrackParameters& old_bound,
0373                            const double old_weight,
0374                            std::vector<ComponentCache>& componentCache,
0375                            result_type& result) const {
0376     const auto& surface = *navigator.currentSurface(state.navigation);
0377     const auto p_prev = old_bound.absoluteMomentum();
0378     const auto& particleHypothesis = old_bound.particleHypothesis();
0379 
0380     // Evaluate material slab
0381     auto slab = surface.surfaceMaterial()->materialSlab(
0382         old_bound.position(state.geoContext), state.options.direction,
0383         MaterialUpdateStage::FullUpdate);
0384 
0385     const auto pathCorrection = surface.pathCorrection(
0386         state.geoContext, old_bound.position(state.geoContext),
0387         old_bound.direction());
0388     slab.scaleThickness(pathCorrection);
0389 
0390     const double pathXOverX0 = slab.thicknessInX0();
0391     result.maxPathXOverX0.tmp() =
0392         std::max(result.maxPathXOverX0.tmp(), pathXOverX0);
0393 
0394     // Emit a warning if the approximation is not valid for this x/x0
0395     if (!m_cfg.bethe_heitler_approx->validXOverX0(pathXOverX0)) {
0396       ++result.nInvalidBetheHeitler.tmp();
0397       ACTS_DEBUG(
0398           "Bethe-Heitler approximation encountered invalid value for x/x0="
0399           << pathXOverX0 << " at surface " << surface.geometryId());
0400     }
0401 
0402     // Get the mixture
0403     result.betheHeitlerCache.resize(
0404         m_cfg.bethe_heitler_approx->maxComponents());
0405     const auto mixture = m_cfg.bethe_heitler_approx->mixture(
0406         pathXOverX0, result.betheHeitlerCache);
0407 
0408     // Create all possible new components
0409     for (const auto& gaussian : mixture) {
0410       // Here we combine the new child weight with the parent weight.
0411       // However, this must be later re-adjusted
0412       const auto new_weight = gaussian.weight * old_weight;
0413 
0414       if (new_weight < m_cfg.weightCutoff) {
0415         ACTS_VERBOSE("Skip component with weight " << new_weight);
0416         continue;
0417       }
0418 
0419       if (gaussian.mean < 1.e-8) {
0420         ACTS_WARNING("Skip component with gaussian " << gaussian.mean << " +- "
0421                                                      << gaussian.var);
0422         continue;
0423       }
0424 
0425       // compute delta p from mixture and update parameters
0426       auto new_pars = old_bound.parameters();
0427 
0428       const auto delta_p = [&]() {
0429         if (state.options.direction == Direction::Forward()) {
0430           return p_prev * (gaussian.mean - 1.);
0431         } else {
0432           return p_prev * (1. / gaussian.mean - 1.);
0433         }
0434       }();
0435 
0436       assert(p_prev + delta_p > 0. && "new momentum must be > 0");
0437       new_pars[eBoundQOverP] =
0438           particleHypothesis.qOverP(p_prev + delta_p, old_bound.charge());
0439 
0440       // compute inverse variance of p from mixture and update covariance
0441       auto new_cov = old_bound.covariance().value();
0442 
0443       const auto varInvP = [&]() {
0444         if (state.options.direction == Direction::Forward()) {
0445           const auto f = 1. / (p_prev * gaussian.mean);
0446           return f * f * gaussian.var;
0447         } else {
0448           return gaussian.var / (p_prev * p_prev);
0449         }
0450       }();
0451 
0452       new_cov(eBoundQOverP, eBoundQOverP) += varInvP;
0453       assert(std::isfinite(new_cov(eBoundQOverP, eBoundQOverP)) &&
0454              "new cov not finite");
0455 
0456       // Set the remaining things and push to vector
0457       componentCache.push_back({new_weight, new_pars, new_cov});
0458     }
0459 
0460     return pathXOverX0;
0461   }
0462 
0463   /// Remove components with low weights and renormalize from the component
0464   /// cache
0465   /// TODO This function does not expect normalized components, but this
0466   /// could be redundant work...
0467   void removeLowWeightComponents(std::vector<ComponentCache>& cmps) const {
0468     auto proj = [](auto& cmp) -> double& { return cmp.weight; };
0469 
0470     detail::normalizeWeights(cmps, proj);
0471 
0472     auto new_end = std::remove_if(cmps.begin(), cmps.end(), [&](auto& cmp) {
0473       return proj(cmp) < m_cfg.weightCutoff;
0474     });
0475 
0476     // In case we would remove all components, keep only the largest
0477     if (std::distance(cmps.begin(), new_end) == 0) {
0478       cmps = {*std::max_element(
0479           cmps.begin(), cmps.end(),
0480           [&](auto& a, auto& b) { return proj(a) < proj(b); })};
0481       cmps.front().weight = 1.0;
0482     } else {
0483       cmps.erase(new_end, cmps.end());
0484       detail::normalizeWeights(cmps, proj);
0485     }
0486   }
0487 
0488   /// Function that updates the stepper from the MultiTrajectory
0489   template <typename propagator_state_t, typename stepper_t>
0490   void updateStepper(propagator_state_t& state, const stepper_t& stepper,
0491                      const TemporaryStates& tmpStates) const {
0492     auto cmps = stepper.componentIterable(state.stepping);
0493 
0494     for (auto [idx, cmp] : zip(tmpStates.tips, cmps)) {
0495       // we set ignored components to missed, so we can remove them after
0496       // the loop
0497       if (tmpStates.weights.at(idx) < m_cfg.weightCutoff) {
0498         cmp.status() = IntersectionStatus::unreachable;
0499         continue;
0500       }
0501 
0502       auto proxy = tmpStates.traj.getTrackState(idx);
0503 
0504       cmp.pars() =
0505           MultiTrajectoryHelpers::freeFiltered(state.geoContext, proxy);
0506       cmp.cov() = proxy.filteredCovariance();
0507       cmp.weight() = tmpStates.weights.at(idx);
0508     }
0509 
0510     stepper.removeMissedComponents(state.stepping);
0511 
0512     // TODO we have two normalization passes here now, this can probably be
0513     // optimized
0514     detail::normalizeWeights(cmps,
0515                              [&](auto cmp) -> double& { return cmp.weight(); });
0516   }
0517 
0518   /// Function that updates the stepper from the ComponentCache
0519   template <typename propagator_state_t, typename stepper_t,
0520             typename navigator_t>
0521   void updateStepper(propagator_state_t& state, const stepper_t& stepper,
0522                      const navigator_t& navigator,
0523                      const std::vector<ComponentCache>& componentCache) const {
0524     const auto& surface = *navigator.currentSurface(state.navigation);
0525 
0526     // Clear components before adding new ones
0527     stepper.clearComponents(state.stepping);
0528 
0529     // Finally loop over components
0530     for (const auto& [weight, pars, cov] : componentCache) {
0531       // Add the component to the stepper
0532       BoundTrackParameters bound(surface.getSharedPtr(), pars, cov,
0533                                  stepper.particleHypothesis(state.stepping));
0534 
0535       auto res = stepper.addComponent(state.stepping, std::move(bound), weight);
0536 
0537       if (!res.ok()) {
0538         ACTS_ERROR("Error adding component to MultiStepper");
0539         continue;
0540       }
0541 
0542       auto& cmp = *res;
0543       auto freeParams = cmp.pars();
0544       cmp.jacToGlobal() = surface.boundToFreeJacobian(
0545           state.geoContext, freeParams.template segment<3>(eFreePos0),
0546           freeParams.template segment<3>(eFreeDir0));
0547       cmp.pathAccumulated() = state.stepping.pathAccumulated;
0548       cmp.jacobian() = BoundMatrix::Identity();
0549       cmp.derivative() = FreeVector::Zero();
0550       cmp.jacTransport() = FreeMatrix::Identity();
0551     }
0552   }
0553 
0554   /// This function performs the kalman update, computes the new posterior
0555   /// weights, renormalizes all components, and does some statistics.
0556   template <typename propagator_state_t, typename stepper_t,
0557             typename navigator_t>
0558   Result<void> kalmanUpdate(propagator_state_t& state, const stepper_t& stepper,
0559                             const navigator_t& navigator, result_type& result,
0560                             TemporaryStates& tmpStates,
0561                             const SourceLink& sourceLink) const {
0562     const auto& surface = *navigator.currentSurface(state.navigation);
0563 
0564     // Boolean flag, to distinguish measurement and outlier states. This flag
0565     // is only modified by the valid-measurement-branch, so only if there
0566     // isn't any valid measurement state, the flag stays false and the state
0567     // is thus counted as an outlier
0568     bool is_valid_measurement = false;
0569 
0570     auto cmps = stepper.componentIterable(state.stepping);
0571     for (auto cmp : cmps) {
0572       auto singleState = cmp.singleState(state);
0573       const auto& singleStepper = cmp.singleStepper(stepper);
0574 
0575       auto trackStateProxyRes = detail::kalmanHandleMeasurement(
0576           *m_cfg.calibrationContext, singleState, singleStepper,
0577           m_cfg.extensions, surface, sourceLink, tmpStates.traj,
0578           MultiTrajectoryTraits::kInvalid, false, logger());
0579 
0580       if (!trackStateProxyRes.ok()) {
0581         return trackStateProxyRes.error();
0582       }
0583 
0584       const auto& trackStateProxy = *trackStateProxyRes;
0585 
0586       // If at least one component is no outlier, we consider the whole thing
0587       // as a measurementState
0588       if (trackStateProxy.typeFlags().test(TrackStateFlag::MeasurementFlag)) {
0589         is_valid_measurement = true;
0590       }
0591 
0592       tmpStates.tips.push_back(trackStateProxy.index());
0593       tmpStates.weights[tmpStates.tips.back()] = cmp.weight();
0594     }
0595 
0596     computePosteriorWeights(tmpStates.traj, tmpStates.tips, tmpStates.weights);
0597 
0598     detail::normalizeWeights(tmpStates.tips, [&](auto idx) -> double& {
0599       return tmpStates.weights.at(idx);
0600     });
0601 
0602     // Do the statistics
0603     ++result.processedStates;
0604 
0605     // TODO should outlier states also be counted here?
0606     if (is_valid_measurement) {
0607       ++result.measurementStates;
0608     }
0609 
0610     updateMultiTrajectory(result, tmpStates, surface);
0611 
0612     result.lastMeasurementTip = result.currentTip;
0613     result.lastMeasurementSurface = &surface;
0614 
0615     // Note, that we do not normalize the components here.
0616     // This must be done before initializing the backward pass.
0617     result.lastMeasurementComponents.clear();
0618 
0619     FiltProjector proj{tmpStates.traj, tmpStates.weights};
0620     for (const auto& idx : tmpStates.tips) {
0621       const auto& [w, p, c] = proj(idx);
0622       // TODO check why zero weight can occur
0623       if (w > 0.0) {
0624         result.lastMeasurementComponents.push_back({w, p, c});
0625       }
0626     }
0627 
0628     // Return success
0629     return Result<void>::success();
0630   }
0631 
0632   template <typename propagator_state_t, typename stepper_t,
0633             typename navigator_t>
0634   Result<void> noMeasurementUpdate(propagator_state_t& state,
0635                                    const stepper_t& stepper,
0636                                    const navigator_t& navigator,
0637                                    result_type& result,
0638                                    TemporaryStates& tmpStates,
0639                                    bool doCovTransport) const {
0640     const auto& surface = *navigator.currentSurface(state.navigation);
0641 
0642     const bool precedingMeasurementExists = result.processedStates > 0;
0643 
0644     // Initialize as true, so that any component can flip it. However, all
0645     // components should behave the same
0646     bool isHole = true;
0647 
0648     for (auto cmp : stepper.componentIterable(state.stepping)) {
0649       auto& singleState = cmp.state();
0650       const auto& singleStepper = cmp.singleStepper(stepper);
0651 
0652       // There is some redundant checking inside this function, but do this for
0653       // now until we measure this is significant
0654       auto trackStateProxyRes = detail::kalmanHandleNoMeasurement(
0655           singleState, singleStepper, surface, tmpStates.traj,
0656           MultiTrajectoryTraits::kInvalid, doCovTransport, logger(),
0657           precedingMeasurementExists);
0658 
0659       if (!trackStateProxyRes.ok()) {
0660         return trackStateProxyRes.error();
0661       }
0662 
0663       const auto& trackStateProxy = *trackStateProxyRes;
0664 
0665       if (!trackStateProxy.typeFlags().test(TrackStateFlag::HoleFlag)) {
0666         isHole = false;
0667       }
0668 
0669       tmpStates.tips.push_back(trackStateProxy.index());
0670       tmpStates.weights[tmpStates.tips.back()] = cmp.weight();
0671     }
0672 
0673     // These things should only be done once for all components
0674     if (isHole) {
0675       ++result.measurementHoles;
0676     }
0677 
0678     ++result.processedStates;
0679 
0680     updateMultiTrajectory(result, tmpStates, surface);
0681 
0682     return Result<void>::success();
0683   }
0684 
0685   /// Apply the multiple scattering to the state
0686   template <typename propagator_state_t, typename stepper_t,
0687             typename navigator_t>
0688   void applyMultipleScattering(propagator_state_t& state,
0689                                const stepper_t& stepper,
0690                                const navigator_t& navigator,
0691                                const MaterialUpdateStage& updateStage =
0692                                    MaterialUpdateStage::FullUpdate) const {
0693     const auto& surface = *navigator.currentSurface(state.navigation);
0694 
0695     for (auto cmp : stepper.componentIterable(state.stepping)) {
0696       auto singleState = cmp.singleState(state);
0697       const auto& singleStepper = cmp.singleStepper(stepper);
0698 
0699       detail::PointwiseMaterialInteraction interaction(&surface, singleState,
0700                                                        singleStepper);
0701       if (interaction.evaluateMaterialSlab(singleState, navigator,
0702                                            updateStage)) {
0703         // In the Gsf we only need to handle the multiple scattering
0704         interaction.evaluatePointwiseMaterialInteraction(
0705             m_cfg.multipleScattering, false);
0706 
0707         // Screen out material effects info
0708         ACTS_VERBOSE("Material effects on surface: "
0709                      << surface.geometryId()
0710                      << " at update stage: " << updateStage << " are :");
0711         ACTS_VERBOSE("eLoss = "
0712                      << interaction.Eloss << ", "
0713                      << "variancePhi = " << interaction.variancePhi << ", "
0714                      << "varianceTheta = " << interaction.varianceTheta << ", "
0715                      << "varianceQoverP = " << interaction.varianceQoverP);
0716 
0717         // Update the state and stepper with material effects
0718         interaction.updateState(singleState, singleStepper, addNoise);
0719 
0720         assert(singleState.stepping.cov.array().isFinite().all() &&
0721                "covariance not finite after multi scattering");
0722       }
0723     }
0724   }
0725 
0726   void updateMultiTrajectory(result_type& result,
0727                              const TemporaryStates& tmpStates,
0728                              const Surface& surface) const {
0729     using PrtProjector =
0730         MultiTrajectoryProjector<StatesType::ePredicted, traj_t>;
0731     using FltProjector =
0732         MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;
0733 
0734     if (!m_cfg.inReversePass) {
0735       const auto firstCmpProxy =
0736           tmpStates.traj.getTrackState(tmpStates.tips.front());
0737       const auto isMeasurement =
0738           firstCmpProxy.typeFlags().test(MeasurementFlag);
0739 
0740       const auto mask =
0741           isMeasurement
0742               ? TrackStatePropMask::Calibrated | TrackStatePropMask::Predicted |
0743                     TrackStatePropMask::Filtered | TrackStatePropMask::Smoothed
0744               : TrackStatePropMask::Calibrated | TrackStatePropMask::Predicted;
0745 
0746       auto proxy = result.fittedStates->makeTrackState(mask, result.currentTip);
0747       result.currentTip = proxy.index();
0748 
0749       proxy.setReferenceSurface(surface.getSharedPtr());
0750       proxy.copyFrom(firstCmpProxy, mask);
0751 
0752       auto [prtMean, prtCov] =
0753           mergeGaussianMixture(tmpStates.tips, surface, m_cfg.mergeMethod,
0754                                PrtProjector{tmpStates.traj, tmpStates.weights});
0755       proxy.predicted() = prtMean;
0756       proxy.predictedCovariance() = prtCov;
0757 
0758       if (isMeasurement) {
0759         auto [fltMean, fltCov] = mergeGaussianMixture(
0760             tmpStates.tips, surface, m_cfg.mergeMethod,
0761             FltProjector{tmpStates.traj, tmpStates.weights});
0762         proxy.filtered() = fltMean;
0763         proxy.filteredCovariance() = fltCov;
0764         proxy.smoothed() = BoundVector::Constant(-2);
0765         proxy.smoothedCovariance() = BoundSquareMatrix::Constant(-2);
0766       } else {
0767         proxy.shareFrom(TrackStatePropMask::Predicted,
0768                         TrackStatePropMask::Filtered);
0769       }
0770 
0771     } else {
0772       assert((result.currentTip != MultiTrajectoryTraits::kInvalid &&
0773               "tip not valid"));
0774 
0775       result.fittedStates->applyBackwards(
0776           result.currentTip, [&](auto trackState) {
0777             auto fSurface = &trackState.referenceSurface();
0778             if (fSurface == &surface) {
0779               result.surfacesVisitedBwdAgain.push_back(&surface);
0780 
0781               if (trackState.hasSmoothed()) {
0782                 const auto [smtMean, smtCov] = mergeGaussianMixture(
0783                     tmpStates.tips, surface, m_cfg.mergeMethod,
0784                     FltProjector{tmpStates.traj, tmpStates.weights});
0785 
0786                 trackState.smoothed() = smtMean;
0787                 trackState.smoothedCovariance() = smtCov;
0788               }
0789               return false;
0790             }
0791             return true;
0792           });
0793     }
0794   }
0795 
0796   /// Set the relevant options that can be set from the Options struct all in
0797   /// one place
0798   void setOptions(const GsfOptions<traj_t>& options) {
0799     m_cfg.maxComponents = options.maxComponents;
0800     m_cfg.extensions = options.extensions;
0801     m_cfg.abortOnError = options.abortOnError;
0802     m_cfg.disableAllMaterialHandling = options.disableAllMaterialHandling;
0803     m_cfg.weightCutoff = options.weightCutoff;
0804     m_cfg.mergeMethod = options.componentMergeMethod;
0805     m_cfg.calibrationContext = &options.calibrationContext.get();
0806   }
0807 };
0808 
0809 }  // namespace Acts::detail