File indexing completed on 2025-01-18 10:10:58
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032 #ifndef ROOT_TMVA_DecisionTree
0033 #define ROOT_TMVA_DecisionTree
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043 #include "TH2.h"
0044 #include <vector>
0045
0046 #include "TMVA/Types.h"
0047 #include "TMVA/DecisionTreeNode.h"
0048 #include "TMVA/BinaryTree.h"
0049 #include "TMVA/BinarySearchTree.h"
0050 #include "TMVA/SeparationBase.h"
0051 #include "TMVA/RegressionVariance.h"
0052 #include "TMVA/DataSetInfo.h"
0053
0054 #ifdef R__USE_IMT
0055 #include <ROOT/TThreadExecutor.hxx>
0056 #include "TSystem.h"
0057 #endif
0058
0059 class TRandom3;
0060
0061 namespace TMVA {
0062
0063 class Event;
0064
0065 class DecisionTree : public BinaryTree {
0066
0067 private:
0068
0069 static const Int_t fgRandomSeed;
0070
0071 public:
0072
0073 typedef std::vector<TMVA::Event*> EventList;
0074 typedef std::vector<const TMVA::Event*> EventConstList;
0075
0076
0077 DecisionTree( void );
0078
0079
0080 DecisionTree( SeparationBase *sepType, Float_t minSize,
0081 Int_t nCuts, DataSetInfo* = nullptr,
0082 UInt_t cls =0,
0083 Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
0084 UInt_t nMaxDepth=9999999,
0085 Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
0086 Int_t treeID = 0);
0087
0088
0089 DecisionTree (const DecisionTree &d);
0090
0091 virtual ~DecisionTree( void );
0092
0093
0094 virtual DecisionTreeNode* GetRoot() const { return static_cast<TMVA::DecisionTreeNode*>(fRoot); }
0095 virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
0096 virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
0097 static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
0098 virtual const char* ClassName() const { return "DecisionTree"; }
0099
0100
0101
0102
0103
0104 UInt_t BuildTree( const EventConstList & eventSample,
0105 DecisionTreeNode *node = nullptr);
0106
0107
0108 Double_t TrainNode( const EventConstList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
0109 Double_t TrainNodeFast( const EventConstList & eventSample, DecisionTreeNode *node );
0110 Double_t TrainNodeFull( const EventConstList & eventSample, DecisionTreeNode *node );
0111 void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
0112 std::vector<Double_t> GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
0113
0114
0115
0116
0117 void FillTree( const EventList & eventSample);
0118
0119
0120
0121 void FillEvent( const TMVA::Event & event,
0122 TMVA::DecisionTreeNode *node );
0123
0124
0125
0126 Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
0127 TMVA::DecisionTreeNode* GetEventNode(const TMVA::Event & e) const;
0128
0129
0130 std::vector< Double_t > GetVariableImportance();
0131
0132 Double_t GetVariableImportance(UInt_t ivar);
0133
0134
0135
0136 void ClearTree();
0137
0138
0139 enum EPruneMethod { kExpectedErrorPruning=0, kCostComplexityPruning, kNoPruning };
0140 void SetPruneMethod( EPruneMethod m = kCostComplexityPruning ) { fPruneMethod = m; }
0141
0142
0143 Double_t PruneTree( const EventConstList* validationSample = nullptr );
0144
0145
0146 void SetPruneStrength( Double_t p ) { fPruneStrength = p; }
0147 Double_t GetPruneStrength( ) const { return fPruneStrength; }
0148
0149
0150 void ApplyValidationSample( const EventConstList* validationSample ) const;
0151
0152
0153 Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = nullptr, Int_t mode = 0 ) const;
0154
0155
0156 void CheckEventWithPrunedTree( const TMVA::Event* ) const;
0157
0158
0159 Double_t GetSumWeights( const EventConstList* validationSample ) const;
0160
0161 void SetNodePurityLimit( Double_t p ) { fNodePurityLimit = p; }
0162 Double_t GetNodePurityLimit( ) const { return fNodePurityLimit; }
0163
0164 void DescendTree( Node *n = nullptr );
0165 void SetParentTreeInNodes( Node *n = nullptr );
0166
0167
0168
0169
0170 Node* GetNode( ULong_t sequence, UInt_t depth );
0171
0172 UInt_t CleanTree(DecisionTreeNode *node = nullptr);
0173
0174 void PruneNode(TMVA::DecisionTreeNode *node);
0175
0176
0177
0178 void PruneNodeInPlace( TMVA::DecisionTreeNode* node );
0179
0180 Int_t GetNNodesBeforePruning(){return (fNNodesBeforePruning)?fNNodesBeforePruning:fNNodesBeforePruning=GetNNodes();}
0181
0182
0183 UInt_t CountLeafNodes(TMVA::Node *n = nullptr);
0184
0185 void SetTreeID(Int_t treeID){fTreeID = treeID;};
0186 Int_t GetTreeID(){return fTreeID;};
0187
0188 Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
0189 void SetAnalysisType (Types::EAnalysisType t) { fAnalysisType = t;}
0190 Types::EAnalysisType GetAnalysisType ( void ) { return fAnalysisType;}
0191 inline void SetUseFisherCuts(Bool_t t=kTRUE) { fUseFisherCuts = t;}
0192 inline void SetMinLinCorrForFisher(Double_t min){fMinLinCorrForFisher = min;}
0193 inline void SetUseExclusiveVars(Bool_t t=kTRUE){fUseExclusiveVars = t;}
0194 inline void SetNVars(Int_t n){fNvars = n;}
0195
0196 private:
0197
0198
0199
0200
0201
0202
0203 Double_t SamplePurity(EventList eventSample);
0204
0205 UInt_t fNvars;
0206 Int_t fNCuts;
0207 Bool_t fUseFisherCuts;
0208 Double_t fMinLinCorrForFisher;
0209 Bool_t fUseExclusiveVars;
0210
0211 SeparationBase *fSepType;
0212 RegressionVariance *fRegType;
0213
0214 Double_t fMinSize;
0215 Double_t fMinNodeSize;
0216 Double_t fMinSepGain;
0217
0218 Bool_t fUseSearchTree;
0219 Double_t fPruneStrength;
0220
0221 EPruneMethod fPruneMethod;
0222 Int_t fNNodesBeforePruning;
0223
0224 Double_t fNodePurityLimit;
0225
0226 Bool_t fRandomisedTree;
0227 Int_t fUseNvars;
0228 Bool_t fUsePoissonNvars;
0229
0230 TRandom3 *fMyTrandom;
0231
0232 std::vector< Double_t > fVariableImportance;
0233
0234 UInt_t fMaxDepth;
0235 UInt_t fSigClass;
0236 static const Int_t fgDebugLevel = 0;
0237 Int_t fTreeID;
0238
0239 Types::EAnalysisType fAnalysisType;
0240
0241 DataSetInfo* fDataSetInfo;
0242
0243 ClassDef(DecisionTree,0);
0244 };
0245
0246 }
0247
0248 #endif