Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:58

0001 // @(#)root/tmva $Id$
0002 // Author: Omar Zapata, Thomas James Stevenson, Pourya Vakilipourtakalou, Kim Albertsson
0003 
0004 /*************************************************************************
0005  * Copyright (C) 2018, Rene Brun and Fons Rademakers.                    *
0006  * All rights reserved.                                                  *
0007  *                                                                       *
0008  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0009  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0010  *************************************************************************/
0011 
0012 #ifndef ROOT_TMVA_CROSS_EVALUATION
0013 #define ROOT_TMVA_CROSS_EVALUATION
0014 
0015 #include "TGraph.h"
0016 #include "TMultiGraph.h"
0017 #include "TString.h"
0018 #include <vector>
0019 #include <map>
0020 
0021 #include "TMVA/IMethod.h"
0022 #include "TMVA/Configurable.h"
0023 #include "TMVA/Types.h"
0024 #include "TMVA/DataSet.h"
0025 #include "TMVA/Event.h"
0026 #include <TMVA/Results.h>
0027 #include <TMVA/Factory.h>
0028 #include <TMVA/DataLoader.h>
0029 #include <TMVA/OptionMap.h>
0030 #include <TMVA/Envelope.h>
0031 
0032 /*! \class TMVA::CrossValidationResult
0033  * Class to save the results of cross validation,
0034  * the metric for the classification ins ROC and you can ROC curves
0035  * ROC integrals, ROC average and ROC standard deviation.
0036 \ingroup TMVA
0037 */
0038 
0039 /*! \class TMVA::CrossValidation
0040  * Class to perform cross validation, splitting the dataloader into folds.
0041 \ingroup TMVA
0042 */
0043 
0044 namespace TMVA {
0045 
0046 class CvSplitKFolds;
0047 
0048 using EventCollection_t = std::vector<Event *>;
0049 using EventTypes_t = std::vector<Bool_t>;
0050 using EventOutputs_t = std::vector<Float_t>;
0051 using EventOutputsMulticlass_t = std::vector<std::vector<Float_t>>;
0052 
0053 class CrossValidationFoldResult {
0054 public:
0055    CrossValidationFoldResult() {} // For multi-proc serialisation
0056    CrossValidationFoldResult(UInt_t iFold)
0057    : fFold(iFold)
0058    {}
0059 
0060    UInt_t fFold;
0061 
0062    Float_t fROCIntegral;
0063    TGraph fROC;
0064 
0065    Double_t fSig;
0066    Double_t fSep;
0067    Double_t fEff01;
0068    Double_t fEff10;
0069    Double_t fEff30;
0070    Double_t fEffArea;
0071    Double_t fTrainEff01;
0072    Double_t fTrainEff10;
0073    Double_t fTrainEff30;
0074 };
0075 
0076 // Used internally to keep per-fold aggregate statistics
0077 // such as ROC curves, ROC integrals and efficiencies.
0078 class CrossValidationResult {
0079    friend class CrossValidation;
0080 
0081 private:
0082    std::map<UInt_t, Float_t> fROCs;
0083    std::shared_ptr<TMultiGraph> fROCCurves;
0084 
0085    std::vector<Double_t> fSigs;
0086    std::vector<Double_t> fSeps;
0087    std::vector<Double_t> fEff01s;
0088    std::vector<Double_t> fEff10s;
0089    std::vector<Double_t> fEff30s;
0090    std::vector<Double_t> fEffAreas;
0091    std::vector<Double_t> fTrainEff01s;
0092    std::vector<Double_t> fTrainEff10s;
0093    std::vector<Double_t> fTrainEff30s;
0094 
0095 public:
0096    CrossValidationResult(UInt_t numFolds);
0097    CrossValidationResult(const CrossValidationResult &);
0098    ~CrossValidationResult() { fROCCurves = nullptr; }
0099 
0100    std::map<UInt_t, Float_t> GetROCValues() const { return fROCs; }
0101    Float_t GetROCAverage() const;
0102    Float_t GetROCStandardDeviation() const;
0103    TMultiGraph *GetROCCurves(Bool_t fLegend = kTRUE);
0104    TGraph *GetAvgROCCurve(UInt_t numSamples = 100) const;
0105    void Print() const;
0106 
0107    TCanvas *Draw(const TString name = "CrossValidation") const;
0108    TCanvas *DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const;
0109 
0110    std::vector<Double_t> GetSigValues() const { return fSigs; }
0111    std::vector<Double_t> GetSepValues() const { return fSeps; }
0112    std::vector<Double_t> GetEff01Values() const { return fEff01s; }
0113    std::vector<Double_t> GetEff10Values() const { return fEff10s; }
0114    std::vector<Double_t> GetEff30Values() const { return fEff30s; }
0115    std::vector<Double_t> GetEffAreaValues() const { return fEffAreas; }
0116    std::vector<Double_t> GetTrainEff01Values() const { return fTrainEff01s; }
0117    std::vector<Double_t> GetTrainEff10Values() const { return fTrainEff10s; }
0118    std::vector<Double_t> GetTrainEff30Values() const { return fTrainEff30s; }
0119 
0120 private:
0121    void Fill(CrossValidationFoldResult const & fr);
0122 };
0123 
0124 class CrossValidation : public Envelope {
0125 
0126 public:
0127    explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options);
0128    explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TFile *outputFile, TString options);
0129    ~CrossValidation();
0130 
0131    void InitOptions();
0132    void ParseOptions();
0133 
0134    void SetNumFolds(UInt_t i);
0135    void SetSplitExpr(TString splitExpr);
0136 
0137    UInt_t GetNumFolds() { return fNumFolds; }
0138    TString GetSplitExpr() { return fSplitExprString; }
0139 
0140    Factory &GetFactory() { return *fFactory; }
0141 
0142    const std::vector<CrossValidationResult> &GetResults() const;
0143 
0144    void Evaluate();
0145 
0146 private:
0147    CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap & methodInfo);
0148 
0149    Types::EAnalysisType fAnalysisType;
0150    TString fAnalysisTypeStr;
0151    TString fSplitTypeStr;
0152    Bool_t fCorrelations;
0153    TString fCvFactoryOptions;
0154    Bool_t fDrawProgressBar;
0155    Bool_t fFoldFileOutput;    ///<! If true: generate output file for each fold
0156    Bool_t fFoldStatus;        ///<! If true: dataset is prepared
0157    TString fJobName;
0158    UInt_t fNumFolds;          ///<! Number of folds to prepare
0159    UInt_t fNumWorkerProcs;    ///<! Number of processes to use for fold evaluation. (Default, no parallel evaluation)
0160    TString fOutputFactoryOptions;
0161    TString fOutputEnsembling; ///<! How to combine output of individual folds
0162    TFile *fOutputFile;
0163    Bool_t fSilent;
0164    TString fSplitExprString;
0165    std::vector<CrossValidationResult> fResults; ///<!
0166    Bool_t fROC;
0167    TString fTransformations;
0168    Bool_t fVerbose;
0169    TString fVerboseLevel;
0170 
0171    std::unique_ptr<Factory> fFoldFactory;
0172    std::unique_ptr<Factory> fFactory;
0173    std::unique_ptr<CvSplitKFolds> fSplit;
0174 
0175    ClassDef(CrossValidation, 0);
0176    };
0177 
0178 } // namespace TMVA
0179 
0180 #endif // ROOT_TMVA_CROSS_EVALUATION