File indexing completed on 2025-01-18 10:10:57
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024 #ifndef ROOT_TMVA_CCTreeWrapper
0025 #define ROOT_TMVA_CCTreeWrapper
0026
0027 #include "TMVA/Event.h"
0028 #include "TMVA/SeparationBase.h"
0029 #include "TMVA/DecisionTree.h"
0030 #include "TMVA/DataSet.h"
0031 #include "TMVA/Version.h"
0032 #include <vector>
0033 #include <string>
0034 #include <sstream>
0035
0036 namespace TMVA {
0037
0038 class CCTreeWrapper {
0039
0040 public:
0041
0042 typedef std::vector<Event*> EventList;
0043
0044
0045
0046
0047
0048
0049 class CCTreeNode : virtual public Node {
0050
0051 public:
0052
0053 CCTreeNode( DecisionTreeNode* n = nullptr );
0054 virtual ~CCTreeNode( );
0055
0056 virtual Node* CreateNode() const { return new CCTreeNode(); }
0057
0058
0059 inline void SetNLeafDaughters( Int_t N ) { fNLeafDaughters = (N > 0 ? N : 0); }
0060
0061
0062 inline Int_t GetNLeafDaughters() const { return fNLeafDaughters; }
0063
0064
0065 inline void SetNodeResubstitutionEstimate( Double_t R ) { fNodeResubstitutionEstimate = (R >= 0 ? R : 0.0); }
0066
0067
0068 inline Double_t GetNodeResubstitutionEstimate( ) const { return fNodeResubstitutionEstimate; }
0069
0070
0071
0072 inline void SetResubstitutionEstimate( Double_t R ) { fResubstitutionEstimate = (R >= 0 ? R : 0.0); }
0073
0074
0075 inline Double_t GetResubstitutionEstimate( ) const { return fResubstitutionEstimate; }
0076
0077
0078
0079
0080
0081
0082 inline void SetAlphaC( Double_t alpha ) { fAlphaC = alpha; }
0083
0084
0085 inline Double_t GetAlphaC( ) const { return fAlphaC; }
0086
0087
0088 inline void SetMinAlphaC( Double_t alpha ) { fMinAlphaC = alpha; }
0089
0090
0091 inline Double_t GetMinAlphaC( ) const { return fMinAlphaC; }
0092
0093
0094 inline DecisionTreeNode* GetDTNode( ) const { return fDTNode; }
0095
0096
0097 inline CCTreeNode* GetLeftDaughter( ) { return dynamic_cast<CCTreeNode*>(GetLeft()); }
0098 inline CCTreeNode* GetRightDaughter( ) { return dynamic_cast<CCTreeNode*>(GetRight()); }
0099 inline CCTreeNode* GetMother( ) { return dynamic_cast<CCTreeNode*>(GetParent()); }
0100
0101
0102 virtual void Print( std::ostream& os ) const;
0103
0104
0105 virtual void PrintRec ( std::ostream& os ) const;
0106
0107 virtual void AddAttributesToNode(void* node) const;
0108 virtual void AddContentToNode(std::stringstream& s) const;
0109
0110
0111
0112 inline virtual Bool_t GoesRight( const Event& e ) const { return GetDTNode() ?
0113 GetDTNode()->GoesRight(e) : false; }
0114
0115
0116 inline virtual Bool_t GoesLeft ( const Event& e ) const { return GetDTNode() ?
0117 GetDTNode()->GoesLeft(e) : false; }
0118
0119 virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
0120 virtual void ReadContent(std::stringstream& s);
0121 virtual Bool_t ReadDataRecord( std::istream& in, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
0122
0123 private:
0124
0125 Int_t fNLeafDaughters;
0126 Double_t fNodeResubstitutionEstimate;
0127 Double_t fResubstitutionEstimate;
0128 Double_t fAlphaC;
0129 Double_t fMinAlphaC;
0130 DecisionTreeNode* fDTNode;
0131 };
0132
0133 CCTreeWrapper( DecisionTree* T, SeparationBase* qualityIndex );
0134 ~CCTreeWrapper( );
0135
0136
0137 Double_t CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf = false );
0138
0139 Double_t TestTreeQuality( const EventList* validationSample );
0140 Double_t TestTreeQuality( const DataSet* validationSample );
0141
0142
0143 void PruneNode( CCTreeNode* t );
0144
0145 void InitTree( CCTreeNode* t );
0146
0147
0148 CCTreeNode* GetRoot() { return fRoot; }
0149 private:
0150 SeparationBase* fQualityIndex;
0151 DecisionTree* fDTParent;
0152 CCTreeNode* fRoot;
0153 };
0154
0155 }
0156
0157 #endif
0158
0159
0160