File indexing completed on 2025-12-16 09:25:34
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <boost/test/unit_test.hpp>
0010
0011 #include "ActsPlugins/Gnn/detail/TensorVectorConversion.hpp"
0012
0013 #include <iostream>
0014
0015 #include <torch/torch.h>
0016
0017 using namespace ActsPlugins::detail;
0018
0019 namespace ActsTests {
0020
0021 BOOST_AUTO_TEST_SUITE(GnnSuite)
0022
0023 BOOST_AUTO_TEST_CASE(test_vector_tensor_conversion_int_2cols) {
0024 std::vector<std::int64_t> start_vec = {
0025
0026 0, 1,
0027 1, 2,
0028 2, 3,
0029 3, 4
0030
0031 };
0032
0033 auto tensor = vectorToTensor2D(start_vec, 2).clone();
0034
0035 BOOST_CHECK_EQUAL(tensor.options().dtype(), torch::kInt64);
0036 BOOST_CHECK_EQUAL(tensor.sizes().size(), 2);
0037 BOOST_CHECK_EQUAL(tensor.size(0), 4);
0038 BOOST_CHECK_EQUAL(tensor.size(1), 2);
0039
0040 BOOST_CHECK_EQUAL(tensor[0][0].item<std::int64_t>(), 0);
0041 BOOST_CHECK_EQUAL(tensor[0][1].item<std::int64_t>(), 1);
0042
0043 BOOST_CHECK_EQUAL(tensor[1][0].item<std::int64_t>(), 1);
0044 BOOST_CHECK_EQUAL(tensor[1][1].item<std::int64_t>(), 2);
0045
0046 BOOST_CHECK_EQUAL(tensor[2][0].item<std::int64_t>(), 2);
0047 BOOST_CHECK_EQUAL(tensor[2][1].item<std::int64_t>(), 3);
0048
0049 BOOST_CHECK_EQUAL(tensor[3][0].item<std::int64_t>(), 3);
0050 BOOST_CHECK_EQUAL(tensor[3][1].item<std::int64_t>(), 4);
0051
0052 auto test_vec = tensor2DToVector<std::int64_t>(tensor);
0053
0054 BOOST_CHECK_EQUAL(test_vec, start_vec);
0055 }
0056
0057 BOOST_AUTO_TEST_CASE(test_vector_tensor_conversion_float_3cols) {
0058 std::vector<float> start_vec = {
0059
0060 0.f, 0.f, 0.f,
0061 1.f, 1.f, 1.f,
0062 2.f, 2.f, 2.f,
0063 3.f, 3.f, 3.f
0064
0065 };
0066
0067 auto tensor = vectorToTensor2D(start_vec, 3).clone();
0068
0069 BOOST_CHECK_EQUAL(tensor.options().dtype(), torch::kFloat32);
0070 BOOST_CHECK_EQUAL(tensor.sizes().size(), 2);
0071 BOOST_CHECK_EQUAL(tensor.size(0), 4);
0072 BOOST_CHECK_EQUAL(tensor.size(1), 3);
0073
0074 for (auto i : {0, 1, 2, 3}) {
0075 BOOST_CHECK_EQUAL(tensor[i][0].item<std::int64_t>(), static_cast<float>(i));
0076 BOOST_CHECK_EQUAL(tensor[i][1].item<std::int64_t>(), static_cast<float>(i));
0077 BOOST_CHECK_EQUAL(tensor[i][2].item<std::int64_t>(), static_cast<float>(i));
0078 }
0079
0080 auto test_vec = tensor2DToVector<float>(tensor);
0081
0082 BOOST_CHECK_EQUAL(test_vec, start_vec);
0083 }
0084
0085 BOOST_AUTO_TEST_CASE(test_slicing) {
0086 std::vector<float> start_vec = {
0087
0088 0.f, 4.f, 0.f,
0089 1.f, 5.f, 1.f,
0090 2.f, 6.f, 2.f,
0091 3.f, 7.f, 3.f
0092
0093 };
0094
0095 auto tensor = vectorToTensor2D(start_vec, 3).clone();
0096
0097 using namespace torch::indexing;
0098 tensor = tensor.index({Slice{}, Slice{0, None, 2}});
0099
0100 BOOST_CHECK_EQUAL(tensor.size(0), 4);
0101 BOOST_CHECK_EQUAL(tensor.size(1), 2);
0102
0103 const std::vector<float> ref_vec = {
0104
0105 0.f, 0.f,
0106 1.f, 1.f,
0107 2.f, 2.f,
0108 3.f, 3.f,
0109
0110 };
0111
0112 const auto test_vec = tensor2DToVector<float>(tensor);
0113
0114 BOOST_CHECK_EQUAL(test_vec, ref_vec);
0115 }
0116
0117 BOOST_AUTO_TEST_SUITE_END()
0118
0119 }