Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 10:30:15

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
0003 // Updated by: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer
0004 
0005 /**********************************************************************************
0006  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0007  * Package: TMVA                                                                  *
0008  * Class  : Factory                                                               *
0009  *                                             *
0010  *                                                                                *
0011  * Description:                                                                   *
0012  *      This is the main MVA steering class: it creates (books) all MVA methods,  *
0013  *      and guides them through the training, testing and evaluation phases.      *
0014  *                                                                                *
0015  * Authors (alphabetical):                                                        *
0016  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
0017  *      Joerg Stelzer   <stelzer@cern.ch>        - DESY, Germany                  *
0018  *      Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland           *
0019  *      Jan Therhaag          <Jan.Therhaag@cern.ch>   - U of Bonn, Germany       *
0020  *      Eckhard v. Toerne     <evt@uni-bonn.de>        - U of Bonn, Germany       *
0021  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
0022  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
0023  *      Omar Zapata     <Omar.Zapata@cern.ch>    - UdeA/ITM Colombia              *
0024  *      Lorenzo Moneta  <Lorenzo.Moneta@cern.ch> - CERN, Switzerland              *
0025  *      Sergei Gleyzer  <Sergei.Gleyzer@cern.ch> - U of Florida & CERN            *
0026  *                                                                                *
0027  * Copyright (c) 2005-2011:                                                       *
0028  *      CERN, Switzerland                                                         *
0029  *      U. of Victoria, Canada                                                    *
0030  *      MPI-K Heidelberg, Germany                                                 *
0031  *      U. of Bonn, Germany                                                       *
0032  *      UdeA/ITM, Colombia                                                        *
0033  *      U. of Florida, USA                                                        *
0034  *                                                                                *
0035  * Redistribution and use in source and binary forms, with or without             *
0036  * modification, are permitted according to the terms listed in LICENSE           *
0037  * (see tmva/doc/LICENSE)                                          *
0038  **********************************************************************************/
0039 
0040 #ifndef ROOT_TMVA_Factory
0041 #define ROOT_TMVA_Factory
0042 
0043 //////////////////////////////////////////////////////////////////////////
0044 //                                                                      //
0045 // Factory                                                              //
0046 //                                                                      //
0047 // This is the main MVA steering class: it creates all MVA methods,     //
0048 // and guides them through the training, testing and evaluation         //
0049 // phases                                                               //
0050 //                                                                      //
0051 //////////////////////////////////////////////////////////////////////////
0052 
0053 #include <vector>
0054 #include <map>
0055 #include "TCut.h"
0056 
0057 #include "TMVA/Configurable.h"
0058 #include "TMVA/Types.h"
0059 #include "TMVA/DataSet.h"
0060 
0061 class TCanvas;
0062 class TDirectory;
0063 class TFile;
0064 class TGraph;
0065 class TH1F;
0066 class TMultiGraph;
0067 class TTree;
0068 namespace TMVA {
0069 
0070    class IMethod;
0071    class MethodBase;
0072    class DataInputHandler;
0073    class DataSetInfo;
0074    class DataSetManager;
0075    class DataLoader;
0076    class ROCCurve;
0077    class VariableTransformBase;
0078 
0079 
0080    class Factory : public Configurable {
0081       friend class CrossValidation;
0082    public:
0083 
0084       typedef std::vector<IMethod*> MVector;
0085       std::map<TString,MVector*>  fMethodsMap;//all methods for every dataset with the same name
0086 
0087       // no default  constructor
0088       Factory( TString theJobName, TFile* theTargetFile, TString theOption = "" );
0089 
0090       // constructor to work without file
0091       Factory( TString theJobName, TString theOption = "" );
0092 
0093       // default destructor
0094       virtual ~Factory();
0095 
0096       // use TName::GetName and define correct name in constructor
0097       //virtual const char*  GetName() const { return "Factory"; }
0098 
0099       // Internal wrapper type that can be constructed either like a TString or
0100       // from a Types::EMVA enum value and stores the resolved TString. This
0101       // avoids the need for multiple overloads of BookMethod.
0102       class MethodName {
0103       public:
0104          template <typename T, typename = std::enable_if_t<std::is_constructible_v<TString, T &&>>>
0105          MethodName(T &&name) : fName(std::forward<T>(name))
0106          {
0107          }
0108          MethodName(Types::EMVA method) : fName(Types::Instance().GetMethodName(method)) {}
0109          TString const &tString() const { return fName; }
0110 
0111       private:
0112          TString fName;
0113       };
0114 
0115       MethodBase* BookMethod( DataLoader *loader, MethodName theMethodName, TString methodTitle, TString theOption = "" );
0116 
0117       // optimize all booked methods (well, if desired by the method)
0118       std::map<TString,Double_t> OptimizeAllMethods                 (TString fomType="ROCIntegral", TString fitType="FitGA");
0119       void OptimizeAllMethodsForClassification(TString fomType="ROCIntegral", TString fitType="FitGA") { OptimizeAllMethods(fomType,fitType); }
0120       void OptimizeAllMethodsForRegression    (TString fomType="ROCIntegral", TString fitType="FitGA") { OptimizeAllMethods(fomType,fitType); }
0121 
0122       // training for all booked methods
0123       void TrainAllMethods                 ();
0124       void TrainAllMethodsForClassification( void ) { TrainAllMethods(); }
0125       void TrainAllMethodsForRegression    ( void ) { TrainAllMethods(); }
0126 
0127       // testing
0128       void TestAllMethods();
0129 
0130       // performance evaluation
0131       void EvaluateAllMethods( void );
0132       void EvaluateAllVariables(DataLoader *loader, TString options = "" );
0133 
0134       TH1F* EvaluateImportance( DataLoader *loader,VIType vitype, Types::EMVA theMethod,  TString methodTitle, const char *theOption = "" );
0135 
0136       // delete all methods and reset the method vector
0137       void DeleteAllMethods( void );
0138 
0139       // accessors
0140       IMethod* GetMethod( const TString& datasetname, const TString& title ) const;
0141       Bool_t   HasMethod( const TString& datasetname, const TString& title ) const;
0142 
0143       Bool_t Verbose( void ) const { return fVerbose; }
0144       void SetVerbose( Bool_t v=kTRUE );
0145 
0146       // make ROOT-independent C++ class for classifier response
0147       // (classifier-specific implementation)
0148       // If no classifier name is given, help messages for all booked
0149       // classifiers are printed
0150       virtual void MakeClass(const TString& datasetname , const TString& methodTitle = "" ) const;
0151 
0152       // prints classifier-specific help messages, dedicated to
0153       // help with the optimisation and configuration options tuning.
0154       // If no classifier name is given, help messages for all booked
0155       // classifiers are printed
0156       void PrintHelpMessage(const TString& datasetname , const TString& methodTitle = "" ) const;
0157 
0158       TDirectory* RootBaseDir() { return (TDirectory*)fgTargetFile; }
0159 
0160       Bool_t IsSilentFile() const { return fSilentFile;}
0161       Bool_t IsModelPersistence() const { return fModelPersistence; }
0162 
0163       Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass = 0,
0164                               Types::ETreeType type = Types::kTesting);
0165       Double_t GetROCIntegral(TString datasetname, TString theMethodName, UInt_t iClass = 0,
0166                               Types::ETreeType type = Types::kTesting);
0167 
0168       // Methods to get a TGraph for an indicated method in dataset.
0169       // Optional title and axis added with fLegend=kTRUE.
0170       // Argument iClass used in multiclass settings, otherwise ignored.
0171       TGraph *GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles = kTRUE, UInt_t iClass = 0,
0172                           Types::ETreeType type = Types::kTesting);
0173       TGraph *GetROCCurve(TString datasetname, TString theMethodName, Bool_t setTitles = kTRUE, UInt_t iClass = 0,
0174                           Types::ETreeType type = Types::kTesting);
0175 
0176       // Methods to get a TMultiGraph for a given class and all methods in dataset.
0177       TMultiGraph *GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass, Types::ETreeType type = Types::kTesting);
0178       TMultiGraph *GetROCCurveAsMultiGraph(TString datasetname, UInt_t iClass, Types::ETreeType type = Types::kTesting);
0179 
0180       // Draw all ROC curves of a given class for all methods in the dataset.
0181       TCanvas *GetROCCurve(DataLoader *loader, UInt_t iClass = 0, Types::ETreeType type = Types::kTesting);
0182       TCanvas *GetROCCurve(TString datasetname, UInt_t iClass = 0, Types::ETreeType type = Types::kTesting);
0183 
0184    private:
0185 
0186       // the beautiful greeting message
0187       void Greetings();
0188 
0189       //evaluate the simple case that is removing 1 variable at time
0190       TH1F* EvaluateImportanceShort( DataLoader *loader,Types::EMVA theMethod,  TString methodTitle, const char *theOption = "" );
0191       //evaluate all variables combinations
0192       TH1F* EvaluateImportanceAll( DataLoader *loader,Types::EMVA theMethod,  TString methodTitle, const char *theOption = "" );
0193       //evaluate randomly given a number of seeds
0194       TH1F* EvaluateImportanceRandom( DataLoader *loader,UInt_t nseeds, Types::EMVA theMethod,  TString methodTitle, const char *theOption = "" );
0195 
0196       TH1F* GetImportance(const int nbits,std::vector<Double_t> importances,std::vector<TString> varNames);
0197 
0198       // Helpers for public facing ROC methods
0199       ROCCurve *GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass = 0,
0200                        Types::ETreeType type = Types::kTesting);
0201       ROCCurve *GetROC(TString datasetname, TString theMethodName, UInt_t iClass = 0,
0202                        Types::ETreeType type = Types::kTesting);
0203 
0204       void WriteDataInformation(DataSetInfo&     fDataSetInfo);
0205 
0206       void SetInputTreesFromEventAssignTrees();
0207 
0208       MethodBase* BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile);
0209 
0210    private:
0211 
0212       // data members
0213 
0214       TFile*                             fgTargetFile;     ///<! ROOT output file
0215 
0216 
0217       std::vector<TMVA::VariableTransformBase*> fDefaultTrfs;     ///<! list of transformations on default DataSet
0218 
0219       // cd to local directory
0220       TString                                   fOptions;         ///<! option string given by construction (presently only "V")
0221       TString                                   fTransformations; ///<! list of transformations to test
0222       Bool_t                                    fVerbose;         ///<! verbose mode
0223       TString                                   fVerboseLevel;    ///<! verbosity level, controls granularity of logging
0224       Bool_t                                    fCorrelations;    ///<! enable to calculate correlations
0225       Bool_t                                    fROC;             ///<! enable to calculate ROC values
0226       Bool_t                                    fSilentFile;      ///<! used in constructor without file
0227 
0228       TString                                   fJobName;         ///<! jobname, used as extension in weight file names
0229 
0230       Types::EAnalysisType                      fAnalysisType;    ///<! the training type
0231       Bool_t                                    fModelPersistence;///<! option to save the trained model in xml file or using serialization
0232 
0233 
0234    protected:
0235 
0236       ClassDefOverride(Factory,0);  // The factory creates all MVA methods, and performs their training and testing
0237    };
0238 
0239 } // namespace TMVA
0240 
0241 #endif