File indexing completed on 2025-01-18 09:12:02
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp"
0010 #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
0011 #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
0012 #include "Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp"
0013 #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
0014 #include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
0015 #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
0016 #include "Acts/Plugins/Python/Utilities.hpp"
0017 #include "ActsExamples/TrackFindingExaTrkX/PrototracksToParameters.hpp"
0018 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp"
0019 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingFromPrototrackAlgorithm.hpp"
0020 #include "ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp"
0021
0022 #include <memory>
0023
0024 #include <pybind11/functional.h>
0025 #include <pybind11/pybind11.h>
0026 #include <pybind11/stl.h>
0027
0028 namespace py = pybind11;
0029
0030 using namespace ActsExamples;
0031 using namespace Acts;
0032 using namespace py::literals;
0033
0034 namespace Acts::Python {
0035
0036 void addExaTrkXTrackFinding(Context &ctx) {
0037 auto [m, mex] = ctx.get("main", "examples");
0038
0039 {
0040 using C = Acts::GraphConstructionBase;
0041 auto c = py::class_<C, std::shared_ptr<C>>(mex, "GraphConstructionBase");
0042 }
0043 {
0044 using C = Acts::EdgeClassificationBase;
0045 auto c = py::class_<C, std::shared_ptr<C>>(mex, "EdgeClassificationBase");
0046 }
0047 {
0048 using C = Acts::TrackBuildingBase;
0049 auto c = py::class_<C, std::shared_ptr<C>>(mex, "TrackBuildingBase");
0050 }
0051
0052 #ifdef ACTS_EXATRKX_TORCH_BACKEND
0053 {
0054 using Alg = Acts::TorchMetricLearning;
0055 using Config = Alg::Config;
0056
0057 auto alg =
0058 py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
0059 mex, "TorchMetricLearning")
0060 .def(py::init([](const Config &c, Logging::Level lvl) {
0061 return std::make_shared<Alg>(
0062 c, getDefaultLogger("MetricLearning", lvl));
0063 }),
0064 py::arg("config"), py::arg("level"))
0065 .def_property_readonly("config", &Alg::config);
0066
0067 auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0068 ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0069 ACTS_PYTHON_MEMBER(modelPath);
0070 ACTS_PYTHON_MEMBER(selectedFeatures);
0071 ACTS_PYTHON_MEMBER(embeddingDim);
0072 ACTS_PYTHON_MEMBER(rVal);
0073 ACTS_PYTHON_MEMBER(knnVal);
0074 ACTS_PYTHON_MEMBER(deviceID);
0075 ACTS_PYTHON_STRUCT_END();
0076 }
0077 {
0078 using Alg = Acts::TorchEdgeClassifier;
0079 using Config = Alg::Config;
0080
0081 auto alg =
0082 py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
0083 mex, "TorchEdgeClassifier")
0084 .def(py::init([](const Config &c, Logging::Level lvl) {
0085 return std::make_shared<Alg>(
0086 c, getDefaultLogger("EdgeClassifier", lvl));
0087 }),
0088 py::arg("config"), py::arg("level"))
0089 .def_property_readonly("config", &Alg::config);
0090
0091 auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0092 ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0093 ACTS_PYTHON_MEMBER(modelPath);
0094 ACTS_PYTHON_MEMBER(selectedFeatures);
0095 ACTS_PYTHON_MEMBER(cut);
0096 ACTS_PYTHON_MEMBER(nChunks);
0097 ACTS_PYTHON_MEMBER(undirected);
0098 ACTS_PYTHON_MEMBER(deviceID);
0099 ACTS_PYTHON_MEMBER(useEdgeFeatures);
0100 ACTS_PYTHON_STRUCT_END();
0101 }
0102 {
0103 using Alg = Acts::BoostTrackBuilding;
0104
0105 auto alg = py::class_<Alg, Acts::TrackBuildingBase, std::shared_ptr<Alg>>(
0106 mex, "BoostTrackBuilding")
0107 .def(py::init([](Logging::Level lvl) {
0108 return std::make_shared<Alg>(
0109 getDefaultLogger("EdgeClassifier", lvl));
0110 }),
0111 py::arg("level"));
0112 }
0113 #endif
0114
0115 #ifdef ACTS_EXATRKX_ONNX_BACKEND
0116 {
0117 using Alg = Acts::OnnxMetricLearning;
0118 using Config = Alg::Config;
0119
0120 auto alg =
0121 py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
0122 mex, "OnnxMetricLearning")
0123 .def(py::init([](const Config &c, Logging::Level lvl) {
0124 return std::make_shared<Alg>(
0125 c, getDefaultLogger("MetricLearning", lvl));
0126 }),
0127 py::arg("config"), py::arg("level"))
0128 .def_property_readonly("config", &Alg::config);
0129
0130 auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0131 ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0132 ACTS_PYTHON_MEMBER(modelPath);
0133 ACTS_PYTHON_MEMBER(spacepointFeatures);
0134 ACTS_PYTHON_MEMBER(embeddingDim);
0135 ACTS_PYTHON_MEMBER(rVal);
0136 ACTS_PYTHON_MEMBER(knnVal);
0137 ACTS_PYTHON_STRUCT_END();
0138 }
0139 {
0140 using Alg = Acts::OnnxEdgeClassifier;
0141 using Config = Alg::Config;
0142
0143 auto alg =
0144 py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
0145 mex, "OnnxEdgeClassifier")
0146 .def(py::init([](const Config &c, Logging::Level lvl) {
0147 return std::make_shared<Alg>(
0148 c, getDefaultLogger("EdgeClassifier", lvl));
0149 }),
0150 py::arg("config"), py::arg("level"))
0151 .def_property_readonly("config", &Alg::config);
0152
0153 auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0154 ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0155 ACTS_PYTHON_MEMBER(modelPath);
0156 ACTS_PYTHON_MEMBER(cut);
0157 ACTS_PYTHON_STRUCT_END();
0158 }
0159 #endif
0160
0161 ACTS_PYTHON_DECLARE_ALGORITHM(
0162 ActsExamples::TruthGraphBuilder, mex, "TruthGraphBuilder",
0163 inputSpacePoints, inputSimHits, inputParticles,
0164 inputMeasurementSimHitsMap, inputMeasurementParticlesMap, outputGraph,
0165 targetMinPT, targetMinSize, uniqueModules);
0166
0167 {
0168 auto nodeFeatureEnum =
0169 py::enum_<TrackFindingAlgorithmExaTrkX::NodeFeature>(mex, "NodeFeature")
0170 .value("R", TrackFindingAlgorithmExaTrkX::NodeFeature::eR)
0171 .value("Phi", TrackFindingAlgorithmExaTrkX::NodeFeature::ePhi)
0172 .value("Z", TrackFindingAlgorithmExaTrkX::NodeFeature::eZ)
0173 .value("X", TrackFindingAlgorithmExaTrkX::NodeFeature::eX)
0174 .value("Y", TrackFindingAlgorithmExaTrkX::NodeFeature::eY)
0175 .value("Eta", TrackFindingAlgorithmExaTrkX::NodeFeature::eEta)
0176 .value("ClusterX",
0177 TrackFindingAlgorithmExaTrkX::NodeFeature::eClusterLoc0)
0178 .value("ClusterY",
0179 TrackFindingAlgorithmExaTrkX::NodeFeature::eClusterLoc1)
0180 .value("CellCount",
0181 TrackFindingAlgorithmExaTrkX::NodeFeature::eCellCount)
0182 .value("ChargeSum",
0183 TrackFindingAlgorithmExaTrkX::NodeFeature::eChargeSum);
0184
0185
0186 #define ADD_FEATURE_ENUMS(n) \
0187 nodeFeatureEnum \
0188 .value("Cluster" #n "X", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##X) \
0189 .value("Cluster" #n "Y", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Y) \
0190 .value("Cluster" #n "Z", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Z) \
0191 .value("Cluster" #n "R", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##R) \
0192 .value("Cluster" #n "Phi", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Phi) \
0193 .value("Cluster" #n "Eta", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Eta) \
0194 .value("CellCount" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eCellCount##n) \
0195 .value("ChargeSum" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eChargeSum##n) \
0196 .value("LocEta" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocEta##n) \
0197 .value("LocPhi" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocPhi##n) \
0198 .value("LocDir0" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir0##n) \
0199 .value("LocDir1" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir1##n) \
0200 .value("LocDir2" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir2##n) \
0201 .value("LengthDir0" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir0##n) \
0202 .value("LengthDir1" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir1##n) \
0203 .value("LengthDir2" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir2##n) \
0204 .value("GlobEta" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eGlobEta##n) \
0205 .value("GlobPhi" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eGlobPhi##n) \
0206 .value("EtaAngle" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eEtaAngle##n) \
0207 .value("PhiAngle" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::ePhiAngle##n)
0208
0209
0210 ADD_FEATURE_ENUMS(1);
0211 ADD_FEATURE_ENUMS(2);
0212
0213 #undef ADD_FEATURE_ENUMS
0214 }
0215
0216 ACTS_PYTHON_DECLARE_ALGORITHM(
0217 ActsExamples::TrackFindingAlgorithmExaTrkX, mex,
0218 "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputClusters,
0219 inputTruthGraph, outputProtoTracks, outputGraph, graphConstructor,
0220 edgeClassifiers, trackBuilder, nodeFeatures, featureScales,
0221 minMeasurementsPerTrack, geometryIdMap);
0222
0223 {
0224 auto cls =
0225 py::class_<Acts::ExaTrkXHook, std::shared_ptr<Acts::ExaTrkXHook>>(
0226 mex, "ExaTrkXHook");
0227 }
0228
0229 {
0230 using Class = Acts::TorchTruthGraphMetricsHook;
0231
0232 auto cls = py::class_<Class, Acts::ExaTrkXHook, std::shared_ptr<Class>>(
0233 mex, "TorchTruthGraphMetricsHook")
0234 .def(py::init([](const std::vector<std::int64_t> &g,
0235 Logging::Level lvl) {
0236 return std::make_shared<Class>(
0237 g, getDefaultLogger("PipelineHook", lvl));
0238 }));
0239 }
0240
0241 {
0242 using Class = Acts::ExaTrkXPipeline;
0243
0244 auto cls =
0245 py::class_<Class, std::shared_ptr<Class>>(mex, "ExaTrkXPipeline")
0246 .def(py::init(
0247 [](std::shared_ptr<GraphConstructionBase> g,
0248 std::vector<std::shared_ptr<EdgeClassificationBase>> e,
0249 std::shared_ptr<TrackBuildingBase> t,
0250 Logging::Level lvl) {
0251 return std::make_shared<Class>(
0252 g, e, t, getDefaultLogger("MetricLearning", lvl));
0253 }),
0254 py::arg("graphConstructor"), py::arg("edgeClassifiers"),
0255 py::arg("trackBuilder"), py::arg("level"))
0256 .def("run", &ExaTrkXPipeline::run, py::arg("features"),
0257 py::arg("moduleIds"), py::arg("spacepoints"),
0258 py::arg("hook") = Acts::ExaTrkXHook{},
0259 py::arg("timing") = nullptr);
0260 }
0261
0262 ACTS_PYTHON_DECLARE_ALGORITHM(
0263 ActsExamples::PrototracksToParameters, mex, "PrototracksToParameters",
0264 inputProtoTracks, inputSpacePoints, outputSeeds, outputParameters,
0265 outputProtoTracks, geometry, magneticField, buildTightSeeds);
0266
0267 ACTS_PYTHON_DECLARE_ALGORITHM(
0268 ActsExamples::TrackFindingFromPrototrackAlgorithm, mex,
0269 "TrackFindingFromPrototrackAlgorithm", inputProtoTracks,
0270 inputMeasurements, inputInitialTrackParameters, outputTracks,
0271 measurementSelectorCfg, trackingGeometry, magneticField, findTracks, tag);
0272 }
0273
0274 }