Warning, file /include/root/TMVA/MethodDNN.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 #ifndef ROOT_TMVA_MethodDNN
0031 #define ROOT_TMVA_MethodDNN
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041 #include <vector>
0042 #include <map>
0043 #include <string>
0044 #include <sstream>
0045
0046 #include "TString.h"
0047 #include "TTree.h"
0048 #include "TRandom3.h"
0049 #include "TH1F.h"
0050 #include "TMVA/MethodBase.h"
0051 #include "TMVA/NeuralNet.h"
0052
0053 #include "TMVA/Tools.h"
0054
0055 #include "TMVA/DNN/Net.h"
0056 #include "TMVA/DNN/Minimizers.h"
0057 #include "TMVA/DNN/Architectures/Reference.h"
0058
0059 #ifdef R__HAS_TMVACPU
0060 #define DNNCPU
0061 #endif
0062 #ifdef R__HAS_TMVAGPU
0063
0064 #endif
0065
0066 #ifdef DNNCPU
0067 #include "TMVA/DNN/Architectures/Cpu.h"
0068 #endif
0069
0070 #ifdef DNNCUDA
0071 #include "TMVA/DNN/Architectures/Cuda.h"
0072 #endif
0073
0074 namespace TMVA {
0075
0076 class MethodDNN : public MethodBase
0077 {
0078 friend struct TestMethodDNNValidationSize;
0079
0080 using Architecture_t = DNN::TReference<Float_t>;
0081 using Net_t = DNN::TNet<Architecture_t>;
0082 using Matrix_t = typename Architecture_t::Matrix_t;
0083 using Scalar_t = typename Architecture_t::Scalar_t;
0084
0085 private:
0086 using LayoutVector_t = std::vector<std::pair<int, DNN::EActivationFunction>>;
0087 using KeyValueVector_t = std::vector<std::map<TString, TString>>;
0088
0089 struct TTrainingSettings
0090 {
0091 size_t batchSize;
0092 size_t testInterval;
0093 size_t convergenceSteps;
0094 DNN::ERegularization regularization;
0095 Double_t learningRate;
0096 Double_t momentum;
0097 Double_t weightDecay;
0098 std::vector<Double_t> dropoutProbabilities;
0099 bool multithreading;
0100 };
0101
0102
0103 void DeclareOptions();
0104 void ProcessOptions();
0105
0106 UInt_t GetNumValidationSamples();
0107
0108
0109 void Init();
0110
0111 Net_t fNet;
0112 DNN::EInitialization fWeightInitialization;
0113 DNN::EOutputFunction fOutputFunction;
0114
0115 TString fLayoutString;
0116 TString fErrorStrategy;
0117 TString fTrainingStrategyString;
0118 TString fWeightInitializationString;
0119 TString fArchitectureString;
0120 TString fValidationSize;
0121 LayoutVector_t fLayout;
0122 std::vector<TTrainingSettings> fTrainingSettings;
0123 bool fResume;
0124
0125 KeyValueVector_t fSettings;
0126
0127 ClassDef(MethodDNN,0);
0128
0129 static inline void WriteMatrixXML(void *parent, const char *name,
0130 const TMatrixT<Double_t> &X);
0131 static inline void ReadMatrixXML(void *xml, const char *name,
0132 TMatrixT<Double_t> &X);
0133 protected:
0134
0135 void MakeClassSpecific( std::ostream&, const TString& ) const;
0136 void GetHelpMessage() const;
0137
0138 public:
0139
0140
0141 MethodDNN(const TString& jobName,
0142 const TString& methodTitle,
0143 DataSetInfo& theData,
0144 const TString& theOption);
0145 MethodDNN(DataSetInfo& theData,
0146 const TString& theWeightFile);
0147 virtual ~MethodDNN();
0148
0149 virtual Bool_t HasAnalysisType(Types::EAnalysisType type,
0150 UInt_t numberClasses,
0151 UInt_t numberTargets );
0152 LayoutVector_t ParseLayoutString(TString layerSpec);
0153 KeyValueVector_t ParseKeyValueString(TString parseString,
0154 TString blockDelim,
0155 TString tokenDelim);
0156 void Train();
0157 void TrainGpu();
0158 void TrainCpu();
0159
0160 virtual Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0161 virtual const std::vector<Float_t>& GetRegressionValues();
0162 virtual const std::vector<Float_t>& GetMulticlassValues();
0163
0164 using MethodBase::ReadWeightsFromStream;
0165
0166
0167 void AddWeightsXMLTo ( void* parent ) const;
0168
0169
0170 void ReadWeightsFromStream( std::istream & i );
0171 void ReadWeightsFromXML ( void* wghtnode );
0172
0173
0174 const Ranking* CreateRanking();
0175
0176 };
0177
0178 inline void MethodDNN::WriteMatrixXML(void *parent,
0179 const char *name,
0180 const TMatrixT<Double_t> &X)
0181 {
0182 std::stringstream matrixStringStream("");
0183 matrixStringStream.precision( 16 );
0184
0185 for (size_t i = 0; i < (size_t) X.GetNrows(); i++)
0186 {
0187 for (size_t j = 0; j < (size_t) X.GetNcols(); j++)
0188 {
0189 matrixStringStream << std::scientific << X(i,j) << " ";
0190 }
0191 }
0192 std::string s = matrixStringStream.str();
0193 void* matxml = gTools().xmlengine().NewChild(parent, nullptr, name);
0194 gTools().xmlengine().NewAttr(matxml, nullptr, "rows",
0195 gTools().StringFromInt((int)X.GetNrows()));
0196 gTools().xmlengine().NewAttr(matxml, nullptr, "cols",
0197 gTools().StringFromInt((int)X.GetNcols()));
0198 gTools().xmlengine().AddRawLine (matxml, s.c_str());
0199 }
0200
0201 inline void MethodDNN::ReadMatrixXML(void *xml,
0202 const char *name,
0203 TMatrixT<Double_t> &X)
0204 {
0205 void *matrixXML = gTools().GetChild(xml, name);
0206 size_t rows, cols;
0207 gTools().ReadAttr(matrixXML, "rows", rows);
0208 gTools().ReadAttr(matrixXML, "cols", cols);
0209
0210 const char * matrixString = gTools().xmlengine().GetNodeContent(matrixXML);
0211 std::stringstream matrixStringStream(matrixString);
0212
0213 for (size_t i = 0; i < rows; i++)
0214 {
0215 for (size_t j = 0; j < cols; j++)
0216 {
0217 matrixStringStream >> X(i,j);
0218 }
0219 }
0220 }
0221 }
0222
0223 #endif