File indexing completed on 2025-01-30 10:22:52
0001
0002
0003
0004
0005 #ifndef ROOT_TMVA_HyperParameterOptimisation
0006 #define ROOT_TMVA_HyperParameterOptimisation
0007
0008
0009 #include "TString.h"
0010 #include <vector>
0011 #include <map>
0012
0013 #include "TMultiGraph.h"
0014
0015 #include "TMVA/IMethod.h"
0016 #include "TMVA/Configurable.h"
0017 #include "TMVA/Types.h"
0018 #include "TMVA/DataSet.h"
0019 #include "TMVA/Event.h"
0020 #include <TMVA/Results.h>
0021
0022 #include <TMVA/Factory.h>
0023
0024 #include <TMVA/DataLoader.h>
0025
0026 #include <TMVA/Envelope.h>
0027
0028 namespace TMVA {
0029
0030 class HyperParameterOptimisationResult
0031 {
0032 friend class HyperParameterOptimisation;
0033 private:
0034 Float_t fROCAVG;
0035 std::vector<Float_t> fROCs;
0036 std::vector<Double_t> fSigs;
0037 std::vector<Double_t> fSeps;
0038 std::vector<Double_t> fEff01s;
0039 std::vector<Double_t> fEff10s;
0040 std::vector<Double_t> fEff30s;
0041 std::vector<Double_t> fEffAreas;
0042 std::vector<Double_t> fTrainEff01s;
0043 std::vector<Double_t> fTrainEff10s;
0044 std::vector<Double_t> fTrainEff30s;
0045 std::shared_ptr<TMultiGraph> fROCCurves;
0046 TString fMethodName;
0047
0048 public:
0049 HyperParameterOptimisationResult();
0050 ~HyperParameterOptimisationResult();
0051
0052 std::vector<std::map<TString,Double_t> > fFoldParameters;
0053
0054 std::vector<Float_t> GetROCValues(){return fROCs;}
0055 Float_t GetROCAverage(){return fROCAVG;}
0056 TMultiGraph *GetROCCurves(Bool_t fLegend=kTRUE);
0057
0058 void Print() const ;
0059
0060
0061 std::vector<Double_t> GetSigValues(){return fSigs;}
0062 std::vector<Double_t> GetSepValues(){return fSeps;}
0063 std::vector<Double_t> GetEff01Values(){return fEff01s;}
0064 std::vector<Double_t> GetEff10Values(){return fEff10s;}
0065 std::vector<Double_t> GetEff30Values(){return fEff30s;}
0066 std::vector<Double_t> GetEffAreaValues(){return fEffAreas;}
0067 std::vector<Double_t> GetTrainEff01Values(){return fTrainEff01s;}
0068 std::vector<Double_t> GetTrainEff10Values(){return fTrainEff10s;}
0069 std::vector<Double_t> GetTrainEff30Values(){return fTrainEff30s;}
0070
0071 };
0072
0073 class HyperParameterOptimisation : public Envelope {
0074 public:
0075
0076 HyperParameterOptimisation(DataLoader *dataloader);
0077 ~HyperParameterOptimisation();
0078
0079 void SetFitter(TString fitType){fFitType=fitType;}
0080 TString GetFiiter(){return fFitType;}
0081
0082
0083
0084 void SetFOMType(TString ftype){fFomType=ftype;}
0085 TString GetFOMType(){return fFitType;}
0086
0087 void SetNumFolds(UInt_t folds);
0088 UInt_t GetNumFolds(){return fNumFolds;}
0089
0090 virtual void Evaluate();
0091 const HyperParameterOptimisationResult& GetResults() const {return fResults;}
0092
0093
0094 private:
0095 TString fFomType;
0096 TString fFitType;
0097 UInt_t fNumFolds;
0098 Bool_t fFoldStatus;
0099 HyperParameterOptimisationResult fResults;
0100 std::unique_ptr<Factory> fClassifier;
0101
0102 public:
0103 ClassDef(HyperParameterOptimisation,0);
0104 };
0105 }
0106
0107
0108 #endif
0109
0110
0111