File indexing completed on 2025-01-18 09:13:09
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <boost/test/unit_test.hpp>
0010
0011 #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
0012
0013 #include <cassert>
0014 #include <iostream>
0015
0016 #include <torch/torch.h>
0017
0018 void testTruthTestGraph(std::vector<std::int64_t> &truthGraph,
0019 std::vector<std::int64_t> &testGraph,
0020 const std::string &resStr) {
0021 std::stringstream ss;
0022 auto logger = Acts::getDefaultLogger("Test", Acts::Logging::INFO, &ss);
0023
0024 Acts::TorchTruthGraphMetricsHook hook(truthGraph, std::move(logger));
0025
0026 auto opts = torch::TensorOptions().dtype(torch::kInt64);
0027 const auto edgeTensor =
0028 torch::from_blob(testGraph.data(),
0029 {static_cast<long>(testGraph.size() / 2), 2}, opts)
0030 .transpose(0, 1);
0031
0032 hook({}, edgeTensor, {});
0033
0034 const auto str = ss.str();
0035
0036 auto begin = str.begin() + str.find("Efficiency");
0037 BOOST_CHECK_EQUAL(std::string(begin, str.end() - 1), resStr);
0038 }
0039
0040 BOOST_AUTO_TEST_CASE(same_graph) {
0041
0042 std::vector<std::int64_t> truthGraph = {
0043 1,2,
0044 2,3,
0045 3,4,
0046 4,5,
0047 };
0048
0049
0050
0051 std::vector<std::int64_t> testGraph = {
0052 3,4,
0053 4,5,
0054 1,2,
0055 2,3,
0056 };
0057
0058
0059 testTruthTestGraph(truthGraph, testGraph, "Efficiency=1, purity=1");
0060 }
0061
0062
0063 BOOST_AUTO_TEST_CASE(same_graph_large_numbers) {
0064
0065 std::int64_t k = 100'000;
0066
0067 std::vector<std::int64_t> truthGraph = {
0068 1,2,
0069 2,3,
0070 3,4,
0071 4,5,
0072 };
0073 // clang-format on
0074 std::transform(truthGraph.begin(), truthGraph.end(), truthGraph.begin(),
0075 [&](auto i) { return k + i; });
0076
0077 // clang-format off
0078 std::vector<std::int64_t> testGraph = {
0079 3,4,
0080 4,5,
0081 1,2,
0082 2,3,
0083 };
0084 // clang-format on
0085 std::transform(testGraph.begin(), testGraph.end(), testGraph.begin(),
0086 [&](auto i) { return k + i; });
0087
0088 testTruthTestGraph(truthGraph, testGraph, "Efficiency=1, purity=1");
0089 }
0090
0091 BOOST_AUTO_TEST_CASE(fifty_fifty) {
0092 // clang-format off
0093 std::vector<std::int64_t> truthGraph = {
0094 1,2,
0095 2,3,
0096 3,4,
0097 4,5,
0098 };
0099 // clang-format on
0100
0101 // clang-format off
0102 std::vector<std::int64_t> testGraph = {
0103 3,4,
0104 4,5,
0105 6,9,
0106 5,1,
0107 };
0108 // clang-format on
0109
0110 testTruthTestGraph(truthGraph, testGraph, "Efficiency=0.5, purity=0.5");
0111 }