Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:25:34

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include <boost/test/unit_test.hpp>
0010 
0011 #include "ActsPlugins/Gnn/detail/CudaUtils.hpp"
0012 #include "ActsPlugins/Gnn/detail/JunctionRemoval.hpp"
0013 
0014 #include <algorithm>
0015 #include <numeric>
0016 
0017 using Vi = std::vector<std::int64_t>;
0018 using Vf = std::vector<float>;
0019 
0020 using namespace ActsPlugins::detail;
0021 
0022 void testJunctionRemoval(const Vi &srcNodes, const Vi &dstNodes,
0023                          const Vf &scores, const Vi &expectedSrcNodes,
0024                          const Vi &expectedDstNodes) {
0025   std::size_t nEdges = srcNodes.size();
0026   std::size_t nNodes =
0027       std::max(*std::max_element(srcNodes.begin(), srcNodes.end()),
0028                *std::max_element(dstNodes.begin(), dstNodes.end())) +
0029       1;
0030   std::size_t nExpectedEdges = expectedSrcNodes.size();
0031 
0032   cudaStream_t stream{};
0033   ACTS_CUDA_CHECK(cudaStreamCreate(&stream));
0034 
0035   std::int64_t *cudaSrcNodes{}, *cudaDstNodes{};
0036   float *cudaScores{};
0037   ACTS_CUDA_CHECK(
0038       cudaMallocAsync(&cudaSrcNodes, nEdges * sizeof(std::int64_t), stream));
0039   ACTS_CUDA_CHECK(
0040       cudaMallocAsync(&cudaDstNodes, nEdges * sizeof(std::int64_t), stream));
0041   ACTS_CUDA_CHECK(cudaMallocAsync(&cudaScores, nEdges * sizeof(float), stream));
0042 
0043   ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaSrcNodes, srcNodes.data(),
0044                                   nEdges * sizeof(std::int64_t),
0045                                   cudaMemcpyHostToDevice, stream));
0046   ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaDstNodes, dstNodes.data(),
0047                                   nEdges * sizeof(std::int64_t),
0048                                   cudaMemcpyHostToDevice, stream));
0049   ACTS_CUDA_CHECK(cudaMemcpyAsync(cudaScores, scores.data(),
0050                                   nEdges * sizeof(float),
0051                                   cudaMemcpyHostToDevice, stream));
0052 
0053   auto [cudaSrcNodesOut, nEdgesOut] = junctionRemovalCuda(
0054       nEdges, nNodes, cudaScores, cudaSrcNodes, cudaDstNodes, stream);
0055   auto cudaDstNodesOut = cudaSrcNodesOut + nEdgesOut;
0056 
0057   Vi srcNodesOut(nEdgesOut);
0058   Vi dstNodesOut(nEdgesOut);
0059   ACTS_CUDA_CHECK(cudaMemcpyAsync(srcNodesOut.data(), cudaSrcNodesOut,
0060                                   nEdgesOut * sizeof(std::int64_t),
0061                                   cudaMemcpyDeviceToHost, stream));
0062   ACTS_CUDA_CHECK(cudaMemcpyAsync(dstNodesOut.data(), cudaDstNodesOut,
0063                                   nEdgesOut * sizeof(std::int64_t),
0064                                   cudaMemcpyDeviceToHost, stream));
0065   ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
0066 
0067   ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcNodes, stream));
0068   ACTS_CUDA_CHECK(cudaFreeAsync(cudaDstNodes, stream));
0069   ACTS_CUDA_CHECK(cudaFreeAsync(cudaScores, stream));
0070   ACTS_CUDA_CHECK(cudaFreeAsync(cudaSrcNodesOut, stream));
0071 
0072   ACTS_CUDA_CHECK(cudaStreamDestroy(stream));
0073 
0074   // Sort edges by src and dst nodes before comparison
0075   Vi idxs(nEdgesOut);
0076   std::iota(idxs.begin(), idxs.end(), 0);
0077   std::sort(idxs.begin(), idxs.end(), [&](int a, int b) {
0078     if (srcNodesOut.at(a) != srcNodesOut.at(b)) {
0079       return srcNodesOut.at(a) < srcNodesOut.at(b);
0080     }
0081     return dstNodesOut.at(a) < dstNodesOut.at(b);
0082   });
0083 
0084   Vi srcNodesOutSorted(nEdgesOut);
0085   Vi dstNodesOutSorted(nEdgesOut);
0086   for (std::size_t i = 0; i < nEdgesOut; ++i) {
0087     srcNodesOutSorted.at(i) = srcNodesOut.at(idxs.at(i));
0088     dstNodesOutSorted.at(i) = dstNodesOut.at(idxs.at(i));
0089   }
0090 
0091   BOOST_REQUIRE_EQUAL(nEdgesOut, nExpectedEdges);
0092   BOOST_CHECK_EQUAL_COLLECTIONS(
0093       srcNodesOutSorted.begin(), srcNodesOutSorted.end(),
0094       expectedSrcNodes.begin(), expectedSrcNodes.end());
0095   BOOST_CHECK_EQUAL_COLLECTIONS(
0096       dstNodesOutSorted.begin(), dstNodesOutSorted.end(),
0097       expectedDstNodes.begin(), expectedDstNodes.end());
0098 }
0099 
0100 BOOST_AUTO_TEST_CASE(test_no_junction) {
0101   Vi srcNodes = {0, 1};
0102   Vi dstNodes = {1, 2};
0103   Vf scores = {0.3f, 0.9f};
0104   Vi expectedSrcNodes = {0, 1};
0105   Vi expectedDstNodes = {1, 2};
0106   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0107                       expectedDstNodes);
0108 }
0109 
0110 BOOST_AUTO_TEST_CASE(test_junction_2_in) {
0111   Vi srcNodes = {0, 1, 2};
0112   Vi dstNodes = {2, 2, 3};
0113   Vf scores = {0.9f, 0.3f, 0.9f};
0114   Vi expectedSrcNodes = {0, 2};
0115   Vi expectedDstNodes = {2, 3};
0116   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0117                       expectedDstNodes);
0118 }
0119 
0120 BOOST_AUTO_TEST_CASE(test_junction_2_out) {
0121   Vi srcNodes = {0, 1, 1};
0122   Vi dstNodes = {1, 2, 3};
0123   Vf scores = {0.3f, 0.3f, 0.9f};
0124   Vi expectedSrcNodes = {0, 1};
0125   Vi expectedDstNodes = {1, 3};
0126   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0127                       expectedDstNodes);
0128 }
0129 
0130 BOOST_AUTO_TEST_CASE(test_junction_2_in_2_out) {
0131   Vi srcNodes = {0, 1, 2, 2};
0132   Vi dstNodes = {2, 2, 3, 4};
0133   Vf scores = {0.9f, 0.3f, 0.9f, 0.5f};
0134   Vi expectedSrcNodes = {0, 2};
0135   Vi expectedDstNodes = {2, 3};
0136   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0137                       expectedDstNodes);
0138 }
0139 
0140 BOOST_AUTO_TEST_CASE(test_junction_3_in_3_out) {
0141   Vi srcNodes = {0, 1, 2, 3, 3, 3};
0142   Vi dstNodes = {3, 3, 3, 4, 5, 6};
0143   Vf scores = {0.2f, 0.3f, 0.9f, 0.5f, 0.9f, 0.1f};
0144   Vi expectedSrcNodes = {2, 3};
0145   Vi expectedDstNodes = {3, 5};
0146   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0147                       expectedDstNodes);
0148 }
0149 
0150 BOOST_AUTO_TEST_CASE(test_junction_leftover_edge) {
0151   int j = 2;
0152   Vi srcNodes = {0, 1, 3, j};
0153   Vi dstNodes = {1, j, j, 4};
0154   Vf scores = {0.9f, 0.1f, 0.9f, 0.9f};
0155   Vi expectedSrcNodes = {0, j, 3};
0156   Vi expectedDstNodes = {1, 4, j};
0157   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0158                       expectedDstNodes);
0159 }
0160 
0161 BOOST_AUTO_TEST_CASE(test_no_in_edges) {
0162   int j = 0;
0163   Vi srcNodes = {j, j};
0164   Vi dstNodes = {1, 2};
0165   Vf scores = {0.9f, 0.1f};
0166   Vi expectedSrcNodes = {j};
0167   Vi expectedDstNodes = {1};
0168   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0169                       expectedDstNodes);
0170 }
0171 
0172 BOOST_AUTO_TEST_CASE(test_no_out_edges) {
0173   int j = 0;
0174   Vi srcNodes = {1, 2};
0175   Vi dstNodes = {j, j};
0176   Vf scores = {0.9f, 0.1f};
0177   Vi expectedSrcNodes = {1};
0178   Vi expectedDstNodes = {j};
0179   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0180                       expectedDstNodes);
0181 }
0182 
0183 BOOST_AUTO_TEST_CASE(test_two_junctions) {
0184   int j1 = 2;
0185   int j2 = 3;
0186   Vi srcNodes = {0, 1, j1, j2, j2};
0187   Vi dstNodes = {j1, j1, j2, 4, 5};
0188   Vf scores = {0.9f, 0.1f, 0.9f, 0.1f, 0.9f};
0189   Vi expectedSrcNodes = {0, j1, j2};
0190   Vi expectedDstNodes = {j1, j2, 5};
0191 
0192   testJunctionRemoval(srcNodes, dstNodes, scores, expectedSrcNodes,
0193                       expectedDstNodes);
0194 }