File indexing completed on 2025-01-18 10:11:10
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 #ifndef ROOT_TMVA_RuleFitAPI
0031 #define ROOT_TMVA_RuleFitAPI
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041 #include <fstream>
0042 #include <vector>
0043
0044 #include "TMVA/MsgLogger.h"
0045
0046 namespace TMVA {
0047
0048 class MethodRuleFit;
0049 class RuleFit;
0050
0051 class RuleFitAPI {
0052
0053 public:
0054
0055 RuleFitAPI( const TMVA::MethodRuleFit *rfbase, TMVA::RuleFit *rulefit, EMsgType minType );
0056
0057 virtual ~RuleFitAPI();
0058
0059
0060 void WelcomeMessage();
0061
0062
0063 void HowtoSetupRF();
0064
0065
0066 void SetRFWorkDir(const char * wdir);
0067
0068
0069 void CheckRFWorkDir();
0070
0071
0072 inline void TrainRuleFit();
0073 inline void TestRuleFit();
0074 inline void VarImp();
0075
0076
0077 Bool_t ReadModelSum();
0078
0079
0080 const TString GetRFWorkDir() const { return fRFWorkDir; }
0081
0082 protected:
0083
0084 enum ERFMode { kRfRegress=1, kRfClass=2 };
0085 enum EModel { kRfLinear=0, kRfRules=1, kRfBoth=2 };
0086 enum ERFProgram { kRfTrain=0, kRfPredict, kRfVarimp };
0087
0088
0089 typedef struct {
0090 Int_t mode;
0091 Int_t lmode;
0092 Int_t n;
0093 Int_t p;
0094 Int_t max_rules;
0095 Int_t tree_size;
0096 Int_t path_speed;
0097 Int_t path_xval;
0098 Int_t path_steps;
0099 Int_t path_testfreq;
0100 Int_t tree_store;
0101 Int_t cat_store;
0102 } IntParms;
0103
0104
0105 typedef struct {
0106 Float_t xmiss;
0107 Float_t trim_qntl;
0108 Float_t huber;
0109 Float_t inter_supp;
0110 Float_t memory_par;
0111 Float_t samp_fract;
0112 Float_t path_inc;
0113 Float_t conv_fac;
0114 } RealParms;
0115
0116
0117 void InitRuleFit();
0118 void FillRealParmsDef();
0119 void FillIntParmsDef();
0120 void ImportSetup();
0121 void SetTrainParms();
0122 void SetTestParms();
0123
0124
0125 Int_t RunRuleFit();
0126
0127
0128 void SetRFTrain() { fRFProgram = kRfTrain; }
0129 void SetRFPredict() { fRFProgram = kRfPredict; }
0130 void SetRFVarimp() { fRFProgram = kRfVarimp; }
0131
0132
0133 inline TString GetRFName(TString name);
0134 inline Bool_t OpenRFile(TString name, std::ofstream & f);
0135 inline Bool_t OpenRFile(TString name, std::ifstream & f);
0136
0137
0138 inline Bool_t WriteInt(std::ofstream & f, const Int_t *v, Int_t n=1);
0139 inline Bool_t WriteFloat(std::ofstream & f, const Float_t *v, Int_t n=1);
0140 inline Int_t ReadInt(std::ifstream & f, Int_t *v, Int_t n=1) const;
0141 inline Int_t ReadFloat(std::ifstream & f, Float_t *v, Int_t n=1) const;
0142
0143
0144 Bool_t WriteAll();
0145 Bool_t WriteIntParms();
0146 Bool_t WriteRealParms();
0147 Bool_t WriteLx();
0148 Bool_t WriteProgram();
0149 Bool_t WriteRealVarImp();
0150 Bool_t WriteRfOut();
0151 Bool_t WriteRfStatus();
0152 Bool_t WriteRuleFitMod();
0153 Bool_t WriteRuleFitSum();
0154 Bool_t WriteTrain();
0155 Bool_t WriteVarNames();
0156 Bool_t WriteVarImp();
0157 Bool_t WriteYhat();
0158 Bool_t WriteTest();
0159
0160
0161 Bool_t ReadYhat();
0162 Bool_t ReadIntParms();
0163 Bool_t ReadRealParms();
0164 Bool_t ReadLx();
0165 Bool_t ReadProgram();
0166 Bool_t ReadRealVarImp();
0167 Bool_t ReadRfOut();
0168 Bool_t ReadRfStatus();
0169 Bool_t ReadRuleFitMod();
0170 Bool_t ReadRuleFitSum();
0171 Bool_t ReadTrainX();
0172 Bool_t ReadTrainY();
0173 Bool_t ReadTrainW();
0174 Bool_t ReadVarNames();
0175 Bool_t ReadVarImp();
0176
0177 private:
0178
0179 RuleFitAPI();
0180 const MethodRuleFit *fMethodRuleFit;
0181 RuleFit *fRuleFit;
0182
0183 std::vector<Float_t> fRFYhat;
0184 std::vector<Float_t> fRFVarImp;
0185 std::vector<Int_t> fRFVarImpInd;
0186 TString fRFWorkDir;
0187 IntParms fRFIntParms;
0188 RealParms fRFRealParms;
0189 std::vector<int> fRFLx;
0190 ERFProgram fRFProgram;
0191 TString fModelType;
0192
0193 mutable MsgLogger fLogger;
0194
0195 ClassDef(RuleFitAPI,0);
0196
0197 };
0198
0199 }
0200
0201
0202 void TMVA::RuleFitAPI::TrainRuleFit()
0203 {
0204
0205 SetTrainParms();
0206 WriteAll();
0207 RunRuleFit();
0208 }
0209
0210
0211 void TMVA::RuleFitAPI::TestRuleFit()
0212 {
0213
0214 SetTestParms();
0215 WriteAll();
0216 RunRuleFit();
0217 ReadYhat();
0218 }
0219
0220
0221 void TMVA::RuleFitAPI::VarImp()
0222 {
0223
0224 SetRFVarimp();
0225 WriteAll();
0226 RunRuleFit();
0227 ReadVarImp();
0228 }
0229
0230
0231 TString TMVA::RuleFitAPI::GetRFName(TString name)
0232 {
0233
0234 return fRFWorkDir+"/"+name;
0235 }
0236
0237
0238 Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ofstream & f)
0239 {
0240
0241 TString fullName = GetRFName(name);
0242 f.open(fullName);
0243 if (!f.is_open()) {
0244 fLogger << kERROR << "Error opening RuleFit file for output: "
0245 << fullName << Endl;
0246 return kFALSE;
0247 }
0248 return kTRUE;
0249 }
0250
0251
0252 Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ifstream & f)
0253 {
0254
0255 TString fullName = GetRFName(name);
0256 f.open(fullName);
0257 if (!f.is_open()) {
0258 fLogger << kERROR << "Error opening RuleFit file for input: "
0259 << fullName << Endl;
0260 return kFALSE;
0261 }
0262 return kTRUE;
0263 }
0264
0265
0266 Bool_t TMVA::RuleFitAPI::WriteInt(std::ofstream & f, const Int_t *v, Int_t n)
0267 {
0268
0269 if (!f.is_open()) return kFALSE;
0270 return (Bool_t)f.write(reinterpret_cast<char const *>(v), n*sizeof(Int_t));
0271 }
0272
0273
0274 Bool_t TMVA::RuleFitAPI::WriteFloat(std::ofstream & f, const Float_t *v, Int_t n)
0275 {
0276
0277 if (!f.is_open()) return kFALSE;
0278 return (Bool_t)f.write(reinterpret_cast<char const *>(v), n*sizeof(Float_t));
0279 }
0280
0281
0282 Int_t TMVA::RuleFitAPI::ReadInt(std::ifstream & f, Int_t *v, Int_t n) const
0283 {
0284
0285 if (!f.is_open()) return 0;
0286 if (f.read(reinterpret_cast<char *>(v), n*sizeof(Int_t))) return 1;
0287 return 0;
0288 }
0289
0290
0291 Int_t TMVA::RuleFitAPI::ReadFloat(std::ifstream & f, Float_t *v, Int_t n) const
0292 {
0293
0294 if (!f.is_open()) return 0;
0295 if (f.read(reinterpret_cast<char *>(v), n*sizeof(Float_t))) return 1;
0296 return 0;
0297 }
0298
0299 #endif