File indexing completed on 2025-12-16 09:28:08
0001
0002
0003
0004 #pragma once
0005
0006 #include "algorithms/onnx/ONNXInference.h"
0007 #include "services/algorithms_init/AlgorithmsInit_service.h"
0008 #include "extensions/jana/JOmniFactory.h"
0009
0010 namespace eicrecon {
0011
0012 class ONNXInference_factory : public JOmniFactory<ONNXInference_factory, ONNXInferenceConfig> {
0013
0014 public:
0015 using AlgoT = eicrecon::ONNXInference;
0016
0017 private:
0018 std::unique_ptr<AlgoT> m_algo;
0019
0020 VariadicPodioInput<edm4eic::Tensor> m_input_tensors{this};
0021
0022 VariadicPodioOutput<edm4eic::Tensor> m_output_tensors{this};
0023
0024 ParameterRef<std::string> m_modelPath{this, "modelPath", config().modelPath};
0025
0026 Service<AlgorithmsInit_service> m_algorithmsInit{this};
0027
0028 public:
0029 void Configure() {
0030 m_algo = std::make_unique<AlgoT>(GetPrefix());
0031 m_algo->level(static_cast<algorithms::LogLevel>(logger()->level()));
0032 m_algo->applyConfig(config());
0033 m_algo->init();
0034 }
0035
0036 void Process(int32_t , uint64_t ) {
0037 std::vector<gsl::not_null<const edm4eic::TensorCollection*>> in_collections;
0038 for (const auto& in_collection : m_input_tensors()) {
0039 in_collections.push_back(gsl::not_null<const edm4eic::TensorCollection*>{in_collection});
0040 }
0041
0042 std::vector<gsl::not_null<edm4eic::TensorCollection*>> out_collections;
0043 for (const auto& out_collection : m_output_tensors()) {
0044 out_collections.push_back(gsl::not_null<edm4eic::TensorCollection*>{out_collection.get()});
0045 }
0046
0047 m_algo->process(in_collections, out_collections);
0048 }
0049 };
0050
0051 }