File indexing completed on 2025-12-11 09:40:23
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/EventData/MultiTrajectoryHelpers.hpp"
0012 #include "Acts/EventData/TrackContainer.hpp"
0013 #include "Acts/EventData/TrackContainerFrontendConcept.hpp"
0014 #include "Acts/EventData/TrackProxyConcept.hpp"
0015 #include "Acts/TrackFinding/detail/AmbiguityTrackClustering.hpp"
0016 #include "Acts/Utilities/VectorHelpers.hpp"
0017 #include "ActsPlugins/Onnx/OnnxRuntimeBase.hpp"
0018
0019 #include <map>
0020 #include <unordered_map>
0021 #include <vector>
0022
0023 #include <onnxruntime_cxx_api.h>
0024
0025 namespace ActsPlugins {
0026
0027
0028 class AmbiguityTrackClassifier {
0029 public:
0030
0031
0032
0033 AmbiguityTrackClassifier(const char* modelPath)
0034 : m_env(ORT_LOGGING_LEVEL_WARNING, "MLClassifier"),
0035 m_duplicateClassifier(m_env, modelPath) {}
0036
0037
0038
0039
0040
0041
0042 template <Acts::TrackContainerFrontend track_container_t>
0043 std::vector<std::vector<float>> inferScores(
0044 std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0045 const track_container_t& tracks) const {
0046
0047
0048 int trackNb = 0;
0049 for (const auto& [_, val] : clusters) {
0050 trackNb += val.size();
0051 }
0052
0053 NetworkBatchInput networkInput(trackNb, 8);
0054 std::size_t inputID = 0;
0055
0056 for (const auto& [key, val] : clusters) {
0057 for (const auto& trackID : val) {
0058 auto track = tracks.getTrack(trackID);
0059 auto trajState = Acts::MultiTrajectoryHelpers::trajectoryState(
0060 tracks.trackStateContainer(), track.tipIndex());
0061 networkInput(inputID, 0) = trajState.nStates;
0062 networkInput(inputID, 1) = trajState.nMeasurements;
0063 networkInput(inputID, 2) = trajState.nOutliers;
0064 networkInput(inputID, 3) = trajState.nHoles;
0065 networkInput(inputID, 4) = trajState.NDF;
0066 networkInput(inputID, 5) = (trajState.chi2Sum * 1.0) /
0067 (trajState.NDF != 0 ? trajState.NDF : 1);
0068 networkInput(inputID, 6) = Acts::VectorHelpers::eta(track.momentum());
0069 networkInput(inputID, 7) = Acts::VectorHelpers::phi(track.momentum());
0070 inputID++;
0071 }
0072 }
0073
0074 std::vector<std::vector<float>> outputTensor =
0075 m_duplicateClassifier.runONNXInference(networkInput);
0076 return outputTensor;
0077 }
0078
0079
0080
0081
0082
0083
0084 std::vector<std::size_t> trackSelection(
0085 std::unordered_map<std::size_t, std::vector<std::size_t>>& clusters,
0086 std::vector<std::vector<float>>& outputTensor) const {
0087 std::vector<std::size_t> goodTracks;
0088 std::size_t iOut = 0;
0089
0090
0091 for (const auto& [key, val] : clusters) {
0092 std::size_t bestTrackID = 0;
0093 float bestTrackScore = 0;
0094 for (const auto& track : val) {
0095 if (outputTensor[iOut][0] > bestTrackScore) {
0096 bestTrackScore = outputTensor[iOut][0];
0097 bestTrackID = track;
0098 }
0099 iOut++;
0100 }
0101 goodTracks.push_back(bestTrackID);
0102 }
0103 return goodTracks;
0104 }
0105
0106 private:
0107
0108 Ort::Env m_env;
0109
0110 OnnxRuntimeBase m_duplicateClassifier;
0111 };
0112
0113 }