Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-23 09:22:34

0001 //
0002 // ********************************************************************
0003 // * License and Disclaimer                                           *
0004 // *                                                                  *
0005 // * The  Geant4 software  is  copyright of the Copyright Holders  of *
0006 // * the Geant4 Collaboration.  It is provided  under  the terms  and *
0007 // * conditions of the Geant4 Software License,  included in the file *
0008 // * LICENSE and available at  http://cern.ch/geant4/license .  These *
0009 // * include a list of copyright holders.                             *
0010 // *                                                                  *
0011 // * Neither the authors of this software system, nor their employing *
0012 // * institutes,nor the agencies providing financial support for this *
0013 // * work  make  any representation or  warranty, express or implied, *
0014 // * regarding  this  software system or assume any liability for its *
0015 // * use.  Please see the license in the file  LICENSE  and URL above *
0016 // * for the full disclaimer and the limitation of liability.         *
0017 // *                                                                  *
0018 // * This  code  implementation is the result of  the  scientific and *
0019 // * technical work of the GEANT4 collaboration.                      *
0020 // * By using,  copying,  modifying or  distributing the software (or *
0021 // * any work based  on the software)  you  agree  to acknowledge its *
0022 // * use  in  resulting  scientific  publications,  and indicate your *
0023 // * acceptance of all terms of the Geant4 Software license.          *
0024 // ********************************************************************
0025 //
0026 
0027 #ifdef USE_INFERENCE_ONNX
0028 #  ifndef PAR04ONNXINFERENCE_HH
0029 #    define PAR04ONNXINFERENCE_HH
0030 #    include "Par04InferenceInterface.hh"  // for Par04InferenceInterface
0031 #    include "core/session/onnxruntime_cxx_api.h"  // for Env, Session, SessionO...
0032 
0033 #    include <G4String.hh>  // for G4String
0034 #    include <G4Types.hh>  // for G4int, G4double
0035 #    include <memory>  // for unique_ptr
0036 #    include <vector>  // for vector
0037 
0038 #    include <core/session/onnxruntime_c_api.h>  // for OrtMemoryInfo
0039 
0040 /**
0041  * @brief Inference using the ONNX runtime.
0042  *
0043  * Creates an enviroment whcih manages an internal thread pool and creates an
0044  * inference session for the model saved as an ONNX file.
0045  * Runs the inference in the session using the input vector from Par04InferenceSetup.
0046  *
0047  **/
0048 
0049 class Par04OnnxInference : public Par04InferenceInterface
0050 {
0051   public:
0052     Par04OnnxInference(G4String, G4int, G4int, G4int,
0053                        G4int,  // For Execution Provider Runtime Flags (for now only CUDA)
0054                        std::vector<const char*>& cuda_keys, std::vector<const char*>& cuda_values,
0055                        G4String, G4String);
0056 
0057     Par04OnnxInference();
0058 
0059     /// Run inference
0060     /// @param[in] aGenVector Input latent space and conditions
0061     /// @param[out] aEnergies Model output = generated shower energies
0062     /// @param[in] aSize Size of the output
0063     void RunInference(std::vector<float> aGenVector, std::vector<G4double>& aEnergies, int aSize);
0064 
0065   private:
0066     /// Pointer to the ONNX enviroment
0067     std::unique_ptr<Ort::Env> fEnv;
0068     /// Pointer to the ONNX inference session
0069     std::unique_ptr<Ort::Session> fSession;
0070     /// ONNX settings
0071     Ort::SessionOptions fSessionOptions;
0072     /// ONNX memory info
0073     const OrtMemoryInfo* fInfo;
0074     struct MemoryInfo;
0075     /// the input names represent the names given to the model
0076     /// when defining  the model's architecture (if applicable)
0077     /// they can also be retrieved from model.summary()
0078     std::vector<const char*> fInames;
0079 };
0080 
0081 #  endif /* PAR04ONNXINFERENCE_HH */
0082 #endif