File indexing completed on 2025-12-03 09:19:22
0001
0002
0003
0004 #include <cstddef>
0005 #include <fmt/core.h>
0006 #include <gsl/pointers>
0007 #include <stdexcept>
0008
0009 #include "CalorimeterParticleIDPostML.h"
0010
0011 namespace eicrecon {
0012
0013 void CalorimeterParticleIDPostML::init() {
0014
0015 }
0016
0017 void CalorimeterParticleIDPostML::process(const CalorimeterParticleIDPostML::Input& input,
0018 const CalorimeterParticleIDPostML::Output& output) const {
0019
0020 const auto [in_clusters, in_assocs, prediction_tensors] = input;
0021 auto [out_clusters, out_assocs, out_particle_ids] = output;
0022
0023 if (prediction_tensors->size() != 1) {
0024 error("Expected to find a single tensor, found {}", prediction_tensors->size());
0025 throw std::runtime_error("");
0026 }
0027 edm4eic::Tensor prediction_tensor = (*prediction_tensors)[0];
0028
0029 if (prediction_tensor.shape_size() != 2) {
0030 error("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size());
0031 throw std::runtime_error(
0032 fmt::format("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size()));
0033 }
0034
0035 if (prediction_tensor.getShape(0) != static_cast<long>(in_clusters->size())) {
0036 error("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0037 prediction_tensor.getShape(0), in_clusters->size());
0038 throw std::runtime_error(
0039 fmt::format("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0040 prediction_tensor.getShape(0), in_clusters->size()));
0041 }
0042
0043 if (prediction_tensor.getShape(1) != 2) {
0044 error("Expected 2 values per cluster in the output tensor, got {}",
0045 prediction_tensor.getShape(0));
0046 throw std::runtime_error(
0047 fmt::format("Expected 2 values per cluster in the output tensor, got {}",
0048 prediction_tensor.getShape(0)));
0049 }
0050
0051 if (prediction_tensor.getElementType() != 1) {
0052 error("Expected a tensor of floats, but element type is {}",
0053 prediction_tensor.getElementType());
0054 throw std::runtime_error(fmt::format("Expected a tensor of floats, but element type is {}",
0055 prediction_tensor.getElementType()));
0056 }
0057
0058 for (std::size_t cluster_ix = 0; cluster_ix < in_clusters->size(); cluster_ix++) {
0059 edm4eic::Cluster in_cluster = (*in_clusters)[cluster_ix];
0060 edm4eic::MutableCluster out_cluster = in_cluster.clone();
0061 out_clusters->push_back(out_cluster);
0062
0063 float prob_pion =
0064 prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 0);
0065 float prob_electron =
0066 prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 1);
0067
0068 out_cluster.addToParticleIDs(out_particle_ids->create(0,
0069 211,
0070 0,
0071 prob_pion
0072 ));
0073 out_cluster.addToParticleIDs(out_particle_ids->create(0,
0074 11,
0075 0,
0076 prob_electron
0077 ));
0078
0079
0080 for (auto in_assoc : *in_assocs) {
0081 if (in_assoc.getRec() == in_cluster) {
0082 auto out_assoc = in_assoc.clone();
0083 out_assoc.setRec(out_cluster);
0084 out_assocs->push_back(out_assoc);
0085 }
0086 }
0087 }
0088 }
0089
0090 }