File indexing completed on 2025-12-16 09:25:34
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <boost/test/unit_test.hpp>
0010
0011 #include "Acts/Utilities/Helpers.hpp"
0012 #include "ActsPlugins/Gnn/BoostTrackBuilding.hpp"
0013
0014 #include <algorithm>
0015 #include <numeric>
0016
0017 using namespace Acts;
0018 using namespace ActsPlugins;
0019
0020 namespace ActsTests {
0021
0022 BOOST_AUTO_TEST_SUITE(GnnSuite)
0023
0024 BOOST_AUTO_TEST_CASE(test_track_building) {
0025
0026
0027
0028
0029 std::vector<int> spacepointIds(16);
0030 std::iota(spacepointIds.begin(), spacepointIds.end(), 100);
0031
0032
0033 std::vector<std::vector<int>> refTracks;
0034 for (auto t = 0ul; t < 4; ++t) {
0035 refTracks.emplace_back(spacepointIds.begin() + 4 * t,
0036 spacepointIds.begin() + 4 * (t + 1));
0037 }
0038
0039
0040 std::vector<std::int64_t> e0, e1;
0041 for (const auto &track : refTracks) {
0042 for (auto it = track.begin(); it != track.end() - 1; ++it) {
0043
0044 e0.push_back(*it - 100);
0045 e1.push_back(*std::next(it) - 100);
0046 }
0047 }
0048
0049 ExecutionContext execCtx{Device::Cpu(), {}};
0050
0051 auto edgeTensor = Tensor<std::int64_t>::Create({2, e0.size()}, execCtx);
0052 std::copy(e0.begin(), e0.end(), edgeTensor.data());
0053 std::copy(e1.begin(), e1.end(), edgeTensor.data() + e0.size());
0054
0055 auto dummyNodes = Tensor<float>::Create({spacepointIds.size(), 16}, execCtx);
0056 auto dummyWeights = Tensor<float>::Create({e0.size(), 1}, execCtx);
0057 std::fill(dummyWeights.data(), dummyWeights.data() + dummyWeights.size(),
0058 1.f);
0059
0060
0061 auto logger = getDefaultLogger("TestLogger", Logging::ERROR);
0062 BoostTrackBuilding trackBuilder({}, std::move(logger));
0063
0064 auto testTracks = trackBuilder({std::move(dummyNodes), std::move(edgeTensor),
0065 std::nullopt, std::move(dummyWeights)},
0066 spacepointIds);
0067
0068 BOOST_CHECK_EQUAL(testTracks.size(), refTracks.size());
0069
0070
0071 std::ranges::for_each(testTracks, [](auto &t) { std::ranges::sort(t); });
0072 std::ranges::sort(testTracks, std::less{}, [](auto &t) { return t.at(0); });
0073
0074 std::ranges::for_each(refTracks, [](auto &t) { std::ranges::sort(t); });
0075 std::ranges::sort(refTracks, std::less{}, [](auto &t) { return t.at(0); });
0076
0077
0078 for (const auto &refTrack : refTracks) {
0079 BOOST_CHECK(rangeContainsValue(testTracks, refTrack));
0080 }
0081 }
0082
0083 BOOST_AUTO_TEST_SUITE_END()
0084
0085 }