File indexing completed on 2024-09-27 07:02:58
0001
0002
0003
0004 #include <TMVA/IMethod.h>
0005 #include <edm4eic/Cov6f.h>
0006 #include <edm4eic/vector_utils.h>
0007 #include <edm4hep/Vector2f.h>
0008 #include <edm4hep/Vector3f.h>
0009 #include <edm4hep/utils/vector_utils.h>
0010 #include <fmt/core.h>
0011 #include <cmath>
0012 #include <cstdint>
0013 #include <exception>
0014 #include <gsl/pointers>
0015 #include <vector>
0016
0017 #include "FarDetectorMLReconstruction.h"
0018 #include "algorithms/fardetectors/FarDetectorMLReconstructionConfig.h"
0019
0020 namespace eicrecon {
0021
0022
0023 void FarDetectorMLReconstruction::init() {
0024
0025 m_reader = new TMVA::Reader( "!Color:!Silent" );
0026
0027
0028 m_reader->AddVariable( "LowQ2Tracks[0].loc.a", &nnInput[FarDetectorMLNNIndexIn::PosY] );
0029 m_reader->AddVariable( "LowQ2Tracks[0].loc.b", &nnInput[FarDetectorMLNNIndexIn::PosZ] );
0030 m_reader->AddVariable( "sin(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)", &nnInput[FarDetectorMLNNIndexIn::DirX] );
0031 m_reader->AddVariable( "cos(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)", &nnInput[FarDetectorMLNNIndexIn::DirY] );
0032
0033
0034
0035 bool methodFound = false;
0036 if(!m_cfg.modelPath.empty()){
0037 try{
0038 m_method = dynamic_cast<TMVA::MethodBase*>(m_reader->BookMVA( m_cfg.methodName, m_cfg.modelPath ));
0039 }
0040 catch(std::exception &e){
0041 error(fmt::format("Failed to load method {} from file {}: {}", m_cfg.methodName, m_cfg.modelPath, e.what()));
0042 }
0043
0044 } else {
0045 error("No model path provided for FarDetectorMLReconstruction");
0046 }
0047 }
0048
0049
0050
0051 void FarDetectorMLReconstruction::process(
0052 const FarDetectorMLReconstruction::Input& input,
0053 const FarDetectorMLReconstruction::Output& output) {
0054
0055 const auto [inputTracks,beamElectrons] = input;
0056 auto [outputFarDetectorMLTrajectories, outputFarDetectorMLTrackParameters, outputFarDetectorMLTracks] = output;
0057
0058
0059 std::call_once(m_initBeamE,[&](){
0060
0061 if(beamElectrons->size() == 0){
0062 error("No beam electrons found keeping default 10GeV beam energy.");
0063 return;
0064 }
0065 m_beamE = beamElectrons->at(0).getEnergy();
0066
0067 m_beamE = round(m_beamE);
0068 });
0069
0070
0071 std::int32_t type = 0;
0072 float charge = -1;
0073
0074 for(const auto& track: *inputTracks){
0075
0076 auto pos = track.getLoc();
0077 auto trackphi = track.getPhi();
0078 auto tracktheta = track.getTheta();
0079
0080 nnInput[FarDetectorMLNNIndexIn::PosY] = pos.a;
0081 nnInput[FarDetectorMLNNIndexIn::PosZ] = pos.b;
0082 nnInput[FarDetectorMLNNIndexIn::DirX] = sin(trackphi)*sin(tracktheta);
0083 nnInput[FarDetectorMLNNIndexIn::DirY] = cos(trackphi)*sin(tracktheta);
0084
0085 auto values = m_method->GetRegressionValues();
0086
0087 edm4hep::Vector3f momentum = {values[FarDetectorMLNNIndexOut::MomX],values[FarDetectorMLNNIndexOut::MomY],values[FarDetectorMLNNIndexOut::MomZ]};
0088
0089
0090 trace("Prescaled Output Momentum: x {}, y {}, z {}",values[FarDetectorMLNNIndexOut::MomX],values[FarDetectorMLNNIndexOut::MomY],values[FarDetectorMLNNIndexOut::MomZ]);
0091 trace("Prescaled Momentum: {}",edm4eic::magnitude(momentum));
0092
0093
0094 momentum = momentum*m_beamE;
0095 trace("Scaled Momentum: {}",edm4eic::magnitude(momentum));
0096
0097
0098
0099
0100 edm4hep::Vector2f loc(0,0);
0101 uint64_t surface = 0;
0102 float theta = edm4eic::anglePolar(momentum);
0103 float phi = edm4eic::angleAzimuthal(momentum);
0104 float qOverP = charge/edm4eic::magnitude(momentum);
0105 float time;
0106
0107 int32_t pdg = 11;
0108
0109 edm4eic::Cov6f error;
0110
0111 edm4eic::TrackParameters params = outputFarDetectorMLTrackParameters->create(type,surface,loc,theta,phi,qOverP,time,pdg,error);
0112
0113 auto trajectory = outputFarDetectorMLTrajectories->create();
0114 trajectory.addToTrackParameters(params);
0115
0116 int32_t trackType = 0;
0117 edm4hep::Vector3f position = {0,0,0};
0118 float timeError;
0119 float charge = -1;
0120 float chi2 = 0;
0121 uint32_t ndf = 0;
0122
0123 auto outTrack = outputFarDetectorMLTracks->create(trackType,position,momentum,error,time,timeError,charge,chi2,ndf);
0124 outTrack.setTrajectory(trajectory);
0125
0126 }
0127
0128 }
0129
0130 }