File indexing completed on 2025-10-15 08:04:46
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp"
0010 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0011 #include "Acts/TrackFitting/GsfOptions.hpp"
0012 #include "Acts/Utilities/Logger.hpp"
0013 #include "ActsExamples/EventData/MeasurementCalibration.hpp"
0014 #include "ActsExamples/EventData/ScalingCalibrator.hpp"
0015 #include "ActsExamples/TrackFitting/RefittingAlgorithm.hpp"
0016 #include "ActsExamples/TrackFitting/TrackFitterFunction.hpp"
0017 #include "ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp"
0018 #include "ActsPython/Utilities/Helpers.hpp"
0019 #include "ActsPython/Utilities/Macros.hpp"
0020
0021 #include <cstddef>
0022 #include <memory>
0023
0024 #include <pybind11/pybind11.h>
0025 #include <pybind11/stl.h>
0026
0027 namespace py = pybind11;
0028
0029 using namespace ActsExamples;
0030 using namespace Acts;
0031 using namespace py::literals;
0032
0033 namespace ActsPython {
0034
0035 void addTrackFitting(Context& ctx) {
0036 auto mex = ctx.get("examples");
0037
0038 ACTS_PYTHON_DECLARE_ALGORITHM(
0039 TrackFittingAlgorithm, mex, "TrackFittingAlgorithm", inputMeasurements,
0040 inputProtoTracks, inputInitialTrackParameters, inputClusters,
0041 outputTracks, fit, pickTrack, calibrator);
0042
0043 ACTS_PYTHON_DECLARE_ALGORITHM(RefittingAlgorithm, mex, "RefittingAlgorithm",
0044 inputTracks, outputTracks, fit, pickTrack,
0045 initialVarInflation);
0046
0047 {
0048 py::class_<TrackFitterFunction, std::shared_ptr<TrackFitterFunction>>(
0049 mex, "TrackFitterFunction");
0050
0051 mex.def(
0052 "makeKalmanFitterFunction",
0053 [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0054 std::shared_ptr<const MagneticFieldProvider> magneticField,
0055 bool multipleScattering, bool energyLoss,
0056 double reverseFilteringMomThreshold,
0057 double reverseFilteringCovarianceScaling,
0058 FreeToBoundCorrection freeToBoundCorrection, double chi2Cut,
0059 Logging::Level level) {
0060 return makeKalmanFitterFunction(
0061 trackingGeometry, magneticField, multipleScattering, energyLoss,
0062 reverseFilteringMomThreshold, reverseFilteringCovarianceScaling,
0063 freeToBoundCorrection, chi2Cut,
0064 *getDefaultLogger("Kalman", level));
0065 },
0066 "trackingGeometry"_a, "magneticField"_a, "multipleScattering"_a,
0067 "energyLoss"_a, "reverseFilteringMomThreshold"_a,
0068 "reverseFilteringCovarianceScaling"_a, "freeToBoundCorrection"_a,
0069 "chi2Cut"_a, "level"_a);
0070
0071 py::class_<MeasurementCalibrator, std::shared_ptr<MeasurementCalibrator>>(
0072 mex, "MeasurementCalibrator");
0073
0074 mex.def("makePassThroughCalibrator",
0075 []() -> std::shared_ptr<MeasurementCalibrator> {
0076 return std::make_shared<PassThroughCalibrator>();
0077 });
0078
0079 mex.def(
0080 "makeScalingCalibrator",
0081 [](const char* path) -> std::shared_ptr<MeasurementCalibrator> {
0082 return std::make_shared<ScalingCalibrator>(path);
0083 },
0084 py::arg("path"));
0085
0086 py::enum_<ComponentMergeMethod>(mex, "ComponentMergeMethod")
0087 .value("mean", ComponentMergeMethod::eMean)
0088 .value("maxWeight", ComponentMergeMethod::eMaxWeight);
0089
0090 py::enum_<MixtureReductionAlgorithm>(mex, "MixtureReductionAlgorithm")
0091 .value("weightCut", MixtureReductionAlgorithm::weightCut)
0092 .value("KLDistance", MixtureReductionAlgorithm::KLDistance);
0093
0094 py::class_<BetheHeitlerApprox>(mex, "AtlasBetheHeitlerApprox")
0095 .def_static("loadFromFiles", &BetheHeitlerApprox::loadFromFiles,
0096 "lowParametersPath"_a, "highParametersPath"_a,
0097 "lowLimit"_a = 0.1, "highLimit"_a = 0.2,
0098 "clampToRange"_a = false)
0099 .def_static(
0100 "makeDefault",
0101 [](bool clampToRange) {
0102 return makeDefaultBetheHeitlerApprox(clampToRange);
0103 },
0104 "clampToRange"_a = false);
0105
0106 mex.def(
0107 "makeGsfFitterFunction",
0108 [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0109 std::shared_ptr<const MagneticFieldProvider> magneticField,
0110 BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents,
0111 double weightCutoff, ComponentMergeMethod componentMergeMethod,
0112 MixtureReductionAlgorithm mixtureReductionAlgorithm,
0113 double reverseFilteringCovarianceScaling, Logging::Level level) {
0114 return makeGsfFitterFunction(
0115 trackingGeometry, magneticField, betheHeitlerApprox,
0116 maxComponents, weightCutoff, componentMergeMethod,
0117 mixtureReductionAlgorithm, reverseFilteringCovarianceScaling,
0118 *getDefaultLogger("GSFFunc", level));
0119 },
0120 "trackingGeometry"_a, "magneticField"_a, "betheHeitlerApprox"_a,
0121 "maxComponents"_a, "weightCutoff"_a, "componentMergeMethod"_a,
0122 "mixtureReductionAlgorithm"_a, "reverseFilteringCovarianceScaling"_a,
0123 "level"_a);
0124
0125 mex.def(
0126 "makeGlobalChiSquareFitterFunction",
0127 [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0128 std::shared_ptr<const MagneticFieldProvider> magneticField,
0129 bool multipleScattering, bool energyLoss,
0130 FreeToBoundCorrection freeToBoundCorrection, std::size_t nUpdateMax,
0131 double relChi2changeCutOff, Logging::Level level) {
0132 return makeGlobalChiSquareFitterFunction(
0133 trackingGeometry, magneticField, multipleScattering, energyLoss,
0134 freeToBoundCorrection, nUpdateMax, relChi2changeCutOff,
0135 *getDefaultLogger("Gx2f", level));
0136 },
0137 py::arg("trackingGeometry"), py::arg("magneticField"),
0138 py::arg("multipleScattering"), py::arg("energyLoss"),
0139 py::arg("freeToBoundCorrection"), py::arg("nUpdateMax"),
0140 py::arg("relChi2changeCutOff"), py::arg("level"));
0141 }
0142
0143 {
0144 py::class_<FreeToBoundCorrection>(mex, "FreeToBoundCorrection")
0145 .def(py::init<>())
0146 .def(py::init<bool>(), py::arg("apply") = false)
0147 .def(py::init<bool, double, double>(), py::arg("apply") = false,
0148 py::arg("alpha") = 0.1, py::arg("beta") = 2);
0149 }
0150 }
0151
0152 }