File indexing completed on 2025-06-30 08:50:16
0001
0002
0003
0004
0005 #ifndef ROOT_TMVA_Classification
0006 #define ROOT_TMVA_Classification
0007
0008 #include <TString.h>
0009 #include <TMultiGraph.h>
0010 #include <vector>
0011 #include <map>
0012
0013 #include <TMVA/IMethod.h>
0014 #include <TMVA/MethodBase.h>
0015 #include <TMVA/Configurable.h>
0016 #include <TMVA/Types.h>
0017 #include <TMVA/DataSet.h>
0018 #include <TMVA/Event.h>
0019 #include <TMVA/Results.h>
0020 #include <TMVA/ResultsClassification.h>
0021 #include <TMVA/ResultsMulticlass.h>
0022 #include <TMVA/Factory.h>
0023 #include <TMVA/DataLoader.h>
0024 #include <TMVA/OptionMap.h>
0025 #include <TMVA/Envelope.h>
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122 namespace TMVA {
0123 class ResultsClassification;
0124 namespace Experimental {
0125 class ClassificationResult : public TObject {
0126 friend class Classification;
0127
0128 private:
0129 OptionMap fMethod;
0130 TString fDataLoaderName;
0131 std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>> fMvaTrain;
0132 std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>> fMvaTest;
0133 std::vector<TString> fClassNames;
0134
0135 Bool_t IsMethod(TString methodname, TString methodtitle);
0136 Bool_t fIsCuts;
0137 Double_t fROCIntegral;
0138
0139 public:
0140 ClassificationResult();
0141 ClassificationResult(const ClassificationResult &cr);
0142 ~ClassificationResult() {}
0143
0144 const TString GetMethodName() const { return fMethod.GetValue<TString>("MethodName"); }
0145 const TString GetMethodTitle() const { return fMethod.GetValue<TString>("MethodTitle"); }
0146 ROCCurve *GetROC(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0147 Double_t GetROCIntegral(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0148 TString GetDataLoaderName() { return fDataLoaderName; }
0149 Bool_t IsCutsMethod() { return fIsCuts; }
0150
0151 void Show();
0152
0153 TGraph *GetROCGraph(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0154 ClassificationResult &operator=(const ClassificationResult &r);
0155
0156 ClassDef(ClassificationResult, 3);
0157 };
0158
0159 class Classification : public Envelope {
0160 std::vector<ClassificationResult> fResults;
0161 std::vector<IMethod *> fIMethods;
0162 Types::EAnalysisType fAnalysisType;
0163 Bool_t fCorrelations;
0164 Bool_t fROC;
0165 public:
0166 explicit Classification(DataLoader *loader, TFile *file, TString options);
0167 explicit Classification(DataLoader *loader, TString options);
0168 ~Classification();
0169
0170 virtual void Train();
0171 virtual void TrainMethod(TString methodname, TString methodtitle);
0172 virtual void TrainMethod(Types::EMVA method, TString methodtitle);
0173
0174 virtual void Test();
0175 virtual void TestMethod(TString methodname, TString methodtitle);
0176 virtual void TestMethod(Types::EMVA method, TString methodtitle);
0177
0178 virtual void Evaluate();
0179
0180 std::vector<ClassificationResult> &GetResults();
0181
0182 MethodBase *GetMethod(TString methodname, TString methodtitle);
0183
0184 protected:
0185 TString GetMethodOptions(TString methodname, TString methodtitle);
0186 Bool_t HasMethodObject(TString methodname, TString methodtitle, Int_t &index);
0187 Bool_t IsCutsMethod(TMVA::MethodBase *method);
0188 TMVA::ROCCurve *
0189 GetROC(TMVA::MethodBase *method, UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0190 TMVA::ROCCurve *GetROC(TString methodname, TString methodtitle, UInt_t iClass = 0,
0191 TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0192
0193 Double_t GetROCIntegral(TString methodname, TString methodtitle, UInt_t iClass = 0);
0194
0195 ClassificationResult &GetResults(TString methodname, TString methodtitle);
0196 void CopyFrom(TDirectory *src, TFile *file);
0197 void MergeFiles();
0198
0199 ClassDef(Classification, 0);
0200 };
0201 }
0202 }
0203
0204 #endif