File indexing completed on 2025-01-18 09:15:30
0001
0002
0003
0004 #include <edm4eic/EDM4eicVersion.h>
0005
0006 #if EDM4EIC_VERSION_MAJOR >= 8
0007 #include <cstddef>
0008 #include <fmt/core.h>
0009 #include <gsl/pointers>
0010 #include <stdexcept>
0011
0012 #include "CalorimeterParticleIDPostML.h"
0013
0014 namespace eicrecon {
0015
0016 void CalorimeterParticleIDPostML::init() {
0017
0018 }
0019
0020 void CalorimeterParticleIDPostML::process(
0021 const CalorimeterParticleIDPostML::Input& input,
0022 const CalorimeterParticleIDPostML::Output& output) const {
0023
0024 const auto [in_clusters, in_assocs, prediction_tensors] = input;
0025 auto [out_clusters, out_assocs, out_particle_ids] = output;
0026
0027 if (prediction_tensors->size() != 1) {
0028 error("Expected to find a single tensor, found {}", prediction_tensors->size());
0029 throw std::runtime_error("");
0030 }
0031 edm4eic::Tensor prediction_tensor = (*prediction_tensors)[0];
0032
0033 if (prediction_tensor.shape_size() != 2) {
0034 error("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size());
0035 throw std::runtime_error(fmt::format("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size()));
0036 }
0037
0038 if (prediction_tensor.getShape(0) != in_clusters->size()) {
0039 error("Length mismatch between tensor's 0th axis and number of clusters: {} != {}", prediction_tensor.getShape(0), in_clusters->size());
0040 throw std::runtime_error(fmt::format("Length mismatch between tensor's 0th axis and number of clusters: {} != {}", prediction_tensor.getShape(0), in_clusters->size()));
0041 }
0042
0043 if (prediction_tensor.getShape(1) != 2) {
0044 error("Expected 2 values per cluster in the output tensor, got {}", prediction_tensor.getShape(0));
0045 throw std::runtime_error(fmt::format("Expected 2 values per cluster in the output tensor, got {}", prediction_tensor.getShape(0)));
0046 }
0047
0048 if (prediction_tensor.getElementType() != 1) {
0049 error("Expected a tensor of floats, but element type is {}", prediction_tensor.getElementType());
0050 throw std::runtime_error(fmt::format("Expected a tensor of floats, but element type is {}", prediction_tensor.getElementType()));
0051 }
0052
0053 for (size_t cluster_ix = 0; cluster_ix < in_clusters->size(); cluster_ix++) {
0054 edm4eic::Cluster in_cluster = (*in_clusters)[cluster_ix];
0055 edm4eic::MutableCluster out_cluster = in_cluster.clone();
0056 out_clusters->push_back(out_cluster);
0057
0058 float prob_pion = prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 0);
0059 float prob_electron = prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 1);
0060
0061 out_cluster.addToParticleIDs(out_particle_ids->create(
0062 0,
0063 211,
0064 0,
0065 prob_pion
0066 ));
0067 out_cluster.addToParticleIDs(out_particle_ids->create(
0068 0,
0069 11,
0070 0,
0071 prob_electron
0072 ));
0073
0074
0075 for (auto in_assoc : *in_assocs) {
0076 if (in_assoc.getRec() == in_cluster) {
0077 auto out_assoc = in_assoc.clone();
0078 out_assoc.setRec(out_cluster);
0079 out_assocs->push_back(out_assoc);
0080 }
0081 }
0082 }
0083 }
0084
0085 }
0086 #endif