File indexing completed on 2024-11-16 09:02:32
0001 #include "TROOT.h"
0002 #include "TChain.h"
0003 #include "TFile.h"
0004 #include "TEfficiency.h"
0005 #include "TH1.h"
0006 #include "TGraphErrors.h"
0007 #include "TCut.h"
0008 #include "TCanvas.h"
0009 #include "TStyle.h"
0010 #include "TLegend.h"
0011 #include "TMath.h"
0012 #include "TLine.h"
0013 #include "TLatex.h"
0014 #include "TRatioPlot.h"
0015
0016 #include "TMVA/Factory.h"
0017 #include "TMVA/DataLoader.h"
0018 #include "TMVA/Tools.h"
0019
0020 #include <glob.h>
0021 #include <iostream>
0022 #include <iomanip>
0023 #include <vector>
0024
0025 #include "PlotFunctions.h"
0026
0027
0028 void StrangeJetClassification(TString dir, TString input, TString filePattern = "*/out.root")
0029 {
0030
0031 gStyle->SetOptStat(0);
0032
0033
0034 TCanvas* pad = new TCanvas("pad","",800,600);
0035 TLegend * legend = nullptr;
0036 TH1F* htemplate = nullptr;
0037
0038 auto default_data = new TChain("tree");
0039 default_data->SetTitle(input.Data());
0040 auto files = fileVector(Form("%s/%s/%s", dir.Data(), input.Data(), filePattern.Data()));
0041
0042 for (auto file : files)
0043 {
0044 default_data->Add(file.c_str());
0045 }
0046
0047
0048
0049 auto signal_train = default_data->CopyTree("jet_flavor==3", "", TMath::Floor(default_data->GetEntries()/1.0));
0050 std::cout << "Signal Tree (Training): " << signal_train->GetEntries() << std::endl;
0051
0052 auto background_train = default_data->CopyTree("(jet_flavor<3||jet_flavor==21)", "", TMath::Floor(default_data->GetEntries()/1.0));
0053 std::cout << "Background Tree (Training): " << background_train->GetEntries() << std::endl;
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065 TMVA::Tools::Instance();
0066
0067 auto outputFile = TFile::Open("StrangeJetClassification_Results.root", "RECREATE");
0068
0069 TMVA::Factory factory("TMVAClassification", outputFile,
0070 "!V:ROC:!Correlations:!Silent:Color:DrawProgressBar:AnalysisType=Classification" );
0071
0072
0073 TMVA::DataLoader loader("dataset");
0074
0075 loader.AddVariable("jet_pt");
0076 loader.AddVariable("jet_eta");
0077 loader.AddVariable("jet_nconstituents");
0078 loader.AddVariable("jet_Ks_leading_zhadron");
0079 loader.AddVariable("jet_K_leading_zhadron");
0080 loader.AddVariable("jet_charge");
0081 loader.AddVariable("jet_Ks_sumpt");
0082 loader.AddVariable("jet_K_sumpt");
0083 loader.AddVariable("jet_ehadoveremratio");
0084 loader.AddSpectator("jet_flavor");
0085
0086 loader.AddSignalTree( signal_train, 1.0 );
0087 loader.AddBackgroundTree( background_train, 1.0 );
0088
0089
0090
0091
0092
0093 loader.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==3"),
0094 TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<3||jet_flavor==21)"),
0095 "nTrain_Signal=50000:nTrain_Background=100000:nTest_Signal=50000:nTest_Background=100000:SplitMode=Random:NormMode=NumEvents:!V" );
0096
0097
0098 factory.BookMethod(&loader,TMVA::Types::kBDT, "BDT",
0099 "!V:NTrees=1000:MinNodeSize=2.5%:MaxDepth=4:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20" );
0100
0101
0102 factory.TrainAllMethods();
0103
0104
0105 factory.TestAllMethods();
0106 factory.EvaluateAllMethods();
0107
0108
0109 pad->cd();
0110 pad = factory.GetROCCurve(&loader);
0111 pad->Draw();
0112
0113 pad->SaveAs("StrangeJetClassification_ROC.pdf");
0114
0115
0116 }
0117