Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-10-22 07:55:36

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2024 Dmitry Kalinkin
0003 
0004 #include <cstddef>
0005 #include <cstdint>
0006 #include <edm4hep/MCParticle.h>
0007 #include <edm4hep/Vector3f.h>
0008 #include <edm4hep/utils/vector_utils.h>
0009 #include <fmt/core.h>
0010 #include <cmath>
0011 #include <stdexcept>
0012 
0013 #include <gsl/pointers>
0014 
0015 #include "CalorimeterParticleIDPreML.h"
0016 
0017 namespace eicrecon {
0018 
0019 void CalorimeterParticleIDPreML::init() {
0020   // Nothing
0021 }
0022 
0023 void CalorimeterParticleIDPreML::process(const CalorimeterParticleIDPreML::Input& input,
0024                                          const CalorimeterParticleIDPreML::Output& output) const {
0025 
0026   const auto [clusters, cluster_assocs]  = input;
0027   auto [feature_tensors, target_tensors] = output;
0028 
0029   edm4eic::MutableTensor feature_tensor = feature_tensors->create();
0030   feature_tensor.addToShape(clusters->size());
0031   feature_tensor.addToShape(11);    // p, E/p, azimuthal, polar, 7 shape parameters
0032   feature_tensor.setElementType(1); // 1 - float
0033 
0034   edm4eic::MutableTensor target_tensor;
0035   if (cluster_assocs != nullptr) {
0036     target_tensor = target_tensors->create();
0037     target_tensor.addToShape(clusters->size());
0038     target_tensor.addToShape(2);     // is electron, is hadron
0039     target_tensor.setElementType(7); // 7 - int64
0040   }
0041 
0042   for (edm4eic::Cluster cluster : *clusters) {
0043     double momentum = NAN;
0044     {
0045       // FIXME: use track momentum once matching to tracks becomes available
0046       edm4eic::MCRecoClusterParticleAssociation best_assoc;
0047       for (auto assoc : *cluster_assocs) {
0048         if (assoc.getRec() == cluster) {
0049           if ((not best_assoc.isAvailable()) || (assoc.getWeight() > best_assoc.getWeight())) {
0050             best_assoc = assoc;
0051           }
0052         }
0053       }
0054       if (best_assoc.isAvailable()) {
0055         momentum = edm4hep::utils::magnitude(best_assoc.getSim().getMomentum());
0056       } else {
0057         warning("Can't find association for cluster. Skipping...");
0058         continue;
0059       }
0060     }
0061 
0062     feature_tensor.addToFloatData(momentum);
0063     feature_tensor.addToFloatData(cluster.getEnergy() / momentum);
0064     auto pos = cluster.getPosition();
0065     feature_tensor.addToFloatData(edm4hep::utils::anglePolar(pos));
0066     feature_tensor.addToFloatData(edm4hep::utils::angleAzimuthal(pos));
0067     for (std::size_t par_ix = 0; par_ix < cluster.shapeParameters_size(); par_ix++) {
0068       feature_tensor.addToFloatData(cluster.getShapeParameters(par_ix));
0069     }
0070 
0071     if (cluster_assocs != nullptr) {
0072       edm4eic::MCRecoClusterParticleAssociation best_assoc;
0073       for (auto assoc : *cluster_assocs) {
0074         if (assoc.getRec() == cluster) {
0075           if ((not best_assoc.isAvailable()) || (assoc.getWeight() > best_assoc.getWeight())) {
0076             best_assoc = assoc;
0077           }
0078         }
0079       }
0080       int64_t is_electron = 0;
0081       int64_t is_pion     = 0;
0082       if (best_assoc.isAvailable()) {
0083         is_electron = static_cast<int64_t>(best_assoc.getSim().getPDG() == 11);
0084         is_pion     = static_cast<int64_t>(best_assoc.getSim().getPDG() != 11);
0085       }
0086       target_tensor.addToInt64Data(is_pion);
0087       target_tensor.addToInt64Data(is_electron);
0088     }
0089   }
0090 
0091   std::size_t expected_num_entries = feature_tensor.getShape(0) * feature_tensor.getShape(1);
0092   if (feature_tensor.floatData_size() != expected_num_entries) {
0093     error("Inconsistent output tensor shape and element count: {} != {}",
0094           feature_tensor.floatData_size(), expected_num_entries);
0095     throw std::runtime_error(
0096         fmt::format("Inconsistent output tensor shape and element count: {} != {}",
0097                     feature_tensor.floatData_size(), expected_num_entries));
0098   }
0099 }
0100 
0101 } // namespace eicrecon