Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-06-26 07:05:40

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2022, 2023 Sylvester Joosten, Dmitry Romanov, Wouter Deconinck
0003 
0004 #pragma once
0005 
0006 #include <algorithms/algorithm.h>
0007 #include <cstdint>
0008 #include <edm4eic/InclusiveKinematicsCollection.h>
0009 #include <onnxruntime_cxx_api.h>
0010 #include <string>
0011 #include <string_view>
0012 #include <vector>
0013 
0014 #include "algorithms/interfaces/WithPodConfig.h"
0015 #include "algorithms/onnx/InclusiveKinematicsMLConfig.h"
0016 
0017 namespace eicrecon {
0018 
0019 using InclusiveKinematicsMLAlgorithm =
0020     algorithms::Algorithm<algorithms::Input<edm4eic::InclusiveKinematicsCollection,
0021                                             edm4eic::InclusiveKinematicsCollection>,
0022                           algorithms::Output<edm4eic::InclusiveKinematicsCollection>>;
0023 
0024 class InclusiveKinematicsML : public InclusiveKinematicsMLAlgorithm,
0025                               public WithPodConfig<InclusiveKinematicsMLConfig> {
0026 
0027 public:
0028   InclusiveKinematicsML(std::string_view name)
0029       : InclusiveKinematicsMLAlgorithm{name,
0030                                        {"inclusiveKinematicsElectron", "inclusiveKinematicsDA"},
0031                                        {"inclusiveKinematicsML"},
0032                                        "Determine inclusive kinematics using combined ML method."} {
0033   }
0034 
0035   void init() final;
0036   void process(const Input&, const Output&) const final;
0037 
0038 private:
0039   mutable Ort::Session m_session{nullptr};
0040 
0041   std::vector<std::string> m_input_names;
0042   std::vector<const char*> m_input_names_char;
0043   std::vector<std::vector<std::int64_t>> m_input_shapes;
0044 
0045   std::vector<std::string> m_output_names;
0046   std::vector<const char*> m_output_names_char;
0047   std::vector<std::vector<std::int64_t>> m_output_shapes;
0048 };
0049 
0050 } // namespace eicrecon