File indexing completed on 2026-05-27 07:24:25
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #include "navigator_cuda_kernel.hpp"
0011
0012
0013 #include "detray/test/common/build_toy_detector.hpp"
0014
0015
0016 #include <vecmem/memory/cuda/device_memory_resource.hpp>
0017 #include <vecmem/memory/cuda/managed_memory_resource.hpp>
0018 #include <vecmem/utils/cuda/copy.hpp>
0019
0020
0021 #include <gtest/gtest.h>
0022
0023 using namespace detray;
0024
0025 TEST(navigator_cuda, navigator) {
0026
0027 vecmem::cuda::copy copy;
0028
0029
0030 vecmem::cuda::managed_memory_resource mng_mr;
0031 vecmem::cuda::device_memory_resource dev_mr;
0032
0033
0034 auto [det, names] = build_toy_detector<test_algebra>(mng_mr);
0035
0036 propagation::config prop_cfg{};
0037 navigation::config& nav_cfg = prop_cfg.navigation;
0038 stepping::config& step_cfg = prop_cfg.stepping;
0039
0040
0041 navigator_host_t nav;
0042 nav_cfg.search_window = {3u, 3u};
0043
0044
0045 vecmem::vector<free_track_parameters<test_algebra>> tracks_host(&mng_mr);
0046 vecmem::vector<free_track_parameters<test_algebra>> tracks_device(&mng_mr);
0047
0048
0049 const scalar p_mag{10.f * unit<scalar>::GeV};
0050
0051
0052 for (auto track :
0053 uniform_track_generator<free_track_parameters<test_algebra>>(
0054 phi_steps, theta_steps, p_mag)) {
0055 tracks_host.push_back(track);
0056 tracks_device.push_back(track);
0057 }
0058
0059
0060 vecmem::jagged_vector<dindex> volume_records_host(theta_steps * phi_steps,
0061 &mng_mr);
0062 vecmem::jagged_vector<point3> position_records_host(theta_steps * phi_steps,
0063 &mng_mr);
0064
0065 for (unsigned int i = 0u; i < theta_steps * phi_steps; i++) {
0066 auto& track = tracks_host[i];
0067 stepper_t stepper;
0068
0069
0070 prop_state<navigator_host_t::state> propagation{
0071 stepper_t::state{track}, navigator_host_t::state(det)};
0072
0073 navigator_host_t::state& navigation = propagation.navigation();
0074 stepper_t::state& stepping = propagation.stepping();
0075 const auto& ctx = propagation.context();
0076
0077
0078 nav.init(stepping(), navigation, nav_cfg, ctx);
0079 bool heartbeat = navigation.is_alive();
0080 bool do_reset{true};
0081
0082 while (heartbeat) {
0083 heartbeat =
0084 heartbeat && stepper.step(navigation(), stepping, step_cfg, do_reset);
0085
0086 navigation.set_high_trust();
0087
0088 do_reset = nav.update(stepping(), navigation, nav_cfg, prop_cfg.context);
0089 do_reset = do_reset || navigation.is_on_surface();
0090 heartbeat = heartbeat && navigation.is_alive();
0091
0092
0093 volume_records_host[i].push_back(navigation.volume());
0094 position_records_host[i].push_back(stepping().pos());
0095 }
0096 }
0097
0098
0099 vecmem::jagged_vector<dindex> volume_records_device(&mng_mr);
0100 vecmem::jagged_vector<point3> position_records_device(&mng_mr);
0101
0102
0103 std::vector<std::size_t> capacities;
0104
0105 for (unsigned int i = 0u; i < theta_steps * phi_steps; i++) {
0106 capacities.push_back(volume_records_host[i].size());
0107 }
0108
0109 vecmem::data::jagged_vector_buffer<dindex> volume_records_buffer(
0110 capacities, dev_mr, &mng_mr, vecmem::data::buffer_type::resizable);
0111 copy.setup(volume_records_buffer)->wait();
0112
0113 vecmem::data::jagged_vector_buffer<point3> position_records_buffer(
0114 capacities, dev_mr, &mng_mr, vecmem::data::buffer_type::resizable);
0115 copy.setup(position_records_buffer)->wait();
0116
0117
0118 auto det_data = detray::get_data(det);
0119
0120
0121 auto tracks_data = vecmem::get_data(tracks_device);
0122
0123
0124 navigator_test(det_data, prop_cfg, tracks_data, volume_records_buffer,
0125 position_records_buffer);
0126
0127
0128 copy(volume_records_buffer, volume_records_device)->wait();
0129 copy(position_records_buffer, position_records_device)->wait();
0130
0131 for (unsigned int i = 0u; i < volume_records_host.size(); i++) {
0132 EXPECT_EQ(volume_records_host[i].size(), volume_records_device[i].size());
0133
0134 for (unsigned int j = 0u; j < volume_records_host[i].size(); j++) {
0135 EXPECT_EQ(volume_records_host[i][j], volume_records_device[i][j]);
0136
0137 auto& pos_host = position_records_host[i][j];
0138 auto& pos_device = position_records_device[i][j];
0139
0140 EXPECT_NEAR(pos_host[0], pos_device[0], pos_diff_tolerance);
0141 EXPECT_NEAR(pos_host[1], pos_device[1], pos_diff_tolerance);
0142 EXPECT_NEAR(pos_host[2], pos_device[2], pos_diff_tolerance);
0143 }
0144 }
0145 }