Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-10-15 08:04:46

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp"
0010 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0011 #include "Acts/TrackFitting/GsfOptions.hpp"
0012 #include "Acts/Utilities/Logger.hpp"
0013 #include "ActsExamples/EventData/MeasurementCalibration.hpp"
0014 #include "ActsExamples/EventData/ScalingCalibrator.hpp"
0015 #include "ActsExamples/TrackFitting/RefittingAlgorithm.hpp"
0016 #include "ActsExamples/TrackFitting/TrackFitterFunction.hpp"
0017 #include "ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp"
0018 #include "ActsPython/Utilities/Helpers.hpp"
0019 #include "ActsPython/Utilities/Macros.hpp"
0020 
0021 #include <cstddef>
0022 #include <memory>
0023 
0024 #include <pybind11/pybind11.h>
0025 #include <pybind11/stl.h>
0026 
0027 namespace py = pybind11;
0028 
0029 using namespace ActsExamples;
0030 using namespace Acts;
0031 using namespace py::literals;
0032 
0033 namespace ActsPython {
0034 
0035 void addTrackFitting(Context& ctx) {
0036   auto mex = ctx.get("examples");
0037 
0038   ACTS_PYTHON_DECLARE_ALGORITHM(
0039       TrackFittingAlgorithm, mex, "TrackFittingAlgorithm", inputMeasurements,
0040       inputProtoTracks, inputInitialTrackParameters, inputClusters,
0041       outputTracks, fit, pickTrack, calibrator);
0042 
0043   ACTS_PYTHON_DECLARE_ALGORITHM(RefittingAlgorithm, mex, "RefittingAlgorithm",
0044                                 inputTracks, outputTracks, fit, pickTrack,
0045                                 initialVarInflation);
0046 
0047   {
0048     py::class_<TrackFitterFunction, std::shared_ptr<TrackFitterFunction>>(
0049         mex, "TrackFitterFunction");
0050 
0051     mex.def(
0052         "makeKalmanFitterFunction",
0053         [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0054            std::shared_ptr<const MagneticFieldProvider> magneticField,
0055            bool multipleScattering, bool energyLoss,
0056            double reverseFilteringMomThreshold,
0057            double reverseFilteringCovarianceScaling,
0058            FreeToBoundCorrection freeToBoundCorrection, double chi2Cut,
0059            Logging::Level level) {
0060           return makeKalmanFitterFunction(
0061               trackingGeometry, magneticField, multipleScattering, energyLoss,
0062               reverseFilteringMomThreshold, reverseFilteringCovarianceScaling,
0063               freeToBoundCorrection, chi2Cut,
0064               *getDefaultLogger("Kalman", level));
0065         },
0066         "trackingGeometry"_a, "magneticField"_a, "multipleScattering"_a,
0067         "energyLoss"_a, "reverseFilteringMomThreshold"_a,
0068         "reverseFilteringCovarianceScaling"_a, "freeToBoundCorrection"_a,
0069         "chi2Cut"_a, "level"_a);
0070 
0071     py::class_<MeasurementCalibrator, std::shared_ptr<MeasurementCalibrator>>(
0072         mex, "MeasurementCalibrator");
0073 
0074     mex.def("makePassThroughCalibrator",
0075             []() -> std::shared_ptr<MeasurementCalibrator> {
0076               return std::make_shared<PassThroughCalibrator>();
0077             });
0078 
0079     mex.def(
0080         "makeScalingCalibrator",
0081         [](const char* path) -> std::shared_ptr<MeasurementCalibrator> {
0082           return std::make_shared<ScalingCalibrator>(path);
0083         },
0084         py::arg("path"));
0085 
0086     py::enum_<ComponentMergeMethod>(mex, "ComponentMergeMethod")
0087         .value("mean", ComponentMergeMethod::eMean)
0088         .value("maxWeight", ComponentMergeMethod::eMaxWeight);
0089 
0090     py::enum_<MixtureReductionAlgorithm>(mex, "MixtureReductionAlgorithm")
0091         .value("weightCut", MixtureReductionAlgorithm::weightCut)
0092         .value("KLDistance", MixtureReductionAlgorithm::KLDistance);
0093 
0094     py::class_<BetheHeitlerApprox>(mex, "AtlasBetheHeitlerApprox")
0095         .def_static("loadFromFiles", &BetheHeitlerApprox::loadFromFiles,
0096                     "lowParametersPath"_a, "highParametersPath"_a,
0097                     "lowLimit"_a = 0.1, "highLimit"_a = 0.2,
0098                     "clampToRange"_a = false)
0099         .def_static(
0100             "makeDefault",
0101             [](bool clampToRange) {
0102               return makeDefaultBetheHeitlerApprox(clampToRange);
0103             },
0104             "clampToRange"_a = false);
0105 
0106     mex.def(
0107         "makeGsfFitterFunction",
0108         [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0109            std::shared_ptr<const MagneticFieldProvider> magneticField,
0110            BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents,
0111            double weightCutoff, ComponentMergeMethod componentMergeMethod,
0112            MixtureReductionAlgorithm mixtureReductionAlgorithm,
0113            double reverseFilteringCovarianceScaling, Logging::Level level) {
0114           return makeGsfFitterFunction(
0115               trackingGeometry, magneticField, betheHeitlerApprox,
0116               maxComponents, weightCutoff, componentMergeMethod,
0117               mixtureReductionAlgorithm, reverseFilteringCovarianceScaling,
0118               *getDefaultLogger("GSFFunc", level));
0119         },
0120         "trackingGeometry"_a, "magneticField"_a, "betheHeitlerApprox"_a,
0121         "maxComponents"_a, "weightCutoff"_a, "componentMergeMethod"_a,
0122         "mixtureReductionAlgorithm"_a, "reverseFilteringCovarianceScaling"_a,
0123         "level"_a);
0124 
0125     mex.def(
0126         "makeGlobalChiSquareFitterFunction",
0127         [](std::shared_ptr<const TrackingGeometry> trackingGeometry,
0128            std::shared_ptr<const MagneticFieldProvider> magneticField,
0129            bool multipleScattering, bool energyLoss,
0130            FreeToBoundCorrection freeToBoundCorrection, std::size_t nUpdateMax,
0131            double relChi2changeCutOff, Logging::Level level) {
0132           return makeGlobalChiSquareFitterFunction(
0133               trackingGeometry, magneticField, multipleScattering, energyLoss,
0134               freeToBoundCorrection, nUpdateMax, relChi2changeCutOff,
0135               *getDefaultLogger("Gx2f", level));
0136         },
0137         py::arg("trackingGeometry"), py::arg("magneticField"),
0138         py::arg("multipleScattering"), py::arg("energyLoss"),
0139         py::arg("freeToBoundCorrection"), py::arg("nUpdateMax"),
0140         py::arg("relChi2changeCutOff"), py::arg("level"));
0141   }
0142 
0143   {
0144     py::class_<FreeToBoundCorrection>(mex, "FreeToBoundCorrection")
0145         .def(py::init<>())
0146         .def(py::init<bool>(), py::arg("apply") = false)
0147         .def(py::init<bool, double, double>(), py::arg("apply") = false,
0148              py::arg("alpha") = 0.1, py::arg("beta") = 2);
0149   }
0150 }
0151 
0152 }  // namespace ActsPython