File indexing completed on 2025-01-18 10:11:01
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 #ifndef ROOT_TMVA_MethodRuleFit
0027 #define ROOT_TMVA_MethodRuleFit
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037 #include "TMVA/MethodBase.h"
0038 #include "TMatrixDfwd.h"
0039 #include "TVectorD.h"
0040 #include "TMVA/DecisionTree.h"
0041 #include "TMVA/RuleFit.h"
0042 #include <vector>
0043
0044 namespace TMVA {
0045
0046 class SeparationBase;
0047
0048 class MethodRuleFit : public MethodBase {
0049
0050 public:
0051
0052 MethodRuleFit( const TString& jobName,
0053 const TString& methodTitle,
0054 DataSetInfo& theData,
0055 const TString& theOption = "");
0056
0057 MethodRuleFit( DataSetInfo& theData,
0058 const TString& theWeightFile);
0059
0060 virtual ~MethodRuleFit( void );
0061
0062 virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t );
0063
0064
0065 void Train( void );
0066
0067 using MethodBase::ReadWeightsFromStream;
0068
0069
0070 void AddWeightsXMLTo ( void* parent ) const;
0071
0072
0073 void ReadWeightsFromStream( std::istream& istr );
0074 void ReadWeightsFromXML ( void* wghtnode );
0075
0076
0077 Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0078
0079
0080 void WriteMonitoringHistosToFile( void ) const;
0081
0082
0083 const Ranking* CreateRanking();
0084
0085 Bool_t UseBoost() const { return fUseBoost; }
0086
0087
0088 RuleFit* GetRuleFitPtr() { return &fRuleFit; }
0089 const RuleFit* GetRuleFitConstPtr() const { return &fRuleFit; }
0090 TDirectory* GetMethodBaseDir() const { return BaseDir(); }
0091 const std::vector<TMVA::Event*>& GetTrainingEvents() const { return fEventSample; }
0092 const std::vector<TMVA::DecisionTree*>& GetForest() const { return fForest; }
0093 Int_t GetNTrees() const { return fNTrees; }
0094 Double_t GetTreeEveFrac() const { return fTreeEveFrac; }
0095 const SeparationBase* GetSeparationBaseConst() const { return fSepType; }
0096 SeparationBase* GetSeparationBase() const { return fSepType; }
0097 TMVA::DecisionTree::EPruneMethod GetPruneMethod() const { return fPruneMethod; }
0098 Double_t GetPruneStrength() const { return fPruneStrength; }
0099 Double_t GetMinFracNEve() const { return fMinFracNEve; }
0100 Double_t GetMaxFracNEve() const { return fMaxFracNEve; }
0101 Int_t GetNCuts() const { return fNCuts; }
0102
0103 Int_t GetGDNPathSteps() const { return fGDNPathSteps; }
0104 Double_t GetGDPathStep() const { return fGDPathStep; }
0105 Double_t GetGDErrScale() const { return fGDErrScale; }
0106 Double_t GetGDPathEveFrac() const { return fGDPathEveFrac; }
0107 Double_t GetGDValidEveFrac() const { return fGDValidEveFrac; }
0108
0109 Double_t GetLinQuantile() const { return fLinQuantile; }
0110
0111 const TString GetRFWorkDir() const { return fRFWorkDir; }
0112 Int_t GetRFNrules() const { return fRFNrules; }
0113 Int_t GetRFNendnodes() const { return fRFNendnodes; }
0114
0115 protected:
0116
0117
0118 void MakeClassSpecific( std::ostream&, const TString& ) const;
0119
0120 void MakeClassRuleCuts( std::ostream& ) const;
0121
0122 void MakeClassLinear( std::ostream& ) const;
0123
0124
0125 void GetHelpMessage() const;
0126
0127
0128 void Init( void );
0129
0130
0131 void InitEventSample( void );
0132
0133
0134 void InitMonitorNtuple();
0135
0136 void TrainTMVARuleFit();
0137 void TrainJFRuleFit();
0138
0139 private:
0140
0141
0142 template<typename T>
0143 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax );
0144
0145 template<typename T>
0146 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef );
0147
0148 template<typename T>
0149 inline Int_t VerifyRange( const T& var, const T& vmin, const T& vmax );
0150
0151
0152 void DeclareOptions();
0153 void ProcessOptions();
0154
0155 RuleFit fRuleFit;
0156 std::vector<TMVA::Event *> fEventSample;
0157 Double_t fSignalFraction;
0158
0159
0160 TTree *fMonitorNtuple;
0161 Double_t fNTImportance;
0162 Double_t fNTCoefficient;
0163 Double_t fNTSupport;
0164 Int_t fNTNcuts;
0165 Int_t fNTNvars;
0166 Double_t fNTPtag;
0167 Double_t fNTPss;
0168 Double_t fNTPsb;
0169 Double_t fNTPbs;
0170 Double_t fNTPbb;
0171 Double_t fNTSSB;
0172 Int_t fNTType;
0173
0174
0175 TString fRuleFitModuleS;
0176 Bool_t fUseRuleFitJF;
0177 TString fRFWorkDir;
0178 Int_t fRFNrules;
0179 Int_t fRFNendnodes;
0180 std::vector<DecisionTree *> fForest;
0181 Int_t fNTrees;
0182 Double_t fTreeEveFrac;
0183 SeparationBase *fSepType;
0184 Double_t fMinFracNEve;
0185 Double_t fMaxFracNEve;
0186 Int_t fNCuts;
0187 TString fSepTypeS;
0188 TString fPruneMethodS;
0189 TMVA::DecisionTree::EPruneMethod fPruneMethod;
0190 Double_t fPruneStrength;
0191 TString fForestTypeS;
0192 Bool_t fUseBoost;
0193
0194 Double_t fGDPathEveFrac;
0195 Double_t fGDValidEveFrac;
0196 Double_t fGDTau;
0197 Double_t fGDTauPrec;
0198 Double_t fGDTauMin;
0199 Double_t fGDTauMax;
0200 UInt_t fGDTauScan;
0201 Double_t fGDPathStep;
0202 Int_t fGDNPathSteps;
0203 Double_t fGDErrScale;
0204 Double_t fMinimp;
0205
0206 TString fModelTypeS;
0207 Double_t fRuleMinDist;
0208 Double_t fLinQuantile;
0209
0210 ClassDef(MethodRuleFit,0);
0211 };
0212
0213 }
0214
0215
0216
0217 template<typename T>
0218 inline Int_t TMVA::MethodRuleFit::VerifyRange( const T& var, const T& vmin, const T& vmax )
0219 {
0220
0221 if (var>vmax) return 1;
0222 if (var<vmin) return -1;
0223 return 0;
0224 }
0225
0226
0227 template<typename T>
0228 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax )
0229 {
0230
0231
0232 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
0233 Bool_t modif=kFALSE;
0234 if (dir==1) {
0235 modif = kTRUE;
0236 var=vmax;
0237 }
0238 if (dir==-1) {
0239 modif = kTRUE;
0240 var=vmin;
0241 }
0242 if (modif) {
0243 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to new value = " << var << Endl;
0244 }
0245 return modif;
0246 }
0247
0248
0249 template<typename T>
0250 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef )
0251 {
0252
0253
0254 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
0255 Bool_t modif=kFALSE;
0256 if (dir!=0) {
0257 modif = kTRUE;
0258 var=vdef;
0259 }
0260 if (modif) {
0261 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to default value = " << var << Endl;
0262 }
0263 return modif;
0264 }
0265
0266
0267 #endif