File indexing completed on 2026-05-27 07:24:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011
0012 #include "detray/definitions/detail/cuda_definitions.hpp"
0013
0014
0015 #include <cstddef>
0016
0017 namespace detray::test::cuda {
0018
0019 template <class functor_t, typename... Args>
0020 __global__ void cuda_test_kernel(std::size_t array_sizes, Args... args) {
0021
0022 const std::size_t i = blockIdx.x * blockDim.x + threadIdx.x;
0023 if (i >= array_sizes) {
0024 return;
0025 }
0026
0027
0028 functor_t()(i, std::forward<Args>(args)...);
0029 }
0030
0031
0032 template <class functor_t, class... Args>
0033 void execute_cuda_test(std::size_t array_sizes, Args... args) {
0034
0035
0036 const int n_threads_per_block{std::min(256, static_cast<int>(array_sizes))};
0037 const int n_blocks{(static_cast<int>(array_sizes) + n_threads_per_block - 1) /
0038 n_threads_per_block};
0039
0040
0041 cuda_test_kernel<functor_t><<<n_blocks, n_threads_per_block>>>(
0042 array_sizes, std::forward<Args>(args)...);
0043
0044
0045 DETRAY_CUDA_ERROR_CHECK(cudaGetLastError());
0046 DETRAY_CUDA_ERROR_CHECK(cudaDeviceSynchronize());
0047 }
0048
0049 }