Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:25:34

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include <boost/test/unit_test.hpp>
0010 
0011 #include "ActsPlugins/Gnn/detail/TensorVectorConversion.hpp"
0012 
0013 #include <iostream>
0014 
0015 #include <torch/torch.h>
0016 
0017 using namespace ActsPlugins::detail;
0018 
0019 namespace ActsTests {
0020 
0021 BOOST_AUTO_TEST_SUITE(GnnSuite)
0022 
0023 BOOST_AUTO_TEST_CASE(test_vector_tensor_conversion_int_2cols) {
0024   std::vector<std::int64_t> start_vec = {
0025       // clang-format off
0026     0, 1,
0027     1, 2,
0028     2, 3,
0029     3, 4
0030       // clang-format on
0031   };
0032 
0033   auto tensor = vectorToTensor2D(start_vec, 2).clone();
0034 
0035   BOOST_CHECK_EQUAL(tensor.options().dtype(), torch::kInt64);
0036   BOOST_CHECK_EQUAL(tensor.sizes().size(), 2);
0037   BOOST_CHECK_EQUAL(tensor.size(0), 4);
0038   BOOST_CHECK_EQUAL(tensor.size(1), 2);
0039 
0040   BOOST_CHECK_EQUAL(tensor[0][0].item<std::int64_t>(), 0);
0041   BOOST_CHECK_EQUAL(tensor[0][1].item<std::int64_t>(), 1);
0042 
0043   BOOST_CHECK_EQUAL(tensor[1][0].item<std::int64_t>(), 1);
0044   BOOST_CHECK_EQUAL(tensor[1][1].item<std::int64_t>(), 2);
0045 
0046   BOOST_CHECK_EQUAL(tensor[2][0].item<std::int64_t>(), 2);
0047   BOOST_CHECK_EQUAL(tensor[2][1].item<std::int64_t>(), 3);
0048 
0049   BOOST_CHECK_EQUAL(tensor[3][0].item<std::int64_t>(), 3);
0050   BOOST_CHECK_EQUAL(tensor[3][1].item<std::int64_t>(), 4);
0051 
0052   auto test_vec = tensor2DToVector<std::int64_t>(tensor);
0053 
0054   BOOST_CHECK_EQUAL(test_vec, start_vec);
0055 }
0056 
0057 BOOST_AUTO_TEST_CASE(test_vector_tensor_conversion_float_3cols) {
0058   std::vector<float> start_vec = {
0059       // clang-format off
0060     0.f, 0.f, 0.f,
0061     1.f, 1.f, 1.f,
0062     2.f, 2.f, 2.f,
0063     3.f, 3.f, 3.f
0064       // clang-format on
0065   };
0066 
0067   auto tensor = vectorToTensor2D(start_vec, 3).clone();
0068 
0069   BOOST_CHECK_EQUAL(tensor.options().dtype(), torch::kFloat32);
0070   BOOST_CHECK_EQUAL(tensor.sizes().size(), 2);
0071   BOOST_CHECK_EQUAL(tensor.size(0), 4);
0072   BOOST_CHECK_EQUAL(tensor.size(1), 3);
0073 
0074   for (auto i : {0, 1, 2, 3}) {
0075     BOOST_CHECK_EQUAL(tensor[i][0].item<std::int64_t>(), static_cast<float>(i));
0076     BOOST_CHECK_EQUAL(tensor[i][1].item<std::int64_t>(), static_cast<float>(i));
0077     BOOST_CHECK_EQUAL(tensor[i][2].item<std::int64_t>(), static_cast<float>(i));
0078   }
0079 
0080   auto test_vec = tensor2DToVector<float>(tensor);
0081 
0082   BOOST_CHECK_EQUAL(test_vec, start_vec);
0083 }
0084 
0085 BOOST_AUTO_TEST_CASE(test_slicing) {
0086   std::vector<float> start_vec = {
0087       // clang-format off
0088     0.f, 4.f, 0.f,
0089     1.f, 5.f, 1.f,
0090     2.f, 6.f, 2.f,
0091     3.f, 7.f, 3.f
0092       // clang-format on
0093   };
0094 
0095   auto tensor = vectorToTensor2D(start_vec, 3).clone();
0096 
0097   using namespace torch::indexing;
0098   tensor = tensor.index({Slice{}, Slice{0, None, 2}});
0099 
0100   BOOST_CHECK_EQUAL(tensor.size(0), 4);
0101   BOOST_CHECK_EQUAL(tensor.size(1), 2);
0102 
0103   const std::vector<float> ref_vec = {
0104       // clang-format off
0105     0.f, 0.f,
0106     1.f, 1.f,
0107     2.f, 2.f,
0108     3.f, 3.f,
0109       // clang-format on
0110   };
0111 
0112   const auto test_vec = tensor2DToVector<float>(tensor);
0113 
0114   BOOST_CHECK_EQUAL(test_vec, ref_vec);
0115 }
0116 
0117 BOOST_AUTO_TEST_SUITE_END()
0118 
0119 }  // namespace ActsTests