Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:11:06

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/EventData/VectorMultiTrajectory.hpp"
0012 #include "Acts/Propagator/DirectNavigator.hpp"
0013 #include "Acts/Propagator/MultiStepperAborters.hpp"
0014 #include "Acts/Propagator/Navigator.hpp"
0015 #include "Acts/Propagator/StandardAborters.hpp"
0016 #include "Acts/Surfaces/BoundaryTolerance.hpp"
0017 #include "Acts/TrackFitting/GsfError.hpp"
0018 #include "Acts/TrackFitting/GsfOptions.hpp"
0019 #include "Acts/TrackFitting/detail/GsfActor.hpp"
0020 #include "Acts/Utilities/Helpers.hpp"
0021 #include "Acts/Utilities/Logger.hpp"
0022 #include "Acts/Utilities/TrackHelpers.hpp"
0023 
0024 namespace Acts {
0025 
0026 namespace detail {
0027 
0028 /// Type trait to identify if a type is a MultiComponentBoundTrackParameters and
0029 /// to inspect its charge representation if not TODO this probably gives an ugly
0030 /// error message if detectCharge does not compile
0031 template <typename T>
0032 struct IsMultiComponentBoundParameters : public std::false_type {};
0033 
0034 template <>
0035 struct IsMultiComponentBoundParameters<MultiComponentBoundTrackParameters>
0036     : public std::true_type {};
0037 
0038 }  // namespace detail
0039 
0040 /// Gaussian Sum Fitter implementation.
0041 /// @tparam propagator_t The propagator type on which the algorithm is built on
0042 /// @tparam bethe_heitler_approx_t The type of the Bethe-Heitler-Approximation
0043 /// @tparam traj_t The MultiTrajectory type (backend)
0044 ///
0045 /// @note This GSF implementation tries to be as compatible to the KalmanFitter
0046 /// as possible. However, strict compatibility is not garantueed.
0047 /// @note Currently there is no possibility to export the states of the
0048 /// individual components from the GSF, the only information returned in the
0049 /// MultiTrajectory are the means of the states. Therefore, also NO dedicated
0050 /// component smoothing is performed as described e.g. by R. Fruewirth.
0051 template <typename propagator_t, typename bethe_heitler_approx_t,
0052           typename traj_t>
0053 struct GaussianSumFitter {
0054   GaussianSumFitter(propagator_t&& propagator, bethe_heitler_approx_t&& bha,
0055                     std::unique_ptr<const Logger> _logger =
0056                         getDefaultLogger("GSF", Logging::INFO))
0057       : m_propagator(std::move(propagator)),
0058         m_betheHeitlerApproximation(std::move(bha)),
0059         m_logger{std::move(_logger)},
0060         m_actorLogger(m_logger->cloneWithSuffix("Actor")) {}
0061 
0062   /// The propagator instance used by the fit function
0063   propagator_t m_propagator;
0064 
0065   /// The fitter holds the instance of the bethe heitler approx
0066   bethe_heitler_approx_t m_betheHeitlerApproximation;
0067 
0068   /// The logger
0069   std::unique_ptr<const Logger> m_logger;
0070   std::unique_ptr<const Logger> m_actorLogger;
0071 
0072   const Logger& logger() const { return *m_logger; }
0073 
0074   /// The navigator type
0075   using GsfNavigator = typename propagator_t::Navigator;
0076 
0077   /// The actor type
0078   using GsfActor = detail::GsfActor<bethe_heitler_approx_t, traj_t>;
0079 
0080   /// @brief The fit function for the Direct navigator
0081   template <typename source_link_it_t, typename start_parameters_t,
0082             TrackContainerFrontend track_container_t>
0083   auto fit(source_link_it_t begin, source_link_it_t end,
0084            const start_parameters_t& sParameters,
0085            const GsfOptions<traj_t>& options,
0086            const std::vector<const Surface*>& sSequence,
0087            track_container_t& trackContainer) const {
0088     // Check if we have the correct navigator
0089     static_assert(
0090         std::is_same_v<DirectNavigator, typename propagator_t::Navigator>);
0091 
0092     // Initialize the forward propagation with the DirectNavigator
0093     auto fwdPropInitializer = [&sSequence, this](const auto& opts) {
0094       using Actors = ActorList<GsfActor>;
0095       using PropagatorOptions = typename propagator_t::template Options<Actors>;
0096 
0097       PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0098 
0099       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0100 
0101       propOptions.navigation.surfaces = sSequence;
0102       propOptions.actorList.template get<GsfActor>()
0103           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0104 
0105       return propOptions;
0106     };
0107 
0108     // Initialize the backward propagation with the DirectNavigator
0109     auto bwdPropInitializer = [&sSequence, this](const auto& opts) {
0110       using Actors = ActorList<GsfActor>;
0111       using PropagatorOptions = typename propagator_t::template Options<Actors>;
0112 
0113       PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0114 
0115       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0116 
0117       propOptions.navigation.surfaces = sSequence;
0118       propOptions.actorList.template get<GsfActor>()
0119           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0120 
0121       return propOptions;
0122     };
0123 
0124     return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0125                     bwdPropInitializer, trackContainer);
0126   }
0127 
0128   /// @brief The fit function for the standard navigator
0129   template <typename source_link_it_t, typename start_parameters_t,
0130             TrackContainerFrontend track_container_t>
0131   auto fit(source_link_it_t begin, source_link_it_t end,
0132            const start_parameters_t& sParameters,
0133            const GsfOptions<traj_t>& options,
0134            track_container_t& trackContainer) const {
0135     // Check if we have the correct navigator
0136     static_assert(std::is_same_v<Navigator, typename propagator_t::Navigator>);
0137 
0138     // Initialize the forward propagation with the DirectNavigator
0139     auto fwdPropInitializer = [this](const auto& opts) {
0140       using Actors = ActorList<GsfActor, EndOfWorldReached>;
0141       using PropagatorOptions = typename propagator_t::template Options<Actors>;
0142 
0143       PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0144 
0145       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0146 
0147       propOptions.actorList.template get<GsfActor>()
0148           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0149 
0150       return propOptions;
0151     };
0152 
0153     // Initialize the backward propagation with the DirectNavigator
0154     auto bwdPropInitializer = [this](const auto& opts) {
0155       using Actors = ActorList<GsfActor, EndOfWorldReached>;
0156       using PropagatorOptions = typename propagator_t::template Options<Actors>;
0157 
0158       PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0159 
0160       propOptions.setPlainOptions(opts.propagatorPlainOptions);
0161 
0162       propOptions.actorList.template get<GsfActor>()
0163           .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0164 
0165       return propOptions;
0166     };
0167 
0168     return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0169                     bwdPropInitializer, trackContainer);
0170   }
0171 
0172   /// The generic implementation of the fit function.
0173   /// TODO check what this function does with the referenceSurface is e.g. the
0174   /// first measurementSurface
0175   template <typename source_link_it_t, typename start_parameters_t,
0176             typename fwd_prop_initializer_t, typename bwd_prop_initializer_t,
0177             TrackContainerFrontend track_container_t>
0178   Acts::Result<typename track_container_t::TrackProxy> fit_impl(
0179       source_link_it_t begin, source_link_it_t end,
0180       const start_parameters_t& sParameters, const GsfOptions<traj_t>& options,
0181       const fwd_prop_initializer_t& fwdPropInitializer,
0182       const bwd_prop_initializer_t& bwdPropInitializer,
0183       track_container_t& trackContainer) const {
0184     // return or abort utility
0185     auto return_error_or_abort = [&](auto error) {
0186       if (options.abortOnError) {
0187         std::abort();
0188       }
0189       return error;
0190     };
0191 
0192     // Define directions based on input propagation direction. This way we can
0193     // refer to 'forward' and 'backward' regardless of the actual direction.
0194     const auto gsfForward = options.propagatorPlainOptions.direction;
0195     const auto gsfBackward = gsfForward.invert();
0196 
0197     // Check if the start parameters are on the start surface
0198     auto intersectionStatusStartSurface =
0199         sParameters.referenceSurface()
0200             .intersect(GeometryContext{},
0201                        sParameters.position(GeometryContext{}),
0202                        sParameters.direction(), BoundaryTolerance::None())
0203             .closest()
0204             .status();
0205 
0206     if (intersectionStatusStartSurface != IntersectionStatus::onSurface) {
0207       ACTS_DEBUG(
0208           "Surface intersection of start parameters WITH bound-check failed");
0209     }
0210 
0211     // To be able to find measurements later, we put them into a map
0212     // We need to copy input SourceLinks anyway, so the map can own them.
0213     ACTS_VERBOSE("Preparing " << std::distance(begin, end)
0214                               << " input measurements");
0215     std::map<GeometryIdentifier, SourceLink> inputMeasurements;
0216     for (auto it = begin; it != end; ++it) {
0217       SourceLink sl = *it;
0218       inputMeasurements.emplace(
0219           options.extensions.surfaceAccessor(sl)->geometryId(), std::move(sl));
0220     }
0221 
0222     ACTS_VERBOSE(
0223         "Gsf: Final measurement map size: " << inputMeasurements.size());
0224 
0225     if (sParameters.covariance() == std::nullopt) {
0226       return GsfError::StartParametersHaveNoCovariance;
0227     }
0228 
0229     /////////////////
0230     // Forward pass
0231     /////////////////
0232     ACTS_VERBOSE("+-----------------------------+");
0233     ACTS_VERBOSE("| Gsf: Do forward propagation |");
0234     ACTS_VERBOSE("+-----------------------------+");
0235 
0236     auto fwdResult = [&]() {
0237       auto fwdPropOptions = fwdPropInitializer(options);
0238 
0239       // Catch the actor and set the measurements
0240       auto& actor = fwdPropOptions.actorList.template get<GsfActor>();
0241       actor.setOptions(options);
0242       actor.m_cfg.inputMeasurements = &inputMeasurements;
0243       actor.m_cfg.numberMeasurements = inputMeasurements.size();
0244       actor.m_cfg.inReversePass = false;
0245       actor.m_cfg.logger = m_actorLogger.get();
0246 
0247       fwdPropOptions.direction = gsfForward;
0248 
0249       // If necessary convert to MultiComponentBoundTrackParameters
0250       using IsMultiParameters =
0251           detail::IsMultiComponentBoundParameters<start_parameters_t>;
0252 
0253       // dirty optional because parameters are not default constructible
0254       std::optional<MultiComponentBoundTrackParameters> params;
0255 
0256       // This allows the initialization with single- and multicomponent start
0257       // parameters
0258       if constexpr (!IsMultiParameters::value) {
0259         params = MultiComponentBoundTrackParameters(
0260             sParameters.referenceSurface().getSharedPtr(),
0261             sParameters.parameters(), *sParameters.covariance(),
0262             sParameters.particleHypothesis());
0263       } else {
0264         params = sParameters;
0265       }
0266 
0267       auto state = m_propagator.makeState(*params, fwdPropOptions);
0268 
0269       auto& r = state.template get<typename GsfActor::result_type>();
0270       r.fittedStates = &trackContainer.trackStateContainer();
0271 
0272       auto propagationResult = m_propagator.propagate(state);
0273 
0274       return m_propagator.makeResult(std::move(state), propagationResult,
0275                                      fwdPropOptions, false);
0276     }();
0277 
0278     if (!fwdResult.ok()) {
0279       return return_error_or_abort(fwdResult.error());
0280     }
0281 
0282     const auto& fwdGsfResult =
0283         fwdResult->template get<typename GsfActor::result_type>();
0284 
0285     if (!fwdGsfResult.result.ok()) {
0286       return return_error_or_abort(fwdGsfResult.result.error());
0287     }
0288 
0289     if (fwdGsfResult.measurementStates == 0) {
0290       return return_error_or_abort(GsfError::NoMeasurementStatesCreatedForward);
0291     }
0292 
0293     ACTS_VERBOSE("Finished forward propagation");
0294     ACTS_VERBOSE("- visited surfaces: " << fwdGsfResult.visitedSurfaces.size());
0295     ACTS_VERBOSE("- processed states: " << fwdGsfResult.processedStates);
0296     ACTS_VERBOSE("- measurement states: " << fwdGsfResult.measurementStates);
0297 
0298     std::size_t nInvalidBetheHeitler = fwdGsfResult.nInvalidBetheHeitler.val();
0299     double maxPathXOverX0 = fwdGsfResult.maxPathXOverX0.val();
0300 
0301     //////////////////
0302     // Backward pass
0303     //////////////////
0304     ACTS_VERBOSE("+------------------------------+");
0305     ACTS_VERBOSE("| Gsf: Do backward propagation |");
0306     ACTS_VERBOSE("+------------------------------+");
0307 
0308     auto bwdResult = [&]() {
0309       auto bwdPropOptions = bwdPropInitializer(options);
0310 
0311       auto& actor = bwdPropOptions.actorList.template get<GsfActor>();
0312       actor.setOptions(options);
0313       actor.m_cfg.inputMeasurements = &inputMeasurements;
0314       actor.m_cfg.inReversePass = true;
0315       actor.m_cfg.logger = m_actorLogger.get();
0316 
0317       bwdPropOptions.direction = gsfBackward;
0318 
0319       const Surface& target = options.referenceSurface
0320                                   ? *options.referenceSurface
0321                                   : sParameters.referenceSurface();
0322 
0323       const auto& params = *fwdGsfResult.lastMeasurementState;
0324       auto state =
0325           m_propagator.template makeState<MultiComponentBoundTrackParameters,
0326                                           decltype(bwdPropOptions),
0327                                           MultiStepperSurfaceReached>(
0328               params, target, bwdPropOptions);
0329 
0330       assert(
0331           (fwdGsfResult.lastMeasurementTip != MultiTrajectoryTraits::kInvalid &&
0332            "tip is invalid"));
0333 
0334       auto proxy = trackContainer.trackStateContainer().getTrackState(
0335           fwdGsfResult.lastMeasurementTip);
0336       proxy.shareFrom(TrackStatePropMask::Filtered,
0337                       TrackStatePropMask::Smoothed);
0338 
0339       auto& r = state.template get<typename GsfActor::result_type>();
0340       r.fittedStates = &trackContainer.trackStateContainer();
0341       r.currentTip = fwdGsfResult.lastMeasurementTip;
0342       r.visitedSurfaces.push_back(&proxy.referenceSurface());
0343       r.surfacesVisitedBwdAgain.push_back(&proxy.referenceSurface());
0344       r.measurementStates++;
0345       r.processedStates++;
0346 
0347       auto propagationResult = m_propagator.propagate(state);
0348 
0349       return m_propagator.makeResult(std::move(state), propagationResult,
0350                                      target, bwdPropOptions);
0351     }();
0352 
0353     if (!bwdResult.ok()) {
0354       return return_error_or_abort(bwdResult.error());
0355     }
0356 
0357     auto& bwdGsfResult =
0358         bwdResult->template get<typename GsfActor::result_type>();
0359 
0360     if (!bwdGsfResult.result.ok()) {
0361       return return_error_or_abort(bwdGsfResult.result.error());
0362     }
0363 
0364     if (bwdGsfResult.measurementStates == 0) {
0365       return return_error_or_abort(
0366           GsfError::NoMeasurementStatesCreatedBackward);
0367     }
0368 
0369     // For the backward pass we want the counters at in end (= at the
0370     // interaction point) and not at the last measurement surface
0371     bwdGsfResult.nInvalidBetheHeitler.update();
0372     bwdGsfResult.maxPathXOverX0.update();
0373     bwdGsfResult.sumPathXOverX0.update();
0374     nInvalidBetheHeitler += bwdGsfResult.nInvalidBetheHeitler.val();
0375     maxPathXOverX0 =
0376         std::max(maxPathXOverX0, bwdGsfResult.maxPathXOverX0.val());
0377 
0378     if (nInvalidBetheHeitler > 0) {
0379       ACTS_WARNING("Encountered " << nInvalidBetheHeitler
0380                                   << " cases where x/X0 exceeds the range "
0381                                      "of the Bethe-Heitler-Approximation. The "
0382                                      "maximum x/X0 encountered was "
0383                                   << maxPathXOverX0
0384                                   << ". Enable DEBUG output "
0385                                      "for more information.");
0386     }
0387 
0388     ////////////////////////////////////
0389     // Create Kalman Result
0390     ////////////////////////////////////
0391     ACTS_VERBOSE("Gsf - States summary:");
0392     ACTS_VERBOSE("- Fwd measurement states: " << fwdGsfResult.measurementStates
0393                                               << ", holes: "
0394                                               << fwdGsfResult.measurementHoles);
0395     ACTS_VERBOSE("- Bwd measurement states: " << bwdGsfResult.measurementStates
0396                                               << ", holes: "
0397                                               << bwdGsfResult.measurementHoles);
0398 
0399     // TODO should this be warning level? it happens quite often... Investigate!
0400     if (bwdGsfResult.measurementStates != fwdGsfResult.measurementStates) {
0401       ACTS_DEBUG("Fwd and bwd measurement states do not match");
0402     }
0403 
0404     // Go through the states and assign outliers / unset smoothed if surface not
0405     // passed in backward pass
0406     const auto& foundBwd = bwdGsfResult.surfacesVisitedBwdAgain;
0407     std::size_t measurementStatesFinal = 0;
0408 
0409     for (auto state : fwdGsfResult.fittedStates->reverseTrackStateRange(
0410              fwdGsfResult.currentTip)) {
0411       const bool found =
0412           rangeContainsValue(foundBwd, &state.referenceSurface());
0413       if (!found && state.typeFlags().test(MeasurementFlag)) {
0414         state.typeFlags().set(OutlierFlag);
0415         state.typeFlags().reset(MeasurementFlag);
0416         state.unset(TrackStatePropMask::Smoothed);
0417       }
0418 
0419       measurementStatesFinal +=
0420           static_cast<std::size_t>(state.typeFlags().test(MeasurementFlag));
0421     }
0422 
0423     if (measurementStatesFinal == 0) {
0424       return return_error_or_abort(GsfError::NoMeasurementStatesCreatedFinal);
0425     }
0426 
0427     auto track = trackContainer.makeTrack();
0428     track.tipIndex() = fwdGsfResult.lastMeasurementTip;
0429 
0430     if (options.referenceSurface) {
0431       const auto& params = *bwdResult->endParameters;
0432 
0433       const auto [finalPars, finalCov] = detail::mergeGaussianMixture(
0434           params.components(), params.referenceSurface(),
0435           options.componentMergeMethod, [](auto& t) {
0436             return std::tie(std::get<0>(t), std::get<1>(t), *std::get<2>(t));
0437           });
0438 
0439       track.parameters() = finalPars;
0440       track.covariance() = finalCov;
0441 
0442       track.setReferenceSurface(params.referenceSurface().getSharedPtr());
0443 
0444       if (trackContainer.hasColumn(
0445               hashString(GsfConstants::kFinalMultiComponentStateColumn))) {
0446         ACTS_DEBUG("Add final multi-component state to track");
0447         track.template component<GsfConstants::FinalMultiComponentState>(
0448             GsfConstants::kFinalMultiComponentStateColumn) = std::move(params);
0449       }
0450     }
0451 
0452     if (trackContainer.hasColumn(
0453             hashString(GsfConstants::kFwdMaxMaterialXOverX0))) {
0454       track.template component<double>(GsfConstants::kFwdMaxMaterialXOverX0) =
0455           fwdGsfResult.maxPathXOverX0.val();
0456     }
0457     if (trackContainer.hasColumn(
0458             hashString(GsfConstants::kFwdSumMaterialXOverX0))) {
0459       track.template component<double>(GsfConstants::kFwdSumMaterialXOverX0) =
0460           fwdGsfResult.sumPathXOverX0.val();
0461     }
0462 
0463     calculateTrackQuantities(track);
0464 
0465     return track;
0466   }
0467 };
0468 
0469 }  // namespace Acts