Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-09-27 07:02:58

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2024, Simon Gardner
0003 
0004 #include <TMVA/IMethod.h>
0005 #include <edm4eic/Cov6f.h>
0006 #include <edm4eic/vector_utils.h>
0007 #include <edm4hep/Vector2f.h>
0008 #include <edm4hep/Vector3f.h>
0009 #include <edm4hep/utils/vector_utils.h>
0010 #include <fmt/core.h>
0011 #include <cmath>
0012 #include <cstdint>
0013 #include <exception>
0014 #include <gsl/pointers>
0015 #include <vector>
0016 
0017 #include "FarDetectorMLReconstruction.h"
0018 #include "algorithms/fardetectors/FarDetectorMLReconstructionConfig.h"
0019 
0020 namespace eicrecon {
0021 
0022 
0023   void FarDetectorMLReconstruction::init() {
0024 
0025     m_reader = new TMVA::Reader( "!Color:!Silent" );
0026     // Create a set of variables and declare them to the reader
0027     // - the variable names MUST corresponds in name and type to those given in the weight file(s) used
0028     m_reader->AddVariable( "LowQ2Tracks[0].loc.a", &nnInput[FarDetectorMLNNIndexIn::PosY] );
0029     m_reader->AddVariable( "LowQ2Tracks[0].loc.b", &nnInput[FarDetectorMLNNIndexIn::PosZ] );
0030     m_reader->AddVariable( "sin(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)", &nnInput[FarDetectorMLNNIndexIn::DirX] );
0031     m_reader->AddVariable( "cos(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)", &nnInput[FarDetectorMLNNIndexIn::DirY] );
0032 
0033     // Locate and load the weight file
0034     // TODO - Add functionality to select passed by configuration
0035     bool methodFound = false;
0036     if(!m_cfg.modelPath.empty()){
0037       try{
0038         m_method = dynamic_cast<TMVA::MethodBase*>(m_reader->BookMVA( m_cfg.methodName, m_cfg.modelPath ));
0039       }
0040       catch(std::exception &e){
0041         error(fmt::format("Failed to load method {} from file {}: {}", m_cfg.methodName, m_cfg.modelPath, e.what()));
0042       }
0043 
0044     } else {
0045       error("No model path provided for FarDetectorMLReconstruction");
0046     }
0047   }
0048 
0049 
0050 
0051   void FarDetectorMLReconstruction::process(
0052       const FarDetectorMLReconstruction::Input& input,
0053       const FarDetectorMLReconstruction::Output& output) {
0054 
0055     const auto [inputTracks,beamElectrons] = input;
0056     auto [outputFarDetectorMLTrajectories, outputFarDetectorMLTrackParameters, outputFarDetectorMLTracks] = output;
0057 
0058     //Set beam energy from first MCBeamElectron, using std::call_once
0059     std::call_once(m_initBeamE,[&](){
0060       // Check if beam electrons are present
0061       if(beamElectrons->size() == 0){
0062         error("No beam electrons found keeping default 10GeV beam energy.");
0063         return;
0064       }
0065       m_beamE = beamElectrons->at(0).getEnergy();
0066       //Round beam energy to nearest GeV - Should be 5, 10 or 18GeV
0067       m_beamE = round(m_beamE);
0068     });
0069 
0070     // Reconstructed particle members which don't change
0071     std::int32_t type   = 0; // Check?
0072     float        charge = -1;
0073 
0074     for(const auto& track: *inputTracks){
0075 
0076       auto pos        = track.getLoc();
0077       auto trackphi   = track.getPhi();
0078       auto tracktheta = track.getTheta();
0079 
0080       nnInput[FarDetectorMLNNIndexIn::PosY] = pos.a;
0081       nnInput[FarDetectorMLNNIndexIn::PosZ] = pos.b;
0082       nnInput[FarDetectorMLNNIndexIn::DirX] = sin(trackphi)*sin(tracktheta);
0083       nnInput[FarDetectorMLNNIndexIn::DirY] = cos(trackphi)*sin(tracktheta);
0084 
0085       auto values = m_method->GetRegressionValues();
0086 
0087       edm4hep::Vector3f momentum = {values[FarDetectorMLNNIndexOut::MomX],values[FarDetectorMLNNIndexOut::MomY],values[FarDetectorMLNNIndexOut::MomZ]};
0088 
0089       // log out the momentum components and magnitude
0090       trace("Prescaled Output Momentum: x {}, y {}, z {}",values[FarDetectorMLNNIndexOut::MomX],values[FarDetectorMLNNIndexOut::MomY],values[FarDetectorMLNNIndexOut::MomZ]);
0091       trace("Prescaled Momentum: {}",edm4eic::magnitude(momentum));
0092 
0093       // Scale momentum magnitude
0094       momentum = momentum*m_beamE;
0095       trace("Scaled Momentum: {}",edm4eic::magnitude(momentum));
0096 
0097       // Track parameter variables
0098       // TODO: Add time and momentum errors
0099       // Plane Point
0100       edm4hep::Vector2f loc(0,0); // Vertex estimate
0101       uint64_t surface = 0; //Not used in this context
0102       float theta   = edm4eic::anglePolar(momentum);
0103       float phi     = edm4eic::angleAzimuthal(momentum);
0104       float qOverP  = charge/edm4eic::magnitude(momentum);
0105       float time;
0106       // PDG
0107       int32_t pdg = 11;
0108       // Point Error
0109       edm4eic::Cov6f error;
0110 
0111       edm4eic::TrackParameters params =  outputFarDetectorMLTrackParameters->create(type,surface,loc,theta,phi,qOverP,time,pdg,error);
0112 
0113       auto trajectory = outputFarDetectorMLTrajectories->create();
0114       trajectory.addToTrackParameters(params);
0115 
0116       int32_t trackType = 0;
0117       edm4hep::Vector3f position = {0,0,0};
0118       float timeError;
0119       float charge    = -1;
0120       float chi2      = 0;
0121       uint32_t ndf    = 0;
0122 
0123       auto outTrack      = outputFarDetectorMLTracks->create(trackType,position,momentum,error,time,timeError,charge,chi2,ndf);
0124       outTrack.setTrajectory(trajectory);
0125 
0126     }
0127 
0128   }
0129 
0130 }