File indexing completed on 2025-01-30 10:23:03
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031 #ifndef ROOT_TMVA_Rule
0032 #define ROOT_TMVA_Rule
0033
0034 #include "TMath.h"
0035 #include <vector>
0036 #include <iostream>
0037
0038 #include "TMVA/DecisionTree.h"
0039 #include "TMVA/Event.h"
0040 #include "TMVA/RuleCut.h"
0041
0042 namespace TMVA {
0043
0044 class RuleEnsemble;
0045 class MsgLogger;
0046 class Rule;
0047
0048 std::ostream& operator<<( std::ostream& os, const Rule & rule );
0049
0050 class Rule {
0051
0052
0053 friend std::ostream& operator<< ( std::ostream& os, const Rule & rule );
0054
0055 public:
0056
0057
0058 Rule( RuleEnsemble *re, const std::vector< const TMVA::Node * > & nodes );
0059
0060
0061 Rule( RuleEnsemble *re );
0062
0063
0064 Rule( const Rule & other ) { Copy( other ); }
0065
0066
0067 Rule();
0068
0069 virtual ~Rule();
0070
0071
0072 void SetMsgType( EMsgType t );
0073
0074
0075 void SetRuleEnsemble( const RuleEnsemble *re ) { fRuleEnsemble = re; }
0076
0077
0078 void SetRuleCut( RuleCut *rc ) { fCut = rc; }
0079
0080
0081 void SetNorm(Double_t norm) { fNorm = (norm>0 ? 1.0/norm:1.0); }
0082
0083
0084 void SetCoefficient(Double_t v) { fCoefficient=v; }
0085
0086
0087 void SetSupport(Double_t v) { fSupport=v; fSigma = TMath::Sqrt(v*(1.0-v));}
0088
0089
0090 void SetSSB(Double_t v) { fSSB=v; }
0091
0092
0093 void SetSSBNeve(Double_t v) { fSSBNeve=v; }
0094
0095
0096 void SetImportanceRef(Double_t v) { fImportanceRef=(v>0 ? v:1.0); }
0097
0098
0099 void CalcImportance() { fImportance = TMath::Abs(fCoefficient)*fSigma; }
0100
0101
0102 Double_t GetRelImportance() const { return fImportance/fImportanceRef; }
0103
0104
0105
0106
0107
0108 inline Bool_t EvalEvent( const Event& e ) const;
0109
0110
0111 Bool_t Equal( const Rule & other, Bool_t useCutValue, Double_t maxdist ) const;
0112
0113
0114 Double_t RuleDist( const Rule & other, Bool_t useCutValue ) const;
0115
0116
0117 Double_t GetSSB() const { return fSSB; }
0118 Double_t GetSSBNeve() const { return fSSBNeve; }
0119 Bool_t IsSignalRule() const { return (fSSB>0.5); }
0120
0121
0122 void operator=( const Rule & other ) { Copy( other ); }
0123
0124
0125 Bool_t operator==( const Rule & other ) const;
0126
0127 Bool_t operator<( const Rule & other ) const;
0128
0129
0130 UInt_t GetNumVarsUsed() const { return fCut->GetNvars(); }
0131
0132
0133 UInt_t GetNcuts() const { return fCut->GetNcuts(); }
0134
0135
0136 Bool_t ContainsVariable(UInt_t iv) const;
0137
0138
0139 const RuleCut* GetRuleCut() const { return fCut; }
0140 const RuleEnsemble* GetRuleEnsemble() const { return fRuleEnsemble; }
0141 Double_t GetCoefficient() const { return fCoefficient; }
0142 Double_t GetSupport() const { return fSupport; }
0143 Double_t GetSigma() const { return fSigma; }
0144 Double_t GetNorm() const { return fNorm; }
0145 Double_t GetImportance() const { return fImportance; }
0146 Double_t GetImportanceRef() const { return fImportanceRef; }
0147
0148
0149 void PrintLogger( const char *title=nullptr ) const;
0150
0151
0152 void PrintRaw ( std::ostream& os ) const;
0153 void* AddXMLTo ( void* parent ) const;
0154
0155 void ReadRaw ( std::istream& os );
0156 void ReadFromXML( void* wghtnode );
0157
0158 private:
0159
0160
0161 void SetSigma(Double_t v) { fSigma=v; }
0162
0163
0164 void Print( std::ostream& os ) const;
0165
0166
0167 void Copy( const Rule & other );
0168
0169
0170 const TString & GetVarName( Int_t i) const;
0171
0172 RuleCut* fCut;
0173 Double_t fNorm;
0174 Double_t fSupport;
0175 Double_t fSigma;
0176 Double_t fCoefficient;
0177 Double_t fImportance;
0178 Double_t fImportanceRef;
0179 const RuleEnsemble* fRuleEnsemble;
0180 Double_t fSSB;
0181 Double_t fSSBNeve;
0182
0183 mutable MsgLogger* fLogger;
0184 MsgLogger& Log() const { return *fLogger; }
0185
0186 };
0187
0188 }
0189
0190
0191 inline Bool_t TMVA::Rule::EvalEvent( const TMVA::Event& e ) const
0192 {
0193
0194
0195
0196 return fCut->EvalEvent(e);
0197 }
0198
0199 #endif