Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-06-30 07:55:42

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2022 - 2024 Sylvester Joosten, Dmitry Romanov, Wouter Deconinck, Dmitry Kalinkin
0003 
0004 #pragma once
0005 
0006 #include <algorithms/algorithm.h>
0007 #include <cstdint>
0008 #include <onnxruntime_cxx_api.h>
0009 #include <string>
0010 #include <string_view>
0011 #include <vector>
0012 #include <edm4eic/TensorCollection.h>
0013 
0014 #include "algorithms/interfaces/WithPodConfig.h"
0015 #include "algorithms/onnx/ONNXInferenceConfig.h"
0016 
0017 namespace eicrecon {
0018 
0019 using ONNXInferenceAlgorithm =
0020     algorithms::Algorithm<algorithms::Input<std::vector<edm4eic::TensorCollection>>,
0021                           algorithms::Output<std::vector<edm4eic::TensorCollection>>>;
0022 
0023 class ONNXInference : public ONNXInferenceAlgorithm, public WithPodConfig<ONNXInferenceConfig> {
0024 
0025 public:
0026   ONNXInference(std::string_view name)
0027       : ONNXInferenceAlgorithm{name, {"inputTensors"}, {"outputTensors"}, ""} {}
0028 
0029   void init() final;
0030   void process(const Input&, const Output&) const final;
0031 
0032 private:
0033   mutable Ort::Env m_env{nullptr};
0034   mutable Ort::Session m_session{nullptr};
0035 
0036   std::vector<std::string> m_input_names;
0037   std::vector<const char*> m_input_names_char;
0038   std::vector<std::vector<std::int64_t>> m_input_shapes;
0039 
0040   std::vector<std::string> m_output_names;
0041   std::vector<const char*> m_output_names_char;
0042   std::vector<std::vector<std::int64_t>> m_output_shapes;
0043 };
0044 
0045 } // namespace eicrecon