Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 09:15:13

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 #include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"
0013 
0014 #include <algorithm>
0015 
0016 #include <torch/torch.h>
0017 
0018 using namespace torch::indexing;
0019 
0020 namespace {
0021 
0022 auto cantorize(std::vector<std::int64_t> edgeIndex,
0023                const Acts::Logger& logger) {
0024   // Use cantor pairing to store truth graph, so we can easily use set
0025   // operations to compute efficiency and purity
0026   std::vector<Acts::detail::CantorEdge<std::int64_t>> cantorEdgeIndex;
0027   cantorEdgeIndex.reserve(edgeIndex.size() / 2);
0028 
0029   for (auto it = edgeIndex.begin(); it != edgeIndex.end(); it += 2) {
0030     cantorEdgeIndex.emplace_back(*it, *std::next(it));
0031   }
0032 
0033   std::ranges::sort(cantorEdgeIndex,
0034                     std::less<Acts::detail::CantorEdge<std::int64_t>>{});
0035 
0036   auto new_end = std::unique(cantorEdgeIndex.begin(), cantorEdgeIndex.end());
0037   if (new_end != cantorEdgeIndex.end()) {
0038     ACTS_WARNING("Graph not unique ("
0039                  << std::distance(new_end, cantorEdgeIndex.end())
0040                  << " duplicates)");
0041     cantorEdgeIndex.erase(new_end, cantorEdgeIndex.end());
0042   }
0043 
0044   return cantorEdgeIndex;
0045 }
0046 
0047 }  // namespace
0048 
0049 Acts::TorchTruthGraphMetricsHook::TorchTruthGraphMetricsHook(
0050     const std::vector<std::int64_t>& truthGraph,
0051     std::unique_ptr<const Acts::Logger> l)
0052     : m_logger(std::move(l)) {
0053   m_truthGraphCantor = cantorize(truthGraph, logger());
0054 }
0055 
0056 void Acts::TorchTruthGraphMetricsHook::operator()(const std::any&,
0057                                                   const std::any& edges,
0058                                                   const std::any&) const {
0059   auto edgeIndexTensor =
0060       std::any_cast<torch::Tensor>(edges).to(torch::kCPU).contiguous();
0061   ACTS_VERBOSE("edge index tensor: " << detail::TensorDetails{edgeIndexTensor});
0062 
0063   const auto numEdges = edgeIndexTensor.size(1);
0064   if (numEdges == 0) {
0065     ACTS_WARNING("no edges, cannot compute metrics");
0066     return;
0067   }
0068   ACTS_VERBOSE("Edge index slice:\n"
0069                << edgeIndexTensor.index(
0070                       {Slice(0, 2), Slice(0, std::min(numEdges, 10l))}));
0071 
0072   // We need to transpose the edges here for the right memory layout
0073   const auto edgeIndex =
0074       Acts::detail::tensor2DToVector<std::int64_t>(edgeIndexTensor.t().clone());
0075 
0076   ACTS_VERBOSE("Edge vector:\n"
0077                << (detail::RangePrinter{
0078                       edgeIndex.begin(),
0079                       edgeIndex.begin() + std::min(numEdges, 10l)}));
0080 
0081   auto predGraphCantor = cantorize(edgeIndex, logger());
0082 
0083   // Calculate intersection
0084   std::vector<Acts::detail::CantorEdge<std::int64_t>> intersection;
0085   intersection.reserve(
0086       std::max(predGraphCantor.size(), m_truthGraphCantor.size()));
0087 
0088   std::set_intersection(predGraphCantor.begin(), predGraphCantor.end(),
0089                         m_truthGraphCantor.begin(), m_truthGraphCantor.end(),
0090                         std::back_inserter(intersection));
0091 
0092   ACTS_DEBUG("Intersection size " << intersection.size());
0093   const float intersectionSizeFloat = intersection.size();
0094   const float eff = intersectionSizeFloat / m_truthGraphCantor.size();
0095   const float pur = intersectionSizeFloat / predGraphCantor.size();
0096 
0097   ACTS_INFO("Efficiency=" << eff << ", purity=" << pur);
0098 }