File indexing completed on 2025-01-18 10:10:57
0001
0002
0003
0004
0005 #ifndef ROOT_TMVA_Classification
0006 #define ROOT_TMVA_Classification
0007
0008 #include <TString.h>
0009 #include <TMultiGraph.h>
0010 #include <vector>
0011 #include <map>
0012
0013 #include <TMVA/IMethod.h>
0014 #include <TMVA/MethodBase.h>
0015 #include <TMVA/Configurable.h>
0016 #include <TMVA/Types.h>
0017 #include <TMVA/DataSet.h>
0018 #include <TMVA/Event.h>
0019 #include <TMVA/Results.h>
0020 #include <TMVA/ResultsClassification.h>
0021 #include <TMVA/ResultsMulticlass.h>
0022 #include <TMVA/Factory.h>
0023 #include <TMVA/DataLoader.h>
0024 #include <TMVA/OptionMap.h>
0025 #include <TMVA/Envelope.h>
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125 namespace TMVA {
0126 class ResultsClassification;
0127 namespace Experimental {
0128 class ClassificationResult : public TObject {
0129 friend class Classification;
0130
0131 private:
0132 OptionMap fMethod;
0133 TString fDataLoaderName;
0134 std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>> fMvaTrain;
0135 std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>> fMvaTest;
0136 std::vector<TString> fClassNames;
0137
0138 Bool_t IsMethod(TString methodname, TString methodtitle);
0139 Bool_t fIsCuts;
0140 Double_t fROCIntegral;
0141
0142 public:
0143 ClassificationResult();
0144 ClassificationResult(const ClassificationResult &cr);
0145 ~ClassificationResult() {}
0146
0147 const TString GetMethodName() const { return fMethod.GetValue<TString>("MethodName"); }
0148 const TString GetMethodTitle() const { return fMethod.GetValue<TString>("MethodTitle"); }
0149 ROCCurve *GetROC(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0150 Double_t GetROCIntegral(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0151 TString GetDataLoaderName() { return fDataLoaderName; }
0152 Bool_t IsCutsMethod() { return fIsCuts; }
0153
0154 void Show();
0155
0156 TGraph *GetROCGraph(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0157 ClassificationResult &operator=(const ClassificationResult &r);
0158
0159 ClassDef(ClassificationResult, 3);
0160 };
0161
0162 class Classification : public Envelope {
0163 std::vector<ClassificationResult> fResults;
0164 std::vector<IMethod *> fIMethods;
0165 Types::EAnalysisType fAnalysisType;
0166 Bool_t fCorrelations;
0167 Bool_t fROC;
0168 public:
0169 explicit Classification(DataLoader *loader, TFile *file, TString options);
0170 explicit Classification(DataLoader *loader, TString options);
0171 ~Classification();
0172
0173 virtual void Train();
0174 virtual void TrainMethod(TString methodname, TString methodtitle);
0175 virtual void TrainMethod(Types::EMVA method, TString methodtitle);
0176
0177 virtual void Test();
0178 virtual void TestMethod(TString methodname, TString methodtitle);
0179 virtual void TestMethod(Types::EMVA method, TString methodtitle);
0180
0181 virtual void Evaluate();
0182
0183 std::vector<ClassificationResult> &GetResults();
0184
0185 MethodBase *GetMethod(TString methodname, TString methodtitle);
0186
0187 protected:
0188 TString GetMethodOptions(TString methodname, TString methodtitle);
0189 Bool_t HasMethodObject(TString methodname, TString methodtitle, Int_t &index);
0190 Bool_t IsCutsMethod(TMVA::MethodBase *method);
0191 TMVA::ROCCurve *
0192 GetROC(TMVA::MethodBase *method, UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0193 TMVA::ROCCurve *GetROC(TString methodname, TString methodtitle, UInt_t iClass = 0,
0194 TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0195
0196 Double_t GetROCIntegral(TString methodname, TString methodtitle, UInt_t iClass = 0);
0197
0198 ClassificationResult &GetResults(TString methodname, TString methodtitle);
0199 void CopyFrom(TDirectory *src, TFile *file);
0200 void MergeFiles();
0201
0202 ClassDef(Classification, 0);
0203 };
0204 }
0205 }
0206
0207 #endif