File indexing completed on 2025-09-14 08:15:51
0001
0002
0003
0004 #include <TMVA/IMethod.h>
0005 #include <TString.h>
0006 #include <edm4eic/Cov6f.h>
0007 #include <edm4eic/vector_utils.h>
0008 #include <edm4hep/Vector2f.h>
0009 #include <edm4hep/Vector3f.h>
0010 #include <edm4hep/utils/vector_utils.h>
0011 #include <fmt/core.h>
0012 #include <cmath>
0013 #include <cstddef>
0014 #include <cstdint>
0015 #include <exception>
0016 #include <gsl/pointers>
0017 #include <stdexcept>
0018 #include <vector>
0019
0020 #include "FarDetectorMLReconstruction.h"
0021 #include "algorithms/fardetectors/FarDetectorMLReconstructionConfig.h"
0022
0023 namespace eicrecon {
0024
0025 void FarDetectorMLReconstruction::init() {
0026
0027 m_reader = new TMVA::Reader("!Color:!Silent");
0028
0029
0030 m_reader->AddVariable("LowQ2Tracks[0].loc.a", &nnInput[FarDetectorMLNNIndexIn::PosY]);
0031 m_reader->AddVariable("LowQ2Tracks[0].loc.b", &nnInput[FarDetectorMLNNIndexIn::PosZ]);
0032 m_reader->AddVariable("sin(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)",
0033 &nnInput[FarDetectorMLNNIndexIn::DirX]);
0034 m_reader->AddVariable("cos(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)",
0035 &nnInput[FarDetectorMLNNIndexIn::DirY]);
0036
0037
0038
0039 if (!m_cfg.modelPath.empty()) {
0040 try {
0041 m_method =
0042 dynamic_cast<TMVA::MethodBase*>(m_reader->BookMVA(m_cfg.methodName, m_cfg.modelPath));
0043 } catch (std::exception& e) {
0044 error(fmt::format("Failed to load method {} from file {}: {}", m_cfg.methodName,
0045 m_cfg.modelPath, e.what()));
0046 }
0047
0048 } else {
0049 error("No model path provided for FarDetectorMLReconstruction");
0050 }
0051 }
0052
0053 void FarDetectorMLReconstruction::process(const FarDetectorMLReconstruction::Input& input,
0054 const FarDetectorMLReconstruction::Output& output) const {
0055
0056 const auto [inputProjectedTracks, beamElectrons, inputFittedTracks, inputFittedAssociations] =
0057 input;
0058 auto [outputFarDetectorMLTrajectories, outputFarDetectorMLTrackParameters,
0059 outputFarDetectorMLTracks, outputAssociations] = output;
0060
0061
0062 std::call_once(m_initBeamE, [&]() {
0063
0064 if (beamElectrons->empty()) {
0065 if (m_cfg.requireBeamElectron) {
0066 critical("No beam electrons found");
0067 throw std::runtime_error("No beam electrons found");
0068 }
0069 return;
0070 }
0071 m_beamE = beamElectrons->at(0).getEnergy();
0072
0073 m_beamE = round(m_beamE);
0074 });
0075
0076
0077 std::int32_t type = 0;
0078 float charge = -1;
0079
0080 for (std::size_t i = 0; i < inputProjectedTracks->size(); i++) {
0081
0082 auto track = (*inputProjectedTracks)[i];
0083
0084 auto pos = track.getLoc();
0085 auto trackphi = track.getPhi();
0086 auto tracktheta = track.getTheta();
0087
0088 nnInput[FarDetectorMLNNIndexIn::PosY] = pos.a;
0089 nnInput[FarDetectorMLNNIndexIn::PosZ] = pos.b;
0090 nnInput[FarDetectorMLNNIndexIn::DirX] = sin(trackphi) * sin(tracktheta);
0091 nnInput[FarDetectorMLNNIndexIn::DirY] = cos(trackphi) * sin(tracktheta);
0092
0093 auto values = m_method->GetRegressionValues();
0094
0095 edm4hep::Vector3f momentum = {values[FarDetectorMLNNIndexOut::MomX],
0096 values[FarDetectorMLNNIndexOut::MomY],
0097 values[FarDetectorMLNNIndexOut::MomZ]};
0098
0099
0100 trace("Prescaled Output Momentum: x {}, y {}, z {}", values[FarDetectorMLNNIndexOut::MomX],
0101 values[FarDetectorMLNNIndexOut::MomY], values[FarDetectorMLNNIndexOut::MomZ]);
0102 trace("Prescaled Momentum: {}", edm4eic::magnitude(momentum));
0103
0104
0105 momentum = momentum * m_beamE;
0106 trace("Scaled Momentum: {}", edm4eic::magnitude(momentum));
0107
0108
0109
0110
0111 edm4hep::Vector2f loc(0, 0);
0112 uint64_t surface = 0;
0113 float theta = edm4eic::anglePolar(momentum);
0114 float phi = edm4eic::angleAzimuthal(momentum);
0115 float qOverP = charge / edm4eic::magnitude(momentum);
0116 float time = 0;
0117
0118 int32_t pdg = 11;
0119
0120 edm4eic::Cov6f error;
0121
0122 edm4eic::TrackParameters params = outputFarDetectorMLTrackParameters->create(
0123 type, surface, loc, theta, phi, qOverP, time, pdg, error);
0124
0125 auto trajectory = outputFarDetectorMLTrajectories->create();
0126 trajectory.addToTrackParameters(params);
0127
0128 int32_t trackType = 0;
0129 edm4hep::Vector3f position = {0, 0, 0};
0130 float timeError = 0;
0131 float charge = -1;
0132 float chi2 = 0;
0133 uint32_t ndf = 0;
0134
0135 auto outTrack = outputFarDetectorMLTracks->create(trackType, position, momentum, error, time,
0136 timeError, charge, chi2, ndf, pdg);
0137 outTrack.setTrajectory(trajectory);
0138
0139
0140
0141 for (auto assoc : *inputFittedAssociations) {
0142 if (assoc.getRec() == (*inputFittedTracks)[i]) {
0143 auto outAssoc = assoc.clone();
0144 outAssoc.setRec(outTrack);
0145 outputAssociations->push_back(outAssoc);
0146 }
0147 }
0148 }
0149 }
0150
0151 }