File indexing completed on 2025-12-16 09:27:59
0001
0002
0003
0004 #include <edm4hep/Vector3f.h>
0005 #include <fmt/core.h>
0006 #include <podio/RelationRange.h>
0007 #include <cmath>
0008 #include <cstddef>
0009 #include <gsl/pointers>
0010 #include <stdexcept>
0011
0012 #include "FarDetectorTransportationPostML.h"
0013 #include "services/particle/ParticleSvc.h"
0014
0015 namespace eicrecon {
0016
0017 void FarDetectorTransportationPostML::init() {
0018 m_beamE = m_cfg.beamE;
0019 auto& particleSvc = algorithms::ParticleSvc::instance();
0020 m_mass = particleSvc.particle(m_cfg.pdg_value).mass;
0021 m_charge = particleSvc.particle(m_cfg.pdg_value).charge;
0022 }
0023
0024 void FarDetectorTransportationPostML::process(
0025 const FarDetectorTransportationPostML::Input& input,
0026 const FarDetectorTransportationPostML::Output& output) const {
0027
0028 const auto [prediction_tensors, track_associations, beamElectrons] = input;
0029 auto [out_particles, out_associations] = output;
0030
0031
0032 if (beamElectrons != nullptr) {
0033 std::call_once(m_initBeamE, [&]() {
0034
0035 if (beamElectrons->empty()) {
0036 if (m_cfg.requireBeamElectron) {
0037 error("No beam electrons found");
0038 throw std::runtime_error("No beam electrons found");
0039 }
0040 return;
0041 }
0042 m_beamE = beamElectrons->at(0).getEnergy();
0043
0044 m_beamE = round(m_beamE);
0045 });
0046 }
0047
0048 if (prediction_tensors->size() != 1) {
0049 error("Expected to find a single tensor, found {}", prediction_tensors->size());
0050 throw std::runtime_error("");
0051 }
0052 edm4eic::Tensor prediction_tensor = (*prediction_tensors)[0];
0053
0054 if (prediction_tensor.shape_size() != 2) {
0055 error("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size());
0056 throw std::runtime_error(
0057 fmt::format("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size()));
0058 }
0059
0060 if (prediction_tensor.getShape(1) != 3) {
0061 error("Expected 2 values per cluster in the output tensor, got {}",
0062 prediction_tensor.getShape(0));
0063 throw std::runtime_error(
0064 fmt::format("Expected 2 values per cluster in the output tensor, got {}",
0065 prediction_tensor.getShape(0)));
0066 }
0067
0068 if (prediction_tensor.getElementType() != 1) {
0069 error("Expected a tensor of floats, but element type is {}",
0070 prediction_tensor.getElementType());
0071 throw std::runtime_error(fmt::format("Expected a tensor of floats, but element type is {}",
0072 prediction_tensor.getElementType()));
0073 }
0074
0075 auto prediction_tensor_data = prediction_tensor.getFloatData();
0076
0077
0078 if (prediction_tensor_data.size() % 3 != 0 || prediction_tensor.getShape(1) != 3) {
0079 error("The size of prediction_tensor_data is not a multiple of 3.");
0080 throw std::runtime_error("The size of prediction_tensor_data is not a multiple of 3.");
0081 }
0082
0083 edm4eic::MutableReconstructedParticle particle;
0084
0085
0086 for (std::size_t i = 0; i < static_cast<std::size_t>(prediction_tensor.getShape(0)); i++) {
0087
0088 std::size_t base_index = i * 3;
0089
0090 if (base_index + 2 >= prediction_tensor_data.size()) {
0091 error("Incomplete data for a prediction tensor at the end of the vector.");
0092 throw std::runtime_error("Incomplete data for a prediction tensor at the end of the vector.");
0093 }
0094
0095
0096 float px = prediction_tensor_data[base_index] * m_beamE;
0097 float py = prediction_tensor_data[base_index + 1] * m_beamE;
0098 float pz = prediction_tensor_data[base_index + 2] * m_beamE;
0099
0100
0101 double energy = sqrt(px * px + py * py + pz * pz + m_mass * m_mass);
0102
0103 particle = out_particles->create();
0104
0105 particle.setEnergy(energy);
0106 particle.setMomentum({px, py, pz});
0107 particle.setCharge(m_charge);
0108 particle.setMass(m_mass);
0109 particle.setPDG(m_cfg.pdg_value);
0110
0111
0112 if ((track_associations != nullptr) && (track_associations->size() > i)) {
0113
0114 auto association = track_associations->at(i);
0115 auto out_association = out_associations->create();
0116 out_association.setSim(association.getSim());
0117 out_association.setRec(particle);
0118 out_association.setWeight(association.getWeight());
0119 }
0120 }
0121
0122
0123 }
0124
0125 }