Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-11 07:53:34

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2024, Simon Gardner
0003 
0004 #include <TMVA/IMethod.h>
0005 #include <edm4eic/Cov6f.h>
0006 #include <edm4eic/vector_utils.h>
0007 #include <edm4hep/Vector2f.h>
0008 #include <edm4hep/Vector3f.h>
0009 #include <edm4hep/utils/vector_utils.h>
0010 #include <fmt/core.h>
0011 #include <cmath>
0012 #include <cstddef>
0013 #include <cstdint>
0014 #include <exception>
0015 #include <gsl/pointers>
0016 #include <stdexcept>
0017 #include <vector>
0018 
0019 #include "FarDetectorMLReconstruction.h"
0020 #include "algorithms/fardetectors/FarDetectorMLReconstructionConfig.h"
0021 
0022 namespace eicrecon {
0023 
0024 void FarDetectorMLReconstruction::init() {
0025 
0026   m_reader = new TMVA::Reader("!Color:!Silent");
0027   // Create a set of variables and declare them to the reader
0028   // - the variable names MUST corresponds in name and type to those given in the weight file(s) used
0029   m_reader->AddVariable("LowQ2Tracks[0].loc.a", &nnInput[FarDetectorMLNNIndexIn::PosY]);
0030   m_reader->AddVariable("LowQ2Tracks[0].loc.b", &nnInput[FarDetectorMLNNIndexIn::PosZ]);
0031   m_reader->AddVariable("sin(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)",
0032                         &nnInput[FarDetectorMLNNIndexIn::DirX]);
0033   m_reader->AddVariable("cos(LowQ2Tracks[0].phi)*sin(LowQ2Tracks[0].theta)",
0034                         &nnInput[FarDetectorMLNNIndexIn::DirY]);
0035 
0036   // Locate and load the weight file
0037   // TODO - Add functionality to select passed by configuration
0038   if (!m_cfg.modelPath.empty()) {
0039     try {
0040       m_method =
0041           dynamic_cast<TMVA::MethodBase*>(m_reader->BookMVA(m_cfg.methodName, m_cfg.modelPath));
0042     } catch (std::exception& e) {
0043       error(fmt::format("Failed to load method {} from file {}: {}", m_cfg.methodName,
0044                         m_cfg.modelPath, e.what()));
0045     }
0046 
0047   } else {
0048     error("No model path provided for FarDetectorMLReconstruction");
0049   }
0050 }
0051 
0052 void FarDetectorMLReconstruction::process(const FarDetectorMLReconstruction::Input& input,
0053                                           const FarDetectorMLReconstruction::Output& output) const {
0054 
0055   const auto [inputProjectedTracks, beamElectrons, inputFittedTracks, inputFittedAssociations] =
0056       input;
0057   auto [outputFarDetectorMLTrajectories, outputFarDetectorMLTrackParameters,
0058         outputFarDetectorMLTracks, outputAssociations] = output;
0059 
0060   //Set beam energy from first MCBeamElectron, using std::call_once
0061   std::call_once(m_initBeamE, [&]() {
0062     // Check if beam electrons are present
0063     if (beamElectrons->empty()) { // NOLINT(clang-analyzer-core.CallAndMessage)
0064       if (m_cfg.requireBeamElectron) {
0065         critical("No beam electrons found");
0066         throw std::runtime_error("No beam electrons found");
0067       }
0068       return;
0069     }
0070     m_beamE = beamElectrons->at(0).getEnergy();
0071     //Round beam energy to nearest GeV - Should be 5, 10 or 18GeV
0072     m_beamE = round(m_beamE);
0073   });
0074 
0075   // Reconstructed particle members which don't change
0076   std::int32_t type = 0; // Check?
0077   float charge      = -1;
0078 
0079   for (std::size_t i = 0; i < inputProjectedTracks->size(); i++) {
0080     // Get the track parameters
0081     auto track = (*inputProjectedTracks)[i];
0082 
0083     auto pos        = track.getLoc();
0084     auto trackphi   = track.getPhi();
0085     auto tracktheta = track.getTheta();
0086 
0087     nnInput[FarDetectorMLNNIndexIn::PosY] = pos.a;
0088     nnInput[FarDetectorMLNNIndexIn::PosZ] = pos.b;
0089     nnInput[FarDetectorMLNNIndexIn::DirX] = sin(trackphi) * sin(tracktheta);
0090     nnInput[FarDetectorMLNNIndexIn::DirY] = cos(trackphi) * sin(tracktheta);
0091 
0092     auto values = m_method->GetRegressionValues();
0093 
0094     edm4hep::Vector3f momentum = {values[FarDetectorMLNNIndexOut::MomX],
0095                                   values[FarDetectorMLNNIndexOut::MomY],
0096                                   values[FarDetectorMLNNIndexOut::MomZ]};
0097 
0098     // log out the momentum components and magnitude
0099     trace("Prescaled Output Momentum: x {}, y {}, z {}", values[FarDetectorMLNNIndexOut::MomX],
0100           values[FarDetectorMLNNIndexOut::MomY], values[FarDetectorMLNNIndexOut::MomZ]);
0101     trace("Prescaled Momentum: {}", edm4eic::magnitude(momentum));
0102 
0103     // Scale momentum magnitude
0104     momentum = momentum * m_beamE;
0105     trace("Scaled Momentum: {}", edm4eic::magnitude(momentum));
0106 
0107     // Track parameter variables
0108     // TODO: Add time and momentum errors
0109     // Plane Point
0110     edm4hep::Vector2f loc(0, 0); // Vertex estimate
0111     uint64_t surface = 0;        //Not used in this context
0112     float theta      = edm4eic::anglePolar(momentum);
0113     float phi        = edm4eic::angleAzimuthal(momentum);
0114     float qOverP     = charge / edm4eic::magnitude(momentum);
0115     float time       = 0;
0116     // PDG
0117     int32_t pdg = 11;
0118     // Point Error
0119     edm4eic::Cov6f error;
0120 
0121     edm4eic::TrackParameters params = outputFarDetectorMLTrackParameters->create(
0122         type, surface, loc, theta, phi, qOverP, time, pdg, error);
0123 
0124     auto trajectory = outputFarDetectorMLTrajectories->create();
0125     trajectory.addToTrackParameters(params);
0126 
0127     int32_t trackType          = 0;
0128     edm4hep::Vector3f position = {0, 0, 0};
0129     float timeError            = 0;
0130     float charge               = -1;
0131     float chi2                 = 0;
0132     uint32_t ndf               = 0;
0133 
0134     auto outTrack = outputFarDetectorMLTracks->create(trackType, position, momentum, error, time,
0135                                                       timeError, charge, chi2, ndf, pdg);
0136     outTrack.setTrajectory(trajectory);
0137 
0138     // Propagate the track associations
0139     // The order of the tracks needs to be the same in both collections with no filtering
0140     for (auto assoc : *inputFittedAssociations) {
0141       if (assoc.getRec() == (*inputFittedTracks)[i]) {
0142         auto outAssoc = assoc.clone();
0143         outAssoc.setRec(outTrack);
0144         outputAssociations->push_back(outAssoc);
0145       }
0146     }
0147   }
0148 }
0149 
0150 } // namespace eicrecon