File indexing completed on 2025-02-23 09:22:35
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 #ifdef USE_INFERENCE_ONNX
0027 # include "Par04OnnxInference.hh"
0028
0029 # include "Par04InferenceInterface.hh" // for Par04InferenceInterface
0030
0031 # include <algorithm> // for copy, max
0032 # include <cassert> // for assert
0033 # include <cstddef> // for size_t
0034 # include <cstdint> // for int64_t
0035 # include <utility> // for move
0036
0037 # include <core/session/onnxruntime_cxx_api.h> // for Value, Session, Env
0038 # ifdef USE_CUDA
0039 # include "cuda_runtime_api.h"
0040 # endif
0041
0042
0043
0044 Par04OnnxInference::Par04OnnxInference(G4String modelPath, G4int profileFlag, G4int optimizeFlag,
0045 G4int intraOpNumThreads, G4int cudaFlag,
0046 std::vector<const char*>& cuda_keys,
0047 std::vector<const char*>& cuda_values,
0048 G4String ModelSavePath, G4String profilingOutputSavePath)
0049
0050 : Par04InferenceInterface()
0051 {
0052
0053 auto envLocal = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ENV");
0054 fEnv = std::move(envLocal);
0055
0056
0057 const auto& ortApi = Ort::GetApi();
0058 fSessionOptions.SetIntraOpNumThreads(intraOpNumThreads);
0059
0060
0061
0062 if (optimizeFlag) {
0063 fSessionOptions.SetOptimizedModelFilePath("opt-graph");
0064 fSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
0065
0066 }
0067 else
0068 fSessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL);
0069 # ifdef USE_CUDA
0070 if (cudaFlag) {
0071 OrtCUDAProviderOptionsV2* fCudaOptions = nullptr;
0072
0073
0074 (void)ortApi.CreateCUDAProviderOptions(&fCudaOptions);
0075
0076 (void)ortApi.UpdateCUDAProviderOptions(fCudaOptions, cuda_keys.data(), cuda_values.data(),
0077 cuda_keys.size());
0078
0079
0080 (void)ortApi.SessionOptionsAppendExecutionProvider_CUDA_V2(fSessionOptions, fCudaOptions);
0081 }
0082 # endif
0083
0084 if (profileFlag) fSessionOptions.EnableProfiling("opt.json");
0085
0086 auto sessionLocal = std::make_unique<Ort::Session>(*fEnv, modelPath, fSessionOptions);
0087 fSession = std::move(sessionLocal);
0088 fInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault);
0089 }
0090
0091
0092
0093 void Par04OnnxInference::RunInference(std::vector<float> aGenVector,
0094 std::vector<G4double>& aEnergies, int aSize)
0095 {
0096
0097 Ort::AllocatorWithDefaultOptions allocator;
0098 # if ORT_API_VERSION < 13
0099
0100 auto allocDeleter = [&allocator](char* p) {
0101 allocator.Free(p);
0102 };
0103 using AllocatedStringPtr = std::unique_ptr<char, decltype(allocDeleter)>;
0104 # endif
0105 std::vector<int64_t> input_node_dims;
0106 size_t num_input_nodes = fSession->GetInputCount();
0107 std::vector<const char*> input_node_names(num_input_nodes);
0108 for (std::size_t i = 0; i < num_input_nodes; i++) {
0109 # if ORT_API_VERSION < 13
0110 const auto input_name =
0111 AllocatedStringPtr(fSession->GetInputName(i, allocator), allocDeleter).release();
0112 # else
0113 const auto input_name = fSession->GetInputNameAllocated(i, allocator).release();
0114 # endif
0115 fInames = {input_name};
0116 input_node_names[i] = input_name;
0117 Ort::TypeInfo type_info = fSession->GetInputTypeInfo(i);
0118 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0119 input_node_dims = tensor_info.GetShape();
0120 for (std::size_t j = 0; j < input_node_dims.size(); j++) {
0121 if (input_node_dims[j] < 0) input_node_dims[j] = 1;
0122 }
0123 }
0124
0125 std::vector<int64_t> output_node_dims;
0126 size_t num_output_nodes = fSession->GetOutputCount();
0127 std::vector<const char*> output_node_names(num_output_nodes);
0128 for (std::size_t i = 0; i < num_output_nodes; i++) {
0129 # if ORT_API_VERSION < 13
0130 const auto output_name =
0131 AllocatedStringPtr(fSession->GetOutputName(i, allocator), allocDeleter).release();
0132 # else
0133 const auto output_name = fSession->GetOutputNameAllocated(i, allocator).release();
0134 # endif
0135 output_node_names[i] = output_name;
0136 Ort::TypeInfo type_info = fSession->GetOutputTypeInfo(i);
0137 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0138 output_node_dims = tensor_info.GetShape();
0139 for (std::size_t j = 0; j < output_node_dims.size(); j++) {
0140 if (output_node_dims[j] < 0) output_node_dims[j] = 1;
0141 }
0142 }
0143
0144
0145 std::vector<int64_t> dims = {1, (unsigned)(aGenVector.size())};
0146 Ort::Value Input_noise_tensor = Ort::Value::CreateTensor<float>(
0147 fInfo, aGenVector.data(), aGenVector.size(), dims.data(), dims.size());
0148 assert(Input_noise_tensor.IsTensor());
0149 std::vector<Ort::Value> ort_inputs;
0150 ort_inputs.push_back(std::move(Input_noise_tensor));
0151
0152 std::vector<Ort::Value> ort_outputs =
0153 fSession->Run(Ort::RunOptions{nullptr}, fInames.data(), ort_inputs.data(), ort_inputs.size(),
0154 output_node_names.data(), output_node_names.size());
0155
0156 float* floatarr = ort_outputs.front().GetTensorMutableData<float>();
0157 aEnergies.assign(aSize, 0);
0158 for (int i = 0; i < aSize; ++i)
0159 aEnergies[i] = floatarr[i];
0160 }
0161
0162 #endif