File indexing completed on 2025-01-18 10:10:58
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012 #ifndef ROOT_TMVA_CROSS_EVALUATION
0013 #define ROOT_TMVA_CROSS_EVALUATION
0014
0015 #include "TGraph.h"
0016 #include "TMultiGraph.h"
0017 #include "TString.h"
0018 #include <vector>
0019 #include <map>
0020
0021 #include "TMVA/IMethod.h"
0022 #include "TMVA/Configurable.h"
0023 #include "TMVA/Types.h"
0024 #include "TMVA/DataSet.h"
0025 #include "TMVA/Event.h"
0026 #include <TMVA/Results.h>
0027 #include <TMVA/Factory.h>
0028 #include <TMVA/DataLoader.h>
0029 #include <TMVA/OptionMap.h>
0030 #include <TMVA/Envelope.h>
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044 namespace TMVA {
0045
0046 class CvSplitKFolds;
0047
0048 using EventCollection_t = std::vector<Event *>;
0049 using EventTypes_t = std::vector<Bool_t>;
0050 using EventOutputs_t = std::vector<Float_t>;
0051 using EventOutputsMulticlass_t = std::vector<std::vector<Float_t>>;
0052
0053 class CrossValidationFoldResult {
0054 public:
0055 CrossValidationFoldResult() {}
0056 CrossValidationFoldResult(UInt_t iFold)
0057 : fFold(iFold)
0058 {}
0059
0060 UInt_t fFold;
0061
0062 Float_t fROCIntegral;
0063 TGraph fROC;
0064
0065 Double_t fSig;
0066 Double_t fSep;
0067 Double_t fEff01;
0068 Double_t fEff10;
0069 Double_t fEff30;
0070 Double_t fEffArea;
0071 Double_t fTrainEff01;
0072 Double_t fTrainEff10;
0073 Double_t fTrainEff30;
0074 };
0075
0076
0077
0078 class CrossValidationResult {
0079 friend class CrossValidation;
0080
0081 private:
0082 std::map<UInt_t, Float_t> fROCs;
0083 std::shared_ptr<TMultiGraph> fROCCurves;
0084
0085 std::vector<Double_t> fSigs;
0086 std::vector<Double_t> fSeps;
0087 std::vector<Double_t> fEff01s;
0088 std::vector<Double_t> fEff10s;
0089 std::vector<Double_t> fEff30s;
0090 std::vector<Double_t> fEffAreas;
0091 std::vector<Double_t> fTrainEff01s;
0092 std::vector<Double_t> fTrainEff10s;
0093 std::vector<Double_t> fTrainEff30s;
0094
0095 public:
0096 CrossValidationResult(UInt_t numFolds);
0097 CrossValidationResult(const CrossValidationResult &);
0098 ~CrossValidationResult() { fROCCurves = nullptr; }
0099
0100 std::map<UInt_t, Float_t> GetROCValues() const { return fROCs; }
0101 Float_t GetROCAverage() const;
0102 Float_t GetROCStandardDeviation() const;
0103 TMultiGraph *GetROCCurves(Bool_t fLegend = kTRUE);
0104 TGraph *GetAvgROCCurve(UInt_t numSamples = 100) const;
0105 void Print() const;
0106
0107 TCanvas *Draw(const TString name = "CrossValidation") const;
0108 TCanvas *DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const;
0109
0110 std::vector<Double_t> GetSigValues() const { return fSigs; }
0111 std::vector<Double_t> GetSepValues() const { return fSeps; }
0112 std::vector<Double_t> GetEff01Values() const { return fEff01s; }
0113 std::vector<Double_t> GetEff10Values() const { return fEff10s; }
0114 std::vector<Double_t> GetEff30Values() const { return fEff30s; }
0115 std::vector<Double_t> GetEffAreaValues() const { return fEffAreas; }
0116 std::vector<Double_t> GetTrainEff01Values() const { return fTrainEff01s; }
0117 std::vector<Double_t> GetTrainEff10Values() const { return fTrainEff10s; }
0118 std::vector<Double_t> GetTrainEff30Values() const { return fTrainEff30s; }
0119
0120 private:
0121 void Fill(CrossValidationFoldResult const & fr);
0122 };
0123
0124 class CrossValidation : public Envelope {
0125
0126 public:
0127 explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options);
0128 explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TFile *outputFile, TString options);
0129 ~CrossValidation();
0130
0131 void InitOptions();
0132 void ParseOptions();
0133
0134 void SetNumFolds(UInt_t i);
0135 void SetSplitExpr(TString splitExpr);
0136
0137 UInt_t GetNumFolds() { return fNumFolds; }
0138 TString GetSplitExpr() { return fSplitExprString; }
0139
0140 Factory &GetFactory() { return *fFactory; }
0141
0142 const std::vector<CrossValidationResult> &GetResults() const;
0143
0144 void Evaluate();
0145
0146 private:
0147 CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap & methodInfo);
0148
0149 Types::EAnalysisType fAnalysisType;
0150 TString fAnalysisTypeStr;
0151 TString fSplitTypeStr;
0152 Bool_t fCorrelations;
0153 TString fCvFactoryOptions;
0154 Bool_t fDrawProgressBar;
0155 Bool_t fFoldFileOutput;
0156 Bool_t fFoldStatus;
0157 TString fJobName;
0158 UInt_t fNumFolds;
0159 UInt_t fNumWorkerProcs;
0160 TString fOutputFactoryOptions;
0161 TString fOutputEnsembling;
0162 TFile *fOutputFile;
0163 Bool_t fSilent;
0164 TString fSplitExprString;
0165 std::vector<CrossValidationResult> fResults;
0166 Bool_t fROC;
0167 TString fTransformations;
0168 Bool_t fVerbose;
0169 TString fVerboseLevel;
0170
0171 std::unique_ptr<Factory> fFoldFactory;
0172 std::unique_ptr<Factory> fFactory;
0173 std::unique_ptr<CvSplitKFolds> fSplit;
0174
0175 ClassDef(CrossValidation, 0);
0176 };
0177
0178 }
0179
0180 #endif