Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-11-03 09:01:32

0001 #include "ROOT/RDataFrame.hxx"
0002 #include "TFile.h"
0003 #include "TTree.h"
0004 #include "edm4hep/MCParticle.h"
0005 #include "edm4hep/SimTrackerHit.h"
0006 #include <iostream>
0007 
0008 void cleanData(const TString inputFile="/home/simong/EIC/detector_benchmarks_anl/sim_output/beamline/acceptanceTestcurrent.edm4hep.root", const TString outputFile="test.root", const double BeamEnergy=18.0, const int desired_cellID = 66757, const bool appendTruth = true) {
0009     
0010     float momentum_tolerance = 0.01; // Define the momentum tolerance for filtering
0011 
0012     // Create a ROOT DataFrame to read the input files
0013     ROOT::RDataFrame df("events", inputFile);
0014 
0015     //Filter on events with only a single MCParticle which hasn't scattered as it enters the drift volume
0016     auto filterDF = df.Define("SimParticles", "MCParticles[MCParticles.generatorStatus==1]")
0017                     .Filter("SimParticles.size()==1")
0018                     .Define("beamlineHit", "BackwardsBeamlineHits[BackwardsBeamlineHits.cellID=="+std::to_string(desired_cellID)+"]")
0019                     .Filter("beamlineHit.size()==1")
0020                     .Define("features", [BeamEnergy](const ROOT::VecOps::RVec<edm4hep::SimTrackerHitData>& hit) {
0021                         return std::array<float, 6>{
0022                             static_cast<float>(hit[0].position.x),
0023                             static_cast<float>(hit[0].position.y),
0024                             static_cast<float>(hit[0].position.z),
0025                             static_cast<float>(hit[0].momentum.x / BeamEnergy),
0026                             static_cast<float>(hit[0].momentum.y / BeamEnergy),
0027                             static_cast<float>(hit[0].momentum.z / BeamEnergy)
0028                         };
0029                     }, {"beamlineHit"})
0030                     .Define("targets", [BeamEnergy](const ROOT::VecOps::RVec<edm4hep::MCParticleData>& mcps) {
0031                             return std::array<float, 3>{
0032                                 static_cast<float>(mcps[0].momentum.x / BeamEnergy),
0033                                 static_cast<float>(mcps[0].momentum.y / BeamEnergy),
0034                                 static_cast<float>(mcps[0].momentum.z / BeamEnergy)
0035                             };
0036                         }, {"SimParticles"})
0037                     .Define("features_momentum", [](const std::array<float, 6>& features) {
0038                         return std::sqrt(features[3] * features[3] + features[4] * features[4] + features[5] * features[5]);
0039                     }, {"features"})
0040                     .Define("targets_momentum", [](const std::array<float, 3>& targets) {
0041                         return std::sqrt(targets[0] * targets[0] + targets[1] * targets[1] + targets[2] * targets[2]);
0042                     }, {"targets"})
0043                     .Filter([momentum_tolerance](float features_momentum, float targets_momentum) {
0044                         float relative_difference = std::abs(features_momentum - targets_momentum) / targets_momentum;
0045                         return relative_difference <= momentum_tolerance;
0046                     }, {"features_momentum", "targets_momentum"});
0047 
0048     auto taggerDF = filterDF.Filter("_TaggerTrackerFeatureTensor_shape[0]==1");
0049 
0050     // Save the filtered data to a new ROOT file
0051     taggerDF.Snapshot("events", outputFile, {"_TaggerTrackerFeatureTensor_floatData","_TaggerTrackerFeatureTensor_shape","_TaggerTrackerTargetTensor_floatData","features_momentum","targets_momentum"});
0052 
0053     // Print the size of the original DataFrame
0054     // std::cout << "Original DataFrame size: " << df.Count().GetValue() << std::endl;
0055     // // Print the size of the filtered DataFrame
0056     // std::cout << "Filtered DataFrame size: " << filterDF.Count().GetValue() << std::endl;
0057     
0058     // std::cout << "Tagger filtered DataFrame size" << taggerDF.Count().GetValue() << std::endl;
0059 
0060     // std::cout << "Filtered data saved to " << outputFile << std::endl;
0061 
0062     // If appendTruth is true, add the truth information
0063     if (appendTruth) {
0064         // Open the output file in update mode
0065         ROOT::RDF::RSnapshotOptions opts;
0066         opts.fMode = "update";
0067         auto aliasDF = filterDF.Redefine("_TaggerTrackerFeatureTensor_floatData", "features")
0068                                 .Redefine("_TaggerTrackerTargetTensor_floatData", "targets");
0069         // filterDF = Concatenate({aliasDF, filterDF});
0070         aliasDF.Snapshot("truthevents", outputFile, {"_TaggerTrackerFeatureTensor_floatData", "_TaggerTrackerFeatureTensor_shape", "_TaggerTrackerTargetTensor_floatData"}, opts);
0071         std::cout << "Truth information appended to " << outputFile << std::endl;
0072 
0073         // std::cout << "Total events after appending truth: " << aliasDF.Count().GetValue() << std::endl;
0074     }
0075 
0076 
0077 
0078 
0079 }