File indexing completed on 2026-05-03 07:42:40
0001
0002
0003
0004 #include <edm4eic/EDM4eicVersion.h>
0005 #include <edm4hep/MCParticle.h>
0006 #include <fmt/format.h>
0007 #include <podio/detail/Link.h>
0008 #include <podio/detail/LinkCollectionImpl.h>
0009 #include <cstddef>
0010 #include <gsl/pointers>
0011 #include <memory>
0012 #include <stdexcept>
0013
0014 #include "CalorimeterParticleIDPostML.h"
0015
0016 namespace eicrecon {
0017
0018 void CalorimeterParticleIDPostML::init() {
0019
0020 }
0021
0022 void CalorimeterParticleIDPostML::process(const CalorimeterParticleIDPostML::Input& input,
0023 const CalorimeterParticleIDPostML::Output& output) const {
0024
0025 const auto [in_clusters, in_assocs, prediction_tensors] = input;
0026 #if EDM4EIC_BUILD_VERSION >= EDM4EIC_VERSION(8, 7, 0)
0027 auto [out_clusters, out_links, out_assocs, out_particle_ids] = output;
0028 #else
0029 auto [out_clusters, out_assocs, out_particle_ids] = output;
0030 #endif
0031
0032 if (prediction_tensors->size() != 1) {
0033 error("Expected to find a single tensor, found {}", prediction_tensors->size());
0034 throw std::runtime_error("");
0035 }
0036 edm4eic::Tensor prediction_tensor = (*prediction_tensors)[0];
0037
0038 if (prediction_tensor.shape_size() != 2) {
0039 error("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size());
0040 throw std::runtime_error(
0041 fmt::format("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size()));
0042 }
0043
0044 if (prediction_tensor.getShape(0) != static_cast<long>(in_clusters->size())) {
0045 error("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0046 prediction_tensor.getShape(0), in_clusters->size());
0047 throw std::runtime_error(
0048 fmt::format("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0049 prediction_tensor.getShape(0), in_clusters->size()));
0050 }
0051
0052 if (prediction_tensor.getShape(1) != 2) {
0053 error("Expected 2 values per cluster in the output tensor, got {}",
0054 prediction_tensor.getShape(0));
0055 throw std::runtime_error(
0056 fmt::format("Expected 2 values per cluster in the output tensor, got {}",
0057 prediction_tensor.getShape(0)));
0058 }
0059
0060 if (prediction_tensor.getElementType() != 1) {
0061 error("Expected a tensor of floats, but element type is {}",
0062 prediction_tensor.getElementType());
0063 throw std::runtime_error(fmt::format("Expected a tensor of floats, but element type is {}",
0064 prediction_tensor.getElementType()));
0065 }
0066
0067 for (std::size_t cluster_ix = 0; cluster_ix < in_clusters->size(); cluster_ix++) {
0068 edm4eic::Cluster in_cluster = (*in_clusters)[cluster_ix];
0069 edm4eic::MutableCluster out_cluster = in_cluster.clone();
0070 out_clusters->push_back(out_cluster);
0071
0072 float prob_pion =
0073 prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 0);
0074 float prob_electron =
0075 prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 1);
0076
0077 out_cluster.addToParticleIDs(out_particle_ids->create(0,
0078 211,
0079 0,
0080 prob_pion
0081 ));
0082 out_cluster.addToParticleIDs(out_particle_ids->create(0,
0083 11,
0084 0,
0085 prob_electron
0086 ));
0087
0088
0089 for (auto in_assoc : *in_assocs) {
0090 if (in_assoc.getRec() == in_cluster) {
0091 #if EDM4EIC_BUILD_VERSION >= EDM4EIC_VERSION(8, 7, 0)
0092 auto out_link = out_links->create();
0093 out_link.setFrom(out_cluster);
0094 out_link.setTo(in_assoc.getSim());
0095 out_link.setWeight(in_assoc.getWeight());
0096 #endif
0097 auto out_assoc = in_assoc.clone();
0098 out_assoc.setRec(out_cluster);
0099 out_assocs->push_back(out_assoc);
0100 }
0101 }
0102 }
0103 }
0104
0105 }