File indexing completed on 2026-05-27 07:24:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "detray/definitions/detail/cuda_definitions.hpp"
0010 #include "detray/propagator/actors.hpp"
0011 #include "detray/propagator/line_stepper.hpp"
0012 #include "material_validation.hpp"
0013
0014 namespace detray::cuda {
0015
0016 template <typename detector_t>
0017 __global__ void material_validation_kernel(
0018 typename detector_t::view_type det_data, const propagation::config cfg,
0019 vecmem::data::vector_view<
0020 free_track_parameters<typename detector_t::algebra_type>>
0021 tracks_view,
0022 vecmem::data::vector_view<
0023 material_validator::material_record<typename detector_t::scalar_type>>
0024 mat_records_view,
0025 vecmem::data::jagged_vector_view<
0026 material_validator::material_params<typename detector_t::scalar_type>>
0027 mat_steps_view) {
0028 using detector_device_t =
0029 detector<typename detector_t::metadata, device_container_types>;
0030 using algebra_t = typename detector_device_t::algebra_type;
0031 using scalar_t = dscalar<algebra_t>;
0032
0033 using stepper_t = line_stepper<algebra_t>;
0034 using navigator_t = caching_navigator<detector_device_t>;
0035
0036
0037 using material_tracer_t =
0038 material_validator::material_tracer<scalar_t, vecmem::device_vector>;
0039 using pathlimit_aborter_t = actor::pathlimit_aborter<scalar_t>;
0040 using actor_chain_t = actor_chain<
0041 pathlimit_aborter_t,
0042 actor::parameter_updater<algebra_t,
0043 actor::pointwise_material_interactor<algebra_t>,
0044 material_tracer_t>>;
0045 using propagator_t = propagator<stepper_t, navigator_t, actor_chain_t>;
0046
0047 detector_device_t det(det_data);
0048
0049 vecmem::device_vector<free_track_parameters<algebra_t>> tracks(tracks_view);
0050 vecmem::device_vector<typename material_tracer_t::material_record_type>
0051 mat_records(mat_records_view);
0052 vecmem::jagged_device_vector<typename material_tracer_t::material_params_type>
0053 mat_steps(mat_steps_view);
0054
0055 int trk_id = threadIdx.x + blockIdx.x * blockDim.x;
0056 if (trk_id >= tracks.size()) {
0057 return;
0058 }
0059
0060 propagator_t p{cfg};
0061
0062
0063 typename pathlimit_aborter_t::state aborter_state{cfg.stepping.path_limit};
0064 actor::parameter_updater_state<algebra_t> updater_state{cfg};
0065 typename actor::pointwise_material_interactor<algebra_t>::state
0066 interactor_state{};
0067 typename material_tracer_t::state mat_tracer_state{mat_steps.at(trk_id)};
0068
0069 auto actor_states = ::detray::tie(aborter_state, updater_state,
0070 interactor_state, mat_tracer_state);
0071
0072
0073 typename navigator_t::state::view_type nav_view{};
0074 typename propagator_t::state propagation(tracks[trk_id], det, nav_view);
0075
0076 p.propagate(propagation, actor_states);
0077
0078
0079 assert(mat_records.size() == tracks.size());
0080 mat_records.at(trk_id) = mat_tracer_state.get_material_record();
0081 }
0082
0083
0084 template <typename detector_t>
0085 void material_validation_device(
0086 typename detector_t::view_type det_view, const propagation::config &cfg,
0087 vecmem::data::vector_view<
0088 free_track_parameters<typename detector_t::algebra_type>> &tracks_view,
0089 vecmem::data::vector_view<
0090 material_validator::material_record<typename detector_t::scalar_type>>
0091 &mat_records_view,
0092 vecmem::data::jagged_vector_view<
0093 material_validator::material_params<typename detector_t::scalar_type>>
0094 &mat_steps_view) {
0095 constexpr int thread_dim = 2 * WARP_SIZE;
0096 int block_dim = tracks_view.size() / thread_dim + 1;
0097
0098
0099 material_validation_kernel<detector_t><<<block_dim, thread_dim>>>(
0100 det_view, cfg, tracks_view, mat_records_view, mat_steps_view);
0101
0102
0103 DETRAY_CUDA_ERROR_CHECK(cudaGetLastError());
0104 DETRAY_CUDA_ERROR_CHECK(cudaDeviceSynchronize());
0105 }
0106
0107
0108 #define DECLARE_MATERIAL_VALIDATION(METADATA) \
0109 \
0110 template void material_validation_device<detector<METADATA>>( \
0111 typename detector<METADATA>::view_type, const propagation::config &, \
0112 vecmem::data::vector_view< \
0113 free_track_parameters<typename detector<METADATA>::algebra_type>> &, \
0114 vecmem::data::vector_view<material_validator::material_record< \
0115 typename detector<METADATA>::scalar_type>> &, \
0116 vecmem::data::jagged_vector_view<material_validator::material_params< \
0117 typename detector<METADATA>::scalar_type>> &);
0118
0119 DECLARE_MATERIAL_VALIDATION(test::default_metadata)
0120 DECLARE_MATERIAL_VALIDATION(test::toy_metadata)
0121 DECLARE_MATERIAL_VALIDATION(test::default_telescope_metadata)
0122 DECLARE_MATERIAL_VALIDATION(test::wire_chamber_metadata)
0123
0124 }