File indexing completed on 2025-01-18 09:13:09
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <boost/test/unit_test.hpp>
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012
0013 #include <iostream>
0014
0015 #include <torch/torch.h>
0016
0017 BOOST_AUTO_TEST_CASE(test_vector_tensor_conversion_int_2cols) {
0018 std::vector<std::int64_t> start_vec = {
0019
0020 0, 1,
0021 1, 2,
0022 2, 3,
0023 3, 4
0024
0025 };
0026
0027 auto tensor = Acts::detail::vectorToTensor2D(start_vec, 2).clone();
0028
0029 BOOST_CHECK_EQUAL(tensor.options().dtype(), torch::kInt64);
0030 BOOST_CHECK_EQUAL(tensor.sizes().size(), 2);
0031 BOOST_CHECK_EQUAL(tensor.size(0), 4);
0032 BOOST_CHECK_EQUAL(tensor.size(1), 2);
0033
0034 BOOST_CHECK_EQUAL(tensor[0][0].item<std::int64_t>(), 0);
0035 BOOST_CHECK_EQUAL(tensor[0][1].item<std::int64_t>(), 1);
0036
0037 BOOST_CHECK_EQUAL(tensor[1][0].item<std::int64_t>(), 1);
0038 BOOST_CHECK_EQUAL(tensor[1][1].item<std::int64_t>(), 2);
0039
0040 BOOST_CHECK_EQUAL(tensor[2][0].item<std::int64_t>(), 2);
0041 BOOST_CHECK_EQUAL(tensor[2][1].item<std::int64_t>(), 3);
0042
0043 BOOST_CHECK_EQUAL(tensor[3][0].item<std::int64_t>(), 3);
0044 BOOST_CHECK_EQUAL(tensor[3][1].item<std::int64_t>(), 4);
0045
0046 auto test_vec = Acts::detail::tensor2DToVector<std::int64_t>(tensor);
0047
0048 BOOST_CHECK_EQUAL(test_vec, start_vec);
0049 }
0050
0051 BOOST_AUTO_TEST_CASE(test_vector_tensor_conversion_float_3cols) {
0052 std::vector<float> start_vec = {
0053
0054 0.f, 0.f, 0.f,
0055 1.f, 1.f, 1.f,
0056 2.f, 2.f, 2.f,
0057 3.f, 3.f, 3.f
0058
0059 };
0060
0061 auto tensor = Acts::detail::vectorToTensor2D(start_vec, 3).clone();
0062
0063 BOOST_CHECK_EQUAL(tensor.options().dtype(), torch::kFloat32);
0064 BOOST_CHECK_EQUAL(tensor.sizes().size(), 2);
0065 BOOST_CHECK_EQUAL(tensor.size(0), 4);
0066 BOOST_CHECK_EQUAL(tensor.size(1), 3);
0067
0068 for (auto i : {0, 1, 2, 3}) {
0069 BOOST_CHECK_EQUAL(tensor[i][0].item<std::int64_t>(), static_cast<float>(i));
0070 BOOST_CHECK_EQUAL(tensor[i][1].item<std::int64_t>(), static_cast<float>(i));
0071 BOOST_CHECK_EQUAL(tensor[i][2].item<std::int64_t>(), static_cast<float>(i));
0072 }
0073
0074 auto test_vec = Acts::detail::tensor2DToVector<float>(tensor);
0075
0076 BOOST_CHECK_EQUAL(test_vec, start_vec);
0077 }
0078
0079 BOOST_AUTO_TEST_CASE(test_slicing) {
0080 std::vector<float> start_vec = {
0081
0082 0.f, 4.f, 0.f,
0083 1.f, 5.f, 1.f,
0084 2.f, 6.f, 2.f,
0085 3.f, 7.f, 3.f
0086
0087 };
0088
0089 auto tensor = Acts::detail::vectorToTensor2D(start_vec, 3).clone();
0090
0091 using namespace torch::indexing;
0092 tensor = tensor.index({Slice{}, Slice{0, None, 2}});
0093
0094 BOOST_CHECK_EQUAL(tensor.size(0), 4);
0095 BOOST_CHECK_EQUAL(tensor.size(1), 2);
0096
0097 const std::vector<float> ref_vec = {
0098
0099 0.f, 0.f,
0100 1.f, 1.f,
0101 2.f, 2.f,
0102 3.f, 3.f,
0103
0104 };
0105
0106 const auto test_vec = Acts::detail::tensor2DToVector<float>(tensor);
0107
0108 BOOST_CHECK_EQUAL(test_vec, ref_vec);
0109 }