File indexing completed on 2025-10-28 08:01:26
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 #ifdef USE_INFERENCE_TORCH
0028 # include "Par04TorchInference.hh"
0029
0030 # include "Par04InferenceInterface.hh" // for Par04InferenceInterface
0031
0032 # include <algorithm> // for copy, max
0033 # include <cassert> // for assert
0034 # include <cstddef> // for size_t
0035 # include <cstdint> // for int64_t
0036 # include <torch/torch.h>
0037 # include <utility> // for move
0038
0039
0040
0041 Par04TorchInference::Par04TorchInference(G4String modelPath) : Par04InferenceInterface()
0042 {
0043 fModule = torch::jit::load(modelPath);
0044 }
0045
0046
0047
0048 void Par04TorchInference::RunInference(std::vector<float> aGenVector,
0049 std::vector<G4double>& aEnergies, int aSize)
0050 {
0051 std::vector<torch::jit::IValue> genInput;
0052
0053 if (aGenVector.size()!=8) {
0054
0055
0056
0057 int latentSize = aGenVector.size() - 4;
0058
0059 std::vector<float> latent;
0060 for (int i = 0; i < latentSize; i++) {
0061 latent.push_back(aGenVector[i]);
0062 }
0063 std::vector<float> energy;
0064 energy.push_back(aGenVector[latentSize + 1]);
0065 std::vector<float> angle;
0066 angle.push_back(aGenVector[latentSize + 2]);
0067 std::vector<float> geo;
0068 for (int i = latentSize + 2; i < latentSize + 4; i++) {
0069 geo.push_back(aGenVector[i]);
0070 }
0071
0072
0073 torch::Tensor latentVector = torch::tensor(latent);
0074 torch::Tensor eTensor = torch::tensor(energy);
0075 torch::Tensor angleTensor = torch::tensor(angle);
0076 torch::Tensor geoTensor = torch::tensor(geo);
0077
0078 genInput.push_back(latentVector);
0079 genInput.push_back(eTensor);
0080 genInput.push_back(angleTensor);
0081 genInput.push_back(geoTensor);
0082 } else {
0083
0084 torch::Tensor conditions = torch::tensor(aGenVector);
0085 genInput.push_back(conditions);
0086 }
0087
0088 torch::NoGradGuard no_grad;
0089
0090 at::Tensor outTensor = fModule.forward(genInput).toTensor().contiguous();
0091
0092 std::vector<G4double> output(outTensor.data_ptr<float>(),
0093 outTensor.data_ptr<float>() + outTensor.numel());
0094
0095 aEnergies.assign(aSize, 0);
0096 for (int i = 0; i < aSize; i++) {
0097 aEnergies[i] = output[i];
0098 }
0099 }
0100
0101 #endif