Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:57

0001 // @(#)root/tmva $Id$ 2017
0002 // Authors:  Omar Zapata, Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne,
0003 // Jan Therhaag
0004 
0005 #ifndef ROOT_TMVA_Classification
0006 #define ROOT_TMVA_Classification
0007 
0008 #include <TString.h>
0009 #include <TMultiGraph.h>
0010 #include <vector>
0011 #include <map>
0012 
0013 #include <TMVA/IMethod.h>
0014 #include <TMVA/MethodBase.h>
0015 #include <TMVA/Configurable.h>
0016 #include <TMVA/Types.h>
0017 #include <TMVA/DataSet.h>
0018 #include <TMVA/Event.h>
0019 #include <TMVA/Results.h>
0020 #include <TMVA/ResultsClassification.h>
0021 #include <TMVA/ResultsMulticlass.h>
0022 #include <TMVA/Factory.h>
0023 #include <TMVA/DataLoader.h>
0024 #include <TMVA/OptionMap.h>
0025 #include <TMVA/Envelope.h>
0026 
0027 /*! \class TMVA::ClassificationResult
0028  * Class to save the results of the classifier.
0029  * Every machine learning method booked have an object for the results
0030  * in the classification process, in this class is stored the mvas,
0031  * data loader name and ml method name and title.
0032  * You can to display the results calling the method Show, get the ROC-integral with the
0033  * method GetROCIntegral or get the TMVA::ROCCurve object calling GetROC.
0034 \ingroup TMVA
0035 */
0036 
0037 /*! \class TMVA::Classification
0038  * Class to perform two class classification.
0039  * The first step before any analysis is to prepare the data,
0040  * to do that you need to create an object of TMVA::DataLoader,
0041  * in this object you need to configure the variables and the number of events
0042  * to train/test.
0043  * The class TMVA::Experimental::Classification needs a TMVA::DataLoader object,
0044  * optional a TFile object to save the results and some extra options in a string
0045  * like "V:Color:Transformations=I;D;P;U;G:Silent:DrawProgressBar:ModelPersistence:Jobs=2" where:
0046  * V                = verbose output
0047  * Color            = coloured screen output
0048  * Silent           = batch mode: boolean silent flag inhibiting any output from TMVA
0049  * Transformations  = list of transformations to test.
0050  * DrawProgressBar  = draw progress bar to display training and testing.
0051  * ModelPersistence = to save the trained model in xml or serialized files.
0052  * Jobs             = number of ml methods to test/train in parallel using MultiProc, requires to call Evaluate method.
0053  * Basic example.
0054  * \code
0055 void classification(UInt_t jobs = 2)
0056 {
0057    TMVA::Tools::Instance();
0058 
0059    TFile *input(0);
0060    TString fname = "./tmva_class_example.root";
0061    if (!gSystem->AccessPathName(fname)) {
0062       input = TFile::Open(fname); // check if file in local directory exists
0063    } else {
0064       TFile::SetCacheFileDir(".");
0065       input = TFile::Open("http://root.cern/files/tmva_class_example.root", "CACHEREAD");
0066    }
0067    if (!input) {
0068       std::cout << "ERROR: could not open data file" << std::endl;
0069       exit(1);
0070    }
0071 
0072    // Register the training and test trees
0073 
0074    TTree *signalTree = (TTree *)input->Get("TreeS");
0075    TTree *background = (TTree *)input->Get("TreeB");
0076 
0077    TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
0078 
0079    dataloader->AddVariable("myvar1 := var1+var2", 'F');
0080    dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
0081    dataloader->AddVariable("var3", "Variable 3", "units", 'F');
0082    dataloader->AddVariable("var4", "Variable 4", "units", 'F');
0083 
0084    dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
0085    dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
0086 
0087    // global event weights per tree (see below for setting event-wise weights)
0088    Double_t signalWeight = 1.0;
0089    Double_t backgroundWeight = 1.0;
0090 
0091    dataloader->SetBackgroundWeightExpression("weight");
0092 
0093    TMVA::Experimental::Classification *cl = new TMVA::Experimental::Classification(dataloader, Form("Jobs=%d", jobs));
0094 
0095    cl->BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
0096                                              "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
0097    cl->BookMethod(TMVA::Types::kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm");
0098 
0099    cl->Evaluate(); // Train and Test all methods
0100 
0101    auto &results = cl->GetResults();
0102 
0103    TCanvas *c = new TCanvas(Form("ROC"));
0104    c->SetTitle("ROC-Integral Curve");
0105 
0106    auto mg = new TMultiGraph();
0107    for (UInt_t i = 0; i < results.size(); i++) {
0108       auto roc = results[i].GetROCGraph();
0109       roc->SetLineColorAlpha(i + 1, 0.1);
0110       mg->Add(roc);
0111    }
0112    mg->Draw("AL");
0113    mg->GetXaxis()->SetTitle(" Signal Efficiency ");
0114    mg->GetYaxis()->SetTitle(" Background Rejection ");
0115    c->BuildLegend(0.15, 0.15, 0.3, 0.3);
0116    c->Draw();
0117 
0118    delete cl;
0119 }
0120  * \endcode
0121  *
0122 \ingroup TMVA
0123 */
0124 
0125 namespace TMVA {
0126 class ResultsClassification;
0127 namespace Experimental {
0128 class ClassificationResult : public TObject {
0129    friend class Classification;
0130 
0131 private:
0132    OptionMap fMethod;
0133    TString fDataLoaderName;
0134    std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>> fMvaTrain; ///< Mvas for two-class classification
0135    std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>> fMvaTest;  ///< Mvas for two-class and multiclass classification
0136    std::vector<TString> fClassNames;
0137 
0138    Bool_t IsMethod(TString methodname, TString methodtitle);
0139    Bool_t fIsCuts;        ///< if it is a method cuts need special output
0140    Double_t fROCIntegral;
0141 
0142 public:
0143    ClassificationResult();
0144    ClassificationResult(const ClassificationResult &cr);
0145    ~ClassificationResult() {}
0146 
0147    const TString GetMethodName() const { return fMethod.GetValue<TString>("MethodName"); }
0148    const TString GetMethodTitle() const { return fMethod.GetValue<TString>("MethodTitle"); }
0149    ROCCurve *GetROC(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0150    Double_t GetROCIntegral(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0151    TString GetDataLoaderName() { return fDataLoaderName; }
0152    Bool_t IsCutsMethod() { return fIsCuts; }
0153 
0154    void Show();
0155 
0156    TGraph *GetROCGraph(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0157    ClassificationResult &operator=(const ClassificationResult &r);
0158 
0159    ClassDef(ClassificationResult, 3);
0160 };
0161 
0162 class Classification : public Envelope {
0163    std::vector<ClassificationResult> fResults; ///<!
0164    std::vector<IMethod *> fIMethods;           ///<! vector of objects with booked methods
0165    Types::EAnalysisType fAnalysisType;         ///<!
0166    Bool_t fCorrelations;                       ///<!
0167    Bool_t fROC;                                ///<!
0168 public:
0169    explicit Classification(DataLoader *loader, TFile *file, TString options);
0170    explicit Classification(DataLoader *loader, TString options);
0171    ~Classification();
0172 
0173    virtual void Train();
0174    virtual void TrainMethod(TString methodname, TString methodtitle);
0175    virtual void TrainMethod(Types::EMVA method, TString methodtitle);
0176 
0177    virtual void Test();
0178    virtual void TestMethod(TString methodname, TString methodtitle);
0179    virtual void TestMethod(Types::EMVA method, TString methodtitle);
0180 
0181    virtual void Evaluate();
0182 
0183    std::vector<ClassificationResult> &GetResults();
0184 
0185    MethodBase *GetMethod(TString methodname, TString methodtitle);
0186 
0187 protected:
0188    TString GetMethodOptions(TString methodname, TString methodtitle);
0189    Bool_t HasMethodObject(TString methodname, TString methodtitle, Int_t &index);
0190    Bool_t IsCutsMethod(TMVA::MethodBase *method);
0191    TMVA::ROCCurve *
0192    GetROC(TMVA::MethodBase *method, UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0193    TMVA::ROCCurve *GetROC(TString methodname, TString methodtitle, UInt_t iClass = 0,
0194                           TMVA::Types::ETreeType type = TMVA::Types::kTesting);
0195 
0196    Double_t GetROCIntegral(TString methodname, TString methodtitle, UInt_t iClass = 0);
0197 
0198    ClassificationResult &GetResults(TString methodname, TString methodtitle);
0199    void CopyFrom(TDirectory *src, TFile *file);
0200    void MergeFiles();
0201 
0202    ClassDef(Classification, 0);
0203 };
0204 } // namespace Experimental
0205 } // namespace TMVA
0206 
0207 #endif // ROOT_TMVA_Classification