Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:12:04

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp"
0010 #include "Acts/Plugins/Python/Utilities.hpp"
0011 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0012 #include "Acts/TrackFitting/GsfOptions.hpp"
0013 #include "Acts/Utilities/Logger.hpp"
0014 #include "ActsExamples/EventData/MeasurementCalibration.hpp"
0015 #include "ActsExamples/EventData/ScalingCalibrator.hpp"
0016 #include "ActsExamples/TrackFitting/RefittingAlgorithm.hpp"
0017 #include "ActsExamples/TrackFitting/TrackFitterFunction.hpp"
0018 #include "ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp"
0019 
0020 #include <cstddef>
0021 #include <memory>
0022 
0023 #include <pybind11/pybind11.h>
0024 #include <pybind11/stl.h>
0025 
0026 namespace py = pybind11;
0027 
0028 using namespace ActsExamples;
0029 using namespace Acts;
0030 using namespace py::literals;
0031 
0032 namespace Acts::Python {
0033 
0034 void addTrackFitting(Context& ctx) {
0035   auto mex = ctx.get("examples");
0036 
0037   ACTS_PYTHON_DECLARE_ALGORITHM(
0038       ActsExamples::TrackFittingAlgorithm, mex, "TrackFittingAlgorithm",
0039       inputMeasurements, inputProtoTracks, inputInitialTrackParameters,
0040       inputClusters, outputTracks, fit, pickTrack, calibrator);
0041 
0042   ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::RefittingAlgorithm, mex,
0043                                 "RefittingAlgorithm", inputTracks, outputTracks,
0044                                 fit, pickTrack);
0045 
0046   {
0047     py::class_<TrackFitterFunction, std::shared_ptr<TrackFitterFunction>>(
0048         mex, "TrackFitterFunction");
0049 
0050     mex.def(
0051         "makeKalmanFitterFunction",
0052         [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0053            std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0054            bool multipleScattering, bool energyLoss,
0055            double reverseFilteringMomThreshold,
0056            Acts::FreeToBoundCorrection freeToBoundCorrection,
0057            Logging::Level level) {
0058           return ActsExamples::makeKalmanFitterFunction(
0059               trackingGeometry, magneticField, multipleScattering, energyLoss,
0060               reverseFilteringMomThreshold, freeToBoundCorrection,
0061               *Acts::getDefaultLogger("Kalman", level));
0062         },
0063         py::arg("trackingGeometry"), py::arg("magneticField"),
0064         py::arg("multipleScattering"), py::arg("energyLoss"),
0065         py::arg("reverseFilteringMomThreshold"),
0066         py::arg("freeToBoundCorrection"), py::arg("level"));
0067 
0068     py::class_<MeasurementCalibrator, std::shared_ptr<MeasurementCalibrator>>(
0069         mex, "MeasurementCalibrator");
0070 
0071     mex.def("makePassThroughCalibrator",
0072             []() -> std::shared_ptr<MeasurementCalibrator> {
0073               return std::make_shared<PassThroughCalibrator>();
0074             });
0075 
0076     mex.def(
0077         "makeScalingCalibrator",
0078         [](const char* path) -> std::shared_ptr<MeasurementCalibrator> {
0079           return std::make_shared<ActsExamples::ScalingCalibrator>(path);
0080         },
0081         py::arg("path"));
0082 
0083     py::enum_<Acts::ComponentMergeMethod>(mex, "ComponentMergeMethod")
0084         .value("mean", Acts::ComponentMergeMethod::eMean)
0085         .value("maxWeight", Acts::ComponentMergeMethod::eMaxWeight);
0086 
0087     py::enum_<ActsExamples::MixtureReductionAlgorithm>(
0088         mex, "MixtureReductionAlgorithm")
0089         .value("weightCut", MixtureReductionAlgorithm::weightCut)
0090         .value("KLDistance", MixtureReductionAlgorithm::KLDistance);
0091 
0092     py::class_<ActsExamples::BetheHeitlerApprox>(mex, "AtlasBetheHeitlerApprox")
0093         .def_static(
0094             "loadFromFiles", &ActsExamples::BetheHeitlerApprox::loadFromFiles,
0095             "lowParametersPath"_a, "highParametersPath"_a, "lowLimit"_a = 0.1,
0096             "highLimit"_a = 0.2, "clampToRange"_a = false)
0097         .def_static(
0098             "makeDefault",
0099             [](bool clampToRange) {
0100               return Acts::makeDefaultBetheHeitlerApprox(clampToRange);
0101             },
0102             "clampToRange"_a = false);
0103 
0104     mex.def(
0105         "makeGsfFitterFunction",
0106         [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0107            std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0108            BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents,
0109            double weightCutoff, Acts::ComponentMergeMethod componentMergeMethod,
0110            ActsExamples::MixtureReductionAlgorithm mixtureReductionAlgorithm,
0111            Logging::Level level) {
0112           return ActsExamples::makeGsfFitterFunction(
0113               trackingGeometry, magneticField, betheHeitlerApprox,
0114               maxComponents, weightCutoff, componentMergeMethod,
0115               mixtureReductionAlgorithm,
0116               *Acts::getDefaultLogger("GSFFunc", level));
0117         },
0118         py::arg("trackingGeometry"), py::arg("magneticField"),
0119         py::arg("betheHeitlerApprox"), py::arg("maxComponents"),
0120         py::arg("weightCutoff"), py::arg("componentMergeMethod"),
0121         py::arg("mixtureReductionAlgorithm"), py::arg("level"));
0122 
0123     mex.def(
0124         "makeGlobalChiSquareFitterFunction",
0125         [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0126            std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0127            bool multipleScattering, bool energyLoss,
0128            Acts::FreeToBoundCorrection freeToBoundCorrection,
0129            std::size_t nUpdateMax, double relChi2changeCutOff,
0130            Logging::Level level) {
0131           return ActsExamples::makeGlobalChiSquareFitterFunction(
0132               trackingGeometry, magneticField, multipleScattering, energyLoss,
0133               freeToBoundCorrection, nUpdateMax, relChi2changeCutOff,
0134               *Acts::getDefaultLogger("Gx2f", level));
0135         },
0136         py::arg("trackingGeometry"), py::arg("magneticField"),
0137         py::arg("multipleScattering"), py::arg("energyLoss"),
0138         py::arg("freeToBoundCorrection"), py::arg("nUpdateMax"),
0139         py::arg("relChi2changeCutOff"), py::arg("level"));
0140   }
0141 
0142   {
0143     py::class_<FreeToBoundCorrection>(mex, "FreeToBoundCorrection")
0144         .def(py::init<>())
0145         .def(py::init<bool>(), py::arg("apply") = false)
0146         .def(py::init<bool, double, double>(), py::arg("apply") = false,
0147              py::arg("alpha") = 0.1, py::arg("beta") = 2);
0148   }
0149 }
0150 
0151 }  // namespace Acts::Python