Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-14 08:15:51

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2024, Simon Gardner
0003 
0004 #include <TMVA/IMethod.h>
0005 #include <TString.h>
0006 #include <edm4eic/Cov6f.h>
0007 #include <edm4eic/vector_utils.h>
0008 #include <edm4hep/Vector2f.h>
0009 #include <edm4hep/Vector3f.h>
0010 #include <edm4hep/utils/vector_utils.h>
0011 #include <fmt/core.h>
0012 #include <cmath>
0013 #include <cstddef>
0014 #include <cstdint>
0015 #include <exception>
0016 #include <gsl/pointers>
0017 #include <stdexcept>
0018 #include <vector>
0019 
0020 #include "FarDetectorMLReconstruction.h"
0021 #include "algorithms/fardetectors/FarDetectorMLReconstructionConfig.h"
0022 
0023 namespace eicrecon {
0024 
0025 void FarDetectorMLReconstruction::init() {
0026 
0027   m_reader = new TMVA::Reader("!Color:!Silent");
0028   // Create a set of variables and declare them to the reader
0029   // - the variable names MUST corresponds in name and type to those given in the weight file(s) used
0030   m_reader->AddVariable("LowQ2Tracks[0].loc.a", &nnInput[FarDetectorMLNNIndexIn::PosY]);
0031   m_reader->AddVariable("LowQ2Tracks[0].loc.b", &nnInput[FarDetectorMLNNIndexIn::PosZ]);
0032   m_reader->AddVariable("sin(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)",
0033                         &nnInput[FarDetectorMLNNIndexIn::DirX]);
0034   m_reader->AddVariable("cos(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)",
0035                         &nnInput[FarDetectorMLNNIndexIn::DirY]);
0036 
0037   // Locate and load the weight file
0038   // TODO - Add functionality to select passed by configuration
0039   if (!m_cfg.modelPath.empty()) {
0040     try {
0041       m_method =
0042           dynamic_cast<TMVA::MethodBase*>(m_reader->BookMVA(m_cfg.methodName, m_cfg.modelPath));
0043     } catch (std::exception& e) {
0044       error(fmt::format("Failed to load method {} from file {}: {}", m_cfg.methodName,
0045                         m_cfg.modelPath, e.what()));
0046     }
0047 
0048   } else {
0049     error("No model path provided for FarDetectorMLReconstruction");
0050   }
0051 }
0052 
0053 void FarDetectorMLReconstruction::process(const FarDetectorMLReconstruction::Input& input,
0054                                           const FarDetectorMLReconstruction::Output& output) const {
0055 
0056   const auto [inputProjectedTracks, beamElectrons, inputFittedTracks, inputFittedAssociations] =
0057       input;
0058   auto [outputFarDetectorMLTrajectories, outputFarDetectorMLTrackParameters,
0059         outputFarDetectorMLTracks, outputAssociations] = output;
0060 
0061   //Set beam energy from first MCBeamElectron, using std::call_once
0062   std::call_once(m_initBeamE, [&]() {
0063     // Check if beam electrons are present
0064     if (beamElectrons->empty()) { // NOLINT(clang-analyzer-core.CallAndMessage)
0065       if (m_cfg.requireBeamElectron) {
0066         critical("No beam electrons found");
0067         throw std::runtime_error("No beam electrons found");
0068       }
0069       return;
0070     }
0071     m_beamE = beamElectrons->at(0).getEnergy();
0072     //Round beam energy to nearest GeV - Should be 5, 10 or 18GeV
0073     m_beamE = round(m_beamE);
0074   });
0075 
0076   // Reconstructed particle members which don't change
0077   std::int32_t type = 0; // Check?
0078   float charge      = -1;
0079 
0080   for (std::size_t i = 0; i < inputProjectedTracks->size(); i++) {
0081     // Get the track parameters
0082     auto track = (*inputProjectedTracks)[i];
0083 
0084     auto pos        = track.getLoc();
0085     auto trackphi   = track.getPhi();
0086     auto tracktheta = track.getTheta();
0087 
0088     nnInput[FarDetectorMLNNIndexIn::PosY] = pos.a;
0089     nnInput[FarDetectorMLNNIndexIn::PosZ] = pos.b;
0090     nnInput[FarDetectorMLNNIndexIn::DirX] = sin(trackphi) * sin(tracktheta);
0091     nnInput[FarDetectorMLNNIndexIn::DirY] = cos(trackphi) * sin(tracktheta);
0092 
0093     auto values = m_method->GetRegressionValues();
0094 
0095     edm4hep::Vector3f momentum = {values[FarDetectorMLNNIndexOut::MomX],
0096                                   values[FarDetectorMLNNIndexOut::MomY],
0097                                   values[FarDetectorMLNNIndexOut::MomZ]};
0098 
0099     // log out the momentum components and magnitude
0100     trace("Prescaled Output Momentum: x {}, y {}, z {}", values[FarDetectorMLNNIndexOut::MomX],
0101           values[FarDetectorMLNNIndexOut::MomY], values[FarDetectorMLNNIndexOut::MomZ]);
0102     trace("Prescaled Momentum: {}", edm4eic::magnitude(momentum));
0103 
0104     // Scale momentum magnitude
0105     momentum = momentum * m_beamE;
0106     trace("Scaled Momentum: {}", edm4eic::magnitude(momentum));
0107 
0108     // Track parameter variables
0109     // TODO: Add time and momentum errors
0110     // Plane Point
0111     edm4hep::Vector2f loc(0, 0); // Vertex estimate
0112     uint64_t surface = 0;        //Not used in this context
0113     float theta      = edm4eic::anglePolar(momentum);
0114     float phi        = edm4eic::angleAzimuthal(momentum);
0115     float qOverP     = charge / edm4eic::magnitude(momentum);
0116     float time       = 0;
0117     // PDG
0118     int32_t pdg = 11;
0119     // Point Error
0120     edm4eic::Cov6f error;
0121 
0122     edm4eic::TrackParameters params = outputFarDetectorMLTrackParameters->create(
0123         type, surface, loc, theta, phi, qOverP, time, pdg, error);
0124 
0125     auto trajectory = outputFarDetectorMLTrajectories->create();
0126     trajectory.addToTrackParameters(params);
0127 
0128     int32_t trackType          = 0;
0129     edm4hep::Vector3f position = {0, 0, 0};
0130     float timeError            = 0;
0131     float charge               = -1;
0132     float chi2                 = 0;
0133     uint32_t ndf               = 0;
0134 
0135     auto outTrack = outputFarDetectorMLTracks->create(trackType, position, momentum, error, time,
0136                                                       timeError, charge, chi2, ndf, pdg);
0137     outTrack.setTrajectory(trajectory);
0138 
0139     // Propagate the track associations
0140     // The order of the tracks needs to be the same in both collections with no filtering
0141     for (auto assoc : *inputFittedAssociations) {
0142       if (assoc.getRec() == (*inputFittedTracks)[i]) {
0143         auto outAssoc = assoc.clone();
0144         outAssoc.setRec(outTrack);
0145         outputAssociations->push_back(outAssoc);
0146       }
0147     }
0148   }
0149 }
0150 
0151 } // namespace eicrecon