Warning, file /include/root/TMVA/DecisionTreeNode.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 #ifndef ROOT_TMVA_DecisionTreeNode
0031 #define ROOT_TMVA_DecisionTreeNode
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041 #include "TMVA/Node.h"
0042
0043 #include "TMVA/Version.h"
0044
0045 #include <sstream>
0046 #include <vector>
0047 #include <string>
0048
0049 namespace TMVA {
0050
0051 class DTNodeTrainingInfo
0052 {
0053 public:
0054 DTNodeTrainingInfo():fSampleMin(),
0055 fSampleMax(),
0056 fNodeR(0),fSubTreeR(0),fAlpha(0),fG(0),fNTerminal(0),
0057 fNB(0),fNS(0),fSumTarget(0),fSumTarget2(0),fCC(0),
0058 fNSigEvents ( 0 ), fNBkgEvents ( 0 ),
0059 fNEvents ( -1 ),
0060 fNSigEvents_unweighted ( 0 ),
0061 fNBkgEvents_unweighted ( 0 ),
0062 fNEvents_unweighted ( 0 ),
0063 fNSigEvents_unboosted ( 0 ),
0064 fNBkgEvents_unboosted ( 0 ),
0065 fNEvents_unboosted ( 0 ),
0066 fSeparationIndex (-1 ),
0067 fSeparationGain ( -1 )
0068 {
0069 }
0070 std::vector< Float_t > fSampleMin;
0071 std::vector< Float_t > fSampleMax;
0072 Double_t fNodeR;
0073 Double_t fSubTreeR;
0074 Double_t fAlpha;
0075 Double_t fG;
0076 Int_t fNTerminal;
0077 Double_t fNB;
0078 Double_t fNS;
0079 Float_t fSumTarget;
0080 Float_t fSumTarget2;
0081 Double_t fCC;
0082
0083 Float_t fNSigEvents;
0084 Float_t fNBkgEvents;
0085 Float_t fNEvents;
0086 Float_t fNSigEvents_unweighted;
0087 Float_t fNBkgEvents_unweighted;
0088 Float_t fNEvents_unweighted;
0089 Float_t fNSigEvents_unboosted;
0090 Float_t fNBkgEvents_unboosted;
0091 Float_t fNEvents_unboosted;
0092 Float_t fSeparationIndex;
0093 Float_t fSeparationGain;
0094
0095
0096 DTNodeTrainingInfo(const DTNodeTrainingInfo& n) :
0097 fSampleMin(),fSampleMax(),
0098 fNodeR(n.fNodeR), fSubTreeR(n.fSubTreeR),
0099 fAlpha(n.fAlpha), fG(n.fG),
0100 fNTerminal(n.fNTerminal),
0101 fNB(n.fNB), fNS(n.fNS),
0102 fSumTarget(0),fSumTarget2(0),
0103 fCC(0),
0104 fNSigEvents ( n.fNSigEvents ), fNBkgEvents ( n.fNBkgEvents ),
0105 fNEvents ( n.fNEvents ),
0106 fNSigEvents_unweighted ( n.fNSigEvents_unweighted ),
0107 fNBkgEvents_unweighted ( n.fNBkgEvents_unweighted ),
0108 fNEvents_unweighted ( n.fNEvents_unweighted ),
0109 fSeparationIndex( n.fSeparationIndex ),
0110 fSeparationGain ( n.fSeparationGain )
0111 { }
0112 };
0113
0114 class Event;
0115 class MsgLogger;
0116
0117 class DecisionTreeNode: public Node {
0118
0119 public:
0120
0121
0122 DecisionTreeNode ();
0123
0124 DecisionTreeNode (Node* p, char pos);
0125
0126
0127 DecisionTreeNode (const DecisionTreeNode &n, DecisionTreeNode* parent = nullptr);
0128
0129
0130 virtual ~DecisionTreeNode();
0131
0132 Node* CreateNode() const override { return new DecisionTreeNode(); }
0133
0134 inline void SetNFisherCoeff(Int_t nvars){fFisherCoeff.resize(nvars);}
0135 inline UInt_t GetNFisherCoeff() const { return fFisherCoeff.size();}
0136
0137 void SetFisherCoeff(Int_t ivar, Double_t coeff);
0138
0139 Double_t GetFisherCoeff(Int_t ivar) const {return fFisherCoeff.at(ivar);}
0140
0141
0142 Bool_t GoesRight( const Event & ) const override;
0143
0144
0145 Bool_t GoesLeft ( const Event & ) const override;
0146
0147
0148 void SetSelector( Short_t i) { fSelector = i; }
0149
0150 Short_t GetSelector() const { return fSelector; }
0151
0152
0153 void SetCutValue ( Float_t c ) { fCutValue = c; }
0154
0155 Float_t GetCutValue ( void ) const { return fCutValue; }
0156
0157
0158 void SetCutType( Bool_t t ) { fCutType = t; }
0159
0160 Bool_t GetCutType( void ) const { return fCutType; }
0161
0162
0163 void SetNodeType( Int_t t ) { fNodeType = t;}
0164
0165 Int_t GetNodeType( void ) const { return fNodeType; }
0166
0167
0168 Float_t GetPurity( void ) const { return fPurity;}
0169
0170 void SetPurity( void );
0171
0172
0173 void SetResponse( Float_t r ) { fResponse = r;}
0174
0175
0176 Float_t GetResponse( void ) const { return fResponse;}
0177
0178
0179 void SetRMS( Float_t r ) { fRMS = r;}
0180
0181
0182 Float_t GetRMS( void ) const { return fRMS;}
0183
0184
0185 void SetNSigEvents( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents = s; }
0186
0187
0188 void SetNBkgEvents( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents = b; }
0189
0190
0191 void SetNEvents( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents =nev ; }
0192
0193
0194 void SetNSigEvents_unweighted( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents_unweighted = s; }
0195
0196
0197 void SetNBkgEvents_unweighted( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents_unweighted = b; }
0198
0199
0200 void SetNEvents_unweighted( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents_unweighted =nev ; }
0201
0202
0203 void SetNSigEvents_unboosted( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents_unboosted = s; }
0204
0205
0206 void SetNBkgEvents_unboosted( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents_unboosted = b; }
0207
0208
0209 void SetNEvents_unboosted( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents_unboosted =nev ; }
0210
0211
0212 void IncrementNSigEvents( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents += s; }
0213
0214
0215 void IncrementNBkgEvents( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents += b; }
0216
0217
0218 void IncrementNEvents( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents +=nev ; }
0219
0220
0221 void IncrementNSigEvents_unweighted( ) { if(fTrainInfo) fTrainInfo->fNSigEvents_unweighted += 1; }
0222
0223
0224 void IncrementNBkgEvents_unweighted( ) { if(fTrainInfo) fTrainInfo->fNBkgEvents_unweighted += 1; }
0225
0226
0227 void IncrementNEvents_unweighted( ){ if(fTrainInfo) fTrainInfo->fNEvents_unweighted +=1 ; }
0228
0229
0230 Float_t GetNSigEvents( void ) const { return fTrainInfo ? fTrainInfo->fNSigEvents : -1.; }
0231
0232
0233 Float_t GetNBkgEvents( void ) const { return fTrainInfo ? fTrainInfo->fNBkgEvents : -1.; }
0234
0235
0236 Float_t GetNEvents( void ) const { return fTrainInfo ? fTrainInfo->fNEvents : -1.; }
0237
0238
0239 Float_t GetNSigEvents_unweighted( void ) const { return fTrainInfo ? fTrainInfo->fNSigEvents_unweighted : -1.; }
0240
0241
0242 Float_t GetNBkgEvents_unweighted( void ) const { return fTrainInfo ? fTrainInfo->fNBkgEvents_unweighted : -1.; }
0243
0244
0245 Float_t GetNEvents_unweighted( void ) const { return fTrainInfo ? fTrainInfo->fNEvents_unweighted : -1.; }
0246
0247
0248 Float_t GetNSigEvents_unboosted( void ) const { return fTrainInfo ? fTrainInfo->fNSigEvents_unboosted : -1.; }
0249
0250
0251 Float_t GetNBkgEvents_unboosted( void ) const { return fTrainInfo ? fTrainInfo->fNBkgEvents_unboosted : -1.; }
0252
0253
0254 Float_t GetNEvents_unboosted( void ) const { return fTrainInfo ? fTrainInfo->fNEvents_unboosted : -1.; }
0255
0256
0257 void SetSeparationIndex( Float_t sep ){ if(fTrainInfo) fTrainInfo->fSeparationIndex =sep ; }
0258
0259
0260 Float_t GetSeparationIndex( void ) const { return fTrainInfo ? fTrainInfo->fSeparationIndex : -1.; }
0261
0262
0263 void SetSeparationGain( Float_t sep ){ if(fTrainInfo) fTrainInfo->fSeparationGain =sep ; }
0264
0265
0266 Float_t GetSeparationGain( void ) const { return fTrainInfo ? fTrainInfo->fSeparationGain : -1.; }
0267
0268
0269 void Print( std::ostream& os ) const override;
0270
0271
0272 void PrintRec( std::ostream& os ) const override;
0273
0274 void AddAttributesToNode(void* node) const override;
0275 void AddContentToNode(std::stringstream& s) const override;
0276
0277
0278 void ClearNodeAndAllDaughters();
0279
0280
0281
0282
0283 inline DecisionTreeNode* GetLeft( ) const override { return static_cast<DecisionTreeNode*>(fLeft); }
0284 inline DecisionTreeNode* GetRight( ) const override { return static_cast<DecisionTreeNode*>(fRight); }
0285 inline DecisionTreeNode* GetParent( ) const override { return static_cast<DecisionTreeNode*>(fParent); }
0286
0287
0288 inline void SetLeft (Node* l) override { fLeft = l;}
0289 inline void SetRight (Node* r) override { fRight = r;}
0290 inline void SetParent(Node* p) override { fParent = p;}
0291
0292
0293 inline void SetNodeR( Double_t r ) { if(fTrainInfo) fTrainInfo->fNodeR = r; }
0294
0295 inline Double_t GetNodeR( ) const { return fTrainInfo ? fTrainInfo->fNodeR : -1.; }
0296
0297
0298 inline void SetSubTreeR( Double_t r ) { if(fTrainInfo) fTrainInfo->fSubTreeR = r; }
0299
0300 inline Double_t GetSubTreeR( ) const { return fTrainInfo ? fTrainInfo->fSubTreeR : -1.; }
0301
0302
0303
0304
0305
0306 inline void SetAlpha( Double_t alpha ) { if(fTrainInfo) fTrainInfo->fAlpha = alpha; }
0307
0308 inline Double_t GetAlpha( ) const { return fTrainInfo ? fTrainInfo->fAlpha : -1.; }
0309
0310
0311 inline void SetAlphaMinSubtree( Double_t g ) { if(fTrainInfo) fTrainInfo->fG = g; }
0312
0313 inline Double_t GetAlphaMinSubtree( ) const { return fTrainInfo ? fTrainInfo->fG : -1.; }
0314
0315
0316 inline void SetNTerminal( Int_t n ) { if(fTrainInfo) fTrainInfo->fNTerminal = n; }
0317
0318 inline Int_t GetNTerminal( ) const { return fTrainInfo ? fTrainInfo->fNTerminal : -1.; }
0319
0320
0321 inline void SetNBValidation( Double_t b ) { if(fTrainInfo) fTrainInfo->fNB = b; }
0322
0323 inline void SetNSValidation( Double_t s ) { if(fTrainInfo) fTrainInfo->fNS = s; }
0324
0325 inline Double_t GetNBValidation( ) const { return fTrainInfo ? fTrainInfo->fNB : -1.; }
0326
0327 inline Double_t GetNSValidation( ) const { return fTrainInfo ? fTrainInfo->fNS : -1.; }
0328
0329
0330 inline void SetSumTarget(Float_t t) {if(fTrainInfo) fTrainInfo->fSumTarget = t; }
0331
0332 inline void SetSumTarget2(Float_t t2){if(fTrainInfo) fTrainInfo->fSumTarget2 = t2; }
0333
0334
0335 inline void AddToSumTarget(Float_t t) {if(fTrainInfo) fTrainInfo->fSumTarget += t; }
0336
0337 inline void AddToSumTarget2(Float_t t2){if(fTrainInfo) fTrainInfo->fSumTarget2 += t2; }
0338
0339
0340 inline Float_t GetSumTarget() const {return fTrainInfo? fTrainInfo->fSumTarget : -9999;}
0341
0342 inline Float_t GetSumTarget2() const {return fTrainInfo? fTrainInfo->fSumTarget2: -9999;}
0343
0344
0345
0346 void ResetValidationData( );
0347
0348
0349 inline Bool_t IsTerminal() const { return fIsTerminalNode; }
0350 inline void SetTerminal( Bool_t s = kTRUE ) { fIsTerminalNode = s; }
0351 void PrintPrune( std::ostream& os ) const ;
0352 void PrintRecPrune( std::ostream& os ) const;
0353
0354 void SetCC(Double_t cc);
0355
0356 Double_t GetCC() const {return (fTrainInfo? fTrainInfo->fCC : -1.);}
0357
0358 Float_t GetSampleMin(UInt_t ivar) const;
0359 Float_t GetSampleMax(UInt_t ivar) const;
0360 void SetSampleMin(UInt_t ivar, Float_t xmin);
0361 void SetSampleMax(UInt_t ivar, Float_t xmax);
0362
0363 static void SetIsTraining(bool on);
0364 static void SetTmvaVersionCode(UInt_t code);
0365
0366 static bool IsTraining();
0367 static UInt_t GetTmvaVersionCode();
0368
0369 Bool_t ReadDataRecord( std::istream& is, UInt_t tmva_Version_Code = TMVA_VERSION_CODE ) override;
0370 void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE ) override;
0371 void ReadContent(std::stringstream& s) override;
0372
0373 protected:
0374
0375 static MsgLogger& Log();
0376
0377 static bool fgIsTraining;
0378 static UInt_t fgTmva_Version_Code;
0379
0380 std::vector<Double_t> fFisherCoeff;
0381
0382 Float_t fCutValue;
0383 Bool_t fCutType;
0384 Short_t fSelector;
0385
0386 Float_t fResponse;
0387 Float_t fRMS;
0388 Int_t fNodeType;
0389 Float_t fPurity;
0390
0391 Bool_t fIsTerminalNode;
0392
0393 mutable DTNodeTrainingInfo* fTrainInfo;
0394
0395 private:
0396
0397 ClassDefOverride(DecisionTreeNode,0);
0398 };
0399 }
0400
0401 #endif