Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:12:25

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/Onnx/MLTrackClassifier.hpp"
0010 
0011 #include <cassert>
0012 #include <stdexcept>
0013 
0014 // prediction function
0015 Acts::MLTrackClassifier::TrackLabels Acts::MLTrackClassifier::predictTrackLabel(
0016     std::vector<float>& inputFeatures, double decisionThreshProb) const {
0017   // check that the decision threshold is a probability
0018   if (!((0. <= decisionThreshProb) && (decisionThreshProb <= 1.))) {
0019     throw std::invalid_argument(
0020         "predictTrackLabel: Decision threshold "
0021         "probability is not in [0, 1].");
0022   }
0023 
0024   // run the model over the input
0025   std::vector<float> outputTensor = runONNXInference(inputFeatures);
0026   // this is binary classification, so only need first value
0027   float outputProbability = outputTensor[0];
0028 
0029   // the output layer computes how confident the network is that the track is a
0030   // duplicate, so need to convert that to a label
0031   if (outputProbability > decisionThreshProb) {
0032     return TrackLabels::eDuplicate;
0033   }
0034   return TrackLabels::eGood;
0035 }
0036 
0037 // function that checks if the predicted track label is duplicate
0038 bool Acts::MLTrackClassifier::isDuplicate(std::vector<float>& inputFeatures,
0039                                           double decisionThreshProb) const {
0040   Acts::MLTrackClassifier::TrackLabels predictedLabel =
0041       Acts::MLTrackClassifier::predictTrackLabel(inputFeatures,
0042                                                  decisionThreshProb);
0043   return predictedLabel == Acts::MLTrackClassifier::TrackLabels::eDuplicate;
0044 }