Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:58

0001 /**********************************************************************************
0002  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0003  * Package: TMVA                                                                  *
0004  * Class  : TMVA::DecisionTree                                                    *
0005  *                                             *
0006  *                                                                                *
0007  * Description:                                                                   *
0008  *      Implementation of a Decision Tree                                         *
0009  *                                                                                *
0010  * Authors (alphabetical):                                                        *
0011  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
0012  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
0013  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
0014  *      Doug Schouten   <dschoute@sfu.ca>        - Simon Fraser U., Canada        *
0015  *                                                                                *
0016  * Copyright (c) 2005:                                                            *
0017  *      CERN, Switzerland                                                         *
0018  *      U. of Victoria, Canada                                                    *
0019  *      MPI-K Heidelberg, Germany                                                 *
0020  *                                                                                *
0021  * Redistribution and use in source and binary forms, with or without             *
0022  * modification, are permitted according to the terms listed in LICENSE           *
0023  * (http://mva.sourceforge.net/license.txt)                                       *
0024  *                                                                                *
0025  **********************************************************************************/
0026 
0027 #ifndef ROOT_TMVA_CostComplexityPruneTool
0028 #define ROOT_TMVA_CostComplexityPruneTool
0029 
0030 ////////////////////////////////////////////////////////////////////////////////////////////////////////////
0031 // CostComplexityPruneTool - a class to prune a decision tree using the Cost Complexity method            //
0032 // (see "Classification and Regression Trees" by Leo Breiman et al)                                       //
0033 //                                                                                                        //
0034 // Some definitions:                                                                                      //
0035 //                                                                                                        //
0036 // T_max - the initial, usually highly overtrained tree, that is to be pruned back                        //
0037 // R(T) - quality index (Gini, misclassification rate, or other) of a tree T                              //
0038 // ~T - set of terminal nodes in T                                                                        //
0039 // T' - the pruned subtree of T_max that has the best quality index R(T')                                 //
0040 // alpha - the prune strength parameter in Cost Complexity pruning (R_alpha(T) = R(T) + alpha*|~T|)       //
0041 //                                                                                                        //
0042 // There are two running modes in CostComplexityPruneTool: (i) one may select a prune strength and prune  //
0043 // the tree T_max until the criterion                                                                     //
0044 //             R(T) - R(t)                                                                                //
0045 //  alpha <    ----------                                                                                 //
0046 //             |~T_t| - 1                                                                                 //
0047 //                                                                                                        //
0048 // is true for all nodes t in T, or (ii) the algorithm finds the sequence of critical points              //
0049 // alpha_k < alpha_k+1 ... < alpha_K such that T_K = root(T_max) and then selects the optimally-pruned    //
0050 // subtree, defined to be the subtree with the best quality index for the validation sample.              //
0051 ////////////////////////////////////////////////////////////////////////////////////////////////////////////
0052 
0053 #include "TMVA/SeparationBase.h"
0054 #include "TMVA/GiniIndex.h"
0055 #include "TMVA/DecisionTree.h"
0056 #include "TMVA/Event.h"
0057 #include "TMVA/IPruneTool.h"
0058 #include <vector>
0059 
0060 namespace TMVA {
0061 
0062    class CostComplexityPruneTool : public IPruneTool {
0063    public:
0064       CostComplexityPruneTool( SeparationBase* qualityIndex = nullptr );
0065       virtual ~CostComplexityPruneTool( );
0066 
0067       // calculate the prune sequence for a given tree
0068       virtual PruningInfo* CalculatePruningInfo( DecisionTree* dt, const IPruneTool::EventSample* testEvents = nullptr, Bool_t isAutomatic = kFALSE );
0069 
0070    private:
0071       SeparationBase* fQualityIndexTool;             ///<! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
0072 
0073       std::vector<DecisionTreeNode*> fPruneSequence; ///<! map of weakest links (i.e., branches to prune) -> pruning index
0074       std::vector<Double_t> fPruneStrengthList;      ///<! map of alpha -> pruning index
0075       std::vector<Double_t> fQualityIndexList;       ///<! map of R(T) -> pruning index
0076 
0077       Int_t fOptimalK;                               ///<! the optimal index of the prune sequence
0078 
0079    private:
0080       // set the meta data used for cost complexity pruning
0081       void InitTreePruningMetaData( DecisionTreeNode* n );
0082 
0083       // optimize the pruning sequence
0084       void Optimize( DecisionTree* dt, Double_t weights );
0085 
0086       mutable MsgLogger* fLogger; //! output stream to save logging information
0087       MsgLogger& Log() const { return *fLogger; }
0088 
0089    };
0090 }
0091 
0092 
0093 #endif