Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-11-16 09:02:32

0001 #include "TROOT.h"
0002 #include "TChain.h"
0003 #include "TFile.h"
0004 #include "TEfficiency.h"
0005 #include "TH1.h"
0006 #include "TGraphErrors.h"
0007 #include "TCut.h"
0008 #include "TCanvas.h"
0009 #include "TStyle.h"
0010 #include "TLegend.h"
0011 #include "TMath.h"
0012 #include "TLine.h"
0013 #include "TLatex.h"
0014 #include "TRatioPlot.h"
0015 
0016 #include "TMVA/Factory.h"
0017 #include "TMVA/DataLoader.h"
0018 #include "TMVA/Tools.h"
0019 
0020 #include <glob.h>
0021 #include <iostream>
0022 #include <iomanip>
0023 #include <vector>
0024 
0025 #include "PlotFunctions.h"
0026 
0027 
0028 void StrangeJetClassification(TString dir, TString input, TString filePattern = "*/out.root")
0029 {
0030   // Global options
0031   gStyle->SetOptStat(0);
0032 
0033   // Create the TCanvas
0034   TCanvas* pad = new TCanvas("pad","",800,600);
0035   TLegend * legend = nullptr;
0036   TH1F* htemplate = nullptr;
0037 
0038   auto default_data = new TChain("tree");
0039   default_data->SetTitle(input.Data());
0040   auto files = fileVector(Form("%s/%s/%s", dir.Data(), input.Data(), filePattern.Data()));
0041 
0042   for (auto file : files) 
0043     {
0044       default_data->Add(file.c_str());
0045     }
0046 
0047   // Create the signal and background trees
0048 
0049   auto signal_train = default_data->CopyTree("jet_flavor==3", "", TMath::Floor(default_data->GetEntries()/1.0));
0050   std::cout << "Signal Tree (Training): " << signal_train->GetEntries() << std::endl;
0051 
0052   auto background_train = default_data->CopyTree("(jet_flavor<3||jet_flavor==21)", "", TMath::Floor(default_data->GetEntries()/1.0));
0053   std::cout << "Background Tree (Training): " << background_train->GetEntries() << std::endl;
0054 
0055   // auto u_background_train = default_data->CopyTree("(jet_flavor==2)", "", TMath::Floor(default_data->GetEntries()/1.0));
0056   // std::cout << "u Background Tree (Training): " << u_background_train->GetEntries() << std::endl;
0057 
0058   // auto d_background_train = default_data->CopyTree("(jet_flavor==1)", "", TMath::Floor(default_data->GetEntries()/1.0));
0059   // std::cout << "d Background Tree (Training): " << d_background_train->GetEntries() << std::endl;
0060 
0061   // auto g_background_train = default_data->CopyTree("(jet_flavor==21)", "", TMath::Floor(default_data->GetEntries()/1.0));
0062   // std::cout << "g Background Tree (Training): " << g_background_train->GetEntries() << std::endl;
0063 
0064   // Create the TMVA tools
0065   TMVA::Tools::Instance();
0066 
0067   auto outputFile = TFile::Open("StrangeJetClassification_Results.root", "RECREATE");
0068 
0069   TMVA::Factory factory("TMVAClassification", outputFile,
0070             "!V:ROC:!Correlations:!Silent:Color:DrawProgressBar:AnalysisType=Classification" );
0071 
0072 
0073   TMVA::DataLoader loader("dataset");
0074 
0075   loader.AddVariable("jet_pt");
0076   loader.AddVariable("jet_eta");
0077   loader.AddVariable("jet_nconstituents");
0078   loader.AddVariable("jet_Ks_leading_zhadron");
0079   loader.AddVariable("jet_K_leading_zhadron");
0080   loader.AddVariable("jet_charge");
0081   loader.AddVariable("jet_Ks_sumpt");
0082   loader.AddVariable("jet_K_sumpt");
0083   loader.AddVariable("jet_ehadoveremratio");
0084   loader.AddSpectator("jet_flavor");
0085 
0086   loader.AddSignalTree( signal_train, 1.0 );
0087   loader.AddBackgroundTree( background_train, 1.0 );
0088   // loader.AddTree( signal_train, "strange_jets" );
0089   // loader.AddTree( u_background_train, "up jets" );
0090   // loader.AddTree( d_background_train, "down jets" );
0091   // loader.AddTree( g_background_train, "gluon jets" );
0092 
0093   loader.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==3"), 
0094                     TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<3||jet_flavor==21)"),
0095                     "nTrain_Signal=50000:nTrain_Background=100000:nTest_Signal=50000:nTest_Background=100000:SplitMode=Random:NormMode=NumEvents:!V" );
0096 
0097   // Declare the classification method(s)
0098   factory.BookMethod(&loader,TMVA::Types::kBDT, "BDT",
0099              "!V:NTrees=1000:MinNodeSize=2.5%:MaxDepth=4:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20" );
0100 
0101   // Train
0102   factory.TrainAllMethods();
0103 
0104   // Test
0105   factory.TestAllMethods();
0106   factory.EvaluateAllMethods();
0107 
0108   // Plot a ROC Curve
0109   pad->cd();
0110   pad = factory.GetROCCurve(&loader);
0111   pad->Draw();
0112 
0113   pad->SaveAs("StrangeJetClassification_ROC.pdf");
0114                     
0115 
0116 }
0117