File indexing completed on 2025-01-18 09:27:42
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Plugins/Onnx/OnnxRuntimeBase.hpp"
0012
0013 #include <vector>
0014
0015 #include <onnxruntime_cxx_api.h>
0016
0017 namespace Acts {
0018
0019
0020 class SeedClassifier {
0021 public:
0022
0023
0024
0025 SeedClassifier(const char* modelPath)
0026 : m_env(ORT_LOGGING_LEVEL_WARNING, "MLSeedClassifier"),
0027 m_duplicateClassifier(m_env, modelPath){};
0028
0029
0030
0031
0032
0033 std::vector<std::vector<float>> inferScores(
0034 Acts::NetworkBatchInput& networkInput) const {
0035
0036 std::vector<std::vector<float>> outputTensor =
0037 m_duplicateClassifier.runONNXInference(networkInput);
0038 return outputTensor;
0039 }
0040
0041
0042
0043
0044
0045
0046
0047 std::vector<std::size_t> seedSelection(
0048 std::vector<std::vector<std::size_t>>& clusters,
0049 std::vector<std::vector<float>>& outputTensor,
0050 float minSeedScore = 0.1) const {
0051 std::vector<std::size_t> goodSeeds;
0052
0053
0054 for (const auto& cluster : clusters) {
0055 std::size_t bestseedID = 0;
0056 float bestSeedScore = 0;
0057 for (const auto& seed : cluster) {
0058 if (outputTensor[seed][0] > bestSeedScore) {
0059 bestSeedScore = outputTensor[seed][0];
0060 bestseedID = seed;
0061 }
0062 }
0063 if (bestSeedScore >= minSeedScore) {
0064 goodSeeds.push_back(bestseedID);
0065 }
0066 }
0067 return goodSeeds;
0068 }
0069
0070
0071
0072
0073
0074
0075
0076 std::vector<std::size_t> solveAmbiguity(
0077 std::vector<std::vector<std::size_t>>& clusters,
0078 Acts::NetworkBatchInput& networkInput, float minSeedScore = 0.1) const {
0079 std::vector<std::vector<float>> outputTensor = inferScores(networkInput);
0080 std::vector<std::size_t> goodSeeds =
0081 seedSelection(clusters, outputTensor, minSeedScore);
0082 return goodSeeds;
0083 }
0084
0085 private:
0086
0087 Ort::Env m_env;
0088
0089 Acts::OnnxRuntimeBase m_duplicateClassifier;
0090 };
0091
0092 }