Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-05 08:15:15

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2024 Dmitry Kalinkin
0003 
0004 #include <edm4eic/EDM4eicVersion.h>
0005 
0006 #if EDM4EIC_VERSION_MAJOR >= 8
0007 #include <cstddef>
0008 #include <fmt/core.h>
0009 #include <gsl/pointers>
0010 #include <stdexcept>
0011 
0012 #include "CalorimeterParticleIDPostML.h"
0013 
0014 namespace eicrecon {
0015 
0016 void CalorimeterParticleIDPostML::init() {
0017   // Nothing
0018 }
0019 
0020 void CalorimeterParticleIDPostML::process(const CalorimeterParticleIDPostML::Input& input,
0021                                           const CalorimeterParticleIDPostML::Output& output) const {
0022 
0023   const auto [in_clusters, in_assocs, prediction_tensors] = input;
0024   auto [out_clusters, out_assocs, out_particle_ids]       = output;
0025 
0026   if (prediction_tensors->size() != 1) {
0027     error("Expected to find a single tensor, found {}", prediction_tensors->size());
0028     throw std::runtime_error("");
0029   }
0030   edm4eic::Tensor prediction_tensor = (*prediction_tensors)[0];
0031 
0032   if (prediction_tensor.shape_size() != 2) {
0033     error("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size());
0034     throw std::runtime_error(
0035         fmt::format("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size()));
0036   }
0037 
0038   if (prediction_tensor.getShape(0) != static_cast<long>(in_clusters->size())) {
0039     error("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0040           prediction_tensor.getShape(0), in_clusters->size());
0041     throw std::runtime_error(
0042         fmt::format("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0043                     prediction_tensor.getShape(0), in_clusters->size()));
0044   }
0045 
0046   if (prediction_tensor.getShape(1) != 2) {
0047     error("Expected 2 values per cluster in the output tensor, got {}",
0048           prediction_tensor.getShape(0));
0049     throw std::runtime_error(
0050         fmt::format("Expected 2 values per cluster in the output tensor, got {}",
0051                     prediction_tensor.getShape(0)));
0052   }
0053 
0054   if (prediction_tensor.getElementType() != 1) { // 1 - float
0055     error("Expected a tensor of floats, but element type is {}",
0056           prediction_tensor.getElementType());
0057     throw std::runtime_error(fmt::format("Expected a tensor of floats, but element type is {}",
0058                                          prediction_tensor.getElementType()));
0059   }
0060 
0061   for (std::size_t cluster_ix = 0; cluster_ix < in_clusters->size(); cluster_ix++) {
0062     edm4eic::Cluster in_cluster         = (*in_clusters)[cluster_ix];
0063     edm4eic::MutableCluster out_cluster = in_cluster.clone();
0064     out_clusters->push_back(out_cluster);
0065 
0066     float prob_pion =
0067         prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 0);
0068     float prob_electron =
0069         prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 1);
0070 
0071     out_cluster.addToParticleIDs(out_particle_ids->create(0,        // std::int32_t type
0072                                                           211,      // std::int32_t PDG
0073                                                           0,        // std::int32_t algorithmType
0074                                                           prob_pion // float likelihood
0075                                                           ));
0076     out_cluster.addToParticleIDs(out_particle_ids->create(0,  // std::int32_t type
0077                                                           11, // std::int32_t PDG
0078                                                           0,  // std::int32_t algorithmType
0079                                                           prob_electron // float likelihood
0080                                                           ));
0081 
0082     // propagate associations
0083     for (auto in_assoc : *in_assocs) {
0084       if (in_assoc.getRec() == in_cluster) {
0085         auto out_assoc = in_assoc.clone();
0086         out_assoc.setRec(out_cluster);
0087         out_assocs->push_back(out_assoc);
0088       }
0089     }
0090   }
0091 }
0092 
0093 } // namespace eicrecon
0094 #endif