File indexing completed on 2025-10-31 09:16:11
0001 
0002 
0003 
0004 
0005 
0006 
0007 
0008 
0009 
0010 
0011 
0012 
0013 
0014 
0015 
0016 
0017 
0018 
0019 
0020 
0021 
0022 
0023 
0024 
0025 
0026 
0027 
0028 
0029 
0030 
0031 
0032 #ifndef ROOT_TMVA_DecisionTree
0033 #define ROOT_TMVA_DecisionTree
0034 
0035 
0036 
0037 
0038 
0039 
0040 
0041 
0042 
0043 #include "TH2.h"
0044 #include <vector>
0045 
0046 #include "TMVA/Types.h"
0047 #include "TMVA/DecisionTreeNode.h"
0048 #include "TMVA/BinaryTree.h"
0049 #include "TMVA/BinarySearchTree.h"
0050 #include "TMVA/SeparationBase.h"
0051 #include "TMVA/RegressionVariance.h"
0052 #include "TMVA/DataSetInfo.h"
0053 
0054 #ifdef R__USE_IMT
0055 #include <ROOT/TThreadExecutor.hxx>
0056 #include "TSystem.h"
0057 #endif
0058 
0059 class TRandom3;
0060 
0061 namespace TMVA {
0062 
0063    class Event;
0064 
0065    class DecisionTree : public BinaryTree {
0066 
0067    private:
0068 
0069       static const Int_t fgRandomSeed; 
0070 
0071    public:
0072 
0073       typedef std::vector<TMVA::Event*> EventList;
0074       typedef std::vector<const TMVA::Event*> EventConstList;
0075 
0076       
0077       DecisionTree( void );
0078 
0079       
0080       DecisionTree( SeparationBase *sepType, Float_t minSize,
0081                     Int_t nCuts, DataSetInfo* = nullptr,
0082                     UInt_t cls =0,
0083                     Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
0084                     UInt_t nMaxDepth=9999999,
0085                     Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
0086                     Int_t treeID = 0);
0087 
0088       
0089       DecisionTree (const DecisionTree &d);
0090 
0091       virtual ~DecisionTree( void );
0092 
0093       
0094       virtual DecisionTreeNode* GetRoot() const { return static_cast<TMVA::DecisionTreeNode*>(fRoot); }
0095       virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
0096       virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
0097       static  DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
0098       virtual const char* ClassName() const { return "DecisionTree"; }
0099 
0100       
0101 
0102       
0103       
0104       UInt_t BuildTree( const EventConstList & eventSample,
0105                         DecisionTreeNode *node = nullptr);
0106       
0107 
0108       Double_t TrainNode( const EventConstList & eventSample,  DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
0109       Double_t TrainNodeFast( const EventConstList & eventSample,  DecisionTreeNode *node );
0110       Double_t TrainNodeFull( const EventConstList & eventSample,  DecisionTreeNode *node );
0111       void    GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
0112       std::vector<Double_t>  GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
0113 
0114       
0115       
0116 
0117       void FillTree( const EventList & eventSample);
0118 
0119       
0120       
0121       void FillEvent( const TMVA::Event & event,
0122                       TMVA::DecisionTreeNode *node  );
0123 
0124       
0125 
0126       Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
0127       TMVA::DecisionTreeNode* GetEventNode(const TMVA::Event & e) const;
0128 
0129       
0130       std::vector< Double_t > GetVariableImportance();
0131 
0132       Double_t GetVariableImportance(UInt_t ivar);
0133 
0134       
0135 
0136       void ClearTree();
0137 
0138       
0139       enum EPruneMethod { kExpectedErrorPruning=0, kCostComplexityPruning, kNoPruning };
0140       void SetPruneMethod( EPruneMethod m = kCostComplexityPruning ) { fPruneMethod = m; }
0141 
0142       
0143       Double_t PruneTree( const EventConstList* validationSample = nullptr );
0144 
0145       
0146       void SetPruneStrength( Double_t p ) { fPruneStrength = p; }
0147       Double_t GetPruneStrength( ) const { return fPruneStrength; }
0148 
0149       
0150       void ApplyValidationSample( const EventConstList* validationSample ) const;
0151 
0152       
0153       Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = nullptr, Int_t mode = 0 ) const;
0154 
0155       
0156       void CheckEventWithPrunedTree( const TMVA::Event* ) const;
0157 
0158       
0159       Double_t GetSumWeights( const EventConstList* validationSample ) const;
0160 
0161       void SetNodePurityLimit( Double_t p ) { fNodePurityLimit = p; }
0162       Double_t GetNodePurityLimit( ) const { return fNodePurityLimit; }
0163 
0164       void DescendTree( Node *n = nullptr );
0165       void SetParentTreeInNodes( Node *n = nullptr );
0166 
0167       
0168       
0169       
0170       Node* GetNode( ULong_t sequence, UInt_t depth );
0171 
0172       UInt_t CleanTree(DecisionTreeNode *node = nullptr);
0173 
0174       void PruneNode(TMVA::DecisionTreeNode *node);
0175 
0176       
0177       
0178       void PruneNodeInPlace( TMVA::DecisionTreeNode* node );
0179 
0180       Int_t GetNNodesBeforePruning(){return (fNNodesBeforePruning)?fNNodesBeforePruning:fNNodesBeforePruning=GetNNodes();}
0181 
0182 
0183       UInt_t CountLeafNodes(TMVA::Node *n = nullptr);
0184 
0185       void  SetTreeID(Int_t treeID){fTreeID = treeID;};
0186       Int_t GetTreeID(){return fTreeID;};
0187 
0188       Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
0189       void SetAnalysisType (Types::EAnalysisType t) { fAnalysisType = t;}
0190       Types::EAnalysisType GetAnalysisType ( void ) { return fAnalysisType;}
0191       inline void SetUseFisherCuts(Bool_t t=kTRUE)  { fUseFisherCuts = t;}
0192       inline void SetMinLinCorrForFisher(Double_t min){fMinLinCorrForFisher = min;}
0193       inline void SetUseExclusiveVars(Bool_t t=kTRUE){fUseExclusiveVars = t;}
0194       inline void SetNVars(Int_t n){fNvars = n;}
0195 
0196    private:
0197       
0198 
0199       
0200       
0201 
0202       
0203       Double_t SamplePurity(EventList eventSample);
0204 
0205       UInt_t    fNvars;               
0206       Int_t     fNCuts;               
0207       Bool_t    fUseFisherCuts;       
0208       Double_t  fMinLinCorrForFisher; 
0209       Bool_t    fUseExclusiveVars;    
0210 
0211       SeparationBase *fSepType;       
0212       RegressionVariance *fRegType;   
0213 
0214       Double_t  fMinSize;             
0215       Double_t  fMinNodeSize;         
0216       Double_t  fMinSepGain;          
0217 
0218       Bool_t    fUseSearchTree;       
0219       Double_t  fPruneStrength;       
0220 
0221       EPruneMethod fPruneMethod;      
0222       Int_t    fNNodesBeforePruning;  
0223 
0224       Double_t  fNodePurityLimit;     
0225 
0226       Bool_t    fRandomisedTree;      
0227       Int_t     fUseNvars;            
0228       Bool_t    fUsePoissonNvars;     
0229 
0230       TRandom3  *fMyTrandom;          
0231 
0232       std::vector< Double_t > fVariableImportance; 
0233 
0234       UInt_t     fMaxDepth;           
0235       UInt_t     fSigClass;           
0236       static const Int_t  fgDebugLevel = 0; 
0237       Int_t     fTreeID;              
0238 
0239       Types::EAnalysisType  fAnalysisType; 
0240 
0241       DataSetInfo*  fDataSetInfo;
0242 
0243       ClassDef(DecisionTree,0);               
0244    };
0245 
0246 } 
0247 
0248 #endif