Warning, file /include/root/TMVA/RuleFitAPI.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 #ifndef ROOT_TMVA_RuleFitAPI
0031 #define ROOT_TMVA_RuleFitAPI
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041 #include <fstream>
0042 #include <vector>
0043
0044 #include "TMVA/MsgLogger.h"
0045
0046 namespace TMVA {
0047
0048 class MethodRuleFit;
0049 class RuleFit;
0050
0051 class RuleFitAPI {
0052
0053 public:
0054
0055 RuleFitAPI( const TMVA::MethodRuleFit *rfbase, TMVA::RuleFit *rulefit, EMsgType minType );
0056
0057 virtual ~RuleFitAPI();
0058
0059
0060 void WelcomeMessage();
0061
0062
0063 void HowtoSetupRF();
0064
0065
0066 void SetRFWorkDir(const char * wdir);
0067
0068
0069 void CheckRFWorkDir();
0070
0071
0072 inline void TrainRuleFit();
0073 inline void TestRuleFit();
0074 inline void VarImp();
0075
0076
0077 Bool_t ReadModelSum();
0078
0079
0080 const TString GetRFWorkDir() const { return fRFWorkDir; }
0081
0082 protected:
0083
0084 enum ERFMode { kRfRegress=1, kRfClass=2 };
0085 enum EModel { kRfLinear=0, kRfRules=1, kRfBoth=2 };
0086 enum ERFProgram { kRfTrain=0, kRfPredict, kRfVarimp };
0087
0088
0089 typedef struct {
0090 Int_t mode;
0091 Int_t lmode;
0092 Int_t n;
0093 Int_t p;
0094 Int_t max_rules;
0095 Int_t tree_size;
0096 Int_t path_speed;
0097 Int_t path_xval;
0098 Int_t path_steps;
0099 Int_t path_testfreq;
0100 Int_t tree_store;
0101 Int_t cat_store;
0102 } IntParms;
0103
0104
0105 typedef struct {
0106 Float_t xmiss;
0107 Float_t trim_qntl;
0108 Float_t huber;
0109 Float_t inter_supp;
0110 Float_t memory_par;
0111 Float_t samp_fract;
0112 Float_t path_inc;
0113 Float_t conv_fac;
0114 } RealParms;
0115
0116
0117 void InitRuleFit();
0118 void FillRealParmsDef();
0119 void FillIntParmsDef();
0120 void ImportSetup();
0121 void SetTrainParms();
0122 void SetTestParms();
0123
0124
0125 Int_t RunRuleFit();
0126
0127
0128 void SetRFTrain() { fRFProgram = kRfTrain; }
0129 void SetRFPredict() { fRFProgram = kRfPredict; }
0130 void SetRFVarimp() { fRFProgram = kRfVarimp; }
0131
0132
0133 inline TString GetRFName(TString name);
0134 inline Bool_t OpenRFile(TString name, std::ofstream & f);
0135 inline Bool_t OpenRFile(TString name, std::ifstream & f);
0136
0137
0138 inline Bool_t WriteInt(std::ofstream & f, const Int_t *v, Int_t n=1);
0139 inline Bool_t WriteFloat(std::ofstream & f, const Float_t *v, Int_t n=1);
0140 inline Int_t ReadInt(std::ifstream & f, Int_t *v, Int_t n=1) const;
0141 inline Int_t ReadFloat(std::ifstream & f, Float_t *v, Int_t n=1) const;
0142
0143
0144 Bool_t WriteAll();
0145 Bool_t WriteIntParms();
0146 Bool_t WriteRealParms();
0147 Bool_t WriteLx();
0148 Bool_t WriteProgram();
0149 Bool_t WriteRealVarImp();
0150 Bool_t WriteRfOut();
0151 Bool_t WriteRfStatus();
0152 Bool_t WriteRuleFitMod();
0153 Bool_t WriteRuleFitSum();
0154 Bool_t WriteTrain();
0155 Bool_t WriteVarNames();
0156 Bool_t WriteVarImp();
0157 Bool_t WriteYhat();
0158 Bool_t WriteTest();
0159
0160
0161 Bool_t ReadYhat();
0162 Bool_t ReadIntParms();
0163 Bool_t ReadRealParms();
0164 Bool_t ReadLx();
0165 Bool_t ReadProgram();
0166 Bool_t ReadRealVarImp();
0167 Bool_t ReadRfOut();
0168 Bool_t ReadRfStatus();
0169 Bool_t ReadRuleFitMod();
0170 Bool_t ReadRuleFitSum();
0171 Bool_t ReadTrainX();
0172 Bool_t ReadTrainY();
0173 Bool_t ReadTrainW();
0174 Bool_t ReadVarNames();
0175 Bool_t ReadVarImp();
0176
0177 private:
0178
0179 RuleFitAPI();
0180 const MethodRuleFit *fMethodRuleFit;
0181 RuleFit *fRuleFit;
0182
0183 std::vector<Float_t> fRFYhat;
0184 std::vector<Float_t> fRFVarImp;
0185 std::vector<Int_t> fRFVarImpInd;
0186 TString fRFWorkDir;
0187 IntParms fRFIntParms;
0188 RealParms fRFRealParms;
0189 std::vector<int> fRFLx;
0190 ERFProgram fRFProgram;
0191 TString fModelType;
0192
0193 mutable MsgLogger fLogger;
0194
0195 ClassDef(RuleFitAPI,0);
0196
0197 };
0198
0199 }
0200
0201
0202 void TMVA::RuleFitAPI::TrainRuleFit()
0203 {
0204
0205 SetTrainParms();
0206 WriteAll();
0207 RunRuleFit();
0208 }
0209
0210
0211 void TMVA::RuleFitAPI::TestRuleFit()
0212 {
0213
0214 SetTestParms();
0215 WriteAll();
0216 RunRuleFit();
0217 ReadYhat();
0218 }
0219
0220
0221 void TMVA::RuleFitAPI::VarImp()
0222 {
0223
0224 SetRFVarimp();
0225 WriteAll();
0226 RunRuleFit();
0227 ReadVarImp();
0228 }
0229
0230
0231 TString TMVA::RuleFitAPI::GetRFName(TString name)
0232 {
0233
0234 return fRFWorkDir+"/"+name;
0235 }
0236
0237
0238 Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ofstream & f)
0239 {
0240
0241 TString fullName = GetRFName(name);
0242 f.open(fullName);
0243 if (!f.is_open()) {
0244 fLogger << kERROR << "Error opening RuleFit file for output: "
0245 << fullName << Endl;
0246 return kFALSE;
0247 }
0248 return kTRUE;
0249 }
0250
0251
0252 Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ifstream & f)
0253 {
0254
0255 TString fullName = GetRFName(name);
0256 f.open(fullName);
0257 if (!f.is_open()) {
0258 fLogger << kERROR << "Error opening RuleFit file for input: "
0259 << fullName << Endl;
0260 return kFALSE;
0261 }
0262 return kTRUE;
0263 }
0264
0265
0266 Bool_t TMVA::RuleFitAPI::WriteInt(std::ofstream & f, const Int_t *v, Int_t n)
0267 {
0268
0269 if (!f.is_open()) return kFALSE;
0270 return (Bool_t)f.write(reinterpret_cast<char const *>(v), n*sizeof(Int_t));
0271 }
0272
0273
0274 Bool_t TMVA::RuleFitAPI::WriteFloat(std::ofstream & f, const Float_t *v, Int_t n)
0275 {
0276
0277 if (!f.is_open()) return kFALSE;
0278 return (Bool_t)f.write(reinterpret_cast<char const *>(v), n*sizeof(Float_t));
0279 }
0280
0281
0282 Int_t TMVA::RuleFitAPI::ReadInt(std::ifstream & f, Int_t *v, Int_t n) const
0283 {
0284
0285 if (!f.is_open()) return 0;
0286 if (f.read(reinterpret_cast<char *>(v), n*sizeof(Int_t))) return 1;
0287 return 0;
0288 }
0289
0290
0291 Int_t TMVA::RuleFitAPI::ReadFloat(std::ifstream & f, Float_t *v, Int_t n) const
0292 {
0293
0294 if (!f.is_open()) return 0;
0295 if (f.read(reinterpret_cast<char *>(v), n*sizeof(Float_t))) return 1;
0296 return 0;
0297 }
0298
0299 #endif