File indexing completed on 2026-04-08 07:47:21
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/EventData/MultiTrajectory.hpp"
0010 #include "Acts/EventData/TrackContainer.hpp"
0011 #include "Acts/EventData/VectorMultiTrajectory.hpp"
0012 #include "Acts/EventData/VectorTrackContainer.hpp"
0013 #include "Acts/Geometry/GeometryIdentifier.hpp"
0014 #include "Acts/Propagator/DirectNavigator.hpp"
0015 #include "Acts/Propagator/MultiEigenStepperLoop.hpp"
0016 #include "Acts/Propagator/Navigator.hpp"
0017 #include "Acts/Propagator/Propagator.hpp"
0018 #include "Acts/TrackFitting/GainMatrixUpdater.hpp"
0019 #include "Acts/TrackFitting/GaussianSumFitter.hpp"
0020 #include "Acts/TrackFitting/GsfMixtureReduction.hpp"
0021 #include "Acts/TrackFitting/GsfOptions.hpp"
0022 #include "Acts/Utilities/Delegate.hpp"
0023 #include "Acts/Utilities/HashedString.hpp"
0024 #include "Acts/Utilities/Logger.hpp"
0025 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0026 #include "ActsExamples/EventData/MeasurementCalibration.hpp"
0027 #include "ActsExamples/EventData/Track.hpp"
0028 #include "ActsExamples/TrackFitting/RefittingCalibrator.hpp"
0029 #include "ActsExamples/TrackFitting/TrackFitterFunction.hpp"
0030
0031 #include <cstddef>
0032 #include <memory>
0033 #include <string>
0034 #include <string_view>
0035 #include <utility>
0036 #include <vector>
0037
0038 using namespace ActsExamples;
0039
0040 namespace {
0041
0042 using MultiStepper =
0043 Acts::MultiEigenStepperLoop<Acts::EigenStepperDefaultExtension,
0044 Acts::MaxWeightReducerLoop>;
0045 using Propagator = Acts::Propagator<MultiStepper, Acts::Navigator>;
0046 using DirectPropagator = Acts::Propagator<MultiStepper, Acts::DirectNavigator>;
0047
0048 using Fitter = Acts::GaussianSumFitter<Propagator, Acts::VectorMultiTrajectory>;
0049 using DirectFitter =
0050 Acts::GaussianSumFitter<DirectPropagator, Acts::VectorMultiTrajectory>;
0051 using TrackContainer =
0052 Acts::TrackContainer<Acts::VectorTrackContainer,
0053 Acts::VectorMultiTrajectory, std::shared_ptr>;
0054
0055 struct GsfFitterFunctionImpl final : public TrackFitterFunction {
0056 Fitter fitter;
0057 DirectFitter directFitter;
0058
0059 Acts::GainMatrixUpdater updater;
0060
0061 std::size_t maxComponents = 0;
0062 double weightCutoff = 0;
0063 const double momentumCutoff = 0;
0064 bool abortOnError = false;
0065 bool disableAllMaterialHandling = false;
0066 MixtureReductionAlgorithm reductionAlg =
0067 MixtureReductionAlgorithm::KLDistance;
0068 Acts::ComponentMergeMethod mergeMethod =
0069 Acts::ComponentMergeMethod::eMaxWeight;
0070 double reverseFilteringCovarianceScaling = 100.0;
0071
0072 IndexSourceLink::SurfaceAccessor m_slSurfaceAccessor;
0073
0074 GsfFitterFunctionImpl(Fitter&& f, DirectFitter&& df,
0075 const Acts::TrackingGeometry& trkGeo)
0076 : fitter(std::move(f)),
0077 directFitter(std::move(df)),
0078 m_slSurfaceAccessor{trkGeo} {}
0079
0080 template <typename calibrator_t>
0081 auto makeGsfOptions(const GeneralFitterOptions& options,
0082 const calibrator_t& calibrator) const {
0083 Acts::GsfExtensions<Acts::VectorMultiTrajectory> extensions;
0084 extensions.updater.connect<
0085 &Acts::GainMatrixUpdater::operator()<Acts::VectorMultiTrajectory>>(
0086 &updater);
0087
0088 Acts::GsfOptions<Acts::VectorMultiTrajectory> gsfOptions{
0089 options.geoContext, options.magFieldContext,
0090 options.calibrationContext};
0091 gsfOptions.extensions = extensions;
0092 gsfOptions.propagatorPlainOptions = options.propOptions;
0093 gsfOptions.referenceSurface = options.referenceSurface;
0094 gsfOptions.maxComponents = maxComponents;
0095 gsfOptions.weightCutoff = weightCutoff;
0096 gsfOptions.abortOnError = abortOnError;
0097 gsfOptions.disableAllMaterialHandling = disableAllMaterialHandling;
0098 gsfOptions.componentMergeMethod = mergeMethod;
0099 gsfOptions.reverseFilteringCovarianceScaling =
0100 reverseFilteringCovarianceScaling;
0101
0102 gsfOptions.extensions.calibrator.connect<&calibrator_t::calibrate>(
0103 &calibrator);
0104
0105 if (options.doRefit) {
0106 gsfOptions.extensions.surfaceAccessor
0107 .connect<&RefittingCalibrator::accessSurface>();
0108 } else {
0109 gsfOptions.extensions.surfaceAccessor
0110 .connect<&IndexSourceLink::SurfaceAccessor::operator()>(
0111 &m_slSurfaceAccessor);
0112 }
0113 switch (reductionAlg) {
0114 case MixtureReductionAlgorithm::weightCut: {
0115 gsfOptions.extensions.mixtureReducer
0116 .connect<&Acts::reduceMixtureLargestWeights>();
0117 } break;
0118 case MixtureReductionAlgorithm::KLDistance: {
0119 gsfOptions.extensions.mixtureReducer
0120 .connect<&Acts::reduceMixtureWithKLDistance>();
0121 } break;
0122 case MixtureReductionAlgorithm::KLDistanceNaive: {
0123 gsfOptions.extensions.mixtureReducer
0124 .connect<&Acts::reduceMixtureWithKLDistanceNaive>();
0125 } break;
0126 }
0127
0128 return gsfOptions;
0129 }
0130
0131 TrackFitterResult operator()(const std::vector<Acts::SourceLink>& sourceLinks,
0132 const TrackParameters& initialParameters,
0133 const GeneralFitterOptions& options,
0134 const MeasurementCalibratorAdapter& calibrator,
0135 TrackContainer& tracks) const override {
0136 const auto gsfOptions = makeGsfOptions(options, calibrator);
0137
0138 using namespace Acts::GsfConstants;
0139 if (!tracks.hasColumn(Acts::hashString(kFinalMultiComponentStateColumn))) {
0140 std::string key(kFinalMultiComponentStateColumn);
0141 tracks.template addColumn<FinalMultiComponentState>(key);
0142 }
0143 if (!tracks.hasColumn(Acts::hashString(kFwdMaxMaterialXOverX0))) {
0144 tracks.template addColumn<double>(std::string(kFwdMaxMaterialXOverX0));
0145 }
0146 if (!tracks.hasColumn(Acts::hashString(kFwdSumMaterialXOverX0))) {
0147 tracks.template addColumn<double>(std::string(kFwdSumMaterialXOverX0));
0148 }
0149
0150 return fitter.fit(sourceLinks.begin(), sourceLinks.end(), initialParameters,
0151 gsfOptions, tracks);
0152 }
0153
0154 TrackFitterResult operator()(
0155 const std::vector<Acts::SourceLink>& sourceLinks,
0156 const TrackParameters& initialParameters,
0157 const GeneralFitterOptions& options,
0158 const RefittingCalibrator& calibrator,
0159 const std::vector<const Acts::Surface*>& surfaceSequence,
0160 TrackContainer& tracks) const override {
0161 const auto gsfOptions = makeGsfOptions(options, calibrator);
0162
0163 using namespace Acts::GsfConstants;
0164 if (!tracks.hasColumn(Acts::hashString(kFinalMultiComponentStateColumn))) {
0165 std::string key(kFinalMultiComponentStateColumn);
0166 tracks.template addColumn<FinalMultiComponentState>(key);
0167 }
0168 if (!tracks.hasColumn(Acts::hashString(kFwdMaxMaterialXOverX0))) {
0169 tracks.template addColumn<double>(std::string(kFwdMaxMaterialXOverX0));
0170 }
0171 if (!tracks.hasColumn(Acts::hashString(kFwdSumMaterialXOverX0))) {
0172 tracks.template addColumn<double>(std::string(kFwdSumMaterialXOverX0));
0173 }
0174
0175 return directFitter.fit(sourceLinks.begin(), sourceLinks.end(),
0176 initialParameters, gsfOptions, surfaceSequence,
0177 tracks);
0178 }
0179 };
0180
0181 }
0182
0183 std::shared_ptr<TrackFitterFunction> ActsExamples::makeGsfFitterFunction(
0184 std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0185 std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0186 const std::shared_ptr<const Acts::BetheHeitlerApprox>& betheHeitlerApprox,
0187 std::size_t maxComponents, double weightCutoff,
0188 Acts::ComponentMergeMethod componentMergeMethod,
0189 MixtureReductionAlgorithm mixtureReductionAlgorithm,
0190 double reverseFilteringCovarianceScaling, const Acts::Logger& logger) {
0191
0192 MultiStepper stepper(magneticField, logger.cloneWithSuffix("Step"));
0193 const auto& geo = *trackingGeometry;
0194 Acts::Navigator::Config cfg{std::move(trackingGeometry)};
0195 cfg.resolvePassive = false;
0196 cfg.resolveMaterial = true;
0197 cfg.resolveSensitive = true;
0198 Acts::Navigator navigator(cfg, logger.cloneWithSuffix("Navigator"));
0199 Propagator propagator(std::move(stepper), std::move(navigator),
0200 logger.cloneWithSuffix("Propagator"));
0201 Fitter trackFitter(std::move(propagator), betheHeitlerApprox,
0202 logger.cloneWithSuffix("GSF"));
0203
0204
0205 MultiStepper directStepper(std::move(magneticField),
0206 logger.cloneWithSuffix("Step"));
0207 Acts::DirectNavigator directNavigator{
0208 logger.cloneWithSuffix("DirectNavigator")};
0209 DirectPropagator directPropagator(std::move(directStepper),
0210 std::move(directNavigator),
0211 logger.cloneWithSuffix("DirectPropagator"));
0212 DirectFitter directTrackFitter(std::move(directPropagator),
0213 betheHeitlerApprox,
0214 logger.cloneWithSuffix("DirectGSF"));
0215
0216
0217 auto fitterFunction = std::make_shared<GsfFitterFunctionImpl>(
0218 std::move(trackFitter), std::move(directTrackFitter), geo);
0219 fitterFunction->maxComponents = maxComponents;
0220 fitterFunction->weightCutoff = weightCutoff;
0221 fitterFunction->mergeMethod = componentMergeMethod;
0222 fitterFunction->reductionAlg = mixtureReductionAlgorithm;
0223 fitterFunction->reverseFilteringCovarianceScaling =
0224 reverseFilteringCovarianceScaling;
0225
0226 return fitterFunction;
0227 }