File indexing completed on 2026-01-09 09:26:46
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/TrackFindingGnn/PrototracksToParameters.hpp"
0010 #include "ActsExamples/TrackFindingGnn/TrackFindingAlgorithmGnn.hpp"
0011 #include "ActsExamples/TrackFindingGnn/TrackFindingFromPrototrackAlgorithm.hpp"
0012 #include "ActsExamples/TrackFindingGnn/TruthGraphBuilder.hpp"
0013 #include "ActsPlugins/Gnn/BoostTrackBuilding.hpp"
0014 #include "ActsPlugins/Gnn/CudaTrackBuilding.hpp"
0015 #include "ActsPlugins/Gnn/GnnPipeline.hpp"
0016 #include "ActsPlugins/Gnn/ModuleMapCuda.hpp"
0017 #include "ActsPlugins/Gnn/OnnxEdgeClassifier.hpp"
0018 #include "ActsPlugins/Gnn/TensorRTEdgeClassifier.hpp"
0019 #include "ActsPlugins/Gnn/TorchEdgeClassifier.hpp"
0020 #include "ActsPlugins/Gnn/TorchMetricLearning.hpp"
0021 #include "ActsPlugins/Gnn/TruthGraphMetricsHook.hpp"
0022 #include "ActsPython/Utilities/Helpers.hpp"
0023 #include "ActsPython/Utilities/Macros.hpp"
0024
0025 #include <memory>
0026
0027 #include <boost/preprocessor/if.hpp>
0028 #include <boost/vmd/tuple/size.hpp>
0029 #include <pybind11/functional.h>
0030 #include <pybind11/pybind11.h>
0031 #include <pybind11/stl.h>
0032
0033 #define ACTS_PYTHON_DECLARE_GNN_STAGE(algorithm, base, mod, ...) \
0034 do { \
0035 using namespace Acts; \
0036 \
0037 using Alg = algorithm; \
0038 using Config = Alg::Config; \
0039 auto alg = py::class_<Alg, base, std::shared_ptr<Alg>>(mod, #algorithm) \
0040 .def(py::init([](const Config &c, Logging::Level lvl) { \
0041 return std::make_shared<Alg>( \
0042 c, getDefaultLogger(#algorithm, lvl)); \
0043 }), \
0044 py::arg("config"), py::arg("level")) \
0045 .def_property_readonly("config", &Alg::config); \
0046 \
0047 auto c = py::class_<Config>(alg, "Config").def(py::init<>()); \
0048 BOOST_PP_IF(BOOST_VMD_IS_EMPTY(__VA_ARGS__), BOOST_PP_EMPTY(), \
0049 ACTS_PYTHON_STRUCT(c, __VA_ARGS__)); \
0050 } while (0)
0051
0052 namespace py = pybind11;
0053
0054 using namespace Acts;
0055 using namespace ActsPlugins;
0056 using namespace ActsExamples;
0057 using namespace ActsPython;
0058 using namespace py::literals;
0059
0060 PYBIND11_MODULE(ActsExamplesPythonBindingsGnn, gnn) {
0061 {
0062 using C = GraphConstructionBase;
0063 auto c = py::class_<C, std::shared_ptr<C>>(gnn, "GraphConstructionBase");
0064 }
0065 {
0066 using C = EdgeClassificationBase;
0067 auto c = py::class_<C, std::shared_ptr<C>>(gnn, "EdgeClassificationBase");
0068 }
0069 {
0070 using C = TrackBuildingBase;
0071 auto c = py::class_<C, std::shared_ptr<C>>(gnn, "TrackBuildingBase");
0072 }
0073
0074 ACTS_PYTHON_DECLARE_GNN_STAGE(BoostTrackBuilding, TrackBuildingBase, gnn);
0075
0076 #ifdef ACTS_GNN_TORCH_BACKEND
0077 ACTS_PYTHON_DECLARE_GNN_STAGE(TorchMetricLearning, GraphConstructionBase, gnn,
0078 modelPath, selectedFeatures, embeddingDim, rVal,
0079 knnVal, deviceID);
0080
0081 ACTS_PYTHON_DECLARE_GNN_STAGE(TorchEdgeClassifier, EdgeClassificationBase,
0082 gnn, modelPath, selectedFeatures, cut, nChunks,
0083 undirected, deviceID, useEdgeFeatures);
0084 #endif
0085
0086 #ifdef ACTS_GNN_WITH_TENSORRT
0087 ACTS_PYTHON_DECLARE_GNN_STAGE(TensorRTEdgeClassifier, EdgeClassificationBase,
0088 gnn, modelPath, selectedFeatures, cut,
0089 nugnnecutionContexts);
0090 #endif
0091
0092 #ifdef ACTS_GNN_WITH_CUDA
0093 ACTS_PYTHON_DECLARE_GNN_STAGE(CudaTrackBuilding, TrackBuildingBase, gnn,
0094 useOneBlockImplementation, doJunctionRemoval,
0095 minCandidateSize);
0096 #endif
0097
0098 #ifdef ACTS_GNN_ONNX_BACKEND
0099 ACTS_PYTHON_DECLARE_GNN_STAGE(OnnxEdgeClassifier, EdgeClassificationBase, gnn,
0100 modelPath, cut);
0101 #endif
0102
0103 #ifdef ACTS_GNN_WITH_MODULEMAP
0104 ACTS_PYTHON_DECLARE_GNN_STAGE(
0105 ModuleMapCuda, GraphConstructionBase, gnn, moduleMapPath, rScale,
0106 phiScale, zScale, etaScale, moreParallel, gpuDevice, gpuBlocks, epsilon);
0107 #endif
0108
0109 ACTS_PYTHON_DECLARE_ALGORITHM(TruthGraphBuilder, gnn, "TruthGraphBuilder",
0110 inputSpacePoints, inputSimHits, inputParticles,
0111 inputMeasurementSimHitsMap,
0112 inputMeasurementParticlesMap, outputGraph,
0113 targetMinPT, targetMinSize, uniqueModules);
0114
0115 {
0116 auto nodeFeatureEnum =
0117 py::enum_<TrackFindingAlgorithmGnn::NodeFeature>(gnn, "NodeFeature")
0118 .value("R", TrackFindingAlgorithmGnn::NodeFeature::eR)
0119 .value("Phi", TrackFindingAlgorithmGnn::NodeFeature::ePhi)
0120 .value("Z", TrackFindingAlgorithmGnn::NodeFeature::eZ)
0121 .value("X", TrackFindingAlgorithmGnn::NodeFeature::eX)
0122 .value("Y", TrackFindingAlgorithmGnn::NodeFeature::eY)
0123 .value("Eta", TrackFindingAlgorithmGnn::NodeFeature::eEta)
0124 .value("ClusterX",
0125 TrackFindingAlgorithmGnn::NodeFeature::eClusterLoc0)
0126 .value("ClusterY",
0127 TrackFindingAlgorithmGnn::NodeFeature::eClusterLoc1)
0128 .value("CellCount",
0129 TrackFindingAlgorithmGnn::NodeFeature::eCellCount)
0130 .value("ChargeSum",
0131 TrackFindingAlgorithmGnn::NodeFeature::eChargeSum);
0132
0133
0134 #define ADD_FEATURE_ENUMS(n) \
0135 nodeFeatureEnum \
0136 .value("Cluster" #n "X", TrackFindingAlgorithmGnn::NodeFeature::eCluster##n##X) \
0137 .value("Cluster" #n "Y", TrackFindingAlgorithmGnn::NodeFeature::eCluster##n##Y) \
0138 .value("Cluster" #n "Z", TrackFindingAlgorithmGnn::NodeFeature::eCluster##n##Z) \
0139 .value("Cluster" #n "R", TrackFindingAlgorithmGnn::NodeFeature::eCluster##n##R) \
0140 .value("Cluster" #n "Phi", TrackFindingAlgorithmGnn::NodeFeature::eCluster##n##Phi) \
0141 .value("Cluster" #n "Eta", TrackFindingAlgorithmGnn::NodeFeature::eCluster##n##Eta) \
0142 .value("CellCount" #n, TrackFindingAlgorithmGnn::NodeFeature::eCellCount##n) \
0143 .value("ChargeSum" #n, TrackFindingAlgorithmGnn::NodeFeature::eChargeSum##n) \
0144 .value("LocEta" #n, TrackFindingAlgorithmGnn::NodeFeature::eLocEta##n) \
0145 .value("LocPhi" #n, TrackFindingAlgorithmGnn::NodeFeature::eLocPhi##n) \
0146 .value("LocDir0" #n, TrackFindingAlgorithmGnn::NodeFeature::eLocDir0##n) \
0147 .value("LocDir1" #n, TrackFindingAlgorithmGnn::NodeFeature::eLocDir1##n) \
0148 .value("LocDir2" #n, TrackFindingAlgorithmGnn::NodeFeature::eLocDir2##n) \
0149 .value("LengthDir0" #n, TrackFindingAlgorithmGnn::NodeFeature::eLengthDir0##n) \
0150 .value("LengthDir1" #n, TrackFindingAlgorithmGnn::NodeFeature::eLengthDir1##n) \
0151 .value("LengthDir2" #n, TrackFindingAlgorithmGnn::NodeFeature::eLengthDir2##n) \
0152 .value("GlobEta" #n, TrackFindingAlgorithmGnn::NodeFeature::eGlobEta##n) \
0153 .value("GlobPhi" #n, TrackFindingAlgorithmGnn::NodeFeature::eGlobPhi##n) \
0154 .value("EtaAngle" #n, TrackFindingAlgorithmGnn::NodeFeature::eEtaAngle##n) \
0155 .value("PhiAngle" #n, TrackFindingAlgorithmGnn::NodeFeature::ePhiAngle##n)
0156
0157
0158 ADD_FEATURE_ENUMS(1);
0159 ADD_FEATURE_ENUMS(2);
0160
0161 #undef ADD_FEATURE_ENUMS
0162 }
0163
0164 ACTS_PYTHON_DECLARE_ALGORITHM(
0165 TrackFindingAlgorithmGnn, gnn, "TrackFindingAlgorithmGnn",
0166 inputSpacePoints, inputClusters, inputTruthGraph, outputProtoTracks,
0167 outputGraph, graphConstructor, edgeClassifiers, trackBuilder,
0168 nodeFeatures, featureScales, minMeasurementsPerTrack, geometryIdMap);
0169
0170 { auto cls = py::class_<GnnHook, std::shared_ptr<GnnHook>>(gnn, "GnnHook"); }
0171
0172 {
0173 using Class = TruthGraphMetricsHook;
0174
0175 auto cls = py::class_<Class, GnnHook, std::shared_ptr<Class>>(
0176 gnn, "TruthGraphMetricsHook")
0177 .def(py::init([](const std::vector<std::int64_t> &g,
0178 Logging::Level lvl) {
0179 return std::make_shared<Class>(
0180 g, getDefaultLogger("TruthGraphHook", lvl));
0181 }));
0182 }
0183
0184 {
0185 auto cls = py::class_<Device>(gnn, "Device")
0186 .def_static("Cpu", &Device::Cpu)
0187 .def_static("Cuda", &Device::Cuda, py::arg("index") = 0);
0188 }
0189
0190 {
0191 using Class = GnnPipeline;
0192
0193 auto cls =
0194 py::class_<Class, std::shared_ptr<Class>>(gnn, "GnnPipeline")
0195 .def(py::init(
0196 [](std::shared_ptr<GraphConstructionBase> g,
0197 std::vector<std::shared_ptr<EdgeClassificationBase>> e,
0198 std::shared_ptr<TrackBuildingBase> t,
0199 Logging::Level lvl) {
0200 return std::make_shared<Class>(
0201 g, e, t, getDefaultLogger("MetricLearning", lvl));
0202 }),
0203 py::arg("graphConstructor"), py::arg("edgeClassifiers"),
0204 py::arg("trackBuilder"), py::arg("level"))
0205 .def("run", &GnnPipeline::run, py::arg("features"),
0206 py::arg("moduleIds"), py::arg("spacepoints"),
0207 py::arg("device") = Device::Cuda(0),
0208 py::arg("hook") = GnnHook{}, py::arg("timing") = nullptr);
0209 }
0210
0211 ACTS_PYTHON_DECLARE_ALGORITHM(
0212 PrototracksToParameters, gnn, "PrototracksToParameters", inputProtoTracks,
0213 inputSpacePoints, outputSeeds, outputParameters, outputProtoTracks,
0214 geometry, magneticField, buildTightSeeds);
0215
0216 ACTS_PYTHON_DECLARE_ALGORITHM(
0217 TrackFindingFromPrototrackAlgorithm, gnn,
0218 "TrackFindingFromPrototrackAlgorithm", inputProtoTracks,
0219 inputMeasurements, inputInitialTrackParameters, outputTracks,
0220 measurementSelectorCfg, trackingGeometry, magneticField, findTracks, tag);
0221 }