Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:24:14

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 // Test include(s).
0012 #include "detray/test/device/execute_host_test.hpp"
0013 #include "detray/test/device/matrix_fixture.hpp"
0014 #include "detray/test/device/sycl/execute_sycl_test.hpp"
0015 #include "detray/test/device/sycl/sycl_test_fixture.hpp"
0016 #include "detray/test/device/transform_fixture.hpp"
0017 #include "detray/test/device/vector_fixture.hpp"
0018 
0019 // GoogleTest include(s).
0020 #include <gtest/gtest.h>
0021 
0022 // SYCL include(s).
0023 #include <sycl/sycl.hpp>
0024 
0025 namespace detray::test::sycl {
0026 
0027 /// Test case class, to be specialised for the different plugins
0028 template <detray::concepts::algebra A>
0029 class sycl_vector_test : public sycl_test_fixture<vector_fixture<A>> {};
0030 
0031 /// Test case class, to be specialised for the different plugins
0032 template <detray::concepts::algebra A>
0033 class sycl_matrix_test : public sycl_test_fixture<matrix_fixture<A>> {};
0034 
0035 /// Test case class, to be specialised for the different plugins
0036 template <detray::concepts::algebra A>
0037 class sycl_transform_test : public sycl_test_fixture<transform_fixture<A>> {};
0038 
0039 TYPED_TEST_SUITE_P(sycl_vector_test);
0040 TYPED_TEST_SUITE_P(sycl_matrix_test);
0041 TYPED_TEST_SUITE_P(sycl_transform_test);
0042 
0043 /// Test for some basic 2D "vector operations"
0044 TYPED_TEST_P(sycl_vector_test, vector_2d_ops) {
0045   // Don't run the test at double precision, if the SYCL device doesn't
0046   // support it.
0047   if ((typeid(dscalar<TypeParam>) == typeid(double)) &&
0048       (this->m_queue.get_device().has(::sycl::aspect::fp64) == false)) {
0049     GTEST_SKIP();
0050   }
0051 
0052   // Run the test on the host, and on the/a device.
0053   execute_host_test<vector_2d_ops_functor<TypeParam>>(
0054       this->m_p1->size(), vecmem::get_data(*(this->m_p1)),
0055       vecmem::get_data(*(this->m_p2)),
0056       vecmem::get_data(*(this->m_output_host)));
0057 
0058   execute_sycl_test<vector_2d_ops_functor<TypeParam>>(
0059       this->m_queue, this->m_p1->size(), vecmem::get_data(*(this->m_p1)),
0060       vecmem::get_data(*(this->m_p2)),
0061       vecmem::get_data(*(this->m_output_device)));
0062 
0063   // Compare the outputs.
0064   this->compareOutputs();
0065 }
0066 
0067 /// Test for some basic 3D "vector operations"
0068 TYPED_TEST_P(sycl_vector_test, vector_3d_ops) {
0069   // Don't run the test at double precision, if the SYCL device doesn't
0070   // support it.
0071   if ((typeid(dscalar<TypeParam>) == typeid(double)) &&
0072       (this->m_queue.get_device().has(::sycl::aspect::fp64) == false)) {
0073     GTEST_SKIP();
0074   }
0075 
0076   // This test is just not numerically stable at float precision in optimized
0077   // mode on some backends. :-( (Cough... HIP... cough...)
0078 #ifdef NDEBUG
0079   if (typeid(dscalar<TypeParam>) == typeid(float)) {
0080     GTEST_SKIP();
0081   }
0082 #endif  // NDEBUG
0083 
0084   // Run the test on the host, and on the/a device.
0085   execute_host_test<vector_3d_ops_functor<TypeParam>>(
0086       this->m_v1->size(), vecmem::get_data(*(this->m_v1)),
0087       vecmem::get_data(*(this->m_v2)),
0088       vecmem::get_data(*(this->m_output_host)));
0089 
0090   execute_sycl_test<vector_3d_ops_functor<TypeParam>>(
0091       this->m_queue, this->m_v1->size(), vecmem::get_data(*(this->m_v1)),
0092       vecmem::get_data(*(this->m_v2)),
0093       vecmem::get_data(*(this->m_output_device)));
0094 
0095   // Compare the outputs.
0096   this->compareOutputs();
0097 }
0098 /// Test for handling matrices
0099 TYPED_TEST_P(sycl_matrix_test, matrix64_ops) {
0100   // Don't run the test at double precision, if the SYCL device doesn't
0101   // support it.
0102   if ((typeid(dscalar<TypeParam>) == typeid(double)) &&
0103       (this->m_queue.get_device().has(::sycl::aspect::fp64) == false)) {
0104     GTEST_SKIP();
0105   }
0106 
0107   // Run the test on the host, and on the/a device.
0108   execute_host_test<matrix64_ops_functor<TypeParam>>(
0109       this->m_m1->size(), vecmem::get_data(*(this->m_m1)),
0110       vecmem::get_data(*(this->m_output_host)));
0111 
0112   execute_sycl_test<matrix64_ops_functor<TypeParam>>(
0113       this->m_queue, this->m_m1->size(), vecmem::get_data(*(this->m_m1)),
0114       vecmem::get_data(*(this->m_output_device)));
0115 
0116   // Compare the outputs.
0117   this->compareOutputs();
0118 }
0119 
0120 /// Test for handling matrices
0121 TYPED_TEST_P(sycl_matrix_test, matrix22_ops) {
0122   // Don't run the test at double precision, if the SYCL device doesn't
0123   // support it.
0124   if ((typeid(dscalar<TypeParam>) == typeid(double)) &&
0125       (this->m_queue.get_device().has(::sycl::aspect::fp64) == false)) {
0126     GTEST_SKIP();
0127   }
0128 
0129   // Run the test on the host, and on the/a device.
0130   execute_host_test<matrix22_ops_functor<TypeParam>>(
0131       this->m_m2->size(), vecmem::get_data(*(this->m_m2)),
0132       vecmem::get_data(*(this->m_output_host)));
0133 
0134   execute_sycl_test<matrix22_ops_functor<TypeParam>>(
0135       this->m_queue, this->m_m2->size(), vecmem::get_data(*(this->m_m2)),
0136       vecmem::get_data(*(this->m_output_device)));
0137 
0138   // Compare the outputs.
0139   this->compareOutputs();
0140 }
0141 
0142 /// Test for some operations with @c transform3
0143 TYPED_TEST_P(sycl_transform_test, transform3D) {
0144   // Don't run the test at double precision, if the SYCL device doesn't
0145   // support it.
0146   if ((typeid(dscalar<TypeParam>) == typeid(double)) &&
0147       (this->m_queue.get_device().has(::sycl::aspect::fp64) == false)) {
0148     GTEST_SKIP();
0149   }
0150 
0151   // Run the test on the host, and on the/a device.
0152   execute_host_test<transform3_ops_functor<TypeParam>>(
0153       this->m_t1->size(), vecmem::get_data(*(this->m_t1)),
0154       vecmem::get_data(*(this->m_t2)), vecmem::get_data(*(this->m_t3)),
0155       vecmem::get_data(*(this->m_v1)), vecmem::get_data(*(this->m_v2)),
0156       vecmem::get_data(*(this->m_output_host)));
0157 
0158   execute_sycl_test<transform3_ops_functor<TypeParam>>(
0159       this->m_queue, this->m_t1->size(), vecmem::get_data(*(this->m_t1)),
0160       vecmem::get_data(*(this->m_t2)), vecmem::get_data(*(this->m_t3)),
0161       vecmem::get_data(*(this->m_v1)), vecmem::get_data(*(this->m_v2)),
0162       vecmem::get_data(*(this->m_output_device)));
0163 
0164   // Compare the outputs.
0165   this->compareOutputs();
0166 }
0167 
0168 }  // namespace detray::test::sycl