Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:10

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : RuleFitAPI                                                            *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Interface to Friedman's RuleFit method                                    *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Andreas Hoecker    <Andreas.Hocker@cern.ch>     - CERN, Switzerland       *
0015  *      Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA      *
0016  *      Helge Voss         <Helge.Voss@cern.ch>         - MPI-KP Heidelberg, Ger. *
0017  *      Kai Voss           <Kai.Voss@cern.ch>           - U. of Victoria, Canada  *
0018  *                                                                                *
0019  * Copyright (c) 2005:                                                            *
0020  *      CERN, Switzerland                                                         *
0021  *      U. of Victoria, Canada                                                    *
0022  *      MPI-KP Heidelberg, Germany                                                *
0023  *      LAPP, Annecy, France                                                      *
0024  *                                                                                *
0025  * Redistribution and use in source and binary forms, with or without             *
0026  * modification, are permitted according to the terms listed in LICENSE           *
0027  *                                                                                *
0028  **********************************************************************************/
0029 
0030 #ifndef ROOT_TMVA_RuleFitAPI
0031 #define ROOT_TMVA_RuleFitAPI
0032 
0033 //////////////////////////////////////////////////////////////////////////
0034 //                                                                      //
0035 // RuleFitAPI                                                           //
0036 //                                                                      //
0037 // J Friedman's RuleFit method                                          //
0038 //                                                                      //
0039 //////////////////////////////////////////////////////////////////////////
0040 
0041 #include <fstream>
0042 #include <vector>
0043 
0044 #include "TMVA/MsgLogger.h"
0045 
0046 namespace TMVA {
0047 
0048    class MethodRuleFit;
0049    class RuleFit;
0050 
0051    class RuleFitAPI {
0052 
0053    public:
0054 
0055       RuleFitAPI( const TMVA::MethodRuleFit *rfbase, TMVA::RuleFit *rulefit, EMsgType minType );
0056 
0057       virtual ~RuleFitAPI();
0058 
0059       // welcome message
0060       void WelcomeMessage();
0061 
0062       // message on howto get the binary
0063       void HowtoSetupRF();
0064 
0065       // Set RuleFit working directory
0066       void SetRFWorkDir(const char * wdir);
0067 
0068       // Check RF work dir - aborts if it fails
0069       void CheckRFWorkDir();
0070 
0071       // run rf_go.exe in various modes
0072       inline void TrainRuleFit();
0073       inline void TestRuleFit();
0074       inline void VarImp();
0075 
0076       // read result into MethodRuleFit
0077       Bool_t ReadModelSum();
0078 
0079       // Get working directory
0080       const TString GetRFWorkDir() const { return fRFWorkDir; }
0081 
0082    protected:
0083 
0084       enum ERFMode    { kRfRegress=1, kRfClass=2 };          // RuleFit modes, default=Class
0085       enum EModel     { kRfLinear=0, kRfRules=1, kRfBoth=2 }; // models, default=Both (rules+linear)
0086       enum ERFProgram { kRfTrain=0, kRfPredict, kRfVarimp };    // rf_go.exe running mode
0087 
0088       // integer parameters
0089       typedef struct {
0090          Int_t mode;
0091          Int_t lmode;
0092          Int_t n;
0093          Int_t p;
0094          Int_t max_rules;
0095          Int_t tree_size;
0096          Int_t path_speed;
0097          Int_t path_xval;
0098          Int_t path_steps;
0099          Int_t path_testfreq;
0100          Int_t tree_store;
0101          Int_t cat_store;
0102       } IntParms;
0103 
0104       // float parameters
0105       typedef struct {
0106          Float_t  xmiss;
0107          Float_t  trim_qntl;
0108          Float_t  huber;
0109          Float_t  inter_supp;
0110          Float_t  memory_par;
0111          Float_t  samp_fract;
0112          Float_t  path_inc;
0113          Float_t  conv_fac;
0114       } RealParms;
0115 
0116       // setup
0117       void InitRuleFit();
0118       void FillRealParmsDef();
0119       void FillIntParmsDef();
0120       void ImportSetup();
0121       void SetTrainParms();
0122       void SetTestParms();
0123 
0124       // run
0125       Int_t  RunRuleFit();
0126 
0127       // set rf_go.exe running mode
0128       void SetRFTrain()   { fRFProgram = kRfTrain; }
0129       void SetRFPredict() { fRFProgram = kRfPredict; }
0130       void SetRFVarimp()  { fRFProgram = kRfVarimp; }
0131 
0132       // handle rulefit files
0133       inline TString GetRFName(TString name);
0134       inline Bool_t  OpenRFile(TString name, std::ofstream & f);
0135       inline Bool_t  OpenRFile(TString name, std::ifstream & f);
0136 
0137       // read/write binary files
0138       inline Bool_t WriteInt(std::ofstream &   f, const Int_t   *v, Int_t n=1);
0139       inline Bool_t WriteFloat(std::ofstream & f, const Float_t *v, Int_t n=1);
0140       inline Int_t  ReadInt(std::ifstream & f,   Int_t *v, Int_t n=1) const;
0141       inline Int_t  ReadFloat(std::ifstream & f, Float_t *v, Int_t n=1) const;
0142 
0143       // write rf_go.exe i/o files
0144       Bool_t WriteAll();
0145       Bool_t WriteIntParms();
0146       Bool_t WriteRealParms();
0147       Bool_t WriteLx();
0148       Bool_t WriteProgram();
0149       Bool_t WriteRealVarImp();
0150       Bool_t WriteRfOut();
0151       Bool_t WriteRfStatus();
0152       Bool_t WriteRuleFitMod();
0153       Bool_t WriteRuleFitSum();
0154       Bool_t WriteTrain();
0155       Bool_t WriteVarNames();
0156       Bool_t WriteVarImp();
0157       Bool_t WriteYhat();
0158       Bool_t WriteTest();
0159 
0160       // read rf_go.exe i/o files
0161       Bool_t ReadYhat();
0162       Bool_t ReadIntParms();
0163       Bool_t ReadRealParms();
0164       Bool_t ReadLx();
0165       Bool_t ReadProgram();
0166       Bool_t ReadRealVarImp();
0167       Bool_t ReadRfOut();
0168       Bool_t ReadRfStatus();
0169       Bool_t ReadRuleFitMod();
0170       Bool_t ReadRuleFitSum();
0171       Bool_t ReadTrainX();
0172       Bool_t ReadTrainY();
0173       Bool_t ReadTrainW();
0174       Bool_t ReadVarNames();
0175       Bool_t ReadVarImp();
0176 
0177    private:
0178       // prevent empty constructor from being used
0179       RuleFitAPI();
0180       const MethodRuleFit *fMethodRuleFit; ///< parent method - set in constructor
0181       RuleFit             *fRuleFit;       ///< non const ptr to RuleFit class in MethodRuleFit
0182       //
0183       std::vector<Float_t> fRFYhat;      ///< score results from test sample
0184       std::vector<Float_t> fRFVarImp;    ///< variable importances
0185       std::vector<Int_t>   fRFVarImpInd; ///< variable index
0186       TString              fRFWorkDir;   ///< working directory
0187       IntParms             fRFIntParms;  ///< integer parameters
0188       RealParms            fRFRealParms; ///< real parameters
0189       std::vector<int>     fRFLx;        ///< variable selector
0190       ERFProgram           fRFProgram;   ///< what to run
0191       TString              fModelType;   ///< model type string
0192 
0193       mutable MsgLogger    fLogger;      ///<! message logger
0194 
0195       ClassDef(RuleFitAPI,0);        // Friedman's RuleFit method
0196 
0197    };
0198 
0199 } // namespace TMVA
0200 
0201 //_______________________________________________________________________
0202 void TMVA::RuleFitAPI::TrainRuleFit()
0203 {
0204    // run rf_go.exe to train the model
0205    SetTrainParms();
0206    WriteAll();
0207    RunRuleFit();
0208 }
0209 
0210 //_______________________________________________________________________
0211 void TMVA::RuleFitAPI::TestRuleFit()
0212 {
0213    // run rf_go.exe with the test data
0214    SetTestParms();
0215    WriteAll();
0216    RunRuleFit();
0217    ReadYhat(); // read in the scores
0218 }
0219 
0220 //_______________________________________________________________________
0221 void TMVA::RuleFitAPI::VarImp()
0222 {
0223    // run rf_go.exe to get the variable importance
0224    SetRFVarimp();
0225    WriteAll();
0226    RunRuleFit();
0227    ReadVarImp(); // read in the variable importances
0228 }
0229 
0230 //_______________________________________________________________________
0231 TString TMVA::RuleFitAPI::GetRFName(TString name)
0232 {
0233    // get the name including the rulefit directory
0234    return fRFWorkDir+"/"+name;
0235 }
0236 
0237 //_______________________________________________________________________
0238 Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ofstream & f)
0239 {
0240    // open a file for writing in the rulefit directory
0241    TString fullName = GetRFName(name);
0242    f.open(fullName);
0243    if (!f.is_open()) {
0244       fLogger << kERROR << "Error opening RuleFit file for output: "
0245               << fullName << Endl;
0246       return kFALSE;
0247    }
0248    return kTRUE;
0249 }
0250 
0251 //_______________________________________________________________________
0252 Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ifstream & f)
0253 {
0254    // open a file for reading in the rulefit directory
0255    TString fullName = GetRFName(name);
0256    f.open(fullName);
0257    if (!f.is_open()) {
0258       fLogger << kERROR << "Error opening RuleFit file for input: "
0259               << fullName << Endl;
0260       return kFALSE;
0261    }
0262    return kTRUE;
0263 }
0264 
0265 //_______________________________________________________________________
0266 Bool_t TMVA::RuleFitAPI::WriteInt(std::ofstream &   f, const Int_t   *v, Int_t n)
0267 {
0268    // write an int
0269    if (!f.is_open()) return kFALSE;
0270    return (Bool_t)f.write(reinterpret_cast<char const *>(v), n*sizeof(Int_t));
0271 }
0272 
0273 //_______________________________________________________________________
0274 Bool_t TMVA::RuleFitAPI::WriteFloat(std::ofstream & f, const Float_t *v, Int_t n)
0275 {
0276    // write a float
0277    if (!f.is_open()) return kFALSE;
0278    return (Bool_t)f.write(reinterpret_cast<char const *>(v), n*sizeof(Float_t));
0279 }
0280 
0281 //_______________________________________________________________________
0282 Int_t TMVA::RuleFitAPI::ReadInt(std::ifstream & f,   Int_t *v, Int_t n) const
0283 {
0284    // read an int
0285    if (!f.is_open()) return 0;
0286    if (f.read(reinterpret_cast<char *>(v), n*sizeof(Int_t))) return 1;
0287    return 0;
0288 }
0289 
0290 //_______________________________________________________________________
0291 Int_t TMVA::RuleFitAPI::ReadFloat(std::ifstream & f, Float_t *v, Int_t n) const
0292 {
0293    // read a float
0294    if (!f.is_open()) return 0;
0295    if (f.read(reinterpret_cast<char *>(v), n*sizeof(Float_t))) return 1;
0296    return 0;
0297 }
0298 
0299 #endif // RuleFitAPI_H