File indexing completed on 2025-09-16 08:27:29
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/EventData/MultiTrajectoryHelpers.hpp"
0012 #include "Acts/EventData/TrackContainer.hpp"
0013 #include "Acts/EventData/TrackContainerFrontendConcept.hpp"
0014 #include "Acts/EventData/TrackProxyConcept.hpp"
0015 #include "Acts/Plugins/Onnx/OnnxRuntimeBase.hpp"
0016 #include "Acts/TrackFinding/detail/AmbiguityTrackClustering.hpp"
0017
0018 #include <map>
0019 #include <unordered_map>
0020 #include <vector>
0021
0022 #include <onnxruntime_cxx_api.h>
0023
0024 namespace Acts {
0025
0026
0027 class AmbiguityTrackClassifier {
0028 public:
0029
0030
0031
0032 AmbiguityTrackClassifier(const char* modelPath)
0033 : m_env(ORT_LOGGING_LEVEL_WARNING, "MLClassifier"),
0034 m_duplicateClassifier(m_env, modelPath) {}
0035
0036
0037
0038
0039
0040
0041 template <TrackContainerFrontend track_container_t>
0042 std::vector<std::vector<float>> inferScores(
0043 std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0044 const track_container_t& tracks) const {
0045
0046
0047 int trackNb = 0;
0048 for (const auto& [_, val] : clusters) {
0049 trackNb += val.size();
0050 }
0051
0052 Acts::NetworkBatchInput networkInput(trackNb, 8);
0053 std::size_t inputID = 0;
0054
0055 for (const auto& [key, val] : clusters) {
0056 for (const auto& trackID : val) {
0057 auto track = tracks.getTrack(trackID);
0058 auto trajState = Acts::MultiTrajectoryHelpers::trajectoryState(
0059 tracks.trackStateContainer(), track.tipIndex());
0060 networkInput(inputID, 0) = trajState.nStates;
0061 networkInput(inputID, 1) = trajState.nMeasurements;
0062 networkInput(inputID, 2) = trajState.nOutliers;
0063 networkInput(inputID, 3) = trajState.nHoles;
0064 networkInput(inputID, 4) = trajState.NDF;
0065 networkInput(inputID, 5) = (trajState.chi2Sum * 1.0) /
0066 (trajState.NDF != 0 ? trajState.NDF : 1);
0067 networkInput(inputID, 6) = Acts::VectorHelpers::eta(track.momentum());
0068 networkInput(inputID, 7) = Acts::VectorHelpers::phi(track.momentum());
0069 inputID++;
0070 }
0071 }
0072
0073 std::vector<std::vector<float>> outputTensor =
0074 m_duplicateClassifier.runONNXInference(networkInput);
0075 return outputTensor;
0076 }
0077
0078
0079
0080
0081
0082
0083 std::vector<std::size_t> trackSelection(
0084 std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0085 std::vector<std::vector<float>>& outputTensor) const {
0086 std::vector<std::size_t> goodTracks;
0087 std::size_t iOut = 0;
0088
0089
0090 for (const auto& [key, val] : clusters) {
0091 std::size_t bestTrackID = 0;
0092 float bestTrackScore = 0;
0093 for (const auto& track : val) {
0094 if (outputTensor[iOut][0] > bestTrackScore) {
0095 bestTrackScore = outputTensor[iOut][0];
0096 bestTrackID = track;
0097 }
0098 iOut++;
0099 }
0100 goodTracks.push_back(bestTrackID);
0101 }
0102 return goodTracks;
0103 }
0104
0105 private:
0106
0107 Ort::Env m_env;
0108
0109 Acts::OnnxRuntimeBase m_duplicateClassifier;
0110 };
0111
0112 }