File indexing completed on 2025-01-19 09:23:37
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/EventData/TrackHelpers.hpp"
0012 #include "Acts/EventData/VectorMultiTrajectory.hpp"
0013 #include "Acts/Propagator/DirectNavigator.hpp"
0014 #include "Acts/Propagator/MultiStepperAborters.hpp"
0015 #include "Acts/Propagator/Navigator.hpp"
0016 #include "Acts/Propagator/StandardAborters.hpp"
0017 #include "Acts/TrackFitting/GsfOptions.hpp"
0018 #include "Acts/TrackFitting/detail/GsfActor.hpp"
0019 #include "Acts/Utilities/Logger.hpp"
0020
0021 namespace Acts {
0022
0023 namespace detail {
0024
0025
0026
0027
0028 template <typename T>
0029 struct IsMultiComponentBoundParameters : public std::false_type {};
0030
0031 template <>
0032 struct IsMultiComponentBoundParameters<MultiComponentBoundTrackParameters>
0033 : public std::true_type {};
0034
0035 }
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048 template <typename propagator_t, typename bethe_heitler_approx_t,
0049 typename traj_t>
0050 struct GaussianSumFitter {
0051 GaussianSumFitter(propagator_t&& propagator, bethe_heitler_approx_t&& bha,
0052 std::unique_ptr<const Logger> _logger =
0053 getDefaultLogger("GSF", Logging::INFO))
0054 : m_propagator(std::move(propagator)),
0055 m_betheHeitlerApproximation(std::move(bha)),
0056 m_logger{std::move(_logger)},
0057 m_actorLogger(m_logger->cloneWithSuffix("Actor")) {}
0058
0059
0060 propagator_t m_propagator;
0061
0062
0063 bethe_heitler_approx_t m_betheHeitlerApproximation;
0064
0065
0066 std::unique_ptr<const Logger> m_logger;
0067 std::unique_ptr<const Logger> m_actorLogger;
0068
0069 const Logger& logger() const { return *m_logger; }
0070
0071
0072 using GsfNavigator = typename propagator_t::Navigator;
0073
0074
0075 using GsfActor = detail::GsfActor<bethe_heitler_approx_t, traj_t>;
0076
0077
0078
0079 struct NavigationBreakAborter {
0080 NavigationBreakAborter() = default;
0081
0082 template <typename propagator_state_t, typename stepper_t,
0083 typename navigator_t>
0084 bool operator()(propagator_state_t& state, const stepper_t& ,
0085 const navigator_t& navigator,
0086 const Logger& ) const {
0087 return navigator.navigationBreak(state.navigation);
0088 }
0089 };
0090
0091
0092 template <typename source_link_it_t, typename start_parameters_t,
0093 typename track_container_t, template <typename> class holder_t>
0094 auto fit(source_link_it_t begin, source_link_it_t end,
0095 const start_parameters_t& sParameters,
0096 const GsfOptions<traj_t>& options,
0097 const std::vector<const Surface*>& sSequence,
0098 TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
0099 const {
0100
0101 static_assert(
0102 std::is_same_v<DirectNavigator, typename propagator_t::Navigator>);
0103
0104
0105 auto fwdPropInitializer = [&sSequence, this](const auto& opts) {
0106 using Actors = ActionList<GsfActor, DirectNavigator::Initializer>;
0107 using Aborters = AbortList<NavigationBreakAborter>;
0108
0109 PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
0110 opts.magFieldContext);
0111
0112 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0113
0114 propOptions.actionList.template get<DirectNavigator::Initializer>()
0115 .navSurfaces = sSequence;
0116 propOptions.actionList.template get<GsfActor>()
0117 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0118
0119 return propOptions;
0120 };
0121
0122
0123 auto bwdPropInitializer = [&sSequence, this](const auto& opts) {
0124 using Actors = ActionList<GsfActor, DirectNavigator::Initializer>;
0125 using Aborters = AbortList<>;
0126
0127 std::vector<const Surface*> backwardSequence(
0128 std::next(sSequence.rbegin()), sSequence.rend());
0129 backwardSequence.push_back(opts.referenceSurface);
0130
0131 PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
0132 opts.magFieldContext);
0133
0134 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0135
0136 propOptions.actionList.template get<DirectNavigator::Initializer>()
0137 .navSurfaces = std::move(backwardSequence);
0138 propOptions.actionList.template get<GsfActor>()
0139 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0140
0141 return propOptions;
0142 };
0143
0144 return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0145 bwdPropInitializer, trackContainer);
0146 }
0147
0148
0149 template <typename source_link_it_t, typename start_parameters_t,
0150 typename track_container_t, template <typename> class holder_t>
0151 auto fit(source_link_it_t begin, source_link_it_t end,
0152 const start_parameters_t& sParameters,
0153 const GsfOptions<traj_t>& options,
0154 TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
0155 const {
0156
0157 static_assert(std::is_same_v<Navigator, typename propagator_t::Navigator>);
0158
0159
0160 auto fwdPropInitializer = [this](const auto& opts) {
0161 using Actors = ActionList<GsfActor>;
0162 using Aborters = AbortList<EndOfWorldReached, NavigationBreakAborter>;
0163
0164 PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
0165 opts.magFieldContext);
0166 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0167 propOptions.actionList.template get<GsfActor>()
0168 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0169
0170 return propOptions;
0171 };
0172
0173
0174 auto bwdPropInitializer = [this](const auto& opts) {
0175 using Actors = ActionList<GsfActor>;
0176 using Aborters = AbortList<EndOfWorldReached>;
0177
0178 PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
0179 opts.magFieldContext);
0180
0181 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0182
0183 propOptions.actionList.template get<GsfActor>()
0184 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0185
0186 return propOptions;
0187 };
0188
0189 return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0190 bwdPropInitializer, trackContainer);
0191 }
0192
0193
0194
0195
0196 template <typename source_link_it_t, typename start_parameters_t,
0197 typename fwd_prop_initializer_t, typename bwd_prop_initializer_t,
0198 typename track_container_t, template <typename> class holder_t>
0199 Acts::Result<
0200 typename TrackContainer<track_container_t, traj_t, holder_t>::TrackProxy>
0201 fit_impl(source_link_it_t begin, source_link_it_t end,
0202 const start_parameters_t& sParameters,
0203 const GsfOptions<traj_t>& options,
0204 const fwd_prop_initializer_t& fwdPropInitializer,
0205 const bwd_prop_initializer_t& bwdPropInitializer,
0206 TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
0207 const {
0208
0209 auto return_error_or_abort = [&](auto error) {
0210 if (options.abortOnError) {
0211 std::abort();
0212 }
0213 return error;
0214 };
0215
0216
0217
0218 const auto gsfForward = options.propagatorPlainOptions.direction;
0219 const auto gsfBackward = gsfForward.invert();
0220
0221
0222 auto intersectionStatusStartSurface =
0223 sParameters.referenceSurface()
0224 .intersect(GeometryContext{},
0225 sParameters.position(GeometryContext{}),
0226 sParameters.direction(), BoundaryCheck(true))
0227 .closest()
0228 .status();
0229
0230 if (intersectionStatusStartSurface != Intersection3D::Status::onSurface) {
0231 ACTS_DEBUG(
0232 "Surface intersection of start parameters WITH bound-check failed");
0233 }
0234
0235
0236
0237 ACTS_VERBOSE("Preparing " << std::distance(begin, end)
0238 << " input measurements");
0239 std::map<GeometryIdentifier, SourceLink> inputMeasurements;
0240 for (auto it = begin; it != end; ++it) {
0241 SourceLink sl = *it;
0242 inputMeasurements.emplace(
0243 options.extensions.surfaceAccessor(sl)->geometryId(), std::move(sl));
0244 }
0245
0246 ACTS_VERBOSE(
0247 "Gsf: Final measurement map size: " << inputMeasurements.size());
0248
0249 if (sParameters.covariance() == std::nullopt) {
0250 return GsfError::StartParametersHaveNoCovariance;
0251 }
0252
0253
0254
0255
0256 ACTS_VERBOSE("+-----------------------------+");
0257 ACTS_VERBOSE("| Gsf: Do forward propagation |");
0258 ACTS_VERBOSE("+-----------------------------+");
0259
0260 auto fwdResult = [&]() {
0261 auto fwdPropOptions = fwdPropInitializer(options);
0262
0263
0264 auto& actor = fwdPropOptions.actionList.template get<GsfActor>();
0265 actor.setOptions(options);
0266 actor.m_cfg.inputMeasurements = &inputMeasurements;
0267 actor.m_cfg.numberMeasurements = inputMeasurements.size();
0268 actor.m_cfg.inReversePass = false;
0269 actor.m_cfg.logger = m_actorLogger.get();
0270
0271 fwdPropOptions.direction = gsfForward;
0272
0273
0274 using IsMultiParameters =
0275 detail::IsMultiComponentBoundParameters<start_parameters_t>;
0276
0277
0278 std::optional<MultiComponentBoundTrackParameters> params;
0279
0280
0281
0282 if constexpr (!IsMultiParameters::value) {
0283 params = MultiComponentBoundTrackParameters(
0284 sParameters.referenceSurface().getSharedPtr(),
0285 sParameters.parameters(), *sParameters.covariance(),
0286 sParameters.particleHypothesis());
0287 } else {
0288 params = sParameters;
0289 }
0290
0291 auto state = m_propagator.makeState(*params, fwdPropOptions);
0292
0293 auto& r = state.template get<typename GsfActor::result_type>();
0294 r.fittedStates = &trackContainer.trackStateContainer();
0295
0296 auto propagationResult = m_propagator.propagate(state);
0297
0298 return m_propagator.makeResult(std::move(state), propagationResult,
0299 fwdPropOptions, false);
0300 }();
0301
0302 if (!fwdResult.ok()) {
0303 return return_error_or_abort(fwdResult.error());
0304 }
0305
0306 const auto& fwdGsfResult =
0307 fwdResult->template get<typename GsfActor::result_type>();
0308
0309 if (!fwdGsfResult.result.ok()) {
0310 return return_error_or_abort(fwdGsfResult.result.error());
0311 }
0312
0313 if (fwdGsfResult.measurementStates == 0) {
0314 return return_error_or_abort(GsfError::NoMeasurementStatesCreatedForward);
0315 }
0316
0317 ACTS_VERBOSE("Finished forward propagation");
0318 ACTS_VERBOSE("- visited surfaces: " << fwdGsfResult.visitedSurfaces.size());
0319 ACTS_VERBOSE("- processed states: " << fwdGsfResult.processedStates);
0320 ACTS_VERBOSE("- measurement states: " << fwdGsfResult.measurementStates);
0321
0322 std::size_t nInvalidBetheHeitler = fwdGsfResult.nInvalidBetheHeitler.val();
0323 double maxPathXOverX0 = fwdGsfResult.maxPathXOverX0.val();
0324
0325
0326
0327
0328 ACTS_VERBOSE("+------------------------------+");
0329 ACTS_VERBOSE("| Gsf: Do backward propagation |");
0330 ACTS_VERBOSE("+------------------------------+");
0331
0332 auto bwdResult = [&]() {
0333 auto bwdPropOptions = bwdPropInitializer(options);
0334
0335 auto& actor = bwdPropOptions.actionList.template get<GsfActor>();
0336 actor.setOptions(options);
0337 actor.m_cfg.inputMeasurements = &inputMeasurements;
0338 actor.m_cfg.inReversePass = true;
0339 actor.m_cfg.logger = m_actorLogger.get();
0340
0341 bwdPropOptions.direction = gsfBackward;
0342
0343 const Surface& target = options.referenceSurface
0344 ? *options.referenceSurface
0345 : sParameters.referenceSurface();
0346
0347 using PM = TrackStatePropMask;
0348
0349 const auto& params = *fwdGsfResult.lastMeasurementState;
0350 auto state =
0351 m_propagator.template makeState<MultiComponentBoundTrackParameters,
0352 decltype(bwdPropOptions),
0353 MultiStepperSurfaceReached>(
0354 params, target, bwdPropOptions);
0355
0356 assert(
0357 (fwdGsfResult.lastMeasurementTip != MultiTrajectoryTraits::kInvalid &&
0358 "tip is invalid"));
0359
0360 auto proxy = trackContainer.trackStateContainer().getTrackState(
0361 fwdGsfResult.lastMeasurementTip);
0362 proxy.shareFrom(TrackStatePropMask::Filtered,
0363 TrackStatePropMask::Smoothed);
0364
0365 auto& r = state.template get<typename GsfActor::result_type>();
0366 r.fittedStates = &trackContainer.trackStateContainer();
0367 r.currentTip = fwdGsfResult.lastMeasurementTip;
0368 r.visitedSurfaces.push_back(&proxy.referenceSurface());
0369 r.surfacesVisitedBwdAgain.push_back(&proxy.referenceSurface());
0370 r.measurementStates++;
0371 r.processedStates++;
0372
0373 auto propagationResult = m_propagator.propagate(state);
0374
0375 return m_propagator.makeResult(std::move(state), propagationResult,
0376 target, bwdPropOptions);
0377 }();
0378
0379 if (!bwdResult.ok()) {
0380 return return_error_or_abort(bwdResult.error());
0381 }
0382
0383 auto& bwdGsfResult =
0384 bwdResult->template get<typename GsfActor::result_type>();
0385
0386 if (!bwdGsfResult.result.ok()) {
0387 return return_error_or_abort(bwdGsfResult.result.error());
0388 }
0389
0390 if (bwdGsfResult.measurementStates == 0) {
0391 return return_error_or_abort(
0392 GsfError::NoMeasurementStatesCreatedBackward);
0393 }
0394
0395
0396
0397 bwdGsfResult.nInvalidBetheHeitler.update();
0398 bwdGsfResult.maxPathXOverX0.update();
0399 bwdGsfResult.sumPathXOverX0.update();
0400 nInvalidBetheHeitler += bwdGsfResult.nInvalidBetheHeitler.val();
0401 maxPathXOverX0 =
0402 std::max(maxPathXOverX0, bwdGsfResult.maxPathXOverX0.val());
0403
0404 if (nInvalidBetheHeitler > 0) {
0405 ACTS_WARNING("Encountered " << nInvalidBetheHeitler
0406 << " cases where x/X0 exceeds the range "
0407 "of the Bethe-Heitler-Approximation. The "
0408 "maximum x/X0 encountered was "
0409 << maxPathXOverX0
0410 << ". Enable DEBUG output "
0411 "for more information.");
0412 }
0413
0414
0415
0416
0417 ACTS_VERBOSE("Gsf - States summary:");
0418 ACTS_VERBOSE("- Fwd measurement states: " << fwdGsfResult.measurementStates
0419 << ", holes: "
0420 << fwdGsfResult.measurementHoles);
0421 ACTS_VERBOSE("- Bwd measurement states: " << bwdGsfResult.measurementStates
0422 << ", holes: "
0423 << bwdGsfResult.measurementHoles);
0424
0425
0426 if (bwdGsfResult.measurementStates != fwdGsfResult.measurementStates) {
0427 ACTS_DEBUG("Fwd and bwd measurement states do not match");
0428 }
0429
0430
0431
0432 const auto& foundBwd = bwdGsfResult.surfacesVisitedBwdAgain;
0433 std::size_t measurementStatesFinal = 0;
0434
0435 for (auto state : fwdGsfResult.fittedStates->reverseTrackStateRange(
0436 fwdGsfResult.currentTip)) {
0437 const bool found = std::find(foundBwd.begin(), foundBwd.end(),
0438 &state.referenceSurface()) != foundBwd.end();
0439 if (!found && state.typeFlags().test(MeasurementFlag)) {
0440 state.typeFlags().set(OutlierFlag);
0441 state.typeFlags().reset(MeasurementFlag);
0442 state.unset(TrackStatePropMask::Smoothed);
0443 }
0444
0445 measurementStatesFinal +=
0446 static_cast<std::size_t>(state.typeFlags().test(MeasurementFlag));
0447 }
0448
0449 if (measurementStatesFinal == 0) {
0450 return return_error_or_abort(GsfError::NoMeasurementStatesCreatedFinal);
0451 }
0452
0453 auto track = trackContainer.makeTrack();
0454 track.tipIndex() = fwdGsfResult.lastMeasurementTip;
0455
0456 if (options.referenceSurface) {
0457 const auto& params = *bwdResult->endParameters;
0458
0459 const auto [finalPars, finalCov] = detail::mergeGaussianMixture(
0460 params.components(), params.referenceSurface(),
0461 options.componentMergeMethod, [](auto& t) {
0462 return std::tie(std::get<0>(t), std::get<1>(t), *std::get<2>(t));
0463 });
0464
0465 track.parameters() = finalPars;
0466 track.covariance() = finalCov;
0467
0468 track.setReferenceSurface(params.referenceSurface().getSharedPtr());
0469
0470 if (trackContainer.hasColumn(
0471 hashString(GsfConstants::kFinalMultiComponentStateColumn))) {
0472 ACTS_DEBUG("Add final multi-component state to track");
0473 track.template component<GsfConstants::FinalMultiComponentState>(
0474 GsfConstants::kFinalMultiComponentStateColumn) = std::move(params);
0475 }
0476 }
0477
0478 if (trackContainer.hasColumn(
0479 hashString(GsfConstants::kFwdMaxMaterialXOverX0))) {
0480 track.template component<double>(GsfConstants::kFwdMaxMaterialXOverX0) =
0481 fwdGsfResult.maxPathXOverX0.val();
0482 }
0483 if (trackContainer.hasColumn(
0484 hashString(GsfConstants::kFwdSumMaterialXOverX0))) {
0485 track.template component<double>(GsfConstants::kFwdSumMaterialXOverX0) =
0486 fwdGsfResult.sumPathXOverX0.val();
0487 }
0488
0489 calculateTrackQuantities(track);
0490
0491 return track;
0492 }
0493 };
0494
0495 }