Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-19 09:23:36

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2022 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/Definitions/TrackParametrization.hpp"
0012 #include "Acts/EventData/MultiComponentTrackParameters.hpp"
0013 #include "Acts/EventData/MultiTrajectory.hpp"
0014 #include "Acts/EventData/MultiTrajectoryHelpers.hpp"
0015 #include "Acts/MagneticField/MagneticFieldProvider.hpp"
0016 #include "Acts/Material/ISurfaceMaterial.hpp"
0017 #include "Acts/Surfaces/CylinderSurface.hpp"
0018 #include "Acts/Surfaces/Surface.hpp"
0019 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0020 #include "Acts/TrackFitting/GsfError.hpp"
0021 #include "Acts/TrackFitting/GsfOptions.hpp"
0022 #include "Acts/TrackFitting/KalmanFitter.hpp"
0023 #include "Acts/TrackFitting/detail/GsfComponentMerging.hpp"
0024 #include "Acts/TrackFitting/detail/GsfUtils.hpp"
0025 #include "Acts/TrackFitting/detail/KalmanUpdateHelpers.hpp"
0026 #include "Acts/Utilities/Zip.hpp"
0027 
0028 #include <ios>
0029 #include <map>
0030 #include <numeric>
0031 
0032 namespace Acts::detail {
0033 
0034 template <typename traj_t>
0035 struct GsfResult {
0036   /// The multi-trajectory which stores the graph of components
0037   traj_t* fittedStates{nullptr};
0038 
0039   /// The current top index of the MultiTrajectory
0040   MultiTrajectoryTraits::IndexType currentTip = MultiTrajectoryTraits::kInvalid;
0041 
0042   /// The last tip referring to a measurement state in the MultiTrajectory
0043   MultiTrajectoryTraits::IndexType lastMeasurementTip =
0044       MultiTrajectoryTraits::kInvalid;
0045 
0046   /// The last multi-component measurement state. Used to initialize the
0047   /// backward pass.
0048   std::optional<MultiComponentBoundTrackParameters> lastMeasurementState;
0049 
0050   /// Some counting
0051   std::size_t measurementStates = 0;
0052   std::size_t measurementHoles = 0;
0053   std::size_t processedStates = 0;
0054 
0055   std::vector<const Acts::Surface*> visitedSurfaces;
0056   std::vector<const Acts::Surface*> surfacesVisitedBwdAgain;
0057 
0058   /// Statistics about material encounterings
0059   Updatable<std::size_t> nInvalidBetheHeitler;
0060   Updatable<double> maxPathXOverX0;
0061   Updatable<double> sumPathXOverX0;
0062 
0063   // Propagate potential errors to the outside
0064   Result<void> result{Result<void>::success()};
0065 
0066   // Internal: component cache to avoid reallocation
0067   std::vector<GsfComponent> componentCache;
0068 };
0069 
0070 /// The actor carrying out the GSF algorithm
0071 template <typename bethe_heitler_approx_t, typename traj_t>
0072 struct GsfActor {
0073   /// Enforce default construction
0074   GsfActor() = default;
0075 
0076   using ComponentCache = Acts::GsfComponent;
0077 
0078   /// Broadcast the result_type
0079   using result_type = GsfResult<traj_t>;
0080 
0081   // Actor configuration
0082   struct Config {
0083     /// Maximum number of components which the GSF should handle
0084     std::size_t maxComponents = 16;
0085 
0086     /// Input measurements
0087     const std::map<GeometryIdentifier, SourceLink>* inputMeasurements = nullptr;
0088 
0089     /// Bethe Heitler Approximator pointer. The fitter holds the approximator
0090     /// instance TODO if we somehow could initialize a reference here...
0091     const bethe_heitler_approx_t* bethe_heitler_approx = nullptr;
0092 
0093     /// Whether to consider multiple scattering.
0094     bool multipleScattering = true;
0095 
0096     /// When to discard components
0097     double weightCutoff = 1.0e-4;
0098 
0099     /// When this option is enabled, material information on all surfaces is
0100     /// ignored. This disables the component convolution as well as the handling
0101     /// of energy. This may be useful for debugging.
0102     bool disableAllMaterialHandling = false;
0103 
0104     /// Whether to abort immediately when an error occurs
0105     bool abortOnError = false;
0106 
0107     /// We can stop the propagation if we reach this number of measurement
0108     /// states
0109     std::optional<std::size_t> numberMeasurements;
0110 
0111     /// The extensions
0112     GsfExtensions<traj_t> extensions;
0113 
0114     /// Whether we are in the reverse pass or not. This is more reliable than
0115     /// checking the navigation direction, because in principle the fitter can
0116     /// be started backwards in the first pass
0117     bool inReversePass = false;
0118 
0119     /// How to reduce the states that are stored in the multi trajectory
0120     ComponentMergeMethod mergeMethod = ComponentMergeMethod::eMaxWeight;
0121 
0122     const Logger* logger{nullptr};
0123 
0124     /// Calibration context for the fit
0125     const CalibrationContext* calibrationContext{nullptr};
0126 
0127   } m_cfg;
0128 
0129   const Logger& logger() const { return *m_cfg.logger; }
0130 
0131   struct TemporaryStates {
0132     traj_t traj;
0133     std::vector<MultiTrajectoryTraits::IndexType> tips;
0134     std::map<MultiTrajectoryTraits::IndexType, double> weights;
0135   };
0136 
0137   /// @brief GSF actor operation
0138   ///
0139   /// @tparam propagator_state_t is the type of Propagator state
0140   /// @tparam stepper_t Type of the stepper
0141   /// @tparam navigator_t Type of the navigator
0142   ///
0143   /// @param state is the mutable propagator state object
0144   /// @param stepper The stepper in use
0145   /// @param result is the mutable result state object
0146   template <typename propagator_state_t, typename stepper_t,
0147             typename navigator_t>
0148   void operator()(propagator_state_t& state, const stepper_t& stepper,
0149                   const navigator_t& navigator, result_type& result,
0150                   const Logger& /*logger*/) const {
0151     assert(result.fittedStates && "No MultiTrajectory set");
0152 
0153     // Return is we found an error earlier
0154     if (!result.result.ok()) {
0155       ACTS_WARNING("result.result not ok, return!")
0156       return;
0157     }
0158 
0159     // Set error or abort utility
0160     auto setErrorOrAbort = [&](auto error) {
0161       if (m_cfg.abortOnError) {
0162         std::abort();
0163       } else {
0164         result.result = error;
0165       }
0166     };
0167 
0168     // Prints some VERBOSE things and performs some asserts. Can be removed
0169     // without change of behaviour
0170     const detail::ScopedGsfInfoPrinterAndChecker printer(state, stepper,
0171                                                          navigator, logger());
0172 
0173     // We only need to do something if we are on a surface
0174     if (!navigator.currentSurface(state.navigation)) {
0175       return;
0176     }
0177 
0178     const auto& surface = *navigator.currentSurface(state.navigation);
0179     ACTS_VERBOSE("Step is at surface " << surface.geometryId());
0180 
0181     // All components must be normalized at the beginning here, otherwise the
0182     // stepper misbehaves
0183     [[maybe_unused]] auto stepperComponents =
0184         stepper.constComponentIterable(state.stepping);
0185     assert(detail::weightsAreNormalized(
0186         stepperComponents, [](const auto& cmp) { return cmp.weight(); }));
0187 
0188     // All components must have status "on surface". It is however possible,
0189     // that currentSurface is nullptr and all components are "on surface" (e.g.,
0190     // for surfaces excluded from the navigation)
0191     using Status = Acts::Intersection3D::Status;
0192     assert(std::all_of(
0193         stepperComponents.begin(), stepperComponents.end(),
0194         [](const auto& cmp) { return cmp.status() == Status::onSurface; }));
0195 
0196     // Early return if we already were on this surface TODO why is this
0197     // necessary
0198     const bool visited =
0199         std::find(result.visitedSurfaces.begin(), result.visitedSurfaces.end(),
0200                   &surface) != result.visitedSurfaces.end();
0201 
0202     if (visited) {
0203       ACTS_VERBOSE("Already visited surface, return");
0204       return;
0205     }
0206 
0207     result.visitedSurfaces.push_back(&surface);
0208 
0209     // Check what we have on this surface
0210     const auto found_source_link =
0211         m_cfg.inputMeasurements->find(surface.geometryId());
0212     const bool haveMaterial =
0213         navigator.currentSurface(state.navigation)->surfaceMaterial() &&
0214         !m_cfg.disableAllMaterialHandling;
0215     const bool haveMeasurement =
0216         found_source_link != m_cfg.inputMeasurements->end();
0217 
0218     ACTS_VERBOSE(std::boolalpha << "haveMaterial " << haveMaterial
0219                                 << ", haveMeasurement: " << haveMeasurement);
0220 
0221     ////////////////////////
0222     // The Core Algorithm
0223     ////////////////////////
0224 
0225     // Early return if nothing happens
0226     if (!haveMaterial && !haveMeasurement) {
0227       // No hole before first measurement
0228       if (result.processedStates > 0 && surface.associatedDetectorElement()) {
0229         TemporaryStates tmpStates;
0230         noMeasurementUpdate(state, stepper, navigator, result, tmpStates, true);
0231       }
0232       return;
0233     }
0234 
0235     // Update the counters. Note that this should be done before potential
0236     // material interactions, because if this is our last measurement this would
0237     // not influence the fit anymore.
0238     if (haveMeasurement) {
0239       result.maxPathXOverX0.update();
0240       result.sumPathXOverX0.update();
0241       result.nInvalidBetheHeitler.update();
0242     }
0243 
0244     for (auto cmp : stepper.componentIterable(state.stepping)) {
0245       auto singleState = cmp.singleState(state);
0246       cmp.singleStepper(stepper).transportCovarianceToBound(
0247           singleState.stepping, surface);
0248     }
0249 
0250     if (haveMaterial) {
0251       if (haveMeasurement) {
0252         applyMultipleScattering(state, stepper, navigator,
0253                                 MaterialUpdateStage::PreUpdate);
0254       } else {
0255         applyMultipleScattering(state, stepper, navigator,
0256                                 MaterialUpdateStage::FullUpdate);
0257       }
0258     }
0259 
0260     // We do not need the component cache here, we can just update our stepper
0261     // state with the filtered components.
0262     // NOTE because of early return before we know that we have a measurement
0263     if (!haveMaterial) {
0264       TemporaryStates tmpStates;
0265 
0266       auto res = kalmanUpdate(state, stepper, navigator, result, tmpStates,
0267                               found_source_link->second);
0268 
0269       if (!res.ok()) {
0270         setErrorOrAbort(res.error());
0271         return;
0272       }
0273 
0274       updateStepper(state, stepper, tmpStates);
0275     }
0276     // We have material, we thus need a component cache since we will
0277     // convolute the components and later reduce them again before updating
0278     // the stepper
0279     else {
0280       TemporaryStates tmpStates;
0281       Result<void> res;
0282 
0283       if (haveMeasurement) {
0284         res = kalmanUpdate(state, stepper, navigator, result, tmpStates,
0285                            found_source_link->second);
0286       } else {
0287         res = noMeasurementUpdate(state, stepper, navigator, result, tmpStates,
0288                                   false);
0289       }
0290 
0291       if (!res.ok()) {
0292         setErrorOrAbort(res.error());
0293         return;
0294       }
0295 
0296       // Reuse memory over all calls to the Actor in a single propagation
0297       std::vector<ComponentCache>& componentCache = result.componentCache;
0298       componentCache.clear();
0299 
0300       convoluteComponents(state, stepper, navigator, tmpStates, componentCache,
0301                           result);
0302 
0303       if (componentCache.empty()) {
0304         ACTS_WARNING(
0305             "No components left after applying energy loss. "
0306             "Is the weight cutoff "
0307             << m_cfg.weightCutoff << " too high?");
0308         ACTS_WARNING("Return to propagator without applying energy loss");
0309         return;
0310       }
0311 
0312       // reduce component number
0313       const auto finalCmpNumber = std::min(
0314           static_cast<std::size_t>(stepper.maxComponents), m_cfg.maxComponents);
0315       m_cfg.extensions.mixtureReducer(componentCache, finalCmpNumber, surface);
0316 
0317       removeLowWeightComponents(componentCache);
0318 
0319       updateStepper(state, stepper, navigator, componentCache);
0320     }
0321 
0322     // If we have only done preUpdate before, now do postUpdate
0323     if (haveMaterial && haveMeasurement) {
0324       applyMultipleScattering(state, stepper, navigator,
0325                               MaterialUpdateStage::PostUpdate);
0326     }
0327 
0328     // Break the navigation if we found all measurements
0329     if (m_cfg.numberMeasurements &&
0330         result.measurementStates == m_cfg.numberMeasurements) {
0331       ACTS_VERBOSE("Stop navigation because all measurements are found");
0332       navigator.navigationBreak(state.navigation, true);
0333     }
0334   }
0335 
0336   template <typename propagator_state_t, typename stepper_t,
0337             typename navigator_t>
0338   void convoluteComponents(propagator_state_t& state, const stepper_t& stepper,
0339                            const navigator_t& navigator,
0340                            const TemporaryStates& tmpStates,
0341                            std::vector<ComponentCache>& componentCache,
0342                            result_type& result) const {
0343     auto cmps = stepper.componentIterable(state.stepping);
0344     double pathXOverX0 = 0.0;
0345     for (auto [idx, cmp] : zip(tmpStates.tips, cmps)) {
0346       auto proxy = tmpStates.traj.getTrackState(idx);
0347 
0348       BoundTrackParameters bound(proxy.referenceSurface().getSharedPtr(),
0349                                  proxy.filtered(), proxy.filteredCovariance(),
0350                                  stepper.particleHypothesis(state.stepping));
0351 
0352       pathXOverX0 +=
0353           applyBetheHeitler(state, navigator, bound, tmpStates.weights.at(idx),
0354                             componentCache, result);
0355     }
0356 
0357     // Store average material seen by the components
0358     // Should not be too broadly distributed
0359     result.sumPathXOverX0.tmp() += pathXOverX0 / tmpStates.tips.size();
0360   }
0361 
0362   template <typename propagator_state_t, typename navigator_t>
0363   double applyBetheHeitler(const propagator_state_t& state,
0364                            const navigator_t& navigator,
0365                            const BoundTrackParameters& old_bound,
0366                            const double old_weight,
0367                            std::vector<ComponentCache>& componentCaches,
0368                            result_type& result) const {
0369     const auto& surface = *navigator.currentSurface(state.navigation);
0370     const auto p_prev = old_bound.absoluteMomentum();
0371 
0372     // Evaluate material slab
0373     auto slab = surface.surfaceMaterial()->materialSlab(
0374         old_bound.position(state.stepping.geoContext), state.options.direction,
0375         MaterialUpdateStage::FullUpdate);
0376 
0377     const auto pathCorrection = surface.pathCorrection(
0378         state.stepping.geoContext,
0379         old_bound.position(state.stepping.geoContext), old_bound.direction());
0380     slab.scaleThickness(pathCorrection);
0381 
0382     const double pathXOverX0 = slab.thicknessInX0();
0383     result.maxPathXOverX0.tmp() =
0384         std::max(result.maxPathXOverX0.tmp(), pathXOverX0);
0385 
0386     // Emit a warning if the approximation is not valid for this x/x0
0387     if (!m_cfg.bethe_heitler_approx->validXOverX0(pathXOverX0)) {
0388       ++result.nInvalidBetheHeitler.tmp();
0389       ACTS_DEBUG(
0390           "Bethe-Heitler approximation encountered invalid value for x/x0="
0391           << pathXOverX0 << " at surface " << surface.geometryId());
0392     }
0393 
0394     // Get the mixture
0395     const auto mixture = m_cfg.bethe_heitler_approx->mixture(pathXOverX0);
0396 
0397     // Create all possible new components
0398     for (const auto& gaussian : mixture) {
0399       // Here we combine the new child weight with the parent weight.
0400       // However, this must be later re-adjusted
0401       const auto new_weight = gaussian.weight * old_weight;
0402 
0403       if (new_weight < m_cfg.weightCutoff) {
0404         ACTS_VERBOSE("Skip component with weight " << new_weight);
0405         continue;
0406       }
0407 
0408       if (gaussian.mean < 1.e-8) {
0409         ACTS_WARNING("Skip component with gaussian " << gaussian.mean << " +- "
0410                                                      << gaussian.var);
0411         continue;
0412       }
0413 
0414       // compute delta p from mixture and update parameters
0415       auto new_pars = old_bound.parameters();
0416 
0417       const auto delta_p = [&]() {
0418         if (state.options.direction == Direction::Forward) {
0419           return p_prev * (gaussian.mean - 1.);
0420         } else {
0421           return p_prev * (1. / gaussian.mean - 1.);
0422         }
0423       }();
0424 
0425       assert(p_prev + delta_p > 0. && "new momentum must be > 0");
0426       new_pars[eBoundQOverP] = old_bound.charge() / (p_prev + delta_p);
0427 
0428       // compute inverse variance of p from mixture and update covariance
0429       auto new_cov = old_bound.covariance().value();
0430 
0431       const auto varInvP = [&]() {
0432         if (state.options.direction == Direction::Forward) {
0433           const auto f = 1. / (p_prev * gaussian.mean);
0434           return f * f * gaussian.var;
0435         } else {
0436           return gaussian.var / (p_prev * p_prev);
0437         }
0438       }();
0439 
0440       new_cov(eBoundQOverP, eBoundQOverP) += varInvP;
0441       assert(std::isfinite(new_cov(eBoundQOverP, eBoundQOverP)) &&
0442              "new cov not finite");
0443 
0444       // Set the remaining things and push to vector
0445       componentCaches.push_back({new_weight, new_pars, new_cov});
0446     }
0447 
0448     return pathXOverX0;
0449   }
0450 
0451   /// Remove components with low weights and renormalize from the component
0452   /// cache
0453   /// TODO This function does not expect normalized components, but this
0454   /// could be redundant work...
0455   void removeLowWeightComponents(std::vector<ComponentCache>& cmps) const {
0456     auto proj = [](auto& cmp) -> double& { return cmp.weight; };
0457 
0458     detail::normalizeWeights(cmps, proj);
0459 
0460     auto new_end = std::remove_if(cmps.begin(), cmps.end(), [&](auto& cmp) {
0461       return proj(cmp) < m_cfg.weightCutoff;
0462     });
0463 
0464     // In case we would remove all components, keep only the largest
0465     if (std::distance(cmps.begin(), new_end) == 0) {
0466       cmps = {*std::max_element(
0467           cmps.begin(), cmps.end(),
0468           [&](auto& a, auto& b) { return proj(a) < proj(b); })};
0469       cmps.front().weight = 1.0;
0470     } else {
0471       cmps.erase(new_end, cmps.end());
0472       detail::normalizeWeights(cmps, proj);
0473     }
0474   }
0475 
0476   /// Function that updates the stepper from the MultiTrajectory
0477   template <typename propagator_state_t, typename stepper_t>
0478   void updateStepper(propagator_state_t& state, const stepper_t& stepper,
0479                      const TemporaryStates& tmpStates) const {
0480     auto cmps = stepper.componentIterable(state.stepping);
0481 
0482     for (auto [idx, cmp] : zip(tmpStates.tips, cmps)) {
0483       // we set ignored components to missed, so we can remove them after
0484       // the loop
0485       if (tmpStates.weights.at(idx) < m_cfg.weightCutoff) {
0486         cmp.status() = Intersection3D::Status::missed;
0487         continue;
0488       }
0489 
0490       auto proxy = tmpStates.traj.getTrackState(idx);
0491 
0492       cmp.pars() =
0493           MultiTrajectoryHelpers::freeFiltered(state.options.geoContext, proxy);
0494       cmp.cov() = proxy.filteredCovariance();
0495       cmp.weight() = tmpStates.weights.at(idx);
0496     }
0497 
0498     stepper.removeMissedComponents(state.stepping);
0499 
0500     // TODO we have two normalization passes here now, this can probably be
0501     // optimized
0502     detail::normalizeWeights(cmps,
0503                              [&](auto cmp) -> double& { return cmp.weight(); });
0504   }
0505 
0506   /// Function that updates the stepper from the ComponentCache
0507   template <typename propagator_state_t, typename stepper_t,
0508             typename navigator_t>
0509   void updateStepper(propagator_state_t& state, const stepper_t& stepper,
0510                      const navigator_t& navigator,
0511                      const std::vector<ComponentCache>& componentCache) const {
0512     const auto& surface = *navigator.currentSurface(state.navigation);
0513 
0514     // Clear components before adding new ones
0515     stepper.clearComponents(state.stepping);
0516 
0517     // Finally loop over components
0518     for (const auto& [weight, pars, cov] : componentCache) {
0519       // Add the component to the stepper
0520       BoundTrackParameters bound(surface.getSharedPtr(), pars, cov,
0521                                  stepper.particleHypothesis(state.stepping));
0522 
0523       auto res = stepper.addComponent(state.stepping, std::move(bound), weight);
0524 
0525       if (!res.ok()) {
0526         ACTS_ERROR("Error adding component to MultiStepper");
0527         continue;
0528       }
0529 
0530       auto& cmp = *res;
0531       auto freeParams = cmp.pars();
0532       cmp.jacToGlobal() = surface.boundToFreeJacobian(
0533           state.geoContext, freeParams.template segment<3>(eFreePos0),
0534           freeParams.template segment<3>(eFreeDir0));
0535       cmp.pathAccumulated() = state.stepping.pathAccumulated;
0536       cmp.jacobian() = Acts::BoundMatrix::Identity();
0537       cmp.derivative() = Acts::FreeVector::Zero();
0538       cmp.jacTransport() = Acts::FreeMatrix::Identity();
0539     }
0540   }
0541 
0542   /// This function performs the kalman update, computes the new posterior
0543   /// weights, renormalizes all components, and does some statistics.
0544   template <typename propagator_state_t, typename stepper_t,
0545             typename navigator_t>
0546   Result<void> kalmanUpdate(propagator_state_t& state, const stepper_t& stepper,
0547                             const navigator_t& navigator, result_type& result,
0548                             TemporaryStates& tmpStates,
0549                             const SourceLink& source_link) const {
0550     const auto& surface = *navigator.currentSurface(state.navigation);
0551 
0552     // Boolean flag, to distinguish measurement and outlier states. This flag
0553     // is only modified by the valid-measurement-branch, so only if there
0554     // isn't any valid measurement state, the flag stays false and the state
0555     // is thus counted as an outlier
0556     bool is_valid_measurement = false;
0557 
0558     auto cmps = stepper.componentIterable(state.stepping);
0559     for (auto cmp : cmps) {
0560       auto singleState = cmp.singleState(state);
0561       const auto& singleStepper = cmp.singleStepper(stepper);
0562 
0563       auto trackStateProxyRes = detail::kalmanHandleMeasurement(
0564           *m_cfg.calibrationContext, singleState, singleStepper,
0565           m_cfg.extensions, surface, source_link, tmpStates.traj,
0566           MultiTrajectoryTraits::kInvalid, false, logger());
0567 
0568       if (!trackStateProxyRes.ok()) {
0569         return trackStateProxyRes.error();
0570       }
0571 
0572       const auto& trackStateProxy = *trackStateProxyRes;
0573 
0574       // If at least one component is no outlier, we consider the whole thing
0575       // as a measurementState
0576       if (trackStateProxy.typeFlags().test(
0577               Acts::TrackStateFlag::MeasurementFlag)) {
0578         is_valid_measurement = true;
0579       }
0580 
0581       tmpStates.tips.push_back(trackStateProxy.index());
0582       tmpStates.weights[tmpStates.tips.back()] = cmp.weight();
0583     }
0584 
0585     computePosteriorWeights(tmpStates.traj, tmpStates.tips, tmpStates.weights);
0586 
0587     detail::normalizeWeights(tmpStates.tips, [&](auto idx) -> double& {
0588       return tmpStates.weights.at(idx);
0589     });
0590 
0591     // Do the statistics
0592     ++result.processedStates;
0593 
0594     // TODO should outlier states also be counted here?
0595     if (is_valid_measurement) {
0596       ++result.measurementStates;
0597     }
0598 
0599     addCombinedState(result, tmpStates, surface);
0600     result.lastMeasurementTip = result.currentTip;
0601 
0602     using FiltProjector =
0603         MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;
0604     FiltProjector proj{tmpStates.traj, tmpStates.weights};
0605 
0606     std::vector<std::tuple<double, BoundVector, BoundMatrix>> v;
0607 
0608     // TODO Check why can zero weights can occur
0609     for (const auto& idx : tmpStates.tips) {
0610       const auto [w, p, c] = proj(idx);
0611       if (w > 0.0) {
0612         v.push_back({w, p, c});
0613       }
0614     }
0615 
0616     normalizeWeights(v, [](auto& c) -> double& { return std::get<double>(c); });
0617 
0618     result.lastMeasurementState = MultiComponentBoundTrackParameters(
0619         surface.getSharedPtr(), std::move(v),
0620         stepper.particleHypothesis(state.stepping));
0621 
0622     // Return success
0623     return Acts::Result<void>::success();
0624   }
0625 
0626   template <typename propagator_state_t, typename stepper_t,
0627             typename navigator_t>
0628   Result<void> noMeasurementUpdate(propagator_state_t& state,
0629                                    const stepper_t& stepper,
0630                                    const navigator_t& navigator,
0631                                    result_type& result,
0632                                    TemporaryStates& tmpStates,
0633                                    bool doCovTransport) const {
0634     const auto& surface = *navigator.currentSurface(state.navigation);
0635 
0636     // Initialize as true, so that any component can flip it. However, all
0637     // components should behave the same
0638     bool is_hole = true;
0639 
0640     auto cmps = stepper.componentIterable(state.stepping);
0641     for (auto cmp : cmps) {
0642       auto singleState = cmp.singleState(state);
0643       const auto& singleStepper = cmp.singleStepper(stepper);
0644 
0645       // There is some redundant checking inside this function, but do this for
0646       // now until we measure this is significant
0647       auto trackStateProxyRes = detail::kalmanHandleNoMeasurement(
0648           singleState, singleStepper, surface, tmpStates.traj,
0649           MultiTrajectoryTraits::kInvalid, doCovTransport, logger());
0650 
0651       if (!trackStateProxyRes.ok()) {
0652         return trackStateProxyRes.error();
0653       }
0654 
0655       const auto& trackStateProxy = *trackStateProxyRes;
0656 
0657       if (!trackStateProxy.typeFlags().test(TrackStateFlag::HoleFlag)) {
0658         is_hole = false;
0659       }
0660 
0661       tmpStates.tips.push_back(trackStateProxy.index());
0662       tmpStates.weights[tmpStates.tips.back()] = cmp.weight();
0663     }
0664 
0665     // These things should only be done once for all components
0666     if (is_hole) {
0667       ++result.measurementHoles;
0668     }
0669 
0670     ++result.processedStates;
0671 
0672     addCombinedState(result, tmpStates, surface);
0673 
0674     return Result<void>::success();
0675   }
0676 
0677   /// Apply the multiple scattering to the state
0678   template <typename propagator_state_t, typename stepper_t,
0679             typename navigator_t>
0680   void applyMultipleScattering(propagator_state_t& state,
0681                                const stepper_t& stepper,
0682                                const navigator_t& navigator,
0683                                const MaterialUpdateStage& updateStage =
0684                                    MaterialUpdateStage::FullUpdate) const {
0685     const auto& surface = *navigator.currentSurface(state.navigation);
0686 
0687     for (auto cmp : stepper.componentIterable(state.stepping)) {
0688       auto singleState = cmp.singleState(state);
0689       const auto& singleStepper = cmp.singleStepper(stepper);
0690 
0691       detail::PointwiseMaterialInteraction interaction(&surface, singleState,
0692                                                        singleStepper);
0693       if (interaction.evaluateMaterialSlab(singleState, navigator,
0694                                            updateStage)) {
0695         // In the Gsf we only need to handle the multiple scattering
0696         interaction.evaluatePointwiseMaterialInteraction(
0697             m_cfg.multipleScattering, false);
0698 
0699         // Screen out material effects info
0700         ACTS_VERBOSE("Material effects on surface: "
0701                      << surface.geometryId()
0702                      << " at update stage: " << updateStage << " are :");
0703         ACTS_VERBOSE("eLoss = "
0704                      << interaction.Eloss << ", "
0705                      << "variancePhi = " << interaction.variancePhi << ", "
0706                      << "varianceTheta = " << interaction.varianceTheta << ", "
0707                      << "varianceQoverP = " << interaction.varianceQoverP);
0708 
0709         // Update the state and stepper with material effects
0710         interaction.updateState(singleState, singleStepper, addNoise);
0711 
0712         assert(singleState.stepping.cov.array().isFinite().all() &&
0713                "covariance not finite after multi scattering");
0714       }
0715     }
0716   }
0717 
0718   void addCombinedState(result_type& result, const TemporaryStates& tmpStates,
0719                         const Surface& surface) const {
0720     using PrtProjector =
0721         MultiTrajectoryProjector<StatesType::ePredicted, traj_t>;
0722     using FltProjector =
0723         MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;
0724 
0725     if (!m_cfg.inReversePass) {
0726       const auto firstCmpProxy =
0727           tmpStates.traj.getTrackState(tmpStates.tips.front());
0728       const auto isMeasurement =
0729           firstCmpProxy.typeFlags().test(MeasurementFlag);
0730 
0731       const auto mask =
0732           isMeasurement
0733               ? TrackStatePropMask::Calibrated | TrackStatePropMask::Predicted |
0734                     TrackStatePropMask::Filtered | TrackStatePropMask::Smoothed
0735               : TrackStatePropMask::Calibrated | TrackStatePropMask::Predicted;
0736 
0737       auto proxy = result.fittedStates->makeTrackState(mask, result.currentTip);
0738       result.currentTip = proxy.index();
0739 
0740       proxy.setReferenceSurface(surface.getSharedPtr());
0741       proxy.copyFrom(firstCmpProxy, mask);
0742 
0743       auto [prtMean, prtCov] =
0744           mergeGaussianMixture(tmpStates.tips, surface, m_cfg.mergeMethod,
0745                                PrtProjector{tmpStates.traj, tmpStates.weights});
0746       proxy.predicted() = prtMean;
0747       proxy.predictedCovariance() = prtCov;
0748 
0749       if (isMeasurement) {
0750         auto [fltMean, fltCov] = mergeGaussianMixture(
0751             tmpStates.tips, surface, m_cfg.mergeMethod,
0752             FltProjector{tmpStates.traj, tmpStates.weights});
0753         proxy.filtered() = fltMean;
0754         proxy.filteredCovariance() = fltCov;
0755         proxy.smoothed() = BoundVector::Constant(-2);
0756         proxy.smoothedCovariance() = BoundSquareMatrix::Constant(-2);
0757       } else {
0758         proxy.shareFrom(TrackStatePropMask::Predicted,
0759                         TrackStatePropMask::Filtered);
0760       }
0761 
0762     } else {
0763       assert((result.currentTip != MultiTrajectoryTraits::kInvalid &&
0764               "tip not valid"));
0765 
0766       result.fittedStates->applyBackwards(
0767           result.currentTip, [&](auto trackState) {
0768             auto fSurface = &trackState.referenceSurface();
0769             if (fSurface == &surface) {
0770               result.surfacesVisitedBwdAgain.push_back(&surface);
0771 
0772               if (trackState.hasSmoothed()) {
0773                 const auto [smtMean, smtCov] = mergeGaussianMixture(
0774                     tmpStates.tips, surface, m_cfg.mergeMethod,
0775                     FltProjector{tmpStates.traj, tmpStates.weights});
0776 
0777                 trackState.smoothed() = smtMean;
0778                 trackState.smoothedCovariance() = smtCov;
0779               }
0780               return false;
0781             }
0782             return true;
0783           });
0784     }
0785   }
0786 
0787   /// Set the relevant options that can be set from the Options struct all in
0788   /// one place
0789   void setOptions(const Acts::GsfOptions<traj_t>& options) {
0790     m_cfg.maxComponents = options.maxComponents;
0791     m_cfg.extensions = options.extensions;
0792     m_cfg.abortOnError = options.abortOnError;
0793     m_cfg.disableAllMaterialHandling = options.disableAllMaterialHandling;
0794     m_cfg.weightCutoff = options.weightCutoff;
0795     m_cfg.mergeMethod = options.componentMergeMethod;
0796     m_cfg.calibrationContext = &options.calibrationContext.get();
0797   }
0798 };
0799 
0800 }  // namespace Acts::detail