File indexing completed on 2025-12-16 09:24:28
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsPlugins/Gnn/TruthGraphMetricsHook.hpp"
0010
0011 #include <algorithm>
0012
0013 using namespace Acts;
0014
0015 namespace {
0016
0017 auto cantorize(std::vector<std::int64_t> edgeIndex, const Logger& logger) {
0018
0019
0020 std::vector<ActsPlugins::detail::CantorEdge<std::int64_t>> cantorEdgeIndex;
0021 cantorEdgeIndex.reserve(edgeIndex.size() / 2);
0022
0023 for (auto it = edgeIndex.begin(); it != edgeIndex.end(); it += 2) {
0024 cantorEdgeIndex.emplace_back(*it, *std::next(it));
0025 }
0026
0027 std::ranges::sort(cantorEdgeIndex,
0028 std::less<ActsPlugins::detail::CantorEdge<std::int64_t>>{});
0029
0030 auto new_end = std::unique(cantorEdgeIndex.begin(), cantorEdgeIndex.end());
0031 if (new_end != cantorEdgeIndex.end()) {
0032 ACTS_WARNING("Graph not unique ("
0033 << std::distance(new_end, cantorEdgeIndex.end())
0034 << " duplicates)");
0035 cantorEdgeIndex.erase(new_end, cantorEdgeIndex.end());
0036 }
0037
0038 return cantorEdgeIndex;
0039 }
0040
0041 }
0042
0043 ActsPlugins::TruthGraphMetricsHook::TruthGraphMetricsHook(
0044 const std::vector<std::int64_t>& truthGraph,
0045 std::unique_ptr<const Logger> l)
0046 : m_logger(std::move(l)) {
0047 m_truthGraphCantor = cantorize(truthGraph, logger());
0048 }
0049
0050 void ActsPlugins::TruthGraphMetricsHook::operator()(
0051 const PipelineTensors& tensors, const ExecutionContext& execCtx) const {
0052 auto edgeIndexTensor =
0053 tensors.edgeIndex.clone({Device::Cpu(), execCtx.stream});
0054
0055 const auto numEdges = edgeIndexTensor.shape().at(1);
0056 if (numEdges == 0) {
0057 ACTS_WARNING("no edges, cannot compute metrics");
0058 return;
0059 }
0060
0061
0062 std::vector<std::int64_t> edgeIndexTransposed;
0063 edgeIndexTransposed.reserve(edgeIndexTensor.size());
0064 for (auto i = 0ul; i < numEdges; ++i) {
0065 edgeIndexTransposed.push_back(*(edgeIndexTensor.data() + i));
0066 edgeIndexTransposed.push_back(*(edgeIndexTensor.data() + numEdges + i));
0067 }
0068
0069 auto predGraphCantor = cantorize(edgeIndexTransposed, logger());
0070
0071
0072 std::vector<ActsPlugins::detail::CantorEdge<std::int64_t>> intersection;
0073 intersection.reserve(
0074 std::max(predGraphCantor.size(), m_truthGraphCantor.size()));
0075
0076 std::set_intersection(predGraphCantor.begin(), predGraphCantor.end(),
0077 m_truthGraphCantor.begin(), m_truthGraphCantor.end(),
0078 std::back_inserter(intersection));
0079
0080 ACTS_DEBUG("Intersection size " << intersection.size());
0081 const float intersectionSizeFloat = intersection.size();
0082 const float eff = intersectionSizeFloat / m_truthGraphCantor.size();
0083 const float pur = intersectionSizeFloat / predGraphCantor.size();
0084
0085 ACTS_INFO("Efficiency=" << eff << ", purity=" << pur);
0086 }