File indexing completed on 2025-07-02 07:51:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/TruthGraphMetricsHook.hpp"
0010
0011 #include <algorithm>
0012
0013 namespace {
0014
0015 auto cantorize(std::vector<std::int64_t> edgeIndex,
0016 const Acts::Logger& logger) {
0017
0018
0019 std::vector<Acts::detail::CantorEdge<std::int64_t>> cantorEdgeIndex;
0020 cantorEdgeIndex.reserve(edgeIndex.size() / 2);
0021
0022 for (auto it = edgeIndex.begin(); it != edgeIndex.end(); it += 2) {
0023 cantorEdgeIndex.emplace_back(*it, *std::next(it));
0024 }
0025
0026 std::ranges::sort(cantorEdgeIndex,
0027 std::less<Acts::detail::CantorEdge<std::int64_t>>{});
0028
0029 auto new_end = std::unique(cantorEdgeIndex.begin(), cantorEdgeIndex.end());
0030 if (new_end != cantorEdgeIndex.end()) {
0031 ACTS_WARNING("Graph not unique ("
0032 << std::distance(new_end, cantorEdgeIndex.end())
0033 << " duplicates)");
0034 cantorEdgeIndex.erase(new_end, cantorEdgeIndex.end());
0035 }
0036
0037 return cantorEdgeIndex;
0038 }
0039
0040 }
0041
0042 Acts::TruthGraphMetricsHook::TruthGraphMetricsHook(
0043 const std::vector<std::int64_t>& truthGraph,
0044 std::unique_ptr<const Acts::Logger> l)
0045 : m_logger(std::move(l)) {
0046 m_truthGraphCantor = cantorize(truthGraph, logger());
0047 }
0048
0049 void Acts::TruthGraphMetricsHook::operator()(
0050 const PipelineTensors& tensors, const ExecutionContext& execCtx) const {
0051 auto edgeIndexTensor =
0052 tensors.edgeIndex.clone({Device::Cpu(), execCtx.stream});
0053
0054 const auto numEdges = edgeIndexTensor.shape().at(1);
0055 if (numEdges == 0) {
0056 ACTS_WARNING("no edges, cannot compute metrics");
0057 return;
0058 }
0059
0060
0061 std::vector<std::int64_t> edgeIndexTransposed;
0062 edgeIndexTransposed.reserve(edgeIndexTensor.size());
0063 for (auto i = 0ul; i < numEdges; ++i) {
0064 edgeIndexTransposed.push_back(*(edgeIndexTensor.data() + i));
0065 edgeIndexTransposed.push_back(*(edgeIndexTensor.data() + numEdges + i));
0066 }
0067
0068 auto predGraphCantor = cantorize(edgeIndexTransposed, logger());
0069
0070
0071 std::vector<Acts::detail::CantorEdge<std::int64_t>> intersection;
0072 intersection.reserve(
0073 std::max(predGraphCantor.size(), m_truthGraphCantor.size()));
0074
0075 std::set_intersection(predGraphCantor.begin(), predGraphCantor.end(),
0076 m_truthGraphCantor.begin(), m_truthGraphCantor.end(),
0077 std::back_inserter(intersection));
0078
0079 ACTS_DEBUG("Intersection size " << intersection.size());
0080 const float intersectionSizeFloat = intersection.size();
0081 const float eff = intersectionSizeFloat / m_truthGraphCantor.size();
0082 const float pur = intersectionSizeFloat / predGraphCantor.size();
0083
0084 ACTS_INFO("Efficiency=" << eff << ", purity=" << pur);
0085 }