Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-15 08:14:44

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp"
0010 #include "Acts/Plugins/Python/Utilities.hpp"
0011 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0012 #include "Acts/TrackFitting/GsfOptions.hpp"
0013 #include "Acts/Utilities/Logger.hpp"
0014 #include "ActsExamples/EventData/MeasurementCalibration.hpp"
0015 #include "ActsExamples/EventData/ScalingCalibrator.hpp"
0016 #include "ActsExamples/TrackFitting/RefittingAlgorithm.hpp"
0017 #include "ActsExamples/TrackFitting/TrackFitterFunction.hpp"
0018 #include "ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp"
0019 
0020 #include <cstddef>
0021 #include <memory>
0022 
0023 #include <pybind11/pybind11.h>
0024 #include <pybind11/stl.h>
0025 
0026 namespace py = pybind11;
0027 
0028 using namespace ActsExamples;
0029 using namespace Acts;
0030 using namespace py::literals;
0031 
0032 namespace Acts::Python {
0033 
0034 void addTrackFitting(Context& ctx) {
0035   auto mex = ctx.get("examples");
0036 
0037   ACTS_PYTHON_DECLARE_ALGORITHM(
0038       ActsExamples::TrackFittingAlgorithm, mex, "TrackFittingAlgorithm",
0039       inputMeasurements, inputProtoTracks, inputInitialTrackParameters,
0040       inputClusters, outputTracks, fit, pickTrack, calibrator);
0041 
0042   ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::RefittingAlgorithm, mex,
0043                                 "RefittingAlgorithm", inputTracks, outputTracks,
0044                                 fit, pickTrack, initialVarInflation);
0045 
0046   {
0047     py::class_<TrackFitterFunction, std::shared_ptr<TrackFitterFunction>>(
0048         mex, "TrackFitterFunction");
0049 
0050     mex.def(
0051         "makeKalmanFitterFunction",
0052         [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0053            std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0054            bool multipleScattering, bool energyLoss,
0055            double reverseFilteringMomThreshold,
0056            double reverseFilteringCovarianceScaling,
0057            Acts::FreeToBoundCorrection freeToBoundCorrection, double chi2Cut,
0058            Logging::Level level) {
0059           return ActsExamples::makeKalmanFitterFunction(
0060               trackingGeometry, magneticField, multipleScattering, energyLoss,
0061               reverseFilteringMomThreshold, reverseFilteringCovarianceScaling,
0062               freeToBoundCorrection, chi2Cut,
0063               *Acts::getDefaultLogger("Kalman", level));
0064         },
0065         "trackingGeometry"_a, "magneticField"_a, "multipleScattering"_a,
0066         "energyLoss"_a, "reverseFilteringMomThreshold"_a,
0067         "reverseFilteringCovarianceScaling"_a, "freeToBoundCorrection"_a,
0068         "chi2Cut"_a, "level"_a);
0069 
0070     py::class_<MeasurementCalibrator, std::shared_ptr<MeasurementCalibrator>>(
0071         mex, "MeasurementCalibrator");
0072 
0073     mex.def("makePassThroughCalibrator",
0074             []() -> std::shared_ptr<MeasurementCalibrator> {
0075               return std::make_shared<PassThroughCalibrator>();
0076             });
0077 
0078     mex.def(
0079         "makeScalingCalibrator",
0080         [](const char* path) -> std::shared_ptr<MeasurementCalibrator> {
0081           return std::make_shared<ActsExamples::ScalingCalibrator>(path);
0082         },
0083         py::arg("path"));
0084 
0085     py::enum_<Acts::ComponentMergeMethod>(mex, "ComponentMergeMethod")
0086         .value("mean", Acts::ComponentMergeMethod::eMean)
0087         .value("maxWeight", Acts::ComponentMergeMethod::eMaxWeight);
0088 
0089     py::enum_<ActsExamples::MixtureReductionAlgorithm>(
0090         mex, "MixtureReductionAlgorithm")
0091         .value("weightCut", MixtureReductionAlgorithm::weightCut)
0092         .value("KLDistance", MixtureReductionAlgorithm::KLDistance);
0093 
0094     py::class_<ActsExamples::BetheHeitlerApprox>(mex, "AtlasBetheHeitlerApprox")
0095         .def_static(
0096             "loadFromFiles", &ActsExamples::BetheHeitlerApprox::loadFromFiles,
0097             "lowParametersPath"_a, "highParametersPath"_a, "lowLimit"_a = 0.1,
0098             "highLimit"_a = 0.2, "clampToRange"_a = false)
0099         .def_static(
0100             "makeDefault",
0101             [](bool clampToRange) {
0102               return Acts::makeDefaultBetheHeitlerApprox(clampToRange);
0103             },
0104             "clampToRange"_a = false);
0105 
0106     mex.def(
0107         "makeGsfFitterFunction",
0108         [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0109            std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0110            BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents,
0111            double weightCutoff, Acts::ComponentMergeMethod componentMergeMethod,
0112            ActsExamples::MixtureReductionAlgorithm mixtureReductionAlgorithm,
0113            double reverseFilteringCovarianceScaling, Logging::Level level) {
0114           return ActsExamples::makeGsfFitterFunction(
0115               trackingGeometry, magneticField, betheHeitlerApprox,
0116               maxComponents, weightCutoff, componentMergeMethod,
0117               mixtureReductionAlgorithm, reverseFilteringCovarianceScaling,
0118               *Acts::getDefaultLogger("GSFFunc", level));
0119         },
0120         "trackingGeometry"_a, "magneticField"_a, "betheHeitlerApprox"_a,
0121         "maxComponents"_a, "weightCutoff"_a, "componentMergeMethod"_a,
0122         "mixtureReductionAlgorithm"_a, "reverseFilteringCovarianceScaling"_a,
0123         "level"_a);
0124 
0125     mex.def(
0126         "makeGlobalChiSquareFitterFunction",
0127         [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0128            std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0129            bool multipleScattering, bool energyLoss,
0130            Acts::FreeToBoundCorrection freeToBoundCorrection,
0131            std::size_t nUpdateMax, double relChi2changeCutOff,
0132            Logging::Level level) {
0133           return ActsExamples::makeGlobalChiSquareFitterFunction(
0134               trackingGeometry, magneticField, multipleScattering, energyLoss,
0135               freeToBoundCorrection, nUpdateMax, relChi2changeCutOff,
0136               *Acts::getDefaultLogger("Gx2f", level));
0137         },
0138         py::arg("trackingGeometry"), py::arg("magneticField"),
0139         py::arg("multipleScattering"), py::arg("energyLoss"),
0140         py::arg("freeToBoundCorrection"), py::arg("nUpdateMax"),
0141         py::arg("relChi2changeCutOff"), py::arg("level"));
0142   }
0143 
0144   {
0145     py::class_<FreeToBoundCorrection>(mex, "FreeToBoundCorrection")
0146         .def(py::init<>())
0147         .def(py::init<bool>(), py::arg("apply") = false)
0148         .def(py::init<bool, double, double>(), py::arg("apply") = false,
0149              py::arg("alpha") = 0.1, py::arg("beta") = 2);
0150   }
0151 }
0152 
0153 }  // namespace Acts::Python