Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-19 09:23:37

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2021 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/EventData/TrackHelpers.hpp"
0012 #include "Acts/EventData/VectorMultiTrajectory.hpp"
0013 #include "Acts/Propagator/DirectNavigator.hpp"
0014 #include "Acts/Propagator/MultiStepperAborters.hpp"
0015 #include "Acts/Propagator/Navigator.hpp"
0016 #include "Acts/Propagator/StandardAborters.hpp"
0017 #include "Acts/TrackFitting/GsfOptions.hpp"
0018 #include "Acts/TrackFitting/detail/GsfActor.hpp"
0019 #include "Acts/Utilities/Logger.hpp"
0020 
0021 namespace Acts {
0022 
0023 namespace detail {
0024 
0025 /// Type trait to identify if a type is a MultiComponentBoundTrackParameters and
0026 /// to inspect its charge representation if not TODO this probably gives an ugly
0027 /// error message if detectCharge does not compile
0028 template <typename T>
0029 struct IsMultiComponentBoundParameters : public std::false_type {};
0030 
0031 template <>
0032 struct IsMultiComponentBoundParameters<MultiComponentBoundTrackParameters>
0033     : public std::true_type {};
0034 
0035 }  // namespace detail
0036 
0037 /// Gaussian Sum Fitter implementation.
0038 /// @tparam propagator_t The propagator type on which the algorithm is built on
0039 /// @tparam bethe_heitler_approx_t The type of the Bethe-Heitler-Approximation
0040 /// @tparam traj_t The MultiTrajectory type (backend)
0041 ///
0042 /// @note This GSF implementation tries to be as compatible to the KalmanFitter
0043 /// as possible. However, strict compatibility is not garantueed.
0044 /// @note Currently there is no possibility to export the states of the
0045 /// individual components from the GSF, the only information returned in the
0046 /// MultiTrajectory are the means of the states. Therefore, also NO dedicated
0047 /// component smoothing is performed as described e.g. by R. Fruewirth.
0048 template <typename propagator_t, typename bethe_heitler_approx_t,
0049           typename traj_t>
0050 struct GaussianSumFitter {
0051   GaussianSumFitter(propagator_t&& propagator, bethe_heitler_approx_t&& bha,
0052                     std::unique_ptr<const Logger> _logger =
0053                         getDefaultLogger("GSF", Logging::INFO))
0054       : m_propagator(std::move(propagator)),
0055         m_betheHeitlerApproximation(std::move(bha)),
0056         m_logger{std::move(_logger)},
0057         m_actorLogger(m_logger->cloneWithSuffix("Actor")) {}
0058 
0059   /// The propagator instance used by the fit function
0060   propagator_t m_propagator;
0061 
0062   /// The fitter holds the instance of the bethe heitler approx
0063   bethe_heitler_approx_t m_betheHeitlerApproximation;
0064 
0065   /// The logger
0066   std::unique_ptr<const Logger> m_logger;
0067   std::unique_ptr<const Logger> m_actorLogger;
0068 
0069   const Logger& logger() const { return *m_logger; }
0070 
0071   /// The navigator type
0072   using GsfNavigator = typename propagator_t::Navigator;
0073 
0074   /// The actor type
0075   using GsfActor = detail::GsfActor<bethe_heitler_approx_t, traj_t>;
0076 
0077   /// This allows to break the propagation by setting the navigationBreak
0078   /// TODO refactor once we can do this more elegantly
0079   struct NavigationBreakAborter {
0080     NavigationBreakAborter() = default;
0081 
0082     template <typename propagator_state_t, typename stepper_t,
0083               typename navigator_t>
0084     bool operator()(propagator_state_t& state, const stepper_t& /*stepper*/,
0085                     const navigator_t& navigator,
0086                     const Logger& /*logger*/) const {
0087       return navigator.navigationBreak(state.navigation);
0088     }
0089   };
0090 
0091   /// @brief The fit function for the Direct navigator
0092   template <typename source_link_it_t, typename start_parameters_t,
0093             typename track_container_t, template <typename> class holder_t>
0094   auto fit(source_link_it_t begin, source_link_it_t end,
0095            const start_parameters_t& sParameters,
0096            const GsfOptions<traj_t>& options,
0097            const std::vector<const Surface*>& sSequence,
0098            TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
0099       const {
0100     // Check if we have the correct navigator
0101     static_assert(
0102         std::is_same_v<DirectNavigator, typename propagator_t::Navigator>);
0103 
0104     // Initialize the forward propagation with the DirectNavigator
0105     auto fwdPropInitializer = [&sSequence, this](const auto& opts) {
0106       using Actors = ActionList<GsfActor, DirectNavigator::Initializer>;
0107       using Aborters = AbortList<NavigationBreakAborter>;
0108 
0109       PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
0110                                                       opts.magFieldContext);
0111 
0112       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0113 
0114       propOptions.actionList.template get<DirectNavigator::Initializer>()
0115           .navSurfaces = sSequence;
0116       propOptions.actionList.template get<GsfActor>()
0117           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0118 
0119       return propOptions;
0120     };
0121 
0122     // Initialize the backward propagation with the DirectNavigator
0123     auto bwdPropInitializer = [&sSequence, this](const auto& opts) {
0124       using Actors = ActionList<GsfActor, DirectNavigator::Initializer>;
0125       using Aborters = AbortList<>;
0126 
0127       std::vector<const Surface*> backwardSequence(
0128           std::next(sSequence.rbegin()), sSequence.rend());
0129       backwardSequence.push_back(opts.referenceSurface);
0130 
0131       PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
0132                                                       opts.magFieldContext);
0133 
0134       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0135 
0136       propOptions.actionList.template get<DirectNavigator::Initializer>()
0137           .navSurfaces = std::move(backwardSequence);
0138       propOptions.actionList.template get<GsfActor>()
0139           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0140 
0141       return propOptions;
0142     };
0143 
0144     return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0145                     bwdPropInitializer, trackContainer);
0146   }
0147 
0148   /// @brief The fit function for the standard navigator
0149   template <typename source_link_it_t, typename start_parameters_t,
0150             typename track_container_t, template <typename> class holder_t>
0151   auto fit(source_link_it_t begin, source_link_it_t end,
0152            const start_parameters_t& sParameters,
0153            const GsfOptions<traj_t>& options,
0154            TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
0155       const {
0156     // Check if we have the correct navigator
0157     static_assert(std::is_same_v<Navigator, typename propagator_t::Navigator>);
0158 
0159     // Initialize the forward propagation with the DirectNavigator
0160     auto fwdPropInitializer = [this](const auto& opts) {
0161       using Actors = ActionList<GsfActor>;
0162       using Aborters = AbortList<EndOfWorldReached, NavigationBreakAborter>;
0163 
0164       PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
0165                                                       opts.magFieldContext);
0166       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0167       propOptions.actionList.template get<GsfActor>()
0168           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0169 
0170       return propOptions;
0171     };
0172 
0173     // Initialize the backward propagation with the DirectNavigator
0174     auto bwdPropInitializer = [this](const auto& opts) {
0175       using Actors = ActionList<GsfActor>;
0176       using Aborters = AbortList<EndOfWorldReached>;
0177 
0178       PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
0179                                                       opts.magFieldContext);
0180 
0181       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0182 
0183       propOptions.actionList.template get<GsfActor>()
0184           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0185 
0186       return propOptions;
0187     };
0188 
0189     return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0190                     bwdPropInitializer, trackContainer);
0191   }
0192 
0193   /// The generic implementation of the fit function.
0194   /// TODO check what this function does with the referenceSurface is e.g. the
0195   /// first measurementSurface
0196   template <typename source_link_it_t, typename start_parameters_t,
0197             typename fwd_prop_initializer_t, typename bwd_prop_initializer_t,
0198             typename track_container_t, template <typename> class holder_t>
0199   Acts::Result<
0200       typename TrackContainer<track_container_t, traj_t, holder_t>::TrackProxy>
0201   fit_impl(source_link_it_t begin, source_link_it_t end,
0202            const start_parameters_t& sParameters,
0203            const GsfOptions<traj_t>& options,
0204            const fwd_prop_initializer_t& fwdPropInitializer,
0205            const bwd_prop_initializer_t& bwdPropInitializer,
0206            TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
0207       const {
0208     // return or abort utility
0209     auto return_error_or_abort = [&](auto error) {
0210       if (options.abortOnError) {
0211         std::abort();
0212       }
0213       return error;
0214     };
0215 
0216     // Define directions based on input propagation direction. This way we can
0217     // refer to 'forward' and 'backward' regardless of the actual direction.
0218     const auto gsfForward = options.propagatorPlainOptions.direction;
0219     const auto gsfBackward = gsfForward.invert();
0220 
0221     // Check if the start parameters are on the start surface
0222     auto intersectionStatusStartSurface =
0223         sParameters.referenceSurface()
0224             .intersect(GeometryContext{},
0225                        sParameters.position(GeometryContext{}),
0226                        sParameters.direction(), BoundaryCheck(true))
0227             .closest()
0228             .status();
0229 
0230     if (intersectionStatusStartSurface != Intersection3D::Status::onSurface) {
0231       ACTS_DEBUG(
0232           "Surface intersection of start parameters WITH bound-check failed");
0233     }
0234 
0235     // To be able to find measurements later, we put them into a map
0236     // We need to copy input SourceLinks anyway, so the map can own them.
0237     ACTS_VERBOSE("Preparing " << std::distance(begin, end)
0238                               << " input measurements");
0239     std::map<GeometryIdentifier, SourceLink> inputMeasurements;
0240     for (auto it = begin; it != end; ++it) {
0241       SourceLink sl = *it;
0242       inputMeasurements.emplace(
0243           options.extensions.surfaceAccessor(sl)->geometryId(), std::move(sl));
0244     }
0245 
0246     ACTS_VERBOSE(
0247         "Gsf: Final measurement map size: " << inputMeasurements.size());
0248 
0249     if (sParameters.covariance() == std::nullopt) {
0250       return GsfError::StartParametersHaveNoCovariance;
0251     }
0252 
0253     /////////////////
0254     // Forward pass
0255     /////////////////
0256     ACTS_VERBOSE("+-----------------------------+");
0257     ACTS_VERBOSE("| Gsf: Do forward propagation |");
0258     ACTS_VERBOSE("+-----------------------------+");
0259 
0260     auto fwdResult = [&]() {
0261       auto fwdPropOptions = fwdPropInitializer(options);
0262 
0263       // Catch the actor and set the measurements
0264       auto& actor = fwdPropOptions.actionList.template get<GsfActor>();
0265       actor.setOptions(options);
0266       actor.m_cfg.inputMeasurements = &inputMeasurements;
0267       actor.m_cfg.numberMeasurements = inputMeasurements.size();
0268       actor.m_cfg.inReversePass = false;
0269       actor.m_cfg.logger = m_actorLogger.get();
0270 
0271       fwdPropOptions.direction = gsfForward;
0272 
0273       // If necessary convert to MultiComponentBoundTrackParameters
0274       using IsMultiParameters =
0275           detail::IsMultiComponentBoundParameters<start_parameters_t>;
0276 
0277       // dirty optional because parameters are not default constructible
0278       std::optional<MultiComponentBoundTrackParameters> params;
0279 
0280       // This allows the initialization with single- and multicomponent start
0281       // parameters
0282       if constexpr (!IsMultiParameters::value) {
0283         params = MultiComponentBoundTrackParameters(
0284             sParameters.referenceSurface().getSharedPtr(),
0285             sParameters.parameters(), *sParameters.covariance(),
0286             sParameters.particleHypothesis());
0287       } else {
0288         params = sParameters;
0289       }
0290 
0291       auto state = m_propagator.makeState(*params, fwdPropOptions);
0292 
0293       auto& r = state.template get<typename GsfActor::result_type>();
0294       r.fittedStates = &trackContainer.trackStateContainer();
0295 
0296       auto propagationResult = m_propagator.propagate(state);
0297 
0298       return m_propagator.makeResult(std::move(state), propagationResult,
0299                                      fwdPropOptions, false);
0300     }();
0301 
0302     if (!fwdResult.ok()) {
0303       return return_error_or_abort(fwdResult.error());
0304     }
0305 
0306     const auto& fwdGsfResult =
0307         fwdResult->template get<typename GsfActor::result_type>();
0308 
0309     if (!fwdGsfResult.result.ok()) {
0310       return return_error_or_abort(fwdGsfResult.result.error());
0311     }
0312 
0313     if (fwdGsfResult.measurementStates == 0) {
0314       return return_error_or_abort(GsfError::NoMeasurementStatesCreatedForward);
0315     }
0316 
0317     ACTS_VERBOSE("Finished forward propagation");
0318     ACTS_VERBOSE("- visited surfaces: " << fwdGsfResult.visitedSurfaces.size());
0319     ACTS_VERBOSE("- processed states: " << fwdGsfResult.processedStates);
0320     ACTS_VERBOSE("- measurement states: " << fwdGsfResult.measurementStates);
0321 
0322     std::size_t nInvalidBetheHeitler = fwdGsfResult.nInvalidBetheHeitler.val();
0323     double maxPathXOverX0 = fwdGsfResult.maxPathXOverX0.val();
0324 
0325     //////////////////
0326     // Backward pass
0327     //////////////////
0328     ACTS_VERBOSE("+------------------------------+");
0329     ACTS_VERBOSE("| Gsf: Do backward propagation |");
0330     ACTS_VERBOSE("+------------------------------+");
0331 
0332     auto bwdResult = [&]() {
0333       auto bwdPropOptions = bwdPropInitializer(options);
0334 
0335       auto& actor = bwdPropOptions.actionList.template get<GsfActor>();
0336       actor.setOptions(options);
0337       actor.m_cfg.inputMeasurements = &inputMeasurements;
0338       actor.m_cfg.inReversePass = true;
0339       actor.m_cfg.logger = m_actorLogger.get();
0340 
0341       bwdPropOptions.direction = gsfBackward;
0342 
0343       const Surface& target = options.referenceSurface
0344                                   ? *options.referenceSurface
0345                                   : sParameters.referenceSurface();
0346 
0347       using PM = TrackStatePropMask;
0348 
0349       const auto& params = *fwdGsfResult.lastMeasurementState;
0350       auto state =
0351           m_propagator.template makeState<MultiComponentBoundTrackParameters,
0352                                           decltype(bwdPropOptions),
0353                                           MultiStepperSurfaceReached>(
0354               params, target, bwdPropOptions);
0355 
0356       assert(
0357           (fwdGsfResult.lastMeasurementTip != MultiTrajectoryTraits::kInvalid &&
0358            "tip is invalid"));
0359 
0360       auto proxy = trackContainer.trackStateContainer().getTrackState(
0361           fwdGsfResult.lastMeasurementTip);
0362       proxy.shareFrom(TrackStatePropMask::Filtered,
0363                       TrackStatePropMask::Smoothed);
0364 
0365       auto& r = state.template get<typename GsfActor::result_type>();
0366       r.fittedStates = &trackContainer.trackStateContainer();
0367       r.currentTip = fwdGsfResult.lastMeasurementTip;
0368       r.visitedSurfaces.push_back(&proxy.referenceSurface());
0369       r.surfacesVisitedBwdAgain.push_back(&proxy.referenceSurface());
0370       r.measurementStates++;
0371       r.processedStates++;
0372 
0373       auto propagationResult = m_propagator.propagate(state);
0374 
0375       return m_propagator.makeResult(std::move(state), propagationResult,
0376                                      target, bwdPropOptions);
0377     }();
0378 
0379     if (!bwdResult.ok()) {
0380       return return_error_or_abort(bwdResult.error());
0381     }
0382 
0383     auto& bwdGsfResult =
0384         bwdResult->template get<typename GsfActor::result_type>();
0385 
0386     if (!bwdGsfResult.result.ok()) {
0387       return return_error_or_abort(bwdGsfResult.result.error());
0388     }
0389 
0390     if (bwdGsfResult.measurementStates == 0) {
0391       return return_error_or_abort(
0392           GsfError::NoMeasurementStatesCreatedBackward);
0393     }
0394 
0395     // For the backward pass we want the counters at in end (= at the
0396     // interaction point) and not at the last measurement surface
0397     bwdGsfResult.nInvalidBetheHeitler.update();
0398     bwdGsfResult.maxPathXOverX0.update();
0399     bwdGsfResult.sumPathXOverX0.update();
0400     nInvalidBetheHeitler += bwdGsfResult.nInvalidBetheHeitler.val();
0401     maxPathXOverX0 =
0402         std::max(maxPathXOverX0, bwdGsfResult.maxPathXOverX0.val());
0403 
0404     if (nInvalidBetheHeitler > 0) {
0405       ACTS_WARNING("Encountered " << nInvalidBetheHeitler
0406                                   << " cases where x/X0 exceeds the range "
0407                                      "of the Bethe-Heitler-Approximation. The "
0408                                      "maximum x/X0 encountered was "
0409                                   << maxPathXOverX0
0410                                   << ". Enable DEBUG output "
0411                                      "for more information.");
0412     }
0413 
0414     ////////////////////////////////////
0415     // Create Kalman Result
0416     ////////////////////////////////////
0417     ACTS_VERBOSE("Gsf - States summary:");
0418     ACTS_VERBOSE("- Fwd measurement states: " << fwdGsfResult.measurementStates
0419                                               << ", holes: "
0420                                               << fwdGsfResult.measurementHoles);
0421     ACTS_VERBOSE("- Bwd measurement states: " << bwdGsfResult.measurementStates
0422                                               << ", holes: "
0423                                               << bwdGsfResult.measurementHoles);
0424 
0425     // TODO should this be warning level? it happens quite often... Investigate!
0426     if (bwdGsfResult.measurementStates != fwdGsfResult.measurementStates) {
0427       ACTS_DEBUG("Fwd and bwd measurement states do not match");
0428     }
0429 
0430     // Go through the states and assign outliers / unset smoothed if surface not
0431     // passed in backward pass
0432     const auto& foundBwd = bwdGsfResult.surfacesVisitedBwdAgain;
0433     std::size_t measurementStatesFinal = 0;
0434 
0435     for (auto state : fwdGsfResult.fittedStates->reverseTrackStateRange(
0436              fwdGsfResult.currentTip)) {
0437       const bool found = std::find(foundBwd.begin(), foundBwd.end(),
0438                                    &state.referenceSurface()) != foundBwd.end();
0439       if (!found && state.typeFlags().test(MeasurementFlag)) {
0440         state.typeFlags().set(OutlierFlag);
0441         state.typeFlags().reset(MeasurementFlag);
0442         state.unset(TrackStatePropMask::Smoothed);
0443       }
0444 
0445       measurementStatesFinal +=
0446           static_cast<std::size_t>(state.typeFlags().test(MeasurementFlag));
0447     }
0448 
0449     if (measurementStatesFinal == 0) {
0450       return return_error_or_abort(GsfError::NoMeasurementStatesCreatedFinal);
0451     }
0452 
0453     auto track = trackContainer.makeTrack();
0454     track.tipIndex() = fwdGsfResult.lastMeasurementTip;
0455 
0456     if (options.referenceSurface) {
0457       const auto& params = *bwdResult->endParameters;
0458 
0459       const auto [finalPars, finalCov] = detail::mergeGaussianMixture(
0460           params.components(), params.referenceSurface(),
0461           options.componentMergeMethod, [](auto& t) {
0462             return std::tie(std::get<0>(t), std::get<1>(t), *std::get<2>(t));
0463           });
0464 
0465       track.parameters() = finalPars;
0466       track.covariance() = finalCov;
0467 
0468       track.setReferenceSurface(params.referenceSurface().getSharedPtr());
0469 
0470       if (trackContainer.hasColumn(
0471               hashString(GsfConstants::kFinalMultiComponentStateColumn))) {
0472         ACTS_DEBUG("Add final multi-component state to track");
0473         track.template component<GsfConstants::FinalMultiComponentState>(
0474             GsfConstants::kFinalMultiComponentStateColumn) = std::move(params);
0475       }
0476     }
0477 
0478     if (trackContainer.hasColumn(
0479             hashString(GsfConstants::kFwdMaxMaterialXOverX0))) {
0480       track.template component<double>(GsfConstants::kFwdMaxMaterialXOverX0) =
0481           fwdGsfResult.maxPathXOverX0.val();
0482     }
0483     if (trackContainer.hasColumn(
0484             hashString(GsfConstants::kFwdSumMaterialXOverX0))) {
0485       track.template component<double>(GsfConstants::kFwdSumMaterialXOverX0) =
0486           fwdGsfResult.sumPathXOverX0.val();
0487     }
0488 
0489     calculateTrackQuantities(track);
0490 
0491     return track;
0492   }
0493 };
0494 
0495 }  // namespace Acts