Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:52

0001 // @(#)root/tmva $Id$
0002 // Author: Omar Zapata, Thomas James Stevenson.
0003 
0004 
0005 #ifndef ROOT_TMVA_HyperParameterOptimisation
0006 #define ROOT_TMVA_HyperParameterOptimisation
0007 
0008 
0009 #include "TString.h"
0010 #include <vector>
0011 #include <map>
0012 
0013 #include "TMultiGraph.h"
0014 
0015 #include "TMVA/IMethod.h"
0016 #include "TMVA/Configurable.h"
0017 #include "TMVA/Types.h"
0018 #include "TMVA/DataSet.h"
0019 #include "TMVA/Event.h"
0020 #include <TMVA/Results.h>
0021 
0022 #include <TMVA/Factory.h>
0023 
0024 #include <TMVA/DataLoader.h>
0025 
0026 #include <TMVA/Envelope.h>
0027 
0028 namespace TMVA {
0029 
0030    class HyperParameterOptimisationResult
0031    {
0032      friend class HyperParameterOptimisation;
0033    private:
0034       Float_t fROCAVG;
0035       std::vector<Float_t> fROCs;
0036       std::vector<Double_t> fSigs;
0037       std::vector<Double_t> fSeps;
0038       std::vector<Double_t> fEff01s;
0039       std::vector<Double_t> fEff10s;
0040       std::vector<Double_t> fEff30s;
0041       std::vector<Double_t> fEffAreas;
0042       std::vector<Double_t> fTrainEff01s;
0043       std::vector<Double_t> fTrainEff10s;
0044       std::vector<Double_t> fTrainEff30s;
0045       std::shared_ptr<TMultiGraph> fROCCurves;
0046       TString fMethodName;
0047 
0048    public:
0049        HyperParameterOptimisationResult();
0050        ~HyperParameterOptimisationResult();
0051 
0052        std::vector<std::map<TString,Double_t> > fFoldParameters;
0053 
0054        std::vector<Float_t> GetROCValues(){return fROCs;}
0055        Float_t GetROCAverage(){return fROCAVG;}
0056        TMultiGraph *GetROCCurves(Bool_t fLegend=kTRUE);
0057 
0058        void Print() const ;
0059 //        TCanvas* Draw(const TString name="HyperParameterOptimisation") const;
0060 
0061        std::vector<Double_t> GetSigValues(){return fSigs;}
0062        std::vector<Double_t> GetSepValues(){return fSeps;}
0063        std::vector<Double_t> GetEff01Values(){return fEff01s;}
0064        std::vector<Double_t> GetEff10Values(){return fEff10s;}
0065        std::vector<Double_t> GetEff30Values(){return fEff30s;}
0066        std::vector<Double_t> GetEffAreaValues(){return fEffAreas;}
0067        std::vector<Double_t> GetTrainEff01Values(){return fTrainEff01s;}
0068        std::vector<Double_t> GetTrainEff10Values(){return fTrainEff10s;}
0069        std::vector<Double_t> GetTrainEff30Values(){return fTrainEff30s;}
0070 
0071    };
0072 
0073    class HyperParameterOptimisation : public Envelope {
0074    public:
0075 
0076        HyperParameterOptimisation(DataLoader *dataloader);
0077        ~HyperParameterOptimisation();
0078 
0079        void SetFitter(TString fitType){fFitType=fitType;}
0080        TString GetFiiter(){return fFitType;}
0081 
0082 
0083        //Figure of Merit (FOM) default Separation
0084        void SetFOMType(TString ftype){fFomType=ftype;}
0085        TString GetFOMType(){return fFitType;}
0086 
0087        void SetNumFolds(UInt_t folds);
0088        UInt_t GetNumFolds(){return fNumFolds;}
0089 
0090        virtual void Evaluate();
0091        const HyperParameterOptimisationResult& GetResults() const {return fResults;}
0092 
0093 
0094    private:
0095        TString                           fFomType;     ///<!
0096        TString                           fFitType;     ///<!
0097        UInt_t                            fNumFolds;    ///<!
0098        Bool_t                            fFoldStatus;  ///<!
0099        HyperParameterOptimisationResult  fResults;     ///<!
0100        std::unique_ptr<Factory>          fClassifier;  ///<!
0101 
0102    public:
0103        ClassDef(HyperParameterOptimisation,0);
0104    };
0105 }
0106 
0107 
0108 #endif
0109 
0110 
0111