Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 09:39:00

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2024 - 2025 Simon Gardner
0003 
0004 #include <edm4eic/EDM4eicVersion.h>
0005 
0006 #if EDM4EIC_VERSION_MAJOR >= 8
0007 #include <cmath>
0008 #include <cstddef>
0009 #include <fmt/core.h>
0010 #include <gsl/pointers>
0011 #include <podio/RelationRange.h>
0012 #include <stdexcept>
0013 
0014 #include "FarDetectorTransportationPostML.h"
0015 
0016 namespace eicrecon {
0017 
0018   void FarDetectorTransportationPostML::init() {
0019 
0020     m_beamE = m_cfg.beamE;
0021 
0022   }
0023 
0024   void FarDetectorTransportationPostML::process(
0025       const FarDetectorTransportationPostML::Input& input,
0026       const FarDetectorTransportationPostML::Output& output) const {
0027 
0028     const auto [prediction_tensors,beamElectrons] = input;
0029     auto [out_particles] = output;
0030 
0031     //Set beam energy from first MCBeamElectron, using std::call_once
0032     if (beamElectrons)
0033     {
0034       std::call_once(m_initBeamE,[&](){
0035         // Check if beam electrons are present
0036         if(beamElectrons->size() == 0){
0037           error("No beam electrons found keeping default 10GeV beam energy.");
0038           return;
0039         }
0040         m_beamE = beamElectrons->at(0).getEnergy();
0041         //Round beam energy to nearest GeV - Should be 5, 10 or 18GeV
0042         m_beamE = round(m_beamE);
0043       });
0044     }
0045 
0046     if (prediction_tensors->size() != 1) {
0047       error("Expected to find a single tensor, found {}", prediction_tensors->size());
0048       throw std::runtime_error("");
0049     }
0050     edm4eic::Tensor prediction_tensor = (*prediction_tensors)[0];
0051 
0052     if (prediction_tensor.shape_size() != 2) {
0053       error("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size());
0054       throw std::runtime_error(fmt::format("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size()));
0055     }
0056 
0057     if (prediction_tensor.getShape(1) != 3) {
0058       error("Expected 2 values per cluster in the output tensor, got {}", prediction_tensor.getShape(0));
0059       throw std::runtime_error(fmt::format("Expected 2 values per cluster in the output tensor, got {}", prediction_tensor.getShape(0)));
0060     }
0061 
0062     if (prediction_tensor.getElementType() != 1) { // 1 - float
0063       error("Expected a tensor of floats, but element type is {}", prediction_tensor.getElementType());
0064       throw std::runtime_error(fmt::format("Expected a tensor of floats, but element type is {}", prediction_tensor.getElementType()));
0065     }
0066 
0067     auto prediction_tensor_data = prediction_tensor.getFloatData();
0068 
0069     // Ensure the size of prediction_tensor_data is a multiple of its shape
0070     if (prediction_tensor_data.size() % 3 != 0) {
0071         error("The size of prediction_tensor_data is not a multiple of 3.");
0072         throw std::runtime_error("The size of prediction_tensor_data is not a multiple of 3.");
0073     }
0074 
0075 
0076     edm4eic::MutableReconstructedParticle particle;
0077 
0078     // Iterate over the prediction_tensor_data in steps of three
0079     for (size_t i = 0; i < prediction_tensor_data.size(); i += 3) {
0080         if (i + 2 >= prediction_tensor_data.size()) {
0081             error("Incomplete data for a prediction tensor at the end of the vector.");
0082             throw std::runtime_error("Incomplete data for a prediction tensor at the end of the vector.");
0083         }
0084 
0085         // Extract the current prediction
0086         float px = prediction_tensor_data[i] * m_beamE;
0087         float py = prediction_tensor_data[i + 1] * m_beamE;
0088         float pz = prediction_tensor_data[i + 2] * m_beamE;
0089 
0090         // Calculate reconstructed electron energy
0091         double energy = sqrt(px * px + py * py + pz * pz + 0.000511 * 0.000511);
0092 
0093         particle = out_particles->create();
0094 
0095         particle.setEnergy(energy);
0096         particle.setMomentum({px, py, pz});
0097         particle.setCharge(-1);
0098         particle.setMass(0.000511);
0099         particle.setPDG(11);
0100     }
0101 
0102     // TODO: Implement the association of the reconstructed particles with the tracks
0103 
0104   }
0105 
0106 } // namespace eicrecon
0107 #endif