File indexing completed on 2026-03-29 07:52:01
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 #ifdef USE_INFERENCE_ONNX
0030 # include "Par04OnnxInference.hh"
0031
0032 # include "Par04InferenceInterface.hh" // for Par04InferenceInterface
0033
0034 # include <onnxruntime_cxx_api.h> // for Value, Session, Env
0035 # include <algorithm> // for copy, max
0036 # include <cassert> // for assert
0037 # include <cstddef> // for size_t
0038 # include <cstdint> // for int64_t
0039 # include <utility> // for move
0040
0041 # ifdef USE_CUDA
0042 # include "cuda_runtime_api.h"
0043 # endif
0044
0045
0046
0047 Par04OnnxInference::Par04OnnxInference(G4String modelPath, G4int profileFlag, G4int optimizeFlag,
0048 G4int intraOpNumThreads, G4int cudaFlag,
0049 std::vector<const char*>& cuda_keys,
0050 std::vector<const char*>& cuda_values,
0051 G4String ModelSavePath, G4String profilingOutputSavePath)
0052
0053 : Par04InferenceInterface()
0054 {
0055
0056 auto envLocal = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ENV");
0057 fEnv = std::move(envLocal);
0058
0059
0060 const auto& ortApi = Ort::GetApi();
0061 fSessionOptions.SetIntraOpNumThreads(intraOpNumThreads);
0062
0063
0064
0065 if (optimizeFlag) {
0066 fSessionOptions.SetOptimizedModelFilePath("opt-graph");
0067 fSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
0068
0069 }
0070 else
0071 fSessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL);
0072 # ifdef USE_CUDA
0073 if (cudaFlag) {
0074 OrtCUDAProviderOptionsV2* fCudaOptions = nullptr;
0075
0076
0077 (void)ortApi.CreateCUDAProviderOptions(&fCudaOptions);
0078
0079 (void)ortApi.UpdateCUDAProviderOptions(fCudaOptions, cuda_keys.data(), cuda_values.data(),
0080 cuda_keys.size());
0081
0082
0083 (void)ortApi.SessionOptionsAppendExecutionProvider_CUDA_V2(fSessionOptions, fCudaOptions);
0084 }
0085 # endif
0086
0087 if (profileFlag) fSessionOptions.EnableProfiling("opt.json");
0088
0089 auto sessionLocal = std::make_unique<Ort::Session>(*fEnv, modelPath, fSessionOptions);
0090 fSession = std::move(sessionLocal);
0091 fInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault);
0092 }
0093
0094
0095
0096 void Par04OnnxInference::RunInference(std::vector<float> aGenVector,
0097 std::vector<G4double>& aEnergies, int aSize)
0098 {
0099
0100 Ort::AllocatorWithDefaultOptions allocator;
0101 # if ORT_API_VERSION < 13
0102
0103 auto allocDeleter = [&allocator](char* p) {
0104 allocator.Free(p);
0105 };
0106 using AllocatedStringPtr = std::unique_ptr<char, decltype(allocDeleter)>;
0107 # endif
0108 std::vector<int64_t> input_node_dims;
0109 size_t num_input_nodes = fSession->GetInputCount();
0110 std::vector<const char*> input_node_names(num_input_nodes);
0111 for (std::size_t i = 0; i < num_input_nodes; i++) {
0112 # if ORT_API_VERSION < 13
0113 const auto input_name =
0114 AllocatedStringPtr(fSession->GetInputName(i, allocator), allocDeleter).release();
0115 # else
0116 const auto input_name = fSession->GetInputNameAllocated(i, allocator).release();
0117 # endif
0118 fInames = {input_name};
0119 input_node_names[i] = input_name;
0120 Ort::TypeInfo type_info = fSession->GetInputTypeInfo(i);
0121 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0122 input_node_dims = tensor_info.GetShape();
0123 for (std::size_t j = 0; j < input_node_dims.size(); j++) {
0124 if (input_node_dims[j] < 0) input_node_dims[j] = 1;
0125 }
0126 }
0127
0128 std::vector<int64_t> output_node_dims;
0129 size_t num_output_nodes = fSession->GetOutputCount();
0130 std::vector<const char*> output_node_names(num_output_nodes);
0131 for (std::size_t i = 0; i < num_output_nodes; i++) {
0132 # if ORT_API_VERSION < 13
0133 const auto output_name =
0134 AllocatedStringPtr(fSession->GetOutputName(i, allocator), allocDeleter).release();
0135 # else
0136 const auto output_name = fSession->GetOutputNameAllocated(i, allocator).release();
0137 # endif
0138 output_node_names[i] = output_name;
0139 Ort::TypeInfo type_info = fSession->GetOutputTypeInfo(i);
0140 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0141 output_node_dims = tensor_info.GetShape();
0142 for (std::size_t j = 0; j < output_node_dims.size(); j++) {
0143 if (output_node_dims[j] < 0) output_node_dims[j] = 1;
0144 }
0145 }
0146
0147
0148 std::vector<int64_t> dims = {1, (unsigned)(aGenVector.size())};
0149 Ort::Value Input_noise_tensor = Ort::Value::CreateTensor<float>(
0150 fInfo, aGenVector.data(), aGenVector.size(), dims.data(), dims.size());
0151 assert(Input_noise_tensor.IsTensor());
0152 std::vector<Ort::Value> ort_inputs;
0153 ort_inputs.push_back(std::move(Input_noise_tensor));
0154
0155 std::vector<Ort::Value> ort_outputs =
0156 fSession->Run(Ort::RunOptions{nullptr}, fInames.data(), ort_inputs.data(), ort_inputs.size(),
0157 output_node_names.data(), output_node_names.size());
0158
0159 float* floatarr = ort_outputs.front().GetTensorMutableData<float>();
0160 aEnergies.assign(aSize, 0);
0161 for (int i = 0; i < aSize; ++i)
0162 aEnergies[i] = floatarr[i];
0163 }
0164
0165 #endif