Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:53

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : MethodDT  (Boosted Decision Trees)                                   *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Analysis of Boosted Decision Trees                                        *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
0015  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
0016  *      Or Cohen        <orcohenor@gmail.com>    - Weizmann Inst., Israel         *
0017  *                                                                                *
0018  * Copyright (c) 2005:                                                            *
0019  *      CERN, Switzerland                                                         *
0020  *      MPI-K Heidelberg, Germany                                                 *
0021  *                                                                                *
0022  * Redistribution and use in source and binary forms, with or without             *
0023  * modification, are permitted according to the terms listed in LICENSE           *
0024  * (see tmva/doc/LICENSE)                                          *
0025  **********************************************************************************/
0026 
0027 #ifndef ROOT_TMVA_MethodDT
0028 #define ROOT_TMVA_MethodDT
0029 
0030 //////////////////////////////////////////////////////////////////////////
0031 //                                                                      //
0032 // MethodDT                                                             //
0033 //                                                                      //
0034 // Analysis of Single Decision Tree                                     //
0035 //                                                                      //
0036 //////////////////////////////////////////////////////////////////////////
0037 
0038 #include <vector>
0039 #include "TH1.h"
0040 #include "TH2.h"
0041 #include "TTree.h"
0042 #include "TMVA/MethodBase.h"
0043 #include "TMVA/DecisionTree.h"
0044 #include "TMVA/Event.h"
0045 
0046 namespace TMVA {
0047    class MethodBoost;
0048 
0049    class MethodDT : public MethodBase {
0050    public:
0051       MethodDT( const TString& jobName,
0052                 const TString& methodTitle,
0053                 DataSetInfo& theData,
0054                 const TString& theOption = "");
0055 
0056       MethodDT( DataSetInfo& dsi,
0057                 const TString& theWeightFile);
0058 
0059       virtual ~MethodDT( void );
0060 
0061       virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );
0062 
0063       void Train( void );
0064 
0065       using MethodBase::ReadWeightsFromStream;
0066 
0067       // write weights to file
0068       void AddWeightsXMLTo( void* parent ) const;
0069 
0070       // read weights from file
0071       void ReadWeightsFromStream( std::istream& istr );
0072       void ReadWeightsFromXML   ( void* wghtnode );
0073 
0074       // calculate the MVA value
0075       Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0076 
0077       // the option handling methods
0078       void DeclareOptions();
0079       void ProcessOptions();
0080       void DeclareCompatibilityOptions();
0081 
0082       void GetHelpMessage() const;
0083 
0084       // ranking of input variables
0085       const Ranking* CreateRanking();
0086 
0087       Double_t PruneTree( );
0088 
0089       Double_t TestTreeQuality( DecisionTree *dt );
0090 
0091       Double_t GetPruneStrength () { return fPruneStrength; }
0092 
0093       void SetMinNodeSize(Double_t sizeInPercent);
0094       void SetMinNodeSize(TString sizeInPercent);
0095 
0096       Int_t GetNNodesBeforePruning(){return fTree->GetNNodesBeforePruning();}
0097       Int_t GetNNodes(){return fTree->GetNNodes();}
0098 
0099    private:
0100       // Init used in the various constructors
0101       void Init( void );
0102 
0103    private:
0104 
0105 
0106       std::vector<Event*>             fEventSample;     ///< the training events
0107 
0108       DecisionTree*                   fTree;            ///< the decision tree
0109       //options for the decision Tree
0110       SeparationBase                 *fSepType;         ///< the separation used in node splitting
0111       TString                         fSepTypeS;        ///< the separation (option string) used in node splitting
0112       Int_t                           fMinNodeEvents;   ///< min number of events in node
0113       Float_t                         fMinNodeSize;     ///< min percentage of training events in node
0114       TString                         fMinNodeSizeS;    ///< string containing min percentage of training events in node
0115 
0116       Int_t                           fNCuts;           ///< grid used in cut applied in node splitting
0117       Bool_t                          fUseYesNoLeaf;    ///< use sig or bkg classification in leave nodes or sig/bkg
0118       Double_t                        fNodePurityLimit; ///< purity limit for sig/bkg nodes
0119       UInt_t                          fMaxDepth;        ///< max depth
0120 
0121 
0122       Double_t                         fErrorFraction;   ///< ntuple var: misclassification error fraction
0123       Double_t                         fPruneStrength;   ///< a parameter to set the "amount" of pruning..needs to be adjusted
0124       DecisionTree::EPruneMethod       fPruneMethod;     ///< method used for pruning
0125       TString                          fPruneMethodS;    ///< prune method option String
0126       Bool_t                           fAutomatic;       ///< use user given prune strength or automatically determined one using a validation sample
0127       Bool_t                           fRandomisedTrees; ///< choose a random subset of possible cut variables at each node during training
0128       Int_t                            fUseNvars;        ///< the number of variables used in the randomised tree splitting
0129       Bool_t                           fUsePoissonNvars; ///< fUseNvars is used as a poisson mean, and the actual value of useNvars is at each step drawn form that distribution
0130       std::vector<Double_t>           fVariableImportance; ///< the relative importance of the different variables
0131 
0132       Double_t                        fDeltaPruneStrength; ///< step size in pruning, is adjusted according to experience of previous trees
0133       // debugging flags
0134       static const Int_t  fgDebugLevel = 0;     ///< debug level determining some printout/control plots etc.
0135 
0136 
0137       Bool_t fPruneBeforeBoost; ///< ancient variable, only needed for "CompatibilityOptions"
0138 
0139       ClassDef(MethodDT,0);  // Analysis of Decision Trees
0140 
0141    };
0142 }
0143 
0144 #endif