File indexing completed on 2025-12-16 09:25:34
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <boost/test/unit_test.hpp>
0010
0011 #include "ActsPlugins/Gnn/detail/CudaUtils.hpp"
0012 #include "ActsPlugins/Gnn/detail/JunctionRemoval.hpp"
0013
0014 #include <algorithm>
0015 #include <numeric>
0016
0017 using Vi = std::vector<std::int64_t>;
0018 using Vf = std::vector<float>;
0019
0020 using namespace ActsPlugins::detail;
0021
0022 void testJunctionRemoval(const Vi &srcNodes, const Vi &dstNodes,
0023 const Vf &scores, const Vi &expectedSrcNodes,
0024 const Vi &expectedDstNodes) {
0025 std::size_t nEdges = srcNodes.size();
0026 std::size_t nNodes =
0027 std::max(*std::max_element(srcNodes.begin(), srcNodes.end()),
0028 *std::max_element(dstNodes.begin(), dstNodes.end())) +
0029 1;
0030 std::size_t nExpectedEdges = expectedSrcNodes.size();
0031
0032 cudaStream_t stream{};
0033 ACTS_CUDA_CHECK(cudaStreamCreate(&stream));
0034
0035 std::int64_t *cudaSrcNodes{}, *cudaDstNodes{};
0036 float *cudaScores{};
0037 ACTS_CUDA_CHECK(
0038 cudaMallocAsync(&cudaSrcNodes, nEdges * sizeof(std::int64_t), stream));
0039 ACTS_CUDA_CHECK(
0040 cudaMallocAsync(&cudaDstNodes, nEdges * sizeof(std::int64_t), stream));
0041 ACTS_CUDA_CHECK(cudaMallocAsync(&cudaScores, nEdges * sizeof(float), stream));
0042
0043 ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaSrcNodes, srcNodes.data(),
0044 nEdges * sizeof(std::int64_t),
0045 cudaMemcpyHostToDevice, stream));
0046 ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaDstNodes, dstNodes.data(),
0047 nEdges * sizeof(std::int64_t),
0048 cudaMemcpyHostToDevice, stream));
0049 ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaScores, scores.data(),
0050 nEdges * sizeof(float),
0051 cudaMemcpyHostToDevice, stream));
0052
0053 auto [cudaSrcNodesOut, nEdgesOut] = junctionRemovalCuda(
0054 nEdges, nNodes, cudaScores, cudaSrcNodes, cudaDstNodes, stream);
0055 auto cudaDstNodesOut = cudaSrcNodesOut + nEdgesOut;
0056
0057 Vi srcNodesOut(nEdgesOut);
0058 Vi dstNodesOut(nEdgesOut);
0059 ACTS_CUDA_CHECK(cudaMemcpyAsync(srcNodesOut.data(), cudaSrcNodesOut,
0060 nEdgesOut * sizeof(std::int64_t),
0061 cudaMemcpyDeviceToHost, stream));
0062 ACTS_CUDA_CHECK(cudaMemcpyAsync(dstNodesOut.data(), cudaDstNodesOut,
0063 nEdgesOut * sizeof(std::int64_t),
0064 cudaMemcpyDeviceToHost, stream));
0065 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0066
0067 ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcNodes, stream));
0068 ACTS_CUDA_CHECK(cudaFreeAsync(cudaDstNodes, stream));
0069 ACTS_CUDA_CHECK(cudaFreeAsync(cudaScores, stream));
0070 ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcNodesOut, stream));
0071
0072 ACTS_CUDA_CHECK(cudaStreamDestroy(stream));
0073
0074
0075 Vi idxs(nEdgesOut);
0076 std::iota(idxs.begin(), idxs.end(), 0);
0077 std::sort(idxs.begin(), idxs.end(), [&](int a, int b) {
0078 if (srcNodesOut.at(a) != srcNodesOut.at(b)) {
0079 return srcNodesOut.at(a) < srcNodesOut.at(b);
0080 }
0081 return dstNodesOut.at(a) < dstNodesOut.at(b);
0082 });
0083
0084 Vi srcNodesOutSorted(nEdgesOut);
0085 Vi dstNodesOutSorted(nEdgesOut);
0086 for (std::size_t i = 0; i < nEdgesOut; ++i) {
0087 srcNodesOutSorted.at(i) = srcNodesOut.at(idxs.at(i));
0088 dstNodesOutSorted.at(i) = dstNodesOut.at(idxs.at(i));
0089 }
0090
0091 BOOST_REQUIRE_EQUAL(nEdgesOut, nExpectedEdges);
0092 BOOST_CHECK_EQUAL_COLLECTIONS(
0093 srcNodesOutSorted.begin(), srcNodesOutSorted.end(),
0094 expectedSrcNodes.begin(), expectedSrcNodes.end());
0095 BOOST_CHECK_EQUAL_COLLECTIONS(
0096 dstNodesOutSorted.begin(), dstNodesOutSorted.end(),
0097 expectedDstNodes.begin(), expectedDstNodes.end());
0098 }
0099
0100 BOOST_AUTO_TEST_CASE(test_no_junction) {
0101 Vi srcNodes = {0, 1};
0102 Vi dstNodes = {1, 2};
0103 Vf scores = {0.3f, 0.9f};
0104 Vi expectedSrcNodes = {0, 1};
0105 Vi expectedDstNodes = {1, 2};
0106 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0107 expectedDstNodes);
0108 }
0109
0110 BOOST_AUTO_TEST_CASE(test_junction_2_in) {
0111 Vi srcNodes = {0, 1, 2};
0112 Vi dstNodes = {2, 2, 3};
0113 Vf scores = {0.9f, 0.3f, 0.9f};
0114 Vi expectedSrcNodes = {0, 2};
0115 Vi expectedDstNodes = {2, 3};
0116 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0117 expectedDstNodes);
0118 }
0119
0120 BOOST_AUTO_TEST_CASE(test_junction_2_out) {
0121 Vi srcNodes = {0, 1, 1};
0122 Vi dstNodes = {1, 2, 3};
0123 Vf scores = {0.3f, 0.3f, 0.9f};
0124 Vi expectedSrcNodes = {0, 1};
0125 Vi expectedDstNodes = {1, 3};
0126 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0127 expectedDstNodes);
0128 }
0129
0130 BOOST_AUTO_TEST_CASE(test_junction_2_in_2_out) {
0131 Vi srcNodes = {0, 1, 2, 2};
0132 Vi dstNodes = {2, 2, 3, 4};
0133 Vf scores = {0.9f, 0.3f, 0.9f, 0.5f};
0134 Vi expectedSrcNodes = {0, 2};
0135 Vi expectedDstNodes = {2, 3};
0136 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0137 expectedDstNodes);
0138 }
0139
0140 BOOST_AUTO_TEST_CASE(test_junction_3_in_3_out) {
0141 Vi srcNodes = {0, 1, 2, 3, 3, 3};
0142 Vi dstNodes = {3, 3, 3, 4, 5, 6};
0143 Vf scores = {0.2f, 0.3f, 0.9f, 0.5f, 0.9f, 0.1f};
0144 Vi expectedSrcNodes = {2, 3};
0145 Vi expectedDstNodes = {3, 5};
0146 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0147 expectedDstNodes);
0148 }
0149
0150 BOOST_AUTO_TEST_CASE(test_junction_leftover_edge) {
0151 int j = 2;
0152 Vi srcNodes = {0, 1, 3, j};
0153 Vi dstNodes = {1, j, j, 4};
0154 Vf scores = {0.9f, 0.1f, 0.9f, 0.9f};
0155 Vi expectedSrcNodes = {0, j, 3};
0156 Vi expectedDstNodes = {1, 4, j};
0157 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0158 expectedDstNodes);
0159 }
0160
0161 BOOST_AUTO_TEST_CASE(test_no_in_edges) {
0162 int j = 0;
0163 Vi srcNodes = {j, j};
0164 Vi dstNodes = {1, 2};
0165 Vf scores = {0.9f, 0.1f};
0166 Vi expectedSrcNodes = {j};
0167 Vi expectedDstNodes = {1};
0168 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0169 expectedDstNodes);
0170 }
0171
0172 BOOST_AUTO_TEST_CASE(test_no_out_edges) {
0173 int j = 0;
0174 Vi srcNodes = {1, 2};
0175 Vi dstNodes = {j, j};
0176 Vf scores = {0.9f, 0.1f};
0177 Vi expectedSrcNodes = {1};
0178 Vi expectedDstNodes = {j};
0179 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0180 expectedDstNodes);
0181 }
0182
0183 BOOST_AUTO_TEST_CASE(test_two_junctions) {
0184 int j1 = 2;
0185 int j2 = 3;
0186 Vi srcNodes = {0, 1, j1, j2, j2};
0187 Vi dstNodes = {j1, j1, j2, 4, 5};
0188 Vf scores = {0.9f, 0.1f, 0.9f, 0.1f, 0.9f};
0189 Vi expectedSrcNodes = {0, j1, j2};
0190 Vi expectedDstNodes = {j1, j2, 5};
0191
0192 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0193 expectedDstNodes);
0194 }