Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-03 07:42:40

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2024 Dmitry Kalinkin
0003 
0004 #include <edm4eic/EDM4eicVersion.h>
0005 #include <edm4hep/MCParticle.h>
0006 #include <fmt/format.h>
0007 #include <podio/detail/Link.h>
0008 #include <podio/detail/LinkCollectionImpl.h>
0009 #include <cstddef>
0010 #include <gsl/pointers>
0011 #include <memory>
0012 #include <stdexcept>
0013 
0014 #include "CalorimeterParticleIDPostML.h"
0015 
0016 namespace eicrecon {
0017 
0018 void CalorimeterParticleIDPostML::init() {
0019   // Nothing
0020 }
0021 
0022 void CalorimeterParticleIDPostML::process(const CalorimeterParticleIDPostML::Input& input,
0023                                           const CalorimeterParticleIDPostML::Output& output) const {
0024 
0025   const auto [in_clusters, in_assocs, prediction_tensors] = input;
0026 #if EDM4EIC_BUILD_VERSION >= EDM4EIC_VERSION(8, 7, 0)
0027   auto [out_clusters, out_links, out_assocs, out_particle_ids] = output;
0028 #else
0029   auto [out_clusters, out_assocs, out_particle_ids] = output;
0030 #endif
0031 
0032   if (prediction_tensors->size() != 1) {
0033     error("Expected to find a single tensor, found {}", prediction_tensors->size());
0034     throw std::runtime_error("");
0035   }
0036   edm4eic::Tensor prediction_tensor = (*prediction_tensors)[0];
0037 
0038   if (prediction_tensor.shape_size() != 2) {
0039     error("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size());
0040     throw std::runtime_error(
0041         fmt::format("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size()));
0042   }
0043 
0044   if (prediction_tensor.getShape(0) != static_cast<long>(in_clusters->size())) {
0045     error("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0046           prediction_tensor.getShape(0), in_clusters->size());
0047     throw std::runtime_error(
0048         fmt::format("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0049                     prediction_tensor.getShape(0), in_clusters->size()));
0050   }
0051 
0052   if (prediction_tensor.getShape(1) != 2) {
0053     error("Expected 2 values per cluster in the output tensor, got {}",
0054           prediction_tensor.getShape(0));
0055     throw std::runtime_error(
0056         fmt::format("Expected 2 values per cluster in the output tensor, got {}",
0057                     prediction_tensor.getShape(0)));
0058   }
0059 
0060   if (prediction_tensor.getElementType() != 1) { // 1 - float
0061     error("Expected a tensor of floats, but element type is {}",
0062           prediction_tensor.getElementType());
0063     throw std::runtime_error(fmt::format("Expected a tensor of floats, but element type is {}",
0064                                          prediction_tensor.getElementType()));
0065   }
0066 
0067   for (std::size_t cluster_ix = 0; cluster_ix < in_clusters->size(); cluster_ix++) {
0068     edm4eic::Cluster in_cluster         = (*in_clusters)[cluster_ix];
0069     edm4eic::MutableCluster out_cluster = in_cluster.clone();
0070     out_clusters->push_back(out_cluster);
0071 
0072     float prob_pion =
0073         prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 0);
0074     float prob_electron =
0075         prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 1);
0076 
0077     out_cluster.addToParticleIDs(out_particle_ids->create(0,        // std::int32_t type
0078                                                           211,      // std::int32_t PDG
0079                                                           0,        // std::int32_t algorithmType
0080                                                           prob_pion // float likelihood
0081                                                           ));
0082     out_cluster.addToParticleIDs(out_particle_ids->create(0,  // std::int32_t type
0083                                                           11, // std::int32_t PDG
0084                                                           0,  // std::int32_t algorithmType
0085                                                           prob_electron // float likelihood
0086                                                           ));
0087 
0088     // propagate associations
0089     for (auto in_assoc : *in_assocs) {
0090       if (in_assoc.getRec() == in_cluster) {
0091 #if EDM4EIC_BUILD_VERSION >= EDM4EIC_VERSION(8, 7, 0)
0092         auto out_link = out_links->create();
0093         out_link.setFrom(out_cluster);
0094         out_link.setTo(in_assoc.getSim());
0095         out_link.setWeight(in_assoc.getWeight());
0096 #endif
0097         auto out_assoc = in_assoc.clone();
0098         out_assoc.setRec(out_cluster);
0099         out_assocs->push_back(out_assoc);
0100       }
0101     }
0102   }
0103 }
0104 
0105 } // namespace eicrecon