File indexing completed on 2025-01-30 10:23:03
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 #ifndef ROOT_TMVA_RuleFit
0028 #define ROOT_TMVA_RuleFit
0029
0030 #include "TMVA/DecisionTree.h"
0031 #include "TMVA/RuleEnsemble.h"
0032 #include "TMVA/RuleFitParams.h"
0033 #include "TMVA/Event.h"
0034
0035 #include <algorithm>
0036 #include <random>
0037 #include <vector>
0038
0039 namespace TMVA {
0040
0041
0042 class MethodBase;
0043 class MethodRuleFit;
0044 class MsgLogger;
0045
0046 class RuleFit {
0047
0048 public:
0049
0050
0051 RuleFit( const TMVA::MethodBase *rfbase );
0052
0053
0054 RuleFit( void );
0055
0056 virtual ~RuleFit( void );
0057
0058 void InitNEveEff();
0059 void InitPtrs( const TMVA::MethodBase *rfbase );
0060 void Initialize( const TMVA::MethodBase *rfbase );
0061
0062 void SetMsgType( EMsgType t );
0063
0064 void SetTrainingEvents( const std::vector<const TMVA::Event *> & el );
0065
0066 void ReshuffleEvents()
0067 {
0068 std::shuffle(fTrainingEventsRndm.begin(), fTrainingEventsRndm.end(), fRNGEngine);
0069 }
0070
0071 void SetMethodBase( const MethodBase *rfbase );
0072
0073
0074 void MakeForest();
0075
0076
0077 void BuildTree( TMVA::DecisionTree *dt );
0078
0079
0080 void SaveEventWeights();
0081
0082
0083 void RestoreEventWeights();
0084
0085
0086 void Boost( TMVA::DecisionTree *dt );
0087
0088
0089 void ForestStatistics();
0090
0091
0092 Double_t EvalEvent( const Event& e );
0093
0094
0095 Double_t CalcWeightSum( const std::vector<const TMVA::Event *> *events, UInt_t neve=0 );
0096
0097
0098 void FitCoefficients();
0099
0100
0101 void CalcImportance();
0102
0103
0104 void SetModelLinear() { fRuleEnsemble.SetModelLinear(); }
0105
0106 void SetModelRules() { fRuleEnsemble.SetModelRules(); }
0107
0108 void SetModelFull() { fRuleEnsemble.SetModelFull(); }
0109
0110 void SetImportanceCut( Double_t minimp=0 ) { fRuleEnsemble.SetImportanceCut(minimp); }
0111
0112 void SetRuleMinDist( Double_t d ) { fRuleEnsemble.SetRuleMinDist(d); }
0113
0114 void SetGDTau( Double_t t=0.0 ) { fRuleFitParams.SetGDTau(t); }
0115 void SetGDPathStep( Double_t s=0.01 ) { fRuleFitParams.SetGDPathStep(s); }
0116 void SetGDNPathSteps( Int_t n=100 ) { fRuleFitParams.SetGDNPathSteps(n); }
0117
0118 void SetVisHistsUseImp( Bool_t f ) { fVisHistsUseImp = f; }
0119 void UseImportanceVisHists() { fVisHistsUseImp = kTRUE; }
0120 void UseCoefficientsVisHists() { fVisHistsUseImp = kFALSE; }
0121 void MakeVisHists();
0122 void FillVisHistCut(const Rule * rule, std::vector<TH2F *> & hlist);
0123 void FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist);
0124 void FillCut(TH2F* h2,const TMVA::Rule *rule,Int_t vind);
0125 void FillLin(TH2F* h2,Int_t vind);
0126 void FillCorr(TH2F* h2,const TMVA::Rule *rule,Int_t v1, Int_t v2);
0127 void NormVisHists(std::vector<TH2F *> & hlist);
0128 void MakeDebugHists();
0129 Bool_t GetCorrVars(TString & title, TString & var1, TString & var2);
0130
0131 UInt_t GetNTreeSample() const { return fNTreeSample; }
0132 Double_t GetNEveEff() const { return fNEveEffTrain; }
0133 const Event* GetTrainingEvent(UInt_t i) const { return static_cast< const Event *>(fTrainingEvents[i]); }
0134 Double_t GetTrainingEventWeight(UInt_t i) const { return fTrainingEvents[i]->GetWeight(); }
0135
0136
0137
0138 const std::vector< const TMVA::Event * > & GetTrainingEvents() const { return fTrainingEvents; }
0139
0140
0141
0142 void GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
0143
0144 const std::vector< const TMVA::DecisionTree *> & GetForest() const { return fForest; }
0145 const RuleEnsemble & GetRuleEnsemble() const { return fRuleEnsemble; }
0146 RuleEnsemble * GetRuleEnsemblePtr() { return &fRuleEnsemble; }
0147 const RuleFitParams & GetRuleFitParams() const { return fRuleFitParams; }
0148 RuleFitParams * GetRuleFitParamsPtr() { return &fRuleFitParams; }
0149 const MethodRuleFit * GetMethodRuleFit() const { return fMethodRuleFit; }
0150 const MethodBase * GetMethodBase() const { return fMethodBase; }
0151
0152 private:
0153
0154
0155 RuleFit( const RuleFit & other );
0156
0157
0158 void Copy( const RuleFit & other );
0159
0160 std::vector<const TMVA::Event *> fTrainingEvents;
0161 std::vector<const TMVA::Event *> fTrainingEventsRndm;
0162 std::vector<Double_t> fEventWeights;
0163 UInt_t fNTreeSample;
0164
0165 Double_t fNEveEffTrain;
0166 std::vector< const TMVA::DecisionTree *> fForest;
0167 RuleEnsemble fRuleEnsemble;
0168 RuleFitParams fRuleFitParams;
0169 const MethodRuleFit *fMethodRuleFit;
0170 const MethodBase *fMethodBase;
0171 Bool_t fVisHistsUseImp;
0172
0173 mutable MsgLogger* fLogger;
0174 MsgLogger& Log() const { return *fLogger; }
0175
0176 static const Int_t randSEED = 0;
0177 std::default_random_engine fRNGEngine;
0178
0179 ClassDef(RuleFit,0);
0180 };
0181 }
0182
0183 #endif