Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-11-03 09:01:48

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2024 Dmitry Kalinkin
0003 
0004 #include <cstddef>
0005 #include <fmt/core.h>
0006 #include <gsl/pointers>
0007 #include <stdexcept>
0008 
0009 #include "CalorimeterParticleIDPostML.h"
0010 
0011 namespace eicrecon {
0012 
0013 void CalorimeterParticleIDPostML::init() {
0014   // Nothing
0015 }
0016 
0017 void CalorimeterParticleIDPostML::process(const CalorimeterParticleIDPostML::Input& input,
0018                                           const CalorimeterParticleIDPostML::Output& output) const {
0019 
0020   const auto [in_clusters, in_assocs, prediction_tensors] = input;
0021   auto [out_clusters, out_assocs, out_particle_ids]       = output;
0022 
0023   if (prediction_tensors->size() != 1) {
0024     error("Expected to find a single tensor, found {}", prediction_tensors->size());
0025     throw std::runtime_error("");
0026   }
0027   edm4eic::Tensor prediction_tensor = (*prediction_tensors)[0];
0028 
0029   if (prediction_tensor.shape_size() != 2) {
0030     error("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size());
0031     throw std::runtime_error(
0032         fmt::format("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size()));
0033   }
0034 
0035   if (prediction_tensor.getShape(0) != static_cast<long>(in_clusters->size())) {
0036     error("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0037           prediction_tensor.getShape(0), in_clusters->size());
0038     throw std::runtime_error(
0039         fmt::format("Length mismatch between tensor's 0th axis and number of clusters: {} != {}",
0040                     prediction_tensor.getShape(0), in_clusters->size()));
0041   }
0042 
0043   if (prediction_tensor.getShape(1) != 2) {
0044     error("Expected 2 values per cluster in the output tensor, got {}",
0045           prediction_tensor.getShape(0));
0046     throw std::runtime_error(
0047         fmt::format("Expected 2 values per cluster in the output tensor, got {}",
0048                     prediction_tensor.getShape(0)));
0049   }
0050 
0051   if (prediction_tensor.getElementType() != 1) { // 1 - float
0052     error("Expected a tensor of floats, but element type is {}",
0053           prediction_tensor.getElementType());
0054     throw std::runtime_error(fmt::format("Expected a tensor of floats, but element type is {}",
0055                                          prediction_tensor.getElementType()));
0056   }
0057 
0058   for (std::size_t cluster_ix = 0; cluster_ix < in_clusters->size(); cluster_ix++) {
0059     edm4eic::Cluster in_cluster         = (*in_clusters)[cluster_ix];
0060     edm4eic::MutableCluster out_cluster = in_cluster.clone();
0061     out_clusters->push_back(out_cluster);
0062 
0063     float prob_pion =
0064         prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 0);
0065     float prob_electron =
0066         prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 1);
0067 
0068     out_cluster.addToParticleIDs(out_particle_ids->create(0,        // std::int32_t type
0069                                                           211,      // std::int32_t PDG
0070                                                           0,        // std::int32_t algorithmType
0071                                                           prob_pion // float likelihood
0072                                                           ));
0073     out_cluster.addToParticleIDs(out_particle_ids->create(0,  // std::int32_t type
0074                                                           11, // std::int32_t PDG
0075                                                           0,  // std::int32_t algorithmType
0076                                                           prob_electron // float likelihood
0077                                                           ));
0078 
0079     // propagate associations
0080     for (auto in_assoc : *in_assocs) {
0081       if (in_assoc.getRec() == in_cluster) {
0082         auto out_assoc = in_assoc.clone();
0083         out_assoc.setRec(out_cluster);
0084         out_assocs->push_back(out_assoc);
0085       }
0086     }
0087   }
0088 }
0089 
0090 } // namespace eicrecon