Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:57

0001 
0002 /**********************************************************************************
0003  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0004  * Package: TMVA                                                                  *
0005  * Class  : CCTreeWrapper                                                         *
0006  *                                             *
0007  *                                                                                *
0008  * Description: a light wrapper of a decision tree, used to perform cost          *
0009  *              complexity pruning "in-place" Cost Complexity Pruning             *
0010  *                                                                                *
0011  * Author: Doug Schouten (dschoute@sfu.ca)                                        *
0012  *                                                                                *
0013  *                                                                                *
0014  * Copyright (c) 2007:                                                            *
0015  *      CERN, Switzerland                                                         *
0016  *      MPI-K Heidelberg, Germany                                                 *
0017  *      U. of Texas at Austin, USA                                                *
0018  *                                                                                *
0019  * Redistribution and use in source and binary forms, with or without             *
0020  * modification, are permitted according to the terms listed in LICENSE           *
0021  * (see tmva/doc/LICENSE)                                          *
0022  **********************************************************************************/
0023 
0024 #ifndef ROOT_TMVA_CCTreeWrapper
0025 #define ROOT_TMVA_CCTreeWrapper
0026 
0027 #include "TMVA/Event.h"
0028 #include "TMVA/SeparationBase.h"
0029 #include "TMVA/DecisionTree.h"
0030 #include "TMVA/DataSet.h"
0031 #include "TMVA/Version.h"
0032 #include <vector>
0033 #include <string>
0034 #include <sstream>
0035 
0036 namespace TMVA {
0037 
0038    class CCTreeWrapper {
0039 
0040    public:
0041 
0042       typedef std::vector<Event*> EventList;
0043 
0044       /////////////////////////////////////////////////////////////
0045       // CCTreeNode - a light wrapper of a decision tree node    //
0046       //                                                         //
0047       /////////////////////////////////////////////////////////////
0048 
0049       class CCTreeNode : virtual public Node {
0050 
0051       public:
0052 
0053          CCTreeNode( DecisionTreeNode* n = nullptr );
0054          virtual ~CCTreeNode( );
0055 
0056          virtual Node* CreateNode() const { return new CCTreeNode(); }
0057 
0058          // set |~T_t|, the number of terminal descendants of node t
0059          inline void SetNLeafDaughters( Int_t N ) { fNLeafDaughters = (N > 0 ? N : 0); }
0060 
0061          // return |~T_t|
0062          inline Int_t GetNLeafDaughters() const { return fNLeafDaughters; }
0063 
0064          // set R(t), the node resubstitution estimate (Gini, misclassification, etc.) for the node t
0065          inline void SetNodeResubstitutionEstimate( Double_t R ) { fNodeResubstitutionEstimate = (R >= 0 ? R : 0.0); }
0066 
0067          // return R(t) for node t
0068          inline Double_t GetNodeResubstitutionEstimate( ) const { return fNodeResubstitutionEstimate; }
0069 
0070          // set R(T_t) = sum[t' in ~T_t]{ R(t) }, the resubstitution estimate for the branch rooted at
0071          // node t (it is an estimate because it is calculated from the training dataset, i.e., the original tree)
0072          inline void SetResubstitutionEstimate( Double_t R ) { fResubstitutionEstimate = (R >= 0 ?  R : 0.0); }
0073 
0074          // return R(T_t) for node t
0075          inline Double_t GetResubstitutionEstimate( ) const { return fResubstitutionEstimate; }
0076 
0077          // set the critical point of alpha
0078          //             R(t) - R(T_t)
0079          //  alpha_c <  ------------- := g(t)
0080          //              |~T_t| - 1
0081          // which is the value of alpha such that the branch rooted at node t is pruned
0082          inline void SetAlphaC( Double_t alpha ) { fAlphaC = alpha; }
0083 
0084          // get the critical alpha value for this node
0085          inline Double_t GetAlphaC( ) const { return fAlphaC; }
0086 
0087          // set the minimum critical alpha value for descendants of node t ( G(t) = min(alpha_c, g(t_l), g(t_r)) )
0088          inline void SetMinAlphaC( Double_t alpha ) { fMinAlphaC = alpha; }
0089 
0090          // get the minimum critical alpha value
0091          inline Double_t GetMinAlphaC( ) const { return fMinAlphaC; }
0092 
0093          // get the pointer to the wrapped DT node
0094          inline DecisionTreeNode* GetDTNode( ) const { return fDTNode; }
0095 
0096          // get pointers to children, mother in the CC tree
0097          inline CCTreeNode* GetLeftDaughter( ) { return dynamic_cast<CCTreeNode*>(GetLeft()); }
0098          inline CCTreeNode* GetRightDaughter( ) { return dynamic_cast<CCTreeNode*>(GetRight()); }
0099          inline CCTreeNode* GetMother( ) { return dynamic_cast<CCTreeNode*>(GetParent()); }
0100 
0101          // printout of the node (can be read in with ReadDataRecord)
0102          virtual void Print( std::ostream& os ) const;
0103 
0104          // recursive printout of the node and its daughters
0105          virtual void PrintRec ( std::ostream& os ) const;
0106 
0107          virtual void AddAttributesToNode(void* node) const;
0108          virtual void AddContentToNode(std::stringstream& s) const;
0109 
0110 
0111          // test event if it descends the tree at this node to the right
0112          inline virtual Bool_t GoesRight( const Event& e ) const { return GetDTNode() ?
0113                                                                            GetDTNode()->GoesRight(e) : false; }
0114 
0115          // test event if it descends the tree at this node to the left
0116          inline virtual Bool_t GoesLeft ( const Event& e ) const { return GetDTNode() ?
0117                                                                            GetDTNode()->GoesLeft(e) : false; }
0118          // initialize a node from a data record
0119          virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
0120          virtual void ReadContent(std::stringstream& s);
0121          virtual Bool_t ReadDataRecord( std::istream& in, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
0122 
0123       private:
0124 
0125          Int_t fNLeafDaughters; //! number of terminal descendants
0126          Double_t fNodeResubstitutionEstimate; //! R(t) = misclassification rate for node t
0127          Double_t fResubstitutionEstimate; //! R(T_t) = sum[t' in ~T_t]{ R(t) }
0128          Double_t fAlphaC; //! critical point, g(t) = alpha_c(t)
0129          Double_t fMinAlphaC; //! G(t), minimum critical point of t and its descendants
0130          DecisionTreeNode* fDTNode; //! pointer to wrapped node in the decision tree
0131       };
0132 
0133       CCTreeWrapper( DecisionTree* T,  SeparationBase* qualityIndex );
0134       ~CCTreeWrapper( );
0135 
0136       // return the decision tree output for an event
0137       Double_t CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf = false );
0138       // return the misclassification rate of a pruned tree for a validation event sample
0139       Double_t TestTreeQuality( const EventList* validationSample );
0140       Double_t TestTreeQuality( const DataSet* validationSample );
0141 
0142       // remove the branch rooted at node t
0143       void PruneNode( CCTreeNode* t );
0144       // initialize the node t and all its descendants
0145       void InitTree( CCTreeNode* t );
0146 
0147       // return the root node for this tree
0148       CCTreeNode* GetRoot() { return fRoot; }
0149    private:
0150       SeparationBase* fQualityIndex;  ///<! pointer to the used quality index calculator
0151       DecisionTree* fDTParent;        ///<! pointer to underlying DecisionTree
0152       CCTreeNode* fRoot;              ///<! the root node of the (wrapped) decision Tree
0153    };
0154 
0155 }
0156 
0157 #endif
0158 
0159 
0160