Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-10-13 08:28:37

0001 //
0002 // ********************************************************************
0003 // * License and Disclaimer                                           *
0004 // *                                                                  *
0005 // * The  Geant4 software  is  copyright of the Copyright Holders  of *
0006 // * the Geant4 Collaboration.  It is provided  under  the terms  and *
0007 // * conditions of the Geant4 Software License,  included in the file *
0008 // * LICENSE and available at  http://cern.ch/geant4/license .  These *
0009 // * include a list of copyright holders.                             *
0010 // *                                                                  *
0011 // * Neither the authors of this software system, nor their employing *
0012 // * institutes,nor the agencies providing financial support for this *
0013 // * work  make  any representation or  warranty, express or implied, *
0014 // * regarding  this  software system or assume any liability for its *
0015 // * use.  Please see the license in the file  LICENSE  and URL above *
0016 // * for the full disclaimer and the limitation of liability.         *
0017 // *                                                                  *
0018 // * This  code  implementation is the result of  the  scientific and *
0019 // * technical work of the GEANT4 collaboration.                      *
0020 // * By using,  copying,  modifying or  distributing the software (or *
0021 // * any work based  on the software)  you  agree  to acknowledge its *
0022 // * use  in  resulting  scientific  publications,  and indicate your *
0023 // * acceptance of all terms of the Geant4 Software license.          *
0024 // ********************************************************************
0025 //
0026 #ifdef USE_INFERENCE
0027 #  include "Par04InferenceSetup.hh"
0028 
0029 #  include "Par04InferenceInterface.hh"  // for Par04InferenceInterface
0030 #  include "Par04InferenceMessenger.hh"  // for Par04InferenceMessenger
0031 #  ifdef USE_INFERENCE_ONNX
0032 #    include "Par04OnnxInference.hh"  // for Par04OnnxInference
0033 #  endif
0034 #  ifdef USE_INFERENCE_LWTNN
0035 #    include "Par04LwtnnInference.hh"  // for Par04LwtnnInference
0036 #  endif
0037 #  ifdef USE_INFERENCE_TORCH
0038 #    include "Par04TorchInference.hh"  // for Par04TorchInference
0039 #  endif
0040 #  include "CLHEP/Random/RandGauss.h"  // for RandGauss
0041 
0042 #  include "G4RotationMatrix.hh"  // for G4RotationMatrix
0043 
0044 #  include <CLHEP/Units/SystemOfUnits.h>  // for pi, GeV, deg
0045 #  include <CLHEP/Vector/Rotation.h>  // for HepRotation
0046 #  include <CLHEP/Vector/ThreeVector.h>  // for Hep3Vector
0047 #  include <G4Exception.hh>  // for G4Exception
0048 #  include <G4ExceptionSeverity.hh>  // for FatalException
0049 #  include <G4ThreeVector.hh>  // for G4ThreeVector
0050 #  include <algorithm>  // for max, copy
0051 #  include <cmath>  // for cos, sin
0052 #  include <string>  // for char_traits, basic_string
0053 
0054 #  include <ext/alloc_traits.h>  // for __alloc_traits<>::value_type
0055 
0056 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0057 
0058 Par04InferenceSetup::Par04InferenceSetup() : fInferenceMessenger(new Par04InferenceMessenger(this))
0059 {}
0060 
0061 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0062 
0063 Par04InferenceSetup::~Par04InferenceSetup() {}
0064 
0065 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0066 
0067 G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy)
0068 {
0069   /// Energy of electrons used in training dataset
0070   if (aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV) return true;
0071   return false;
0072 }
0073 
0074 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0075 
0076 void Par04InferenceSetup::SetInferenceLibrary(G4String aName)
0077 {
0078   fInferenceLibrary = aName;
0079 
0080 #  ifdef USE_INFERENCE_ONNX
0081   if (fInferenceLibrary == "ONNX")
0082     fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(new Par04OnnxInference(
0083       fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads, fCudaFlag, cuda_keys,
0084       cuda_values, fModelSavePath, fProfilingOutputSavePath));
0085 #  endif
0086 #  ifdef USE_INFERENCE_LWTNN
0087   if (fInferenceLibrary == "LWTNN")
0088     fInferenceInterface =
0089       std::unique_ptr<Par04InferenceInterface>(new Par04LwtnnInference(fModelPathName));
0090 #  endif
0091 #  ifdef USE_INFERENCE_TORCH
0092   if (fInferenceLibrary == "TORCH")
0093     fInferenceInterface =
0094       std::unique_ptr<Par04InferenceInterface>(new Par04TorchInference(fModelPathName));
0095 #  endif
0096 
0097   CheckInferenceLibrary();
0098 }
0099 
0100 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0101 
0102 void Par04InferenceSetup::CheckInferenceLibrary()
0103 {
0104   G4String msg = "Please choose inference library from available libraries (";
0105 #  ifdef USE_INFERENCE_ONNX
0106   msg += "ONNX,";
0107 #  endif
0108 #  ifdef USE_INFERENCE_LWTNN
0109   msg += "LWTNN,";
0110 #  endif
0111 #  ifdef USE_INFERENCE_TORCH
0112   msg += "TORCH";
0113 #  endif
0114   if (fInferenceInterface == nullptr)
0115     G4Exception("Par04InferenceSetup::CheckInferenceLibrary()", "InvalidSetup", FatalException,
0116                 (msg + "). Current name: " + fInferenceLibrary).c_str());
0117 }
0118 
0119 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0120 
0121 void Par04InferenceSetup::GetEnergies(std::vector<G4double>& aEnergies, G4double aInitialEnergy,
0122                                       G4float aTheta, G4float aPhi)
0123 {
0124   // First check if inference library was set correctly
0125   CheckInferenceLibrary();
0126   // size represents the size of the output vector
0127   int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z();
0128   std::vector<G4float> genVector;
0129 
0130   if (fModelType == "VAE")
0131   {
0132     genVector.assign(fSizeLatentVector + fSizeConditionVector, 0);
0133 
0134   // randomly sample from a gaussian distribution in the latent space
0135   for (int i = 0; i < fSizeLatentVector; ++i) {
0136     genVector[i] = CLHEP::RandGauss::shoot(0., 1.);
0137   }
0138 
0139   // Vector of condition
0140   // this is application specific it depdens on what the model was condition on
0141   // and it depends on how the condition values were encoded at the training
0142   // time in this example the energy of each particle is normlaized to the
0143   // highest energy in the considered range (1GeV-500GeV) the angle is also is
0144   // normlaized to the highest angle in the considered range (0-90 in dergrees)
0145   // the model in this example was trained on two detector geometries PBW04
0146   // and SiW  a one hot encoding vector is used to represent the geometry with
0147   // [0,1] for PBW04 and [1,0] for SiW
0148   // 1. energy
0149   genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy;
0150   // 2. angle
0151   genVector[fSizeLatentVector + 1] = (aTheta / (CLHEP::deg)) / fMaxAngle;
0152   // 3. geometry
0153   genVector[fSizeLatentVector + 2] = 0;
0154   genVector[fSizeLatentVector + 3] = 1;
0155   } else if (fModelType == "CaloDiT-2")
0156   {
0157     // fSizeLatentVector & fSizeConditionVector are ignored for CaloDiT-2
0158     // Conditions (dim) are energy (1), phi (1), theta (1) and geo (5)
0159     // The energy range here is 1 GeV - 1TeV, phi goes from 0 to 2pi,
0160     // and theta goes from 0.87 to 2.27.
0161     // And, geo is one-hot encoding describing the 4 geometries the model
0162     // is trained on.
0163     // Order of the geo condition is Par04SiW (this one), Par04SciPb, ODD, FCCeeCLD
0164     // As CaloDiT-2 is trained on these 4 detectors, it can be quickly adapted to
0165     // any new detector (see CaloDiT-2 readme for adaptation) of your choice. Thus
0166     // reusing the knowledge from these previous detectors.
0167     // To use the adapted model, make the following changes for inference:
0168     // genVector[3] = 0.0; (turning OFF Par04SiW)
0169     // genVector[7] = 1.0; (turning ON a new detector)
0170     genVector.assign(8, 0);
0171 
0172     genVector[0] = aInitialEnergy / 1000;  // convert to GeV
0173     genVector[1] = aPhi;
0174     genVector[2] = aTheta;
0175     genVector[3] = 1.0;  //Par04SiW
0176   }
0177   // Run the inference
0178   fInferenceInterface->RunInference(genVector, aEnergies, size);
0179 
0180   // After the inference rescale back to the initial energy
0181 
0182   if (fModelType == "VAE")
0183   // For VAE, energies of cells were normalized to the energy of the particle
0184   {
0185   for (int i = 0; i < size; ++i) {
0186     aEnergies[i] = aEnergies[i] * aInitialEnergy;
0187     }
0188   } else if (fModelType == "CaloDiT-2")
0189   // For CaloDiT-2, energies were scaled by a factor of 1000
0190   {
0191     for (int i = 0; i < size; ++i){
0192       aEnergies[i] = aEnergies[i] * 1000;
0193     }
0194   }
0195 }
0196 
0197 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0198 
0199 void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector>& aPositions, G4ThreeVector pos0,
0200                                        G4ThreeVector direction)
0201 {
0202   aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z());
0203 
0204   // Calculate rotation matrix along the particle momentum direction
0205   // It will rotate the shower axes to match the incoming particle direction
0206   G4RotationMatrix rotMatrix = G4RotationMatrix();
0207   double particleTheta = direction.theta();
0208   double particlePhi = direction.phi();
0209   rotMatrix.rotateZ(-particlePhi);
0210   rotMatrix.rotateY(-particleTheta);
0211   G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix);
0212 
0213   int cpt = 0;
0214   for (G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++) {
0215     for (G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++) {
0216       for (G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++) {
0217         aPositions[cpt] =
0218           pos0
0219           + rotMatrixInv
0220               * G4ThreeVector(
0221                 (iCellR + 0.5) * fMeshSize.x()
0222                   * std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
0223                 (iCellR + 0.5) * fMeshSize.x()
0224                   * std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
0225                 (iCellZ + 0.5) * fMeshSize.z());
0226         cpt++;
0227       }
0228     }
0229   }
0230 }
0231 
0232 #endif