Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:13:09

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include <boost/test/unit_test.hpp>
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 
0013 #include <iostream>
0014 
0015 #include <torch/torch.h>
0016 
0017 BOOST_AUTO_TEST_CASE(test_vector_tensor_conversion_int_2cols) {
0018   std::vector<std::int64_t> start_vec = {
0019       // clang-format off
0020     0, 1,
0021     1, 2,
0022     2, 3,
0023     3, 4
0024       // clang-format on
0025   };
0026 
0027   auto tensor = Acts::detail::vectorToTensor2D(start_vec, 2).clone();
0028 
0029   BOOST_CHECK_EQUAL(tensor.options().dtype(), torch::kInt64);
0030   BOOST_CHECK_EQUAL(tensor.sizes().size(), 2);
0031   BOOST_CHECK_EQUAL(tensor.size(0), 4);
0032   BOOST_CHECK_EQUAL(tensor.size(1), 2);
0033 
0034   BOOST_CHECK_EQUAL(tensor[0][0].item<std::int64_t>(), 0);
0035   BOOST_CHECK_EQUAL(tensor[0][1].item<std::int64_t>(), 1);
0036 
0037   BOOST_CHECK_EQUAL(tensor[1][0].item<std::int64_t>(), 1);
0038   BOOST_CHECK_EQUAL(tensor[1][1].item<std::int64_t>(), 2);
0039 
0040   BOOST_CHECK_EQUAL(tensor[2][0].item<std::int64_t>(), 2);
0041   BOOST_CHECK_EQUAL(tensor[2][1].item<std::int64_t>(), 3);
0042 
0043   BOOST_CHECK_EQUAL(tensor[3][0].item<std::int64_t>(), 3);
0044   BOOST_CHECK_EQUAL(tensor[3][1].item<std::int64_t>(), 4);
0045 
0046   auto test_vec = Acts::detail::tensor2DToVector<std::int64_t>(tensor);
0047 
0048   BOOST_CHECK_EQUAL(test_vec, start_vec);
0049 }
0050 
0051 BOOST_AUTO_TEST_CASE(test_vector_tensor_conversion_float_3cols) {
0052   std::vector<float> start_vec = {
0053       // clang-format off
0054     0.f, 0.f, 0.f,
0055     1.f, 1.f, 1.f,
0056     2.f, 2.f, 2.f,
0057     3.f, 3.f, 3.f
0058       // clang-format on
0059   };
0060 
0061   auto tensor = Acts::detail::vectorToTensor2D(start_vec, 3).clone();
0062 
0063   BOOST_CHECK_EQUAL(tensor.options().dtype(), torch::kFloat32);
0064   BOOST_CHECK_EQUAL(tensor.sizes().size(), 2);
0065   BOOST_CHECK_EQUAL(tensor.size(0), 4);
0066   BOOST_CHECK_EQUAL(tensor.size(1), 3);
0067 
0068   for (auto i : {0, 1, 2, 3}) {
0069     BOOST_CHECK_EQUAL(tensor[i][0].item<std::int64_t>(), static_cast<float>(i));
0070     BOOST_CHECK_EQUAL(tensor[i][1].item<std::int64_t>(), static_cast<float>(i));
0071     BOOST_CHECK_EQUAL(tensor[i][2].item<std::int64_t>(), static_cast<float>(i));
0072   }
0073 
0074   auto test_vec = Acts::detail::tensor2DToVector<float>(tensor);
0075 
0076   BOOST_CHECK_EQUAL(test_vec, start_vec);
0077 }
0078 
0079 BOOST_AUTO_TEST_CASE(test_slicing) {
0080   std::vector<float> start_vec = {
0081       // clang-format off
0082     0.f, 4.f, 0.f,
0083     1.f, 5.f, 1.f,
0084     2.f, 6.f, 2.f,
0085     3.f, 7.f, 3.f
0086       // clang-format on
0087   };
0088 
0089   auto tensor = Acts::detail::vectorToTensor2D(start_vec, 3).clone();
0090 
0091   using namespace torch::indexing;
0092   tensor = tensor.index({Slice{}, Slice{0, None, 2}});
0093 
0094   BOOST_CHECK_EQUAL(tensor.size(0), 4);
0095   BOOST_CHECK_EQUAL(tensor.size(1), 2);
0096 
0097   const std::vector<float> ref_vec = {
0098       // clang-format off
0099     0.f, 0.f,
0100     1.f, 1.f,
0101     2.f, 2.f,
0102     3.f, 3.f,
0103       // clang-format on
0104   };
0105 
0106   const auto test_vec = Acts::detail::tensor2DToVector<float>(tensor);
0107 
0108   BOOST_CHECK_EQUAL(test_vec, ref_vec);
0109 }