Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-06-21 08:10:33

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include <boost/test/unit_test.hpp>
0010 
0011 #include <Acts/Plugins/ExaTrkX/detail/CudaUtils.hpp>
0012 #include <Acts/Plugins/ExaTrkX/detail/JunctionRemoval.hpp>
0013 
0014 #include <algorithm>
0015 #include <numeric>
0016 
0017 using Vi = std::vector<std::int64_t>;
0018 using Vf = std::vector<float>;
0019 
0020 void testJunctionRemoval(const Vi &srcNodes, const Vi &dstNodes,
0021                          const Vf &scores, const Vi &expectedSrcNodes,
0022                          const Vi &expectedDstNodes) {
0023   std::size_t nEdges = srcNodes.size();
0024   std::size_t nNodes =
0025       std::max(*std::max_element(srcNodes.begin(), srcNodes.end()),
0026                *std::max_element(dstNodes.begin(), dstNodes.end())) +
0027       1;
0028   std::size_t nExpectedEdges = expectedSrcNodes.size();
0029 
0030   cudaStream_t stream{};
0031   ACTS_CUDA_CHECK(cudaStreamCreate(&stream));
0032 
0033   std::int64_t *cudaSrcNodes{}, *cudaDstNodes{};
0034   float *cudaScores{};
0035   ACTS_CUDA_CHECK(
0036       cudaMallocAsync(&cudaSrcNodes, nEdges * sizeof(std::int64_t), stream));
0037   ACTS_CUDA_CHECK(
0038       cudaMallocAsync(&cudaDstNodes, nEdges * sizeof(std::int64_t), stream));
0039   ACTS_CUDA_CHECK(cudaMallocAsync(&cudaScores, nEdges * sizeof(float), stream));
0040 
0041   ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaSrcNodes, srcNodes.data(),
0042                                   nEdges * sizeof(std::int64_t),
0043                                   cudaMemcpyHostToDevice, stream));
0044   ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaDstNodes, dstNodes.data(),
0045                                   nEdges * sizeof(std::int64_t),
0046                                   cudaMemcpyHostToDevice, stream));
0047   ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaScores, scores.data(),
0048                                   nEdges * sizeof(float),
0049                                   cudaMemcpyHostToDevice, stream));
0050 
0051   auto [cudaSrcNodesOut, nEdgesOut] = Acts::detail::junctionRemovalCuda(
0052       nEdges, nNodes, cudaScores, cudaSrcNodes, cudaDstNodes, stream);
0053   auto cudaDstNodesOut = cudaSrcNodesOut + nEdgesOut;
0054 
0055   Vi srcNodesOut(nEdgesOut);
0056   Vi dstNodesOut(nEdgesOut);
0057   ACTS_CUDA_CHECK(cudaMemcpyAsync(srcNodesOut.data(), cudaSrcNodesOut,
0058                                   nEdgesOut * sizeof(std::int64_t),
0059                                   cudaMemcpyDeviceToHost, stream));
0060   ACTS_CUDA_CHECK(cudaMemcpyAsync(dstNodesOut.data(), cudaDstNodesOut,
0061                                   nEdgesOut * sizeof(std::int64_t),
0062                                   cudaMemcpyDeviceToHost, stream));
0063   ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0064 
0065   ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcNodes, stream));
0066   ACTS_CUDA_CHECK(cudaFreeAsync(cudaDstNodes, stream));
0067   ACTS_CUDA_CHECK(cudaFreeAsync(cudaScores, stream));
0068   ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcNodesOut, stream));
0069 
0070   ACTS_CUDA_CHECK(cudaStreamDestroy(stream));
0071 
0072   // Sort edges by src and dst nodes before comparison
0073   Vi idxs(nEdgesOut);
0074   std::iota(idxs.begin(), idxs.end(), 0);
0075   std::sort(idxs.begin(), idxs.end(), [&](int a, int b) {
0076     if (srcNodesOut.at(a) != srcNodesOut.at(b)) {
0077       return srcNodesOut.at(a) < srcNodesOut.at(b);
0078     }
0079     return dstNodesOut.at(a) < dstNodesOut.at(b);
0080   });
0081 
0082   Vi srcNodesOutSorted(nEdgesOut);
0083   Vi dstNodesOutSorted(nEdgesOut);
0084   for (std::size_t i = 0; i < nEdgesOut; ++i) {
0085     srcNodesOutSorted.at(i) = srcNodesOut.at(idxs.at(i));
0086     dstNodesOutSorted.at(i) = dstNodesOut.at(idxs.at(i));
0087   }
0088 
0089   BOOST_REQUIRE_EQUAL(nEdgesOut, nExpectedEdges);
0090   BOOST_CHECK_EQUAL_COLLECTIONS(
0091       srcNodesOutSorted.begin(), srcNodesOutSorted.end(),
0092       expectedSrcNodes.begin(), expectedSrcNodes.end());
0093   BOOST_CHECK_EQUAL_COLLECTIONS(
0094       dstNodesOutSorted.begin(), dstNodesOutSorted.end(),
0095       expectedDstNodes.begin(), expectedDstNodes.end());
0096 }
0097 
0098 BOOST_AUTO_TEST_CASE(test_no_junction) {
0099   Vi srcNodes = {0, 1};
0100   Vi dstNodes = {1, 2};
0101   Vf scores = {0.3f, 0.9f};
0102   Vi expectedSrcNodes = {0, 1};
0103   Vi expectedDstNodes = {1, 2};
0104   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0105                       expectedDstNodes);
0106 }
0107 
0108 BOOST_AUTO_TEST_CASE(test_junction_2_in) {
0109   Vi srcNodes = {0, 1, 2};
0110   Vi dstNodes = {2, 2, 3};
0111   Vf scores = {0.9f, 0.3f, 0.9f};
0112   Vi expectedSrcNodes = {0, 2};
0113   Vi expectedDstNodes = {2, 3};
0114   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0115                       expectedDstNodes);
0116 }
0117 
0118 BOOST_AUTO_TEST_CASE(test_junction_2_out) {
0119   Vi srcNodes = {0, 1, 1};
0120   Vi dstNodes = {1, 2, 3};
0121   Vf scores = {0.3f, 0.3f, 0.9f};
0122   Vi expectedSrcNodes = {0, 1};
0123   Vi expectedDstNodes = {1, 3};
0124   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0125                       expectedDstNodes);
0126 }
0127 
0128 BOOST_AUTO_TEST_CASE(test_junction_2_in_2_out) {
0129   Vi srcNodes = {0, 1, 2, 2};
0130   Vi dstNodes = {2, 2, 3, 4};
0131   Vf scores = {0.9f, 0.3f, 0.9f, 0.5f};
0132   Vi expectedSrcNodes = {0, 2};
0133   Vi expectedDstNodes = {2, 3};
0134   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0135                       expectedDstNodes);
0136 }
0137 
0138 BOOST_AUTO_TEST_CASE(test_junction_3_in_3_out) {
0139   Vi srcNodes = {0, 1, 2, 3, 3, 3};
0140   Vi dstNodes = {3, 3, 3, 4, 5, 6};
0141   Vf scores = {0.2f, 0.3f, 0.9f, 0.5f, 0.9f, 0.1f};
0142   Vi expectedSrcNodes = {2, 3};
0143   Vi expectedDstNodes = {3, 5};
0144   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0145                       expectedDstNodes);
0146 }
0147 
0148 BOOST_AUTO_TEST_CASE(test_junction_leftover_edge) {
0149   int j = 2;
0150   Vi srcNodes = {0, 1, 3, j};
0151   Vi dstNodes = {1, j, j, 4};
0152   Vf scores = {0.9f, 0.1f, 0.9f, 0.9f};
0153   Vi expectedSrcNodes = {0, j, 3};
0154   Vi expectedDstNodes = {1, 4, j};
0155   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0156                       expectedDstNodes);
0157 }
0158 
0159 BOOST_AUTO_TEST_CASE(test_no_in_edges) {
0160   int j = 0;
0161   Vi srcNodes = {j, j};
0162   Vi dstNodes = {1, 2};
0163   Vf scores = {0.9f, 0.1f};
0164   Vi expectedSrcNodes = {j};
0165   Vi expectedDstNodes = {1};
0166   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0167                       expectedDstNodes);
0168 }
0169 
0170 BOOST_AUTO_TEST_CASE(test_no_out_edges) {
0171   int j = 0;
0172   Vi srcNodes = {1, 2};
0173   Vi dstNodes = {j, j};
0174   Vf scores = {0.9f, 0.1f};
0175   Vi expectedSrcNodes = {1};
0176   Vi expectedDstNodes = {j};
0177   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0178                       expectedDstNodes);
0179 }
0180 
0181 BOOST_AUTO_TEST_CASE(test_two_junctions) {
0182   int j1 = 2;
0183   int j2 = 3;
0184   Vi srcNodes = {0, 1, j1, j2, j2};
0185   Vi dstNodes = {j1, j1, j2, 4, 5};
0186   Vf scores = {0.9f, 0.1f, 0.9f, 0.1f, 0.9f};
0187   Vi expectedSrcNodes = {0, j1, j2};
0188   Vi expectedDstNodes = {j1, j2, 5};
0189 
0190   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0191                       expectedDstNodes);
0192 }