Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-14 08:25:23

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2021 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/EventData/VectorMultiTrajectory.hpp"
0012 #include "Acts/Propagator/DirectNavigator.hpp"
0013 #include "Acts/Propagator/MultiStepperAborters.hpp"
0014 #include "Acts/Propagator/Navigator.hpp"
0015 #include "Acts/Propagator/StandardAborters.hpp"
0016 #include "Acts/Surfaces/BoundaryTolerance.hpp"
0017 #include "Acts/TrackFitting/GsfOptions.hpp"
0018 #include "Acts/TrackFitting/detail/GsfActor.hpp"
0019 #include "Acts/Utilities/Logger.hpp"
0020 #include "Acts/Utilities/TrackHelpers.hpp"
0021 
0022 namespace Acts {
0023 
0024 namespace detail {
0025 
0026 /// Type trait to identify if a type is a MultiComponentBoundTrackParameters and
0027 /// to inspect its charge representation if not TODO this probably gives an ugly
0028 /// error message if detectCharge does not compile
0029 template <typename T>
0030 struct IsMultiComponentBoundParameters : public std::false_type {};
0031 
0032 template <>
0033 struct IsMultiComponentBoundParameters<MultiComponentBoundTrackParameters>
0034     : public std::true_type {};
0035 
0036 }  // namespace detail
0037 
0038 /// Gaussian Sum Fitter implementation.
0039 /// @tparam propagator_t The propagator type on which the algorithm is built on
0040 /// @tparam bethe_heitler_approx_t The type of the Bethe-Heitler-Approximation
0041 /// @tparam traj_t The MultiTrajectory type (backend)
0042 ///
0043 /// @note This GSF implementation tries to be as compatible to the KalmanFitter
0044 /// as possible. However, strict compatibility is not garantueed.
0045 /// @note Currently there is no possibility to export the states of the
0046 /// individual components from the GSF, the only information returned in the
0047 /// MultiTrajectory are the means of the states. Therefore, also NO dedicated
0048 /// component smoothing is performed as described e.g. by R. Fruewirth.
0049 template <typename propagator_t, typename bethe_heitler_approx_t,
0050           typename traj_t>
0051 struct GaussianSumFitter {
0052   GaussianSumFitter(propagator_t&& propagator, bethe_heitler_approx_t&& bha,
0053                     std::unique_ptr<const Logger> _logger =
0054                         getDefaultLogger("GSF", Logging::INFO))
0055       : m_propagator(std::move(propagator)),
0056         m_betheHeitlerApproximation(std::move(bha)),
0057         m_logger{std::move(_logger)},
0058         m_actorLogger(m_logger->cloneWithSuffix("Actor")) {}
0059 
0060   /// The propagator instance used by the fit function
0061   propagator_t m_propagator;
0062 
0063   /// The fitter holds the instance of the bethe heitler approx
0064   bethe_heitler_approx_t m_betheHeitlerApproximation;
0065 
0066   /// The logger
0067   std::unique_ptr<const Logger> m_logger;
0068   std::unique_ptr<const Logger> m_actorLogger;
0069 
0070   const Logger& logger() const { return *m_logger; }
0071 
0072   /// The navigator type
0073   using GsfNavigator = typename propagator_t::Navigator;
0074 
0075   /// The actor type
0076   using GsfActor = detail::GsfActor<bethe_heitler_approx_t, traj_t>;
0077 
0078   /// This allows to break the propagation by setting the navigationBreak
0079   /// TODO refactor once we can do this more elegantly
0080   struct NavigationBreakAborter {
0081     NavigationBreakAborter() = default;
0082 
0083     template <typename propagator_state_t, typename stepper_t,
0084               typename navigator_t>
0085     bool operator()(propagator_state_t& state, const stepper_t& /*stepper*/,
0086                     const navigator_t& navigator,
0087                     const Logger& /*logger*/) const {
0088       return navigator.navigationBreak(state.navigation);
0089     }
0090   };
0091 
0092   /// @brief The fit function for the Direct navigator
0093   template <typename source_link_it_t, typename start_parameters_t,
0094             typename track_container_t, template <typename> class holder_t>
0095   auto fit(source_link_it_t begin, source_link_it_t end,
0096            const start_parameters_t& sParameters,
0097            const GsfOptions<traj_t>& options,
0098            const std::vector<const Surface*>& sSequence,
0099            TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
0100       const {
0101     // Check if we have the correct navigator
0102     static_assert(
0103         std::is_same_v<DirectNavigator, typename propagator_t::Navigator>);
0104 
0105     // Initialize the forward propagation with the DirectNavigator
0106     auto fwdPropInitializer = [&sSequence, this](const auto& opts) {
0107       using Actors = ActionList<GsfActor>;
0108       using Aborters = AbortList<NavigationBreakAborter>;
0109       using PropagatorOptions =
0110           typename propagator_t::template Options<Actors, Aborters>;
0111 
0112       PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0113 
0114       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0115 
0116       propOptions.navigation.surfaces = sSequence;
0117       propOptions.actionList.template get<GsfActor>()
0118           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0119 
0120       return propOptions;
0121     };
0122 
0123     // Initialize the backward propagation with the DirectNavigator
0124     auto bwdPropInitializer = [&sSequence, this](const auto& opts) {
0125       using Actors = ActionList<GsfActor>;
0126       using Aborters = AbortList<>;
0127       using PropagatorOptions =
0128           typename propagator_t::template Options<Actors, Aborters>;
0129 
0130       std::vector<const Surface*> backwardSequence(
0131           std::next(sSequence.rbegin()), sSequence.rend());
0132       backwardSequence.push_back(opts.referenceSurface);
0133 
0134       PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0135 
0136       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0137 
0138       propOptions.navigation.surfaces = backwardSequence;
0139       propOptions.actionList.template get<GsfActor>()
0140           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0141 
0142       return propOptions;
0143     };
0144 
0145     return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0146                     bwdPropInitializer, trackContainer);
0147   }
0148 
0149   /// @brief The fit function for the standard navigator
0150   template <typename source_link_it_t, typename start_parameters_t,
0151             typename track_container_t, template <typename> class holder_t>
0152   auto fit(source_link_it_t begin, source_link_it_t end,
0153            const start_parameters_t& sParameters,
0154            const GsfOptions<traj_t>& options,
0155            TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
0156       const {
0157     // Check if we have the correct navigator
0158     static_assert(std::is_same_v<Navigator, typename propagator_t::Navigator>);
0159 
0160     // Initialize the forward propagation with the DirectNavigator
0161     auto fwdPropInitializer = [this](const auto& opts) {
0162       using Actors = ActionList<GsfActor>;
0163       using Aborters = AbortList<EndOfWorldReached, NavigationBreakAborter>;
0164       using PropagatorOptions =
0165           typename propagator_t::template Options<Actors, Aborters>;
0166 
0167       PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0168 
0169       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0170 
0171       propOptions.actionList.template get<GsfActor>()
0172           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0173 
0174       return propOptions;
0175     };
0176 
0177     // Initialize the backward propagation with the DirectNavigator
0178     auto bwdPropInitializer = [this](const auto& opts) {
0179       using Actors = ActionList<GsfActor>;
0180       using Aborters = AbortList<EndOfWorldReached>;
0181       using PropagatorOptions =
0182           typename propagator_t::template Options<Actors, Aborters>;
0183 
0184       PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0185 
0186       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0187 
0188       propOptions.actionList.template get<GsfActor>()
0189           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0190 
0191       return propOptions;
0192     };
0193 
0194     return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0195                     bwdPropInitializer, trackContainer);
0196   }
0197 
0198   /// The generic implementation of the fit function.
0199   /// TODO check what this function does with the referenceSurface is e.g. the
0200   /// first measurementSurface
0201   template <typename source_link_it_t, typename start_parameters_t,
0202             typename fwd_prop_initializer_t, typename bwd_prop_initializer_t,
0203             typename track_container_t, template <typename> class holder_t>
0204   Acts::Result<
0205       typename TrackContainer<track_container_t, traj_t, holder_t>::TrackProxy>
0206   fit_impl(source_link_it_t begin, source_link_it_t end,
0207            const start_parameters_t& sParameters,
0208            const GsfOptions<traj_t>& options,
0209            const fwd_prop_initializer_t& fwdPropInitializer,
0210            const bwd_prop_initializer_t& bwdPropInitializer,
0211            TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
0212       const {
0213     // return or abort utility
0214     auto return_error_or_abort = [&](auto error) {
0215       if (options.abortOnError) {
0216         std::abort();
0217       }
0218       return error;
0219     };
0220 
0221     // Define directions based on input propagation direction. This way we can
0222     // refer to 'forward' and 'backward' regardless of the actual direction.
0223     const auto gsfForward = options.propagatorPlainOptions.direction;
0224     const auto gsfBackward = gsfForward.invert();
0225 
0226     // Check if the start parameters are on the start surface
0227     auto intersectionStatusStartSurface =
0228         sParameters.referenceSurface()
0229             .intersect(GeometryContext{},
0230                        sParameters.position(GeometryContext{}),
0231                        sParameters.direction(), BoundaryTolerance::None())
0232             .closest()
0233             .status();
0234 
0235     if (intersectionStatusStartSurface != Intersection3D::Status::onSurface) {
0236       ACTS_DEBUG(
0237           "Surface intersection of start parameters WITH bound-check failed");
0238     }
0239 
0240     // To be able to find measurements later, we put them into a map
0241     // We need to copy input SourceLinks anyway, so the map can own them.
0242     ACTS_VERBOSE("Preparing " << std::distance(begin, end)
0243                               << " input measurements");
0244     std::map<GeometryIdentifier, SourceLink> inputMeasurements;
0245     for (auto it = begin; it != end; ++it) {
0246       SourceLink sl = *it;
0247       inputMeasurements.emplace(
0248           options.extensions.surfaceAccessor(sl)->geometryId(), std::move(sl));
0249     }
0250 
0251     ACTS_VERBOSE(
0252         "Gsf: Final measurement map size: " << inputMeasurements.size());
0253 
0254     if (sParameters.covariance() == std::nullopt) {
0255       return GsfError::StartParametersHaveNoCovariance;
0256     }
0257 
0258     /////////////////
0259     // Forward pass
0260     /////////////////
0261     ACTS_VERBOSE("+-----------------------------+");
0262     ACTS_VERBOSE("| Gsf: Do forward propagation |");
0263     ACTS_VERBOSE("+-----------------------------+");
0264 
0265     auto fwdResult = [&]() {
0266       auto fwdPropOptions = fwdPropInitializer(options);
0267 
0268       // Catch the actor and set the measurements
0269       auto& actor = fwdPropOptions.actionList.template get<GsfActor>();
0270       actor.setOptions(options);
0271       actor.m_cfg.inputMeasurements = &inputMeasurements;
0272       actor.m_cfg.numberMeasurements = inputMeasurements.size();
0273       actor.m_cfg.inReversePass = false;
0274       actor.m_cfg.logger = m_actorLogger.get();
0275 
0276       fwdPropOptions.direction = gsfForward;
0277 
0278       // If necessary convert to MultiComponentBoundTrackParameters
0279       using IsMultiParameters =
0280           detail::IsMultiComponentBoundParameters<start_parameters_t>;
0281 
0282       // dirty optional because parameters are not default constructible
0283       std::optional<MultiComponentBoundTrackParameters> params;
0284 
0285       // This allows the initialization with single- and multicomponent start
0286       // parameters
0287       if constexpr (!IsMultiParameters::value) {
0288         params = MultiComponentBoundTrackParameters(
0289             sParameters.referenceSurface().getSharedPtr(),
0290             sParameters.parameters(), *sParameters.covariance(),
0291             sParameters.particleHypothesis());
0292       } else {
0293         params = sParameters;
0294       }
0295 
0296       auto state = m_propagator.makeState(*params, fwdPropOptions);
0297 
0298       auto& r = state.template get<typename GsfActor::result_type>();
0299       r.fittedStates = &trackContainer.trackStateContainer();
0300 
0301       auto propagationResult = m_propagator.propagate(state);
0302 
0303       return m_propagator.makeResult(std::move(state), propagationResult,
0304                                      fwdPropOptions, false);
0305     }();
0306 
0307     if (!fwdResult.ok()) {
0308       return return_error_or_abort(fwdResult.error());
0309     }
0310 
0311     const auto& fwdGsfResult =
0312         fwdResult->template get<typename GsfActor::result_type>();
0313 
0314     if (!fwdGsfResult.result.ok()) {
0315       return return_error_or_abort(fwdGsfResult.result.error());
0316     }
0317 
0318     if (fwdGsfResult.measurementStates == 0) {
0319       return return_error_or_abort(GsfError::NoMeasurementStatesCreatedForward);
0320     }
0321 
0322     ACTS_VERBOSE("Finished forward propagation");
0323     ACTS_VERBOSE("- visited surfaces: " << fwdGsfResult.visitedSurfaces.size());
0324     ACTS_VERBOSE("- processed states: " << fwdGsfResult.processedStates);
0325     ACTS_VERBOSE("- measurement states: " << fwdGsfResult.measurementStates);
0326 
0327     std::size_t nInvalidBetheHeitler = fwdGsfResult.nInvalidBetheHeitler.val();
0328     double maxPathXOverX0 = fwdGsfResult.maxPathXOverX0.val();
0329 
0330     //////////////////
0331     // Backward pass
0332     //////////////////
0333     ACTS_VERBOSE("+------------------------------+");
0334     ACTS_VERBOSE("| Gsf: Do backward propagation |");
0335     ACTS_VERBOSE("+------------------------------+");
0336 
0337     auto bwdResult = [&]() {
0338       auto bwdPropOptions = bwdPropInitializer(options);
0339 
0340       auto& actor = bwdPropOptions.actionList.template get<GsfActor>();
0341       actor.setOptions(options);
0342       actor.m_cfg.inputMeasurements = &inputMeasurements;
0343       actor.m_cfg.inReversePass = true;
0344       actor.m_cfg.logger = m_actorLogger.get();
0345 
0346       bwdPropOptions.direction = gsfBackward;
0347 
0348       const Surface& target = options.referenceSurface
0349                                   ? *options.referenceSurface
0350                                   : sParameters.referenceSurface();
0351 
0352       const auto& params = *fwdGsfResult.lastMeasurementState;
0353       auto state =
0354           m_propagator.template makeState<MultiComponentBoundTrackParameters,
0355                                           decltype(bwdPropOptions),
0356                                           MultiStepperSurfaceReached>(
0357               params, target, bwdPropOptions);
0358 
0359       assert(
0360           (fwdGsfResult.lastMeasurementTip != MultiTrajectoryTraits::kInvalid &&
0361            "tip is invalid"));
0362 
0363       auto proxy = trackContainer.trackStateContainer().getTrackState(
0364           fwdGsfResult.lastMeasurementTip);
0365       proxy.shareFrom(TrackStatePropMask::Filtered,
0366                       TrackStatePropMask::Smoothed);
0367 
0368       auto& r = state.template get<typename GsfActor::result_type>();
0369       r.fittedStates = &trackContainer.trackStateContainer();
0370       r.currentTip = fwdGsfResult.lastMeasurementTip;
0371       r.visitedSurfaces.push_back(&proxy.referenceSurface());
0372       r.surfacesVisitedBwdAgain.push_back(&proxy.referenceSurface());
0373       r.measurementStates++;
0374       r.processedStates++;
0375 
0376       auto propagationResult = m_propagator.propagate(state);
0377 
0378       return m_propagator.makeResult(std::move(state), propagationResult,
0379                                      target, bwdPropOptions);
0380     }();
0381 
0382     if (!bwdResult.ok()) {
0383       return return_error_or_abort(bwdResult.error());
0384     }
0385 
0386     auto& bwdGsfResult =
0387         bwdResult->template get<typename GsfActor::result_type>();
0388 
0389     if (!bwdGsfResult.result.ok()) {
0390       return return_error_or_abort(bwdGsfResult.result.error());
0391     }
0392 
0393     if (bwdGsfResult.measurementStates == 0) {
0394       return return_error_or_abort(
0395           GsfError::NoMeasurementStatesCreatedBackward);
0396     }
0397 
0398     // For the backward pass we want the counters at in end (= at the
0399     // interaction point) and not at the last measurement surface
0400     bwdGsfResult.nInvalidBetheHeitler.update();
0401     bwdGsfResult.maxPathXOverX0.update();
0402     bwdGsfResult.sumPathXOverX0.update();
0403     nInvalidBetheHeitler += bwdGsfResult.nInvalidBetheHeitler.val();
0404     maxPathXOverX0 =
0405         std::max(maxPathXOverX0, bwdGsfResult.maxPathXOverX0.val());
0406 
0407     if (nInvalidBetheHeitler > 0) {
0408       ACTS_WARNING("Encountered " << nInvalidBetheHeitler
0409                                   << " cases where x/X0 exceeds the range "
0410                                      "of the Bethe-Heitler-Approximation. The "
0411                                      "maximum x/X0 encountered was "
0412                                   << maxPathXOverX0
0413                                   << ". Enable DEBUG output "
0414                                      "for more information.");
0415     }
0416 
0417     ////////////////////////////////////
0418     // Create Kalman Result
0419     ////////////////////////////////////
0420     ACTS_VERBOSE("Gsf - States summary:");
0421     ACTS_VERBOSE("- Fwd measurement states: " << fwdGsfResult.measurementStates
0422                                               << ", holes: "
0423                                               << fwdGsfResult.measurementHoles);
0424     ACTS_VERBOSE("- Bwd measurement states: " << bwdGsfResult.measurementStates
0425                                               << ", holes: "
0426                                               << bwdGsfResult.measurementHoles);
0427 
0428     // TODO should this be warning level? it happens quite often... Investigate!
0429     if (bwdGsfResult.measurementStates != fwdGsfResult.measurementStates) {
0430       ACTS_DEBUG("Fwd and bwd measurement states do not match");
0431     }
0432 
0433     // Go through the states and assign outliers / unset smoothed if surface not
0434     // passed in backward pass
0435     const auto& foundBwd = bwdGsfResult.surfacesVisitedBwdAgain;
0436     std::size_t measurementStatesFinal = 0;
0437 
0438     for (auto state : fwdGsfResult.fittedStates->reverseTrackStateRange(
0439              fwdGsfResult.currentTip)) {
0440       const bool found = std::find(foundBwd.begin(), foundBwd.end(),
0441                                    &state.referenceSurface()) != foundBwd.end();
0442       if (!found && state.typeFlags().test(MeasurementFlag)) {
0443         state.typeFlags().set(OutlierFlag);
0444         state.typeFlags().reset(MeasurementFlag);
0445         state.unset(TrackStatePropMask::Smoothed);
0446       }
0447 
0448       measurementStatesFinal +=
0449           static_cast<std::size_t>(state.typeFlags().test(MeasurementFlag));
0450     }
0451 
0452     if (measurementStatesFinal == 0) {
0453       return return_error_or_abort(GsfError::NoMeasurementStatesCreatedFinal);
0454     }
0455 
0456     auto track = trackContainer.makeTrack();
0457     track.tipIndex() = fwdGsfResult.lastMeasurementTip;
0458 
0459     if (options.referenceSurface) {
0460       const auto& params = *bwdResult->endParameters;
0461 
0462       const auto [finalPars, finalCov] = detail::mergeGaussianMixture(
0463           params.components(), params.referenceSurface(),
0464           options.componentMergeMethod, [](auto& t) {
0465             return std::tie(std::get<0>(t), std::get<1>(t), *std::get<2>(t));
0466           });
0467 
0468       track.parameters() = finalPars;
0469       track.covariance() = finalCov;
0470 
0471       track.setReferenceSurface(params.referenceSurface().getSharedPtr());
0472 
0473       if (trackContainer.hasColumn(
0474               hashString(GsfConstants::kFinalMultiComponentStateColumn))) {
0475         ACTS_DEBUG("Add final multi-component state to track");
0476         track.template component<GsfConstants::FinalMultiComponentState>(
0477             GsfConstants::kFinalMultiComponentStateColumn) = std::move(params);
0478       }
0479     }
0480 
0481     if (trackContainer.hasColumn(
0482             hashString(GsfConstants::kFwdMaxMaterialXOverX0))) {
0483       track.template component<double>(GsfConstants::kFwdMaxMaterialXOverX0) =
0484           fwdGsfResult.maxPathXOverX0.val();
0485     }
0486     if (trackContainer.hasColumn(
0487             hashString(GsfConstants::kFwdSumMaterialXOverX0))) {
0488       track.template component<double>(GsfConstants::kFwdSumMaterialXOverX0) =
0489           fwdGsfResult.sumPathXOverX0.val();
0490     }
0491 
0492     calculateTrackQuantities(track);
0493 
0494     return track;
0495   }
0496 };
0497 
0498 }  // namespace Acts