Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:28:08

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2023 - 2024, Wouter Deconinck, Simon Gardener, Dmitry Kalinkin
0003 
0004 #pragma once
0005 
0006 #include "algorithms/onnx/ONNXInference.h"
0007 #include "services/algorithms_init/AlgorithmsInit_service.h"
0008 #include "extensions/jana/JOmniFactory.h"
0009 
0010 namespace eicrecon {
0011 
0012 class ONNXInference_factory : public JOmniFactory<ONNXInference_factory, ONNXInferenceConfig> {
0013 
0014 public:
0015   using AlgoT = eicrecon::ONNXInference;
0016 
0017 private:
0018   std::unique_ptr<AlgoT> m_algo;
0019 
0020   VariadicPodioInput<edm4eic::Tensor> m_input_tensors{this};
0021 
0022   VariadicPodioOutput<edm4eic::Tensor> m_output_tensors{this};
0023 
0024   ParameterRef<std::string> m_modelPath{this, "modelPath", config().modelPath};
0025 
0026   Service<AlgorithmsInit_service> m_algorithmsInit{this};
0027 
0028 public:
0029   void Configure() {
0030     m_algo = std::make_unique<AlgoT>(GetPrefix());
0031     m_algo->level(static_cast<algorithms::LogLevel>(logger()->level()));
0032     m_algo->applyConfig(config());
0033     m_algo->init();
0034   }
0035 
0036   void Process(int32_t /* run_number */, uint64_t /* event_number */) {
0037     std::vector<gsl::not_null<const edm4eic::TensorCollection*>> in_collections;
0038     for (const auto& in_collection : m_input_tensors()) {
0039       in_collections.push_back(gsl::not_null<const edm4eic::TensorCollection*>{in_collection});
0040     }
0041 
0042     std::vector<gsl::not_null<edm4eic::TensorCollection*>> out_collections;
0043     for (const auto& out_collection : m_output_tensors()) {
0044       out_collections.push_back(gsl::not_null<edm4eic::TensorCollection*>{out_collection.get()});
0045     }
0046 
0047     m_algo->process(in_collections, out_collections);
0048   }
0049 };
0050 
0051 } // namespace eicrecon