File indexing completed on 2025-12-15 09:42:23
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/EventData/VectorMultiTrajectory.hpp"
0012 #include "Acts/Propagator/DirectNavigator.hpp"
0013 #include "Acts/Propagator/MultiStepperAborters.hpp"
0014 #include "Acts/Propagator/Navigator.hpp"
0015 #include "Acts/Propagator/StandardAborters.hpp"
0016 #include "Acts/Surfaces/BoundaryTolerance.hpp"
0017 #include "Acts/TrackFitting/GsfError.hpp"
0018 #include "Acts/TrackFitting/GsfOptions.hpp"
0019 #include "Acts/TrackFitting/detail/GsfActor.hpp"
0020 #include "Acts/Utilities/Helpers.hpp"
0021 #include "Acts/Utilities/Intersection.hpp"
0022 #include "Acts/Utilities/Logger.hpp"
0023 #include "Acts/Utilities/TrackHelpers.hpp"
0024
0025 namespace Acts {
0026
0027 namespace detail {
0028
0029
0030
0031
0032 template <typename T>
0033 struct IsMultiComponentBoundParameters : public std::false_type {};
0034
0035 template <>
0036 struct IsMultiComponentBoundParameters<MultiComponentBoundTrackParameters>
0037 : public std::true_type {};
0038
0039 }
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052 template <typename propagator_t, typename bethe_heitler_approx_t,
0053 typename traj_t>
0054 struct GaussianSumFitter {
0055
0056
0057
0058
0059 GaussianSumFitter(propagator_t&& propagator, bethe_heitler_approx_t&& bha,
0060 std::unique_ptr<const Logger> _logger =
0061 getDefaultLogger("GSF", Logging::INFO))
0062 : m_propagator(std::move(propagator)),
0063 m_betheHeitlerApproximation(std::move(bha)),
0064 m_logger{std::move(_logger)},
0065 m_actorLogger(m_logger->cloneWithSuffix("Actor")) {}
0066
0067
0068 propagator_t m_propagator;
0069
0070
0071 bethe_heitler_approx_t m_betheHeitlerApproximation;
0072
0073
0074 std::unique_ptr<const Logger> m_logger;
0075
0076 std::unique_ptr<const Logger> m_actorLogger;
0077
0078
0079
0080 const Logger& logger() const { return *m_logger; }
0081
0082
0083 using GsfNavigator = typename propagator_t::Navigator;
0084
0085
0086 using GsfActor = detail::GsfActor<bethe_heitler_approx_t, traj_t>;
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096 template <typename source_link_it_t, typename start_parameters_t,
0097 TrackContainerFrontend track_container_t>
0098 auto fit(source_link_it_t begin, source_link_it_t end,
0099 const start_parameters_t& sParameters,
0100 const GsfOptions<traj_t>& options,
0101 const std::vector<const Surface*>& sSequence,
0102 track_container_t& trackContainer) const {
0103
0104 static_assert(
0105 std::is_same_v<DirectNavigator, typename propagator_t::Navigator>);
0106
0107
0108 auto fwdPropInitializer = [&sSequence, this](const auto& opts) {
0109 using Actors = ActorList<GsfActor>;
0110 using PropagatorOptions = typename propagator_t::template Options<Actors>;
0111
0112 PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0113
0114 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0115
0116 propOptions.navigation.surfaces = sSequence;
0117 propOptions.actorList.template get<GsfActor>()
0118 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0119
0120 return propOptions;
0121 };
0122
0123
0124 auto bwdPropInitializer = [&sSequence, this](const auto& opts) {
0125 using Actors = ActorList<GsfActor>;
0126 using PropagatorOptions = typename propagator_t::template Options<Actors>;
0127
0128 PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0129
0130 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0131
0132 propOptions.navigation.surfaces = sSequence;
0133 propOptions.actorList.template get<GsfActor>()
0134 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0135
0136 return propOptions;
0137 };
0138
0139 return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0140 bwdPropInitializer, trackContainer);
0141 }
0142
0143
0144
0145
0146
0147
0148
0149
0150 template <typename source_link_it_t, typename start_parameters_t,
0151 TrackContainerFrontend track_container_t>
0152 auto fit(source_link_it_t begin, source_link_it_t end,
0153 const start_parameters_t& sParameters,
0154 const GsfOptions<traj_t>& options,
0155 track_container_t& trackContainer) const {
0156
0157 static_assert(std::is_same_v<Navigator, typename propagator_t::Navigator>);
0158
0159
0160 auto fwdPropInitializer = [&](const auto& opts) {
0161 using Actors = ActorList<GsfActor, EndOfWorldReached>;
0162 using PropagatorOptions = typename propagator_t::template Options<Actors>;
0163
0164 PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0165
0166 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0167
0168 if (options.useExternalSurfaces) {
0169 for (auto it = begin; it != end; ++it) {
0170 propOptions.navigation.insertExternalSurface(
0171 options.extensions.surfaceAccessor(SourceLink{*it})
0172 ->geometryId());
0173 }
0174 }
0175
0176 propOptions.actorList.template get<GsfActor>()
0177 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0178
0179 return propOptions;
0180 };
0181
0182
0183 auto bwdPropInitializer = [&](const auto& opts) {
0184 using Actors = ActorList<GsfActor, EndOfWorldReached>;
0185 using PropagatorOptions = typename propagator_t::template Options<Actors>;
0186
0187 PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0188
0189 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0190
0191 if (options.useExternalSurfaces) {
0192 for (auto it = begin; it != end; ++it) {
0193 propOptions.navigation.insertExternalSurface(
0194 options.extensions.surfaceAccessor(SourceLink{*it})
0195 ->geometryId());
0196 }
0197 }
0198
0199 propOptions.actorList.template get<GsfActor>()
0200 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0201
0202 return propOptions;
0203 };
0204
0205 return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0206 bwdPropInitializer, trackContainer);
0207 }
0208
0209
0210
0211
0212
0213
0214
0215
0216
0217
0218
0219
0220 template <typename source_link_it_t, typename start_parameters_t,
0221 typename fwd_prop_initializer_t, typename bwd_prop_initializer_t,
0222 TrackContainerFrontend track_container_t>
0223 Acts::Result<typename track_container_t::TrackProxy> fit_impl(
0224 source_link_it_t begin, source_link_it_t end,
0225 const start_parameters_t& sParameters, const GsfOptions<traj_t>& options,
0226 const fwd_prop_initializer_t& fwdPropInitializer,
0227 const bwd_prop_initializer_t& bwdPropInitializer,
0228 track_container_t& trackContainer) const {
0229
0230 auto return_error_or_abort = [&](auto error) {
0231 if (options.abortOnError) {
0232 std::abort();
0233 }
0234 return error;
0235 };
0236
0237
0238
0239 const auto gsfForward = options.propagatorPlainOptions.direction;
0240 const auto gsfBackward = gsfForward.invert();
0241
0242
0243 IntersectionStatus intersectionStatusStartSurface =
0244 sParameters.referenceSurface()
0245 .intersect(GeometryContext{},
0246 sParameters.position(GeometryContext{}),
0247 sParameters.direction(), BoundaryTolerance::None())
0248 .closest()
0249 .status();
0250
0251 if (intersectionStatusStartSurface != IntersectionStatus::onSurface) {
0252 ACTS_DEBUG(
0253 "Surface intersection of start parameters WITH bound-check failed");
0254 }
0255
0256
0257
0258 ACTS_VERBOSE("Preparing " << std::distance(begin, end)
0259 << " input measurements");
0260 std::map<GeometryIdentifier, SourceLink> inputMeasurements;
0261 for (auto it = begin; it != end; ++it) {
0262 SourceLink sl = *it;
0263 inputMeasurements.emplace(
0264 options.extensions.surfaceAccessor(sl)->geometryId(), std::move(sl));
0265 }
0266
0267 ACTS_VERBOSE(
0268 "Gsf: Final measurement map size: " << inputMeasurements.size());
0269
0270 if (sParameters.covariance() == std::nullopt) {
0271 return GsfError::StartParametersHaveNoCovariance;
0272 }
0273
0274
0275
0276
0277 ACTS_VERBOSE("+-----------------------------+");
0278 ACTS_VERBOSE("| Gsf: Do forward propagation |");
0279 ACTS_VERBOSE("+-----------------------------+");
0280
0281 auto fwdResult = [&]() {
0282 auto fwdPropOptions = fwdPropInitializer(options);
0283
0284
0285 auto& actor = fwdPropOptions.actorList.template get<GsfActor>();
0286 actor.setOptions(options);
0287 actor.m_cfg.inputMeasurements = &inputMeasurements;
0288 actor.m_cfg.numberMeasurements = inputMeasurements.size();
0289 actor.m_cfg.inReversePass = false;
0290 actor.m_cfg.logger = m_actorLogger.get();
0291
0292 fwdPropOptions.direction = gsfForward;
0293
0294
0295 using IsMultiParameters =
0296 detail::IsMultiComponentBoundParameters<start_parameters_t>;
0297
0298
0299 std::optional<MultiComponentBoundTrackParameters> params;
0300
0301
0302
0303 if constexpr (!IsMultiParameters::value) {
0304 params = MultiComponentBoundTrackParameters(
0305 sParameters.referenceSurface().getSharedPtr(),
0306 sParameters.parameters(), *sParameters.covariance(),
0307 sParameters.particleHypothesis());
0308 } else {
0309 params = sParameters;
0310 }
0311
0312 auto state = m_propagator.makeState(fwdPropOptions);
0313
0314
0315 using OptionsType = decltype(fwdPropOptions);
0316 using StateType = decltype(state);
0317 using PropagationResultType =
0318 decltype(m_propagator.propagate(std::declval<StateType&>()));
0319 using ResultType = decltype(m_propagator.makeResult(
0320 std::declval<StateType&&>(), std::declval<PropagationResultType>(),
0321 std::declval<const OptionsType&>(), false));
0322
0323 auto initRes = m_propagator.initialize(state, *params);
0324 if (!initRes.ok()) {
0325 return ResultType::failure(initRes.error());
0326 }
0327
0328 auto& r = state.template get<typename GsfActor::result_type>();
0329 r.fittedStates = &trackContainer.trackStateContainer();
0330
0331 auto propagationResult = m_propagator.propagate(state);
0332
0333 return m_propagator.makeResult(std::move(state), propagationResult,
0334 fwdPropOptions, false);
0335 }();
0336
0337 if (!fwdResult.ok()) {
0338 return return_error_or_abort(fwdResult.error());
0339 }
0340
0341 const auto& fwdGsfResult =
0342 fwdResult->template get<typename GsfActor::result_type>();
0343
0344 if (!fwdGsfResult.result.ok()) {
0345 return return_error_or_abort(fwdGsfResult.result.error());
0346 }
0347
0348 if (fwdGsfResult.measurementStates == 0) {
0349 return return_error_or_abort(GsfError::NoMeasurementStatesCreatedForward);
0350 }
0351
0352 ACTS_VERBOSE("Finished forward propagation");
0353 ACTS_VERBOSE("- visited surfaces: " << fwdGsfResult.visitedSurfaces.size());
0354 ACTS_VERBOSE("- processed states: " << fwdGsfResult.processedStates);
0355 ACTS_VERBOSE("- measurement states: " << fwdGsfResult.measurementStates);
0356
0357 std::size_t nInvalidBetheHeitler = fwdGsfResult.nInvalidBetheHeitler.val();
0358 double maxPathXOverX0 = fwdGsfResult.maxPathXOverX0.val();
0359
0360
0361
0362
0363 ACTS_VERBOSE("+------------------------------+");
0364 ACTS_VERBOSE("| Gsf: Do backward propagation |");
0365 ACTS_VERBOSE("+------------------------------+");
0366
0367 auto bwdResult = [&]() {
0368 auto bwdPropOptions = bwdPropInitializer(options);
0369
0370 auto& actor = bwdPropOptions.actorList.template get<GsfActor>();
0371 actor.setOptions(options);
0372 actor.m_cfg.inputMeasurements = &inputMeasurements;
0373 actor.m_cfg.inReversePass = true;
0374 actor.m_cfg.logger = m_actorLogger.get();
0375
0376 bwdPropOptions.direction = gsfBackward;
0377
0378 const Surface& target = options.referenceSurface
0379 ? *options.referenceSurface
0380 : sParameters.referenceSurface();
0381
0382 std::vector<
0383 std::tuple<double, BoundVector, std::optional<BoundSquareMatrix>>>
0384 inflatedParamVector;
0385 assert(!fwdGsfResult.lastMeasurementComponents.empty());
0386 assert(fwdGsfResult.lastMeasurementSurface != nullptr);
0387 for (auto& [w, p, cov] : fwdGsfResult.lastMeasurementComponents) {
0388 inflatedParamVector.emplace_back(
0389 w, p, cov * options.reverseFilteringCovarianceScaling);
0390 }
0391
0392 MultiComponentBoundTrackParameters inflatedParams(
0393 fwdGsfResult.lastMeasurementSurface->getSharedPtr(),
0394 std::move(inflatedParamVector), sParameters.particleHypothesis());
0395
0396 auto state = m_propagator.template makeState<decltype(bwdPropOptions),
0397 MultiStepperSurfaceReached>(
0398 target, bwdPropOptions);
0399
0400
0401 using OptionsType = decltype(bwdPropOptions);
0402 using StateType = decltype(state);
0403 using PropagationResultType =
0404 decltype(m_propagator.propagate(std::declval<StateType&>()));
0405 using ResultType = decltype(m_propagator.makeResult(
0406 std::declval<StateType&&>(), std::declval<PropagationResultType>(),
0407 target, std::declval<const OptionsType&>()));
0408
0409 auto initRes = m_propagator.initialize(state, inflatedParams);
0410 if (!initRes.ok()) {
0411 return ResultType::failure(initRes.error());
0412 }
0413
0414 assert(
0415 (fwdGsfResult.lastMeasurementTip != MultiTrajectoryTraits::kInvalid &&
0416 "tip is invalid"));
0417
0418 auto proxy = trackContainer.trackStateContainer().getTrackState(
0419 fwdGsfResult.lastMeasurementTip);
0420 proxy.shareFrom(TrackStatePropMask::Filtered,
0421 TrackStatePropMask::Smoothed);
0422
0423 auto& r = state.template get<typename GsfActor::result_type>();
0424 r.fittedStates = &trackContainer.trackStateContainer();
0425 r.currentTip = fwdGsfResult.lastMeasurementTip;
0426 r.visitedSurfaces.push_back(&proxy.referenceSurface());
0427 r.surfacesVisitedBwdAgain.push_back(&proxy.referenceSurface());
0428 r.measurementStates++;
0429 r.processedStates++;
0430
0431 auto propagationResult = m_propagator.propagate(state);
0432
0433 return m_propagator.makeResult(std::move(state), propagationResult,
0434 target, bwdPropOptions);
0435 }();
0436
0437 if (!bwdResult.ok()) {
0438 return return_error_or_abort(bwdResult.error());
0439 }
0440
0441 auto& bwdGsfResult =
0442 bwdResult->template get<typename GsfActor::result_type>();
0443
0444 if (!bwdGsfResult.result.ok()) {
0445 return return_error_or_abort(bwdGsfResult.result.error());
0446 }
0447
0448 if (bwdGsfResult.measurementStates == 0) {
0449 return return_error_or_abort(
0450 GsfError::NoMeasurementStatesCreatedBackward);
0451 }
0452
0453
0454
0455 bwdGsfResult.nInvalidBetheHeitler.update();
0456 bwdGsfResult.maxPathXOverX0.update();
0457 bwdGsfResult.sumPathXOverX0.update();
0458 nInvalidBetheHeitler += bwdGsfResult.nInvalidBetheHeitler.val();
0459 maxPathXOverX0 =
0460 std::max(maxPathXOverX0, bwdGsfResult.maxPathXOverX0.val());
0461
0462 if (nInvalidBetheHeitler > 0) {
0463 ACTS_WARNING("Encountered " << nInvalidBetheHeitler
0464 << " cases where x/X0 exceeds the range "
0465 "of the Bethe-Heitler-Approximation. The "
0466 "maximum x/X0 encountered was "
0467 << maxPathXOverX0
0468 << ". Enable DEBUG output "
0469 "for more information.");
0470 }
0471
0472
0473
0474
0475 ACTS_VERBOSE("Gsf - States summary:");
0476 ACTS_VERBOSE("- Fwd measurement states: " << fwdGsfResult.measurementStates
0477 << ", holes: "
0478 << fwdGsfResult.measurementHoles);
0479 ACTS_VERBOSE("- Bwd measurement states: " << bwdGsfResult.measurementStates
0480 << ", holes: "
0481 << bwdGsfResult.measurementHoles);
0482
0483
0484 if (bwdGsfResult.measurementStates != fwdGsfResult.measurementStates) {
0485 ACTS_DEBUG("Fwd and bwd measurement states do not match");
0486 }
0487
0488
0489
0490 const auto& foundBwd = bwdGsfResult.surfacesVisitedBwdAgain;
0491 std::size_t measurementStatesFinal = 0;
0492
0493 for (auto state : fwdGsfResult.fittedStates->reverseTrackStateRange(
0494 fwdGsfResult.currentTip)) {
0495 const bool found =
0496 rangeContainsValue(foundBwd, &state.referenceSurface());
0497 if (!found && state.typeFlags().test(MeasurementFlag)) {
0498 state.typeFlags().set(OutlierFlag);
0499 state.typeFlags().reset(MeasurementFlag);
0500 state.unset(TrackStatePropMask::Smoothed);
0501 }
0502
0503 measurementStatesFinal +=
0504 static_cast<std::size_t>(state.typeFlags().test(MeasurementFlag));
0505 }
0506
0507 if (measurementStatesFinal == 0) {
0508 return return_error_or_abort(GsfError::NoMeasurementStatesCreatedFinal);
0509 }
0510
0511 auto track = trackContainer.makeTrack();
0512 track.tipIndex() = fwdGsfResult.lastMeasurementTip;
0513
0514 if (options.referenceSurface) {
0515 const auto& params = *bwdResult->endParameters;
0516
0517 const auto [finalPars, finalCov] = detail::mergeGaussianMixture(
0518 params.components(), params.referenceSurface(),
0519 options.componentMergeMethod, [](auto& t) {
0520 return std::tie(std::get<0>(t), std::get<1>(t), *std::get<2>(t));
0521 });
0522
0523 track.parameters() = finalPars;
0524 track.covariance() = finalCov;
0525
0526 track.setReferenceSurface(params.referenceSurface().getSharedPtr());
0527
0528 if (trackContainer.hasColumn(
0529 hashString(GsfConstants::kFinalMultiComponentStateColumn))) {
0530 ACTS_DEBUG("Add final multi-component state to track");
0531 track.template component<GsfConstants::FinalMultiComponentState>(
0532 GsfConstants::kFinalMultiComponentStateColumn) = std::move(params);
0533 }
0534 }
0535
0536 if (trackContainer.hasColumn(
0537 hashString(GsfConstants::kFwdMaxMaterialXOverX0))) {
0538 track.template component<double>(GsfConstants::kFwdMaxMaterialXOverX0) =
0539 fwdGsfResult.maxPathXOverX0.val();
0540 }
0541 if (trackContainer.hasColumn(
0542 hashString(GsfConstants::kFwdSumMaterialXOverX0))) {
0543 track.template component<double>(GsfConstants::kFwdSumMaterialXOverX0) =
0544 fwdGsfResult.sumPathXOverX0.val();
0545 }
0546
0547 calculateTrackQuantities(track);
0548
0549 return track;
0550 }
0551 };
0552
0553 }