File indexing completed on 2025-02-23 09:22:34
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 #ifdef USE_INFERENCE
0027 # ifndef PAR04INFEERENCESETUP_HH
0028 # define PAR04INFEERENCESETUP_HH
0029
0030 # include "CLHEP/Units/SystemOfUnits.h" // for mm
0031
0032 # include "G4ThreeVector.hh" // for G4ThreeVector
0033
0034 # include <G4String.hh> // for G4String
0035 # include <G4SystemOfUnits.hh> // for mm
0036 # include <G4Types.hh> // for G4int, G4double, G4bool, G4f...
0037 # include <memory> // for unique_ptr
0038 # include <vector> // for vector
0039 class Par04DetectorConstruction;
0040 class Par04InferenceInterface;
0041 class Par04InferenceMessenger;
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058 class Par04InferenceSetup
0059 {
0060 public:
0061 Par04InferenceSetup();
0062 ~Par04InferenceSetup();
0063
0064
0065
0066
0067 G4bool IfTrigger(G4double aEnergy);
0068
0069
0070
0071 inline void SetMeshSize(const G4ThreeVector& aSize) { fMeshSize = aSize; };
0072
0073
0074
0075 inline G4ThreeVector GetMeshSize() const { return fMeshSize; };
0076
0077
0078
0079 inline void SetMeshNumber(const G4ThreeVector& aSize) { fMeshNumber = aSize; };
0080
0081
0082
0083 inline G4ThreeVector GetMeshNumber() const { return fMeshNumber; };
0084
0085 inline void SetSizeConditionVector(G4int aNumber) { fSizeConditionVector = aNumber; };
0086
0087 inline G4int GetSizeConditionVector() const { return fSizeConditionVector; };
0088
0089 inline void SetSizeLatentVector(G4int aNumber) { fSizeLatentVector = aNumber; };
0090
0091 inline G4int GetSizeLatentVector() const { return fSizeLatentVector; };
0092
0093 inline void SetModelPathName(G4String aName) { fModelPathName = aName; };
0094
0095 inline G4String GetModelPathName() const { return fModelPathName; };
0096
0097 inline void SetProfileFlag(G4int aNumber) { fProfileFlag = aNumber; };
0098
0099 inline G4int GetProfileFlag() const { return fProfileFlag; };
0100
0101 inline void SetOptimizationFlag(G4int aNumber) { fOptimizationFlag = aNumber; };
0102
0103 inline G4int GetOptimizationFlag() const { return fOptimizationFlag; };
0104
0105 inline G4String GetInferenceLibrary() const { return fInferenceLibrary; };
0106
0107
0108 void SetInferenceLibrary(G4String aName);
0109
0110 void CheckInferenceLibrary();
0111
0112 inline void SetMeshNbOfCells(G4ThreeVector aNb) { fMeshNumber = aNb; };
0113
0114
0115 inline void SetMeshNbOfCells(G4int aIndex, G4double aNb) { fMeshNumber[aIndex] = aNb; };
0116
0117 inline G4ThreeVector GetMeshNbOfCells() const { return fMeshNumber; };
0118
0119 inline void SetMeshSizeOfCells(G4ThreeVector aNb) { fMeshSize = aNb; };
0120
0121
0122 inline void SetMeshSizeOfCells(G4int aIndex, G4double aNb) { fMeshSize[aIndex] = aNb; };
0123
0124 inline G4ThreeVector GetMeshSizeOfCells() const { return fMeshSize; };
0125
0126
0127 inline void SetCudaFlag(G4int aNumber) { fCudaFlag = aNumber; };
0128 inline G4int GetCudaFlag() const { return fCudaFlag; };
0129
0130
0131 inline void SetCudaDeviceId(G4String aNumber) { fCudaDeviceId = aNumber; };
0132 inline G4String GetCudaDeviceId() const { return fCudaDeviceId; };
0133 inline void SetCudaGpuMemLimit(G4String aNumber) { fCudaGpuMemLimit = aNumber; };
0134 inline G4String GetCudaGpuMemLimit() const { return fCudaGpuMemLimit; };
0135 inline void SetCudaArenaExtendedStrategy(G4String aNumber)
0136 {
0137 fCudaArenaExtendedStrategy = aNumber;
0138 };
0139 inline G4String GetCudaArenaExtendedStrategy() const { return fCudaArenaExtendedStrategy; };
0140 inline void SetCudaCudnnConvAlgoSearch(G4String aNumber)
0141 {
0142 fCudaCudnnConvAlgoSearch = aNumber;
0143 };
0144 inline G4String GetCudaCudnnConvAlgoSearch() const { return fCudaCudnnConvAlgoSearch; };
0145 inline void SetCudaDoCopyInDefaultStream(G4String aNumber)
0146 {
0147 fCudaDoCopyInDefaultStream = aNumber;
0148 };
0149 inline G4String GetCudaDoCopyInDefaultStream() const { return fCudaDoCopyInDefaultStream; };
0150 inline void SetCudaCudnnConvUseMaxWorkspace(G4String aNumber)
0151 {
0152 fCudaCudnnConvUseMaxWorkspace = aNumber;
0153 };
0154 inline G4String GetCudaCudnnConvUseMaxWorkspace() const
0155 {
0156 return fCudaCudnnConvUseMaxWorkspace;
0157 };
0158
0159
0160
0161
0162
0163 void GetEnergies(std::vector<G4double>& aEnergies, G4double aParticleEnergy,
0164 G4float aInitialAngle);
0165
0166
0167
0168
0169
0170
0171
0172
0173
0174 void GetPositions(std::vector<G4ThreeVector>& aDepositsPositions,
0175 G4ThreeVector aParticlePosition, G4ThreeVector aParticleDirection);
0176
0177 private:
0178
0179
0180
0181
0182 G4ThreeVector fMeshSize = G4ThreeVector(2.325 * CLHEP::mm, 1, 3.4 * CLHEP::mm);
0183
0184
0185
0186 G4ThreeVector fMeshNumber = G4ThreeVector(18, 50, 45);
0187
0188 std::unique_ptr<Par04InferenceInterface> fInferenceInterface;
0189
0190 Par04InferenceMessenger* fInferenceMessenger;
0191
0192 float fMaxEnergy = 1024000.0;
0193
0194 float fMaxAngle = 90.0;
0195
0196 G4String fInferenceLibrary = "ONNX";
0197
0198 G4int fSizeLatentVector = 10;
0199
0200 G4int fSizeConditionVector = 4;
0201
0202 G4String fModelPathName = "MLModels/Generator.onnx";
0203
0204
0205 G4bool fProfileFlag = false;
0206
0207 G4bool fOptimizationFlag = false;
0208
0209 G4String fModelSavePath = "MLModels/Optimized-Generator.onnx";
0210
0211 G4String fProfilingOutputSavePath = "opt.json";
0212
0213 G4int fIntraOpNumThreads = 1;
0214
0215
0216 G4bool fCudaFlag = false;
0217
0218
0219 G4String fCudaDeviceId = "0";
0220 G4String fCudaGpuMemLimit = "2147483648";
0221 G4String fCudaArenaExtendedStrategy = "kSameAsRequested";
0222 G4String fCudaCudnnConvAlgoSearch = "DEFAULT";
0223 G4String fCudaDoCopyInDefaultStream = "1";
0224 G4String fCudaCudnnConvUseMaxWorkspace = "1";
0225 std::vector<const char*> cuda_keys{
0226 "device_id",
0227 "gpu_mem_limit",
0228 "arena_extend_strategy",
0229 "cudnn_conv_algo_search",
0230 "do_copy_in_default_stream",
0231 "cudnn_conv_use_max_workspace",
0232 };
0233 std::vector<const char*> cuda_values{
0234 fCudaDeviceId.c_str(),
0235 fCudaGpuMemLimit.c_str(),
0236 fCudaArenaExtendedStrategy.c_str(),
0237 fCudaCudnnConvAlgoSearch.c_str(),
0238 fCudaDoCopyInDefaultStream.c_str(),
0239 fCudaCudnnConvUseMaxWorkspace.c_str(),
0240 };
0241 };
0242
0243 # endif
0244 #endif