File indexing completed on 2025-01-18 10:11:01
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038 #ifndef ROOT_TMVA_MethodMLP
0039 #define ROOT_TMVA_MethodMLP
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049 #include <vector>
0050 #include <utility>
0051 #include "TString.h"
0052 #include "TTree.h"
0053 #include "TRandom3.h"
0054 #include "TH1F.h"
0055 #include "TMatrixDfwd.h"
0056
0057 #include "TMVA/IFitterTarget.h"
0058 #include "TMVA/MethodBase.h"
0059 #include "TMVA/MethodANNBase.h"
0060 #include "TMVA/TNeuron.h"
0061 #include "TMVA/TActivation.h"
0062 #include "TMVA/ConvergenceTest.h"
0063
0064 #define MethodMLP_UseMinuit__
0065 #undef MethodMLP_UseMinuit__
0066
0067 namespace TMVA {
0068
0069 class MethodMLP : public MethodANNBase, public IFitterTarget, public ConvergenceTest {
0070
0071 public:
0072
0073
0074 MethodMLP( const TString& jobName,
0075 const TString& methodTitle,
0076 DataSetInfo& theData,
0077 const TString& theOption );
0078
0079 MethodMLP( DataSetInfo& theData,
0080 const TString& theWeightFile );
0081
0082 virtual ~MethodMLP();
0083
0084 virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );
0085
0086 void Train();
0087
0088 Double_t ComputeEstimator ( std::vector<Double_t>& parameters );
0089 Double_t EstimatorFunction( std::vector<Double_t>& parameters );
0090
0091 enum ETrainingMethod { kBP=0, kBFGS, kGA };
0092 enum EBPTrainingMode { kSequential=0, kBatch };
0093
0094 bool HasInverseHessian() { return fCalculateErrors; }
0095 Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0096
0097 protected:
0098
0099
0100 void MakeClassSpecific( std::ostream&, const TString& ) const;
0101
0102
0103 void GetHelpMessage() const;
0104
0105
0106 private:
0107
0108
0109 void DeclareOptions();
0110 void ProcessOptions();
0111
0112
0113 void Train( Int_t nEpochs );
0114 void Init();
0115 void InitializeLearningRates();
0116
0117
0118 Double_t CalculateEstimator( Types::ETreeType treeType = Types::kTraining, Int_t iEpoch = -1 );
0119
0120
0121 void BFGSMinimize( Int_t nEpochs );
0122 void SetGammaDelta( TMatrixD &Gamma, TMatrixD &Delta, std::vector<Double_t> &Buffer );
0123 void SteepestDir( TMatrixD &Dir );
0124 Bool_t GetHessian( TMatrixD &Hessian, TMatrixD &Gamma, TMatrixD &Delta );
0125 void SetDir( TMatrixD &Hessian, TMatrixD &Dir );
0126 Double_t DerivDir( TMatrixD &Dir );
0127 Bool_t LineSearch( TMatrixD &Dir, std::vector<Double_t> &Buffer, Double_t* dError=nullptr );
0128 void ComputeDEDw();
0129 void SimulateEvent( const Event* ev );
0130 void SetDirWeights( std::vector<Double_t> &Origin, TMatrixD &Dir, Double_t alpha );
0131 Double_t GetError();
0132 Double_t GetMSEErr( const Event* ev, UInt_t index = 0 );
0133 Double_t GetCEErr( const Event* ev, UInt_t index = 0 );
0134
0135
0136 void BackPropagationMinimize( Int_t nEpochs );
0137 void TrainOneEpoch();
0138 void Shuffle( Int_t* index, Int_t n );
0139 void DecaySynapseWeights(Bool_t lateEpoch );
0140 void TrainOneEvent( Int_t ievt);
0141 Double_t GetDesiredOutput( const Event* ev );
0142 void UpdateNetwork( Double_t desired, Double_t eventWeight=1.0 );
0143 void UpdateNetwork(const std::vector<Float_t>& desired, Double_t eventWeight=1.0);
0144 void CalculateNeuronDeltas();
0145 void UpdateSynapses();
0146 void AdjustSynapseWeights();
0147
0148
0149 void TrainOneEventFast( Int_t ievt, Float_t*& branchVar, Int_t& type );
0150
0151
0152 void GeneticMinimize();
0153
0154
0155 #ifdef MethodMLP_UseMinuit__
0156
0157 void MinuitMinimize();
0158 static MethodMLP* GetThisPtr();
0159 static void IFCN( Int_t& npars, Double_t* grad, Double_t &f, Double_t* fitPars, Int_t ifl );
0160 void FCN( Int_t& npars, Double_t* grad, Double_t &f, Double_t* fitPars, Int_t ifl );
0161 #endif
0162
0163
0164 bool fUseRegulator;
0165 bool fCalculateErrors;
0166 Double_t fPrior;
0167 std::vector<Double_t> fPriorDev;
0168 void GetApproxInvHessian ( TMatrixD& InvHessian, bool regulate=true );
0169 void UpdateRegulators();
0170 void UpdatePriors();
0171 Int_t fUpdateLimit;
0172
0173 ETrainingMethod fTrainingMethod;
0174 TString fTrainMethodS;
0175
0176 Float_t fSamplingFraction;
0177 Float_t fSamplingEpoch;
0178 Float_t fSamplingWeight;
0179 Bool_t fSamplingTraining;
0180 Bool_t fSamplingTesting;
0181
0182
0183 Double_t fLastAlpha;
0184 Double_t fTau;
0185 Int_t fResetStep;
0186
0187
0188 Double_t fLearnRate;
0189 Double_t fDecayRate;
0190 EBPTrainingMode fBPMode;
0191 TString fBpModeS;
0192 Int_t fBatchSize;
0193 Int_t fTestRate;
0194 Bool_t fEpochMon;
0195
0196
0197 Int_t fGA_nsteps;
0198 Int_t fGA_preCalc;
0199 Int_t fGA_SC_steps;
0200 Int_t fGA_SC_rate;
0201 Double_t fGA_SC_factor;
0202
0203
0204 std::vector<std::pair<Float_t,Float_t> >* fDeviationsFromTargets;
0205
0206 Float_t fWeightRange;
0207
0208 #ifdef MethodMLP_UseMinuit__
0209
0210 Int_t fNumberOfWeights;
0211 static MethodMLP* fgThis;
0212 #endif
0213
0214
0215 static const Int_t fgPRINT_ESTIMATOR_INC = 10;
0216 static const Bool_t fgPRINT_SEQ = kFALSE;
0217 static const Bool_t fgPRINT_BATCH = kFALSE;
0218
0219 ClassDef(MethodMLP,0);
0220 };
0221
0222 }
0223
0224 #endif