File indexing completed on 2025-01-30 10:22:51
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 #ifndef ROOT_TMVA_DecisionTreeNode
0031 #define ROOT_TMVA_DecisionTreeNode
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041 #include "TMVA/Node.h"
0042
0043 #include "TMVA/Version.h"
0044
0045 #include <sstream>
0046 #include <vector>
0047 #include <string>
0048
0049 namespace TMVA {
0050
0051 class DTNodeTrainingInfo
0052 {
0053 public:
0054 DTNodeTrainingInfo():fSampleMin(),
0055 fSampleMax(),
0056 fNodeR(0),fSubTreeR(0),fAlpha(0),fG(0),fNTerminal(0),
0057 fNB(0),fNS(0),fSumTarget(0),fSumTarget2(0),fCC(0),
0058 fNSigEvents ( 0 ), fNBkgEvents ( 0 ),
0059 fNEvents ( -1 ),
0060 fNSigEvents_unweighted ( 0 ),
0061 fNBkgEvents_unweighted ( 0 ),
0062 fNEvents_unweighted ( 0 ),
0063 fNSigEvents_unboosted ( 0 ),
0064 fNBkgEvents_unboosted ( 0 ),
0065 fNEvents_unboosted ( 0 ),
0066 fSeparationIndex (-1 ),
0067 fSeparationGain ( -1 )
0068 {
0069 }
0070 std::vector< Float_t > fSampleMin;
0071 std::vector< Float_t > fSampleMax;
0072 Double_t fNodeR;
0073 Double_t fSubTreeR;
0074 Double_t fAlpha;
0075 Double_t fG;
0076 Int_t fNTerminal;
0077 Double_t fNB;
0078 Double_t fNS;
0079 Float_t fSumTarget;
0080 Float_t fSumTarget2;
0081 Double_t fCC;
0082
0083 Float_t fNSigEvents;
0084 Float_t fNBkgEvents;
0085 Float_t fNEvents;
0086 Float_t fNSigEvents_unweighted;
0087 Float_t fNBkgEvents_unweighted;
0088 Float_t fNEvents_unweighted;
0089 Float_t fNSigEvents_unboosted;
0090 Float_t fNBkgEvents_unboosted;
0091 Float_t fNEvents_unboosted;
0092 Float_t fSeparationIndex;
0093 Float_t fSeparationGain;
0094
0095
0096 DTNodeTrainingInfo(const DTNodeTrainingInfo& n) :
0097 fSampleMin(),fSampleMax(),
0098 fNodeR(n.fNodeR), fSubTreeR(n.fSubTreeR),
0099 fAlpha(n.fAlpha), fG(n.fG),
0100 fNTerminal(n.fNTerminal),
0101 fNB(n.fNB), fNS(n.fNS),
0102 fSumTarget(0),fSumTarget2(0),
0103 fCC(0),
0104 fNSigEvents ( n.fNSigEvents ), fNBkgEvents ( n.fNBkgEvents ),
0105 fNEvents ( n.fNEvents ),
0106 fNSigEvents_unweighted ( n.fNSigEvents_unweighted ),
0107 fNBkgEvents_unweighted ( n.fNBkgEvents_unweighted ),
0108 fNEvents_unweighted ( n.fNEvents_unweighted ),
0109 fSeparationIndex( n.fSeparationIndex ),
0110 fSeparationGain ( n.fSeparationGain )
0111 { }
0112 };
0113
0114 class Event;
0115 class MsgLogger;
0116
0117 class DecisionTreeNode: public Node {
0118
0119 public:
0120
0121
0122 DecisionTreeNode ();
0123
0124 DecisionTreeNode (Node* p, char pos);
0125
0126
0127 DecisionTreeNode (const DecisionTreeNode &n, DecisionTreeNode* parent = nullptr);
0128
0129
0130 virtual ~DecisionTreeNode();
0131
0132 virtual Node* CreateNode() const { return new DecisionTreeNode(); }
0133
0134 inline void SetNFisherCoeff(Int_t nvars){fFisherCoeff.resize(nvars);}
0135 inline UInt_t GetNFisherCoeff() const { return fFisherCoeff.size();}
0136
0137 void SetFisherCoeff(Int_t ivar, Double_t coeff);
0138
0139 Double_t GetFisherCoeff(Int_t ivar) const {return fFisherCoeff.at(ivar);}
0140
0141
0142 virtual Bool_t GoesRight( const Event & ) const;
0143
0144
0145 virtual Bool_t GoesLeft ( const Event & ) const;
0146
0147
0148 void SetSelector( Short_t i) { fSelector = i; }
0149
0150 Short_t GetSelector() const { return fSelector; }
0151
0152
0153 void SetCutValue ( Float_t c ) { fCutValue = c; }
0154
0155 Float_t GetCutValue ( void ) const { return fCutValue; }
0156
0157
0158 void SetCutType( Bool_t t ) { fCutType = t; }
0159
0160 Bool_t GetCutType( void ) const { return fCutType; }
0161
0162
0163 void SetNodeType( Int_t t ) { fNodeType = t;}
0164
0165 Int_t GetNodeType( void ) const { return fNodeType; }
0166
0167
0168 Float_t GetPurity( void ) const { return fPurity;}
0169
0170 void SetPurity( void );
0171
0172
0173 void SetResponse( Float_t r ) { fResponse = r;}
0174
0175
0176 Float_t GetResponse( void ) const { return fResponse;}
0177
0178
0179 void SetRMS( Float_t r ) { fRMS = r;}
0180
0181
0182 Float_t GetRMS( void ) const { return fRMS;}
0183
0184
0185 void SetNSigEvents( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents = s; }
0186
0187
0188 void SetNBkgEvents( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents = b; }
0189
0190
0191 void SetNEvents( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents =nev ; }
0192
0193
0194 void SetNSigEvents_unweighted( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents_unweighted = s; }
0195
0196
0197 void SetNBkgEvents_unweighted( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents_unweighted = b; }
0198
0199
0200 void SetNEvents_unweighted( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents_unweighted =nev ; }
0201
0202
0203 void SetNSigEvents_unboosted( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents_unboosted = s; }
0204
0205
0206 void SetNBkgEvents_unboosted( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents_unboosted = b; }
0207
0208
0209 void SetNEvents_unboosted( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents_unboosted =nev ; }
0210
0211
0212 void IncrementNSigEvents( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents += s; }
0213
0214
0215 void IncrementNBkgEvents( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents += b; }
0216
0217
0218 void IncrementNEvents( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents +=nev ; }
0219
0220
0221 void IncrementNSigEvents_unweighted( ) { if(fTrainInfo) fTrainInfo->fNSigEvents_unweighted += 1; }
0222
0223
0224 void IncrementNBkgEvents_unweighted( ) { if(fTrainInfo) fTrainInfo->fNBkgEvents_unweighted += 1; }
0225
0226
0227 void IncrementNEvents_unweighted( ){ if(fTrainInfo) fTrainInfo->fNEvents_unweighted +=1 ; }
0228
0229
0230 Float_t GetNSigEvents( void ) const { return fTrainInfo ? fTrainInfo->fNSigEvents : -1.; }
0231
0232
0233 Float_t GetNBkgEvents( void ) const { return fTrainInfo ? fTrainInfo->fNBkgEvents : -1.; }
0234
0235
0236 Float_t GetNEvents( void ) const { return fTrainInfo ? fTrainInfo->fNEvents : -1.; }
0237
0238
0239 Float_t GetNSigEvents_unweighted( void ) const { return fTrainInfo ? fTrainInfo->fNSigEvents_unweighted : -1.; }
0240
0241
0242 Float_t GetNBkgEvents_unweighted( void ) const { return fTrainInfo ? fTrainInfo->fNBkgEvents_unweighted : -1.; }
0243
0244
0245 Float_t GetNEvents_unweighted( void ) const { return fTrainInfo ? fTrainInfo->fNEvents_unweighted : -1.; }
0246
0247
0248 Float_t GetNSigEvents_unboosted( void ) const { return fTrainInfo ? fTrainInfo->fNSigEvents_unboosted : -1.; }
0249
0250
0251 Float_t GetNBkgEvents_unboosted( void ) const { return fTrainInfo ? fTrainInfo->fNBkgEvents_unboosted : -1.; }
0252
0253
0254 Float_t GetNEvents_unboosted( void ) const { return fTrainInfo ? fTrainInfo->fNEvents_unboosted : -1.; }
0255
0256
0257 void SetSeparationIndex( Float_t sep ){ if(fTrainInfo) fTrainInfo->fSeparationIndex =sep ; }
0258
0259
0260 Float_t GetSeparationIndex( void ) const { return fTrainInfo ? fTrainInfo->fSeparationIndex : -1.; }
0261
0262
0263 void SetSeparationGain( Float_t sep ){ if(fTrainInfo) fTrainInfo->fSeparationGain =sep ; }
0264
0265
0266 Float_t GetSeparationGain( void ) const { return fTrainInfo ? fTrainInfo->fSeparationGain : -1.; }
0267
0268
0269 virtual void Print( std::ostream& os ) const;
0270
0271
0272 virtual void PrintRec( std::ostream& os ) const;
0273
0274 virtual void AddAttributesToNode(void* node) const;
0275 virtual void AddContentToNode(std::stringstream& s) const;
0276
0277
0278 void ClearNodeAndAllDaughters();
0279
0280
0281
0282
0283 inline virtual DecisionTreeNode* GetLeft( ) const { return static_cast<DecisionTreeNode*>(fLeft); }
0284 inline virtual DecisionTreeNode* GetRight( ) const { return static_cast<DecisionTreeNode*>(fRight); }
0285 inline virtual DecisionTreeNode* GetParent( ) const { return static_cast<DecisionTreeNode*>(fParent); }
0286
0287
0288 inline virtual void SetLeft (Node* l) { fLeft = l;}
0289 inline virtual void SetRight (Node* r) { fRight = r;}
0290 inline virtual void SetParent(Node* p) { fParent = p;}
0291
0292
0293 inline void SetNodeR( Double_t r ) { if(fTrainInfo) fTrainInfo->fNodeR = r; }
0294
0295 inline Double_t GetNodeR( ) const { return fTrainInfo ? fTrainInfo->fNodeR : -1.; }
0296
0297
0298 inline void SetSubTreeR( Double_t r ) { if(fTrainInfo) fTrainInfo->fSubTreeR = r; }
0299
0300 inline Double_t GetSubTreeR( ) const { return fTrainInfo ? fTrainInfo->fSubTreeR : -1.; }
0301
0302
0303
0304
0305
0306 inline void SetAlpha( Double_t alpha ) { if(fTrainInfo) fTrainInfo->fAlpha = alpha; }
0307
0308 inline Double_t GetAlpha( ) const { return fTrainInfo ? fTrainInfo->fAlpha : -1.; }
0309
0310
0311 inline void SetAlphaMinSubtree( Double_t g ) { if(fTrainInfo) fTrainInfo->fG = g; }
0312
0313 inline Double_t GetAlphaMinSubtree( ) const { return fTrainInfo ? fTrainInfo->fG : -1.; }
0314
0315
0316 inline void SetNTerminal( Int_t n ) { if(fTrainInfo) fTrainInfo->fNTerminal = n; }
0317
0318 inline Int_t GetNTerminal( ) const { return fTrainInfo ? fTrainInfo->fNTerminal : -1.; }
0319
0320
0321 inline void SetNBValidation( Double_t b ) { if(fTrainInfo) fTrainInfo->fNB = b; }
0322
0323 inline void SetNSValidation( Double_t s ) { if(fTrainInfo) fTrainInfo->fNS = s; }
0324
0325 inline Double_t GetNBValidation( ) const { return fTrainInfo ? fTrainInfo->fNB : -1.; }
0326
0327 inline Double_t GetNSValidation( ) const { return fTrainInfo ? fTrainInfo->fNS : -1.; }
0328
0329
0330 inline void SetSumTarget(Float_t t) {if(fTrainInfo) fTrainInfo->fSumTarget = t; }
0331
0332 inline void SetSumTarget2(Float_t t2){if(fTrainInfo) fTrainInfo->fSumTarget2 = t2; }
0333
0334
0335 inline void AddToSumTarget(Float_t t) {if(fTrainInfo) fTrainInfo->fSumTarget += t; }
0336
0337 inline void AddToSumTarget2(Float_t t2){if(fTrainInfo) fTrainInfo->fSumTarget2 += t2; }
0338
0339
0340 inline Float_t GetSumTarget() const {return fTrainInfo? fTrainInfo->fSumTarget : -9999;}
0341
0342 inline Float_t GetSumTarget2() const {return fTrainInfo? fTrainInfo->fSumTarget2: -9999;}
0343
0344
0345
0346 void ResetValidationData( );
0347
0348
0349 inline Bool_t IsTerminal() const { return fIsTerminalNode; }
0350 inline void SetTerminal( Bool_t s = kTRUE ) { fIsTerminalNode = s; }
0351 void PrintPrune( std::ostream& os ) const ;
0352 void PrintRecPrune( std::ostream& os ) const;
0353
0354 void SetCC(Double_t cc);
0355
0356 Double_t GetCC() const {return (fTrainInfo? fTrainInfo->fCC : -1.);}
0357
0358 Float_t GetSampleMin(UInt_t ivar) const;
0359 Float_t GetSampleMax(UInt_t ivar) const;
0360 void SetSampleMin(UInt_t ivar, Float_t xmin);
0361 void SetSampleMax(UInt_t ivar, Float_t xmax);
0362
0363 static void SetIsTraining(bool on);
0364 static void SetTmvaVersionCode(UInt_t code);
0365
0366 static bool IsTraining();
0367 static UInt_t GetTmvaVersionCode();
0368
0369 virtual Bool_t ReadDataRecord( std::istream& is, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
0370 virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
0371 virtual void ReadContent(std::stringstream& s);
0372
0373 protected:
0374
0375 static MsgLogger& Log();
0376
0377 static bool fgIsTraining;
0378 static UInt_t fgTmva_Version_Code;
0379
0380 std::vector<Double_t> fFisherCoeff;
0381
0382 Float_t fCutValue;
0383 Bool_t fCutType;
0384 Short_t fSelector;
0385
0386 Float_t fResponse;
0387 Float_t fRMS;
0388 Int_t fNodeType;
0389 Float_t fPurity;
0390
0391 Bool_t fIsTerminalNode;
0392
0393 mutable DTNodeTrainingInfo* fTrainInfo;
0394
0395 private:
0396
0397 ClassDef(DecisionTreeNode,0);
0398 };
0399 }
0400
0401 #endif