File indexing completed on 2026-04-17 07:47:27
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsPlugins/Gnn/BoostTrackBuilding.hpp"
0010 #include "ActsPlugins/Gnn/CudaTrackBuilding.hpp"
0011 #include "ActsPlugins/Gnn/GnnPipeline.hpp"
0012 #include "ActsPlugins/Gnn/ModuleMapCuda.hpp"
0013 #include "ActsPlugins/Gnn/OnnxEdgeClassifier.hpp"
0014 #include "ActsPlugins/Gnn/TensorRTEdgeClassifier.hpp"
0015 #include "ActsPlugins/Gnn/TorchEdgeClassifier.hpp"
0016 #include "ActsPlugins/Gnn/TorchMetricLearning.hpp"
0017 #include "ActsPlugins/Gnn/TruthGraphMetricsHook.hpp"
0018 #include "ActsPython/Utilities/Macros.hpp"
0019
0020 #include <boost/preprocessor/if.hpp>
0021 #include <boost/vmd/tuple/size.hpp>
0022 #include <pybind11/functional.h>
0023 #include <pybind11/pybind11.h>
0024 #include <pybind11/stl.h>
0025
0026 #define ACTS_PYTHON_DECLARE_GNN_STAGE(algorithm, base, mod, ...) \
0027 do { \
0028 using namespace Acts; \
0029 \
0030 using Alg = algorithm; \
0031 using Config = Alg::Config; \
0032 auto alg = py::class_<Alg, base, std::shared_ptr<Alg>>(mod, #algorithm) \
0033 .def(py::init([](const Config &c, Logging::Level lvl) { \
0034 return std::make_shared<Alg>( \
0035 c, getDefaultLogger(#algorithm, lvl)); \
0036 }), \
0037 py::arg("config"), py::arg("level")) \
0038 .def_property_readonly("config", &Alg::config); \
0039 \
0040 auto c = py::class_<Config>(alg, "Config").def(py::init<>()); \
0041 BOOST_PP_IF(BOOST_VMD_IS_EMPTY(__VA_ARGS__), BOOST_PP_EMPTY(), \
0042 ACTS_PYTHON_STRUCT(c, __VA_ARGS__)); \
0043 } while (0)
0044
0045 namespace py = pybind11;
0046
0047 using namespace Acts;
0048 using namespace ActsPlugins;
0049 using namespace ActsPython;
0050 using namespace py::literals;
0051
0052 PYBIND11_MODULE(ActsPluginsPythonBindingsGnn, gnn) {
0053 {
0054 using C = GraphConstructionBase;
0055 auto c = py::class_<C, std::shared_ptr<C>>(gnn, "GraphConstructionBase");
0056 }
0057 {
0058 using C = EdgeClassificationBase;
0059 auto c = py::class_<C, std::shared_ptr<C>>(gnn, "EdgeClassificationBase");
0060 }
0061 {
0062 using C = TrackBuildingBase;
0063 auto c = py::class_<C, std::shared_ptr<C>>(gnn, "TrackBuildingBase");
0064 }
0065
0066 ACTS_PYTHON_DECLARE_GNN_STAGE(BoostTrackBuilding, TrackBuildingBase, gnn);
0067
0068 #ifdef ACTS_GNN_TORCH_BACKEND
0069 ACTS_PYTHON_DECLARE_GNN_STAGE(TorchMetricLearning, GraphConstructionBase, gnn,
0070 modelPath, selectedFeatures, embeddingDim, rVal,
0071 knnVal, deviceID);
0072
0073 ACTS_PYTHON_DECLARE_GNN_STAGE(TorchEdgeClassifier, EdgeClassificationBase,
0074 gnn, modelPath, selectedFeatures, cut, nChunks,
0075 undirected, deviceID, useEdgeFeatures);
0076 #endif
0077
0078 #ifdef ACTS_GNN_WITH_TENSORRT
0079 ACTS_PYTHON_DECLARE_GNN_STAGE(TensorRTEdgeClassifier, EdgeClassificationBase,
0080 gnn, modelPath, selectedFeatures, cut,
0081 numExecutionContexts);
0082 #endif
0083
0084 #ifdef ACTS_GNN_WITH_CUDA
0085 ACTS_PYTHON_DECLARE_GNN_STAGE(CudaTrackBuilding, TrackBuildingBase, gnn,
0086 useOneBlockImplementation, doJunctionRemoval,
0087 minCandidateSize);
0088 #endif
0089
0090 #ifdef ACTS_GNN_ONNX_BACKEND
0091 ACTS_PYTHON_DECLARE_GNN_STAGE(OnnxEdgeClassifier, EdgeClassificationBase, gnn,
0092 modelPath, cut);
0093 #endif
0094
0095 #ifdef ACTS_GNN_WITH_MODULEMAP
0096 ACTS_PYTHON_DECLARE_GNN_STAGE(
0097 ModuleMapCuda, GraphConstructionBase, gnn, moduleMapPath, rScale,
0098 phiScale, zScale, etaScale, moreParallel, gpuDevice, gpuBlocks, epsilon);
0099 #endif
0100
0101 {
0102 auto cls = py::class_<GnnHook, std::shared_ptr<GnnHook>>(gnn, "GnnHook");
0103 }
0104
0105 {
0106 using Class = TruthGraphMetricsHook;
0107
0108 auto cls = py::class_<Class, GnnHook, std::shared_ptr<Class>>(
0109 gnn, "TruthGraphMetricsHook")
0110 .def(py::init([](const std::vector<std::int64_t> &g,
0111 Logging::Level lvl) {
0112 return std::make_shared<Class>(
0113 g, getDefaultLogger("TruthGraphHook", lvl));
0114 }));
0115 }
0116
0117 {
0118 auto cls = py::class_<Device>(gnn, "Device")
0119 .def_static("Cpu", &Device::Cpu)
0120 .def_static("Cuda", &Device::Cuda, py::arg("index") = 0);
0121 }
0122
0123 {
0124 using Class = GnnPipeline;
0125
0126 auto cls =
0127 py::class_<Class, std::shared_ptr<Class>>(gnn, "GnnPipeline")
0128 .def(py::init(
0129 [](std::shared_ptr<GraphConstructionBase> g,
0130 std::vector<std::shared_ptr<EdgeClassificationBase>> e,
0131 std::shared_ptr<TrackBuildingBase> t,
0132 Logging::Level lvl) {
0133 return std::make_shared<Class>(
0134 g, e, t, getDefaultLogger("MetricLearning", lvl));
0135 }),
0136 py::arg("graphConstructor"), py::arg("edgeClassifiers"),
0137 py::arg("trackBuilder"), py::arg("level"))
0138 .def("run", &GnnPipeline::run, py::arg("features"),
0139 py::arg("moduleIds"), py::arg("spacepoints"),
0140 py::arg("device") = Device::Cuda(0),
0141 py::arg("hook") = GnnHook{}, py::arg("timing") = nullptr);
0142 }
0143 }