Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-05 08:12:06

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp"
0010 #include "Acts/Plugins/ExaTrkX/CudaTrackBuilding.hpp"
0011 #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
0012 #include "Acts/Plugins/ExaTrkX/ModuleMapCuda.hpp"
0013 #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
0014 #include "Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp"
0015 #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
0016 #include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
0017 #include "Acts/Plugins/ExaTrkX/TruthGraphMetricsHook.hpp"
0018 #include "Acts/Plugins/Python/Utilities.hpp"
0019 #include "ActsExamples/TrackFindingExaTrkX/PrototracksToParameters.hpp"
0020 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp"
0021 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingFromPrototrackAlgorithm.hpp"
0022 #include "ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp"
0023 
0024 #include <memory>
0025 
0026 #include <boost/preprocessor/if.hpp>
0027 #include <boost/vmd/tuple/size.hpp>
0028 #include <pybind11/functional.h>
0029 #include <pybind11/pybind11.h>
0030 #include <pybind11/stl.h>
0031 
0032 #define ACTS_PYTHON_DECLARE_GNN_STAGE(algorithm, base, mod, ...)            \
0033   do {                                                                      \
0034     using namespace Acts;                                                   \
0035                                                                             \
0036     using Alg = algorithm;                                                  \
0037     using Config = Alg::Config;                                             \
0038     auto alg = py::class_<Alg, base, std::shared_ptr<Alg>>(mod, #algorithm) \
0039                    .def(py::init([](const Config &c, Logging::Level lvl) {  \
0040                           return std::make_shared<Alg>(                     \
0041                               c, getDefaultLogger(#algorithm, lvl));        \
0042                         }),                                                 \
0043                         py::arg("config"), py::arg("level"))                \
0044                    .def_property_readonly("config", &Alg::config);          \
0045                                                                             \
0046     auto c = py::class_<Config>(alg, "Config").def(py::init<>());           \
0047     BOOST_PP_IF(BOOST_VMD_IS_EMPTY(__VA_ARGS__), BOOST_PP_EMPTY(),          \
0048                 ACTS_PYTHON_STRUCT(c, __VA_ARGS__));                        \
0049   } while (0)
0050 
0051 namespace py = pybind11;
0052 
0053 using namespace ActsExamples;
0054 using namespace Acts;
0055 using namespace py::literals;
0056 
0057 namespace Acts::Python {
0058 
0059 void addExaTrkXTrackFinding(Context &ctx) {
0060   auto [m, mex] = ctx.get("main", "examples");
0061 
0062   {
0063     using C = Acts::GraphConstructionBase;
0064     auto c = py::class_<C, std::shared_ptr<C>>(mex, "GraphConstructionBase");
0065   }
0066   {
0067     using C = Acts::EdgeClassificationBase;
0068     auto c = py::class_<C, std::shared_ptr<C>>(mex, "EdgeClassificationBase");
0069   }
0070   {
0071     using C = Acts::TrackBuildingBase;
0072     auto c = py::class_<C, std::shared_ptr<C>>(mex, "TrackBuildingBase");
0073   }
0074 
0075   ACTS_PYTHON_DECLARE_GNN_STAGE(BoostTrackBuilding, TrackBuildingBase, mex);
0076 
0077 #ifdef ACTS_EXATRKX_TORCH_BACKEND
0078   ACTS_PYTHON_DECLARE_GNN_STAGE(TorchMetricLearning, GraphConstructionBase, mex,
0079                                 modelPath, selectedFeatures, embeddingDim, rVal,
0080                                 knnVal, deviceID);
0081 
0082   ACTS_PYTHON_DECLARE_GNN_STAGE(TorchEdgeClassifier, EdgeClassificationBase,
0083                                 mex, modelPath, selectedFeatures, cut, nChunks,
0084                                 undirected, deviceID, useEdgeFeatures);
0085 #endif
0086 
0087 #ifdef ACTS_EXATRKX_WITH_TENSORRT
0088   ACTS_PYTHON_DECLARE_GNN_STAGE(TensorRTEdgeClassifier, EdgeClassificationBase,
0089                                 mex, modelPath, selectedFeatures, cut,
0090                                 numExecutionContexts);
0091 #endif
0092 
0093 #ifdef ACTS_EXATRKX_WITH_CUDA
0094   ACTS_PYTHON_DECLARE_GNN_STAGE(CudaTrackBuilding, TrackBuildingBase, mex,
0095                                 useOneBlockImplementation, doJunctionRemoval);
0096 #endif
0097 
0098 #ifdef ACTS_EXATRKX_ONNX_BACKEND
0099   ACTS_PYTHON_DECLARE_GNN_STAGE(OnnxEdgeClassifier, EdgeClassificationBase, mex,
0100                                 modelPath, cut);
0101 #endif
0102 
0103 #ifdef ACTS_EXATRKX_WITH_MODULEMAP
0104   ACTS_PYTHON_DECLARE_GNN_STAGE(
0105       ModuleMapCuda, GraphConstructionBase, mex, moduleMapPath, rScale,
0106       phiScale, zScale, etaScale, moreParallel, gpuDevice, gpuBlocks, epsilon);
0107 #endif
0108 
0109   ACTS_PYTHON_DECLARE_ALGORITHM(
0110       ActsExamples::TruthGraphBuilder, mex, "TruthGraphBuilder",
0111       inputSpacePoints, inputSimHits, inputParticles,
0112       inputMeasurementSimHitsMap, inputMeasurementParticlesMap, outputGraph,
0113       targetMinPT, targetMinSize, uniqueModules);
0114 
0115   {
0116     auto nodeFeatureEnum =
0117         py::enum_<TrackFindingAlgorithmExaTrkX::NodeFeature>(mex, "NodeFeature")
0118             .value("R", TrackFindingAlgorithmExaTrkX::NodeFeature::eR)
0119             .value("Phi", TrackFindingAlgorithmExaTrkX::NodeFeature::ePhi)
0120             .value("Z", TrackFindingAlgorithmExaTrkX::NodeFeature::eZ)
0121             .value("X", TrackFindingAlgorithmExaTrkX::NodeFeature::eX)
0122             .value("Y", TrackFindingAlgorithmExaTrkX::NodeFeature::eY)
0123             .value("Eta", TrackFindingAlgorithmExaTrkX::NodeFeature::eEta)
0124             .value("ClusterX",
0125                    TrackFindingAlgorithmExaTrkX::NodeFeature::eClusterLoc0)
0126             .value("ClusterY",
0127                    TrackFindingAlgorithmExaTrkX::NodeFeature::eClusterLoc1)
0128             .value("CellCount",
0129                    TrackFindingAlgorithmExaTrkX::NodeFeature::eCellCount)
0130             .value("ChargeSum",
0131                    TrackFindingAlgorithmExaTrkX::NodeFeature::eChargeSum);
0132 
0133     // clang-format off
0134 #define ADD_FEATURE_ENUMS(n) \
0135   nodeFeatureEnum \
0136     .value("Cluster" #n "X", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##X) \
0137     .value("Cluster" #n "Y", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Y) \
0138     .value("Cluster" #n "Z", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Z) \
0139     .value("Cluster" #n "R", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##R) \
0140     .value("Cluster" #n "Phi", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Phi) \
0141     .value("Cluster" #n "Eta", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Eta) \
0142     .value("CellCount" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eCellCount##n) \
0143     .value("ChargeSum" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eChargeSum##n) \
0144     .value("LocEta" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocEta##n) \
0145     .value("LocPhi" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocPhi##n) \
0146     .value("LocDir0" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir0##n) \
0147     .value("LocDir1" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir1##n) \
0148     .value("LocDir2" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir2##n) \
0149     .value("LengthDir0" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir0##n) \
0150     .value("LengthDir1" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir1##n) \
0151     .value("LengthDir2" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir2##n) \
0152     .value("GlobEta" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eGlobEta##n) \
0153     .value("GlobPhi" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eGlobPhi##n) \
0154     .value("EtaAngle" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eEtaAngle##n) \
0155     .value("PhiAngle" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::ePhiAngle##n)
0156     // clang-format on
0157 
0158     ADD_FEATURE_ENUMS(1);
0159     ADD_FEATURE_ENUMS(2);
0160 
0161 #undef ADD_FEATURE_ENUMS
0162   }
0163 
0164   ACTS_PYTHON_DECLARE_ALGORITHM(
0165       ActsExamples::TrackFindingAlgorithmExaTrkX, mex,
0166       "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputClusters,
0167       inputTruthGraph, outputProtoTracks, outputGraph, graphConstructor,
0168       edgeClassifiers, trackBuilder, nodeFeatures, featureScales,
0169       minMeasurementsPerTrack, geometryIdMap);
0170 
0171   {
0172     auto cls =
0173         py::class_<Acts::ExaTrkXHook, std::shared_ptr<Acts::ExaTrkXHook>>(
0174             mex, "ExaTrkXHook");
0175   }
0176 
0177   {
0178     using Class = Acts::TruthGraphMetricsHook;
0179 
0180     auto cls = py::class_<Class, Acts::ExaTrkXHook, std::shared_ptr<Class>>(
0181                    mex, "TruthGraphMetricsHook")
0182                    .def(py::init([](const std::vector<std::int64_t> &g,
0183                                     Logging::Level lvl) {
0184                      return std::make_shared<Class>(
0185                          g, getDefaultLogger("TruthGraphHook", lvl));
0186                    }));
0187   }
0188 
0189   {
0190     auto cls =
0191         py::class_<Acts::Device>(mex, "Device")
0192             .def_static("Cpu", &Acts::Device::Cpu)
0193             .def_static("Cuda", &Acts::Device::Cuda, py::arg("index") = 0);
0194   }
0195 
0196   {
0197     using Class = Acts::ExaTrkXPipeline;
0198 
0199     auto cls =
0200         py::class_<Class, std::shared_ptr<Class>>(mex, "ExaTrkXPipeline")
0201             .def(py::init(
0202                      [](std::shared_ptr<GraphConstructionBase> g,
0203                         std::vector<std::shared_ptr<EdgeClassificationBase>> e,
0204                         std::shared_ptr<TrackBuildingBase> t,
0205                         Logging::Level lvl) {
0206                        return std::make_shared<Class>(
0207                            g, e, t, getDefaultLogger("MetricLearning", lvl));
0208                      }),
0209                  py::arg("graphConstructor"), py::arg("edgeClassifiers"),
0210                  py::arg("trackBuilder"), py::arg("level"))
0211             .def("run", &ExaTrkXPipeline::run, py::arg("features"),
0212                  py::arg("moduleIds"), py::arg("spacepoints"),
0213                  py::arg("device") = Acts::Device::Cuda(0),
0214                  py::arg("hook") = Acts::ExaTrkXHook{},
0215                  py::arg("timing") = nullptr);
0216   }
0217 
0218   ACTS_PYTHON_DECLARE_ALGORITHM(
0219       ActsExamples::PrototracksToParameters, mex, "PrototracksToParameters",
0220       inputProtoTracks, inputSpacePoints, outputSeeds, outputParameters,
0221       outputProtoTracks, geometry, magneticField, buildTightSeeds);
0222 
0223   ACTS_PYTHON_DECLARE_ALGORITHM(
0224       ActsExamples::TrackFindingFromPrototrackAlgorithm, mex,
0225       "TrackFindingFromPrototrackAlgorithm", inputProtoTracks,
0226       inputMeasurements, inputInitialTrackParameters, outputTracks,
0227       measurementSelectorCfg, trackingGeometry, magneticField, findTracks, tag);
0228 }
0229 
0230 }  // namespace Acts::Python