Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-03-29 07:52:01

0001 //
0002 // ********************************************************************
0003 // * License and Disclaimer                                           *
0004 // *                                                                  *
0005 // * The  Geant4 software  is  copyright of the Copyright Holders  of *
0006 // * the Geant4 Collaboration.  It is provided  under  the terms  and *
0007 // * conditions of the Geant4 Software License,  included in the file *
0008 // * LICENSE and available at  http://cern.ch/geant4/license .  These *
0009 // * include a list of copyright holders.                             *
0010 // *                                                                  *
0011 // * Neither the authors of this software system, nor their employing *
0012 // * institutes,nor the agencies providing financial support for this *
0013 // * work  make  any representation or  warranty, express or implied, *
0014 // * regarding  this  software system or assume any liability for its *
0015 // * use.  Please see the license in the file  LICENSE  and URL above *
0016 // * for the full disclaimer and the limitation of liability.         *
0017 // *                                                                  *
0018 // * This  code  implementation is the result of  the  scientific and *
0019 // * technical work of the GEANT4 collaboration.                      *
0020 // * By using,  copying,  modifying or  distributing the software (or *
0021 // * any work based  on the software)  you  agree  to acknowledge its *
0022 // * use  in  resulting  scientific  publications,  and indicate your *
0023 // * acceptance of all terms of the Geant4 Software license.          *
0024 // ********************************************************************
0025 //
0026 /// \file Par04OnnxInference.cc
0027 /// \brief Implementation of the Par04OnnxInference class
0028 
0029 #ifdef USE_INFERENCE_ONNX
0030 #  include "Par04OnnxInference.hh"
0031 
0032 #  include "Par04InferenceInterface.hh"  // for Par04InferenceInterface
0033 
0034 #  include <onnxruntime_cxx_api.h>  // for Value, Session, Env
0035 #  include <algorithm>  // for copy, max
0036 #  include <cassert>  // for assert
0037 #  include <cstddef>  // for size_t
0038 #  include <cstdint>  // for int64_t
0039 #  include <utility>  // for move
0040 
0041 #  ifdef USE_CUDA
0042 #    include "cuda_runtime_api.h"
0043 #  endif
0044 
0045 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0046 
0047 Par04OnnxInference::Par04OnnxInference(G4String modelPath, G4int profileFlag, G4int optimizeFlag,
0048                                        G4int intraOpNumThreads, G4int cudaFlag,
0049                                        std::vector<const char*>& cuda_keys,
0050                                        std::vector<const char*>& cuda_values,
0051                                        G4String ModelSavePath, G4String profilingOutputSavePath)
0052 
0053   : Par04InferenceInterface()
0054 {
0055   // initialization of the enviroment and inference session
0056   auto envLocal = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ENV");
0057   fEnv = std::move(envLocal);
0058   // Creating a OrtApi Class variable for getting access to C api, necessary for
0059   // CUDA
0060   const auto& ortApi = Ort::GetApi();
0061   fSessionOptions.SetIntraOpNumThreads(intraOpNumThreads);
0062   // graph optimizations of the model
0063   // if the flag is not set to true none of the optimizations will be applied
0064   // if it is set to true all the optimizations will be applied
0065   if (optimizeFlag) {
0066     fSessionOptions.SetOptimizedModelFilePath("opt-graph");
0067     fSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
0068     // ORT_ENABLE_BASIC #### ORT_ENABLE_EXTENDED
0069   }
0070   else
0071     fSessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL);
0072 #  ifdef USE_CUDA
0073   if (cudaFlag) {
0074     OrtCUDAProviderOptionsV2* fCudaOptions = nullptr;
0075     // Initialize the CUDA provider options, fCudaOptions should now point to a
0076     // valid CUDA configuration.
0077     (void)ortApi.CreateCUDAProviderOptions(&fCudaOptions);
0078     // Update the CUDA provider options
0079     (void)ortApi.UpdateCUDAProviderOptions(fCudaOptions, cuda_keys.data(), cuda_values.data(),
0080                                            cuda_keys.size());
0081     // Append the CUDA execution provider to the session options, indicating to
0082     // use CUDA for execution
0083     (void)ortApi.SessionOptionsAppendExecutionProvider_CUDA_V2(fSessionOptions, fCudaOptions);
0084   }
0085 #  endif
0086   // save json file for model execution profiling
0087   if (profileFlag) fSessionOptions.EnableProfiling("opt.json");
0088 
0089   auto sessionLocal = std::make_unique<Ort::Session>(*fEnv, modelPath, fSessionOptions);
0090   fSession = std::move(sessionLocal);
0091   fInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault);
0092 }
0093 
0094 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0095 
0096 void Par04OnnxInference::RunInference(std::vector<float> aGenVector,
0097                                       std::vector<G4double>& aEnergies, int aSize)
0098 {
0099   // input nodes
0100   Ort::AllocatorWithDefaultOptions allocator;
0101 #  if ORT_API_VERSION < 13
0102   // Before 1.13 we have to roll our own unique_ptr wrapper here
0103   auto allocDeleter = [&allocator](char* p) {
0104     allocator.Free(p);
0105   };
0106   using AllocatedStringPtr = std::unique_ptr<char, decltype(allocDeleter)>;
0107 #  endif
0108   std::vector<int64_t> input_node_dims;
0109   size_t num_input_nodes = fSession->GetInputCount();
0110   std::vector<const char*> input_node_names(num_input_nodes);
0111   for (std::size_t i = 0; i < num_input_nodes; i++) {
0112 #  if ORT_API_VERSION < 13
0113     const auto input_name =
0114       AllocatedStringPtr(fSession->GetInputName(i, allocator), allocDeleter).release();
0115 #  else
0116     const auto input_name = fSession->GetInputNameAllocated(i, allocator).release();
0117 #  endif
0118     fInames = {input_name};
0119     input_node_names[i] = input_name;
0120     Ort::TypeInfo type_info = fSession->GetInputTypeInfo(i);
0121     auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0122     input_node_dims = tensor_info.GetShape();
0123     for (std::size_t j = 0; j < input_node_dims.size(); j++) {
0124       if (input_node_dims[j] < 0) input_node_dims[j] = 1;
0125     }
0126   }
0127   // output nodes
0128   std::vector<int64_t> output_node_dims;
0129   size_t num_output_nodes = fSession->GetOutputCount();
0130   std::vector<const char*> output_node_names(num_output_nodes);
0131   for (std::size_t i = 0; i < num_output_nodes; i++) {
0132 #  if ORT_API_VERSION < 13
0133     const auto output_name =
0134       AllocatedStringPtr(fSession->GetOutputName(i, allocator), allocDeleter).release();
0135 #  else
0136     const auto output_name = fSession->GetOutputNameAllocated(i, allocator).release();
0137 #  endif
0138     output_node_names[i] = output_name;
0139     Ort::TypeInfo type_info = fSession->GetOutputTypeInfo(i);
0140     auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0141     output_node_dims = tensor_info.GetShape();
0142     for (std::size_t j = 0; j < output_node_dims.size(); j++) {
0143       if (output_node_dims[j] < 0) output_node_dims[j] = 1;
0144     }
0145   }
0146 
0147   // create input tensor object from data values
0148   std::vector<int64_t> dims = {1, (unsigned)(aGenVector.size())};
0149   Ort::Value Input_noise_tensor = Ort::Value::CreateTensor<float>(
0150     fInfo, aGenVector.data(), aGenVector.size(), dims.data(), dims.size());
0151   assert(Input_noise_tensor.IsTensor());
0152   std::vector<Ort::Value> ort_inputs;
0153   ort_inputs.push_back(std::move(Input_noise_tensor));
0154   // run the inference session
0155   std::vector<Ort::Value> ort_outputs =
0156     fSession->Run(Ort::RunOptions{nullptr}, fInames.data(), ort_inputs.data(), ort_inputs.size(),
0157                   output_node_names.data(), output_node_names.size());
0158   // get pointer to output tensor float values
0159   float* floatarr = ort_outputs.front().GetTensorMutableData<float>();
0160   aEnergies.assign(aSize, 0);
0161   for (int i = 0; i < aSize; ++i)
0162     aEnergies[i] = floatarr[i];
0163 }
0164 
0165 #endif