Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:23:03

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : Rule                                                                  *
0008  *                                                                                *
0009  * Description:                                                                   *
0010  *      A class describing a 'rule'                                               *
0011  *      Each internal node of a tree defines a rule from all the parental nodes.  *
0012  *      A rule consists of at least 2 nodes.                                      *
0013  *      Input: a decision tree (in the constructor)                               *
0014  *             its coefficient                                                    *
0015  *                                                                                *
0016  *                                                                                *
0017  * Authors (alphabetical):                                                        *
0018  *      Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA      *
0019  *      Helge Voss         <Helge.Voss@cern.ch>         - MPI-KP Heidelberg, Ger. *
0020  *                                                                                *
0021  * Copyright (c) 2005:                                                            *
0022  *      CERN, Switzerland                                                         *
0023  *      Iowa State U.                                                             *
0024  *      MPI-K Heidelberg, Germany                                                 *
0025  *                                                                                *
0026  * Redistribution and use in source and binary forms, with or without             *
0027  * modification, are permitted according to the terms listed in LICENSE           *
0028  * (see tmva/doc/LICENSE)                                          *
0029  **********************************************************************************/
0030 
0031 #ifndef ROOT_TMVA_Rule
0032 #define ROOT_TMVA_Rule
0033 
0034 #include "TMath.h"
0035 #include <vector>
0036 #include <iostream>
0037 
0038 #include "TMVA/DecisionTree.h"
0039 #include "TMVA/Event.h"
0040 #include "TMVA/RuleCut.h"
0041 
0042 namespace TMVA {
0043 
0044    class RuleEnsemble;
0045    class MsgLogger;
0046    class Rule;
0047 
0048    std::ostream& operator<<( std::ostream& os, const Rule & rule );
0049 
0050    class Rule {
0051 
0052       // output operator for a Rule
0053       friend std::ostream& operator<< ( std::ostream& os, const Rule & rule );
0054 
0055    public:
0056 
0057       // main constructor
0058       Rule( RuleEnsemble *re, const std::vector< const TMVA::Node * > & nodes );
0059 
0060       // main constructor
0061       Rule( RuleEnsemble *re );
0062 
0063       // copy constructor
0064       Rule( const Rule & other ) { Copy( other ); }
0065 
0066       // empty constructor
0067       Rule();
0068 
0069       virtual ~Rule();
0070 
0071       // set message type
0072       void SetMsgType( EMsgType t );
0073 
0074       // set RuleEnsemble ptr
0075       void SetRuleEnsemble( const RuleEnsemble *re ) { fRuleEnsemble = re; }
0076 
0077       // set RuleCut ptr
0078       void SetRuleCut( RuleCut *rc )           { fCut = rc; }
0079 
0080       // set Rule norm
0081       void SetNorm(Double_t norm)       { fNorm = (norm>0 ? 1.0/norm:1.0); }
0082 
0083       // set coefficient
0084       void SetCoefficient(Double_t v)   { fCoefficient=v; }
0085 
0086       // set support
0087       void SetSupport(Double_t v)       { fSupport=v; fSigma = TMath::Sqrt(v*(1.0-v));}
0088 
0089       // set s/(s+b)
0090       void SetSSB(Double_t v)           { fSSB=v; }
0091 
0092       // set N(eve) accepted by rule
0093       void SetSSBNeve(Double_t v)       { fSSBNeve=v; }
0094 
0095       // set reference importance
0096       void SetImportanceRef(Double_t v) { fImportanceRef=(v>0 ? v:1.0); }
0097 
0098       // calculate importance
0099       void CalcImportance()             { fImportance = TMath::Abs(fCoefficient)*fSigma; }
0100 
0101       // get the relative importance
0102       Double_t GetRelImportance()  const { return fImportance/fImportanceRef; }
0103 
0104       // evaluate the Rule for the given Event using the coefficient
0105       //      inline Double_t EvalEvent( const Event& e, Bool_t norm ) const;
0106 
0107       // evaluate the Rule for the given Event, not using normalization or the coefficient
0108       inline Bool_t EvalEvent( const Event& e ) const;
0109 
0110       // test if two rules are equal
0111       Bool_t Equal( const Rule & other, Bool_t useCutValue, Double_t maxdist ) const;
0112 
0113       // get distance between two equal (ie apart from the cut values) rules
0114       Double_t RuleDist( const Rule & other, Bool_t useCutValue ) const;
0115 
0116       // returns true if the trained S/(S+B) of the last node is > 0.5
0117       Double_t GetSSB()       const { return fSSB; }
0118       Double_t GetSSBNeve()   const { return fSSBNeve; }
0119       Bool_t   IsSignalRule() const { return (fSSB>0.5); }
0120 
0121       // copy operator
0122       void operator=( const Rule & other )  { Copy( other ); }
0123 
0124       // identical operator
0125       Bool_t operator==( const Rule & other ) const;
0126 
0127       Bool_t operator<( const Rule & other ) const;
0128 
0129       // get number of variables used in Rule
0130       UInt_t GetNumVarsUsed() const { return fCut->GetNvars(); }
0131 
0132       // get number of cuts in Rule
0133       UInt_t GetNcuts() const { return fCut->GetNcuts(); }
0134 
0135       // check if variable is used by the rule
0136       Bool_t ContainsVariable(UInt_t iv) const;
0137 
0138       // accessors
0139       const RuleCut*      GetRuleCut()       const { return fCut; }
0140       const RuleEnsemble* GetRuleEnsemble()  const { return fRuleEnsemble; }
0141       Double_t            GetCoefficient()   const { return fCoefficient; }
0142       Double_t            GetSupport()       const { return fSupport; }
0143       Double_t            GetSigma()         const { return fSigma; }
0144       Double_t            GetNorm()          const { return fNorm; }
0145       Double_t            GetImportance()    const { return fImportance; }
0146       Double_t            GetImportanceRef() const { return fImportanceRef; }
0147 
0148       // print the rule using flogger
0149       void PrintLogger( const char *title=nullptr ) const;
0150 
0151       // print just the raw info, used for weight file generation
0152       void  PrintRaw   ( std::ostream& os  ) const; // obsolete
0153       void* AddXMLTo   ( void* parent ) const;
0154 
0155       void  ReadRaw    ( std::istream& os    ); // obsolete
0156       void  ReadFromXML( void* wghtnode );
0157 
0158    private:
0159 
0160       // set sigma - don't use this as non private!
0161       void SetSigma(Double_t v)         { fSigma=v; }
0162 
0163       // print info about the Rule
0164       void Print( std::ostream& os ) const;
0165 
0166       // copy from another rule
0167       void Copy( const Rule & other );
0168 
0169       // get the name of variable with index i
0170       const TString & GetVarName( Int_t i) const;
0171 
0172       RuleCut*             fCut;           ///< all cuts associated with the rule
0173       Double_t             fNorm;          ///< normalization - usually 1.0/t(k)
0174       Double_t             fSupport;       ///< s(k)
0175       Double_t             fSigma;         ///< t(k) = sqrt(s*(1-s))
0176       Double_t             fCoefficient;   ///< rule coeff. a(k)
0177       Double_t             fImportance;    ///< importance of rule
0178       Double_t             fImportanceRef; ///< importance ref
0179       const RuleEnsemble*  fRuleEnsemble;  ///< pointer to parent RuleEnsemble
0180       Double_t             fSSB;           ///< S/(S+B) for rule
0181       Double_t             fSSBNeve;       ///< N(events) reaching the last node in reevaluation
0182 
0183       mutable MsgLogger*   fLogger;        ///<! message logger
0184       MsgLogger& Log() const { return *fLogger; }
0185 
0186    };
0187 
0188 } // end of TMVA namespace
0189 
0190 //_______________________________________________________________________
0191 inline Bool_t TMVA::Rule::EvalEvent( const TMVA::Event& e ) const
0192 {
0193    // Checks if event is accepted by rule.
0194    // Return true if yes and false if not.
0195    //
0196    return fCut->EvalEvent(e);
0197 }
0198 
0199 #endif