Warning, file /include/root/TMVA/MethodRuleFit.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 #ifndef ROOT_TMVA_MethodRuleFit
0027 #define ROOT_TMVA_MethodRuleFit
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037 #include "TMVA/MethodBase.h"
0038 #include "TMatrixDfwd.h"
0039 #include "TVectorD.h"
0040 #include "TMVA/DecisionTree.h"
0041 #include "TMVA/RuleFit.h"
0042 #include <vector>
0043
0044 namespace TMVA {
0045
0046 class SeparationBase;
0047
0048 class MethodRuleFit : public MethodBase {
0049
0050 public:
0051
0052 MethodRuleFit( const TString& jobName,
0053 const TString& methodTitle,
0054 DataSetInfo& theData,
0055 const TString& theOption = "");
0056
0057 MethodRuleFit( DataSetInfo& theData,
0058 const TString& theWeightFile);
0059
0060 virtual ~MethodRuleFit( void );
0061
0062 Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t ) override;
0063
0064
0065 void Train( void ) override;
0066
0067 using MethodBase::ReadWeightsFromStream;
0068
0069
0070 void AddWeightsXMLTo ( void* parent ) const override;
0071
0072
0073 void ReadWeightsFromStream( std::istream& istr ) override;
0074 void ReadWeightsFromXML ( void* wghtnode ) override;
0075
0076
0077 Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr ) override;
0078
0079
0080 void WriteMonitoringHistosToFile( void ) const override;
0081
0082
0083 const Ranking* CreateRanking() override;
0084
0085 Bool_t UseBoost() const { return fUseBoost; }
0086
0087
0088 RuleFit* GetRuleFitPtr() { return &fRuleFit; }
0089 const RuleFit* GetRuleFitConstPtr() const { return &fRuleFit; }
0090 TDirectory* GetMethodBaseDir() const { return BaseDir(); }
0091 const std::vector<TMVA::Event*>& GetTrainingEvents() const { return fEventSample; }
0092 const std::vector<TMVA::DecisionTree*>& GetForest() const { return fForest; }
0093 Int_t GetNTrees() const { return fNTrees; }
0094 Double_t GetTreeEveFrac() const { return fTreeEveFrac; }
0095 const SeparationBase* GetSeparationBaseConst() const { return fSepType; }
0096 SeparationBase* GetSeparationBase() const { return fSepType; }
0097 TMVA::DecisionTree::EPruneMethod GetPruneMethod() const { return fPruneMethod; }
0098 Double_t GetPruneStrength() const { return fPruneStrength; }
0099 Double_t GetMinFracNEve() const { return fMinFracNEve; }
0100 Double_t GetMaxFracNEve() const { return fMaxFracNEve; }
0101 Int_t GetNCuts() const { return fNCuts; }
0102
0103 Int_t GetGDNPathSteps() const { return fGDNPathSteps; }
0104 Double_t GetGDPathStep() const { return fGDPathStep; }
0105 Double_t GetGDErrScale() const { return fGDErrScale; }
0106 Double_t GetGDPathEveFrac() const { return fGDPathEveFrac; }
0107 Double_t GetGDValidEveFrac() const { return fGDValidEveFrac; }
0108
0109 Double_t GetLinQuantile() const { return fLinQuantile; }
0110
0111 const TString GetRFWorkDir() const { return fRFWorkDir; }
0112 Int_t GetRFNrules() const { return fRFNrules; }
0113 Int_t GetRFNendnodes() const { return fRFNendnodes; }
0114
0115 protected:
0116
0117
0118 void MakeClassSpecific( std::ostream&, const TString& ) const override;
0119
0120 void MakeClassRuleCuts( std::ostream& ) const;
0121
0122 void MakeClassLinear( std::ostream& ) const;
0123
0124
0125 void GetHelpMessage() const override;
0126
0127
0128 void Init( void ) override;
0129
0130
0131 void InitEventSample( void );
0132
0133
0134 void InitMonitorNtuple();
0135
0136 void TrainTMVARuleFit();
0137 void TrainJFRuleFit();
0138
0139 private:
0140
0141
0142 template<typename T>
0143 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax );
0144
0145 template<typename T>
0146 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef );
0147
0148 template<typename T>
0149 inline Int_t VerifyRange( const T& var, const T& vmin, const T& vmax );
0150
0151
0152 void DeclareOptions() override;
0153 void ProcessOptions() override;
0154
0155 RuleFit fRuleFit;
0156 std::vector<TMVA::Event *> fEventSample;
0157 Double_t fSignalFraction;
0158
0159
0160 TTree *fMonitorNtuple;
0161 Double_t fNTImportance;
0162 Double_t fNTCoefficient;
0163 Double_t fNTSupport;
0164 Int_t fNTNcuts;
0165 Int_t fNTNvars;
0166 Double_t fNTPtag;
0167 Double_t fNTPss;
0168 Double_t fNTPsb;
0169 Double_t fNTPbs;
0170 Double_t fNTPbb;
0171 Double_t fNTSSB;
0172 Int_t fNTType;
0173
0174
0175 TString fRuleFitModuleS;
0176 Bool_t fUseRuleFitJF;
0177 TString fRFWorkDir;
0178 Int_t fRFNrules;
0179 Int_t fRFNendnodes;
0180 std::vector<DecisionTree *> fForest;
0181 Int_t fNTrees;
0182 Double_t fTreeEveFrac;
0183 SeparationBase *fSepType;
0184 Double_t fMinFracNEve;
0185 Double_t fMaxFracNEve;
0186 Int_t fNCuts;
0187 TString fSepTypeS;
0188 TString fPruneMethodS;
0189 TMVA::DecisionTree::EPruneMethod fPruneMethod;
0190 Double_t fPruneStrength;
0191 TString fForestTypeS;
0192 Bool_t fUseBoost;
0193
0194 Double_t fGDPathEveFrac;
0195 Double_t fGDValidEveFrac;
0196 Double_t fGDTau;
0197 Double_t fGDTauPrec;
0198 Double_t fGDTauMin;
0199 Double_t fGDTauMax;
0200 UInt_t fGDTauScan;
0201 Double_t fGDPathStep;
0202 Int_t fGDNPathSteps;
0203 Double_t fGDErrScale;
0204 Double_t fMinimp;
0205
0206 TString fModelTypeS;
0207 Double_t fRuleMinDist;
0208 Double_t fLinQuantile;
0209
0210 ClassDefOverride(MethodRuleFit,0);
0211 };
0212
0213 }
0214
0215
0216
0217 template<typename T>
0218 inline Int_t TMVA::MethodRuleFit::VerifyRange( const T& var, const T& vmin, const T& vmax )
0219 {
0220
0221 if (var>vmax) return 1;
0222 if (var<vmin) return -1;
0223 return 0;
0224 }
0225
0226
0227 template<typename T>
0228 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax )
0229 {
0230
0231
0232 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
0233 Bool_t modif=kFALSE;
0234 if (dir==1) {
0235 modif = kTRUE;
0236 var=vmax;
0237 }
0238 if (dir==-1) {
0239 modif = kTRUE;
0240 var=vmin;
0241 }
0242 if (modif) {
0243 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to new value = " << var << Endl;
0244 }
0245 return modif;
0246 }
0247
0248
0249 template<typename T>
0250 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef )
0251 {
0252
0253
0254 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
0255 Bool_t modif=kFALSE;
0256 if (dir!=0) {
0257 modif = kTRUE;
0258 var=vdef;
0259 }
0260 if (modif) {
0261 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to default value = " << var << Endl;
0262 }
0263 return modif;
0264 }
0265
0266
0267 #endif