File indexing completed on 2025-01-18 09:27:42
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/EventData/MultiTrajectoryHelpers.hpp"
0012 #include "Acts/EventData/TrackContainer.hpp"
0013 #include "Acts/Plugins/Onnx/OnnxRuntimeBase.hpp"
0014 #include "Acts/TrackFinding/detail/AmbiguityTrackClustering.hpp"
0015
0016 #include <map>
0017 #include <unordered_map>
0018 #include <vector>
0019
0020 #include <onnxruntime_cxx_api.h>
0021
0022 namespace Acts {
0023
0024
0025 class AmbiguityTrackClassifier {
0026 public:
0027
0028
0029
0030 AmbiguityTrackClassifier(const char* modelPath)
0031 : m_env(ORT_LOGGING_LEVEL_WARNING, "MLClassifier"),
0032 m_duplicateClassifier(m_env, modelPath) {}
0033
0034
0035
0036
0037
0038
0039 template <typename track_container_t, typename traj_t,
0040 template <typename> class holder_t>
0041 std::vector<std::vector<float>> inferScores(
0042 std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0043 const Acts::TrackContainer<track_container_t, traj_t, holder_t>& tracks)
0044 const {
0045
0046
0047 int trackNb = 0;
0048 for (const auto& [_, val] : clusters) {
0049 trackNb += val.size();
0050 }
0051
0052 Acts::NetworkBatchInput networkInput(trackNb, 8);
0053 std::size_t inputID = 0;
0054
0055 for (const auto& [key, val] : clusters) {
0056 for (const auto& trackID : val) {
0057 auto track = tracks.getTrack(trackID);
0058 auto trajState = Acts::MultiTrajectoryHelpers::trajectoryState(
0059 tracks.trackStateContainer(), track.tipIndex());
0060 networkInput(inputID, 0) = trajState.nStates;
0061 networkInput(inputID, 1) = trajState.nMeasurements;
0062 networkInput(inputID, 2) = trajState.nOutliers;
0063 networkInput(inputID, 3) = trajState.nHoles;
0064 networkInput(inputID, 4) = trajState.NDF;
0065 networkInput(inputID, 5) = (trajState.chi2Sum * 1.0) /
0066 (trajState.NDF != 0 ? trajState.NDF : 1);
0067 networkInput(inputID, 6) = Acts::VectorHelpers::eta(track.momentum());
0068 networkInput(inputID, 7) = Acts::VectorHelpers::phi(track.momentum());
0069 inputID++;
0070 }
0071 }
0072
0073 std::vector<std::vector<float>> outputTensor =
0074 m_duplicateClassifier.runONNXInference(networkInput);
0075 return outputTensor;
0076 }
0077
0078
0079
0080
0081
0082
0083 std::vector<std::size_t> trackSelection(
0084 std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0085 std::vector<std::vector<float>>& outputTensor) const {
0086 std::vector<std::size_t> goodTracks;
0087 std::size_t iOut = 0;
0088
0089
0090 for (const auto& [key, val] : clusters) {
0091 std::size_t bestTrackID = 0;
0092 float bestTrackScore = 0;
0093 for (const auto& track : val) {
0094 if (outputTensor[iOut][0] > bestTrackScore) {
0095 bestTrackScore = outputTensor[iOut][0];
0096 bestTrackID = track;
0097 }
0098 iOut++;
0099 }
0100 goodTracks.push_back(bestTrackID);
0101 }
0102 return goodTracks;
0103 }
0104
0105
0106
0107
0108
0109
0110 template <typename track_container_t, typename traj_t,
0111 template <typename> class holder_t>
0112 std::vector<std::size_t> solveAmbiguity(
0113 std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0114 const Acts::TrackContainer<track_container_t, traj_t, holder_t>& tracks)
0115 const {
0116 std::vector<std::vector<float>> outputTensor =
0117 inferScores(clusters, tracks);
0118 std::vector<std::size_t> goodTracks =
0119 trackSelection(clusters, outputTensor);
0120
0121 return goodTracks;
0122 }
0123
0124 private:
0125
0126 Ort::Env m_env;
0127
0128 Acts::OnnxRuntimeBase m_duplicateClassifier;
0129 };
0130
0131 }