File indexing completed on 2026-05-19 07:33:57
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Definitions/Common.hpp"
0012 #include "Acts/Definitions/TrackParametrization.hpp"
0013 #include "Acts/EventData/MultiTrajectory.hpp"
0014 #include "Acts/EventData/Types.hpp"
0015 #include "Acts/Propagator/detail/PointwiseMaterialInteraction.hpp"
0016 #include "Acts/Surfaces/Surface.hpp"
0017 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0018 #include "Acts/TrackFitting/GsfComponent.hpp"
0019 #include "Acts/TrackFitting/GsfOptions.hpp"
0020 #include "Acts/TrackFitting/detail/GsfComponentMerging.hpp"
0021 #include "Acts/TrackFitting/detail/GsfUtils.hpp"
0022 #include "Acts/Utilities/Helpers.hpp"
0023
0024 #include <map>
0025
0026 namespace Acts::detail::Gsf {
0027
0028 template <typename traj_t>
0029 struct GsfResult {
0030
0031 traj_t* fittedStates{nullptr};
0032
0033
0034 TrackIndexType currentTip = kTrackIndexInvalid;
0035
0036
0037 TrackIndexType lastMeasurementTip = kTrackIndexInvalid;
0038
0039
0040
0041 std::vector<std::tuple<double, BoundVector, BoundMatrix>>
0042 lastMeasurementComponents;
0043
0044
0045 const Acts::Surface* lastMeasurementSurface = nullptr;
0046
0047
0048 std::size_t measurementStates = 0;
0049 std::size_t measurementHoles = 0;
0050 std::size_t processedStates = 0;
0051
0052 std::vector<const Surface*> visitedSurfaces;
0053 std::vector<const Surface*> surfacesVisitedBwdAgain;
0054
0055
0056 Updatable<std::size_t> nInvalidBetheHeitler;
0057 Updatable<double> maxPathXOverX0;
0058 Updatable<double> sumPathXOverX0;
0059
0060
0061 std::vector<BetheHeitlerApprox::Component> betheHeitlerCache;
0062
0063
0064 std::vector<GsfComponent> componentCache;
0065 };
0066
0067
0068 template <typename traj_t>
0069 struct GsfActor {
0070
0071 GsfActor() = default;
0072
0073
0074 using result_type = GsfResult<traj_t>;
0075
0076
0077 struct Config {
0078
0079 std::size_t maxComponents = 16;
0080
0081
0082 const std::map<GeometryIdentifier, SourceLink>* inputMeasurements = nullptr;
0083
0084
0085
0086 const BetheHeitlerApprox* bethe_heitler_approx = nullptr;
0087
0088
0089 bool multipleScattering = true;
0090
0091
0092 double weightCutoff = 1.0e-4;
0093
0094
0095
0096
0097 bool disableAllMaterialHandling = false;
0098
0099
0100 bool abortOnError = false;
0101
0102
0103
0104 std::optional<std::size_t> numberMeasurements;
0105
0106
0107 GsfExtensions<traj_t> extensions;
0108
0109
0110
0111
0112 bool inReversePass = false;
0113
0114
0115 ComponentMergeMethod mergeMethod = ComponentMergeMethod::eMaxWeight;
0116
0117 const Logger* logger{nullptr};
0118
0119
0120 const CalibrationContext* calibrationContext{nullptr};
0121
0122 } m_cfg;
0123
0124 const Logger& logger() const { return *m_cfg.logger; }
0125
0126 using TemporaryStates = detail::Gsf::TemporaryStates<traj_t>;
0127
0128 using FiltProjector = MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139 template <typename propagator_state_t, typename stepper_t,
0140 typename navigator_t>
0141 Result<void> act(propagator_state_t& state, const stepper_t& stepper,
0142 const navigator_t& navigator, result_type& result,
0143 const Logger& ) const {
0144 assert(result.fittedStates && "No MultiTrajectory set");
0145
0146
0147
0148 const ScopedGsfInfoPrinterAndChecker printer(state, stepper, navigator,
0149 logger());
0150
0151
0152 if (navigator.currentSurface(state.navigation) == nullptr) {
0153 return Result<void>::success();
0154 }
0155
0156 const auto& surface = *navigator.currentSurface(state.navigation);
0157 ACTS_VERBOSE("Step is at surface " << surface.geometryId());
0158
0159
0160
0161 [[maybe_unused]] auto stepperComponents =
0162 stepper.constComponentIterable(state.stepping);
0163 assert(weightsAreNormalized(stepperComponents,
0164 [](const auto& cmp) { return cmp.weight(); }));
0165
0166
0167
0168
0169 using Status [[maybe_unused]] = IntersectionStatus;
0170 assert(std::all_of(
0171 stepperComponents.begin(), stepperComponents.end(),
0172 [](const auto& cmp) { return cmp.status() == Status::onSurface; }));
0173
0174
0175
0176 const bool visited = rangeContainsValue(result.visitedSurfaces, &surface);
0177
0178 if (visited) {
0179 ACTS_VERBOSE("Already visited surface, return");
0180 return Result<void>::success();
0181 }
0182
0183 result.visitedSurfaces.push_back(&surface);
0184
0185
0186 const auto foundSourceLink =
0187 m_cfg.inputMeasurements->find(surface.geometryId());
0188 const bool haveMaterial =
0189 surface.surfaceMaterial() && !m_cfg.disableAllMaterialHandling;
0190 const bool haveMeasurement =
0191 foundSourceLink != m_cfg.inputMeasurements->end();
0192
0193 ACTS_VERBOSE(std::boolalpha << "haveMaterial " << haveMaterial
0194 << ", haveMeasurement: " << haveMeasurement);
0195
0196
0197
0198
0199
0200
0201 if (!haveMaterial && !haveMeasurement) {
0202
0203 if (result.processedStates > 0 && surface.isSensitive()) {
0204 TemporaryStates tmpStates;
0205 noMeasurementUpdate(state, stepper, surface, result, tmpStates, true);
0206 }
0207 return Result<void>::success();
0208 }
0209
0210
0211
0212
0213 if (haveMeasurement) {
0214 result.maxPathXOverX0.update();
0215 result.sumPathXOverX0.update();
0216 result.nInvalidBetheHeitler.update();
0217 }
0218
0219 for (auto cmp : stepper.componentIterable(state.stepping)) {
0220 cmp.singleStepper(stepper).transportCovarianceToBound(cmp.state(),
0221 surface);
0222 }
0223
0224 if (m_cfg.multipleScattering && haveMaterial) {
0225 if (haveMeasurement) {
0226 applyMultipleScattering(
0227 state, stepper, surface,
0228 determineMaterialUpdateMode(state, navigator,
0229 MaterialUpdateMode::PreUpdate),
0230 logger());
0231 } else {
0232 applyMultipleScattering(
0233 state, stepper, surface,
0234 determineMaterialUpdateMode(state, navigator,
0235 MaterialUpdateMode::FullUpdate),
0236 logger());
0237 }
0238 }
0239
0240
0241
0242
0243 if (!haveMaterial) {
0244 TemporaryStates tmpStates;
0245
0246 auto res = kalmanUpdate(state, stepper, surface, result, tmpStates,
0247 foundSourceLink->second);
0248
0249 if (!res.ok()) {
0250 if (m_cfg.abortOnError) {
0251 std::abort();
0252 }
0253 return res.error();
0254 }
0255
0256 updateStepper(state, stepper, tmpStates, m_cfg.weightCutoff);
0257 }
0258
0259
0260
0261 else {
0262 TemporaryStates tmpStates;
0263 Result<void> res;
0264
0265 if (haveMeasurement) {
0266 res = kalmanUpdate(state, stepper, surface, result, tmpStates,
0267 foundSourceLink->second);
0268 } else {
0269 res = noMeasurementUpdate(state, stepper, surface, result, tmpStates,
0270 false);
0271 }
0272
0273 if (!res.ok()) {
0274 if (m_cfg.abortOnError) {
0275 std::abort();
0276 }
0277 return res.error();
0278 }
0279
0280
0281 std::vector<GsfComponent>& componentCache = result.componentCache;
0282 componentCache.clear();
0283
0284 convoluteComponents(
0285 state, stepper, tmpStates, *m_cfg.bethe_heitler_approx,
0286 result.betheHeitlerCache, m_cfg.weightCutoff, componentCache,
0287 result.nInvalidBetheHeitler.tmp(), result.maxPathXOverX0.tmp(),
0288 result.sumPathXOverX0.tmp(), logger());
0289
0290 if (componentCache.empty()) {
0291 ACTS_WARNING(
0292 "No components left after applying energy loss. "
0293 "Is the weight cutoff "
0294 << m_cfg.weightCutoff << " too high?");
0295 ACTS_WARNING("Return to propagator without applying energy loss");
0296 return Result<void>::success();
0297 }
0298
0299
0300 const auto finalCmpNumber = std::min(
0301 static_cast<std::size_t>(stepper.maxComponents), m_cfg.maxComponents);
0302 m_cfg.extensions.mixtureReducer(componentCache, finalCmpNumber, surface);
0303
0304 removeLowWeightComponents(componentCache, m_cfg.weightCutoff);
0305
0306 updateStepper(state, stepper, surface, componentCache, logger());
0307 }
0308
0309
0310 if (m_cfg.multipleScattering && haveMaterial && haveMeasurement) {
0311 applyMultipleScattering(
0312 state, stepper, surface,
0313 determineMaterialUpdateMode(state, navigator,
0314 MaterialUpdateMode::PostUpdate),
0315 logger());
0316 }
0317
0318 return Result<void>::success();
0319 }
0320
0321 template <typename propagator_state_t, typename stepper_t,
0322 typename navigator_t>
0323 bool checkAbort(propagator_state_t& , const stepper_t& ,
0324 const navigator_t& , const result_type& result,
0325 const Logger& ) const {
0326 if (m_cfg.numberMeasurements &&
0327 result.measurementStates == m_cfg.numberMeasurements) {
0328 ACTS_VERBOSE("Stop navigation because all measurements are found");
0329 return true;
0330 }
0331
0332 return false;
0333 }
0334
0335
0336
0337 template <typename propagator_state_t, typename stepper_t>
0338 Result<void> kalmanUpdate(propagator_state_t& state, const stepper_t& stepper,
0339 const Surface& surface, result_type& result,
0340 TemporaryStates& tmpStates,
0341 const SourceLink& sourceLink) const {
0342
0343 std::vector<TrackIndexType> allTips;
0344 allTips.reserve(stepper.numberComponents(state.stepping));
0345
0346 for (auto cmp : stepper.componentIterable(state.stepping)) {
0347 auto singleState = cmp.singleState(state);
0348 const auto& singleStepper = cmp.singleStepper(stepper);
0349
0350
0351
0352 TrackStatePropMask mask =
0353 TrackStatePropMask::Predicted | TrackStatePropMask::Filtered |
0354 TrackStatePropMask::Jacobian | TrackStatePropMask::Calibrated;
0355 typename traj_t::TrackStateProxy trackStateProxy =
0356 tmpStates.traj.makeTrackState(mask, kTrackIndexInvalid);
0357 typename traj_t::ConstTrackStateProxy trackStateProxyConst{
0358 trackStateProxy};
0359
0360
0361
0362 {
0363 trackStateProxy.setReferenceSurface(surface.getSharedPtr());
0364
0365 auto res =
0366 singleStepper.boundState(singleState.stepping, surface, false);
0367 if (!res.ok()) {
0368 ACTS_DEBUG("Propagate to surface " << surface.geometryId()
0369 << " failed: " << res.error());
0370 return res.error();
0371 }
0372 const auto& [boundParams, jacobian, pathLength] = *res;
0373
0374
0375 trackStateProxy.predicted() = boundParams.parameters();
0376 trackStateProxy.predictedCovariance() = singleState.stepping.cov;
0377
0378 trackStateProxy.jacobian() = jacobian;
0379 trackStateProxy.pathLength() = pathLength;
0380 }
0381
0382
0383
0384 m_cfg.extensions.calibrator(state.geoContext, *m_cfg.calibrationContext,
0385 sourceLink, trackStateProxy);
0386
0387 if (!m_cfg.extensions.outlierFinder(trackStateProxyConst)) {
0388
0389 auto updateRes = m_cfg.extensions.updater(state.geoContext,
0390 trackStateProxy, logger());
0391 if (!updateRes.ok()) {
0392 ACTS_DEBUG("Update step failed: " << updateRes.error());
0393 return updateRes.error();
0394 }
0395
0396 tmpStates.tips.push_back(trackStateProxy.index());
0397 tmpStates.weights[trackStateProxy.index()] = cmp.weight();
0398 }
0399
0400 allTips.push_back(trackStateProxy.index());
0401 }
0402
0403 const bool isOutlier = tmpStates.tips.empty();
0404
0405 if (!isOutlier) {
0406 computePosteriorWeights(tmpStates.traj, tmpStates.tips,
0407 tmpStates.weights);
0408 normalizeWeights(tmpStates.tips, [&](auto idx) -> double& {
0409 return tmpStates.weights.at(idx);
0410 });
0411 } else {
0412 auto cmps = stepper.componentIterable(state.stepping);
0413 for (const auto [cmp, idx] : zip(cmps, allTips)) {
0414 typename traj_t::TrackStateProxy trackStateProxy =
0415 tmpStates.traj.getTrackState(idx);
0416
0417
0418
0419 trackStateProxy.shareFrom(trackStateProxy,
0420 TrackStatePropMask::Predicted,
0421 TrackStatePropMask::Filtered);
0422
0423 tmpStates.tips.push_back(trackStateProxy.index());
0424 tmpStates.weights[trackStateProxy.index()] = cmp.weight();
0425 }
0426 }
0427
0428
0429 ++result.processedStates;
0430 if (!isOutlier) {
0431 ++result.measurementStates;
0432 }
0433
0434 updateMultiTrajectory(
0435 result, tmpStates, surface,
0436 TrackStateType()
0437 .setHasParameters()
0438 .setHasMaterial(surface.surfaceMaterial() != nullptr)
0439 .setHasMeasurement()
0440 .setIsOutlier(isOutlier));
0441
0442 result.lastMeasurementTip = result.currentTip;
0443 result.lastMeasurementSurface = &surface;
0444
0445
0446
0447 result.lastMeasurementComponents.clear();
0448
0449 FiltProjector proj{tmpStates.traj, tmpStates.weights};
0450 for (const auto& idx : tmpStates.tips) {
0451 const auto& [w, p, c] = proj(idx);
0452
0453 if (w > 0.0) {
0454 result.lastMeasurementComponents.push_back({w, p, c});
0455 }
0456 }
0457
0458
0459 return Result<void>::success();
0460 }
0461
0462 template <typename propagator_state_t, typename stepper_t>
0463 Result<void> noMeasurementUpdate(propagator_state_t& state,
0464 const stepper_t& stepper,
0465 const Surface& surface, result_type& result,
0466 TemporaryStates& tmpStates,
0467 bool doCovTransport) const {
0468 for (auto cmp : stepper.componentIterable(state.stepping)) {
0469 auto& singleState = cmp.state();
0470 const auto& singleStepper = cmp.singleStepper(stepper);
0471
0472
0473
0474 TrackStatePropMask mask =
0475 TrackStatePropMask::Predicted | TrackStatePropMask::Jacobian;
0476 typename traj_t::TrackStateProxy trackStateProxy =
0477 tmpStates.traj.makeTrackState(mask, kTrackIndexInvalid);
0478
0479
0480
0481 {
0482 trackStateProxy.setReferenceSurface(surface.getSharedPtr());
0483
0484 auto res =
0485 singleStepper.boundState(singleState, surface, doCovTransport);
0486 if (!res.ok()) {
0487 return res.error();
0488 }
0489 const auto& [boundParams, jacobian, pathLength] = *res;
0490
0491
0492 trackStateProxy.predicted() = boundParams.parameters();
0493 trackStateProxy.predictedCovariance() = singleState.cov;
0494
0495 trackStateProxy.jacobian() = jacobian;
0496 trackStateProxy.pathLength() = pathLength;
0497
0498
0499
0500 trackStateProxy.shareFrom(trackStateProxy,
0501 TrackStatePropMask::Predicted,
0502 TrackStatePropMask::Filtered);
0503 }
0504
0505 tmpStates.tips.push_back(trackStateProxy.index());
0506 tmpStates.weights[trackStateProxy.index()] = cmp.weight();
0507 }
0508
0509 const bool precedingMeasurementExists = result.processedStates > 0;
0510 const bool isHole = surface.isSensitive();
0511
0512
0513 ++result.processedStates;
0514 if (precedingMeasurementExists && isHole) {
0515 ++result.measurementHoles;
0516 }
0517
0518 updateMultiTrajectory(
0519 result, tmpStates, surface,
0520 TrackStateType()
0521 .setHasParameters()
0522 .setHasMaterial(surface.surfaceMaterial() != nullptr)
0523 .setIsHole(isHole));
0524
0525 return Result<void>::success();
0526 }
0527
0528 void updateMultiTrajectory(result_type& result,
0529 const TemporaryStates& tmpStates,
0530 const Surface& surface,
0531 TrackStateType type) const {
0532 using PrtProjector =
0533 MultiTrajectoryProjector<StatesType::ePredicted, traj_t>;
0534 using FltProjector =
0535 MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;
0536
0537 if (!m_cfg.inReversePass) {
0538 assert(!tmpStates.tips.empty() &&
0539 "No components to update multi-trajectory");
0540
0541 const auto firstCmpProxy =
0542 tmpStates.traj.getTrackState(tmpStates.tips.front());
0543
0544 auto combinedStateMask = TrackStatePropMask::Predicted;
0545 if (type.isMeasurement()) {
0546 combinedStateMask |= TrackStatePropMask::Calibrated |
0547 TrackStatePropMask::Filtered |
0548 TrackStatePropMask::Smoothed;
0549 } else if (type.isOutlier()) {
0550 combinedStateMask |= TrackStatePropMask::Calibrated;
0551 }
0552 auto combinedState = result.fittedStates->makeTrackState(
0553 combinedStateMask, result.currentTip);
0554 result.currentTip = combinedState.index();
0555
0556
0557 auto copyMask = TrackStatePropMask::None;
0558 if (ACTS_CHECK_BIT(combinedStateMask, TrackStatePropMask::Calibrated)) {
0559
0560 copyMask |= TrackStatePropMask::Calibrated;
0561 }
0562 combinedState.copyFrom(firstCmpProxy, copyMask);
0563 combinedState.typeFlags() = type;
0564
0565 auto [prtMean, prtCov] = mergeGaussianMixture(
0566 tmpStates.tips, PrtProjector{tmpStates.traj, tmpStates.weights},
0567 surface, m_cfg.mergeMethod);
0568 combinedState.predicted() = prtMean;
0569 combinedState.predictedCovariance() = prtCov;
0570
0571 if (type.isMeasurement()) {
0572 auto [fltMean, fltCov] = mergeGaussianMixture(
0573 tmpStates.tips, FltProjector{tmpStates.traj, tmpStates.weights},
0574 surface, m_cfg.mergeMethod);
0575 combinedState.filtered() = fltMean;
0576 combinedState.filteredCovariance() = fltCov;
0577
0578
0579
0580 combinedState.smoothed() = BoundVector::Constant(-2);
0581 combinedState.smoothedCovariance() = BoundMatrix::Constant(-2);
0582 } else {
0583 combinedState.shareFrom(TrackStatePropMask::Predicted,
0584 TrackStatePropMask::Filtered);
0585 }
0586
0587 } else {
0588 assert((result.currentTip != kTrackIndexInvalid && "tip not valid"));
0589
0590 result.fittedStates->applyBackwards(
0591 result.currentTip, [&](auto trackState) {
0592 if (&trackState.referenceSurface() != &surface) {
0593 return true;
0594 }
0595
0596 result.surfacesVisitedBwdAgain.push_back(&surface);
0597
0598 if (trackState.hasSmoothed()) {
0599 const auto [smtMean, smtCov] = mergeGaussianMixture(
0600 tmpStates.tips,
0601 FltProjector{tmpStates.traj, tmpStates.weights}, surface,
0602 m_cfg.mergeMethod);
0603
0604 trackState.smoothed() = smtMean;
0605 trackState.smoothedCovariance() = smtCov;
0606 }
0607
0608 return false;
0609 });
0610 }
0611 }
0612
0613
0614
0615 void setOptions(const GsfOptions<traj_t>& options) {
0616 m_cfg.maxComponents = options.maxComponents;
0617 m_cfg.extensions = options.extensions;
0618 m_cfg.abortOnError = options.abortOnError;
0619 m_cfg.disableAllMaterialHandling = options.disableAllMaterialHandling;
0620 m_cfg.weightCutoff = options.weightCutoff;
0621 m_cfg.mergeMethod = options.componentMergeMethod;
0622 m_cfg.calibrationContext = &options.calibrationContext.get();
0623 }
0624 };
0625
0626 }