File indexing completed on 2026-01-09 09:26:48
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp"
0010 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0011 #include "Acts/TrackFitting/GsfOptions.hpp"
0012 #include "Acts/Utilities/Logger.hpp"
0013 #include "ActsExamples/EventData/MeasurementCalibration.hpp"
0014 #include "ActsExamples/EventData/ScalingCalibrator.hpp"
0015 #include "ActsExamples/TrackFitting/RefittingAlgorithm.hpp"
0016 #include "ActsExamples/TrackFitting/TrackFitterFunction.hpp"
0017 #include "ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp"
0018 #include "ActsPython/Utilities/Helpers.hpp"
0019 #include "ActsPython/Utilities/Macros.hpp"
0020
0021 #include <cstddef>
0022 #include <memory>
0023 #include <utility>
0024
0025 #include <pybind11/pybind11.h>
0026 #include <pybind11/stl.h>
0027
0028 namespace py = pybind11;
0029
0030 using namespace ActsExamples;
0031 using namespace Acts;
0032 using namespace py::literals;
0033
0034 namespace ActsPython {
0035
0036 void addTrackFitting(py::module& mex) {
0037 ACTS_PYTHON_DECLARE_ALGORITHM(
0038 TrackFittingAlgorithm, mex, "TrackFittingAlgorithm", inputMeasurements,
0039 inputProtoTracks, inputInitialTrackParameters, inputClusters,
0040 outputTracks, fit, pickTrack, calibrator);
0041
0042 ACTS_PYTHON_DECLARE_ALGORITHM(RefittingAlgorithm, mex, "RefittingAlgorithm",
0043 inputTracks, outputTracks, fit, pickTrack,
0044 initialVarInflation);
0045
0046 {
0047 py::class_<TrackFitterFunction, std::shared_ptr<TrackFitterFunction>>(
0048 mex, "TrackFitterFunction");
0049
0050 mex.def(
0051 "makeKalmanFitterFunction",
0052 [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0053 std::shared_ptr<const MagneticFieldProvider> magneticField,
0054 bool multipleScattering, bool energyLoss,
0055 double reverseFilteringMomThreshold,
0056 double reverseFilteringCovarianceScaling,
0057 FreeToBoundCorrection freeToBoundCorrection, double chi2Cut,
0058 Logging::Level level) {
0059 return makeKalmanFitterFunction(
0060 std::move(trackingGeometry), std::move(magneticField),
0061 multipleScattering, energyLoss, reverseFilteringMomThreshold,
0062 reverseFilteringCovarianceScaling, freeToBoundCorrection, chi2Cut,
0063 *getDefaultLogger("Kalman", level));
0064 },
0065 "trackingGeometry"_a, "magneticField"_a, "multipleScattering"_a,
0066 "energyLoss"_a, "reverseFilteringMomThreshold"_a,
0067 "reverseFilteringCovarianceScaling"_a, "freeToBoundCorrection"_a,
0068 "chi2Cut"_a, "level"_a);
0069
0070 py::class_<MeasurementCalibrator, std::shared_ptr<MeasurementCalibrator>>(
0071 mex, "MeasurementCalibrator");
0072
0073 mex.def("makePassThroughCalibrator",
0074 []() -> std::shared_ptr<MeasurementCalibrator> {
0075 return std::make_shared<PassThroughCalibrator>();
0076 });
0077
0078 mex.def(
0079 "makeScalingCalibrator",
0080 [](const char* path) -> std::shared_ptr<MeasurementCalibrator> {
0081 return std::make_shared<ScalingCalibrator>(path);
0082 },
0083 py::arg("path"));
0084
0085 py::enum_<ComponentMergeMethod>(mex, "ComponentMergeMethod")
0086 .value("mean", ComponentMergeMethod::eMean)
0087 .value("maxWeight", ComponentMergeMethod::eMaxWeight);
0088
0089 py::enum_<MixtureReductionAlgorithm>(mex, "MixtureReductionAlgorithm")
0090 .value("weightCut", MixtureReductionAlgorithm::weightCut)
0091 .value("KLDistance", MixtureReductionAlgorithm::KLDistance);
0092
0093 py::class_<BetheHeitlerApprox, std::shared_ptr<BetheHeitlerApprox>>(
0094 mex, "BetheHeitlerApprox");
0095 py::class_<AtlasBetheHeitlerApprox, BetheHeitlerApprox,
0096 std::shared_ptr<AtlasBetheHeitlerApprox>>(
0097 mex, "AtlasBetheHeitlerApprox")
0098 .def_static("loadFromFiles", &AtlasBetheHeitlerApprox::loadFromFiles,
0099 "lowParametersPath"_a, "highParametersPath"_a, "lowLimit"_a,
0100 "highLimit"_a, "clampToRange"_a, "noChangeLimit"_a,
0101 "singleGaussianLimit"_a)
0102 .def_static(
0103 "makeDefault",
0104 [](bool clampToRange) {
0105 return makeDefaultBetheHeitlerApprox(clampToRange);
0106 },
0107 "clampToRange"_a);
0108
0109 mex.def(
0110 "makeGsfFitterFunction",
0111 [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0112 std::shared_ptr<const MagneticFieldProvider> magneticField,
0113 const std::shared_ptr<const BetheHeitlerApprox>& betheHeitlerApprox,
0114 std::size_t maxComponents, double weightCutoff,
0115 ComponentMergeMethod componentMergeMethod,
0116 MixtureReductionAlgorithm mixtureReductionAlgorithm,
0117 double reverseFilteringCovarianceScaling, Logging::Level level) {
0118 return makeGsfFitterFunction(
0119 std::move(trackingGeometry), std::move(magneticField),
0120 betheHeitlerApprox, maxComponents, weightCutoff,
0121 componentMergeMethod, mixtureReductionAlgorithm,
0122 reverseFilteringCovarianceScaling,
0123 *getDefaultLogger("GSFFunc", level));
0124 },
0125 "trackingGeometry"_a, "magneticField"_a, "betheHeitlerApprox"_a,
0126 "maxComponents"_a, "weightCutoff"_a, "componentMergeMethod"_a,
0127 "mixtureReductionAlgorithm"_a, "reverseFilteringCovarianceScaling"_a,
0128 "level"_a);
0129
0130 mex.def(
0131 "makeGlobalChiSquareFitterFunction",
0132 [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0133 std::shared_ptr<const MagneticFieldProvider> magneticField,
0134 bool multipleScattering, bool energyLoss,
0135 FreeToBoundCorrection freeToBoundCorrection, std::size_t nUpdateMax,
0136 double relChi2changeCutOff, Logging::Level level) {
0137 return makeGlobalChiSquareFitterFunction(
0138 std::move(trackingGeometry), std::move(magneticField),
0139 multipleScattering, energyLoss, freeToBoundCorrection, nUpdateMax,
0140 relChi2changeCutOff, *getDefaultLogger("Gx2f", level));
0141 },
0142 py::arg("trackingGeometry"), py::arg("magneticField"),
0143 py::arg("multipleScattering"), py::arg("energyLoss"),
0144 py::arg("freeToBoundCorrection"), py::arg("nUpdateMax"),
0145 py::arg("relChi2changeCutOff"), py::arg("level"));
0146 }
0147
0148 {
0149 py::class_<FreeToBoundCorrection>(mex, "FreeToBoundCorrection")
0150 .def(py::init<>())
0151 .def(py::init<bool>(), py::arg("apply") = false)
0152 .def(py::init<bool, double, double>(), py::arg("apply") = false,
0153 py::arg("alpha") = 0.1, py::arg("beta") = 2);
0154 }
0155 }
0156
0157 }