Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:23:03

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : RuleFit                                                               *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      A class implementing various fits of rule ensembles                       *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA      *
0015  *      Helge Voss         <Helge.Voss@cern.ch>         - MPI-KP Heidelberg, Ger. *
0016  *                                                                                *
0017  * Copyright (c) 2005:                                                            *
0018  *      CERN, Switzerland                                                         *
0019  *      Iowa State U.                                                             *
0020  *      MPI-K Heidelberg, Germany                                                 *
0021  *                                                                                *
0022  * Redistribution and use in source and binary forms, with or without             *
0023  * modification, are permitted according to the terms listed in LICENSE           *
0024  * (see tmva/doc/LICENSE)                                          *
0025  **********************************************************************************/
0026 
0027 #ifndef ROOT_TMVA_RuleFit
0028 #define ROOT_TMVA_RuleFit
0029 
0030 #include "TMVA/DecisionTree.h"
0031 #include "TMVA/RuleEnsemble.h"
0032 #include "TMVA/RuleFitParams.h"
0033 #include "TMVA/Event.h"
0034 
0035 #include <algorithm>
0036 #include <random>
0037 #include <vector>
0038 
0039 namespace TMVA {
0040 
0041 
0042    class MethodBase;
0043    class MethodRuleFit;
0044    class MsgLogger;
0045 
0046    class RuleFit {
0047 
0048    public:
0049 
0050       // main constructor
0051       RuleFit( const TMVA::MethodBase *rfbase );
0052 
0053       // empty constructor
0054       RuleFit( void );
0055 
0056       virtual ~RuleFit( void );
0057 
0058       void InitNEveEff();
0059       void InitPtrs( const TMVA::MethodBase *rfbase );
0060       void Initialize(  const TMVA::MethodBase *rfbase );
0061 
0062       void SetMsgType( EMsgType t );
0063 
0064       void SetTrainingEvents( const std::vector<const TMVA::Event *> & el );
0065 
0066       void ReshuffleEvents()
0067       {
0068          std::shuffle(fTrainingEventsRndm.begin(), fTrainingEventsRndm.end(), fRNGEngine);
0069       }
0070 
0071       void SetMethodBase( const MethodBase *rfbase );
0072 
0073       // make the forest of trees for rule generation
0074       void MakeForest();
0075 
0076       // build a tree
0077       void BuildTree( TMVA::DecisionTree *dt );
0078 
0079       // save event weights
0080       void SaveEventWeights();
0081 
0082       // restore saved event weights
0083       void RestoreEventWeights();
0084 
0085       // boost events based on the given tree
0086       void Boost( TMVA::DecisionTree *dt );
0087 
0088       // calculate and print some statistics on the given forest
0089       void ForestStatistics();
0090 
0091       // calculate the discriminating variable for the given event
0092       Double_t EvalEvent( const Event& e );
0093 
0094       // calculate sum of
0095       Double_t CalcWeightSum( const std::vector<const TMVA::Event *> *events, UInt_t neve=0 );
0096 
0097       // do the fitting of the coefficients
0098       void     FitCoefficients();
0099 
0100       // calculate variable and rule importance from a set of events
0101       void     CalcImportance();
0102 
0103       // set usage of linear term
0104       void     SetModelLinear()                      { fRuleEnsemble.SetModelLinear(); }
0105       // set usage of rules
0106       void     SetModelRules()                       { fRuleEnsemble.SetModelRules(); }
0107       // set usage of linear term
0108       void     SetModelFull()                        { fRuleEnsemble.SetModelFull(); }
0109       // set minimum importance allowed
0110       void     SetImportanceCut( Double_t minimp=0 ) { fRuleEnsemble.SetImportanceCut(minimp); }
0111       // set minimum rule distance - see RuleEnsemble
0112       void     SetRuleMinDist( Double_t d )          { fRuleEnsemble.SetRuleMinDist(d); }
0113       // set path related parameters
0114       void     SetGDTau( Double_t t=0.0 )       { fRuleFitParams.SetGDTau(t); }
0115       void     SetGDPathStep( Double_t s=0.01 ) { fRuleFitParams.SetGDPathStep(s); }
0116       void     SetGDNPathSteps( Int_t n=100 )   { fRuleFitParams.SetGDNPathSteps(n); }
0117       // make visualization histograms
0118       void     SetVisHistsUseImp( Bool_t f ) { fVisHistsUseImp = f; }
0119       void     UseImportanceVisHists()       { fVisHistsUseImp = kTRUE; }
0120       void     UseCoefficientsVisHists()     { fVisHistsUseImp = kFALSE; }
0121       void     MakeVisHists();
0122       void     FillVisHistCut(const Rule * rule, std::vector<TH2F *> & hlist);
0123       void     FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist);
0124       void     FillCut(TH2F* h2,const TMVA::Rule *rule,Int_t vind);
0125       void     FillLin(TH2F* h2,Int_t vind);
0126       void     FillCorr(TH2F* h2,const TMVA::Rule *rule,Int_t v1, Int_t v2);
0127       void     NormVisHists(std::vector<TH2F *> & hlist);
0128       void     MakeDebugHists();
0129       Bool_t   GetCorrVars(TString & title, TString & var1, TString & var2);
0130       // accessors
0131       UInt_t        GetNTreeSample()            const { return fNTreeSample; }
0132       Double_t      GetNEveEff()                const { return fNEveEffTrain; } // reweighted number of events = sum(wi)
0133       const Event*  GetTrainingEvent(UInt_t i)  const { return static_cast< const Event *>(fTrainingEvents[i]); }
0134       Double_t      GetTrainingEventWeight(UInt_t i)  const { return fTrainingEvents[i]->GetWeight(); }
0135 
0136       //      const Event*  GetTrainingEvent(UInt_t i, UInt_t isub)  const { return &(fTrainingEvents[fSubsampleEvents[isub]])[i]; }
0137 
0138       const std::vector< const TMVA::Event * > & GetTrainingEvents()  const { return fTrainingEvents; }
0139       //      const std::vector< Int_t >               & GetSubsampleEvents() const { return fSubsampleEvents; }
0140 
0141       //      void  GetSubsampleEvents(Int_t sub, UInt_t & ibeg, UInt_t & iend) const;
0142       void  GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
0143       //
0144       const std::vector< const TMVA::DecisionTree *> & GetForest()     const { return fForest; }
0145       const RuleEnsemble                       & GetRuleEnsemble()     const { return fRuleEnsemble; }
0146       RuleEnsemble                       * GetRuleEnsemblePtr()        { return &fRuleEnsemble; }
0147       const RuleFitParams                      & GetRuleFitParams()    const { return fRuleFitParams; }
0148       RuleFitParams                      * GetRuleFitParamsPtr()       { return &fRuleFitParams; }
0149       const MethodRuleFit                      * GetMethodRuleFit()    const { return fMethodRuleFit; }
0150       const MethodBase                         * GetMethodBase()       const { return fMethodBase; }
0151 
0152    private:
0153 
0154       // copy constructor
0155       RuleFit( const RuleFit & other );
0156 
0157       // copy method
0158       void Copy( const RuleFit & other );
0159 
0160       std::vector<const TMVA::Event *>    fTrainingEvents;      ///< all training events
0161       std::vector<const TMVA::Event *>    fTrainingEventsRndm;  ///< idem, but randomly shuffled
0162       std::vector<Double_t>               fEventWeights;        ///< original weights of the events - follows fTrainingEvents
0163       UInt_t                              fNTreeSample;         ///< number of events in sub sample = frac*neve
0164 
0165       Double_t                            fNEveEffTrain;    ///< reweighted number of events = sum(wi)
0166       std::vector< const TMVA::DecisionTree *>  fForest;    ///< the input forest of decision trees
0167       RuleEnsemble                        fRuleEnsemble;    ///< the ensemble of rules
0168       RuleFitParams                       fRuleFitParams;   ///< fit rule parameters
0169       const MethodRuleFit                *fMethodRuleFit;   ///< pointer the method which initialized this RuleFit instance
0170       const MethodBase                   *fMethodBase;      ///< pointer the method base which initialized this RuleFit instance
0171       Bool_t                              fVisHistsUseImp;  ///< if true, use importance as weight; else coef in vis hists
0172 
0173       mutable MsgLogger*                  fLogger;   ///<! message logger
0174       MsgLogger& Log() const { return *fLogger; }
0175 
0176       static const Int_t randSEED = 0; // set to 1 for debugging purposes or to zero for random seeds
0177       std::default_random_engine fRNGEngine;
0178 
0179       ClassDef(RuleFit,0);  // Calculations for Friedman's RuleFit method
0180    };
0181 }
0182 
0183 #endif