File indexing completed on 2024-11-15 09:56:15
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 #ifndef ROOT_TMVA_MethodDNN
0031 #define ROOT_TMVA_MethodDNN
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041 #include <vector>
0042 #include <map>
0043 #include <string>
0044 #include <sstream>
0045
0046 #include "TString.h"
0047 #include "TTree.h"
0048 #include "TRandom3.h"
0049 #include "TH1F.h"
0050 #include "TMVA/MethodBase.h"
0051 #include "TMVA/NeuralNet.h"
0052
0053 #include "TMVA/Tools.h"
0054
0055 #include "TMVA/DNN/Net.h"
0056 #include "TMVA/DNN/Minimizers.h"
0057 #include "TMVA/DNN/Architectures/Reference.h"
0058
0059 #ifdef R__HAS_TMVACPU
0060 #define DNNCPU
0061 #endif
0062 #ifdef R__HAS_TMVAGPU
0063
0064 #endif
0065
0066 #ifdef DNNCPU
0067 #include "TMVA/DNN/Architectures/Cpu.h"
0068 #endif
0069
0070 #ifdef DNNCUDA
0071 #include "TMVA/DNN/Architectures/Cuda.h"
0072 #endif
0073
0074 namespace TMVA {
0075
0076 class MethodDNN : public MethodBase
0077 {
0078 friend struct TestMethodDNNValidationSize;
0079
0080 using Architecture_t = DNN::TReference<Float_t>;
0081 using Net_t = DNN::TNet<Architecture_t>;
0082 using Matrix_t = typename Architecture_t::Matrix_t;
0083 using Scalar_t = typename Architecture_t::Scalar_t;
0084
0085 private:
0086 using LayoutVector_t = std::vector<std::pair<int, DNN::EActivationFunction>>;
0087 using KeyValueVector_t = std::vector<std::map<TString, TString>>;
0088
0089 struct TTrainingSettings
0090 {
0091 size_t batchSize;
0092 size_t testInterval;
0093 size_t convergenceSteps;
0094 DNN::ERegularization regularization;
0095 Double_t learningRate;
0096 Double_t momentum;
0097 Double_t weightDecay;
0098 std::vector<Double_t> dropoutProbabilities;
0099 bool multithreading;
0100 };
0101
0102
0103 void DeclareOptions();
0104 void ProcessOptions();
0105
0106 UInt_t GetNumValidationSamples();
0107
0108
0109 void Init();
0110
0111 Net_t fNet;
0112 DNN::EInitialization fWeightInitialization;
0113 DNN::EOutputFunction fOutputFunction;
0114
0115 TString fLayoutString;
0116 TString fErrorStrategy;
0117 TString fTrainingStrategyString;
0118 TString fWeightInitializationString;
0119 TString fArchitectureString;
0120 TString fValidationSize;
0121 LayoutVector_t fLayout;
0122 std::vector<TTrainingSettings> fTrainingSettings;
0123 bool fResume;
0124
0125 KeyValueVector_t fSettings;
0126
0127 ClassDef(MethodDNN,0);
0128
0129 static inline void WriteMatrixXML(void *parent, const char *name,
0130 const TMatrixT<Double_t> &X);
0131 static inline void ReadMatrixXML(void *xml, const char *name,
0132 TMatrixT<Double_t> &X);
0133 protected:
0134
0135 void MakeClassSpecific( std::ostream&, const TString& ) const;
0136 void GetHelpMessage() const;
0137
0138 public:
0139
0140
0141 MethodDNN(const TString& jobName,
0142 const TString& methodTitle,
0143 DataSetInfo& theData,
0144 const TString& theOption);
0145 MethodDNN(DataSetInfo& theData,
0146 const TString& theWeightFile);
0147 virtual ~MethodDNN();
0148
0149 virtual Bool_t HasAnalysisType(Types::EAnalysisType type,
0150 UInt_t numberClasses,
0151 UInt_t numberTargets );
0152 LayoutVector_t ParseLayoutString(TString layerSpec);
0153 KeyValueVector_t ParseKeyValueString(TString parseString,
0154 TString blockDelim,
0155 TString tokenDelim);
0156 void Train();
0157 void TrainGpu();
0158 void TrainCpu();
0159
0160 virtual Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0161 virtual const std::vector<Float_t>& GetRegressionValues();
0162 virtual const std::vector<Float_t>& GetMulticlassValues();
0163
0164 using MethodBase::ReadWeightsFromStream;
0165
0166
0167 void AddWeightsXMLTo ( void* parent ) const;
0168
0169
0170 void ReadWeightsFromStream( std::istream & i );
0171 void ReadWeightsFromXML ( void* wghtnode );
0172
0173
0174 const Ranking* CreateRanking();
0175
0176 };
0177
0178 inline void MethodDNN::WriteMatrixXML(void *parent,
0179 const char *name,
0180 const TMatrixT<Double_t> &X)
0181 {
0182 std::stringstream matrixStringStream("");
0183 matrixStringStream.precision( 16 );
0184
0185 for (size_t i = 0; i < (size_t) X.GetNrows(); i++)
0186 {
0187 for (size_t j = 0; j < (size_t) X.GetNcols(); j++)
0188 {
0189 matrixStringStream << std::scientific << X(i,j) << " ";
0190 }
0191 }
0192 std::string s = matrixStringStream.str();
0193 void* matxml = gTools().xmlengine().NewChild(parent, nullptr, name);
0194 gTools().xmlengine().NewAttr(matxml, nullptr, "rows",
0195 gTools().StringFromInt((int)X.GetNrows()));
0196 gTools().xmlengine().NewAttr(matxml, nullptr, "cols",
0197 gTools().StringFromInt((int)X.GetNcols()));
0198 gTools().xmlengine().AddRawLine (matxml, s.c_str());
0199 }
0200
0201 inline void MethodDNN::ReadMatrixXML(void *xml,
0202 const char *name,
0203 TMatrixT<Double_t> &X)
0204 {
0205 void *matrixXML = gTools().GetChild(xml, name);
0206 size_t rows, cols;
0207 gTools().ReadAttr(matrixXML, "rows", rows);
0208 gTools().ReadAttr(matrixXML, "cols", cols);
0209
0210 const char * matrixString = gTools().xmlengine().GetNodeContent(matrixXML);
0211 std::stringstream matrixStringStream(matrixString);
0212
0213 for (size_t i = 0; i < rows; i++)
0214 {
0215 for (size_t j = 0; j < cols; j++)
0216 {
0217 matrixStringStream >> X(i,j);
0218 }
0219 }
0220 }
0221 }
0222
0223 #endif