Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:01

0001 // @(#)root/tmva $Id$
0002 // Author: Fredrik Tegenfeldt
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : MethodRuleFit                                                         *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Friedman's RuleFit method                                                 *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA      *
0015  *                                                                                *
0016  * Copyright (c) 2005:                                                            *
0017  *      CERN, Switzerland                                                         *
0018  *      Iowa State U.                                                             *
0019  *      MPI-K Heidelberg, Germany                                                 *
0020  *                                                                                *
0021  * Redistribution and use in source and binary forms, with or without             *
0022  * modification, are permitted according to the terms listed in LICENSE           *
0023  *                                                                                *
0024  **********************************************************************************/
0025 
0026 #ifndef ROOT_TMVA_MethodRuleFit
0027 #define ROOT_TMVA_MethodRuleFit
0028 
0029 //////////////////////////////////////////////////////////////////////////
0030 //                                                                      //
0031 // MethodRuleFit                                                        //
0032 //                                                                      //
0033 // J Friedman's RuleFit method                                          //
0034 //                                                                      //
0035 //////////////////////////////////////////////////////////////////////////
0036 
0037 #include "TMVA/MethodBase.h"
0038 #include "TMatrixDfwd.h"
0039 #include "TVectorD.h"
0040 #include "TMVA/DecisionTree.h"
0041 #include "TMVA/RuleFit.h"
0042 #include <vector>
0043 
0044 namespace TMVA {
0045 
0046    class SeparationBase;
0047 
0048    class MethodRuleFit : public MethodBase {
0049 
0050    public:
0051 
0052       MethodRuleFit( const TString& jobName,
0053                      const TString& methodTitle,
0054                      DataSetInfo& theData,
0055                      const TString& theOption = "");
0056 
0057       MethodRuleFit( DataSetInfo& theData,
0058                      const TString& theWeightFile);
0059 
0060       virtual ~MethodRuleFit( void );
0061 
0062       virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t /*numberTargets*/ );
0063 
0064       // training method
0065       void Train( void );
0066 
0067       using MethodBase::ReadWeightsFromStream;
0068 
0069       // write weights to file
0070       void AddWeightsXMLTo     ( void* parent ) const;
0071 
0072       // read weights from file
0073       void ReadWeightsFromStream( std::istream& istr );
0074       void ReadWeightsFromXML   ( void* wghtnode );
0075 
0076       // calculate the MVA value
0077       Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0078 
0079       // write method specific histos to target file
0080       void WriteMonitoringHistosToFile( void ) const;
0081 
0082       // ranking of input variables
0083       const Ranking* CreateRanking();
0084 
0085       Bool_t                                   UseBoost()           const   { return fUseBoost; }
0086 
0087       // accessors
0088       RuleFit*                                 GetRuleFitPtr()              { return &fRuleFit; }
0089       const RuleFit*                           GetRuleFitConstPtr() const   { return &fRuleFit; }
0090       TDirectory*                              GetMethodBaseDir()   const   { return BaseDir(); }
0091       const std::vector<TMVA::Event*>&         GetTrainingEvents()  const   { return fEventSample; }
0092       const std::vector<TMVA::DecisionTree*>&  GetForest()          const   { return fForest; }
0093       Int_t                                    GetNTrees()          const   { return fNTrees; }
0094       Double_t                                 GetTreeEveFrac()     const   { return fTreeEveFrac; }
0095       const SeparationBase*                    GetSeparationBaseConst() const { return fSepType; }
0096       SeparationBase*                          GetSeparationBase()  const   { return fSepType; }
0097       TMVA::DecisionTree::EPruneMethod         GetPruneMethod()     const   { return fPruneMethod; }
0098       Double_t                                 GetPruneStrength()   const   { return fPruneStrength; }
0099       Double_t                                 GetMinFracNEve()     const   { return fMinFracNEve; }
0100       Double_t                                 GetMaxFracNEve()     const   { return fMaxFracNEve; }
0101       Int_t                                    GetNCuts()           const   { return fNCuts; }
0102       //
0103       Int_t                                    GetGDNPathSteps()    const   { return fGDNPathSteps; }
0104       Double_t                                 GetGDPathStep()      const   { return fGDPathStep; }
0105       Double_t                                 GetGDErrScale()      const   { return fGDErrScale; }
0106       Double_t                                 GetGDPathEveFrac()   const   { return fGDPathEveFrac; }
0107       Double_t                                 GetGDValidEveFrac()  const   { return fGDValidEveFrac; }
0108       //
0109       Double_t                                 GetLinQuantile()     const   { return fLinQuantile; }
0110 
0111       const TString                            GetRFWorkDir()       const   { return fRFWorkDir; }
0112       Int_t                                    GetRFNrules()        const   { return fRFNrules; }
0113       Int_t                                    GetRFNendnodes()     const   { return fRFNendnodes; }
0114 
0115    protected:
0116 
0117       // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
0118       void MakeClassSpecific( std::ostream&, const TString& ) const;
0119 
0120       void MakeClassRuleCuts( std::ostream& ) const;
0121 
0122       void MakeClassLinear( std::ostream& ) const;
0123 
0124       // get help message text
0125       void GetHelpMessage() const;
0126 
0127       // initialize rulefit
0128       void Init( void );
0129 
0130       // copy all training events into a stl::vector
0131       void InitEventSample( void );
0132 
0133       // initialize monitor ntuple
0134       void InitMonitorNtuple();
0135 
0136       void TrainTMVARuleFit();
0137       void TrainJFRuleFit();
0138 
0139    private:
0140 
0141       // check variable range and set var to lower or upper if out of range
0142       template<typename T>
0143          inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax );
0144 
0145       template<typename T>
0146          inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef );
0147 
0148       template<typename T>
0149          inline Int_t VerifyRange( const T& var, const T& vmin, const T& vmax );
0150 
0151       // the option handling methods
0152       void DeclareOptions();
0153       void ProcessOptions();
0154 
0155       RuleFit                      fRuleFit;        ///< RuleFit instance
0156       std::vector<TMVA::Event *>   fEventSample;    ///< the complete training sample
0157       Double_t                     fSignalFraction; ///< scalefactor for bkg events to modify initial s/b fraction in training data
0158 
0159       // ntuple
0160       TTree                       *fMonitorNtuple;  ///< pointer to monitor rule ntuple
0161       Double_t                     fNTImportance;   ///< ntuple: rule importance
0162       Double_t                     fNTCoefficient;  ///< ntuple: rule coefficient
0163       Double_t                     fNTSupport;      ///< ntuple: rule support
0164       Int_t                        fNTNcuts;        ///< ntuple: rule number of cuts
0165       Int_t                        fNTNvars;        ///< ntuple: rule number of vars
0166       Double_t                     fNTPtag;         ///< ntuple: rule P(tag)
0167       Double_t                     fNTPss;          ///< ntuple: rule P(tag s, true s)
0168       Double_t                     fNTPsb;          ///< ntuple: rule P(tag s, true b)
0169       Double_t                     fNTPbs;          ///< ntuple: rule P(tag b, true s)
0170       Double_t                     fNTPbb;          ///< ntuple: rule P(tag b, true b)
0171       Double_t                     fNTSSB;          ///< ntuple: rule S/(S+B)
0172       Int_t                        fNTType;         ///< ntuple: rule type (+1->signal, -1->bkg)
0173 
0174       // options
0175       TString                      fRuleFitModuleS;///< which rulefit module to use
0176       Bool_t                       fUseRuleFitJF;  ///< if true interface with J.Friedmans RuleFit module
0177       TString                      fRFWorkDir;     ///< working directory from Friedmans module
0178       Int_t                        fRFNrules;      ///< max number of rules (only Friedmans module)
0179       Int_t                        fRFNendnodes;   ///< max number of rules (only Friedmans module)
0180       std::vector<DecisionTree *>  fForest;        ///< the forest
0181       Int_t                        fNTrees;        ///< number of trees in forest
0182       Double_t                     fTreeEveFrac;   ///< fraction of events used for training each tree
0183       SeparationBase              *fSepType;       ///< the separation used in node splitting
0184       Double_t                     fMinFracNEve;   ///< min fraction of number events
0185       Double_t                     fMaxFracNEve;   ///< ditto max
0186       Int_t                        fNCuts;         ///< grid used in cut applied in node splitting
0187       TString                      fSepTypeS;        ///< forest generation: separation type - see DecisionTree
0188       TString                      fPruneMethodS;    ///< forest generation: prune method - see DecisionTree
0189       TMVA::DecisionTree::EPruneMethod fPruneMethod; ///< forest generation: method used for pruning - see DecisionTree
0190       Double_t                     fPruneStrength;   ///< forest generation: prune strength - see DecisionTree
0191       TString                      fForestTypeS;     ///< forest generation: how the trees are generated
0192       Bool_t                       fUseBoost;        ///< use boosted events for forest generation
0193       //
0194       Double_t                     fGDPathEveFrac;  ///< GD path: fraction of subsamples used for the fitting
0195       Double_t                     fGDValidEveFrac; ///< GD path: fraction of subsamples used for the fitting
0196       Double_t                     fGDTau;          ///< GD path: def threshold fraction [0..1]
0197       Double_t                     fGDTauPrec;      ///< GD path: precision of estimated tau
0198       Double_t                     fGDTauMin;       ///< GD path: min threshold fraction [0..1]
0199       Double_t                     fGDTauMax;       ///< GD path: max threshold fraction [0..1]
0200       UInt_t                       fGDTauScan;      ///< GD path: number of points to scan
0201       Double_t                     fGDPathStep;     ///< GD path: step size in path
0202       Int_t                        fGDNPathSteps;   ///< GD path: number of steps
0203       Double_t                     fGDErrScale;     ///< GD path: stop
0204       Double_t                     fMinimp;         ///< rule/linear: minimum importance
0205       //
0206       TString                      fModelTypeS;     ///< rule ensemble: which model (rule,linear or both)
0207       Double_t                     fRuleMinDist;    ///< rule min distance - see RuleEnsemble
0208       Double_t                     fLinQuantile;    ///< quantile cut to remove outliers - see RuleEnsemble
0209 
0210       ClassDef(MethodRuleFit,0);  // Friedman's RuleFit method
0211    };
0212 
0213 } // namespace TMVA
0214 
0215 
0216 //_______________________________________________________________________
0217 template<typename T>
0218 inline Int_t TMVA::MethodRuleFit::VerifyRange( const T& var, const T& vmin, const T& vmax )
0219 {
0220    // check range and return +1 if above, -1 if below or 0 if inside
0221    if (var>vmax) return  1;
0222    if (var<vmin) return -1;
0223    return 0;
0224 }
0225 
0226 //_______________________________________________________________________
0227 template<typename T>
0228 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax )
0229 {
0230    // verify range and print out message
0231    // if outside range, set to closest limit
0232    Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
0233    Bool_t modif=kFALSE;
0234    if (dir==1) {
0235       modif = kTRUE;
0236       var=vmax;
0237    }
0238    if (dir==-1) {
0239       modif = kTRUE;
0240       var=vmin;
0241    }
0242    if (modif) {
0243       mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to new value = " << var << Endl;
0244    }
0245    return modif;
0246 }
0247 
0248 //_______________________________________________________________________
0249 template<typename T>
0250 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef )
0251 {
0252    // verify range and print out message
0253    // if outside range, set to given default value
0254    Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
0255    Bool_t modif=kFALSE;
0256    if (dir!=0) {
0257       modif = kTRUE;
0258       var=vdef;
0259    }
0260    if (modif) {
0261       mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to default value = " << var << Endl;
0262    }
0263    return modif;
0264 }
0265 
0266 
0267 #endif // MethodRuleFit_H