Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-17 07:47:27

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsPlugins/Gnn/BoostTrackBuilding.hpp"
0010 #include "ActsPlugins/Gnn/CudaTrackBuilding.hpp"
0011 #include "ActsPlugins/Gnn/GnnPipeline.hpp"
0012 #include "ActsPlugins/Gnn/ModuleMapCuda.hpp"
0013 #include "ActsPlugins/Gnn/OnnxEdgeClassifier.hpp"
0014 #include "ActsPlugins/Gnn/TensorRTEdgeClassifier.hpp"
0015 #include "ActsPlugins/Gnn/TorchEdgeClassifier.hpp"
0016 #include "ActsPlugins/Gnn/TorchMetricLearning.hpp"
0017 #include "ActsPlugins/Gnn/TruthGraphMetricsHook.hpp"
0018 #include "ActsPython/Utilities/Macros.hpp"
0019 
0020 #include <boost/preprocessor/if.hpp>
0021 #include <boost/vmd/tuple/size.hpp>
0022 #include <pybind11/functional.h>
0023 #include <pybind11/pybind11.h>
0024 #include <pybind11/stl.h>
0025 
0026 #define ACTS_PYTHON_DECLARE_GNN_STAGE(algorithm, base, mod, ...)            \
0027   do {                                                                      \
0028     using namespace Acts;                                                   \
0029                                                                             \
0030     using Alg = algorithm;                                                  \
0031     using Config = Alg::Config;                                             \
0032     auto alg = py::class_<Alg, base, std::shared_ptr<Alg>>(mod, #algorithm) \
0033                    .def(py::init([](const Config &c, Logging::Level lvl) {  \
0034                           return std::make_shared<Alg>(                     \
0035                               c, getDefaultLogger(#algorithm, lvl));        \
0036                         }),                                                 \
0037                         py::arg("config"), py::arg("level"))                \
0038                    .def_property_readonly("config", &Alg::config);          \
0039                                                                             \
0040     auto c = py::class_<Config>(alg, "Config").def(py::init<>());           \
0041     BOOST_PP_IF(BOOST_VMD_IS_EMPTY(__VA_ARGS__), BOOST_PP_EMPTY(),          \
0042                 ACTS_PYTHON_STRUCT(c, __VA_ARGS__));                        \
0043   } while (0)
0044 
0045 namespace py = pybind11;
0046 
0047 using namespace Acts;
0048 using namespace ActsPlugins;
0049 using namespace ActsPython;
0050 using namespace py::literals;
0051 
0052 PYBIND11_MODULE(ActsPluginsPythonBindingsGnn, gnn) {
0053   {
0054     using C = GraphConstructionBase;
0055     auto c = py::class_<C, std::shared_ptr<C>>(gnn, "GraphConstructionBase");
0056   }
0057   {
0058     using C = EdgeClassificationBase;
0059     auto c = py::class_<C, std::shared_ptr<C>>(gnn, "EdgeClassificationBase");
0060   }
0061   {
0062     using C = TrackBuildingBase;
0063     auto c = py::class_<C, std::shared_ptr<C>>(gnn, "TrackBuildingBase");
0064   }
0065 
0066   ACTS_PYTHON_DECLARE_GNN_STAGE(BoostTrackBuilding, TrackBuildingBase, gnn);
0067 
0068 #ifdef ACTS_GNN_TORCH_BACKEND
0069   ACTS_PYTHON_DECLARE_GNN_STAGE(TorchMetricLearning, GraphConstructionBase, gnn,
0070                                 modelPath, selectedFeatures, embeddingDim, rVal,
0071                                 knnVal, deviceID);
0072 
0073   ACTS_PYTHON_DECLARE_GNN_STAGE(TorchEdgeClassifier, EdgeClassificationBase,
0074                                 gnn, modelPath, selectedFeatures, cut, nChunks,
0075                                 undirected, deviceID, useEdgeFeatures);
0076 #endif
0077 
0078 #ifdef ACTS_GNN_WITH_TENSORRT
0079   ACTS_PYTHON_DECLARE_GNN_STAGE(TensorRTEdgeClassifier, EdgeClassificationBase,
0080                                 gnn, modelPath, selectedFeatures, cut,
0081                                 numExecutionContexts);
0082 #endif
0083 
0084 #ifdef ACTS_GNN_WITH_CUDA
0085   ACTS_PYTHON_DECLARE_GNN_STAGE(CudaTrackBuilding, TrackBuildingBase, gnn,
0086                                 useOneBlockImplementation, doJunctionRemoval,
0087                                 minCandidateSize);
0088 #endif
0089 
0090 #ifdef ACTS_GNN_ONNX_BACKEND
0091   ACTS_PYTHON_DECLARE_GNN_STAGE(OnnxEdgeClassifier, EdgeClassificationBase, gnn,
0092                                 modelPath, cut);
0093 #endif
0094 
0095 #ifdef ACTS_GNN_WITH_MODULEMAP
0096   ACTS_PYTHON_DECLARE_GNN_STAGE(
0097       ModuleMapCuda, GraphConstructionBase, gnn, moduleMapPath, rScale,
0098       phiScale, zScale, etaScale, moreParallel, gpuDevice, gpuBlocks, epsilon);
0099 #endif
0100 
0101   {
0102     auto cls = py::class_<GnnHook, std::shared_ptr<GnnHook>>(gnn, "GnnHook");
0103   }
0104 
0105   {
0106     using Class = TruthGraphMetricsHook;
0107 
0108     auto cls = py::class_<Class, GnnHook, std::shared_ptr<Class>>(
0109                    gnn, "TruthGraphMetricsHook")
0110                    .def(py::init([](const std::vector<std::int64_t> &g,
0111                                     Logging::Level lvl) {
0112                      return std::make_shared<Class>(
0113                          g, getDefaultLogger("TruthGraphHook", lvl));
0114                    }));
0115   }
0116 
0117   {
0118     auto cls = py::class_<Device>(gnn, "Device")
0119                    .def_static("Cpu", &Device::Cpu)
0120                    .def_static("Cuda", &Device::Cuda, py::arg("index") = 0);
0121   }
0122 
0123   {
0124     using Class = GnnPipeline;
0125 
0126     auto cls =
0127         py::class_<Class, std::shared_ptr<Class>>(gnn, "GnnPipeline")
0128             .def(py::init(
0129                      [](std::shared_ptr<GraphConstructionBase> g,
0130                         std::vector<std::shared_ptr<EdgeClassificationBase>> e,
0131                         std::shared_ptr<TrackBuildingBase> t,
0132                         Logging::Level lvl) {
0133                        return std::make_shared<Class>(
0134                            g, e, t, getDefaultLogger("MetricLearning", lvl));
0135                      }),
0136                  py::arg("graphConstructor"), py::arg("edgeClassifiers"),
0137                  py::arg("trackBuilder"), py::arg("level"))
0138             .def("run", &GnnPipeline::run, py::arg("features"),
0139                  py::arg("moduleIds"), py::arg("spacepoints"),
0140                  py::arg("device") = Device::Cuda(0),
0141                  py::arg("hook") = GnnHook{}, py::arg("timing") = nullptr);
0142   }
0143 }