Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:27:42

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2023 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/Plugins/Onnx/OnnxRuntimeBase.hpp"
0012 
0013 #include <vector>
0014 
0015 #include <onnxruntime_cxx_api.h>
0016 
0017 namespace Acts {
0018 
0019 /// Onnx model implementation for seed scoring and selection
0020 class SeedClassifier {
0021  public:
0022   /// Construct the scoring algorithm.
0023   ///
0024   /// @param modelPath path to the model file
0025   SeedClassifier(const char* modelPath)
0026       : m_env(ORT_LOGGING_LEVEL_WARNING, "MLSeedClassifier"),
0027         m_duplicateClassifier(m_env, modelPath){};
0028 
0029   /// Compute a score for each seed to be used in the seed selection
0030   ///
0031   /// @param networkInput input of the network
0032   /// @return a vector of vector of seed score. Due to the architecture of the network each seed only have a size 1 score vector.
0033   std::vector<std::vector<float>> inferScores(
0034       Acts::NetworkBatchInput& networkInput) const {
0035     // Use the network to compute a score for all the Seeds.
0036     std::vector<std::vector<float>> outputTensor =
0037         m_duplicateClassifier.runONNXInference(networkInput);
0038     return outputTensor;
0039   }
0040 
0041   /// Select the seed associated with each cluster based on the score vector
0042   ///
0043   /// @param clusters is a vector of clusters, each cluster corresponds to a vector of seedIDs
0044   /// @param outputTensor is the score vector obtained from inferScores.
0045   /// @param minSeedScore is the minimum score a seed needs to be selected
0046   /// @return a vector of seedIDs corresponding tho the good seeds
0047   std::vector<std::size_t> seedSelection(
0048       std::vector<std::vector<std::size_t>>& clusters,
0049       std::vector<std::vector<float>>& outputTensor,
0050       float minSeedScore = 0.1) const {
0051     std::vector<std::size_t> goodSeeds;
0052     // Loop over all the cluster and only keep the seed with the highest score
0053     // in each cluster
0054     for (const auto& cluster : clusters) {
0055       std::size_t bestseedID = 0;
0056       float bestSeedScore = 0;
0057       for (const auto& seed : cluster) {
0058         if (outputTensor[seed][0] > bestSeedScore) {
0059           bestSeedScore = outputTensor[seed][0];
0060           bestseedID = seed;
0061         }
0062       }
0063       if (bestSeedScore >= minSeedScore) {
0064         goodSeeds.push_back(bestseedID);
0065       }
0066     }
0067     return goodSeeds;
0068   }
0069 
0070   /// Select the seed associated with each cluster
0071   ///
0072   /// @param clusters is a map of clusters, each cluster correspond to a vector of seed ID
0073   /// @param networkInput input of the network
0074   /// @param minSeedScore is the minimum score a seed need to be selected
0075   /// @return a vector of seedID corresponding the the good seeds
0076   std::vector<std::size_t> solveAmbiguity(
0077       std::vector<std::vector<std::size_t>>& clusters,
0078       Acts::NetworkBatchInput& networkInput, float minSeedScore = 0.1) const {
0079     std::vector<std::vector<float>> outputTensor = inferScores(networkInput);
0080     std::vector<std::size_t> goodSeeds =
0081         seedSelection(clusters, outputTensor, minSeedScore);
0082     return goodSeeds;
0083   }
0084 
0085  private:
0086   // ONNX environment
0087   Ort::Env m_env;
0088   // ONNX model for the duplicate neural network
0089   Acts::OnnxRuntimeBase m_duplicateClassifier;
0090 };
0091 
0092 }  // namespace Acts