Warning, file /include/root/TMVA/RuleFit.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 #ifndef ROOT_TMVA_RuleFit
0028 #define ROOT_TMVA_RuleFit
0029
0030 #include "TMVA/DecisionTree.h"
0031 #include "TMVA/RuleEnsemble.h"
0032 #include "TMVA/RuleFitParams.h"
0033 #include "TMVA/Event.h"
0034
0035 #include <algorithm>
0036 #include <random>
0037 #include <vector>
0038
0039 namespace TMVA {
0040
0041
0042 class MethodBase;
0043 class MethodRuleFit;
0044 class MsgLogger;
0045
0046 class RuleFit {
0047
0048 public:
0049
0050
0051 RuleFit( const TMVA::MethodBase *rfbase );
0052
0053
0054 RuleFit( void );
0055
0056 virtual ~RuleFit( void );
0057
0058 void InitNEveEff();
0059 void InitPtrs( const TMVA::MethodBase *rfbase );
0060 void Initialize( const TMVA::MethodBase *rfbase );
0061
0062 void SetMsgType( EMsgType t );
0063
0064 void SetTrainingEvents( const std::vector<const TMVA::Event *> & el );
0065
0066 void ReshuffleEvents()
0067 {
0068 std::shuffle(fTrainingEventsRndm.begin(), fTrainingEventsRndm.end(), fRNGEngine);
0069 }
0070
0071 void SetMethodBase( const MethodBase *rfbase );
0072
0073
0074 void MakeForest();
0075
0076
0077 void BuildTree( TMVA::DecisionTree *dt );
0078
0079
0080 void SaveEventWeights();
0081
0082
0083 void RestoreEventWeights();
0084
0085
0086 void Boost( TMVA::DecisionTree *dt );
0087
0088
0089 void ForestStatistics();
0090
0091
0092 Double_t EvalEvent( const Event& e );
0093
0094
0095 Double_t CalcWeightSum( const std::vector<const TMVA::Event *> *events, UInt_t neve=0 );
0096
0097
0098 void FitCoefficients();
0099
0100
0101 void CalcImportance();
0102
0103
0104 void SetModelLinear() { fRuleEnsemble.SetModelLinear(); }
0105
0106 void SetModelRules() { fRuleEnsemble.SetModelRules(); }
0107
0108 void SetModelFull() { fRuleEnsemble.SetModelFull(); }
0109
0110 void SetImportanceCut( Double_t minimp=0 ) { fRuleEnsemble.SetImportanceCut(minimp); }
0111
0112 void SetRuleMinDist( Double_t d ) { fRuleEnsemble.SetRuleMinDist(d); }
0113
0114 void SetGDTau( Double_t t=0.0 ) { fRuleFitParams.SetGDTau(t); }
0115 void SetGDPathStep( Double_t s=0.01 ) { fRuleFitParams.SetGDPathStep(s); }
0116 void SetGDNPathSteps( Int_t n=100 ) { fRuleFitParams.SetGDNPathSteps(n); }
0117
0118 void SetVisHistsUseImp( Bool_t f ) { fVisHistsUseImp = f; }
0119 void UseImportanceVisHists() { fVisHistsUseImp = kTRUE; }
0120 void UseCoefficientsVisHists() { fVisHistsUseImp = kFALSE; }
0121 void MakeVisHists();
0122 void FillVisHistCut(const Rule * rule, std::vector<TH2F *> & hlist);
0123 void FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist);
0124 void FillCut(TH2F* h2,const TMVA::Rule *rule,Int_t vind);
0125 void FillLin(TH2F* h2,Int_t vind);
0126 void FillCorr(TH2F* h2,const TMVA::Rule *rule,Int_t v1, Int_t v2);
0127 void NormVisHists(std::vector<TH2F *> & hlist);
0128 void MakeDebugHists();
0129 Bool_t GetCorrVars(TString & title, TString & var1, TString & var2);
0130
0131 UInt_t GetNTreeSample() const { return fNTreeSample; }
0132 Double_t GetNEveEff() const { return fNEveEffTrain; }
0133 const Event* GetTrainingEvent(UInt_t i) const { return static_cast< const Event *>(fTrainingEvents[i]); }
0134 Double_t GetTrainingEventWeight(UInt_t i) const { return fTrainingEvents[i]->GetWeight(); }
0135
0136
0137
0138 const std::vector< const TMVA::Event * > & GetTrainingEvents() const { return fTrainingEvents; }
0139
0140
0141
0142 void GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
0143
0144 const std::vector< const TMVA::DecisionTree *> & GetForest() const { return fForest; }
0145 const RuleEnsemble & GetRuleEnsemble() const { return fRuleEnsemble; }
0146 RuleEnsemble * GetRuleEnsemblePtr() { return &fRuleEnsemble; }
0147 const RuleFitParams & GetRuleFitParams() const { return fRuleFitParams; }
0148 RuleFitParams * GetRuleFitParamsPtr() { return &fRuleFitParams; }
0149 const MethodRuleFit * GetMethodRuleFit() const { return fMethodRuleFit; }
0150 const MethodBase * GetMethodBase() const { return fMethodBase; }
0151
0152 private:
0153
0154
0155 RuleFit( const RuleFit & other );
0156
0157
0158 void Copy( const RuleFit & other );
0159
0160 std::vector<const TMVA::Event *> fTrainingEvents;
0161 std::vector<const TMVA::Event *> fTrainingEventsRndm;
0162 std::vector<Double_t> fEventWeights;
0163 UInt_t fNTreeSample;
0164
0165 Double_t fNEveEffTrain;
0166 std::vector< const TMVA::DecisionTree *> fForest;
0167 RuleEnsemble fRuleEnsemble;
0168 RuleFitParams fRuleFitParams;
0169 const MethodRuleFit *fMethodRuleFit;
0170 const MethodBase *fMethodBase;
0171 Bool_t fVisHistsUseImp;
0172
0173 mutable MsgLogger* fLogger;
0174 MsgLogger& Log() const { return *fLogger; }
0175
0176 static const Int_t randSEED = 0;
0177 std::default_random_engine fRNGEngine;
0178
0179 ClassDef(RuleFit,0);
0180 };
0181 }
0182
0183 #endif