|
||||
File indexing completed on 2025-01-30 10:22:49
0001 #ifndef ROOT_TMVA_CCPruner 0002 #define ROOT_TMVA_CCPruner 0003 /********************************************************************************** 0004 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * 0005 * Package: TMVA * 0006 * Class : CCPruner * 0007 * * 0008 * * 0009 * Description: Cost Complexity Pruning * 0010 * 0011 * Author: Doug Schouten (dschoute@sfu.ca) 0012 * 0013 * * 0014 * Copyright (c) 2007: * 0015 * CERN, Switzerland * 0016 * MPI-K Heidelberg, Germany * 0017 * U. of Texas at Austin, USA * 0018 * * 0019 * Redistribution and use in source and binary forms, with or without * 0020 * modification, are permitted according to the terms listed in LICENSE * 0021 * (see tmva/doc/LICENSE) * 0022 **********************************************************************************/ 0023 0024 //////////////////////////////////////////////////////////////////////////////////////////////////////////// 0025 // CCPruner - a helper class to prune a decision tree using the Cost Complexity method // 0026 // (see Classification and Regression Trees by Leo Breiman et al) // 0027 // // 0028 // Some definitions: // 0029 // // 0030 // T_max - the initial, usually highly overtrained tree, that is to be pruned back // 0031 // R(T) - quality index (Gini, misclassification rate, or other) of a tree T // 0032 // ~T - set of terminal nodes in T // 0033 // T' - the pruned subtree of T_max that has the best quality index R(T') // 0034 // alpha - the prune strength parameter in Cost Complexity pruning (R_alpha(T) = R(T) + alpha// |~T|) // 0035 // // 0036 // There are two running modes in CCPruner: (i) one may select a prune strength and prune back // 0037 // the tree T_max until the criterion // 0038 // R(T) - R(t) // 0039 // alpha < ---------- // 0040 // |~T_t| - 1 // 0041 // // 0042 // is true for all nodes t in T, or (ii) the algorithm finds the sequence of critical points // 0043 // alpha_k < alpha_k+1 ... < alpha_K such that T_K = root(T_max) and then selects the optimally-pruned // 0044 // subtree, defined to be the subtree with the best quality index for the validation sample. // 0045 //////////////////////////////////////////////////////////////////////////////////////////////////////////// 0046 0047 0048 #include "TMVA/DecisionTree.h" 0049 0050 /* #ifndef ROOT_TMVA_DecisionTreeNode */ 0051 /* #include "TMVA/DecisionTreeNode.h" */ 0052 /* #endif */ 0053 0054 #include "TMVA/Event.h" 0055 #include <vector> 0056 0057 namespace TMVA { 0058 class DataSet; 0059 class DecisionTreeNode; 0060 class SeparationBase; 0061 0062 class CCPruner { 0063 public: 0064 typedef std::vector<Event*> EventList; 0065 0066 CCPruner( DecisionTree* t_max, 0067 const EventList* validationSample, 0068 SeparationBase* qualityIndex = nullptr ); 0069 0070 CCPruner( DecisionTree* t_max, 0071 const DataSet* validationSample, 0072 SeparationBase* qualityIndex = nullptr ); 0073 0074 ~CCPruner( ); 0075 0076 // set the pruning strength parameter alpha (if alpha < 0, the optimal alpha is calculated) 0077 void SetPruneStrength( Float_t alpha = -1.0 ); 0078 0079 void Optimize( ); 0080 0081 // return the list of pruning locations to define the optimal subtree T' of T_max 0082 std::vector<TMVA::DecisionTreeNode*> GetOptimalPruneSequence( ) const; 0083 0084 // return the quality index from the validation sample for the optimal subtree T' 0085 inline Float_t GetOptimalQualityIndex( ) const { return (fOptimalK >= 0 && fQualityIndexList.size() > 0 ? 0086 fQualityIndexList[fOptimalK] : -1.0); } 0087 0088 // return the prune strength (=alpha) corresponding to the prune sequence 0089 inline Float_t GetOptimalPruneStrength( ) const { return (fOptimalK >= 0 && fPruneStrengthList.size() > 0 ? 0090 fPruneStrengthList[fOptimalK] : -1.0); } 0091 0092 private: 0093 Float_t fAlpha; ///<! regularization parameter in CC pruning 0094 const EventList* fValidationSample; ///<! the event sample to select the optimally-pruned tree 0095 const DataSet* fValidationDataSet; ///<! the event sample to select the optimally-pruned tree 0096 SeparationBase* fQualityIndex; ///<! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) } 0097 Bool_t fOwnQIndex; ///<! flag indicates if fQualityIndex is owned by this 0098 0099 DecisionTree* fTree; ///<! (pruned) decision tree 0100 0101 std::vector<TMVA::DecisionTreeNode*> fPruneSequence; ///<! map of weakest links (i.e., branches to prune) -> pruning index 0102 std::vector<Float_t> fPruneStrengthList; ///<! map of alpha -> pruning index 0103 std::vector<Float_t> fQualityIndexList; ///<! map of R(T) -> pruning index 0104 0105 Int_t fOptimalK; ///<! index of the optimal tree in the pruned tree sequence 0106 Bool_t fDebug; ///<! debug flag 0107 }; 0108 } 0109 0110 inline void TMVA::CCPruner::SetPruneStrength( Float_t alpha ) { 0111 fAlpha = (alpha > 0 ? alpha : 0.0); 0112 } 0113 0114 0115 #endif 0116 0117
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |