File indexing completed on 2025-02-23 09:22:35
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 #ifdef USE_INFERENCE
0027 # include "Par04InferenceSetup.hh"
0028
0029 # include "Par04InferenceInterface.hh" // for Par04InferenceInterface
0030 # include "Par04InferenceMessenger.hh" // for Par04InferenceMessenger
0031 # ifdef USE_INFERENCE_ONNX
0032 # include "Par04OnnxInference.hh" // for Par04OnnxInference
0033 # endif
0034 # ifdef USE_INFERENCE_LWTNN
0035 # include "Par04LwtnnInference.hh" // for Par04LwtnnInference
0036 # endif
0037 # ifdef USE_INFERENCE_TORCH
0038 # include "Par04TorchInference.hh" // for Par04TorchInference
0039 # endif
0040 # include "CLHEP/Random/RandGauss.h" // for RandGauss
0041
0042 # include "G4RotationMatrix.hh" // for G4RotationMatrix
0043
0044 # include <CLHEP/Units/SystemOfUnits.h> // for pi, GeV, deg
0045 # include <CLHEP/Vector/Rotation.h> // for HepRotation
0046 # include <CLHEP/Vector/ThreeVector.h> // for Hep3Vector
0047 # include <G4Exception.hh> // for G4Exception
0048 # include <G4ExceptionSeverity.hh> // for FatalException
0049 # include <G4ThreeVector.hh> // for G4ThreeVector
0050 # include <algorithm> // for max, copy
0051 # include <cmath> // for cos, sin
0052 # include <string> // for char_traits, basic_string
0053
0054 # include <ext/alloc_traits.h> // for __alloc_traits<>::value_type
0055
0056
0057
0058 Par04InferenceSetup::Par04InferenceSetup() : fInferenceMessenger(new Par04InferenceMessenger(this))
0059 {}
0060
0061
0062
0063 Par04InferenceSetup::~Par04InferenceSetup() {}
0064
0065
0066
0067 G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy)
0068 {
0069
0070 if (aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV) return true;
0071 return false;
0072 }
0073
0074
0075
0076 void Par04InferenceSetup::SetInferenceLibrary(G4String aName)
0077 {
0078 fInferenceLibrary = aName;
0079
0080 # ifdef USE_INFERENCE_ONNX
0081 if (fInferenceLibrary == "ONNX")
0082 fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(new Par04OnnxInference(
0083 fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads, fCudaFlag, cuda_keys,
0084 cuda_values, fModelSavePath, fProfilingOutputSavePath));
0085 # endif
0086 # ifdef USE_INFERENCE_LWTNN
0087 if (fInferenceLibrary == "LWTNN")
0088 fInferenceInterface =
0089 std::unique_ptr<Par04InferenceInterface>(new Par04LwtnnInference(fModelPathName));
0090 # endif
0091 # ifdef USE_INFERENCE_TORCH
0092 if (fInferenceLibrary == "TORCH")
0093 fInferenceInterface =
0094 std::unique_ptr<Par04InferenceInterface>(new Par04TorchInference(fModelPathName));
0095 # endif
0096
0097 CheckInferenceLibrary();
0098 }
0099
0100
0101
0102 void Par04InferenceSetup::CheckInferenceLibrary()
0103 {
0104 G4String msg = "Please choose inference library from available libraries (";
0105 # ifdef USE_INFERENCE_ONNX
0106 msg += "ONNX,";
0107 # endif
0108 # ifdef USE_INFERENCE_LWTNN
0109 msg += "LWTNN,";
0110 # endif
0111 # ifdef USE_INFERENCE_TORCH
0112 msg += "TORCH";
0113 # endif
0114 if (fInferenceInterface == nullptr)
0115 G4Exception("Par04InferenceSetup::CheckInferenceLibrary()", "InvalidSetup", FatalException,
0116 (msg + "). Current name: " + fInferenceLibrary).c_str());
0117 }
0118
0119
0120
0121 void Par04InferenceSetup::GetEnergies(std::vector<G4double>& aEnergies, G4double aInitialEnergy,
0122 G4float aInitialAngle)
0123 {
0124
0125 CheckInferenceLibrary();
0126
0127 int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z();
0128
0129
0130 std::vector<G4float> genVector(fSizeLatentVector + fSizeConditionVector, 0);
0131 for (int i = 0; i < fSizeLatentVector; ++i) {
0132 genVector[i] = CLHEP::RandGauss::shoot(0., 1.);
0133 }
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145 genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy;
0146
0147 genVector[fSizeLatentVector + 1] = (aInitialAngle / (CLHEP::deg)) / fMaxAngle;
0148
0149 genVector[fSizeLatentVector + 2] = 0;
0150 genVector[fSizeLatentVector + 3] = 1;
0151
0152
0153 fInferenceInterface->RunInference(genVector, aEnergies, size);
0154
0155
0156
0157 for (int i = 0; i < size; ++i) {
0158 aEnergies[i] = aEnergies[i] * aInitialEnergy;
0159 }
0160 }
0161
0162
0163
0164 void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector>& aPositions, G4ThreeVector pos0,
0165 G4ThreeVector direction)
0166 {
0167 aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z());
0168
0169
0170
0171 G4RotationMatrix rotMatrix = G4RotationMatrix();
0172 double particleTheta = direction.theta();
0173 double particlePhi = direction.phi();
0174 rotMatrix.rotateZ(-particlePhi);
0175 rotMatrix.rotateY(-particleTheta);
0176 G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix);
0177
0178 int cpt = 0;
0179 for (G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++) {
0180 for (G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++) {
0181 for (G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++) {
0182 aPositions[cpt] =
0183 pos0
0184 + rotMatrixInv
0185 * G4ThreeVector(
0186 (iCellR + 0.5) * fMeshSize.x()
0187 * std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
0188 (iCellR + 0.5) * fMeshSize.x()
0189 * std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
0190 (iCellZ + 0.5) * fMeshSize.z());
0191 cpt++;
0192 }
0193 }
0194 }
0195 }
0196
0197 #endif