Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-06-09 07:54:05

0001 //
0002 // ********************************************************************
0003 // * License and Disclaimer                                           *
0004 // *                                                                  *
0005 // * The  Geant4 software  is  copyright of the Copyright Holders  of *
0006 // * the Geant4 Collaboration.  It is provided  under  the terms  and *
0007 // * conditions of the Geant4 Software License,  included in the file *
0008 // * LICENSE and available at  http://cern.ch/geant4/license .  These *
0009 // * include a list of copyright holders.                             *
0010 // *                                                                  *
0011 // * Neither the authors of this software system, nor their employing *
0012 // * institutes,nor the agencies providing financial support for this *
0013 // * work  make  any representation or  warranty, express or implied, *
0014 // * regarding  this  software system or assume any liability for its *
0015 // * use.  Please see the license in the file  LICENSE  and URL above *
0016 // * for the full disclaimer and the limitation of liability.         *
0017 // *                                                                  *
0018 // * This  code  implementation is the result of  the  scientific and *
0019 // * technical work of the GEANT4 collaboration.                      *
0020 // * By using,  copying,  modifying or  distributing the software (or *
0021 // * any work based  on the software)  you  agree  to acknowledge its *
0022 // * use  in  resulting  scientific  publications,  and indicate your *
0023 // * acceptance of all terms of the Geant4 Software license.          *
0024 // ********************************************************************
0025 //
0026 /// \file Par04TorchInference.cc
0027 /// \brief Implementation of the Par04TorchInference class
0028 
0029 #ifdef USE_INFERENCE_TORCH
0030 #  include "Par04TorchInference.hh"
0031 
0032 #  include "Par04InferenceInterface.hh"  // for Par04InferenceInterface
0033 
0034 #  include <algorithm>  // for copy, max
0035 #  include <cassert>  // for assert
0036 #  include <cstddef>  // for size_t
0037 #  include <cstdint>  // for int64_t
0038 #  include <torch/torch.h>
0039 #  include <utility>  // for move
0040 
0041 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0042 
0043 Par04TorchInference::Par04TorchInference(G4String modelPath) : Par04InferenceInterface()
0044 {
0045   fModule = torch::jit::load(modelPath);
0046 }
0047 
0048 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0049 
0050 void Par04TorchInference::RunInference(std::vector<float> aGenVector,
0051                                        std::vector<G4double>& aEnergies, int aSize)
0052 {
0053   std::vector<torch::jit::IValue> genInput;
0054 
0055   if (aGenVector.size()!=8) {
0056     // VAE
0057     // latentSize : size of the latent space
0058     // 4 is the size of the condition vector
0059     int latentSize = aGenVector.size() - 4;
0060     // split into latent and condition vectors
0061     std::vector<float> latent;
0062     for (int i = 0; i < latentSize; i++) {
0063       latent.push_back(aGenVector[i]);
0064     }
0065     std::vector<float> energy;
0066     energy.push_back(aGenVector[latentSize + 1]);
0067     std::vector<float> angle;
0068     angle.push_back(aGenVector[latentSize + 2]);
0069     std::vector<float> geo;
0070     for (int i = latentSize + 2; i < latentSize + 4; i++) {
0071       geo.push_back(aGenVector[i]);
0072     }
0073 
0074     // convert vectors to tensors
0075     torch::Tensor latentVector = torch::tensor(latent);
0076     torch::Tensor eTensor = torch::tensor(energy);
0077     torch::Tensor angleTensor = torch::tensor(angle);
0078     torch::Tensor geoTensor = torch::tensor(geo);
0079 
0080     genInput.push_back(latentVector);
0081     genInput.push_back(eTensor);
0082     genInput.push_back(angleTensor);
0083     genInput.push_back(geoTensor);
0084   } else {
0085     // CaloDiT-2
0086     torch::Tensor conditions = torch::tensor(aGenVector);
0087     genInput.push_back(conditions);
0088   }
0089   // equivalent to torch.no_grad()
0090   torch::NoGradGuard no_grad;
0091 
0092   at::Tensor outTensor = fModule.forward(genInput).toTensor().contiguous();
0093 
0094   std::vector<G4double> output(outTensor.data_ptr<float>(),
0095                                outTensor.data_ptr<float>() + outTensor.numel());
0096 
0097   aEnergies.assign(aSize, 0);
0098   for (int i = 0; i < aSize; i++) {
0099     aEnergies[i] = output[i];
0100   }
0101 }
0102 
0103 #endif