File indexing completed on 2025-01-18 09:11:06
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/EventData/VectorMultiTrajectory.hpp"
0012 #include "Acts/Propagator/DirectNavigator.hpp"
0013 #include "Acts/Propagator/MultiStepperAborters.hpp"
0014 #include "Acts/Propagator/Navigator.hpp"
0015 #include "Acts/Propagator/StandardAborters.hpp"
0016 #include "Acts/Surfaces/BoundaryTolerance.hpp"
0017 #include "Acts/TrackFitting/GsfError.hpp"
0018 #include "Acts/TrackFitting/GsfOptions.hpp"
0019 #include "Acts/TrackFitting/detail/GsfActor.hpp"
0020 #include "Acts/Utilities/Helpers.hpp"
0021 #include "Acts/Utilities/Logger.hpp"
0022 #include "Acts/Utilities/TrackHelpers.hpp"
0023
0024 namespace Acts {
0025
0026 namespace detail {
0027
0028
0029
0030
0031 template <typename T>
0032 struct IsMultiComponentBoundParameters : public std::false_type {};
0033
0034 template <>
0035 struct IsMultiComponentBoundParameters<MultiComponentBoundTrackParameters>
0036 : public std::true_type {};
0037
0038 }
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051 template <typename propagator_t, typename bethe_heitler_approx_t,
0052 typename traj_t>
0053 struct GaussianSumFitter {
0054 GaussianSumFitter(propagator_t&& propagator, bethe_heitler_approx_t&& bha,
0055 std::unique_ptr<const Logger> _logger =
0056 getDefaultLogger("GSF", Logging::INFO))
0057 : m_propagator(std::move(propagator)),
0058 m_betheHeitlerApproximation(std::move(bha)),
0059 m_logger{std::move(_logger)},
0060 m_actorLogger(m_logger->cloneWithSuffix("Actor")) {}
0061
0062
0063 propagator_t m_propagator;
0064
0065
0066 bethe_heitler_approx_t m_betheHeitlerApproximation;
0067
0068
0069 std::unique_ptr<const Logger> m_logger;
0070 std::unique_ptr<const Logger> m_actorLogger;
0071
0072 const Logger& logger() const { return *m_logger; }
0073
0074
0075 using GsfNavigator = typename propagator_t::Navigator;
0076
0077
0078 using GsfActor = detail::GsfActor<bethe_heitler_approx_t, traj_t>;
0079
0080
0081 template <typename source_link_it_t, typename start_parameters_t,
0082 TrackContainerFrontend track_container_t>
0083 auto fit(source_link_it_t begin, source_link_it_t end,
0084 const start_parameters_t& sParameters,
0085 const GsfOptions<traj_t>& options,
0086 const std::vector<const Surface*>& sSequence,
0087 track_container_t& trackContainer) const {
0088
0089 static_assert(
0090 std::is_same_v<DirectNavigator, typename propagator_t::Navigator>);
0091
0092
0093 auto fwdPropInitializer = [&sSequence, this](const auto& opts) {
0094 using Actors = ActorList<GsfActor>;
0095 using PropagatorOptions = typename propagator_t::template Options<Actors>;
0096
0097 PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0098
0099 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0100
0101 propOptions.navigation.surfaces = sSequence;
0102 propOptions.actorList.template get<GsfActor>()
0103 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0104
0105 return propOptions;
0106 };
0107
0108
0109 auto bwdPropInitializer = [&sSequence, this](const auto& opts) {
0110 using Actors = ActorList<GsfActor>;
0111 using PropagatorOptions = typename propagator_t::template Options<Actors>;
0112
0113 PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0114
0115 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0116
0117 propOptions.navigation.surfaces = sSequence;
0118 propOptions.actorList.template get<GsfActor>()
0119 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0120
0121 return propOptions;
0122 };
0123
0124 return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0125 bwdPropInitializer, trackContainer);
0126 }
0127
0128
0129 template <typename source_link_it_t, typename start_parameters_t,
0130 TrackContainerFrontend track_container_t>
0131 auto fit(source_link_it_t begin, source_link_it_t end,
0132 const start_parameters_t& sParameters,
0133 const GsfOptions<traj_t>& options,
0134 track_container_t& trackContainer) const {
0135
0136 static_assert(std::is_same_v<Navigator, typename propagator_t::Navigator>);
0137
0138
0139 auto fwdPropInitializer = [this](const auto& opts) {
0140 using Actors = ActorList<GsfActor, EndOfWorldReached>;
0141 using PropagatorOptions = typename propagator_t::template Options<Actors>;
0142
0143 PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0144
0145 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0146
0147 propOptions.actorList.template get<GsfActor>()
0148 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0149
0150 return propOptions;
0151 };
0152
0153
0154 auto bwdPropInitializer = [this](const auto& opts) {
0155 using Actors = ActorList<GsfActor, EndOfWorldReached>;
0156 using PropagatorOptions = typename propagator_t::template Options<Actors>;
0157
0158 PropagatorOptions propOptions(opts.geoContext, opts.magFieldContext);
0159
0160 propOptions.setPlainOptions(opts.propagatorPlainOptions);
0161
0162 propOptions.actorList.template get<GsfActor>()
0163 .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
0164
0165 return propOptions;
0166 };
0167
0168 return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
0169 bwdPropInitializer, trackContainer);
0170 }
0171
0172
0173
0174
0175 template <typename source_link_it_t, typename start_parameters_t,
0176 typename fwd_prop_initializer_t, typename bwd_prop_initializer_t,
0177 TrackContainerFrontend track_container_t>
0178 Acts::Result<typename track_container_t::TrackProxy> fit_impl(
0179 source_link_it_t begin, source_link_it_t end,
0180 const start_parameters_t& sParameters, const GsfOptions<traj_t>& options,
0181 const fwd_prop_initializer_t& fwdPropInitializer,
0182 const bwd_prop_initializer_t& bwdPropInitializer,
0183 track_container_t& trackContainer) const {
0184
0185 auto return_error_or_abort = [&](auto error) {
0186 if (options.abortOnError) {
0187 std::abort();
0188 }
0189 return error;
0190 };
0191
0192
0193
0194 const auto gsfForward = options.propagatorPlainOptions.direction;
0195 const auto gsfBackward = gsfForward.invert();
0196
0197
0198 auto intersectionStatusStartSurface =
0199 sParameters.referenceSurface()
0200 .intersect(GeometryContext{},
0201 sParameters.position(GeometryContext{}),
0202 sParameters.direction(), BoundaryTolerance::None())
0203 .closest()
0204 .status();
0205
0206 if (intersectionStatusStartSurface != IntersectionStatus::onSurface) {
0207 ACTS_DEBUG(
0208 "Surface intersection of start parameters WITH bound-check failed");
0209 }
0210
0211
0212
0213 ACTS_VERBOSE("Preparing " << std::distance(begin, end)
0214 << " input measurements");
0215 std::map<GeometryIdentifier, SourceLink> inputMeasurements;
0216 for (auto it = begin; it != end; ++it) {
0217 SourceLink sl = *it;
0218 inputMeasurements.emplace(
0219 options.extensions.surfaceAccessor(sl)->geometryId(), std::move(sl));
0220 }
0221
0222 ACTS_VERBOSE(
0223 "Gsf: Final measurement map size: " << inputMeasurements.size());
0224
0225 if (sParameters.covariance() == std::nullopt) {
0226 return GsfError::StartParametersHaveNoCovariance;
0227 }
0228
0229
0230
0231
0232 ACTS_VERBOSE("+-----------------------------+");
0233 ACTS_VERBOSE("| Gsf: Do forward propagation |");
0234 ACTS_VERBOSE("+-----------------------------+");
0235
0236 auto fwdResult = [&]() {
0237 auto fwdPropOptions = fwdPropInitializer(options);
0238
0239
0240 auto& actor = fwdPropOptions.actorList.template get<GsfActor>();
0241 actor.setOptions(options);
0242 actor.m_cfg.inputMeasurements = &inputMeasurements;
0243 actor.m_cfg.numberMeasurements = inputMeasurements.size();
0244 actor.m_cfg.inReversePass = false;
0245 actor.m_cfg.logger = m_actorLogger.get();
0246
0247 fwdPropOptions.direction = gsfForward;
0248
0249
0250 using IsMultiParameters =
0251 detail::IsMultiComponentBoundParameters<start_parameters_t>;
0252
0253
0254 std::optional<MultiComponentBoundTrackParameters> params;
0255
0256
0257
0258 if constexpr (!IsMultiParameters::value) {
0259 params = MultiComponentBoundTrackParameters(
0260 sParameters.referenceSurface().getSharedPtr(),
0261 sParameters.parameters(), *sParameters.covariance(),
0262 sParameters.particleHypothesis());
0263 } else {
0264 params = sParameters;
0265 }
0266
0267 auto state = m_propagator.makeState(*params, fwdPropOptions);
0268
0269 auto& r = state.template get<typename GsfActor::result_type>();
0270 r.fittedStates = &trackContainer.trackStateContainer();
0271
0272 auto propagationResult = m_propagator.propagate(state);
0273
0274 return m_propagator.makeResult(std::move(state), propagationResult,
0275 fwdPropOptions, false);
0276 }();
0277
0278 if (!fwdResult.ok()) {
0279 return return_error_or_abort(fwdResult.error());
0280 }
0281
0282 const auto& fwdGsfResult =
0283 fwdResult->template get<typename GsfActor::result_type>();
0284
0285 if (!fwdGsfResult.result.ok()) {
0286 return return_error_or_abort(fwdGsfResult.result.error());
0287 }
0288
0289 if (fwdGsfResult.measurementStates == 0) {
0290 return return_error_or_abort(GsfError::NoMeasurementStatesCreatedForward);
0291 }
0292
0293 ACTS_VERBOSE("Finished forward propagation");
0294 ACTS_VERBOSE("- visited surfaces: " << fwdGsfResult.visitedSurfaces.size());
0295 ACTS_VERBOSE("- processed states: " << fwdGsfResult.processedStates);
0296 ACTS_VERBOSE("- measurement states: " << fwdGsfResult.measurementStates);
0297
0298 std::size_t nInvalidBetheHeitler = fwdGsfResult.nInvalidBetheHeitler.val();
0299 double maxPathXOverX0 = fwdGsfResult.maxPathXOverX0.val();
0300
0301
0302
0303
0304 ACTS_VERBOSE("+------------------------------+");
0305 ACTS_VERBOSE("| Gsf: Do backward propagation |");
0306 ACTS_VERBOSE("+------------------------------+");
0307
0308 auto bwdResult = [&]() {
0309 auto bwdPropOptions = bwdPropInitializer(options);
0310
0311 auto& actor = bwdPropOptions.actorList.template get<GsfActor>();
0312 actor.setOptions(options);
0313 actor.m_cfg.inputMeasurements = &inputMeasurements;
0314 actor.m_cfg.inReversePass = true;
0315 actor.m_cfg.logger = m_actorLogger.get();
0316
0317 bwdPropOptions.direction = gsfBackward;
0318
0319 const Surface& target = options.referenceSurface
0320 ? *options.referenceSurface
0321 : sParameters.referenceSurface();
0322
0323 const auto& params = *fwdGsfResult.lastMeasurementState;
0324 auto state =
0325 m_propagator.template makeState<MultiComponentBoundTrackParameters,
0326 decltype(bwdPropOptions),
0327 MultiStepperSurfaceReached>(
0328 params, target, bwdPropOptions);
0329
0330 assert(
0331 (fwdGsfResult.lastMeasurementTip != MultiTrajectoryTraits::kInvalid &&
0332 "tip is invalid"));
0333
0334 auto proxy = trackContainer.trackStateContainer().getTrackState(
0335 fwdGsfResult.lastMeasurementTip);
0336 proxy.shareFrom(TrackStatePropMask::Filtered,
0337 TrackStatePropMask::Smoothed);
0338
0339 auto& r = state.template get<typename GsfActor::result_type>();
0340 r.fittedStates = &trackContainer.trackStateContainer();
0341 r.currentTip = fwdGsfResult.lastMeasurementTip;
0342 r.visitedSurfaces.push_back(&proxy.referenceSurface());
0343 r.surfacesVisitedBwdAgain.push_back(&proxy.referenceSurface());
0344 r.measurementStates++;
0345 r.processedStates++;
0346
0347 auto propagationResult = m_propagator.propagate(state);
0348
0349 return m_propagator.makeResult(std::move(state), propagationResult,
0350 target, bwdPropOptions);
0351 }();
0352
0353 if (!bwdResult.ok()) {
0354 return return_error_or_abort(bwdResult.error());
0355 }
0356
0357 auto& bwdGsfResult =
0358 bwdResult->template get<typename GsfActor::result_type>();
0359
0360 if (!bwdGsfResult.result.ok()) {
0361 return return_error_or_abort(bwdGsfResult.result.error());
0362 }
0363
0364 if (bwdGsfResult.measurementStates == 0) {
0365 return return_error_or_abort(
0366 GsfError::NoMeasurementStatesCreatedBackward);
0367 }
0368
0369
0370
0371 bwdGsfResult.nInvalidBetheHeitler.update();
0372 bwdGsfResult.maxPathXOverX0.update();
0373 bwdGsfResult.sumPathXOverX0.update();
0374 nInvalidBetheHeitler += bwdGsfResult.nInvalidBetheHeitler.val();
0375 maxPathXOverX0 =
0376 std::max(maxPathXOverX0, bwdGsfResult.maxPathXOverX0.val());
0377
0378 if (nInvalidBetheHeitler > 0) {
0379 ACTS_WARNING("Encountered " << nInvalidBetheHeitler
0380 << " cases where x/X0 exceeds the range "
0381 "of the Bethe-Heitler-Approximation. The "
0382 "maximum x/X0 encountered was "
0383 << maxPathXOverX0
0384 << ". Enable DEBUG output "
0385 "for more information.");
0386 }
0387
0388
0389
0390
0391 ACTS_VERBOSE("Gsf - States summary:");
0392 ACTS_VERBOSE("- Fwd measurement states: " << fwdGsfResult.measurementStates
0393 << ", holes: "
0394 << fwdGsfResult.measurementHoles);
0395 ACTS_VERBOSE("- Bwd measurement states: " << bwdGsfResult.measurementStates
0396 << ", holes: "
0397 << bwdGsfResult.measurementHoles);
0398
0399
0400 if (bwdGsfResult.measurementStates != fwdGsfResult.measurementStates) {
0401 ACTS_DEBUG("Fwd and bwd measurement states do not match");
0402 }
0403
0404
0405
0406 const auto& foundBwd = bwdGsfResult.surfacesVisitedBwdAgain;
0407 std::size_t measurementStatesFinal = 0;
0408
0409 for (auto state : fwdGsfResult.fittedStates->reverseTrackStateRange(
0410 fwdGsfResult.currentTip)) {
0411 const bool found =
0412 rangeContainsValue(foundBwd, &state.referenceSurface());
0413 if (!found && state.typeFlags().test(MeasurementFlag)) {
0414 state.typeFlags().set(OutlierFlag);
0415 state.typeFlags().reset(MeasurementFlag);
0416 state.unset(TrackStatePropMask::Smoothed);
0417 }
0418
0419 measurementStatesFinal +=
0420 static_cast<std::size_t>(state.typeFlags().test(MeasurementFlag));
0421 }
0422
0423 if (measurementStatesFinal == 0) {
0424 return return_error_or_abort(GsfError::NoMeasurementStatesCreatedFinal);
0425 }
0426
0427 auto track = trackContainer.makeTrack();
0428 track.tipIndex() = fwdGsfResult.lastMeasurementTip;
0429
0430 if (options.referenceSurface) {
0431 const auto& params = *bwdResult->endParameters;
0432
0433 const auto [finalPars, finalCov] = detail::mergeGaussianMixture(
0434 params.components(), params.referenceSurface(),
0435 options.componentMergeMethod, [](auto& t) {
0436 return std::tie(std::get<0>(t), std::get<1>(t), *std::get<2>(t));
0437 });
0438
0439 track.parameters() = finalPars;
0440 track.covariance() = finalCov;
0441
0442 track.setReferenceSurface(params.referenceSurface().getSharedPtr());
0443
0444 if (trackContainer.hasColumn(
0445 hashString(GsfConstants::kFinalMultiComponentStateColumn))) {
0446 ACTS_DEBUG("Add final multi-component state to track");
0447 track.template component<GsfConstants::FinalMultiComponentState>(
0448 GsfConstants::kFinalMultiComponentStateColumn) = std::move(params);
0449 }
0450 }
0451
0452 if (trackContainer.hasColumn(
0453 hashString(GsfConstants::kFwdMaxMaterialXOverX0))) {
0454 track.template component<double>(GsfConstants::kFwdMaxMaterialXOverX0) =
0455 fwdGsfResult.maxPathXOverX0.val();
0456 }
0457 if (trackContainer.hasColumn(
0458 hashString(GsfConstants::kFwdSumMaterialXOverX0))) {
0459 track.template component<double>(GsfConstants::kFwdSumMaterialXOverX0) =
0460 fwdGsfResult.sumPathXOverX0.val();
0461 }
0462
0463 calculateTrackQuantities(track);
0464
0465 return track;
0466 }
0467 };
0468
0469 }