Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-03 07:55:53

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2024 Dmitry Kalinkin
0003 
0004 #include <edm4eic/EDM4eicVersion.h>
0005 
0006 #if EDM4EIC_VERSION_MAJOR >= 8
0007 #include <cstddef>
0008 #include <cstdint>
0009 #include <edm4hep/MCParticle.h>
0010 #include <edm4hep/Vector3f.h>
0011 #include <edm4hep/utils/vector_utils.h>
0012 #include <fmt/core.h>
0013 #include <cmath>
0014 #include <stdexcept>
0015 
0016 #include <gsl/pointers>
0017 
0018 #include "CalorimeterParticleIDPreML.h"
0019 
0020 namespace eicrecon {
0021 
0022 void CalorimeterParticleIDPreML::init() {
0023   // Nothing
0024 }
0025 
0026 void CalorimeterParticleIDPreML::process(const CalorimeterParticleIDPreML::Input& input,
0027                                          const CalorimeterParticleIDPreML::Output& output) const {
0028 
0029   const auto [clusters, cluster_assocs]  = input;
0030   auto [feature_tensors, target_tensors] = output;
0031 
0032   edm4eic::MutableTensor feature_tensor = feature_tensors->create();
0033   feature_tensor.addToShape(clusters->size());
0034   feature_tensor.addToShape(11);    // p, E/p, azimuthal, polar, 7 shape parameters
0035   feature_tensor.setElementType(1); // 1 - float
0036 
0037   edm4eic::MutableTensor target_tensor;
0038   if (cluster_assocs != nullptr) {
0039     target_tensor = target_tensors->create();
0040     target_tensor.addToShape(clusters->size());
0041     target_tensor.addToShape(2);     // is electron, is hadron
0042     target_tensor.setElementType(7); // 7 - int64
0043   }
0044 
0045   for (edm4eic::Cluster cluster : *clusters) {
0046     double momentum = NAN;
0047     {
0048       // FIXME: use track momentum once matching to tracks becomes available
0049       edm4eic::MCRecoClusterParticleAssociation best_assoc;
0050       for (auto assoc : *cluster_assocs) {
0051         if (assoc.getRec() == cluster) {
0052           if ((not best_assoc.isAvailable()) || (assoc.getWeight() > best_assoc.getWeight())) {
0053             best_assoc = assoc;
0054           }
0055         }
0056       }
0057       if (best_assoc.isAvailable()) {
0058         momentum = edm4hep::utils::magnitude(best_assoc.getSim().getMomentum());
0059       } else {
0060         warning("Can't find association for cluster. Skipping...");
0061         continue;
0062       }
0063     }
0064 
0065     feature_tensor.addToFloatData(momentum);
0066     feature_tensor.addToFloatData(cluster.getEnergy() / momentum);
0067     auto pos = cluster.getPosition();
0068     feature_tensor.addToFloatData(edm4hep::utils::anglePolar(pos));
0069     feature_tensor.addToFloatData(edm4hep::utils::angleAzimuthal(pos));
0070     for (std::size_t par_ix = 0; par_ix < cluster.shapeParameters_size(); par_ix++) {
0071       feature_tensor.addToFloatData(cluster.getShapeParameters(par_ix));
0072     }
0073 
0074     if (cluster_assocs != nullptr) {
0075       edm4eic::MCRecoClusterParticleAssociation best_assoc;
0076       for (auto assoc : *cluster_assocs) {
0077         if (assoc.getRec() == cluster) {
0078           if ((not best_assoc.isAvailable()) || (assoc.getWeight() > best_assoc.getWeight())) {
0079             best_assoc = assoc;
0080           }
0081         }
0082       }
0083       int64_t is_electron = 0;
0084       int64_t is_pion     = 0;
0085       if (best_assoc.isAvailable()) {
0086         is_electron = static_cast<int64_t>(best_assoc.getSim().getPDG() == 11);
0087         is_pion     = static_cast<int64_t>(best_assoc.getSim().getPDG() != 11);
0088       }
0089       target_tensor.addToInt64Data(is_pion);
0090       target_tensor.addToInt64Data(is_electron);
0091     }
0092   }
0093 
0094   std::size_t expected_num_entries = feature_tensor.getShape(0) * feature_tensor.getShape(1);
0095   if (feature_tensor.floatData_size() != expected_num_entries) {
0096     error("Inconsistent output tensor shape and element count: {} != {}",
0097           feature_tensor.floatData_size(), expected_num_entries);
0098     throw std::runtime_error(
0099         fmt::format("Inconsistent output tensor shape and element count: {} != {}",
0100                     feature_tensor.floatData_size(), expected_num_entries));
0101   }
0102 }
0103 
0104 } // namespace eicrecon
0105 #endif