Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-01-09 09:26:48

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp"
0010 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0011 #include "Acts/TrackFitting/GsfOptions.hpp"
0012 #include "Acts/Utilities/Logger.hpp"
0013 #include "ActsExamples/EventData/MeasurementCalibration.hpp"
0014 #include "ActsExamples/EventData/ScalingCalibrator.hpp"
0015 #include "ActsExamples/TrackFitting/RefittingAlgorithm.hpp"
0016 #include "ActsExamples/TrackFitting/TrackFitterFunction.hpp"
0017 #include "ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp"
0018 #include "ActsPython/Utilities/Helpers.hpp"
0019 #include "ActsPython/Utilities/Macros.hpp"
0020 
0021 #include <cstddef>
0022 #include <memory>
0023 #include <utility>
0024 
0025 #include <pybind11/pybind11.h>
0026 #include <pybind11/stl.h>
0027 
0028 namespace py = pybind11;
0029 
0030 using namespace ActsExamples;
0031 using namespace Acts;
0032 using namespace py::literals;
0033 
0034 namespace ActsPython {
0035 
0036 void addTrackFitting(py::module& mex) {
0037   ACTS_PYTHON_DECLARE_ALGORITHM(
0038       TrackFittingAlgorithm, mex, "TrackFittingAlgorithm", inputMeasurements,
0039       inputProtoTracks, inputInitialTrackParameters, inputClusters,
0040       outputTracks, fit, pickTrack, calibrator);
0041 
0042   ACTS_PYTHON_DECLARE_ALGORITHM(RefittingAlgorithm, mex, "RefittingAlgorithm",
0043                                 inputTracks, outputTracks, fit, pickTrack,
0044                                 initialVarInflation);
0045 
0046   {
0047     py::class_<TrackFitterFunction, std::shared_ptr<TrackFitterFunction>>(
0048         mex, "TrackFitterFunction");
0049 
0050     mex.def(
0051         "makeKalmanFitterFunction",
0052         [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0053            std::shared_ptr<const MagneticFieldProvider> magneticField,
0054            bool multipleScattering, bool energyLoss,
0055            double reverseFilteringMomThreshold,
0056            double reverseFilteringCovarianceScaling,
0057            FreeToBoundCorrection freeToBoundCorrection, double chi2Cut,
0058            Logging::Level level) {
0059           return makeKalmanFitterFunction(
0060               std::move(trackingGeometry), std::move(magneticField),
0061               multipleScattering, energyLoss, reverseFilteringMomThreshold,
0062               reverseFilteringCovarianceScaling, freeToBoundCorrection, chi2Cut,
0063               *getDefaultLogger("Kalman", level));
0064         },
0065         "trackingGeometry"_a, "magneticField"_a, "multipleScattering"_a,
0066         "energyLoss"_a, "reverseFilteringMomThreshold"_a,
0067         "reverseFilteringCovarianceScaling"_a, "freeToBoundCorrection"_a,
0068         "chi2Cut"_a, "level"_a);
0069 
0070     py::class_<MeasurementCalibrator, std::shared_ptr<MeasurementCalibrator>>(
0071         mex, "MeasurementCalibrator");
0072 
0073     mex.def("makePassThroughCalibrator",
0074             []() -> std::shared_ptr<MeasurementCalibrator> {
0075               return std::make_shared<PassThroughCalibrator>();
0076             });
0077 
0078     mex.def(
0079         "makeScalingCalibrator",
0080         [](const char* path) -> std::shared_ptr<MeasurementCalibrator> {
0081           return std::make_shared<ScalingCalibrator>(path);
0082         },
0083         py::arg("path"));
0084 
0085     py::enum_<ComponentMergeMethod>(mex, "ComponentMergeMethod")
0086         .value("mean", ComponentMergeMethod::eMean)
0087         .value("maxWeight", ComponentMergeMethod::eMaxWeight);
0088 
0089     py::enum_<MixtureReductionAlgorithm>(mex, "MixtureReductionAlgorithm")
0090         .value("weightCut", MixtureReductionAlgorithm::weightCut)
0091         .value("KLDistance", MixtureReductionAlgorithm::KLDistance);
0092 
0093     py::class_<BetheHeitlerApprox, std::shared_ptr<BetheHeitlerApprox>>(
0094         mex, "BetheHeitlerApprox");
0095     py::class_<AtlasBetheHeitlerApprox, BetheHeitlerApprox,
0096                std::shared_ptr<AtlasBetheHeitlerApprox>>(
0097         mex, "AtlasBetheHeitlerApprox")
0098         .def_static("loadFromFiles", &AtlasBetheHeitlerApprox::loadFromFiles,
0099                     "lowParametersPath"_a, "highParametersPath"_a, "lowLimit"_a,
0100                     "highLimit"_a, "clampToRange"_a, "noChangeLimit"_a,
0101                     "singleGaussianLimit"_a)
0102         .def_static(
0103             "makeDefault",
0104             [](bool clampToRange) {
0105               return makeDefaultBetheHeitlerApprox(clampToRange);
0106             },
0107             "clampToRange"_a);
0108 
0109     mex.def(
0110         "makeGsfFitterFunction",
0111         [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0112            std::shared_ptr<const MagneticFieldProvider> magneticField,
0113            const std::shared_ptr<const BetheHeitlerApprox>& betheHeitlerApprox,
0114            std::size_t maxComponents, double weightCutoff,
0115            ComponentMergeMethod componentMergeMethod,
0116            MixtureReductionAlgorithm mixtureReductionAlgorithm,
0117            double reverseFilteringCovarianceScaling, Logging::Level level) {
0118           return makeGsfFitterFunction(
0119               std::move(trackingGeometry), std::move(magneticField),
0120               betheHeitlerApprox, maxComponents, weightCutoff,
0121               componentMergeMethod, mixtureReductionAlgorithm,
0122               reverseFilteringCovarianceScaling,
0123               *getDefaultLogger("GSFFunc", level));
0124         },
0125         "trackingGeometry"_a, "magneticField"_a, "betheHeitlerApprox"_a,
0126         "maxComponents"_a, "weightCutoff"_a, "componentMergeMethod"_a,
0127         "mixtureReductionAlgorithm"_a, "reverseFilteringCovarianceScaling"_a,
0128         "level"_a);
0129 
0130     mex.def(
0131         "makeGlobalChiSquareFitterFunction",
0132         [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0133            std::shared_ptr<const MagneticFieldProvider> magneticField,
0134            bool multipleScattering, bool energyLoss,
0135            FreeToBoundCorrection freeToBoundCorrection, std::size_t nUpdateMax,
0136            double relChi2changeCutOff, Logging::Level level) {
0137           return makeGlobalChiSquareFitterFunction(
0138               std::move(trackingGeometry), std::move(magneticField),
0139               multipleScattering, energyLoss, freeToBoundCorrection, nUpdateMax,
0140               relChi2changeCutOff, *getDefaultLogger("Gx2f", level));
0141         },
0142         py::arg("trackingGeometry"), py::arg("magneticField"),
0143         py::arg("multipleScattering"), py::arg("energyLoss"),
0144         py::arg("freeToBoundCorrection"), py::arg("nUpdateMax"),
0145         py::arg("relChi2changeCutOff"), py::arg("level"));
0146   }
0147 
0148   {
0149     py::class_<FreeToBoundCorrection>(mex, "FreeToBoundCorrection")
0150         .def(py::init<>())
0151         .def(py::init<bool>(), py::arg("apply") = false)
0152         .def(py::init<bool, double, double>(), py::arg("apply") = false,
0153              py::arg("alpha") = 0.1, py::arg("beta") = 2);
0154   }
0155 }
0156 
0157 }  // namespace ActsPython