File indexing completed on 2026-06-09 07:54:05
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 #ifdef USE_INFERENCE_TORCH
0030 # include "Par04TorchInference.hh"
0031
0032 # include "Par04InferenceInterface.hh" // for Par04InferenceInterface
0033
0034 # include <algorithm> // for copy, max
0035 # include <cassert> // for assert
0036 # include <cstddef> // for size_t
0037 # include <cstdint> // for int64_t
0038 # include <torch/torch.h>
0039 # include <utility> // for move
0040
0041
0042
0043 Par04TorchInference::Par04TorchInference(G4String modelPath) : Par04InferenceInterface()
0044 {
0045 fModule = torch::jit::load(modelPath);
0046 }
0047
0048
0049
0050 void Par04TorchInference::RunInference(std::vector<float> aGenVector,
0051 std::vector<G4double>& aEnergies, int aSize)
0052 {
0053 std::vector<torch::jit::IValue> genInput;
0054
0055 if (aGenVector.size()!=8) {
0056
0057
0058
0059 int latentSize = aGenVector.size() - 4;
0060
0061 std::vector<float> latent;
0062 for (int i = 0; i < latentSize; i++) {
0063 latent.push_back(aGenVector[i]);
0064 }
0065 std::vector<float> energy;
0066 energy.push_back(aGenVector[latentSize + 1]);
0067 std::vector<float> angle;
0068 angle.push_back(aGenVector[latentSize + 2]);
0069 std::vector<float> geo;
0070 for (int i = latentSize + 2; i < latentSize + 4; i++) {
0071 geo.push_back(aGenVector[i]);
0072 }
0073
0074
0075 torch::Tensor latentVector = torch::tensor(latent);
0076 torch::Tensor eTensor = torch::tensor(energy);
0077 torch::Tensor angleTensor = torch::tensor(angle);
0078 torch::Tensor geoTensor = torch::tensor(geo);
0079
0080 genInput.push_back(latentVector);
0081 genInput.push_back(eTensor);
0082 genInput.push_back(angleTensor);
0083 genInput.push_back(geoTensor);
0084 } else {
0085
0086 torch::Tensor conditions = torch::tensor(aGenVector);
0087 genInput.push_back(conditions);
0088 }
0089
0090 torch::NoGradGuard no_grad;
0091
0092 at::Tensor outTensor = fModule.forward(genInput).toTensor().contiguous();
0093
0094 std::vector<G4double> output(outTensor.data_ptr<float>(),
0095 outTensor.data_ptr<float>() + outTensor.numel());
0096
0097 aEnergies.assign(aSize, 0);
0098 for (int i = 0; i < aSize; i++) {
0099 aEnergies[i] = output[i];
0100 }
0101 }
0102
0103 #endif