File indexing completed on 2025-06-21 08:10:33
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <boost/test/unit_test.hpp>
0010
0011 #include <Acts/Plugins/ExaTrkX/detail/CudaUtils.hpp>
0012 #include <Acts/Plugins/ExaTrkX/detail/JunctionRemoval.hpp>
0013
0014 #include <algorithm>
0015 #include <numeric>
0016
0017 using Vi = std::vector<std::int64_t>;
0018 using Vf = std::vector<float>;
0019
0020 void testJunctionRemoval(const Vi &srcNodes, const Vi &dstNodes,
0021 const Vf &scores, const Vi &expectedSrcNodes,
0022 const Vi &expectedDstNodes) {
0023 std::size_t nEdges = srcNodes.size();
0024 std::size_t nNodes =
0025 std::max(*std::max_element(srcNodes.begin(), srcNodes.end()),
0026 *std::max_element(dstNodes.begin(), dstNodes.end())) +
0027 1;
0028 std::size_t nExpectedEdges = expectedSrcNodes.size();
0029
0030 cudaStream_t stream{};
0031 ACTS_CUDA_CHECK(cudaStreamCreate(&stream));
0032
0033 std::int64_t *cudaSrcNodes{}, *cudaDstNodes{};
0034 float *cudaScores{};
0035 ACTS_CUDA_CHECK(
0036 cudaMallocAsync(&cudaSrcNodes, nEdges * sizeof(std::int64_t), stream));
0037 ACTS_CUDA_CHECK(
0038 cudaMallocAsync(&cudaDstNodes, nEdges * sizeof(std::int64_t), stream));
0039 ACTS_CUDA_CHECK(cudaMallocAsync(&cudaScores, nEdges * sizeof(float), stream));
0040
0041 ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaSrcNodes, srcNodes.data(),
0042 nEdges * sizeof(std::int64_t),
0043 cudaMemcpyHostToDevice, stream));
0044 ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaDstNodes, dstNodes.data(),
0045 nEdges * sizeof(std::int64_t),
0046 cudaMemcpyHostToDevice, stream));
0047 ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaScores, scores.data(),
0048 nEdges * sizeof(float),
0049 cudaMemcpyHostToDevice, stream));
0050
0051 auto [cudaSrcNodesOut, nEdgesOut] = Acts::detail::junctionRemovalCuda(
0052 nEdges, nNodes, cudaScores, cudaSrcNodes, cudaDstNodes, stream);
0053 auto cudaDstNodesOut = cudaSrcNodesOut + nEdgesOut;
0054
0055 Vi srcNodesOut(nEdgesOut);
0056 Vi dstNodesOut(nEdgesOut);
0057 ACTS_CUDA_CHECK(cudaMemcpyAsync(srcNodesOut.data(), cudaSrcNodesOut,
0058 nEdgesOut * sizeof(std::int64_t),
0059 cudaMemcpyDeviceToHost, stream));
0060 ACTS_CUDA_CHECK(cudaMemcpyAsync(dstNodesOut.data(), cudaDstNodesOut,
0061 nEdgesOut * sizeof(std::int64_t),
0062 cudaMemcpyDeviceToHost, stream));
0063 ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0064
0065 ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcNodes, stream));
0066 ACTS_CUDA_CHECK(cudaFreeAsync(cudaDstNodes, stream));
0067 ACTS_CUDA_CHECK(cudaFreeAsync(cudaScores, stream));
0068 ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcNodesOut, stream));
0069
0070 ACTS_CUDA_CHECK(cudaStreamDestroy(stream));
0071
0072
0073 Vi idxs(nEdgesOut);
0074 std::iota(idxs.begin(), idxs.end(), 0);
0075 std::sort(idxs.begin(), idxs.end(), [&](int a, int b) {
0076 if (srcNodesOut.at(a) != srcNodesOut.at(b)) {
0077 return srcNodesOut.at(a) < srcNodesOut.at(b);
0078 }
0079 return dstNodesOut.at(a) < dstNodesOut.at(b);
0080 });
0081
0082 Vi srcNodesOutSorted(nEdgesOut);
0083 Vi dstNodesOutSorted(nEdgesOut);
0084 for (std::size_t i = 0; i < nEdgesOut; ++i) {
0085 srcNodesOutSorted.at(i) = srcNodesOut.at(idxs.at(i));
0086 dstNodesOutSorted.at(i) = dstNodesOut.at(idxs.at(i));
0087 }
0088
0089 BOOST_REQUIRE_EQUAL(nEdgesOut, nExpectedEdges);
0090 BOOST_CHECK_EQUAL_COLLECTIONS(
0091 srcNodesOutSorted.begin(), srcNodesOutSorted.end(),
0092 expectedSrcNodes.begin(), expectedSrcNodes.end());
0093 BOOST_CHECK_EQUAL_COLLECTIONS(
0094 dstNodesOutSorted.begin(), dstNodesOutSorted.end(),
0095 expectedDstNodes.begin(), expectedDstNodes.end());
0096 }
0097
0098 BOOST_AUTO_TEST_CASE(test_no_junction) {
0099 Vi srcNodes = {0, 1};
0100 Vi dstNodes = {1, 2};
0101 Vf scores = {0.3f, 0.9f};
0102 Vi expectedSrcNodes = {0, 1};
0103 Vi expectedDstNodes = {1, 2};
0104 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0105 expectedDstNodes);
0106 }
0107
0108 BOOST_AUTO_TEST_CASE(test_junction_2_in) {
0109 Vi srcNodes = {0, 1, 2};
0110 Vi dstNodes = {2, 2, 3};
0111 Vf scores = {0.9f, 0.3f, 0.9f};
0112 Vi expectedSrcNodes = {0, 2};
0113 Vi expectedDstNodes = {2, 3};
0114 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0115 expectedDstNodes);
0116 }
0117
0118 BOOST_AUTO_TEST_CASE(test_junction_2_out) {
0119 Vi srcNodes = {0, 1, 1};
0120 Vi dstNodes = {1, 2, 3};
0121 Vf scores = {0.3f, 0.3f, 0.9f};
0122 Vi expectedSrcNodes = {0, 1};
0123 Vi expectedDstNodes = {1, 3};
0124 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0125 expectedDstNodes);
0126 }
0127
0128 BOOST_AUTO_TEST_CASE(test_junction_2_in_2_out) {
0129 Vi srcNodes = {0, 1, 2, 2};
0130 Vi dstNodes = {2, 2, 3, 4};
0131 Vf scores = {0.9f, 0.3f, 0.9f, 0.5f};
0132 Vi expectedSrcNodes = {0, 2};
0133 Vi expectedDstNodes = {2, 3};
0134 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0135 expectedDstNodes);
0136 }
0137
0138 BOOST_AUTO_TEST_CASE(test_junction_3_in_3_out) {
0139 Vi srcNodes = {0, 1, 2, 3, 3, 3};
0140 Vi dstNodes = {3, 3, 3, 4, 5, 6};
0141 Vf scores = {0.2f, 0.3f, 0.9f, 0.5f, 0.9f, 0.1f};
0142 Vi expectedSrcNodes = {2, 3};
0143 Vi expectedDstNodes = {3, 5};
0144 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0145 expectedDstNodes);
0146 }
0147
0148 BOOST_AUTO_TEST_CASE(test_junction_leftover_edge) {
0149 int j = 2;
0150 Vi srcNodes = {0, 1, 3, j};
0151 Vi dstNodes = {1, j, j, 4};
0152 Vf scores = {0.9f, 0.1f, 0.9f, 0.9f};
0153 Vi expectedSrcNodes = {0, j, 3};
0154 Vi expectedDstNodes = {1, 4, j};
0155 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0156 expectedDstNodes);
0157 }
0158
0159 BOOST_AUTO_TEST_CASE(test_no_in_edges) {
0160 int j = 0;
0161 Vi srcNodes = {j, j};
0162 Vi dstNodes = {1, 2};
0163 Vf scores = {0.9f, 0.1f};
0164 Vi expectedSrcNodes = {j};
0165 Vi expectedDstNodes = {1};
0166 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0167 expectedDstNodes);
0168 }
0169
0170 BOOST_AUTO_TEST_CASE(test_no_out_edges) {
0171 int j = 0;
0172 Vi srcNodes = {1, 2};
0173 Vi dstNodes = {j, j};
0174 Vf scores = {0.9f, 0.1f};
0175 Vi expectedSrcNodes = {1};
0176 Vi expectedDstNodes = {j};
0177 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0178 expectedDstNodes);
0179 }
0180
0181 BOOST_AUTO_TEST_CASE(test_two_junctions) {
0182 int j1 = 2;
0183 int j2 = 3;
0184 Vi srcNodes = {0, 1, j1, j2, j2};
0185 Vi dstNodes = {j1, j1, j2, 4, 5};
0186 Vf scores = {0.9f, 0.1f, 0.9f, 0.1f, 0.9f};
0187 Vi expectedSrcNodes = {0, j1, j2};
0188 Vi expectedDstNodes = {j1, j2, 5};
0189
0190 testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0191 expectedDstNodes);
0192 }