Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:12:02

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp"
0010 #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
0011 #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
0012 #include "Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp"
0013 #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
0014 #include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
0015 #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
0016 #include "Acts/Plugins/Python/Utilities.hpp"
0017 #include "ActsExamples/TrackFindingExaTrkX/PrototracksToParameters.hpp"
0018 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp"
0019 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingFromPrototrackAlgorithm.hpp"
0020 #include "ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp"
0021 
0022 #include <memory>
0023 
0024 #include <pybind11/functional.h>
0025 #include <pybind11/pybind11.h>
0026 #include <pybind11/stl.h>
0027 
0028 namespace py = pybind11;
0029 
0030 using namespace ActsExamples;
0031 using namespace Acts;
0032 using namespace py::literals;
0033 
0034 namespace Acts::Python {
0035 
0036 void addExaTrkXTrackFinding(Context &ctx) {
0037   auto [m, mex] = ctx.get("main", "examples");
0038 
0039   {
0040     using C = Acts::GraphConstructionBase;
0041     auto c = py::class_<C, std::shared_ptr<C>>(mex, "GraphConstructionBase");
0042   }
0043   {
0044     using C = Acts::EdgeClassificationBase;
0045     auto c = py::class_<C, std::shared_ptr<C>>(mex, "EdgeClassificationBase");
0046   }
0047   {
0048     using C = Acts::TrackBuildingBase;
0049     auto c = py::class_<C, std::shared_ptr<C>>(mex, "TrackBuildingBase");
0050   }
0051 
0052 #ifdef ACTS_EXATRKX_TORCH_BACKEND
0053   {
0054     using Alg = Acts::TorchMetricLearning;
0055     using Config = Alg::Config;
0056 
0057     auto alg =
0058         py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
0059             mex, "TorchMetricLearning")
0060             .def(py::init([](const Config &c, Logging::Level lvl) {
0061                    return std::make_shared<Alg>(
0062                        c, getDefaultLogger("MetricLearning", lvl));
0063                  }),
0064                  py::arg("config"), py::arg("level"))
0065             .def_property_readonly("config", &Alg::config);
0066 
0067     auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0068     ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0069     ACTS_PYTHON_MEMBER(modelPath);
0070     ACTS_PYTHON_MEMBER(selectedFeatures);
0071     ACTS_PYTHON_MEMBER(embeddingDim);
0072     ACTS_PYTHON_MEMBER(rVal);
0073     ACTS_PYTHON_MEMBER(knnVal);
0074     ACTS_PYTHON_MEMBER(deviceID);
0075     ACTS_PYTHON_STRUCT_END();
0076   }
0077   {
0078     using Alg = Acts::TorchEdgeClassifier;
0079     using Config = Alg::Config;
0080 
0081     auto alg =
0082         py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
0083             mex, "TorchEdgeClassifier")
0084             .def(py::init([](const Config &c, Logging::Level lvl) {
0085                    return std::make_shared<Alg>(
0086                        c, getDefaultLogger("EdgeClassifier", lvl));
0087                  }),
0088                  py::arg("config"), py::arg("level"))
0089             .def_property_readonly("config", &Alg::config);
0090 
0091     auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0092     ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0093     ACTS_PYTHON_MEMBER(modelPath);
0094     ACTS_PYTHON_MEMBER(selectedFeatures);
0095     ACTS_PYTHON_MEMBER(cut);
0096     ACTS_PYTHON_MEMBER(nChunks);
0097     ACTS_PYTHON_MEMBER(undirected);
0098     ACTS_PYTHON_MEMBER(deviceID);
0099     ACTS_PYTHON_MEMBER(useEdgeFeatures);
0100     ACTS_PYTHON_STRUCT_END();
0101   }
0102   {
0103     using Alg = Acts::BoostTrackBuilding;
0104 
0105     auto alg = py::class_<Alg, Acts::TrackBuildingBase, std::shared_ptr<Alg>>(
0106                    mex, "BoostTrackBuilding")
0107                    .def(py::init([](Logging::Level lvl) {
0108                           return std::make_shared<Alg>(
0109                               getDefaultLogger("EdgeClassifier", lvl));
0110                         }),
0111                         py::arg("level"));
0112   }
0113 #endif
0114 
0115 #ifdef ACTS_EXATRKX_ONNX_BACKEND
0116   {
0117     using Alg = Acts::OnnxMetricLearning;
0118     using Config = Alg::Config;
0119 
0120     auto alg =
0121         py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
0122             mex, "OnnxMetricLearning")
0123             .def(py::init([](const Config &c, Logging::Level lvl) {
0124                    return std::make_shared<Alg>(
0125                        c, getDefaultLogger("MetricLearning", lvl));
0126                  }),
0127                  py::arg("config"), py::arg("level"))
0128             .def_property_readonly("config", &Alg::config);
0129 
0130     auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0131     ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0132     ACTS_PYTHON_MEMBER(modelPath);
0133     ACTS_PYTHON_MEMBER(spacepointFeatures);
0134     ACTS_PYTHON_MEMBER(embeddingDim);
0135     ACTS_PYTHON_MEMBER(rVal);
0136     ACTS_PYTHON_MEMBER(knnVal);
0137     ACTS_PYTHON_STRUCT_END();
0138   }
0139   {
0140     using Alg = Acts::OnnxEdgeClassifier;
0141     using Config = Alg::Config;
0142 
0143     auto alg =
0144         py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
0145             mex, "OnnxEdgeClassifier")
0146             .def(py::init([](const Config &c, Logging::Level lvl) {
0147                    return std::make_shared<Alg>(
0148                        c, getDefaultLogger("EdgeClassifier", lvl));
0149                  }),
0150                  py::arg("config"), py::arg("level"))
0151             .def_property_readonly("config", &Alg::config);
0152 
0153     auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0154     ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0155     ACTS_PYTHON_MEMBER(modelPath);
0156     ACTS_PYTHON_MEMBER(cut);
0157     ACTS_PYTHON_STRUCT_END();
0158   }
0159 #endif
0160 
0161   ACTS_PYTHON_DECLARE_ALGORITHM(
0162       ActsExamples::TruthGraphBuilder, mex, "TruthGraphBuilder",
0163       inputSpacePoints, inputSimHits, inputParticles,
0164       inputMeasurementSimHitsMap, inputMeasurementParticlesMap, outputGraph,
0165       targetMinPT, targetMinSize, uniqueModules);
0166 
0167   {
0168     auto nodeFeatureEnum =
0169         py::enum_<TrackFindingAlgorithmExaTrkX::NodeFeature>(mex, "NodeFeature")
0170             .value("R", TrackFindingAlgorithmExaTrkX::NodeFeature::eR)
0171             .value("Phi", TrackFindingAlgorithmExaTrkX::NodeFeature::ePhi)
0172             .value("Z", TrackFindingAlgorithmExaTrkX::NodeFeature::eZ)
0173             .value("X", TrackFindingAlgorithmExaTrkX::NodeFeature::eX)
0174             .value("Y", TrackFindingAlgorithmExaTrkX::NodeFeature::eY)
0175             .value("Eta", TrackFindingAlgorithmExaTrkX::NodeFeature::eEta)
0176             .value("ClusterX",
0177                    TrackFindingAlgorithmExaTrkX::NodeFeature::eClusterLoc0)
0178             .value("ClusterY",
0179                    TrackFindingAlgorithmExaTrkX::NodeFeature::eClusterLoc1)
0180             .value("CellCount",
0181                    TrackFindingAlgorithmExaTrkX::NodeFeature::eCellCount)
0182             .value("ChargeSum",
0183                    TrackFindingAlgorithmExaTrkX::NodeFeature::eChargeSum);
0184 
0185     // clang-format off
0186 #define ADD_FEATURE_ENUMS(n) \
0187   nodeFeatureEnum \
0188     .value("Cluster" #n "X", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##X) \
0189     .value("Cluster" #n "Y", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Y) \
0190     .value("Cluster" #n "Z", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Z) \
0191     .value("Cluster" #n "R", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##R) \
0192     .value("Cluster" #n "Phi", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Phi) \
0193     .value("Cluster" #n "Eta", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Eta) \
0194     .value("CellCount" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eCellCount##n) \
0195     .value("ChargeSum" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eChargeSum##n) \
0196     .value("LocEta" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocEta##n) \
0197     .value("LocPhi" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocPhi##n) \
0198     .value("LocDir0" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir0##n) \
0199     .value("LocDir1" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir1##n) \
0200     .value("LocDir2" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir2##n) \
0201     .value("LengthDir0" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir0##n) \
0202     .value("LengthDir1" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir1##n) \
0203     .value("LengthDir2" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir2##n) \
0204     .value("GlobEta" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eGlobEta##n) \
0205     .value("GlobPhi" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eGlobPhi##n) \
0206     .value("EtaAngle" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eEtaAngle##n) \
0207     .value("PhiAngle" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::ePhiAngle##n)
0208     // clang-format on
0209 
0210     ADD_FEATURE_ENUMS(1);
0211     ADD_FEATURE_ENUMS(2);
0212 
0213 #undef ADD_FEATURE_ENUMS
0214   }
0215 
0216   ACTS_PYTHON_DECLARE_ALGORITHM(
0217       ActsExamples::TrackFindingAlgorithmExaTrkX, mex,
0218       "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputClusters,
0219       inputTruthGraph, outputProtoTracks, outputGraph, graphConstructor,
0220       edgeClassifiers, trackBuilder, nodeFeatures, featureScales,
0221       minMeasurementsPerTrack, geometryIdMap);
0222 
0223   {
0224     auto cls =
0225         py::class_<Acts::ExaTrkXHook, std::shared_ptr<Acts::ExaTrkXHook>>(
0226             mex, "ExaTrkXHook");
0227   }
0228 
0229   {
0230     using Class = Acts::TorchTruthGraphMetricsHook;
0231 
0232     auto cls = py::class_<Class, Acts::ExaTrkXHook, std::shared_ptr<Class>>(
0233                    mex, "TorchTruthGraphMetricsHook")
0234                    .def(py::init([](const std::vector<std::int64_t> &g,
0235                                     Logging::Level lvl) {
0236                      return std::make_shared<Class>(
0237                          g, getDefaultLogger("PipelineHook", lvl));
0238                    }));
0239   }
0240 
0241   {
0242     using Class = Acts::ExaTrkXPipeline;
0243 
0244     auto cls =
0245         py::class_<Class, std::shared_ptr<Class>>(mex, "ExaTrkXPipeline")
0246             .def(py::init(
0247                      [](std::shared_ptr<GraphConstructionBase> g,
0248                         std::vector<std::shared_ptr<EdgeClassificationBase>> e,
0249                         std::shared_ptr<TrackBuildingBase> t,
0250                         Logging::Level lvl) {
0251                        return std::make_shared<Class>(
0252                            g, e, t, getDefaultLogger("MetricLearning", lvl));
0253                      }),
0254                  py::arg("graphConstructor"), py::arg("edgeClassifiers"),
0255                  py::arg("trackBuilder"), py::arg("level"))
0256             .def("run", &ExaTrkXPipeline::run, py::arg("features"),
0257                  py::arg("moduleIds"), py::arg("spacepoints"),
0258                  py::arg("hook") = Acts::ExaTrkXHook{},
0259                  py::arg("timing") = nullptr);
0260   }
0261 
0262   ACTS_PYTHON_DECLARE_ALGORITHM(
0263       ActsExamples::PrototracksToParameters, mex, "PrototracksToParameters",
0264       inputProtoTracks, inputSpacePoints, outputSeeds, outputParameters,
0265       outputProtoTracks, geometry, magneticField, buildTightSeeds);
0266 
0267   ACTS_PYTHON_DECLARE_ALGORITHM(
0268       ActsExamples::TrackFindingFromPrototrackAlgorithm, mex,
0269       "TrackFindingFromPrototrackAlgorithm", inputProtoTracks,
0270       inputMeasurements, inputInitialTrackParameters, outputTracks,
0271       measurementSelectorCfg, trackingGeometry, magneticField, findTracks, tag);
0272 }
0273 
0274 }  // namespace Acts::Python