Warning, file /geant4/examples/extended/parameterisations/Par04/src/Par04InferenceSetup.cc was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 #ifdef USE_INFERENCE
0027 # include "Par04InferenceSetup.hh"
0028
0029 # include "Par04InferenceInterface.hh" // for Par04InferenceInterface
0030 # include "Par04InferenceMessenger.hh" // for Par04InferenceMessenger
0031 # ifdef USE_INFERENCE_ONNX
0032 # include "Par04OnnxInference.hh" // for Par04OnnxInference
0033 # endif
0034 # ifdef USE_INFERENCE_LWTNN
0035 # include "Par04LwtnnInference.hh" // for Par04LwtnnInference
0036 # endif
0037 # ifdef USE_INFERENCE_TORCH
0038 # include "Par04TorchInference.hh" // for Par04TorchInference
0039 # endif
0040 # include "CLHEP/Random/RandGauss.h" // for RandGauss
0041
0042 # include "G4RotationMatrix.hh" // for G4RotationMatrix
0043
0044 # include <CLHEP/Units/SystemOfUnits.h> // for pi, GeV, deg
0045 # include <CLHEP/Vector/Rotation.h> // for HepRotation
0046 # include <CLHEP/Vector/ThreeVector.h> // for Hep3Vector
0047 # include <G4Exception.hh> // for G4Exception
0048 # include <G4ExceptionSeverity.hh> // for FatalException
0049 # include <G4ThreeVector.hh> // for G4ThreeVector
0050 # include <algorithm> // for max, copy
0051 # include <cmath> // for cos, sin
0052 # include <string> // for char_traits, basic_string
0053
0054 # include <ext/alloc_traits.h> // for __alloc_traits<>::value_type
0055
0056
0057
0058 Par04InferenceSetup::Par04InferenceSetup() : fInferenceMessenger(new Par04InferenceMessenger(this))
0059 {}
0060
0061
0062
0063 Par04InferenceSetup::~Par04InferenceSetup() {}
0064
0065
0066
0067 G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy)
0068 {
0069
0070 if (aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV) return true;
0071 return false;
0072 }
0073
0074
0075
0076 void Par04InferenceSetup::SetInferenceLibrary(G4String aName)
0077 {
0078 fInferenceLibrary = aName;
0079
0080 # ifdef USE_INFERENCE_ONNX
0081 if (fInferenceLibrary == "ONNX")
0082 fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(new Par04OnnxInference(
0083 fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads, fCudaFlag, cuda_keys,
0084 cuda_values, fModelSavePath, fProfilingOutputSavePath));
0085 # endif
0086 # ifdef USE_INFERENCE_LWTNN
0087 if (fInferenceLibrary == "LWTNN")
0088 fInferenceInterface =
0089 std::unique_ptr<Par04InferenceInterface>(new Par04LwtnnInference(fModelPathName));
0090 # endif
0091 # ifdef USE_INFERENCE_TORCH
0092 if (fInferenceLibrary == "TORCH")
0093 fInferenceInterface =
0094 std::unique_ptr<Par04InferenceInterface>(new Par04TorchInference(fModelPathName));
0095 # endif
0096
0097 CheckInferenceLibrary();
0098 }
0099
0100
0101
0102 void Par04InferenceSetup::CheckInferenceLibrary()
0103 {
0104 G4String msg = "Please choose inference library from available libraries (";
0105 # ifdef USE_INFERENCE_ONNX
0106 msg += "ONNX,";
0107 # endif
0108 # ifdef USE_INFERENCE_LWTNN
0109 msg += "LWTNN,";
0110 # endif
0111 # ifdef USE_INFERENCE_TORCH
0112 msg += "TORCH";
0113 # endif
0114 if (fInferenceInterface == nullptr)
0115 G4Exception("Par04InferenceSetup::CheckInferenceLibrary()", "InvalidSetup", FatalException,
0116 (msg + "). Current name: " + fInferenceLibrary).c_str());
0117 }
0118
0119
0120
0121 void Par04InferenceSetup::GetEnergies(std::vector<G4double>& aEnergies, G4double aInitialEnergy,
0122 G4float aTheta, G4float aPhi)
0123 {
0124
0125 CheckInferenceLibrary();
0126
0127 int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z();
0128 std::vector<G4float> genVector;
0129
0130 if (fModelType == "VAE")
0131 {
0132 genVector.assign(fSizeLatentVector + fSizeConditionVector, 0);
0133
0134
0135 for (int i = 0; i < fSizeLatentVector; ++i) {
0136 genVector[i] = CLHEP::RandGauss::shoot(0., 1.);
0137 }
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148
0149 genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy;
0150
0151 genVector[fSizeLatentVector + 1] = (aTheta / (CLHEP::deg)) / fMaxAngle;
0152
0153 genVector[fSizeLatentVector + 2] = 0;
0154 genVector[fSizeLatentVector + 3] = 1;
0155 } else if (fModelType == "CaloDiT-2")
0156 {
0157
0158
0159
0160
0161
0162
0163
0164
0165
0166
0167
0168
0169
0170 genVector.assign(8, 0);
0171
0172 genVector[0] = aInitialEnergy / 1000;
0173 genVector[1] = aPhi;
0174 genVector[2] = aTheta;
0175 genVector[3] = 1.0;
0176 }
0177
0178 fInferenceInterface->RunInference(genVector, aEnergies, size);
0179
0180
0181
0182 if (fModelType == "VAE")
0183
0184 {
0185 for (int i = 0; i < size; ++i) {
0186 aEnergies[i] = aEnergies[i] * aInitialEnergy;
0187 }
0188 } else if (fModelType == "CaloDiT-2")
0189
0190 {
0191 for (int i = 0; i < size; ++i){
0192 aEnergies[i] = aEnergies[i] * 1000;
0193 }
0194 }
0195 }
0196
0197
0198
0199 void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector>& aPositions, G4ThreeVector pos0,
0200 G4ThreeVector direction)
0201 {
0202 aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z());
0203
0204
0205
0206 G4RotationMatrix rotMatrix = G4RotationMatrix();
0207 double particleTheta = direction.theta();
0208 double particlePhi = direction.phi();
0209 rotMatrix.rotateZ(-particlePhi);
0210 rotMatrix.rotateY(-particleTheta);
0211 G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix);
0212
0213 int cpt = 0;
0214 for (G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++) {
0215 for (G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++) {
0216 for (G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++) {
0217 aPositions[cpt] =
0218 pos0
0219 + rotMatrixInv
0220 * G4ThreeVector(
0221 (iCellR + 0.5) * fMeshSize.x()
0222 * std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
0223 (iCellR + 0.5) * fMeshSize.x()
0224 * std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
0225 (iCellZ + 0.5) * fMeshSize.z());
0226 cpt++;
0227 }
0228 }
0229 }
0230 }
0231
0232 #endif