Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:15:30

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2024 Dmitry Kalinkin
0003 
0004 #include <edm4eic/EDM4eicVersion.h>
0005 
0006 #if EDM4EIC_VERSION_MAJOR >= 8
0007 #include <cstddef>
0008 #include <fmt/core.h>
0009 #include <gsl/pointers>
0010 #include <stdexcept>
0011 
0012 #include "CalorimeterParticleIDPostML.h"
0013 
0014 namespace eicrecon {
0015 
0016   void CalorimeterParticleIDPostML::init() {
0017     // Nothing
0018   }
0019 
0020   void CalorimeterParticleIDPostML::process(
0021       const CalorimeterParticleIDPostML::Input& input,
0022       const CalorimeterParticleIDPostML::Output& output) const {
0023 
0024     const auto [in_clusters, in_assocs, prediction_tensors] = input;
0025     auto [out_clusters, out_assocs, out_particle_ids] = output;
0026 
0027     if (prediction_tensors->size() != 1) {
0028       error("Expected to find a single tensor, found {}", prediction_tensors->size());
0029       throw std::runtime_error("");
0030     }
0031     edm4eic::Tensor prediction_tensor = (*prediction_tensors)[0];
0032 
0033     if (prediction_tensor.shape_size() != 2) {
0034       error("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size());
0035       throw std::runtime_error(fmt::format("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size()));
0036     }
0037 
0038     if (prediction_tensor.getShape(0) != in_clusters->size()) {
0039       error("Length mismatch between tensor's 0th axis and number of clusters: {} != {}", prediction_tensor.getShape(0), in_clusters->size());
0040       throw std::runtime_error(fmt::format("Length mismatch between tensor's 0th axis and number of clusters: {} != {}", prediction_tensor.getShape(0), in_clusters->size()));
0041     }
0042 
0043     if (prediction_tensor.getShape(1) != 2) {
0044       error("Expected 2 values per cluster in the output tensor, got {}", prediction_tensor.getShape(0));
0045       throw std::runtime_error(fmt::format("Expected 2 values per cluster in the output tensor, got {}", prediction_tensor.getShape(0)));
0046     }
0047 
0048     if (prediction_tensor.getElementType() != 1) { // 1 - float
0049       error("Expected a tensor of floats, but element type is {}", prediction_tensor.getElementType());
0050       throw std::runtime_error(fmt::format("Expected a tensor of floats, but element type is {}", prediction_tensor.getElementType()));
0051     }
0052 
0053     for (size_t cluster_ix = 0; cluster_ix < in_clusters->size(); cluster_ix++) {
0054       edm4eic::Cluster in_cluster = (*in_clusters)[cluster_ix];
0055       edm4eic::MutableCluster out_cluster = in_cluster.clone();
0056       out_clusters->push_back(out_cluster);
0057 
0058       float prob_pion = prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 0);
0059       float prob_electron = prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 1);
0060 
0061       out_cluster.addToParticleIDs(out_particle_ids->create(
0062         0,            // std::int32_t type
0063         211,          // std::int32_t PDG
0064         0,            // std::int32_t algorithmType
0065         prob_pion     // float likelihood
0066       ));
0067       out_cluster.addToParticleIDs(out_particle_ids->create(
0068         0,            // std::int32_t type
0069         11,           // std::int32_t PDG
0070         0,            // std::int32_t algorithmType
0071         prob_electron // float likelihood
0072       ));
0073 
0074       // propagate associations
0075       for (auto in_assoc : *in_assocs) {
0076         if (in_assoc.getRec() == in_cluster) {
0077           auto out_assoc = in_assoc.clone();
0078           out_assoc.setRec(out_cluster);
0079           out_assocs->push_back(out_assoc);
0080         }
0081       }
0082     }
0083   }
0084 
0085 } // namespace eicrecon
0086 #endif