Warning, file /include/root/TMVA/MethodDL.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028 #ifndef ROOT_TMVA_MethodDL
0029 #define ROOT_TMVA_MethodDL
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039 #include "TString.h"
0040
0041 #include "TMVA/MethodBase.h"
0042 #include "TMVA/Types.h"
0043
0044 #include "TMVA/DNN/Architectures/Reference.h"
0045
0046
0047 #include "TMVA/DNN/Architectures/Cpu.h"
0048
0049
0050 #if 0
0051 #ifdef R__HAS_TMVAGPU
0052 #include "TMVA/DNN/Architectures/Cuda.h"
0053 #ifdef R__HAS_CUDNN
0054 #include "TMVA/DNN/Architectures/TCudnn.h"
0055 #endif
0056 #endif
0057 #endif
0058
0059 #include "TMVA/DNN/Functions.h"
0060 #include "TMVA/DNN/DeepNet.h"
0061
0062 #include <vector>
0063 #include <map>
0064
0065 #ifdef R__HAS_TMVAGPU
0066
0067 #endif
0068
0069 namespace TMVA {
0070
0071
0072 struct TTrainingSettings {
0073 size_t batchSize;
0074 size_t testInterval;
0075 size_t convergenceSteps;
0076 size_t maxEpochs;
0077 DNN::ERegularization regularization;
0078 DNN::EOptimizer optimizer;
0079 TString optimizerName;
0080 Double_t learningRate;
0081 Double_t momentum;
0082 Double_t weightDecay;
0083 std::vector<Double_t> dropoutProbabilities;
0084 std::map<TString,double> optimizerParams;
0085 bool multithreading;
0086 };
0087
0088
0089 class MethodDL : public MethodBase {
0090
0091 private:
0092
0093 using KeyValueVector_t = std::vector<std::map<TString, TString>>;
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103 using ArchitectureImpl_t = TMVA::DNN::TCpu<Float_t>;
0104
0105
0106 using DeepNetImpl_t = TMVA::DNN::TDeepNet<ArchitectureImpl_t>;
0107 using MatrixImpl_t = typename ArchitectureImpl_t::Matrix_t;
0108 using TensorImpl_t = typename ArchitectureImpl_t::Tensor_t;
0109 using ScalarImpl_t = typename ArchitectureImpl_t::Scalar_t;
0110 using HostBufferImpl_t = typename ArchitectureImpl_t::HostBuffer_t;
0111
0112
0113 void DeclareOptions() override;
0114 void ProcessOptions() override;
0115
0116 void Init() override;
0117
0118
0119 void ParseInputLayout();
0120 void ParseBatchLayout();
0121
0122
0123
0124
0125
0126 template <typename Architecture_t, typename Layer_t>
0127 void CreateDeepNet(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0128 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets);
0129
0130 template <typename Architecture_t, typename Layer_t>
0131 void ParseDenseLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0132 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
0133
0134 template <typename Architecture_t, typename Layer_t>
0135 void ParseConvLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0136 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
0137
0138 template <typename Architecture_t, typename Layer_t>
0139 void ParseMaxPoolLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0140 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
0141 TString delim);
0142
0143 template <typename Architecture_t, typename Layer_t>
0144 void ParseReshapeLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0145 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
0146 TString delim);
0147
0148 template <typename Architecture_t, typename Layer_t>
0149 void ParseBatchNormLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0150 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
0151 TString delim);
0152
0153 enum ERecurrentLayerType { kLayerRNN = 0, kLayerLSTM = 1, kLayerGRU = 2 };
0154 template <typename Architecture_t, typename Layer_t>
0155 void ParseRecurrentLayer(ERecurrentLayerType type, DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0156 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
0157
0158
0159
0160 template <typename Architecture_t>
0161 void TrainDeepNet();
0162
0163
0164
0165 template <typename Architecture_t>
0166 std::vector<Double_t> PredictDeepNet(Long64_t firstEvt, Long64_t lastEvt, size_t batchSize, Bool_t logProgress);
0167
0168
0169
0170 void FillInputTensor();
0171
0172
0173 UInt_t GetNumValidationSamples();
0174
0175
0176
0177
0178 std::vector<size_t> fInputShape;
0179
0180
0181
0182 size_t fBatchDepth;
0183 size_t fBatchHeight;
0184 size_t fBatchWidth;
0185
0186 size_t fRandomSeed;
0187
0188 DNN::EInitialization fWeightInitialization;
0189 DNN::EOutputFunction fOutputFunction;
0190 DNN::ELossFunction fLossFunction;
0191
0192 TString fInputLayoutString;
0193 TString fBatchLayoutString;
0194 TString fLayoutString;
0195 TString fErrorStrategy;
0196 TString fTrainingStrategyString;
0197 TString fWeightInitializationString;
0198 TString fArchitectureString;
0199 TString fNumValidationString;
0200 bool fResume;
0201 bool fBuildNet;
0202
0203 KeyValueVector_t fSettings;
0204 std::vector<TTrainingSettings> fTrainingSettings;
0205
0206 TensorImpl_t fXInput;
0207 HostBufferImpl_t fXInputBuffer;
0208 std::unique_ptr<MatrixImpl_t> fYHat;
0209 std::unique_ptr<DeepNetImpl_t> fNet;
0210
0211
0212 ClassDefOverride(MethodDL, 0);
0213
0214 protected:
0215
0216 void GetHelpMessage() const override;
0217
0218 std::vector<Double_t> GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress) override;
0219
0220
0221 public:
0222
0223 MethodDL(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption);
0224
0225
0226 MethodDL(DataSetInfo &theData, const TString &theWeightFile);
0227
0228
0229 virtual ~MethodDL();
0230
0231
0232
0233 KeyValueVector_t ParseKeyValueString(TString parseString, TString blockDelim, TString tokenDelim);
0234
0235
0236 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) override;
0237
0238
0239 void Train() override;
0240
0241 Double_t GetMvaValue(Double_t *err = nullptr, Double_t *errUpper = nullptr) override;
0242 const std::vector<Float_t>& GetRegressionValues() override;
0243 const std::vector<Float_t>& GetMulticlassValues() override;
0244
0245
0246 using MethodBase::ReadWeightsFromStream;
0247 void AddWeightsXMLTo(void *parent) const override;
0248 void ReadWeightsFromXML(void *wghtnode) override;
0249 void ReadWeightsFromStream(std::istream &) override;
0250
0251
0252 const Ranking *CreateRanking() override;
0253
0254
0255 size_t GetInputDepth() const { return fInputShape[1]; }
0256 size_t GetInputHeight() const { return fInputShape[2]; }
0257 size_t GetInputWidth() const { return fInputShape[3]; }
0258 size_t GetInputDim() const { return fInputShape.size() - 2; }
0259 std::vector<size_t> GetInputShape() const { return fInputShape; }
0260
0261 size_t GetBatchSize() const { return fInputShape[0]; }
0262 size_t GetBatchDepth() const { return fBatchDepth; }
0263 size_t GetBatchHeight() const { return fBatchHeight; }
0264 size_t GetBatchWidth() const { return fBatchWidth; }
0265
0266 const DeepNetImpl_t & GetDeepNet() const { return *fNet; }
0267
0268 DNN::EInitialization GetWeightInitialization() const { return fWeightInitialization; }
0269 DNN::EOutputFunction GetOutputFunction() const { return fOutputFunction; }
0270 DNN::ELossFunction GetLossFunction() const { return fLossFunction; }
0271
0272 TString GetInputLayoutString() const { return fInputLayoutString; }
0273 TString GetBatchLayoutString() const { return fBatchLayoutString; }
0274 TString GetLayoutString() const { return fLayoutString; }
0275 TString GetErrorStrategyString() const { return fErrorStrategy; }
0276 TString GetTrainingStrategyString() const { return fTrainingStrategyString; }
0277 TString GetWeightInitializationString() const { return fWeightInitializationString; }
0278 TString GetArchitectureString() const { return fArchitectureString; }
0279
0280 const std::vector<TTrainingSettings> &GetTrainingSettings() const { return fTrainingSettings; }
0281 std::vector<TTrainingSettings> &GetTrainingSettings() { return fTrainingSettings; }
0282 const KeyValueVector_t &GetKeyValueSettings() const { return fSettings; }
0283 KeyValueVector_t &GetKeyValueSettings() { return fSettings; }
0284
0285
0286 void SetInputDepth (int inputDepth) { fInputShape[1] = inputDepth; }
0287 void SetInputHeight(int inputHeight) { fInputShape[2] = inputHeight; }
0288 void SetInputWidth (int inputWidth) { fInputShape[3] = inputWidth; }
0289 void SetInputShape (std::vector<size_t> inputShape) { fInputShape = std::move(inputShape); }
0290
0291 void SetBatchSize (size_t batchSize) { fInputShape[0] = batchSize; }
0292 void SetBatchDepth (size_t batchDepth) { fBatchDepth = batchDepth; }
0293 void SetBatchHeight(size_t batchHeight) { fBatchHeight = batchHeight; }
0294 void SetBatchWidth (size_t batchWidth) { fBatchWidth = batchWidth; }
0295
0296 void SetWeightInitialization(DNN::EInitialization weightInitialization)
0297 {
0298 fWeightInitialization = weightInitialization;
0299 }
0300 void SetOutputFunction (DNN::EOutputFunction outputFunction) { fOutputFunction = outputFunction; }
0301 void SetErrorStrategyString (TString errorStrategy) { fErrorStrategy = errorStrategy; }
0302 void SetTrainingStrategyString (TString trainingStrategyString) { fTrainingStrategyString = trainingStrategyString; }
0303 void SetWeightInitializationString(TString weightInitializationString)
0304 {
0305 fWeightInitializationString = weightInitializationString;
0306 }
0307 void SetArchitectureString (TString architectureString) { fArchitectureString = architectureString; }
0308 void SetLayoutString (TString layoutString) { fLayoutString = layoutString; }
0309 };
0310
0311 }
0312
0313 #endif