Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:58

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Jan Therhaag, Eckhard von Toerne
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : DecisionTree                                                          *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Implementation of a Decision Tree                                         *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
0015  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
0016  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
0017  *      Jan Therhaag       <Jan.Therhaag@cern.ch>     - U of Bonn, Germany        *
0018  *      Eckhard v. Toerne  <evt@uni-bonn.de>          - U of Bonn, Germany        *
0019  *                                                                                *
0020  * Copyright (c) 2005-2011:                                                       *
0021  *      CERN, Switzerland                                                         *
0022  *      U. of Victoria, Canada                                                    *
0023  *      MPI-K Heidelberg, Germany                                                 *
0024  *      U. of Bonn, Germany                                                       *
0025  *                                                                                *
0026  * Redistribution and use in source and binary forms, with or without             *
0027  * modification, are permitted according to the terms listed in LICENSE           *
0028  * (http://mva.sourceforge.net/license.txt)                                       *
0029  *                                                                                *
0030  **********************************************************************************/
0031 
0032 #ifndef ROOT_TMVA_DecisionTree
0033 #define ROOT_TMVA_DecisionTree
0034 
0035 //////////////////////////////////////////////////////////////////////////
0036 //                                                                      //
0037 // DecisionTree                                                         //
0038 //                                                                      //
0039 // Implementation of a Decision Tree                                    //
0040 //                                                                      //
0041 //////////////////////////////////////////////////////////////////////////
0042 
0043 #include "TH2.h"
0044 #include <vector>
0045 
0046 #include "TMVA/Types.h"
0047 #include "TMVA/DecisionTreeNode.h"
0048 #include "TMVA/BinaryTree.h"
0049 #include "TMVA/BinarySearchTree.h"
0050 #include "TMVA/SeparationBase.h"
0051 #include "TMVA/RegressionVariance.h"
0052 #include "TMVA/DataSetInfo.h"
0053 
0054 #ifdef R__USE_IMT
0055 #include <ROOT/TThreadExecutor.hxx>
0056 #include "TSystem.h"
0057 #endif
0058 
0059 class TRandom3;
0060 
0061 namespace TMVA {
0062 
0063    class Event;
0064 
0065    class DecisionTree : public BinaryTree {
0066 
0067    private:
0068 
0069       static const Int_t fgRandomSeed; // set nonzero for debugging and zero for random seeds
0070 
0071    public:
0072 
0073       typedef std::vector<TMVA::Event*> EventList;
0074       typedef std::vector<const TMVA::Event*> EventConstList;
0075 
0076       // the constructor needed for the "reading" of the decision tree from weight files
0077       DecisionTree( void );
0078 
0079       // the constructor needed for constructing the decision tree via training with events
0080       DecisionTree( SeparationBase *sepType, Float_t minSize,
0081                     Int_t nCuts, DataSetInfo* = nullptr,
0082                     UInt_t cls =0,
0083                     Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
0084                     UInt_t nMaxDepth=9999999,
0085                     Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
0086                     Int_t treeID = 0);
0087 
0088       // copy constructor
0089       DecisionTree (const DecisionTree &d);
0090 
0091       virtual ~DecisionTree( void );
0092 
0093       // Retrieves the address of the root node
0094       virtual DecisionTreeNode* GetRoot() const { return static_cast<TMVA::DecisionTreeNode*>(fRoot); }
0095       virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
0096       virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
0097       static  DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
0098       virtual const char* ClassName() const { return "DecisionTree"; }
0099 
0100       // building of a tree by recursively splitting the nodes
0101 
0102       //      UInt_t BuildTree( const EventList & eventSample,
0103       //                        DecisionTreeNode *node = nullptr);
0104       UInt_t BuildTree( const EventConstList & eventSample,
0105                         DecisionTreeNode *node = nullptr);
0106       // determine the way how a node is split (which variable, which cut value)
0107 
0108       Double_t TrainNode( const EventConstList & eventSample,  DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
0109       Double_t TrainNodeFast( const EventConstList & eventSample,  DecisionTreeNode *node );
0110       Double_t TrainNodeFull( const EventConstList & eventSample,  DecisionTreeNode *node );
0111       void    GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
0112       std::vector<Double_t>  GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
0113 
0114       // fill at tree with a given structure already (just see how many signa/bkgr
0115       // events end up in each node
0116 
0117       void FillTree( const EventList & eventSample);
0118 
0119       // fill the existing the decision tree structure by filling event
0120       // in from the top node and see where they happen to end up
0121       void FillEvent( const TMVA::Event & event,
0122                       TMVA::DecisionTreeNode *node  );
0123 
0124       // returns: 1 = Signal (right),  -1 = Bkg (left)
0125 
0126       Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
0127       TMVA::DecisionTreeNode* GetEventNode(const TMVA::Event & e) const;
0128 
0129       // return the individual relative variable importance
0130       std::vector< Double_t > GetVariableImportance();
0131 
0132       Double_t GetVariableImportance(UInt_t ivar);
0133 
0134       // clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
0135 
0136       void ClearTree();
0137 
0138       // set pruning method
0139       enum EPruneMethod { kExpectedErrorPruning=0, kCostComplexityPruning, kNoPruning };
0140       void SetPruneMethod( EPruneMethod m = kCostComplexityPruning ) { fPruneMethod = m; }
0141 
0142       // recursive pruning of the tree, validation sample required for automatic pruning
0143       Double_t PruneTree( const EventConstList* validationSample = nullptr );
0144 
0145       // manage the pruning strength parameter (iff < 0 -> automate the pruning process)
0146       void SetPruneStrength( Double_t p ) { fPruneStrength = p; }
0147       Double_t GetPruneStrength( ) const { return fPruneStrength; }
0148 
0149       // apply pruning validation sample to a decision tree
0150       void ApplyValidationSample( const EventConstList* validationSample ) const;
0151 
0152       // return the misclassification rate of a pruned tree
0153       Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = nullptr, Int_t mode = 0 ) const;
0154 
0155       // pass a single validation event through a pruned decision tree
0156       void CheckEventWithPrunedTree( const TMVA::Event* ) const;
0157 
0158       // calculate the normalization factor for a pruning validation sample
0159       Double_t GetSumWeights( const EventConstList* validationSample ) const;
0160 
0161       void SetNodePurityLimit( Double_t p ) { fNodePurityLimit = p; }
0162       Double_t GetNodePurityLimit( ) const { return fNodePurityLimit; }
0163 
0164       void DescendTree( Node *n = nullptr );
0165       void SetParentTreeInNodes( Node *n = nullptr );
0166 
0167       // retrieve node from the tree. Its position (up to a maximal tree depth of 64)
0168       // is coded as a sequence of left-right moves starting from the root, coded as
0169       // 0-1 bit patterns stored in the "long-integer" together with the depth
0170       Node* GetNode( ULong_t sequence, UInt_t depth );
0171 
0172       UInt_t CleanTree(DecisionTreeNode *node = nullptr);
0173 
0174       void PruneNode(TMVA::DecisionTreeNode *node);
0175 
0176       // prune a node from the tree without deleting its descendants; allows one to
0177       // effectively prune a tree many times without making deep copies
0178       void PruneNodeInPlace( TMVA::DecisionTreeNode* node );
0179 
0180       Int_t GetNNodesBeforePruning(){return (fNNodesBeforePruning)?fNNodesBeforePruning:fNNodesBeforePruning=GetNNodes();}
0181 
0182 
0183       UInt_t CountLeafNodes(TMVA::Node *n = nullptr);
0184 
0185       void  SetTreeID(Int_t treeID){fTreeID = treeID;};
0186       Int_t GetTreeID(){return fTreeID;};
0187 
0188       Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
0189       void SetAnalysisType (Types::EAnalysisType t) { fAnalysisType = t;}
0190       Types::EAnalysisType GetAnalysisType ( void ) { return fAnalysisType;}
0191       inline void SetUseFisherCuts(Bool_t t=kTRUE)  { fUseFisherCuts = t;}
0192       inline void SetMinLinCorrForFisher(Double_t min){fMinLinCorrForFisher = min;}
0193       inline void SetUseExclusiveVars(Bool_t t=kTRUE){fUseExclusiveVars = t;}
0194       inline void SetNVars(Int_t n){fNvars = n;}
0195 
0196    private:
0197       // utility functions
0198 
0199       // calculate the Purity out of the number of sig and bkg events collected
0200       // from individual samples.
0201 
0202       // calculates the purity S/(S+B) of a given event sample
0203       Double_t SamplePurity(EventList eventSample);
0204 
0205       UInt_t    fNvars;               ///< number of variables used to separate S and B
0206       Int_t     fNCuts;               ///< number of grid point in variable cut scans
0207       Bool_t    fUseFisherCuts;       ///< use multivariate splits using the Fisher criterium
0208       Double_t  fMinLinCorrForFisher; ///< the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
0209       Bool_t    fUseExclusiveVars;    ///< individual variables already used in fisher criterium are not anymore analysed individually for node splitting
0210 
0211       SeparationBase *fSepType;       ///< the separation criteria
0212       RegressionVariance *fRegType;   ///< the separation criteria used in Regression
0213 
0214       Double_t  fMinSize;             ///< min number of events in node
0215       Double_t  fMinNodeSize;         ///< min fraction of training events in node
0216       Double_t  fMinSepGain;          ///< min number of separation gain to perform node splitting
0217 
0218       Bool_t    fUseSearchTree;       ///< cut scan done with binary trees or simple event loop.
0219       Double_t  fPruneStrength;       ///< a parameter to set the "amount" of pruning..needs to be adjusted
0220 
0221       EPruneMethod fPruneMethod;      ///< method used for pruning
0222       Int_t    fNNodesBeforePruning;  ///< remember this one (in case of pruning, it allows to monitor the before/after
0223 
0224       Double_t  fNodePurityLimit;     ///< purity limit to decide whether a node is signal
0225 
0226       Bool_t    fRandomisedTree;      ///< choose at each node splitting a random set of variables
0227       Int_t     fUseNvars;            ///< the number of variables used in randomised trees;
0228       Bool_t    fUsePoissonNvars;     ///< use "fUseNvars" not as fixed number but as mean of a poisson distr. in each split
0229 
0230       TRandom3  *fMyTrandom;          ///< random number generator for randomised trees
0231 
0232       std::vector< Double_t > fVariableImportance; ///< the relative importance of the different variables
0233 
0234       UInt_t     fMaxDepth;           ///< max depth
0235       UInt_t     fSigClass;           ///< class which is treated as signal when building the tree
0236       static const Int_t  fgDebugLevel = 0; ///< debug level determining some printout/control plots etc.
0237       Int_t     fTreeID;              ///< just an ID number given to the tree.. makes debugging easier as tree knows who he is.
0238 
0239       Types::EAnalysisType  fAnalysisType; ///< kClassification(=0=false) or kRegression(=1=true)
0240 
0241       DataSetInfo*  fDataSetInfo;
0242 
0243       ClassDef(DecisionTree,0);               // implementation of a Decision Tree
0244    };
0245 
0246 } // namespace TMVA
0247 
0248 #endif