|
||||
File indexing completed on 2025-01-18 10:10:58
0001 /********************************************************************************** 0002 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * 0003 * Package: TMVA * 0004 * Class : TMVA::DecisionTree * 0005 * * 0006 * * 0007 * Description: * 0008 * Implementation of a Decision Tree * 0009 * * 0010 * Authors (alphabetical): * 0011 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland * 0012 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany * 0013 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada * 0014 * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada * 0015 * * 0016 * Copyright (c) 2005: * 0017 * CERN, Switzerland * 0018 * U. of Victoria, Canada * 0019 * MPI-K Heidelberg, Germany * 0020 * * 0021 * Redistribution and use in source and binary forms, with or without * 0022 * modification, are permitted according to the terms listed in LICENSE * 0023 * (http://mva.sourceforge.net/license.txt) * 0024 * * 0025 **********************************************************************************/ 0026 0027 #ifndef ROOT_TMVA_CostComplexityPruneTool 0028 #define ROOT_TMVA_CostComplexityPruneTool 0029 0030 //////////////////////////////////////////////////////////////////////////////////////////////////////////// 0031 // CostComplexityPruneTool - a class to prune a decision tree using the Cost Complexity method // 0032 // (see "Classification and Regression Trees" by Leo Breiman et al) // 0033 // // 0034 // Some definitions: // 0035 // // 0036 // T_max - the initial, usually highly overtrained tree, that is to be pruned back // 0037 // R(T) - quality index (Gini, misclassification rate, or other) of a tree T // 0038 // ~T - set of terminal nodes in T // 0039 // T' - the pruned subtree of T_max that has the best quality index R(T') // 0040 // alpha - the prune strength parameter in Cost Complexity pruning (R_alpha(T) = R(T) + alpha*|~T|) // 0041 // // 0042 // There are two running modes in CostComplexityPruneTool: (i) one may select a prune strength and prune // 0043 // the tree T_max until the criterion // 0044 // R(T) - R(t) // 0045 // alpha < ---------- // 0046 // |~T_t| - 1 // 0047 // // 0048 // is true for all nodes t in T, or (ii) the algorithm finds the sequence of critical points // 0049 // alpha_k < alpha_k+1 ... < alpha_K such that T_K = root(T_max) and then selects the optimally-pruned // 0050 // subtree, defined to be the subtree with the best quality index for the validation sample. // 0051 //////////////////////////////////////////////////////////////////////////////////////////////////////////// 0052 0053 #include "TMVA/SeparationBase.h" 0054 #include "TMVA/GiniIndex.h" 0055 #include "TMVA/DecisionTree.h" 0056 #include "TMVA/Event.h" 0057 #include "TMVA/IPruneTool.h" 0058 #include <vector> 0059 0060 namespace TMVA { 0061 0062 class CostComplexityPruneTool : public IPruneTool { 0063 public: 0064 CostComplexityPruneTool( SeparationBase* qualityIndex = nullptr ); 0065 virtual ~CostComplexityPruneTool( ); 0066 0067 // calculate the prune sequence for a given tree 0068 virtual PruningInfo* CalculatePruningInfo( DecisionTree* dt, const IPruneTool::EventSample* testEvents = nullptr, Bool_t isAutomatic = kFALSE ); 0069 0070 private: 0071 SeparationBase* fQualityIndexTool; ///<! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) } 0072 0073 std::vector<DecisionTreeNode*> fPruneSequence; ///<! map of weakest links (i.e., branches to prune) -> pruning index 0074 std::vector<Double_t> fPruneStrengthList; ///<! map of alpha -> pruning index 0075 std::vector<Double_t> fQualityIndexList; ///<! map of R(T) -> pruning index 0076 0077 Int_t fOptimalK; ///<! the optimal index of the prune sequence 0078 0079 private: 0080 // set the meta data used for cost complexity pruning 0081 void InitTreePruningMetaData( DecisionTreeNode* n ); 0082 0083 // optimize the pruning sequence 0084 void Optimize( DecisionTree* dt, Double_t weights ); 0085 0086 mutable MsgLogger* fLogger; //! output stream to save logging information 0087 MsgLogger& Log() const { return *fLogger; } 0088 0089 }; 0090 } 0091 0092 0093 #endif
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |