Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:49

0001 #ifndef ROOT_TMVA_CCPruner
0002 #define ROOT_TMVA_CCPruner
0003 /**********************************************************************************
0004  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0005  * Package: TMVA                                                                  *
0006  * Class  : CCPruner                                                              *
0007  *                                             *
0008  *                                                                                *
0009  * Description: Cost Complexity Pruning                                           *
0010  *
0011  * Author: Doug Schouten (dschoute@sfu.ca)
0012  *
0013  *                                                                                *
0014  * Copyright (c) 2007:                                                            *
0015  *      CERN, Switzerland                                                         *
0016  *      MPI-K Heidelberg, Germany                                                 *
0017  *      U. of Texas at Austin, USA                                                *
0018  *                                                                                *
0019  * Redistribution and use in source and binary forms, with or without             *
0020  * modification, are permitted according to the terms listed in LICENSE           *
0021  * (see tmva/doc/LICENSE)                                          *
0022  **********************************************************************************/
0023 
0024 ////////////////////////////////////////////////////////////////////////////////////////////////////////////
0025 // CCPruner - a helper class to prune a decision tree using the Cost Complexity method                    //
0026 // (see Classification and Regression Trees by Leo Breiman et al)                                         //
0027 //                                                                                                        //
0028 // Some definitions:                                                                                      //
0029 //                                                                                                        //
0030 // T_max - the initial, usually highly overtrained tree, that is to be pruned back                        //
0031 // R(T) - quality index (Gini, misclassification rate, or other) of a tree T                              //
0032 // ~T - set of terminal nodes in T                                                                        //
0033 // T' - the pruned subtree of T_max that has the best quality index R(T')                                 //
0034 // alpha - the prune strength parameter in Cost Complexity pruning (R_alpha(T) = R(T) + alpha// |~T|)     //
0035 //                                                                                                        //
0036 // There are two running modes in CCPruner: (i) one may select a prune strength and prune back            //
0037 // the tree T_max until the criterion                                                                     //
0038 //             R(T) - R(t)                                                                                //
0039 //  alpha <    ----------                                                                                 //
0040 //             |~T_t| - 1                                                                                 //
0041 //                                                                                                        //
0042 // is true for all nodes t in T, or (ii) the algorithm finds the sequence of critical points              //
0043 // alpha_k < alpha_k+1 ... < alpha_K such that T_K = root(T_max) and then selects the optimally-pruned    //
0044 // subtree, defined to be the subtree with the best quality index for the validation sample.              //
0045 ////////////////////////////////////////////////////////////////////////////////////////////////////////////
0046 
0047 
0048 #include "TMVA/DecisionTree.h"
0049 
0050 /* #ifndef ROOT_TMVA_DecisionTreeNode */
0051 /* #include "TMVA/DecisionTreeNode.h" */
0052 /* #endif */
0053 
0054 #include "TMVA/Event.h"
0055 #include <vector>
0056 
0057 namespace TMVA {
0058    class DataSet;
0059    class DecisionTreeNode;
0060    class SeparationBase;
0061 
0062    class CCPruner {
0063    public:
0064       typedef std::vector<Event*> EventList;
0065 
0066       CCPruner( DecisionTree* t_max,
0067                 const EventList* validationSample,
0068                 SeparationBase* qualityIndex = nullptr );
0069 
0070       CCPruner( DecisionTree* t_max,
0071                 const DataSet* validationSample,
0072                 SeparationBase* qualityIndex = nullptr );
0073 
0074       ~CCPruner( );
0075 
0076       // set the pruning strength parameter alpha (if alpha < 0, the optimal alpha is calculated)
0077       void SetPruneStrength( Float_t alpha = -1.0 );
0078 
0079       void Optimize( );
0080 
0081       // return the list of pruning locations to define the optimal subtree T' of T_max
0082       std::vector<TMVA::DecisionTreeNode*> GetOptimalPruneSequence( ) const;
0083 
0084       // return the quality index from the validation sample for the optimal subtree T'
0085       inline Float_t GetOptimalQualityIndex( ) const { return (fOptimalK >= 0 && fQualityIndexList.size() > 0 ?
0086                                                                fQualityIndexList[fOptimalK] : -1.0); }
0087 
0088       // return the prune strength (=alpha) corresponding to the prune sequence
0089       inline Float_t GetOptimalPruneStrength( ) const { return (fOptimalK >= 0 && fPruneStrengthList.size() > 0 ?
0090                                                                 fPruneStrengthList[fOptimalK] : -1.0); }
0091 
0092    private:
0093       Float_t              fAlpha;             ///<! regularization parameter in CC pruning
0094       const EventList*     fValidationSample;  ///<! the event sample to select the optimally-pruned tree
0095       const DataSet*       fValidationDataSet; ///<! the event sample to select the optimally-pruned tree
0096       SeparationBase*      fQualityIndex;      ///<! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
0097       Bool_t               fOwnQIndex;         ///<! flag indicates if fQualityIndex is owned by this
0098 
0099       DecisionTree*        fTree;              ///<! (pruned) decision tree
0100 
0101       std::vector<TMVA::DecisionTreeNode*> fPruneSequence; ///<! map of weakest links (i.e., branches to prune) -> pruning index
0102       std::vector<Float_t> fPruneStrengthList; ///<! map of alpha -> pruning index
0103       std::vector<Float_t> fQualityIndexList;  ///<! map of R(T) -> pruning index
0104 
0105       Int_t                fOptimalK;          ///<! index of the optimal tree in the pruned tree sequence
0106       Bool_t               fDebug;             ///<! debug flag
0107    };
0108 }
0109 
0110 inline void TMVA::CCPruner::SetPruneStrength( Float_t alpha ) {
0111    fAlpha = (alpha > 0 ? alpha : 0.0);
0112 }
0113 
0114 
0115 #endif
0116 
0117