File indexing completed on 2025-07-05 08:15:15
0001
0002
0003
0004 #include <edm4eic/EDM4eicVersion.h>
0005
0006 #if EDM4EIC_VERSION_MAJOR >= 8
0007 #include <cstddef>
0008 #include <fmt/core.h>
0009 #include <gsl/pointers>
0010 #include <stdexcept>
0011
0012 #include "CalorimeterParticleIDPostML.h"
0013
0014 namespace eicrecon {
0015
0016 void CalorimeterParticleIDPostML::init() {
0017
0018 }
0019
0020 void CalorimeterParticleIDPostML::process(const CalorimeterParticleIDPostML::Input& input,
0021 const CalorimeterParticleIDPostML::Output& output) const {
0022
0023 const auto [in_clusters, in_assocs, prediction_tensors] = input;
0024 auto [out_clusters, out_assocs, out_particle_ids] = output;
0025
0026 if (prediction_tensors->size() != 1) {
0027 error("Expected to find a single tensor, found {}", prediction_tensors->size());
0028 throw std::runtime_error("");
0029 }
0030 edm4eic::Tensor prediction_tensor = (*prediction_tensors)[0];
0031
0032 if (prediction_tensor.shape_size() != 2) {
0033 error("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size());
0034 throw std::runtime_error(
0035 fmt::format("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size()));
0036 }
0037
0038 if (prediction_tensor.getShape(0) != static_cast<long>(in_clusters->size())) {
0039 error("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0040 prediction_tensor.getShape(0), in_clusters->size());
0041 throw std::runtime_error(
0042 fmt::format("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0043 prediction_tensor.getShape(0), in_clusters->size()));
0044 }
0045
0046 if (prediction_tensor.getShape(1) != 2) {
0047 error("Expected 2 values per cluster in the output tensor, got {}",
0048 prediction_tensor.getShape(0));
0049 throw std::runtime_error(
0050 fmt::format("Expected 2 values per cluster in the output tensor, got {}",
0051 prediction_tensor.getShape(0)));
0052 }
0053
0054 if (prediction_tensor.getElementType() != 1) {
0055 error("Expected a tensor of floats, but element type is {}",
0056 prediction_tensor.getElementType());
0057 throw std::runtime_error(fmt::format("Expected a tensor of floats, but element type is {}",
0058 prediction_tensor.getElementType()));
0059 }
0060
0061 for (std::size_t cluster_ix = 0; cluster_ix < in_clusters->size(); cluster_ix++) {
0062 edm4eic::Cluster in_cluster = (*in_clusters)[cluster_ix];
0063 edm4eic::MutableCluster out_cluster = in_cluster.clone();
0064 out_clusters->push_back(out_cluster);
0065
0066 float prob_pion =
0067 prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 0);
0068 float prob_electron =
0069 prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 1);
0070
0071 out_cluster.addToParticleIDs(out_particle_ids->create(0,
0072 211,
0073 0,
0074 prob_pion
0075 ));
0076 out_cluster.addToParticleIDs(out_particle_ids->create(0,
0077 11,
0078 0,
0079 prob_electron
0080 ));
0081
0082
0083 for (auto in_assoc : *in_assocs) {
0084 if (in_assoc.getRec() == in_cluster) {
0085 auto out_assoc = in_assoc.clone();
0086 out_assoc.setRec(out_cluster);
0087 out_assocs->push_back(out_assoc);
0088 }
0089 }
0090 }
0091 }
0092
0093 }
0094 #endif