File indexing completed on 2025-10-13 08:18:21
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsPlugins/Onnx/MLTrackClassifier.hpp"
0010
0011 #include <cassert>
0012 #include <stdexcept>
0013
0014
0015 ActsPlugins::MLTrackClassifier::TrackLabels
0016 ActsPlugins::MLTrackClassifier::predictTrackLabel(
0017 std::vector<float>& inputFeatures, double decisionThreshProb) const {
0018
0019 if (!((0. <= decisionThreshProb) && (decisionThreshProb <= 1.))) {
0020 throw std::invalid_argument(
0021 "predictTrackLabel: Decision threshold "
0022 "probability is not in [0, 1].");
0023 }
0024
0025
0026 std::vector<float> outputTensor = runONNXInference(inputFeatures);
0027
0028 float outputProbability = outputTensor[0];
0029
0030
0031
0032 if (outputProbability > decisionThreshProb) {
0033 return TrackLabels::eDuplicate;
0034 }
0035 return TrackLabels::eGood;
0036 }
0037
0038
0039 bool ActsPlugins::MLTrackClassifier::isDuplicate(
0040 std::vector<float>& inputFeatures, double decisionThreshProb) const {
0041 MLTrackClassifier::TrackLabels predictedLabel =
0042 MLTrackClassifier::predictTrackLabel(inputFeatures, decisionThreshProb);
0043 return predictedLabel == MLTrackClassifier::TrackLabels::eDuplicate;
0044 }