File indexing completed on 2025-01-18 10:11:10
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 #ifndef ROOT_TMVA_RuleEnsemble
0030 #define ROOT_TMVA_RuleEnsemble
0031
0032 #include "TMath.h"
0033
0034 #include "TMVA/DecisionTree.h"
0035 #include "TMVA/Event.h"
0036 #include "TMVA/Rule.h"
0037 #include "TMVA/Types.h"
0038
0039 #include <vector>
0040
0041 class TH1F;
0042
0043 namespace TMVA {
0044
0045 class MethodBase;
0046 class RuleFit;
0047 class MethodRuleFit;
0048 class RuleEnsemble;
0049 class MsgLogger;
0050
0051 std::ostream& operator<<( std::ostream& os, const RuleEnsemble& event );
0052
0053 class RuleEnsemble {
0054
0055
0056 friend std::ostream& operator<< ( std::ostream& os, const RuleEnsemble& rules );
0057
0058 public:
0059
0060 enum ELearningModel { kFull=0, kRules=1, kLinear=2 };
0061
0062
0063 RuleEnsemble( RuleFit* rf );
0064
0065
0066 RuleEnsemble( const RuleEnsemble& other );
0067
0068
0069 RuleEnsemble();
0070
0071
0072 virtual ~RuleEnsemble();
0073
0074
0075 void Initialize( const RuleFit* rf );
0076
0077
0078 void SetMsgType( EMsgType t );
0079
0080
0081 void MakeModel();
0082
0083
0084 void MakeRules( const std::vector< const TMVA::DecisionTree *>& forest );
0085
0086
0087 void MakeLinearTerms();
0088
0089
0090 void SetModelLinear() { fLearningModel = kLinear; }
0091
0092
0093 void SetModelRules() { fLearningModel = kRules; }
0094
0095
0096 void SetModelFull() { fLearningModel = kFull; }
0097
0098
0099 void SetRules( const std::vector< TMVA::Rule *> & rules );
0100
0101
0102 void SetRuleFit( const RuleFit *rf ) { fRuleFit = rf; }
0103
0104
0105 void SetCoefficients( const std::vector< Double_t >& v );
0106 void SetCoefficient( UInt_t i, Double_t v ) { if (i<fRules.size()) fRules[i]->SetCoefficient(v); }
0107
0108 void SetOffset(Double_t v=0.0) { fOffset=v; }
0109 void AddOffset(Double_t v) { fOffset+=v; }
0110 void SetLinCoefficients( const std::vector< Double_t >& v ) { fLinCoefficients = v; }
0111 void SetLinCoefficient( UInt_t i, Double_t v ) { fLinCoefficients[i] = v; }
0112 void SetLinDM( const std::vector<Double_t> & xmin ) { fLinDM = xmin; }
0113 void SetLinDP( const std::vector<Double_t> & xmax ) { fLinDP = xmax; }
0114 void SetLinNorm( const std::vector<Double_t> & norm ) { fLinNorm = norm; }
0115
0116 Double_t CalcLinNorm( Double_t stdev ) { return ( stdev>0 ? fAverageRuleSigma/stdev : 1.0 ); }
0117
0118
0119 void ClearCoefficients( Double_t val=0 ) { for (UInt_t i=0; i<fRules.size(); i++) fRules[i]->SetCoefficient(val); }
0120 void ClearLinCoefficients( Double_t val=0 ) { for (UInt_t i=0; i<fLinCoefficients.size(); i++) fLinCoefficients[i]=val; }
0121 void ClearLinNorm( Double_t val=1.0 ) { for (UInt_t i=0; i<fLinNorm.size(); i++) fLinNorm[i]=val; }
0122
0123
0124 void SetRuleMinDist(Double_t d) { fRuleMinDist = d; }
0125
0126
0127 void SetImportanceCut(Double_t minimp=0) { fImportanceCut=minimp; }
0128
0129
0130 void SetLinQuantile(Double_t q) { fLinQuantile=q; }
0131
0132
0133 void SetAverageRuleSigma(Double_t v) { if (v>0.5) v=0.5; fAverageRuleSigma = v; fAverageSupport = 0.5*(1.0+TMath::Sqrt(1.0-4.0*v*v)); }
0134
0135
0136 Int_t CalcNRules( const TMVA::DecisionTree* dtree );
0137
0138 void FindNEndNodes( const TMVA::Node* node, Int_t& nendnodes );
0139
0140
0141 void SetEvent( const Event & e ) { fEvent = &e; fEventCacheOK = kFALSE; }
0142
0143
0144 void UpdateEventVal();
0145
0146
0147 void MakeRuleMap(const std::vector<const TMVA::Event *> *events=nullptr, UInt_t ifirst=0, UInt_t ilast=0);
0148
0149
0150 void ClearRuleMap() { fRuleMap.clear(); fRuleMapEvents=nullptr; }
0151
0152
0153
0154 Double_t EvalEvent() const;
0155 Double_t EvalEvent( const Event & e );
0156
0157
0158 Double_t EvalEvent( Double_t ofs,
0159 const std::vector<Double_t> & coefs,
0160 const std::vector<Double_t> & lincoefs) const;
0161 Double_t EvalEvent( const Event & e,
0162 Double_t ofs,
0163 const std::vector<Double_t> & coefs,
0164 const std::vector<Double_t> & lincoefs);
0165
0166
0167
0168 Double_t EvalEvent( UInt_t evtidx ) const;
0169 Double_t EvalEvent( UInt_t evtidx,
0170 Double_t ofs,
0171 const std::vector<Double_t> & coefs,
0172 const std::vector<Double_t> & lincoefs) const;
0173
0174
0175
0176 Double_t EvalLinEvent() const;
0177 Double_t EvalLinEvent( const std::vector<Double_t> & coefs ) const;
0178 Double_t EvalLinEvent( const Event &e );
0179 Double_t EvalLinEvent( const Event &e, UInt_t vind );
0180 Double_t EvalLinEvent( const Event &e, const std::vector<Double_t> & coefs );
0181
0182
0183 Double_t EvalLinEvent( UInt_t evtidx ) const;
0184 Double_t EvalLinEvent( UInt_t evtidx, const std::vector<Double_t> & coefs ) const;
0185 Double_t EvalLinEvent( UInt_t evtidx, UInt_t vind ) const;
0186 Double_t EvalLinEvent( UInt_t evtidx, UInt_t vind, Double_t coefs ) const;
0187
0188
0189 Double_t EvalLinEventRaw( UInt_t vind, const Event &e, Bool_t norm ) const;
0190 Double_t EvalLinEventRaw( UInt_t vind, UInt_t evtidx, Bool_t norm ) const;
0191
0192
0193 Double_t PdfLinear( Double_t & nsig, Double_t & ntot ) const;
0194
0195
0196 Double_t PdfRule( Double_t & nsig, Double_t & ntot ) const;
0197
0198
0199 Double_t FStar() const;
0200 Double_t FStar(const TMVA::Event & e );
0201
0202
0203 void SetImportanceRef(Double_t impref);
0204
0205
0206 void CalcRuleSupport();
0207
0208
0209 void CalcImportance();
0210
0211
0212 Double_t CalcRuleImportance();
0213
0214
0215 Double_t CalcLinImportance();
0216
0217
0218 void CalcVarImportance();
0219
0220
0221 void CleanupRules();
0222
0223
0224 void CleanupLinear();
0225
0226
0227 void RemoveSimilarRules();
0228
0229
0230 void RuleStatistics();
0231
0232
0233 void RuleResponseStats();
0234
0235
0236 void operator=( const RuleEnsemble& other ) { Copy( other ); }
0237
0238
0239 Double_t CoefficientRadius();
0240
0241
0242 void GetCoefficients( std::vector< Double_t >& v );
0243
0244
0245 const MethodRuleFit* GetMethodRuleFit() const;
0246 const MethodBase* GetMethodBase() const;
0247 const RuleFit* GetRuleFit() const { return fRuleFit; }
0248
0249 const std::vector<const TMVA::Event *>* GetTrainingEvents() const;
0250 const Event* GetTrainingEvent(UInt_t i) const;
0251 const Event* GetEvent() const { return fEvent; }
0252
0253 Bool_t DoLinear() const { return (fLearningModel==kFull) || (fLearningModel==kLinear); }
0254 Bool_t DoRules() const { return (fLearningModel==kFull) || (fLearningModel==kRules); }
0255 Bool_t DoOnlyRules() const { return (fLearningModel==kRules); }
0256 Bool_t DoOnlyLinear() const { return (fLearningModel==kLinear); }
0257 Bool_t DoFull() const { return (fLearningModel==kFull); }
0258 ELearningModel GetLearningModel() const { return fLearningModel; }
0259 Double_t GetImportanceCut() const { return fImportanceCut; }
0260 Double_t GetImportanceRef() const { return fImportanceRef; }
0261 Double_t GetOffset() const { return fOffset; }
0262 UInt_t GetNRules() const { return (DoRules() ? fRules.size():0); }
0263 const std::vector<TMVA::Rule*>& GetRulesConst() const { return fRules; }
0264 std::vector<TMVA::Rule*>& GetRules() { return fRules; }
0265 const std::vector< Double_t >& GetLinCoefficients() const { return fLinCoefficients; }
0266 const std::vector< Double_t >& GetLinNorm() const { return fLinNorm; }
0267 const std::vector< Double_t >& GetLinImportance() const { return fLinImportance; }
0268 const std::vector< Double_t >& GetVarImportance() const { return fVarImportance; }
0269 UInt_t GetNLinear() const { return (DoLinear() ? fLinNorm.size():0); }
0270 Double_t GetLinQuantile() const { return fLinQuantile; }
0271
0272 const Rule *GetRulesConst(int i) const { return fRules[i]; }
0273 Rule *GetRules(int i) { return fRules[i]; }
0274
0275 UInt_t GetRulesNCuts(int i) const { return fRules[i]->GetRuleCut()->GetNcuts(); }
0276 Double_t GetRuleMinDist() const { return fRuleMinDist; }
0277 Double_t GetLinCoefficients(int i) const { return fLinCoefficients[i]; }
0278 Double_t GetLinNorm(int i) const { return fLinNorm[i]; }
0279 Double_t GetLinDM(int i) const { return fLinDM[i]; }
0280 Double_t GetLinDP(int i) const { return fLinDP[i]; }
0281 Double_t GetLinImportance(int i) const { return fLinImportance[i]; }
0282 Double_t GetVarImportance(int i) const { return fVarImportance[i]; }
0283 Double_t GetRulePTag(int i) const { return fRulePTag[i]; }
0284 Double_t GetRulePSS(int i) const { return fRulePSS[i]; }
0285 Double_t GetRulePSB(int i) const { return fRulePSB[i]; }
0286 Double_t GetRulePBS(int i) const { return fRulePBS[i]; }
0287 Double_t GetRulePBB(int i) const { return fRulePBB[i]; }
0288
0289 Bool_t IsLinTermOK(int i) const { return fLinTermOK[i]; }
0290
0291 Double_t GetAverageSupport() const { return fAverageSupport; }
0292 Double_t GetAverageRuleSigma() const { return fAverageRuleSigma; }
0293 Double_t GetEventRuleVal(UInt_t i) const { return (fEventRuleVal[i] ? 1.0:0.0); }
0294 Double_t GetEventLinearVal(UInt_t i) const { return fEventLinearVal[i]; }
0295 Double_t GetEventLinearValNorm(UInt_t i) const { return fEventLinearVal[i]*fLinNorm[i]; }
0296
0297 const std::vector<UInt_t> & GetEventRuleMap(UInt_t evtidx) const { return fRuleMap[evtidx]; }
0298 const TMVA::Event *GetRuleMapEvent(UInt_t evtidx) const { return (*fRuleMapEvents)[evtidx]; }
0299 Bool_t IsRuleMapOK() const { return fRuleMapOK; }
0300
0301
0302 void PrintRuleGen() const;
0303
0304
0305 void Print() const;
0306
0307
0308 void PrintRaw ( std::ostream& os ) const;
0309 void* AddXMLTo ( void* parent ) const;
0310
0311
0312 void ReadRaw ( std::istream& istr );
0313 void ReadFromXML( void* wghtnode );
0314
0315
0316 private:
0317
0318
0319 void DeleteRules() { for (UInt_t i=0; i<fRules.size(); i++) delete fRules[i]; fRules.clear(); }
0320
0321
0322 void Copy( RuleEnsemble const& other );
0323
0324
0325 void ResetCoefficients();
0326
0327
0328 void MakeRulesFromTree( const DecisionTree *dtree );
0329
0330
0331 void AddRule( const Node *node );
0332
0333
0334 Rule *MakeTheRule( const Node *node );
0335
0336
0337 ELearningModel fLearningModel;
0338 Double_t fImportanceCut;
0339 Double_t fLinQuantile;
0340 Double_t fOffset;
0341 std::vector< TMVA::Rule* > fRules;
0342 std::vector< Char_t > fLinTermOK;
0343 std::vector< Double_t > fLinDP;
0344 std::vector< Double_t > fLinDM;
0345 std::vector< Double_t > fLinCoefficients;
0346 std::vector< Double_t > fLinNorm;
0347 std::vector< TH1F* > fLinPDFB;
0348 std::vector< TH1F* > fLinPDFS;
0349 std::vector< Double_t > fLinImportance;
0350 std::vector< Double_t > fVarImportance;
0351 Double_t fImportanceRef;
0352 Double_t fAverageSupport;
0353 Double_t fAverageRuleSigma;
0354
0355 std::vector< Double_t > fRuleVarFrac;
0356 std::vector< Double_t > fRulePSS;
0357 std::vector< Double_t > fRulePSB;
0358 std::vector< Double_t > fRulePBS;
0359 std::vector< Double_t > fRulePBB;
0360 std::vector< Double_t > fRulePTag;
0361 Double_t fRuleFSig;
0362 Double_t fRuleNCave;
0363 Double_t fRuleNCsig;
0364
0365 Double_t fRuleMinDist;
0366 UInt_t fNRulesGenerated;
0367
0368 const Event* fEvent;
0369 Bool_t fEventCacheOK;
0370 std::vector<Char_t> fEventRuleVal;
0371 std::vector<Double_t> fEventLinearVal;
0372
0373 Bool_t fRuleMapOK;
0374 std::vector< std::vector<UInt_t> > fRuleMap;
0375 UInt_t fRuleMapInd0;
0376 UInt_t fRuleMapInd1;
0377 const std::vector<const TMVA::Event *> *fRuleMapEvents;
0378
0379 const RuleFit* fRuleFit;
0380
0381 mutable MsgLogger* fLogger;
0382 MsgLogger& Log() const { return *fLogger; }
0383 };
0384 }
0385
0386
0387 inline void TMVA::RuleEnsemble::UpdateEventVal()
0388 {
0389
0390
0391
0392 if (fEventCacheOK) return;
0393
0394 if (DoRules()) {
0395 UInt_t nrules = fRules.size();
0396 fEventRuleVal.resize(nrules,kFALSE);
0397 for (UInt_t r=0; r<nrules; r++) {
0398 fEventRuleVal[r] = fRules[r]->EvalEvent(*fEvent);
0399 }
0400 }
0401 if (DoLinear()) {
0402 UInt_t nlin = fLinTermOK.size();
0403 fEventLinearVal.resize(nlin,0);
0404 for (UInt_t r=0; r<nlin; r++) {
0405 fEventLinearVal[r] = EvalLinEventRaw(r,*fEvent,kFALSE);
0406 }
0407 }
0408 fEventCacheOK = kTRUE;
0409 }
0410
0411
0412 inline Double_t TMVA::RuleEnsemble::EvalEvent() const
0413 {
0414
0415
0416 Int_t nrules = fRules.size();
0417 Double_t rval=fOffset;
0418 Double_t linear=0;
0419
0420
0421
0422
0423 if (DoRules()) {
0424 for ( Int_t i=0; i<nrules; i++ ) {
0425 if (fEventRuleVal[i])
0426 rval += fRules[i]->GetCoefficient();
0427 }
0428 }
0429
0430
0431
0432 if (DoLinear()) linear = EvalLinEvent();
0433 rval +=linear;
0434
0435 return rval;
0436 }
0437
0438
0439 inline Double_t TMVA::RuleEnsemble::EvalEvent( Double_t ofs,
0440 const std::vector<Double_t> & coefs,
0441 const std::vector<Double_t> & lincoefs ) const
0442 {
0443
0444
0445 Int_t nrules = fRules.size();
0446 Double_t rval = ofs;
0447 Double_t linear = 0;
0448
0449
0450
0451 if (DoRules()) {
0452 for ( Int_t i=0; i<nrules; i++ ) {
0453 if (fEventRuleVal[i])
0454 rval += coefs[i];
0455 }
0456 }
0457
0458
0459
0460 if (DoLinear()) linear = EvalLinEvent(lincoefs);
0461 rval +=linear;
0462
0463 return rval;
0464 }
0465
0466
0467 inline Double_t TMVA::RuleEnsemble::EvalEvent(const TMVA::Event & e)
0468 {
0469
0470 SetEvent(e);
0471 UpdateEventVal();
0472 return EvalEvent();
0473 }
0474
0475
0476 inline Double_t TMVA::RuleEnsemble::EvalEvent(const TMVA::Event & e,
0477 Double_t ofs,
0478 const std::vector<Double_t> & coefs,
0479 const std::vector<Double_t> & lincoefs )
0480 {
0481
0482 SetEvent(e);
0483 UpdateEventVal();
0484 return EvalEvent(ofs,coefs,lincoefs);
0485 }
0486
0487
0488 inline Double_t TMVA::RuleEnsemble::EvalEvent(UInt_t evtidx) const
0489 {
0490
0491 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
0492
0493 Double_t rval=fOffset;
0494 if (DoRules()) {
0495 UInt_t nrules = fRuleMap[evtidx].size();
0496 UInt_t rind;
0497 for (UInt_t ir = 0; ir<nrules; ir++) {
0498 rind = fRuleMap[evtidx][ir];
0499 rval += fRules[rind]->GetCoefficient();
0500 }
0501 }
0502 if (DoLinear()) {
0503 UInt_t nlin = fLinTermOK.size();
0504 for (UInt_t r=0; r<nlin; r++) {
0505 if (fLinTermOK[r]) {
0506 rval += fLinCoefficients[r] * EvalLinEventRaw(r,*(*fRuleMapEvents)[evtidx],kTRUE);
0507 }
0508 }
0509 }
0510 return rval;
0511 }
0512
0513
0514 inline Double_t TMVA::RuleEnsemble::EvalEvent(UInt_t evtidx,
0515 Double_t ofs,
0516 const std::vector<Double_t> & coefs,
0517 const std::vector<Double_t> & lincoefs ) const
0518 {
0519
0520
0521 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
0522 Double_t rval=ofs;
0523 if (DoRules()) {
0524 UInt_t nrules = fRuleMap[evtidx].size();
0525 UInt_t rind;
0526 for (UInt_t ir = 0; ir<nrules; ir++) {
0527 rind = fRuleMap[evtidx][ir];
0528 rval += coefs[rind];
0529 }
0530 }
0531 if (DoLinear()) {
0532 rval += EvalLinEvent( evtidx, lincoefs );
0533 }
0534 return rval;
0535 }
0536
0537
0538 inline Double_t TMVA::RuleEnsemble::EvalLinEventRaw( UInt_t vind, const TMVA::Event & e, Bool_t norm) const
0539 {
0540
0541
0542 Double_t val = e.GetValue(vind);
0543 Double_t rval = TMath::Min( fLinDP[vind], TMath::Max( fLinDM[vind], val ) );
0544 if (norm) rval *= fLinNorm[vind];
0545 return rval;
0546 }
0547
0548
0549 inline Double_t TMVA::RuleEnsemble::EvalLinEventRaw( UInt_t vind, UInt_t evtidx, Bool_t norm) const
0550 {
0551
0552
0553 Double_t val = (*fRuleMapEvents)[evtidx]->GetValue(vind);
0554 Double_t rval = TMath::Min( fLinDP[vind], TMath::Max( fLinDM[vind], val ) );
0555 if (norm) rval *= fLinNorm[vind];
0556 return rval;
0557 }
0558
0559
0560 inline Double_t TMVA::RuleEnsemble::EvalLinEvent() const
0561 {
0562
0563
0564 Double_t rval=0;
0565 for (UInt_t v=0; v<fLinTermOK.size(); v++) {
0566 if (fLinTermOK[v])
0567 rval += fLinCoefficients[v]*fEventLinearVal[v]*fLinNorm[v];
0568 }
0569 return rval;
0570 }
0571
0572
0573 inline Double_t TMVA::RuleEnsemble::EvalLinEvent(const std::vector<Double_t> & coefs) const
0574 {
0575
0576
0577 Double_t rval=0;
0578 for (UInt_t v=0; v<fLinTermOK.size(); v++) {
0579 if (fLinTermOK[v])
0580 rval += coefs[v]*fEventLinearVal[v]*fLinNorm[v];
0581 }
0582 return rval;
0583 }
0584
0585
0586 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( const TMVA::Event& e )
0587 {
0588
0589
0590 SetEvent(e);
0591 UpdateEventVal();
0592 return EvalLinEvent();
0593 }
0594
0595
0596 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( const TMVA::Event& e, UInt_t vind )
0597 {
0598
0599
0600 SetEvent(e);
0601 UpdateEventVal();
0602 return GetEventLinearValNorm(vind);
0603 }
0604
0605
0606 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( const TMVA::Event& e, const std::vector<Double_t> & coefs )
0607 {
0608
0609
0610 SetEvent(e);
0611 UpdateEventVal();
0612 return EvalLinEvent(coefs);
0613 }
0614
0615
0616 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx, const std::vector<Double_t> & coefs ) const
0617 {
0618
0619 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
0620 Double_t rval=0;
0621 UInt_t nlin = fLinTermOK.size();
0622 for (UInt_t r=0; r<nlin; r++) {
0623 if (fLinTermOK[r]) {
0624 rval += coefs[r] * EvalLinEventRaw(r,*(*fRuleMapEvents)[evtidx],kTRUE);
0625 }
0626 }
0627 return rval;
0628 }
0629
0630
0631 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx ) const
0632 {
0633
0634 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
0635 Double_t rval=0;
0636 UInt_t nlin = fLinTermOK.size();
0637 for (UInt_t r=0; r<nlin; r++) {
0638 if (fLinTermOK[r]) {
0639 rval += fLinCoefficients[r] * EvalLinEventRaw(r,*(*fRuleMapEvents)[evtidx],kTRUE);
0640 }
0641 }
0642 return rval;
0643 }
0644
0645
0646 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx, UInt_t vind ) const
0647 {
0648
0649 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
0650 Double_t rval;
0651 rval = fLinCoefficients[vind] * EvalLinEventRaw(vind,*(*fRuleMapEvents)[evtidx],kTRUE);
0652 return rval;
0653 }
0654
0655
0656 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx, UInt_t vind, Double_t coefs ) const
0657 {
0658
0659 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
0660 Double_t rval;
0661 rval = coefs * EvalLinEventRaw(vind,*(*fRuleMapEvents)[evtidx],kTRUE);
0662 return rval;
0663 }
0664
0665 #endif