File indexing completed on 2025-01-30 09:15:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 #include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"
0013
0014 #include <algorithm>
0015
0016 #include <torch/torch.h>
0017
0018 using namespace torch::indexing;
0019
0020 namespace {
0021
0022 auto cantorize(std::vector<std::int64_t> edgeIndex,
0023 const Acts::Logger& logger) {
0024
0025
0026 std::vector<Acts::detail::CantorEdge<std::int64_t>> cantorEdgeIndex;
0027 cantorEdgeIndex.reserve(edgeIndex.size() / 2);
0028
0029 for (auto it = edgeIndex.begin(); it != edgeIndex.end(); it += 2) {
0030 cantorEdgeIndex.emplace_back(*it, *std::next(it));
0031 }
0032
0033 std::ranges::sort(cantorEdgeIndex,
0034 std::less<Acts::detail::CantorEdge<std::int64_t>>{});
0035
0036 auto new_end = std::unique(cantorEdgeIndex.begin(), cantorEdgeIndex.end());
0037 if (new_end != cantorEdgeIndex.end()) {
0038 ACTS_WARNING("Graph not unique ("
0039 << std::distance(new_end, cantorEdgeIndex.end())
0040 << " duplicates)");
0041 cantorEdgeIndex.erase(new_end, cantorEdgeIndex.end());
0042 }
0043
0044 return cantorEdgeIndex;
0045 }
0046
0047 }
0048
0049 Acts::TorchTruthGraphMetricsHook::TorchTruthGraphMetricsHook(
0050 const std::vector<std::int64_t>& truthGraph,
0051 std::unique_ptr<const Acts::Logger> l)
0052 : m_logger(std::move(l)) {
0053 m_truthGraphCantor = cantorize(truthGraph, logger());
0054 }
0055
0056 void Acts::TorchTruthGraphMetricsHook::operator()(const std::any&,
0057 const std::any& edges,
0058 const std::any&) const {
0059 auto edgeIndexTensor =
0060 std::any_cast<torch::Tensor>(edges).to(torch::kCPU).contiguous();
0061 ACTS_VERBOSE("edge index tensor: " << detail::TensorDetails{edgeIndexTensor});
0062
0063 const auto numEdges = edgeIndexTensor.size(1);
0064 if (numEdges == 0) {
0065 ACTS_WARNING("no edges, cannot compute metrics");
0066 return;
0067 }
0068 ACTS_VERBOSE("Edge index slice:\n"
0069 << edgeIndexTensor.index(
0070 {Slice(0, 2), Slice(0, std::min(numEdges, 10l))}));
0071
0072
0073 const auto edgeIndex =
0074 Acts::detail::tensor2DToVector<std::int64_t>(edgeIndexTensor.t().clone());
0075
0076 ACTS_VERBOSE("Edge vector:\n"
0077 << (detail::RangePrinter{
0078 edgeIndex.begin(),
0079 edgeIndex.begin() + std::min(numEdges, 10l)}));
0080
0081 auto predGraphCantor = cantorize(edgeIndex, logger());
0082
0083
0084 std::vector<Acts::detail::CantorEdge<std::int64_t>> intersection;
0085 intersection.reserve(
0086 std::max(predGraphCantor.size(), m_truthGraphCantor.size()));
0087
0088 std::set_intersection(predGraphCantor.begin(), predGraphCantor.end(),
0089 m_truthGraphCantor.begin(), m_truthGraphCantor.end(),
0090 std::back_inserter(intersection));
0091
0092 ACTS_DEBUG("Intersection size " << intersection.size());
0093 const float intersectionSizeFloat = intersection.size();
0094 const float eff = intersectionSizeFloat / m_truthGraphCantor.size();
0095 const float pur = intersectionSizeFloat / predGraphCantor.size();
0096
0097 ACTS_INFO("Efficiency=" << eff << ", purity=" << pur);
0098 }