File indexing completed on 2025-07-11 07:53:34
0001
0002
0003
0004 #include <TMVA/IMethod.h>
0005 #include <edm4eic/Cov6f.h>
0006 #include <edm4eic/vector_utils.h>
0007 #include <edm4hep/Vector2f.h>
0008 #include <edm4hep/Vector3f.h>
0009 #include <edm4hep/utils/vector_utils.h>
0010 #include <fmt/core.h>
0011 #include <cmath>
0012 #include <cstddef>
0013 #include <cstdint>
0014 #include <exception>
0015 #include <gsl/pointers>
0016 #include <stdexcept>
0017 #include <vector>
0018
0019 #include "FarDetectorMLReconstruction.h"
0020 #include "algorithms/fardetectors/FarDetectorMLReconstructionConfig.h"
0021
0022 namespace eicrecon {
0023
0024 void FarDetectorMLReconstruction::init() {
0025
0026 m_reader = new TMVA::Reader("!Color:!Silent");
0027
0028
0029 m_reader->AddVariable("LowQ2Tracks[0].loc.a", &nnInput[FarDetectorMLNNIndexIn::PosY]);
0030 m_reader->AddVariable("LowQ2Tracks[0].loc.b", &nnInput[FarDetectorMLNNIndexIn::PosZ]);
0031 m_reader->AddVariable("sin(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)",
0032 &nnInput[FarDetectorMLNNIndexIn::DirX]);
0033 m_reader->AddVariable("cos(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)",
0034 &nnInput[FarDetectorMLNNIndexIn::DirY]);
0035
0036
0037
0038 if (!m_cfg.modelPath.empty()) {
0039 try {
0040 m_method =
0041 dynamic_cast<TMVA::MethodBase*>(m_reader->BookMVA(m_cfg.methodName, m_cfg.modelPath));
0042 } catch (std::exception& e) {
0043 error(fmt::format("Failed to load method {} from file {}: {}", m_cfg.methodName,
0044 m_cfg.modelPath, e.what()));
0045 }
0046
0047 } else {
0048 error("No model path provided for FarDetectorMLReconstruction");
0049 }
0050 }
0051
0052 void FarDetectorMLReconstruction::process(const FarDetectorMLReconstruction::Input& input,
0053 const FarDetectorMLReconstruction::Output& output) const {
0054
0055 const auto [inputProjectedTracks, beamElectrons, inputFittedTracks, inputFittedAssociations] =
0056 input;
0057 auto [outputFarDetectorMLTrajectories, outputFarDetectorMLTrackParameters,
0058 outputFarDetectorMLTracks, outputAssociations] = output;
0059
0060
0061 std::call_once(m_initBeamE, [&]() {
0062
0063 if (beamElectrons->empty()) {
0064 if (m_cfg.requireBeamElectron) {
0065 critical("No beam electrons found");
0066 throw std::runtime_error("No beam electrons found");
0067 }
0068 return;
0069 }
0070 m_beamE = beamElectrons->at(0).getEnergy();
0071
0072 m_beamE = round(m_beamE);
0073 });
0074
0075
0076 std::int32_t type = 0;
0077 float charge = -1;
0078
0079 for (std::size_t i = 0; i < inputProjectedTracks->size(); i++) {
0080
0081 auto track = (*inputProjectedTracks)[i];
0082
0083 auto pos = track.getLoc();
0084 auto trackphi = track.getPhi();
0085 auto tracktheta = track.getTheta();
0086
0087 nnInput[FarDetectorMLNNIndexIn::PosY] = pos.a;
0088 nnInput[FarDetectorMLNNIndexIn::PosZ] = pos.b;
0089 nnInput[FarDetectorMLNNIndexIn::DirX] = sin(trackphi) * sin(tracktheta);
0090 nnInput[FarDetectorMLNNIndexIn::DirY] = cos(trackphi) * sin(tracktheta);
0091
0092 auto values = m_method->GetRegressionValues();
0093
0094 edm4hep::Vector3f momentum = {values[FarDetectorMLNNIndexOut::MomX],
0095 values[FarDetectorMLNNIndexOut::MomY],
0096 values[FarDetectorMLNNIndexOut::MomZ]};
0097
0098
0099 trace("Prescaled Output Momentum: x {}, y {}, z {}", values[FarDetectorMLNNIndexOut::MomX],
0100 values[FarDetectorMLNNIndexOut::MomY], values[FarDetectorMLNNIndexOut::MomZ]);
0101 trace("Prescaled Momentum: {}", edm4eic::magnitude(momentum));
0102
0103
0104 momentum = momentum * m_beamE;
0105 trace("Scaled Momentum: {}", edm4eic::magnitude(momentum));
0106
0107
0108
0109
0110 edm4hep::Vector2f loc(0, 0);
0111 uint64_t surface = 0;
0112 float theta = edm4eic::anglePolar(momentum);
0113 float phi = edm4eic::angleAzimuthal(momentum);
0114 float qOverP = charge / edm4eic::magnitude(momentum);
0115 float time = 0;
0116
0117 int32_t pdg = 11;
0118
0119 edm4eic::Cov6f error;
0120
0121 edm4eic::TrackParameters params = outputFarDetectorMLTrackParameters->create(
0122 type, surface, loc, theta, phi, qOverP, time, pdg, error);
0123
0124 auto trajectory = outputFarDetectorMLTrajectories->create();
0125 trajectory.addToTrackParameters(params);
0126
0127 int32_t trackType = 0;
0128 edm4hep::Vector3f position = {0, 0, 0};
0129 float timeError = 0;
0130 float charge = -1;
0131 float chi2 = 0;
0132 uint32_t ndf = 0;
0133
0134 auto outTrack = outputFarDetectorMLTracks->create(trackType, position, momentum, error, time,
0135 timeError, charge, chi2, ndf, pdg);
0136 outTrack.setTrajectory(trajectory);
0137
0138
0139
0140 for (auto assoc : *inputFittedAssociations) {
0141 if (assoc.getRec() == (*inputFittedTracks)[i]) {
0142 auto outAssoc = assoc.clone();
0143 outAssoc.setRec(outTrack);
0144 outputAssociations->push_back(outAssoc);
0145 }
0146 }
0147 }
0148 }
0149
0150 }