File indexing completed on 2025-12-16 09:25:34
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <boost/test/unit_test.hpp>
0010
0011 #include "ActsPlugins/Gnn/TruthGraphMetricsHook.hpp"
0012
0013 #include <cassert>
0014
0015 using namespace Acts;
0016 using namespace ActsPlugins;
0017
0018 void testTruthTestGraph(std::vector<std::int64_t> &truthGraph,
0019 std::vector<std::int64_t> &testGraph,
0020 const std::string &resStr) {
0021 std::stringstream ss;
0022 auto logger = getDefaultLogger("Test", Logging::INFO, &ss);
0023
0024 TruthGraphMetricsHook hook(truthGraph, std::move(logger));
0025
0026 auto numTestEdges = testGraph.size() / 2;
0027 auto edgeIndexTensor =
0028 Tensor<std::int64_t>::Create({2, numTestEdges}, {Device::Cpu(), {}});
0029
0030
0031 for (auto i = 0ul; i < numTestEdges; ++i) {
0032 *(edgeIndexTensor.data() + i) = testGraph.at(2 * i);
0033 *(edgeIndexTensor.data() + numTestEdges + i) = testGraph.at(2 * i + 1);
0034 }
0035
0036 PipelineTensors tensors{Tensor<float>::Create({1, 1}, {Device::Cpu(), {}}),
0037 std::move(edgeIndexTensor),
0038 {},
0039 {}};
0040
0041 hook(tensors, {Device::Cpu(), {}});
0042
0043 const auto str = ss.str();
0044
0045 auto begin = str.begin() + str.find("Efficiency");
0046 BOOST_CHECK_EQUAL(std::string(begin, str.end() - 1), resStr);
0047 }
0048
0049 namespace ActsTests {
0050
0051 BOOST_AUTO_TEST_SUITE(GnnSuite)
0052
0053 BOOST_AUTO_TEST_CASE(same_graph) {
0054
0055 std::vector<std::int64_t> truthGraph = {
0056 1,2,
0057 2,3,
0058 3,4,
0059 4,5,
0060 };
0061
0062
0063
0064 std::vector<std::int64_t> testGraph = {
0065 3,4,
0066 4,5,
0067 1,2,
0068 2,3,
0069 };
0070
0071
0072 testTruthTestGraph(truthGraph, testGraph, "Efficiency=1, purity=1");
0073 }
0074
0075
0076 BOOST_AUTO_TEST_CASE(same_graph_large_numbers) {
0077
0078 std::int64_t k = 100'000;
0079
0080 std::vector<std::int64_t> truthGraph = {
0081 1,2,
0082 2,3,
0083 3,4,
0084 4,5,
0085 };
0086 // clang-format on
0087 std::transform(truthGraph.begin(), truthGraph.end(), truthGraph.begin(),
0088 [&](auto i) { return k + i; });
0089
0090 // clang-format off
0091 std::vector<std::int64_t> testGraph = {
0092 3,4,
0093 4,5,
0094 1,2,
0095 2,3,
0096 };
0097 // clang-format on
0098 std::transform(testGraph.begin(), testGraph.end(), testGraph.begin(),
0099 [&](auto i) { return k + i; });
0100
0101 testTruthTestGraph(truthGraph, testGraph, "Efficiency=1, purity=1");
0102 }
0103
0104 BOOST_AUTO_TEST_CASE(fifty_fifty) {
0105 // clang-format off
0106 std::vector<std::int64_t> truthGraph = {
0107 1,2,
0108 2,3,
0109 3,4,
0110 4,5,
0111 };
0112 // clang-format on
0113
0114 // clang-format off
0115 std::vector<std::int64_t> testGraph = {
0116 3,4,
0117 4,5,
0118 6,9,
0119 5,1,
0120 };
0121 // clang-format on
0122
0123 testTruthTestGraph(truthGraph, testGraph, "Efficiency=0.5, purity=0.5");
0124 }
0125
0126 BOOST_AUTO_TEST_SUITE_END()
0127
0128 } // namespace ActsTests