File indexing completed on 2025-01-18 09:12:25
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/Onnx/MLTrackClassifier.hpp"
0010
0011 #include <cassert>
0012 #include <stdexcept>
0013
0014
0015 Acts::MLTrackClassifier::TrackLabels Acts::MLTrackClassifier::predictTrackLabel(
0016 std::vector<float>& inputFeatures, double decisionThreshProb) const {
0017
0018 if (!((0. <= decisionThreshProb) && (decisionThreshProb <= 1.))) {
0019 throw std::invalid_argument(
0020 "predictTrackLabel: Decision threshold "
0021 "probability is not in [0, 1].");
0022 }
0023
0024
0025 std::vector<float> outputTensor = runONNXInference(inputFeatures);
0026
0027 float outputProbability = outputTensor[0];
0028
0029
0030
0031 if (outputProbability > decisionThreshProb) {
0032 return TrackLabels::eDuplicate;
0033 }
0034 return TrackLabels::eGood;
0035 }
0036
0037
0038 bool Acts::MLTrackClassifier::isDuplicate(std::vector<float>& inputFeatures,
0039 double decisionThreshProb) const {
0040 Acts::MLTrackClassifier::TrackLabels predictedLabel =
0041 Acts::MLTrackClassifier::predictTrackLabel(inputFeatures,
0042 decisionThreshProb);
0043 return predictedLabel == Acts::MLTrackClassifier::TrackLabels::eDuplicate;
0044 }