Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-23 09:22:35

0001 //
0002 // ********************************************************************
0003 // * License and Disclaimer                                           *
0004 // *                                                                  *
0005 // * The  Geant4 software  is  copyright of the Copyright Holders  of *
0006 // * the Geant4 Collaboration.  It is provided  under  the terms  and *
0007 // * conditions of the Geant4 Software License,  included in the file *
0008 // * LICENSE and available at  http://cern.ch/geant4/license .  These *
0009 // * include a list of copyright holders.                             *
0010 // *                                                                  *
0011 // * Neither the authors of this software system, nor their employing *
0012 // * institutes,nor the agencies providing financial support for this *
0013 // * work  make  any representation or  warranty, express or implied, *
0014 // * regarding  this  software system or assume any liability for its *
0015 // * use.  Please see the license in the file  LICENSE  and URL above *
0016 // * for the full disclaimer and the limitation of liability.         *
0017 // *                                                                  *
0018 // * This  code  implementation is the result of  the  scientific and *
0019 // * technical work of the GEANT4 collaboration.                      *
0020 // * By using,  copying,  modifying or  distributing the software (or *
0021 // * any work based  on the software)  you  agree  to acknowledge its *
0022 // * use  in  resulting  scientific  publications,  and indicate your *
0023 // * acceptance of all terms of the Geant4 Software license.          *
0024 // ********************************************************************
0025 //
0026 #ifdef USE_INFERENCE_ONNX
0027 #  include "Par04OnnxInference.hh"
0028 
0029 #  include "Par04InferenceInterface.hh"  // for Par04InferenceInterface
0030 
0031 #  include <algorithm>  // for copy, max
0032 #  include <cassert>  // for assert
0033 #  include <cstddef>  // for size_t
0034 #  include <cstdint>  // for int64_t
0035 #  include <utility>  // for move
0036 
0037 #  include <core/session/onnxruntime_cxx_api.h>  // for Value, Session, Env
0038 #  ifdef USE_CUDA
0039 #    include "cuda_runtime_api.h"
0040 #  endif
0041 
0042 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0043 
0044 Par04OnnxInference::Par04OnnxInference(G4String modelPath, G4int profileFlag, G4int optimizeFlag,
0045                                        G4int intraOpNumThreads, G4int cudaFlag,
0046                                        std::vector<const char*>& cuda_keys,
0047                                        std::vector<const char*>& cuda_values,
0048                                        G4String ModelSavePath, G4String profilingOutputSavePath)
0049 
0050   : Par04InferenceInterface()
0051 {
0052   // initialization of the enviroment and inference session
0053   auto envLocal = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ENV");
0054   fEnv = std::move(envLocal);
0055   // Creating a OrtApi Class variable for getting access to C api, necessary for
0056   // CUDA
0057   const auto& ortApi = Ort::GetApi();
0058   fSessionOptions.SetIntraOpNumThreads(intraOpNumThreads);
0059   // graph optimizations of the model
0060   // if the flag is not set to true none of the optimizations will be applied
0061   // if it is set to true all the optimizations will be applied
0062   if (optimizeFlag) {
0063     fSessionOptions.SetOptimizedModelFilePath("opt-graph");
0064     fSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
0065     // ORT_ENABLE_BASIC #### ORT_ENABLE_EXTENDED
0066   }
0067   else
0068     fSessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL);
0069 #  ifdef USE_CUDA
0070   if (cudaFlag) {
0071     OrtCUDAProviderOptionsV2* fCudaOptions = nullptr;
0072     // Initialize the CUDA provider options, fCudaOptions should now point to a
0073     // valid CUDA configuration.
0074     (void)ortApi.CreateCUDAProviderOptions(&fCudaOptions);
0075     // Update the CUDA provider options
0076     (void)ortApi.UpdateCUDAProviderOptions(fCudaOptions, cuda_keys.data(), cuda_values.data(),
0077                                            cuda_keys.size());
0078     // Append the CUDA execution provider to the session options, indicating to
0079     // use CUDA for execution
0080     (void)ortApi.SessionOptionsAppendExecutionProvider_CUDA_V2(fSessionOptions, fCudaOptions);
0081   }
0082 #  endif
0083   // save json file for model execution profiling
0084   if (profileFlag) fSessionOptions.EnableProfiling("opt.json");
0085 
0086   auto sessionLocal = std::make_unique<Ort::Session>(*fEnv, modelPath, fSessionOptions);
0087   fSession = std::move(sessionLocal);
0088   fInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault);
0089 }
0090 
0091 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0092 
0093 void Par04OnnxInference::RunInference(std::vector<float> aGenVector,
0094                                       std::vector<G4double>& aEnergies, int aSize)
0095 {
0096   // input nodes
0097   Ort::AllocatorWithDefaultOptions allocator;
0098 #  if ORT_API_VERSION < 13
0099   // Before 1.13 we have to roll our own unique_ptr wrapper here
0100   auto allocDeleter = [&allocator](char* p) {
0101     allocator.Free(p);
0102   };
0103   using AllocatedStringPtr = std::unique_ptr<char, decltype(allocDeleter)>;
0104 #  endif
0105   std::vector<int64_t> input_node_dims;
0106   size_t num_input_nodes = fSession->GetInputCount();
0107   std::vector<const char*> input_node_names(num_input_nodes);
0108   for (std::size_t i = 0; i < num_input_nodes; i++) {
0109 #  if ORT_API_VERSION < 13
0110     const auto input_name =
0111       AllocatedStringPtr(fSession->GetInputName(i, allocator), allocDeleter).release();
0112 #  else
0113     const auto input_name = fSession->GetInputNameAllocated(i, allocator).release();
0114 #  endif
0115     fInames = {input_name};
0116     input_node_names[i] = input_name;
0117     Ort::TypeInfo type_info = fSession->GetInputTypeInfo(i);
0118     auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0119     input_node_dims = tensor_info.GetShape();
0120     for (std::size_t j = 0; j < input_node_dims.size(); j++) {
0121       if (input_node_dims[j] < 0) input_node_dims[j] = 1;
0122     }
0123   }
0124   // output nodes
0125   std::vector<int64_t> output_node_dims;
0126   size_t num_output_nodes = fSession->GetOutputCount();
0127   std::vector<const char*> output_node_names(num_output_nodes);
0128   for (std::size_t i = 0; i < num_output_nodes; i++) {
0129 #  if ORT_API_VERSION < 13
0130     const auto output_name =
0131       AllocatedStringPtr(fSession->GetOutputName(i, allocator), allocDeleter).release();
0132 #  else
0133     const auto output_name = fSession->GetOutputNameAllocated(i, allocator).release();
0134 #  endif
0135     output_node_names[i] = output_name;
0136     Ort::TypeInfo type_info = fSession->GetOutputTypeInfo(i);
0137     auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0138     output_node_dims = tensor_info.GetShape();
0139     for (std::size_t j = 0; j < output_node_dims.size(); j++) {
0140       if (output_node_dims[j] < 0) output_node_dims[j] = 1;
0141     }
0142   }
0143 
0144   // create input tensor object from data values
0145   std::vector<int64_t> dims = {1, (unsigned)(aGenVector.size())};
0146   Ort::Value Input_noise_tensor = Ort::Value::CreateTensor<float>(
0147     fInfo, aGenVector.data(), aGenVector.size(), dims.data(), dims.size());
0148   assert(Input_noise_tensor.IsTensor());
0149   std::vector<Ort::Value> ort_inputs;
0150   ort_inputs.push_back(std::move(Input_noise_tensor));
0151   // run the inference session
0152   std::vector<Ort::Value> ort_outputs =
0153     fSession->Run(Ort::RunOptions{nullptr}, fInames.data(), ort_inputs.data(), ort_inputs.size(),
0154                   output_node_names.data(), output_node_names.size());
0155   // get pointer to output tensor float values
0156   float* floatarr = ort_outputs.front().GetTensorMutableData<float>();
0157   aEnergies.assign(aSize, 0);
0158   for (int i = 0; i < aSize; ++i)
0159     aEnergies[i] = floatarr[i];
0160 }
0161 
0162 #endif