|
||||
File indexing completed on 2025-01-18 10:10:59
0001 /********************************************************************************** 0002 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * 0003 * Package: TMVA * 0004 * Class : TMVA::DecisionTree * 0005 * * 0006 * * 0007 * Description: * 0008 * Implementation of a Decision Tree * 0009 * * 0010 * Authors (alphabetical): * 0011 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland * 0012 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany * 0013 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada * 0014 * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada * 0015 * * 0016 * Copyright (c) 2005: * 0017 * CERN, Switzerland * 0018 * U. of Victoria, Canada * 0019 * MPI-K Heidelberg, Germany * 0020 * * 0021 * Redistribution and use in source and binary forms, with or without * 0022 * modification, are permitted according to the terms listed in LICENSE * 0023 * (http://mva.sourceforge.net/license.txt) * 0024 * * 0025 **********************************************************************************/ 0026 0027 #ifndef ROOT_TMVA_ExpectedErrorPruneTool 0028 #define ROOT_TMVA_ExpectedErrorPruneTool 0029 0030 ///////////////////////////////////////////////////////////////////////////////////////////////////////////// 0031 // ExpectedErrorPruneTool - a helper class to prune a decision tree using the expected error (C4.5) method // 0032 // // 0033 // Uses an upper limit on the error made by the classification done by each node. If the S/S+B of the node // 0034 // is f, then according to the training sample, the error rate (fraction of misclassified events by this // 0035 // node) is (1-f). Now f has a statistical error according to the binomial distribution hence the error on // 0036 // f can be estimated (same error as the binomial error for efficiency calculations // 0037 // ( sigma = sqrt(eff(1-eff)/nEvts ) ) // 0038 // // 0039 // This tool prunes branches from a tree if the expected error of a node is less than that of the sum of // 0040 // the error in its descendants. // 0041 // // 0042 ///////////////////////////////////////////////////////////////////////////////////////////////////////////// 0043 0044 #include <vector> 0045 0046 #include "TMath.h" 0047 0048 #include "TMVA/IPruneTool.h" 0049 0050 namespace TMVA { 0051 0052 class MsgLogger; 0053 0054 class ExpectedErrorPruneTool : public IPruneTool { 0055 public: 0056 ExpectedErrorPruneTool( ); 0057 virtual ~ExpectedErrorPruneTool( ); 0058 0059 // returns the PruningInfo object for a given tree and test sample 0060 virtual PruningInfo* CalculatePruningInfo( DecisionTree* dt, const IPruneTool::EventSample* testEvents = nullptr, 0061 Bool_t isAutomatic = kFALSE ); 0062 0063 public: 0064 // set the increment dalpha with which to scan for the optimal prune strength 0065 inline void SetPruneStrengthIncrement( Double_t dalpha ) { fDeltaPruneStrength = dalpha; } 0066 0067 private: 0068 void FindListOfNodes( DecisionTreeNode* node ); 0069 Double_t GetNodeError( DecisionTreeNode* node ) const; 0070 Double_t GetSubTreeError( DecisionTreeNode* node ) const; 0071 Int_t CountNodes( DecisionTreeNode* node, Int_t icount = 0 ); 0072 0073 Double_t fDeltaPruneStrength; ///<! the stepsize for optimizing the pruning strength parameter 0074 Double_t fNodePurityLimit; ///<! the purity limit for labelling a terminal node as signal 0075 std::vector<DecisionTreeNode*> fPruneSequence; ///<! the (optimal) prune sequence 0076 // std::multimap<const Double_t, Double_t> fQualityMap; ///<! map of tree quality <=> prune strength 0077 mutable MsgLogger* fLogger; ///<! message logger 0078 MsgLogger& Log() const { return *fLogger; } 0079 }; 0080 0081 inline Int_t ExpectedErrorPruneTool::CountNodes( DecisionTreeNode* node, Int_t icount ) { 0082 DecisionTreeNode* l = (DecisionTreeNode*)node->GetLeft(); 0083 DecisionTreeNode* r = (DecisionTreeNode*)node->GetRight(); 0084 Int_t counter = icount + 1; // count this node 0085 if(!node->IsTerminal() && l && r) { 0086 counter = CountNodes(l,counter); 0087 counter = CountNodes(r,counter); 0088 } 0089 return counter; 0090 } 0091 } 0092 0093 #endif 0094
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |