![]() |
|
|||
File indexing completed on 2025-02-23 09:22:34
0001 // 0002 // ******************************************************************** 0003 // * License and Disclaimer * 0004 // * * 0005 // * The Geant4 software is copyright of the Copyright Holders of * 0006 // * the Geant4 Collaboration. It is provided under the terms and * 0007 // * conditions of the Geant4 Software License, included in the file * 0008 // * LICENSE and available at http://cern.ch/geant4/license . These * 0009 // * include a list of copyright holders. * 0010 // * * 0011 // * Neither the authors of this software system, nor their employing * 0012 // * institutes,nor the agencies providing financial support for this * 0013 // * work make any representation or warranty, express or implied, * 0014 // * regarding this software system or assume any liability for its * 0015 // * use. Please see the license in the file LICENSE and URL above * 0016 // * for the full disclaimer and the limitation of liability. * 0017 // * * 0018 // * This code implementation is the result of the scientific and * 0019 // * technical work of the GEANT4 collaboration. * 0020 // * By using, copying, modifying or distributing the software (or * 0021 // * any work based on the software) you agree to acknowledge its * 0022 // * use in resulting scientific publications, and indicate your * 0023 // * acceptance of all terms of the Geant4 Software license. * 0024 // ******************************************************************** 0025 // 0026 0027 #ifdef USE_INFERENCE_ONNX 0028 # ifndef PAR04ONNXINFERENCE_HH 0029 # define PAR04ONNXINFERENCE_HH 0030 # include "Par04InferenceInterface.hh" // for Par04InferenceInterface 0031 # include "core/session/onnxruntime_cxx_api.h" // for Env, Session, SessionO... 0032 0033 # include <G4String.hh> // for G4String 0034 # include <G4Types.hh> // for G4int, G4double 0035 # include <memory> // for unique_ptr 0036 # include <vector> // for vector 0037 0038 # include <core/session/onnxruntime_c_api.h> // for OrtMemoryInfo 0039 0040 /** 0041 * @brief Inference using the ONNX runtime. 0042 * 0043 * Creates an enviroment whcih manages an internal thread pool and creates an 0044 * inference session for the model saved as an ONNX file. 0045 * Runs the inference in the session using the input vector from Par04InferenceSetup. 0046 * 0047 **/ 0048 0049 class Par04OnnxInference : public Par04InferenceInterface 0050 { 0051 public: 0052 Par04OnnxInference(G4String, G4int, G4int, G4int, 0053 G4int, // For Execution Provider Runtime Flags (for now only CUDA) 0054 std::vector<const char*>& cuda_keys, std::vector<const char*>& cuda_values, 0055 G4String, G4String); 0056 0057 Par04OnnxInference(); 0058 0059 /// Run inference 0060 /// @param[in] aGenVector Input latent space and conditions 0061 /// @param[out] aEnergies Model output = generated shower energies 0062 /// @param[in] aSize Size of the output 0063 void RunInference(std::vector<float> aGenVector, std::vector<G4double>& aEnergies, int aSize); 0064 0065 private: 0066 /// Pointer to the ONNX enviroment 0067 std::unique_ptr<Ort::Env> fEnv; 0068 /// Pointer to the ONNX inference session 0069 std::unique_ptr<Ort::Session> fSession; 0070 /// ONNX settings 0071 Ort::SessionOptions fSessionOptions; 0072 /// ONNX memory info 0073 const OrtMemoryInfo* fInfo; 0074 struct MemoryInfo; 0075 /// the input names represent the names given to the model 0076 /// when defining the model's architecture (if applicable) 0077 /// they can also be retrieved from model.summary() 0078 std::vector<const char*> fInames; 0079 }; 0080 0081 # endif /* PAR04ONNXINFERENCE_HH */ 0082 #endif
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |
![]() ![]() |