Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-23 09:22:35

0001 //
0002 // ********************************************************************
0003 // * License and Disclaimer                                           *
0004 // *                                                                  *
0005 // * The  Geant4 software  is  copyright of the Copyright Holders  of *
0006 // * the Geant4 Collaboration.  It is provided  under  the terms  and *
0007 // * conditions of the Geant4 Software License,  included in the file *
0008 // * LICENSE and available at  http://cern.ch/geant4/license .  These *
0009 // * include a list of copyright holders.                             *
0010 // *                                                                  *
0011 // * Neither the authors of this software system, nor their employing *
0012 // * institutes,nor the agencies providing financial support for this *
0013 // * work  make  any representation or  warranty, express or implied, *
0014 // * regarding  this  software system or assume any liability for its *
0015 // * use.  Please see the license in the file  LICENSE  and URL above *
0016 // * for the full disclaimer and the limitation of liability.         *
0017 // *                                                                  *
0018 // * This  code  implementation is the result of  the  scientific and *
0019 // * technical work of the GEANT4 collaboration.                      *
0020 // * By using,  copying,  modifying or  distributing the software (or *
0021 // * any work based  on the software)  you  agree  to acknowledge its *
0022 // * use  in  resulting  scientific  publications,  and indicate your *
0023 // * acceptance of all terms of the Geant4 Software license.          *
0024 // ********************************************************************
0025 //
0026 #ifdef USE_INFERENCE
0027 #  include "Par04InferenceSetup.hh"
0028 
0029 #  include "Par04InferenceInterface.hh"  // for Par04InferenceInterface
0030 #  include "Par04InferenceMessenger.hh"  // for Par04InferenceMessenger
0031 #  ifdef USE_INFERENCE_ONNX
0032 #    include "Par04OnnxInference.hh"  // for Par04OnnxInference
0033 #  endif
0034 #  ifdef USE_INFERENCE_LWTNN
0035 #    include "Par04LwtnnInference.hh"  // for Par04LwtnnInference
0036 #  endif
0037 #  ifdef USE_INFERENCE_TORCH
0038 #    include "Par04TorchInference.hh"  // for Par04TorchInference
0039 #  endif
0040 #  include "CLHEP/Random/RandGauss.h"  // for RandGauss
0041 
0042 #  include "G4RotationMatrix.hh"  // for G4RotationMatrix
0043 
0044 #  include <CLHEP/Units/SystemOfUnits.h>  // for pi, GeV, deg
0045 #  include <CLHEP/Vector/Rotation.h>  // for HepRotation
0046 #  include <CLHEP/Vector/ThreeVector.h>  // for Hep3Vector
0047 #  include <G4Exception.hh>  // for G4Exception
0048 #  include <G4ExceptionSeverity.hh>  // for FatalException
0049 #  include <G4ThreeVector.hh>  // for G4ThreeVector
0050 #  include <algorithm>  // for max, copy
0051 #  include <cmath>  // for cos, sin
0052 #  include <string>  // for char_traits, basic_string
0053 
0054 #  include <ext/alloc_traits.h>  // for __alloc_traits<>::value_type
0055 
0056 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0057 
0058 Par04InferenceSetup::Par04InferenceSetup() : fInferenceMessenger(new Par04InferenceMessenger(this))
0059 {}
0060 
0061 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0062 
0063 Par04InferenceSetup::~Par04InferenceSetup() {}
0064 
0065 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0066 
0067 G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy)
0068 {
0069   /// Energy of electrons used in training dataset
0070   if (aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV) return true;
0071   return false;
0072 }
0073 
0074 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0075 
0076 void Par04InferenceSetup::SetInferenceLibrary(G4String aName)
0077 {
0078   fInferenceLibrary = aName;
0079 
0080 #  ifdef USE_INFERENCE_ONNX
0081   if (fInferenceLibrary == "ONNX")
0082     fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(new Par04OnnxInference(
0083       fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads, fCudaFlag, cuda_keys,
0084       cuda_values, fModelSavePath, fProfilingOutputSavePath));
0085 #  endif
0086 #  ifdef USE_INFERENCE_LWTNN
0087   if (fInferenceLibrary == "LWTNN")
0088     fInferenceInterface =
0089       std::unique_ptr<Par04InferenceInterface>(new Par04LwtnnInference(fModelPathName));
0090 #  endif
0091 #  ifdef USE_INFERENCE_TORCH
0092   if (fInferenceLibrary == "TORCH")
0093     fInferenceInterface =
0094       std::unique_ptr<Par04InferenceInterface>(new Par04TorchInference(fModelPathName));
0095 #  endif
0096 
0097   CheckInferenceLibrary();
0098 }
0099 
0100 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0101 
0102 void Par04InferenceSetup::CheckInferenceLibrary()
0103 {
0104   G4String msg = "Please choose inference library from available libraries (";
0105 #  ifdef USE_INFERENCE_ONNX
0106   msg += "ONNX,";
0107 #  endif
0108 #  ifdef USE_INFERENCE_LWTNN
0109   msg += "LWTNN,";
0110 #  endif
0111 #  ifdef USE_INFERENCE_TORCH
0112   msg += "TORCH";
0113 #  endif
0114   if (fInferenceInterface == nullptr)
0115     G4Exception("Par04InferenceSetup::CheckInferenceLibrary()", "InvalidSetup", FatalException,
0116                 (msg + "). Current name: " + fInferenceLibrary).c_str());
0117 }
0118 
0119 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0120 
0121 void Par04InferenceSetup::GetEnergies(std::vector<G4double>& aEnergies, G4double aInitialEnergy,
0122                                       G4float aInitialAngle)
0123 {
0124   // First check if inference library was set correctly
0125   CheckInferenceLibrary();
0126   // size represents the size of the output vector
0127   int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z();
0128 
0129   // randomly sample from a gaussian distribution in the latent space
0130   std::vector<G4float> genVector(fSizeLatentVector + fSizeConditionVector, 0);
0131   for (int i = 0; i < fSizeLatentVector; ++i) {
0132     genVector[i] = CLHEP::RandGauss::shoot(0., 1.);
0133   }
0134 
0135   // Vector of condition
0136   // this is application specific it depdens on what the model was condition on
0137   // and it depends on how the condition values were encoded at the training
0138   // time in this example the energy of each particle is normlaized to the
0139   // highest energy in the considered range (1GeV-500GeV) the angle is also is
0140   // normlaized to the highest angle in the considered range (0-90 in dergrees)
0141   // the model in this example was trained on two detector geometries PBW04
0142   // and SiW  a one hot encoding vector is used to represent the geometry with
0143   // [0,1] for PBW04 and [1,0] for SiW
0144   // 1. energy
0145   genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy;
0146   // 2. angle
0147   genVector[fSizeLatentVector + 1] = (aInitialAngle / (CLHEP::deg)) / fMaxAngle;
0148   // 3. geometry
0149   genVector[fSizeLatentVector + 2] = 0;
0150   genVector[fSizeLatentVector + 3] = 1;
0151 
0152   // Run the inference
0153   fInferenceInterface->RunInference(genVector, aEnergies, size);
0154 
0155   // After the inference rescale back to the initial energy (in this example the
0156   // energies of cells were normalized to the energy of the particle)
0157   for (int i = 0; i < size; ++i) {
0158     aEnergies[i] = aEnergies[i] * aInitialEnergy;
0159   }
0160 }
0161 
0162 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
0163 
0164 void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector>& aPositions, G4ThreeVector pos0,
0165                                        G4ThreeVector direction)
0166 {
0167   aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z());
0168 
0169   // Calculate rotation matrix along the particle momentum direction
0170   // It will rotate the shower axes to match the incoming particle direction
0171   G4RotationMatrix rotMatrix = G4RotationMatrix();
0172   double particleTheta = direction.theta();
0173   double particlePhi = direction.phi();
0174   rotMatrix.rotateZ(-particlePhi);
0175   rotMatrix.rotateY(-particleTheta);
0176   G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix);
0177 
0178   int cpt = 0;
0179   for (G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++) {
0180     for (G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++) {
0181       for (G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++) {
0182         aPositions[cpt] =
0183           pos0
0184           + rotMatrixInv
0185               * G4ThreeVector(
0186                 (iCellR + 0.5) * fMeshSize.x()
0187                   * std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
0188                 (iCellR + 0.5) * fMeshSize.x()
0189                   * std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
0190                 (iCellZ + 0.5) * fMeshSize.z());
0191         cpt++;
0192       }
0193     }
0194   }
0195 }
0196 
0197 #endif