File indexing completed on 2026-05-14 07:54:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Definitions/Common.hpp"
0012 #include "Acts/Definitions/TrackParametrization.hpp"
0013 #include "Acts/EventData/MultiTrajectory.hpp"
0014 #include "Acts/EventData/Types.hpp"
0015 #include "Acts/Propagator/detail/PointwiseMaterialInteraction.hpp"
0016 #include "Acts/Surfaces/Surface.hpp"
0017 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0018 #include "Acts/TrackFitting/GsfOptions.hpp"
0019 #include "Acts/TrackFitting/detail/GsfComponentMerging.hpp"
0020 #include "Acts/TrackFitting/detail/GsfUtils.hpp"
0021 #include "Acts/Utilities/Helpers.hpp"
0022
0023 #include <map>
0024
0025 namespace Acts::detail::Gsf {
0026
0027 template <typename traj_t>
0028 struct GsfResult {
0029
0030 traj_t* fittedStates{nullptr};
0031
0032
0033 TrackIndexType currentTip = kTrackIndexInvalid;
0034
0035
0036 TrackIndexType lastMeasurementTip = kTrackIndexInvalid;
0037
0038
0039
0040 std::vector<std::tuple<double, BoundVector, BoundMatrix>>
0041 lastMeasurementComponents;
0042
0043
0044 const Acts::Surface* lastMeasurementSurface = nullptr;
0045
0046
0047 std::size_t measurementStates = 0;
0048 std::size_t measurementHoles = 0;
0049 std::size_t processedStates = 0;
0050
0051 std::vector<const Surface*> visitedSurfaces;
0052 std::vector<const Surface*> surfacesVisitedBwdAgain;
0053
0054
0055 Updatable<std::size_t> nInvalidBetheHeitler;
0056 Updatable<double> maxPathXOverX0;
0057 Updatable<double> sumPathXOverX0;
0058
0059
0060 std::vector<BetheHeitlerApprox::Component> betheHeitlerCache;
0061
0062
0063 std::vector<GsfComponent> componentCache;
0064 };
0065
0066
0067 template <typename traj_t>
0068 struct GsfActor {
0069
0070 GsfActor() = default;
0071
0072 using ComponentCache = GsfComponent;
0073
0074
0075 using result_type = GsfResult<traj_t>;
0076
0077
0078 struct Config {
0079
0080 std::size_t maxComponents = 16;
0081
0082
0083 const std::map<GeometryIdentifier, SourceLink>* inputMeasurements = nullptr;
0084
0085
0086
0087 const BetheHeitlerApprox* bethe_heitler_approx = nullptr;
0088
0089
0090 bool multipleScattering = true;
0091
0092
0093 double weightCutoff = 1.0e-4;
0094
0095
0096
0097
0098 bool disableAllMaterialHandling = false;
0099
0100
0101 bool abortOnError = false;
0102
0103
0104
0105 std::optional<std::size_t> numberMeasurements;
0106
0107
0108 GsfExtensions<traj_t> extensions;
0109
0110
0111
0112
0113 bool inReversePass = false;
0114
0115
0116 ComponentMergeMethod mergeMethod = ComponentMergeMethod::eMaxWeight;
0117
0118 const Logger* logger{nullptr};
0119
0120
0121 const CalibrationContext* calibrationContext{nullptr};
0122
0123 } m_cfg;
0124
0125 const Logger& logger() const { return *m_cfg.logger; }
0126
0127 using TemporaryStates = detail::Gsf::TemporaryStates<traj_t>;
0128
0129 using FiltProjector = MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140 template <typename propagator_state_t, typename stepper_t,
0141 typename navigator_t>
0142 Result<void> act(propagator_state_t& state, const stepper_t& stepper,
0143 const navigator_t& navigator, result_type& result,
0144 const Logger& ) const {
0145 assert(result.fittedStates && "No MultiTrajectory set");
0146
0147
0148
0149 const ScopedGsfInfoPrinterAndChecker printer(state, stepper, navigator,
0150 logger());
0151
0152
0153 if (!navigator.currentSurface(state.navigation)) {
0154 return Result<void>::success();
0155 }
0156
0157 const auto& surface = *navigator.currentSurface(state.navigation);
0158 ACTS_VERBOSE("Step is at surface " << surface.geometryId());
0159
0160
0161
0162 [[maybe_unused]] auto stepperComponents =
0163 stepper.constComponentIterable(state.stepping);
0164 assert(weightsAreNormalized(stepperComponents,
0165 [](const auto& cmp) { return cmp.weight(); }));
0166
0167
0168
0169
0170 using Status [[maybe_unused]] = IntersectionStatus;
0171 assert(std::all_of(
0172 stepperComponents.begin(), stepperComponents.end(),
0173 [](const auto& cmp) { return cmp.status() == Status::onSurface; }));
0174
0175
0176
0177 const bool visited = rangeContainsValue(result.visitedSurfaces, &surface);
0178
0179 if (visited) {
0180 ACTS_VERBOSE("Already visited surface, return");
0181 return Result<void>::success();
0182 }
0183
0184 result.visitedSurfaces.push_back(&surface);
0185
0186
0187 const auto foundSourceLink =
0188 m_cfg.inputMeasurements->find(surface.geometryId());
0189 const bool haveMaterial =
0190 navigator.currentSurface(state.navigation)->surfaceMaterial() &&
0191 !m_cfg.disableAllMaterialHandling;
0192 const bool haveMeasurement =
0193 foundSourceLink != m_cfg.inputMeasurements->end();
0194
0195 ACTS_VERBOSE(std::boolalpha << "haveMaterial " << haveMaterial
0196 << ", haveMeasurement: " << haveMeasurement);
0197
0198
0199
0200
0201
0202
0203 if (!haveMaterial && !haveMeasurement) {
0204
0205 if (result.processedStates > 0 && surface.isSensitive()) {
0206 TemporaryStates tmpStates;
0207 noMeasurementUpdate(state, stepper, navigator, result, tmpStates, true);
0208 }
0209 return Result<void>::success();
0210 }
0211
0212
0213
0214
0215 if (haveMeasurement) {
0216 result.maxPathXOverX0.update();
0217 result.sumPathXOverX0.update();
0218 result.nInvalidBetheHeitler.update();
0219 }
0220
0221 for (auto cmp : stepper.componentIterable(state.stepping)) {
0222 cmp.singleStepper(stepper).transportCovarianceToBound(cmp.state(),
0223 surface);
0224 }
0225
0226 if (m_cfg.multipleScattering && haveMaterial) {
0227 if (haveMeasurement) {
0228 applyMultipleScattering(
0229 state, stepper, navigator,
0230 determineMaterialUpdateMode(state, navigator,
0231 MaterialUpdateMode::PreUpdate),
0232 logger());
0233 } else {
0234 applyMultipleScattering(
0235 state, stepper, navigator,
0236 determineMaterialUpdateMode(state, navigator,
0237 MaterialUpdateMode::FullUpdate),
0238 logger());
0239 }
0240 }
0241
0242
0243
0244
0245 if (!haveMaterial) {
0246 TemporaryStates tmpStates;
0247
0248 auto res = kalmanUpdate(state, stepper, navigator, result, tmpStates,
0249 foundSourceLink->second);
0250
0251 if (!res.ok()) {
0252 if (m_cfg.abortOnError) {
0253 std::abort();
0254 }
0255 return res.error();
0256 }
0257
0258 updateStepper(state, stepper, tmpStates, m_cfg.weightCutoff);
0259 }
0260
0261
0262
0263 else {
0264 TemporaryStates tmpStates;
0265 Result<void> res;
0266
0267 if (haveMeasurement) {
0268 res = kalmanUpdate(state, stepper, navigator, result, tmpStates,
0269 foundSourceLink->second);
0270 } else {
0271 res = noMeasurementUpdate(state, stepper, navigator, result, tmpStates,
0272 false);
0273 }
0274
0275 if (!res.ok()) {
0276 if (m_cfg.abortOnError) {
0277 std::abort();
0278 }
0279 return res.error();
0280 }
0281
0282
0283 std::vector<ComponentCache>& componentCache = result.componentCache;
0284 componentCache.clear();
0285
0286 convoluteComponents(state, stepper, navigator, tmpStates,
0287 *m_cfg.bethe_heitler_approx, result.betheHeitlerCache,
0288 m_cfg.weightCutoff, componentCache,
0289 result.nInvalidBetheHeitler, result.maxPathXOverX0,
0290 result.sumPathXOverX0, logger());
0291
0292 if (componentCache.empty()) {
0293 ACTS_WARNING(
0294 "No components left after applying energy loss. "
0295 "Is the weight cutoff "
0296 << m_cfg.weightCutoff << " too high?");
0297 ACTS_WARNING("Return to propagator without applying energy loss");
0298 return Result<void>::success();
0299 }
0300
0301
0302 const auto finalCmpNumber = std::min(
0303 static_cast<std::size_t>(stepper.maxComponents), m_cfg.maxComponents);
0304 m_cfg.extensions.mixtureReducer(componentCache, finalCmpNumber, surface);
0305
0306 removeLowWeightComponents(componentCache, m_cfg.weightCutoff);
0307
0308 updateStepper(state, stepper, navigator, componentCache, logger());
0309 }
0310
0311
0312 if (m_cfg.multipleScattering && haveMaterial && haveMeasurement) {
0313 applyMultipleScattering(
0314 state, stepper, navigator,
0315 determineMaterialUpdateMode(state, navigator,
0316 MaterialUpdateMode::PostUpdate),
0317 logger());
0318 }
0319
0320 return Result<void>::success();
0321 }
0322
0323 template <typename propagator_state_t, typename stepper_t,
0324 typename navigator_t>
0325 bool checkAbort(propagator_state_t& , const stepper_t& ,
0326 const navigator_t& , const result_type& result,
0327 const Logger& ) const {
0328 if (m_cfg.numberMeasurements &&
0329 result.measurementStates == m_cfg.numberMeasurements) {
0330 ACTS_VERBOSE("Stop navigation because all measurements are found");
0331 return true;
0332 }
0333
0334 return false;
0335 }
0336
0337
0338
0339 template <typename propagator_state_t, typename stepper_t,
0340 typename navigator_t>
0341 Result<void> kalmanUpdate(propagator_state_t& state, const stepper_t& stepper,
0342 const navigator_t& navigator, result_type& result,
0343 TemporaryStates& tmpStates,
0344 const SourceLink& sourceLink) const {
0345 const auto& surface = *navigator.currentSurface(state.navigation);
0346
0347
0348 std::vector<TrackIndexType> allTips;
0349 allTips.reserve(stepper.numberComponents(state.stepping));
0350
0351 for (auto cmp : stepper.componentIterable(state.stepping)) {
0352 auto singleState = cmp.singleState(state);
0353 const auto& singleStepper = cmp.singleStepper(stepper);
0354
0355
0356
0357 TrackStatePropMask mask =
0358 TrackStatePropMask::Predicted | TrackStatePropMask::Filtered |
0359 TrackStatePropMask::Jacobian | TrackStatePropMask::Calibrated;
0360 typename traj_t::TrackStateProxy trackStateProxy =
0361 tmpStates.traj.makeTrackState(mask, kTrackIndexInvalid);
0362 typename traj_t::ConstTrackStateProxy trackStateProxyConst{
0363 trackStateProxy};
0364
0365
0366
0367 {
0368 trackStateProxy.setReferenceSurface(surface.getSharedPtr());
0369
0370 auto res =
0371 singleStepper.boundState(singleState.stepping, surface, false);
0372 if (!res.ok()) {
0373 ACTS_DEBUG("Propagate to surface " << surface.geometryId()
0374 << " failed: " << res.error());
0375 return res.error();
0376 }
0377 const auto& [boundParams, jacobian, pathLength] = *res;
0378
0379
0380 trackStateProxy.predicted() = boundParams.parameters();
0381 trackStateProxy.predictedCovariance() = singleState.stepping.cov;
0382
0383 trackStateProxy.jacobian() = jacobian;
0384 trackStateProxy.pathLength() = pathLength;
0385 }
0386
0387
0388
0389 m_cfg.extensions.calibrator(state.geoContext, *m_cfg.calibrationContext,
0390 sourceLink, trackStateProxy);
0391
0392 if (!m_cfg.extensions.outlierFinder(trackStateProxyConst)) {
0393
0394 auto updateRes = m_cfg.extensions.updater(state.geoContext,
0395 trackStateProxy, logger());
0396 if (!updateRes.ok()) {
0397 ACTS_DEBUG("Update step failed: " << updateRes.error());
0398 return updateRes.error();
0399 }
0400
0401 tmpStates.tips.push_back(trackStateProxy.index());
0402 tmpStates.weights[trackStateProxy.index()] = cmp.weight();
0403 }
0404
0405 allTips.push_back(trackStateProxy.index());
0406 }
0407
0408 const bool isOutlier = tmpStates.tips.empty();
0409
0410 if (!isOutlier) {
0411 computePosteriorWeights(tmpStates.traj, tmpStates.tips,
0412 tmpStates.weights);
0413 normalizeWeights(tmpStates.tips, [&](auto idx) -> double& {
0414 return tmpStates.weights.at(idx);
0415 });
0416 } else {
0417 auto cmps = stepper.componentIterable(state.stepping);
0418 for (const auto [cmp, idx] : zip(cmps, allTips)) {
0419 typename traj_t::TrackStateProxy trackStateProxy =
0420 tmpStates.traj.getTrackState(idx);
0421
0422
0423
0424 trackStateProxy.shareFrom(trackStateProxy,
0425 TrackStatePropMask::Predicted,
0426 TrackStatePropMask::Filtered);
0427
0428 tmpStates.tips.push_back(trackStateProxy.index());
0429 tmpStates.weights[trackStateProxy.index()] = cmp.weight();
0430 }
0431 }
0432
0433
0434 ++result.processedStates;
0435 if (!isOutlier) {
0436 ++result.measurementStates;
0437 }
0438
0439 updateMultiTrajectory(
0440 result, tmpStates, surface,
0441 TrackStateType()
0442 .setHasParameters()
0443 .setHasMaterial(surface.surfaceMaterial() != nullptr)
0444 .setHasMeasurement()
0445 .setIsOutlier(isOutlier));
0446
0447 result.lastMeasurementTip = result.currentTip;
0448 result.lastMeasurementSurface = &surface;
0449
0450
0451
0452 result.lastMeasurementComponents.clear();
0453
0454 FiltProjector proj{tmpStates.traj, tmpStates.weights};
0455 for (const auto& idx : tmpStates.tips) {
0456 const auto& [w, p, c] = proj(idx);
0457
0458 if (w > 0.0) {
0459 result.lastMeasurementComponents.push_back({w, p, c});
0460 }
0461 }
0462
0463
0464 return Result<void>::success();
0465 }
0466
0467 template <typename propagator_state_t, typename stepper_t,
0468 typename navigator_t>
0469 Result<void> noMeasurementUpdate(propagator_state_t& state,
0470 const stepper_t& stepper,
0471 const navigator_t& navigator,
0472 result_type& result,
0473 TemporaryStates& tmpStates,
0474 bool doCovTransport) const {
0475 const Surface& surface = *navigator.currentSurface(state.navigation);
0476
0477 for (auto cmp : stepper.componentIterable(state.stepping)) {
0478 auto& singleState = cmp.state();
0479 const auto& singleStepper = cmp.singleStepper(stepper);
0480
0481
0482
0483 TrackStatePropMask mask =
0484 TrackStatePropMask::Predicted | TrackStatePropMask::Jacobian;
0485 typename traj_t::TrackStateProxy trackStateProxy =
0486 tmpStates.traj.makeTrackState(mask, kTrackIndexInvalid);
0487
0488
0489
0490 {
0491 trackStateProxy.setReferenceSurface(surface.getSharedPtr());
0492
0493 auto res =
0494 singleStepper.boundState(singleState, surface, doCovTransport);
0495 if (!res.ok()) {
0496 return res.error();
0497 }
0498 const auto& [boundParams, jacobian, pathLength] = *res;
0499
0500
0501 trackStateProxy.predicted() = boundParams.parameters();
0502 trackStateProxy.predictedCovariance() = singleState.cov;
0503
0504 trackStateProxy.jacobian() = jacobian;
0505 trackStateProxy.pathLength() = pathLength;
0506
0507
0508
0509 trackStateProxy.shareFrom(trackStateProxy,
0510 TrackStatePropMask::Predicted,
0511 TrackStatePropMask::Filtered);
0512 }
0513
0514 tmpStates.tips.push_back(trackStateProxy.index());
0515 tmpStates.weights[trackStateProxy.index()] = cmp.weight();
0516 }
0517
0518 const bool precedingMeasurementExists = result.processedStates > 0;
0519 const bool isHole = surface.isSensitive();
0520
0521
0522 ++result.processedStates;
0523 if (precedingMeasurementExists && isHole) {
0524 ++result.measurementHoles;
0525 }
0526
0527 updateMultiTrajectory(
0528 result, tmpStates, surface,
0529 TrackStateType()
0530 .setHasParameters()
0531 .setHasMaterial(surface.surfaceMaterial() != nullptr)
0532 .setIsHole(isHole));
0533
0534 return Result<void>::success();
0535 }
0536
0537 void updateMultiTrajectory(result_type& result,
0538 const TemporaryStates& tmpStates,
0539 const Surface& surface,
0540 TrackStateType type) const {
0541 using PrtProjector =
0542 MultiTrajectoryProjector<StatesType::ePredicted, traj_t>;
0543 using FltProjector =
0544 MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;
0545
0546 if (!m_cfg.inReversePass) {
0547 assert(!tmpStates.tips.empty() &&
0548 "No components to update multi-trajectory");
0549
0550 const auto firstCmpProxy =
0551 tmpStates.traj.getTrackState(tmpStates.tips.front());
0552
0553 auto combinedStateMask = TrackStatePropMask::Predicted;
0554 if (type.isMeasurement()) {
0555 combinedStateMask |= TrackStatePropMask::Calibrated |
0556 TrackStatePropMask::Filtered |
0557 TrackStatePropMask::Smoothed;
0558 } else if (type.isOutlier()) {
0559 combinedStateMask |= TrackStatePropMask::Calibrated;
0560 }
0561 auto combinedState = result.fittedStates->makeTrackState(
0562 combinedStateMask, result.currentTip);
0563 result.currentTip = combinedState.index();
0564
0565
0566 auto copyMask = TrackStatePropMask::None;
0567 if (ACTS_CHECK_BIT(combinedStateMask, TrackStatePropMask::Calibrated)) {
0568
0569 copyMask |= TrackStatePropMask::Calibrated;
0570 }
0571 combinedState.copyFrom(firstCmpProxy, copyMask);
0572 combinedState.typeFlags() = type;
0573
0574 auto [prtMean, prtCov] =
0575 mergeGaussianMixture(tmpStates.tips, surface, m_cfg.mergeMethod,
0576 PrtProjector{tmpStates.traj, tmpStates.weights});
0577 combinedState.predicted() = prtMean;
0578 combinedState.predictedCovariance() = prtCov;
0579
0580 if (type.isMeasurement()) {
0581 auto [fltMean, fltCov] = mergeGaussianMixture(
0582 tmpStates.tips, surface, m_cfg.mergeMethod,
0583 FltProjector{tmpStates.traj, tmpStates.weights});
0584 combinedState.filtered() = fltMean;
0585 combinedState.filteredCovariance() = fltCov;
0586
0587
0588
0589 combinedState.smoothed() = BoundVector::Constant(-2);
0590 combinedState.smoothedCovariance() = BoundMatrix::Constant(-2);
0591 } else {
0592 combinedState.shareFrom(TrackStatePropMask::Predicted,
0593 TrackStatePropMask::Filtered);
0594 }
0595
0596 } else {
0597 assert((result.currentTip != kTrackIndexInvalid && "tip not valid"));
0598
0599 result.fittedStates->applyBackwards(
0600 result.currentTip, [&](auto trackState) {
0601 auto fSurface = &trackState.referenceSurface();
0602 if (fSurface == &surface) {
0603 result.surfacesVisitedBwdAgain.push_back(&surface);
0604
0605 if (trackState.hasSmoothed()) {
0606 const auto [smtMean, smtCov] = mergeGaussianMixture(
0607 tmpStates.tips, surface, m_cfg.mergeMethod,
0608 FltProjector{tmpStates.traj, tmpStates.weights});
0609
0610 trackState.smoothed() = smtMean;
0611 trackState.smoothedCovariance() = smtCov;
0612 }
0613 return false;
0614 }
0615 return true;
0616 });
0617 }
0618 }
0619
0620
0621
0622 void setOptions(const GsfOptions<traj_t>& options) {
0623 m_cfg.maxComponents = options.maxComponents;
0624 m_cfg.extensions = options.extensions;
0625 m_cfg.abortOnError = options.abortOnError;
0626 m_cfg.disableAllMaterialHandling = options.disableAllMaterialHandling;
0627 m_cfg.weightCutoff = options.weightCutoff;
0628 m_cfg.mergeMethod = options.componentMergeMethod;
0629 m_cfg.calibrationContext = &options.calibrationContext.get();
0630 }
0631 };
0632
0633 }