File indexing completed on 2024-09-28 07:03:08
0001 #include "TROOT.h"
0002 #include "TChain.h"
0003 #include "TFile.h"
0004 #include "TEfficiency.h"
0005 #include "TH1.h"
0006 #include "TGraphErrors.h"
0007 #include "TCut.h"
0008 #include "TCanvas.h"
0009 #include "TStyle.h"
0010 #include "TLegend.h"
0011 #include "TMath.h"
0012 #include "TLine.h"
0013 #include "TLatex.h"
0014 #include "TRatioPlot.h"
0015
0016 #include "TMVA/Factory.h"
0017 #include "TMVA/DataLoader.h"
0018 #include "TMVA/Tools.h"
0019
0020 #include <glob.h>
0021 #include <iostream>
0022 #include <iomanip>
0023 #include <vector>
0024
0025 #include "PlotFunctions.h"
0026
0027
0028 void CharmJetClassification(TString dir, TString input, TString filePattern = "*/out.root")
0029 {
0030
0031 gStyle->SetOptStat(0);
0032
0033
0034 TCanvas *pad = new TCanvas("pad",
0035 "",
0036 800,
0037 600);
0038 TLegend *legend = nullptr;
0039 TH1F *htemplate = nullptr;
0040
0041 auto default_data = new TChain("tree");
0042 default_data->SetTitle(input.Data());
0043 auto files = fileVector(Form("%s/%s/%s", dir.Data(), input.Data(), filePattern.Data()));
0044
0045 for (auto file : files)
0046 {
0047 default_data->Add(file.c_str());
0048 }
0049
0050
0051
0052 auto signal_train = default_data->CopyTree("jet_flavor==4 && jet_n>0", "", TMath::Floor(default_data->GetEntries() / 1.0));
0053 std::cout << "Signal Tree (Training): " << signal_train->GetEntries() << std::endl;
0054
0055 auto background_train = default_data->CopyTree("(jet_flavor<4||jet_flavor==21) && jet_n>0", "", TMath::Floor(default_data->GetEntries() / 1.0));
0056 std::cout << "Background Tree (Training): " << background_train->GetEntries() << std::endl;
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074 TMVA::Tools::Instance();
0075
0076 auto outputFile = TFile::Open("CharmJetClassification_Results.root", "RECREATE");
0077
0078 TMVA::Factory factory("TMVAClassification",
0079 outputFile,
0080 "!V:ROC:!Correlations:!Silent:Color:DrawProgressBar:AnalysisType=Classification");
0081
0082
0083
0084
0085
0086 TMVA::DataLoader loader_ktagger("dataset_ktagger");
0087 TMVA::DataLoader loader_etagger("dataset_etagger");
0088 TMVA::DataLoader loader_mutagger("dataset_mutagger");
0089 TMVA::DataLoader loader_ip3dtagger("dataset_ip3dtagger");
0090
0091
0092
0093
0094
0095
0096
0097 loader_ktagger.AddSpectator("jet_pt");
0098 loader_ktagger.AddSpectator("jet_eta");
0099 loader_ktagger.AddVariable("jet_k1_pt");
0100 loader_ktagger.AddVariable("jet_k1_sIP3D");
0101 loader_ktagger.AddVariable("jet_k2_pt");
0102 loader_ktagger.AddVariable("jet_k2_sIP3D");
0103 loader_ktagger.AddSpectator("jet_flavor");
0104 loader_ktagger.AddSpectator("met_et");
0105 loader_ktagger.AddSignalTree(signal_train, 1.0);
0106 loader_ktagger.AddBackgroundTree(background_train, 1.0);
0107
0108
0109
0110 loader_etagger.AddSpectator("jet_pt");
0111 loader_etagger.AddSpectator("jet_eta");
0112 loader_etagger.AddVariable("jet_e1_pt");
0113 loader_etagger.AddVariable("jet_e1_sIP3D");
0114 loader_etagger.AddVariable("jet_e2_pt");
0115 loader_etagger.AddVariable("jet_e2_sIP3D");
0116 loader_etagger.AddSpectator("jet_flavor");
0117 loader_etagger.AddSpectator("met_et");
0118 loader_etagger.AddSignalTree(signal_train, 1.0);
0119 loader_etagger.AddBackgroundTree(background_train, 1.0);
0120
0121 loader_mutagger.AddSpectator("jet_pt");
0122 loader_mutagger.AddSpectator("jet_eta");
0123 loader_mutagger.AddVariable("jet_mu1_pt");
0124 loader_mutagger.AddVariable("jet_mu1_sIP3D");
0125 loader_mutagger.AddVariable("jet_mu2_pt");
0126 loader_mutagger.AddVariable("jet_mu2_sIP3D");
0127 loader_mutagger.AddSpectator("jet_flavor");
0128 loader_mutagger.AddSpectator("met_et");
0129 loader_mutagger.AddSignalTree(signal_train, 1.0);
0130 loader_mutagger.AddBackgroundTree(background_train, 1.0);
0131
0132
0133 loader_ip3dtagger.AddSpectator("jet_pt");
0134 loader_ip3dtagger.AddSpectator("jet_eta");
0135 loader_ip3dtagger.AddVariable("jet_t1_pt");
0136 loader_ip3dtagger.AddVariable("jet_t1_sIP3D");
0137 loader_ip3dtagger.AddVariable("jet_t2_pt");
0138 loader_ip3dtagger.AddVariable("jet_t2_sIP3D");
0139 loader_ip3dtagger.AddVariable("jet_t3_pt");
0140 loader_ip3dtagger.AddVariable("jet_t3_sIP3D");
0141 loader_ip3dtagger.AddVariable("jet_t4_pt");
0142 loader_ip3dtagger.AddVariable("jet_t4_sIP3D");
0143 loader_ip3dtagger.AddSpectator("jet_flavor");
0144 loader_ip3dtagger.AddSpectator("met_et");
0145 loader_ip3dtagger.AddSignalTree(signal_train, 1.0);
0146 loader_ip3dtagger.AddBackgroundTree(background_train, 1.0);
0147
0148
0149
0150
0151
0152
0153
0154
0155
0156
0157
0158
0159
0160 loader_ktagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
0161 TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
0162 "nTrain_Signal=50000:nTrain_Background=500000:nTest_Signal=50000:nTest_Background=500000:SplitMode=Random:NormMode=NumEvents:!V");
0163 loader_etagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
0164 TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
0165 "nTrain_Signal=100000:nTrain_Background=1000000:nTest_Signal=100000:nTest_Background=1000000:SplitMode=Random:NormMode=NumEvents:!V");
0166 loader_mutagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
0167 TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
0168 "nTrain_Signal=100000:nTrain_Background=1000000:nTest_Signal=100000:nTest_Background=1000000:SplitMode=Random:NormMode=NumEvents:!V");
0169 loader_ip3dtagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
0170 TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
0171 "nTrain_Signal=10000:nTrain_Background=100000:nTest_Signal=10000:nTest_Background=100000:SplitMode=Random:NormMode=NumEvents:!V");
0172
0173
0174
0175
0176
0177
0178
0179 factory.BookMethod(&loader_ktagger, TMVA::Types::kMLP, "CharmKTagger",
0180 "!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=1000:HiddenLayers=N+12:TestRate=5:!UseRegulator");
0181
0182
0183
0184
0185
0186
0187
0188
0189
0190
0191
0192
0193
0194 factory.TrainAllMethods();
0195
0196
0197 factory.TestAllMethods();
0198 factory.EvaluateAllMethods();
0199
0200
0201 pad->cd();
0202 pad = factory.GetROCCurve(&loader_ktagger);
0203
0204
0205
0206 pad->Draw();
0207
0208 pad->SaveAs("CharmJetClassification_ROC.pdf");
0209 }