File indexing completed on 2025-01-18 10:11:00
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028 #ifndef ROOT_TMVA_MethodDL
0029 #define ROOT_TMVA_MethodDL
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039 #include "TString.h"
0040
0041 #include "TMVA/MethodBase.h"
0042 #include "TMVA/Types.h"
0043
0044 #include "TMVA/DNN/Architectures/Reference.h"
0045
0046
0047 #include "TMVA/DNN/Architectures/Cpu.h"
0048
0049
0050 #if 0
0051 #ifdef R__HAS_TMVAGPU
0052 #include "TMVA/DNN/Architectures/Cuda.h"
0053 #ifdef R__HAS_CUDNN
0054 #include "TMVA/DNN/Architectures/TCudnn.h"
0055 #endif
0056 #endif
0057 #endif
0058
0059 #include "TMVA/DNN/Functions.h"
0060 #include "TMVA/DNN/DeepNet.h"
0061
0062 #include <vector>
0063 #include <map>
0064
0065 #ifdef R__HAS_TMVAGPU
0066
0067 #endif
0068
0069 namespace TMVA {
0070
0071
0072 struct TTrainingSettings {
0073 size_t batchSize;
0074 size_t testInterval;
0075 size_t convergenceSteps;
0076 size_t maxEpochs;
0077 DNN::ERegularization regularization;
0078 DNN::EOptimizer optimizer;
0079 TString optimizerName;
0080 Double_t learningRate;
0081 Double_t momentum;
0082 Double_t weightDecay;
0083 std::vector<Double_t> dropoutProbabilities;
0084 std::map<TString,double> optimizerParams;
0085 bool multithreading;
0086 };
0087
0088
0089 class MethodDL : public MethodBase {
0090
0091 private:
0092
0093 using KeyValueVector_t = std::vector<std::map<TString, TString>>;
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103 using ArchitectureImpl_t = TMVA::DNN::TCpu<Float_t>;
0104
0105
0106 using DeepNetImpl_t = TMVA::DNN::TDeepNet<ArchitectureImpl_t>;
0107 using MatrixImpl_t = typename ArchitectureImpl_t::Matrix_t;
0108 using TensorImpl_t = typename ArchitectureImpl_t::Tensor_t;
0109 using ScalarImpl_t = typename ArchitectureImpl_t::Scalar_t;
0110 using HostBufferImpl_t = typename ArchitectureImpl_t::HostBuffer_t;
0111
0112
0113 void DeclareOptions();
0114 void ProcessOptions();
0115
0116 void Init();
0117
0118
0119 void ParseInputLayout();
0120 void ParseBatchLayout();
0121
0122
0123
0124
0125
0126 template <typename Architecture_t, typename Layer_t>
0127 void CreateDeepNet(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0128 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets);
0129
0130 template <typename Architecture_t, typename Layer_t>
0131 void ParseDenseLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0132 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
0133
0134 template <typename Architecture_t, typename Layer_t>
0135 void ParseConvLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0136 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
0137
0138 template <typename Architecture_t, typename Layer_t>
0139 void ParseMaxPoolLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0140 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
0141 TString delim);
0142
0143 template <typename Architecture_t, typename Layer_t>
0144 void ParseReshapeLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0145 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
0146 TString delim);
0147
0148 template <typename Architecture_t, typename Layer_t>
0149 void ParseBatchNormLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0150 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
0151 TString delim);
0152
0153 enum ERecurrentLayerType { kLayerRNN = 0, kLayerLSTM = 1, kLayerGRU = 2 };
0154 template <typename Architecture_t, typename Layer_t>
0155 void ParseRecurrentLayer(ERecurrentLayerType type, DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0156 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
0157
0158
0159
0160 template <typename Architecture_t>
0161 void TrainDeepNet();
0162
0163
0164
0165 template <typename Architecture_t>
0166 std::vector<Double_t> PredictDeepNet(Long64_t firstEvt, Long64_t lastEvt, size_t batchSize, Bool_t logProgress);
0167
0168
0169
0170 void FillInputTensor();
0171
0172
0173 UInt_t GetNumValidationSamples();
0174
0175
0176
0177
0178 std::vector<size_t> fInputShape;
0179
0180
0181
0182 size_t fBatchDepth;
0183 size_t fBatchHeight;
0184 size_t fBatchWidth;
0185
0186 size_t fRandomSeed;
0187
0188 DNN::EInitialization fWeightInitialization;
0189 DNN::EOutputFunction fOutputFunction;
0190 DNN::ELossFunction fLossFunction;
0191
0192 TString fInputLayoutString;
0193 TString fBatchLayoutString;
0194 TString fLayoutString;
0195 TString fErrorStrategy;
0196 TString fTrainingStrategyString;
0197 TString fWeightInitializationString;
0198 TString fArchitectureString;
0199 TString fNumValidationString;
0200 bool fResume;
0201 bool fBuildNet;
0202
0203 KeyValueVector_t fSettings;
0204 std::vector<TTrainingSettings> fTrainingSettings;
0205
0206 TensorImpl_t fXInput;
0207 HostBufferImpl_t fXInputBuffer;
0208 std::unique_ptr<MatrixImpl_t> fYHat;
0209 std::unique_ptr<DeepNetImpl_t> fNet;
0210
0211
0212 ClassDef(MethodDL, 0);
0213
0214 protected:
0215
0216 void GetHelpMessage() const;
0217
0218 virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress);
0219
0220
0221 public:
0222
0223 MethodDL(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption);
0224
0225
0226 MethodDL(DataSetInfo &theData, const TString &theWeightFile);
0227
0228
0229 virtual ~MethodDL();
0230
0231
0232
0233 KeyValueVector_t ParseKeyValueString(TString parseString, TString blockDelim, TString tokenDelim);
0234
0235
0236 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
0237
0238
0239 void Train();
0240
0241 Double_t GetMvaValue(Double_t *err = nullptr, Double_t *errUpper = nullptr);
0242 virtual const std::vector<Float_t>& GetRegressionValues();
0243 virtual const std::vector<Float_t>& GetMulticlassValues();
0244
0245
0246 using MethodBase::ReadWeightsFromStream;
0247 void AddWeightsXMLTo(void *parent) const;
0248 void ReadWeightsFromXML(void *wghtnode);
0249 void ReadWeightsFromStream(std::istream &);
0250
0251
0252 const Ranking *CreateRanking();
0253
0254
0255 size_t GetInputDepth() const { return fInputShape[1]; }
0256 size_t GetInputHeight() const { return fInputShape[2]; }
0257 size_t GetInputWidth() const { return fInputShape[3]; }
0258 size_t GetInputDim() const { return fInputShape.size() - 2; }
0259 std::vector<size_t> GetInputShape() const { return fInputShape; }
0260
0261 size_t GetBatchSize() const { return fInputShape[0]; }
0262 size_t GetBatchDepth() const { return fBatchDepth; }
0263 size_t GetBatchHeight() const { return fBatchHeight; }
0264 size_t GetBatchWidth() const { return fBatchWidth; }
0265
0266 const DeepNetImpl_t & GetDeepNet() const { return *fNet; }
0267
0268 DNN::EInitialization GetWeightInitialization() const { return fWeightInitialization; }
0269 DNN::EOutputFunction GetOutputFunction() const { return fOutputFunction; }
0270 DNN::ELossFunction GetLossFunction() const { return fLossFunction; }
0271
0272 TString GetInputLayoutString() const { return fInputLayoutString; }
0273 TString GetBatchLayoutString() const { return fBatchLayoutString; }
0274 TString GetLayoutString() const { return fLayoutString; }
0275 TString GetErrorStrategyString() const { return fErrorStrategy; }
0276 TString GetTrainingStrategyString() const { return fTrainingStrategyString; }
0277 TString GetWeightInitializationString() const { return fWeightInitializationString; }
0278 TString GetArchitectureString() const { return fArchitectureString; }
0279
0280 const std::vector<TTrainingSettings> &GetTrainingSettings() const { return fTrainingSettings; }
0281 std::vector<TTrainingSettings> &GetTrainingSettings() { return fTrainingSettings; }
0282 const KeyValueVector_t &GetKeyValueSettings() const { return fSettings; }
0283 KeyValueVector_t &GetKeyValueSettings() { return fSettings; }
0284
0285
0286 void SetInputDepth (int inputDepth) { fInputShape[1] = inputDepth; }
0287 void SetInputHeight(int inputHeight) { fInputShape[2] = inputHeight; }
0288 void SetInputWidth (int inputWidth) { fInputShape[3] = inputWidth; }
0289 void SetInputShape (std::vector<size_t> inputShape) { fInputShape = std::move(inputShape); }
0290
0291 void SetBatchSize (size_t batchSize) { fInputShape[0] = batchSize; }
0292 void SetBatchDepth (size_t batchDepth) { fBatchDepth = batchDepth; }
0293 void SetBatchHeight(size_t batchHeight) { fBatchHeight = batchHeight; }
0294 void SetBatchWidth (size_t batchWidth) { fBatchWidth = batchWidth; }
0295
0296 void SetWeightInitialization(DNN::EInitialization weightInitialization)
0297 {
0298 fWeightInitialization = weightInitialization;
0299 }
0300 void SetOutputFunction (DNN::EOutputFunction outputFunction) { fOutputFunction = outputFunction; }
0301 void SetErrorStrategyString (TString errorStrategy) { fErrorStrategy = errorStrategy; }
0302 void SetTrainingStrategyString (TString trainingStrategyString) { fTrainingStrategyString = trainingStrategyString; }
0303 void SetWeightInitializationString(TString weightInitializationString)
0304 {
0305 fWeightInitializationString = weightInitializationString;
0306 }
0307 void SetArchitectureString (TString architectureString) { fArchitectureString = architectureString; }
0308 void SetLayoutString (TString layoutString) { fLayoutString = layoutString; }
0309 };
0310
0311 }
0312
0313 #endif