File indexing completed on 2025-12-16 10:30:15
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040 #ifndef ROOT_TMVA_Factory
0041 #define ROOT_TMVA_Factory
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053 #include <vector>
0054 #include <map>
0055 #include "TCut.h"
0056
0057 #include "TMVA/Configurable.h"
0058 #include "TMVA/Types.h"
0059 #include "TMVA/DataSet.h"
0060
0061 class TCanvas;
0062 class TDirectory;
0063 class TFile;
0064 class TGraph;
0065 class TH1F;
0066 class TMultiGraph;
0067 class TTree;
0068 namespace TMVA {
0069
0070 class IMethod;
0071 class MethodBase;
0072 class DataInputHandler;
0073 class DataSetInfo;
0074 class DataSetManager;
0075 class DataLoader;
0076 class ROCCurve;
0077 class VariableTransformBase;
0078
0079
0080 class Factory : public Configurable {
0081 friend class CrossValidation;
0082 public:
0083
0084 typedef std::vector<IMethod*> MVector;
0085 std::map<TString,MVector*> fMethodsMap;
0086
0087
0088 Factory( TString theJobName, TFile* theTargetFile, TString theOption = "" );
0089
0090
0091 Factory( TString theJobName, TString theOption = "" );
0092
0093
0094 virtual ~Factory();
0095
0096
0097
0098
0099
0100
0101
0102 class MethodName {
0103 public:
0104 template <typename T, typename = std::enable_if_t<std::is_constructible_v<TString, T &&>>>
0105 MethodName(T &&name) : fName(std::forward<T>(name))
0106 {
0107 }
0108 MethodName(Types::EMVA method) : fName(Types::Instance().GetMethodName(method)) {}
0109 TString const &tString() const { return fName; }
0110
0111 private:
0112 TString fName;
0113 };
0114
0115 MethodBase* BookMethod( DataLoader *loader, MethodName theMethodName, TString methodTitle, TString theOption = "" );
0116
0117
0118 std::map<TString,Double_t> OptimizeAllMethods (TString fomType="ROCIntegral", TString fitType="FitGA");
0119 void OptimizeAllMethodsForClassification(TString fomType="ROCIntegral", TString fitType="FitGA") { OptimizeAllMethods(fomType,fitType); }
0120 void OptimizeAllMethodsForRegression (TString fomType="ROCIntegral", TString fitType="FitGA") { OptimizeAllMethods(fomType,fitType); }
0121
0122
0123 void TrainAllMethods ();
0124 void TrainAllMethodsForClassification( void ) { TrainAllMethods(); }
0125 void TrainAllMethodsForRegression ( void ) { TrainAllMethods(); }
0126
0127
0128 void TestAllMethods();
0129
0130
0131 void EvaluateAllMethods( void );
0132 void EvaluateAllVariables(DataLoader *loader, TString options = "" );
0133
0134 TH1F* EvaluateImportance( DataLoader *loader,VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
0135
0136
0137 void DeleteAllMethods( void );
0138
0139
0140 IMethod* GetMethod( const TString& datasetname, const TString& title ) const;
0141 Bool_t HasMethod( const TString& datasetname, const TString& title ) const;
0142
0143 Bool_t Verbose( void ) const { return fVerbose; }
0144 void SetVerbose( Bool_t v=kTRUE );
0145
0146
0147
0148
0149
0150 virtual void MakeClass(const TString& datasetname , const TString& methodTitle = "" ) const;
0151
0152
0153
0154
0155
0156 void PrintHelpMessage(const TString& datasetname , const TString& methodTitle = "" ) const;
0157
0158 TDirectory* RootBaseDir() { return (TDirectory*)fgTargetFile; }
0159
0160 Bool_t IsSilentFile() const { return fSilentFile;}
0161 Bool_t IsModelPersistence() const { return fModelPersistence; }
0162
0163 Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass = 0,
0164 Types::ETreeType type = Types::kTesting);
0165 Double_t GetROCIntegral(TString datasetname, TString theMethodName, UInt_t iClass = 0,
0166 Types::ETreeType type = Types::kTesting);
0167
0168
0169
0170
0171 TGraph *GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles = kTRUE, UInt_t iClass = 0,
0172 Types::ETreeType type = Types::kTesting);
0173 TGraph *GetROCCurve(TString datasetname, TString theMethodName, Bool_t setTitles = kTRUE, UInt_t iClass = 0,
0174 Types::ETreeType type = Types::kTesting);
0175
0176
0177 TMultiGraph *GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass, Types::ETreeType type = Types::kTesting);
0178 TMultiGraph *GetROCCurveAsMultiGraph(TString datasetname, UInt_t iClass, Types::ETreeType type = Types::kTesting);
0179
0180
0181 TCanvas *GetROCCurve(DataLoader *loader, UInt_t iClass = 0, Types::ETreeType type = Types::kTesting);
0182 TCanvas *GetROCCurve(TString datasetname, UInt_t iClass = 0, Types::ETreeType type = Types::kTesting);
0183
0184 private:
0185
0186
0187 void Greetings();
0188
0189
0190 TH1F* EvaluateImportanceShort( DataLoader *loader,Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
0191
0192 TH1F* EvaluateImportanceAll( DataLoader *loader,Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
0193
0194 TH1F* EvaluateImportanceRandom( DataLoader *loader,UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
0195
0196 TH1F* GetImportance(const int nbits,std::vector<Double_t> importances,std::vector<TString> varNames);
0197
0198
0199 ROCCurve *GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass = 0,
0200 Types::ETreeType type = Types::kTesting);
0201 ROCCurve *GetROC(TString datasetname, TString theMethodName, UInt_t iClass = 0,
0202 Types::ETreeType type = Types::kTesting);
0203
0204 void WriteDataInformation(DataSetInfo& fDataSetInfo);
0205
0206 void SetInputTreesFromEventAssignTrees();
0207
0208 MethodBase* BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile);
0209
0210 private:
0211
0212
0213
0214 TFile* fgTargetFile;
0215
0216
0217 std::vector<TMVA::VariableTransformBase*> fDefaultTrfs;
0218
0219
0220 TString fOptions;
0221 TString fTransformations;
0222 Bool_t fVerbose;
0223 TString fVerboseLevel;
0224 Bool_t fCorrelations;
0225 Bool_t fROC;
0226 Bool_t fSilentFile;
0227
0228 TString fJobName;
0229
0230 Types::EAnalysisType fAnalysisType;
0231 Bool_t fModelPersistence;
0232
0233
0234 protected:
0235
0236 ClassDefOverride(Factory,0);
0237 };
0238
0239 }
0240
0241 #endif