Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:51

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : DecisionTreeNode                                                      *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Node for the Decision Tree                                                *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
0015  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
0016  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
0017  *      Eckhard von Toerne <evt@physik.uni-bonn.de>  - U. of Bonn, Germany        *
0018  *                                                                                *
0019  * Copyright (c) 2009:                                                            *
0020  *      CERN, Switzerland                                                         *
0021  *      U. of Victoria, Canada                                                    *
0022  *      MPI-K Heidelberg, Germany                                                 *
0023  *       U. of Bonn, Germany                                                       *
0024  *                                                                                *
0025  * Redistribution and use in source and binary forms, with or without             *
0026  * modification, are permitted according to the terms listed in LICENSE           *
0027  * (see tmva/doc/LICENSE)                                          *
0028  **********************************************************************************/
0029 
0030 #ifndef ROOT_TMVA_DecisionTreeNode
0031 #define ROOT_TMVA_DecisionTreeNode
0032 
0033 //////////////////////////////////////////////////////////////////////////
0034 //                                                                      //
0035 // DecisionTreeNode                                                     //
0036 //                                                                      //
0037 // Node for the Decision Tree                                           //
0038 //                                                                      //
0039 //////////////////////////////////////////////////////////////////////////
0040 
0041 #include "TMVA/Node.h"
0042 
0043 #include "TMVA/Version.h"
0044 
0045 #include <sstream>
0046 #include <vector>
0047 #include <string>
0048 
0049 namespace TMVA {
0050 
0051    class DTNodeTrainingInfo
0052    {
0053    public:
0054    DTNodeTrainingInfo():fSampleMin(),
0055          fSampleMax(),
0056          fNodeR(0),fSubTreeR(0),fAlpha(0),fG(0),fNTerminal(0),
0057          fNB(0),fNS(0),fSumTarget(0),fSumTarget2(0),fCC(0),
0058          fNSigEvents ( 0 ), fNBkgEvents ( 0 ),
0059          fNEvents ( -1 ),
0060          fNSigEvents_unweighted ( 0 ),
0061          fNBkgEvents_unweighted ( 0 ),
0062          fNEvents_unweighted ( 0 ),
0063          fNSigEvents_unboosted ( 0 ),
0064          fNBkgEvents_unboosted ( 0 ),
0065          fNEvents_unboosted ( 0 ),
0066          fSeparationIndex (-1 ),
0067          fSeparationGain ( -1 )
0068             {
0069             }
0070       std::vector< Float_t >  fSampleMin; ///< the minima for each ivar of the sample on the node during training
0071       std::vector< Float_t >  fSampleMax; ///< the maxima for each ivar of the sample on the node during training
0072       Double_t fNodeR;                    ///< node resubstitution estimate, R(t)
0073       Double_t fSubTreeR;                 ///< R(T) = Sum(R(t) : t in ~T)
0074       Double_t fAlpha;                    ///< critical alpha for this node
0075       Double_t fG;                        ///< minimum alpha in subtree rooted at this node
0076       Int_t    fNTerminal;                ///< number of terminal nodes in subtree rooted at this node
0077       Double_t fNB;                       ///< sum of weights of background events from the pruning sample in this node
0078       Double_t fNS;                       ///< ditto for the signal events
0079       Float_t  fSumTarget;                ///< sum of weight*target  used for the calculation of the variance (regression)
0080       Float_t  fSumTarget2;               ///< sum of weight*target^2 used for the calculation of the variance (regression)
0081       Double_t fCC;                       ///< debug variable for cost complexity pruning ..
0082 
0083       Float_t  fNSigEvents;               ///< sum of weights of signal event in the node
0084       Float_t  fNBkgEvents;               ///< sum of weights of backgr event in the node
0085       Float_t  fNEvents;                  ///< number of events in that entered the node (during training)
0086       Float_t  fNSigEvents_unweighted;    ///< sum of signal event in the node
0087       Float_t  fNBkgEvents_unweighted;    ///< sum of backgr event in the node
0088       Float_t  fNEvents_unweighted;       ///< number of events in that entered the node (during training)
0089       Float_t  fNSigEvents_unboosted;     ///< sum of signal event in the node
0090       Float_t  fNBkgEvents_unboosted;     ///< sum of backgr event in the node
0091       Float_t  fNEvents_unboosted;        ///< number of events in that entered the node (during training)
0092       Float_t  fSeparationIndex;          ///< measure of "purity" (separation between S and B) AT this node
0093       Float_t  fSeparationGain;           ///< measure of "purity", separation, or information gained BY this nodes selection
0094 
0095       // copy constructor
0096    DTNodeTrainingInfo(const DTNodeTrainingInfo& n) :
0097       fSampleMin(),fSampleMax(),          ///< Samplemin and max are reset in copy constructor
0098          fNodeR(n.fNodeR), fSubTreeR(n.fSubTreeR),
0099          fAlpha(n.fAlpha), fG(n.fG),
0100          fNTerminal(n.fNTerminal),
0101          fNB(n.fNB), fNS(n.fNS),
0102          fSumTarget(0),fSumTarget2(0),    ///< SumTarget reset in copy constructor
0103          fCC(0),
0104          fNSigEvents ( n.fNSigEvents ), fNBkgEvents ( n.fNBkgEvents ),
0105          fNEvents ( n.fNEvents ),
0106          fNSigEvents_unweighted ( n.fNSigEvents_unweighted ),
0107          fNBkgEvents_unweighted ( n.fNBkgEvents_unweighted ),
0108          fNEvents_unweighted ( n.fNEvents_unweighted ),
0109          fSeparationIndex( n.fSeparationIndex ),
0110          fSeparationGain ( n.fSeparationGain )
0111             { }
0112    };
0113 
0114    class Event;
0115    class MsgLogger;
0116 
0117    class DecisionTreeNode: public Node {
0118 
0119    public:
0120 
0121       // constructor of an essentially "empty" node floating in space
0122       DecisionTreeNode ();
0123       // constructor of a daughter node as a daughter of 'p'
0124       DecisionTreeNode (Node* p, char pos);
0125 
0126       // copy constructor
0127       DecisionTreeNode (const DecisionTreeNode &n, DecisionTreeNode* parent = nullptr);
0128 
0129       // destructor
0130       virtual ~DecisionTreeNode();
0131 
0132       virtual Node* CreateNode() const { return new DecisionTreeNode(); }
0133 
0134       inline void SetNFisherCoeff(Int_t nvars){fFisherCoeff.resize(nvars);}
0135       inline UInt_t GetNFisherCoeff() const { return fFisherCoeff.size();}
0136       // set fisher coefficients
0137       void SetFisherCoeff(Int_t ivar, Double_t coeff);
0138       /// get fisher coefficients
0139       Double_t GetFisherCoeff(Int_t ivar) const {return fFisherCoeff.at(ivar);}
0140 
0141       // test event if it descends the tree at this node to the right
0142       virtual Bool_t GoesRight( const Event & ) const;
0143 
0144       // test event if it descends the tree at this node to the left
0145       virtual Bool_t GoesLeft ( const Event & ) const;
0146 
0147       /// set index of variable used for discrimination at this node
0148       void SetSelector( Short_t i) { fSelector = i; }
0149       /// return index of variable used for discrimination at this node
0150       Short_t GetSelector() const { return fSelector; }
0151 
0152       /// set the cut value applied at this node
0153       void  SetCutValue ( Float_t c ) { fCutValue  = c; }
0154       /// return the cut value applied at this node
0155       Float_t GetCutValue ( void ) const { return fCutValue;  }
0156 
0157       /// set true: if event variable > cutValue ==> signal , false otherwise
0158       void SetCutType( Bool_t t   ) { fCutType = t; }
0159       /// return kTRUE: Cuts select signal, kFALSE: Cuts select bkg
0160       Bool_t GetCutType( void ) const { return fCutType; }
0161 
0162       /// set node type: 1 signal node, -1 bkg leave, 0 intermediate Node
0163       void  SetNodeType( Int_t t ) { fNodeType = t;}
0164       /// return node type: 1 signal node, -1 bkg leave, 0 intermediate Node
0165       Int_t GetNodeType( void ) const { return fNodeType; }
0166 
0167       /// return  S/(S+B) (purity) at this node (from  training)
0168       Float_t GetPurity( void ) const { return fPurity;}
0169       // calculate S/(S+B) (purity) at this node (from  training)
0170       void SetPurity( void );
0171 
0172       /// set the response of the node (for regression)
0173       void SetResponse( Float_t r ) { fResponse = r;}
0174 
0175       /// return the response of the node (for regression)
0176       Float_t GetResponse( void ) const { return fResponse;}
0177 
0178       /// set the RMS of the response of the node (for regression)
0179       void SetRMS( Float_t r ) { fRMS = r;}
0180 
0181       /// return the RMS of the response of the node (for regression)
0182       Float_t GetRMS( void ) const { return fRMS;}
0183 
0184       /// set the sum of the signal weights in the node, if traininfo defined
0185       void SetNSigEvents( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents = s; }
0186 
0187       /// set the sum of the backgr weights in the node, if traininfo defined
0188       void SetNBkgEvents( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents = b; }
0189 
0190       /// set the number of events that entered the node (during training), if traininfo defined
0191       void SetNEvents( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents =nev ; }
0192 
0193       /// set the sum of the unweighted signal events in the node, if traininfo defined
0194       void SetNSigEvents_unweighted( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents_unweighted = s; }
0195 
0196       /// set the sum of the unweighted backgr events in the node, if traininfo defined
0197       void SetNBkgEvents_unweighted( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents_unweighted = b; }
0198 
0199       /// set the number of unweighted events that entered the node (during training), if traininfo defined
0200       void SetNEvents_unweighted( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents_unweighted =nev ; }
0201 
0202       /// set the sum of the unboosted signal events in the node, if traininfo defined
0203       void SetNSigEvents_unboosted( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents_unboosted = s; }
0204 
0205       /// set the sum of the unboosted backgr events in the node, if traininfo defined
0206       void SetNBkgEvents_unboosted( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents_unboosted = b; }
0207 
0208       /// set the number of unboosted events that entered the node (during training), if traininfo defined
0209       void SetNEvents_unboosted( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents_unboosted =nev ; }
0210 
0211       /// increment the sum of the signal weights in the node, if traininfo defined
0212       void IncrementNSigEvents( Float_t s ) { if(fTrainInfo) fTrainInfo->fNSigEvents += s; }
0213 
0214       /// increment the sum of the backgr weights in the node, if traininfo defined
0215       void IncrementNBkgEvents( Float_t b ) { if(fTrainInfo) fTrainInfo->fNBkgEvents += b; }
0216 
0217       // increment the number of events that entered the node (during training), if traininfo defined
0218       void IncrementNEvents( Float_t nev ){ if(fTrainInfo) fTrainInfo->fNEvents +=nev ; }
0219 
0220       /// increment the sum of the signal weights in the node, if traininfo defined
0221       void IncrementNSigEvents_unweighted( ) { if(fTrainInfo) fTrainInfo->fNSigEvents_unweighted += 1; }
0222 
0223       /// increment the sum of the backgr weights in the node, if traininfo defined
0224       void IncrementNBkgEvents_unweighted( ) { if(fTrainInfo) fTrainInfo->fNBkgEvents_unweighted += 1; }
0225 
0226       /// increment the number of events that entered the node (during training), if traininfo defined
0227       void IncrementNEvents_unweighted( ){ if(fTrainInfo) fTrainInfo->fNEvents_unweighted +=1 ; }
0228 
0229       /// return the sum of the signal weights in the node, or -1 if traininfo undefined
0230       Float_t GetNSigEvents( void ) const  { return fTrainInfo ? fTrainInfo->fNSigEvents : -1.; }
0231 
0232       /// return the sum of the backgr weights in the node, or -1 if traininfo undefined
0233       Float_t GetNBkgEvents( void ) const  { return fTrainInfo ? fTrainInfo->fNBkgEvents : -1.; }
0234 
0235       /// return  the number of events that entered the node (during training), or -1 if traininfo undefined
0236       Float_t GetNEvents( void ) const  { return fTrainInfo ? fTrainInfo->fNEvents : -1.; }
0237 
0238       // return the sum of unweighted signal weights in the node, or -1 if traininfo undefined
0239       Float_t GetNSigEvents_unweighted( void ) const  { return fTrainInfo ? fTrainInfo->fNSigEvents_unweighted : -1.; }
0240 
0241       /// return the sum of unweighted backgr weights in the node, or -1 if traininfo undefined
0242       Float_t GetNBkgEvents_unweighted( void ) const  { return fTrainInfo ? fTrainInfo->fNBkgEvents_unweighted : -1.; }
0243 
0244       /// return  the number of unweighted events that entered the node (during training), or -1 if traininfo undefined
0245       Float_t GetNEvents_unweighted( void ) const  { return fTrainInfo ? fTrainInfo->fNEvents_unweighted : -1.; }
0246 
0247       /// return the sum of unboosted signal weights in the node, or -1 if traininfo undefined
0248       Float_t GetNSigEvents_unboosted( void ) const  { return fTrainInfo ? fTrainInfo->fNSigEvents_unboosted : -1.; }
0249 
0250       /// return the sum of unboosted backgr weights in the node, or -1 if traininfo undefined
0251       Float_t GetNBkgEvents_unboosted( void ) const  { return fTrainInfo ? fTrainInfo->fNBkgEvents_unboosted : -1.; }
0252 
0253       /// return  the number of unboosted events that entered the node (during training), or -1 if traininfo undefined
0254       Float_t GetNEvents_unboosted( void ) const  { return fTrainInfo ? fTrainInfo->fNEvents_unboosted : -1.; }
0255 
0256       /// set the chosen index, measure of "purity" (separation between S and B) AT this node, if traininfo defined
0257       void SetSeparationIndex( Float_t sep ){ if(fTrainInfo) fTrainInfo->fSeparationIndex =sep ; }
0258 
0259       /// return the separation index AT this node, or 0 if traininfo undefined
0260       Float_t GetSeparationIndex( void ) const  { return fTrainInfo ? fTrainInfo->fSeparationIndex : -1.; }
0261 
0262       /// set the separation, or information gained BY this node's selection, if traininfo defined
0263       void SetSeparationGain( Float_t sep ){ if(fTrainInfo) fTrainInfo->fSeparationGain =sep ; }
0264 
0265       /// return the gain in separation obtained by this node's selection, or -1 if traininfo undefined
0266       Float_t GetSeparationGain( void ) const  { return fTrainInfo ? fTrainInfo->fSeparationGain : -1.; }
0267 
0268       // printout of the node
0269       virtual void Print( std::ostream& os ) const;
0270 
0271       // recursively print the node and its daughters (--> print the 'tree')
0272       virtual void PrintRec( std::ostream&  os ) const;
0273 
0274       virtual void AddAttributesToNode(void* node) const;
0275       virtual void AddContentToNode(std::stringstream& s) const;
0276 
0277       // recursively clear the nodes content (S/N etc, but not the cut criteria)
0278       void ClearNodeAndAllDaughters();
0279 
0280       // get pointers to children, mother in the tree
0281 
0282       // return pointer to the left/right daughter or parent node
0283       inline virtual DecisionTreeNode* GetLeft( )   const { return static_cast<DecisionTreeNode*>(fLeft); }
0284       inline virtual DecisionTreeNode* GetRight( )  const { return static_cast<DecisionTreeNode*>(fRight); }
0285       inline virtual DecisionTreeNode* GetParent( ) const { return static_cast<DecisionTreeNode*>(fParent); }
0286 
0287       // set pointer to the left/right daughter and parent node
0288       inline virtual void SetLeft  (Node* l) { fLeft   = l;}
0289       inline virtual void SetRight (Node* r) { fRight  = r;}
0290       inline virtual void SetParent(Node* p) { fParent = p;}
0291 
0292       /// set the node resubstitution estimate, R(t), for Cost Complexity pruning, if traininfo defined
0293       inline void SetNodeR( Double_t r ) { if(fTrainInfo) fTrainInfo->fNodeR = r;    }
0294       /// return the node resubstitution estimate, R(t), for Cost Complexity pruning, or -1 if traininfo undefined
0295       inline Double_t GetNodeR( ) const  { return fTrainInfo ? fTrainInfo->fNodeR : -1.; }
0296 
0297       /// set the resubstitution estimate, R(T_t), of the tree rooted at this node, if traininfo defined
0298       inline void SetSubTreeR( Double_t r ) { if(fTrainInfo) fTrainInfo->fSubTreeR = r;    }
0299       /// return the resubstitution estimate, R(T_t), of the tree rooted at this node, or -1 if traininfo undefined
0300       inline Double_t GetSubTreeR( ) const  { return fTrainInfo ? fTrainInfo->fSubTreeR : -1.; }
0301 
0302       //                             R(t) - R(T_t)
0303       // the critical point alpha =  -------------
0304       //                              |~T_t| - 1
0305       /// set the critical point alpha, if traininfo defined
0306       inline void SetAlpha( Double_t alpha ) { if(fTrainInfo) fTrainInfo->fAlpha = alpha; }
0307       /// return the critical point alpha, or -1 if traininfo undefined
0308       inline Double_t GetAlpha( ) const      { return fTrainInfo ? fTrainInfo->fAlpha : -1.;  }
0309 
0310       /// set the minimum alpha in the tree rooted at this node, if traininfo defined
0311       inline void SetAlphaMinSubtree( Double_t g ) { if(fTrainInfo) fTrainInfo->fG = g;    }
0312       /// return the minimum alpha in the tree rooted at this node, or -1 if traininfo undefined
0313       inline Double_t GetAlphaMinSubtree( ) const  { return fTrainInfo ? fTrainInfo->fG : -1.; }
0314 
0315       /// set number of terminal nodes in the subtree rooted here, if traininfo defined
0316       inline void SetNTerminal( Int_t n ) { if(fTrainInfo) fTrainInfo->fNTerminal = n;    }
0317       /// return number of terminal nodes in the subtree rooted here, or -1 if traininfo undefined
0318       inline Int_t GetNTerminal( ) const  { return fTrainInfo ? fTrainInfo->fNTerminal : -1.; }
0319 
0320       /// set number of background events from the pruning validation sample, if traininfo defined
0321       inline void SetNBValidation( Double_t b ) { if(fTrainInfo) fTrainInfo->fNB = b; }
0322       /// set number of signal events from the pruning validation sample, if traininfo defined
0323       inline void SetNSValidation( Double_t s ) { if(fTrainInfo) fTrainInfo->fNS = s; }
0324       /// return number of background events from the pruning validation sample, or -1 if traininfo undefined
0325       inline Double_t GetNBValidation( ) const  { return fTrainInfo ? fTrainInfo->fNB : -1.; }
0326       /// return number of signal events from the pruning validation sample, or -1 if traininfo undefined
0327       inline Double_t GetNSValidation( ) const  { return fTrainInfo ? fTrainInfo->fNS : -1.; }
0328 
0329       /// set sum target, if traininfo defined
0330       inline void SetSumTarget(Float_t t)  {if(fTrainInfo) fTrainInfo->fSumTarget = t; }
0331       /// set sum target 2, if traininfo defined
0332       inline void SetSumTarget2(Float_t t2){if(fTrainInfo) fTrainInfo->fSumTarget2 = t2; }
0333 
0334       /// add to sum target, if traininfo defined
0335       inline void AddToSumTarget(Float_t t)  {if(fTrainInfo) fTrainInfo->fSumTarget += t; }
0336       /// add to sum target 2, if traininfo defined
0337       inline void AddToSumTarget2(Float_t t2){if(fTrainInfo) fTrainInfo->fSumTarget2 += t2; }
0338 
0339       /// return sum target, or -9999 if traininfo undefined
0340       inline Float_t GetSumTarget()  const {return fTrainInfo? fTrainInfo->fSumTarget : -9999;}
0341       /// return sum target 2, or -9999 if traininfo undefined
0342       inline Float_t GetSumTarget2() const {return fTrainInfo? fTrainInfo->fSumTarget2: -9999;}
0343 
0344 
0345       // reset the pruning validation data
0346       void ResetValidationData( );
0347 
0348       /// flag indicates whether this node is terminal
0349       inline Bool_t IsTerminal() const            { return fIsTerminalNode; }
0350       inline void SetTerminal( Bool_t s = kTRUE ) { fIsTerminalNode = s;    }
0351       void PrintPrune( std::ostream& os ) const ;
0352       void PrintRecPrune( std::ostream& os ) const;
0353 
0354       void     SetCC(Double_t cc);
0355       /// return CC, or -1 if traininfo undefined
0356       Double_t GetCC() const {return (fTrainInfo? fTrainInfo->fCC : -1.);}
0357 
0358       Float_t GetSampleMin(UInt_t ivar) const;
0359       Float_t GetSampleMax(UInt_t ivar) const;
0360       void     SetSampleMin(UInt_t ivar, Float_t xmin);
0361       void     SetSampleMax(UInt_t ivar, Float_t xmax);
0362 
0363       static void SetIsTraining(bool on);
0364       static void SetTmvaVersionCode(UInt_t code);
0365 
0366       static bool IsTraining();
0367       static UInt_t GetTmvaVersionCode();
0368 
0369       virtual Bool_t ReadDataRecord( std::istream& is, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
0370       virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
0371       virtual void ReadContent(std::stringstream& s);
0372 
0373    protected:
0374 
0375       static MsgLogger& Log();
0376 
0377       static bool fgIsTraining;           ///< static variable to flag training phase in which we need fTrainInfo
0378       static UInt_t fgTmva_Version_Code;  ///< set only when read from weightfile
0379 
0380       std::vector<Double_t> fFisherCoeff; ///< the fisher coeff (offset at the last element)
0381 
0382       Float_t  fCutValue;                 ///< cut value applied on this node to discriminate bkg against sig
0383       Bool_t   fCutType;                  ///< true: if event variable > cutValue ==> signal , false otherwise
0384       Short_t  fSelector;                 ///< index of variable used in node selection (decision tree)
0385 
0386       Float_t  fResponse;                 ///< response value in case of regression
0387       Float_t  fRMS;                      ///< response RMS of the regression node
0388       Int_t    fNodeType;                 ///< Type of node: -1 == Bkg-leaf, 1 == Signal-leaf, 0 = internal
0389       Float_t  fPurity;                   ///< the node purity
0390 
0391       Bool_t   fIsTerminalNode;           ///<! flag to set node as terminal (i.e., without deleting its descendants)
0392 
0393       mutable DTNodeTrainingInfo* fTrainInfo;
0394 
0395    private:
0396 
0397       ClassDef(DecisionTreeNode,0); // Node for the Decision Tree
0398    };
0399 } // namespace TMVA
0400 
0401 #endif