Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-09-28 07:03:08

0001 #include "TROOT.h"
0002 #include "TChain.h"
0003 #include "TFile.h"
0004 #include "TEfficiency.h"
0005 #include "TH1.h"
0006 #include "TGraphErrors.h"
0007 #include "TCut.h"
0008 #include "TCanvas.h"
0009 #include "TStyle.h"
0010 #include "TLegend.h"
0011 #include "TMath.h"
0012 #include "TLine.h"
0013 #include "TLatex.h"
0014 #include "TRatioPlot.h"
0015 
0016 #include "TMVA/Factory.h"
0017 #include "TMVA/DataLoader.h"
0018 #include "TMVA/Tools.h"
0019 
0020 #include <glob.h>
0021 #include <iostream>
0022 #include <iomanip>
0023 #include <vector>
0024 
0025 #include "PlotFunctions.h"
0026 
0027 
0028 void CharmJetClassification(TString dir, TString input, TString filePattern = "*/out.root")
0029 {
0030   // Global options
0031   gStyle->SetOptStat(0);
0032 
0033   // Create the TCanvas
0034   TCanvas *pad = new TCanvas("pad",
0035                              "",
0036                              800,
0037                              600);
0038   TLegend *legend    = nullptr;
0039   TH1F    *htemplate = nullptr;
0040 
0041   auto default_data = new TChain("tree");
0042   default_data->SetTitle(input.Data());
0043   auto files = fileVector(Form("%s/%s/%s", dir.Data(), input.Data(), filePattern.Data()));
0044 
0045   for (auto file : files)
0046   {
0047     default_data->Add(file.c_str());
0048   }
0049 
0050   // Create the signal and background trees
0051 
0052   auto signal_train = default_data->CopyTree("jet_flavor==4 && jet_n>0", "", TMath::Floor(default_data->GetEntries() / 1.0));
0053   std::cout << "Signal Tree (Training): " << signal_train->GetEntries() << std::endl;
0054 
0055   auto background_train = default_data->CopyTree("(jet_flavor<4||jet_flavor==21) && jet_n>0", "", TMath::Floor(default_data->GetEntries() / 1.0));
0056   std::cout << "Background Tree (Training): " << background_train->GetEntries() << std::endl;
0057 
0058   // auto u_background_train = default_data->CopyTree("(jet_flavor==2)", "",
0059   // TMath::Floor(default_data->GetEntries()/1.0));
0060   // std::cout << "u Background Tree (Training): " <<
0061   // u_background_train->GetEntries() << std::endl;
0062 
0063   // auto d_background_train = default_data->CopyTree("(jet_flavor==1)", "",
0064   // TMath::Floor(default_data->GetEntries()/1.0));
0065   // std::cout << "d Background Tree (Training): " <<
0066   // d_background_train->GetEntries() << std::endl;
0067 
0068   // auto g_background_train = default_data->CopyTree("(jet_flavor==21)", "",
0069   // TMath::Floor(default_data->GetEntries()/1.0));
0070   // std::cout << "g Background Tree (Training): " <<
0071   // g_background_train->GetEntries() << std::endl;
0072 
0073   // Create the TMVA tools
0074   TMVA::Tools::Instance();
0075 
0076   auto outputFile = TFile::Open("CharmJetClassification_Results.root", "RECREATE");
0077 
0078   TMVA::Factory factory("TMVAClassification",
0079                         outputFile,
0080                         "!V:ROC:!Correlations:!Silent:Color:DrawProgressBar:AnalysisType=Classification");
0081 
0082   //
0083   // Kaon Tagger
0084   //
0085 
0086   TMVA::DataLoader loader_ktagger("dataset_ktagger");
0087   TMVA::DataLoader loader_etagger("dataset_etagger");
0088   TMVA::DataLoader loader_mutagger("dataset_mutagger");
0089   TMVA::DataLoader loader_ip3dtagger("dataset_ip3dtagger");
0090 
0091   // loader_ktagger.AddVariable("jet_pt",  "Jet p_{T}",          "GeV", 'F',
0092   // 0.0,
0093   // 1000.0);
0094   // loader_ktagger.AddVariable("jet_eta", "Jet Pseudorapidity", "",    'F',
0095   // -5.0, 5.0);
0096 
0097   loader_ktagger.AddSpectator("jet_pt");
0098   loader_ktagger.AddSpectator("jet_eta");
0099   loader_ktagger.AddVariable("jet_k1_pt");
0100   loader_ktagger.AddVariable("jet_k1_sIP3D");
0101   loader_ktagger.AddVariable("jet_k2_pt");
0102   loader_ktagger.AddVariable("jet_k2_sIP3D");
0103   loader_ktagger.AddSpectator("jet_flavor");
0104   loader_ktagger.AddSpectator("met_et");
0105   loader_ktagger.AddSignalTree(signal_train, 1.0);
0106   loader_ktagger.AddBackgroundTree(background_train, 1.0);
0107 
0108   // loader_ktagger.AddVariable("jet_nconstituents");
0109 
0110   loader_etagger.AddSpectator("jet_pt");
0111   loader_etagger.AddSpectator("jet_eta");
0112   loader_etagger.AddVariable("jet_e1_pt");
0113   loader_etagger.AddVariable("jet_e1_sIP3D");
0114   loader_etagger.AddVariable("jet_e2_pt");
0115   loader_etagger.AddVariable("jet_e2_sIP3D");
0116   loader_etagger.AddSpectator("jet_flavor");
0117   loader_etagger.AddSpectator("met_et");
0118   loader_etagger.AddSignalTree(signal_train, 1.0);
0119   loader_etagger.AddBackgroundTree(background_train, 1.0);
0120 
0121   loader_mutagger.AddSpectator("jet_pt");
0122   loader_mutagger.AddSpectator("jet_eta");
0123   loader_mutagger.AddVariable("jet_mu1_pt");
0124   loader_mutagger.AddVariable("jet_mu1_sIP3D");
0125   loader_mutagger.AddVariable("jet_mu2_pt");
0126   loader_mutagger.AddVariable("jet_mu2_sIP3D");
0127   loader_mutagger.AddSpectator("jet_flavor");
0128   loader_mutagger.AddSpectator("met_et");
0129   loader_mutagger.AddSignalTree(signal_train, 1.0);
0130   loader_mutagger.AddBackgroundTree(background_train, 1.0);
0131 
0132 
0133   loader_ip3dtagger.AddSpectator("jet_pt");
0134   loader_ip3dtagger.AddSpectator("jet_eta");
0135   loader_ip3dtagger.AddVariable("jet_t1_pt");
0136   loader_ip3dtagger.AddVariable("jet_t1_sIP3D");
0137   loader_ip3dtagger.AddVariable("jet_t2_pt");
0138   loader_ip3dtagger.AddVariable("jet_t2_sIP3D");
0139   loader_ip3dtagger.AddVariable("jet_t3_pt");
0140   loader_ip3dtagger.AddVariable("jet_t3_sIP3D");
0141   loader_ip3dtagger.AddVariable("jet_t4_pt");
0142   loader_ip3dtagger.AddVariable("jet_t4_sIP3D");
0143   loader_ip3dtagger.AddSpectator("jet_flavor");
0144   loader_ip3dtagger.AddSpectator("met_et");
0145   loader_ip3dtagger.AddSignalTree(signal_train, 1.0);
0146   loader_ip3dtagger.AddBackgroundTree(background_train, 1.0);
0147 
0148   //  loader.AddVariable("jet_sip3dtag", "sIP3D Jet-Level Tag", "", 'B', -10,
0149   // 10);
0150   // loader.AddVariable("jet_charge");
0151 
0152   // loader.AddVariable("jet_ehadoveremratio");
0153 
0154 
0155   // loader.AddTree( signal_train, "strange_jets" );
0156   // loader.AddTree( u_background_train, "up jets" );
0157   // loader.AddTree( d_background_train, "down jets" );
0158   // loader.AddTree( g_background_train, "gluon jets" );
0159 
0160   loader_ktagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
0161                                             TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
0162                                             "nTrain_Signal=50000:nTrain_Background=500000:nTest_Signal=50000:nTest_Background=500000:SplitMode=Random:NormMode=NumEvents:!V");
0163   loader_etagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
0164                                             TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
0165                                             "nTrain_Signal=100000:nTrain_Background=1000000:nTest_Signal=100000:nTest_Background=1000000:SplitMode=Random:NormMode=NumEvents:!V");
0166   loader_mutagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
0167                                              TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
0168                                              "nTrain_Signal=100000:nTrain_Background=1000000:nTest_Signal=100000:nTest_Background=1000000:SplitMode=Random:NormMode=NumEvents:!V");
0169   loader_ip3dtagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
0170                                                TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
0171                                                "nTrain_Signal=10000:nTrain_Background=100000:nTest_Signal=10000:nTest_Background=100000:SplitMode=Random:NormMode=NumEvents:!V");
0172 
0173   // Declare the classification method(s)
0174   // factory.BookMethod(&loader,TMVA::Types::kBDT, "BDT",
0175   //
0176   //
0177   //  "!V:NTrees=1000:MinNodeSize=2.5%:MaxDepth=4:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20"
0178   // );
0179   factory.BookMethod(&loader_ktagger,    TMVA::Types::kMLP, "CharmKTagger",
0180                      "!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=1000:HiddenLayers=N+12:TestRate=5:!UseRegulator");
0181   // factory.BookMethod(&loader_etagger,    TMVA::Types::kMLP, "CharmETagger",
0182   //                    "!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=1000:HiddenLayers=N+12:TestRate=5:!UseRegulator");
0183   // factory.BookMethod(&loader_mutagger,   TMVA::Types::kMLP, "CharmMuTagger",
0184   //                    "!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=1000:HiddenLayers=N+12:TestRate=5:!UseRegulator");
0185   // factory.BookMethod(&loader_ip3dtagger, TMVA::Types::kMLP, "CharmIP3DTagger",
0186   //                    "!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=1000:HiddenLayers=N+16:TestRate=5:!UseRegulator");
0187 
0188   // factory.BookMethod(&loader, TMVA::Types::kMLP,
0189   // "CharmETagger","!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=600:HiddenLayers=N+8:TestRate=5:!UseRegulator");
0190   // factory.BookMethod(&loader, TMVA::Types::kMLP,
0191   // "CharmMuTagger","!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=600:HiddenLayers=N+8:TestRate=5:!UseRegulator");
0192 
0193   // Train
0194   factory.TrainAllMethods();
0195 
0196   // Test
0197   factory.TestAllMethods();
0198   factory.EvaluateAllMethods();
0199 
0200   // Plot a ROC Curve
0201   pad->cd();
0202   pad = factory.GetROCCurve(&loader_ktagger);
0203   // pad = factory.GetROCCurve(&loader_etagger);
0204   // pad = factory.GetROCCurve(&loader_mutagger);
0205   // pad = factory.GetROCCurve(&loader_ip3dtagger);
0206   pad->Draw();
0207 
0208   pad->SaveAs("CharmJetClassification_ROC.pdf");
0209 }