File indexing completed on 2026-05-27 07:24:14
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "detray/definitions/detail/cuda_definitions.hpp"
0010 #include "navigation_validation.hpp"
0011
0012 namespace detray::cuda {
0013
0014 template <typename bfield_t, typename detector_t,
0015 typename intersection_record_t>
0016 __global__ void navigation_validation_kernel(
0017 typename detector_t::view_type det_data, const propagation::config cfg,
0018 pdg_particle<typename detector_t::scalar_type> ptc_hypo,
0019 bfield_t field_data,
0020 vecmem::data::jagged_vector_view<const intersection_record_t>
0021 truth_intersection_traces_view,
0022 vecmem::data::jagged_vector_view<navigation::detail::candidate_record<
0023 typename intersection_record_t::intersection_type>>
0024 recorded_intersections_view,
0025 vecmem::data::vector_view<
0026 material_validator::material_record<typename detector_t::scalar_type>>
0027 mat_records_view,
0028 vecmem::data::jagged_vector_view<
0029 material_validator::material_params<typename detector_t::scalar_type>>
0030 mat_steps_view) {
0031 using detector_device_t =
0032 detector<typename detector_t::metadata, device_container_types>;
0033 using algebra_t = typename detector_device_t::algebra_type;
0034 using scalar_t = dscalar<algebra_t>;
0035
0036 static_assert(std::is_same_v<typename detector_t::view_type,
0037 typename detector_device_t::view_type>,
0038 "Host and device detector view types do not match");
0039
0040 using hom_bfield_view_t = typename bfield::const_field_t<scalar_t>::view_t;
0041 using rk_stepper_t = rk_stepper<hom_bfield_view_t, algebra_t>;
0042 using line_stepper_t = line_stepper<algebra_t>;
0043
0044 static constexpr auto is_no_bfield{
0045 std::is_same_v<bfield_t, navigation_validator::empty_bfield>};
0046 using stepper_t =
0047 std::conditional_t<is_no_bfield, line_stepper_t, rk_stepper_t>;
0048
0049
0050 using intersection_t = typename intersection_record_t::intersection_type;
0051 using object_tracer_t =
0052 navigation::object_tracer<intersection_t, vecmem::device_vector,
0053 navigation::status::e_on_object,
0054 navigation::status::e_on_portal>;
0055
0056 using navigator_t =
0057 caching_navigator<detector_device_t, navigation::default_cache_size,
0058 object_tracer_t, intersection_t>;
0059
0060
0061 using material_tracer_t =
0062 material_validator::material_tracer<scalar_t, vecmem::device_vector>;
0063 using pathlimit_aborter_t = actor::pathlimit_aborter<scalar_t>;
0064 using actor_chain_t = actor_chain<pathlimit_aborter_t, material_tracer_t>;
0065 using propagator_t = propagator<stepper_t, navigator_t, actor_chain_t>;
0066
0067 detector_device_t det(det_data);
0068
0069 vecmem::jagged_device_vector<const intersection_record_t>
0070 truth_intersection_traces(truth_intersection_traces_view);
0071 vecmem::jagged_device_vector<
0072 navigation::detail::candidate_record<intersection_t>>
0073 recorded_intersections(recorded_intersections_view);
0074 vecmem::device_vector<typename material_tracer_t::material_record_type>
0075 mat_records(mat_records_view);
0076 vecmem::jagged_device_vector<typename material_tracer_t::material_params_type>
0077 mat_steps(mat_steps_view);
0078
0079
0080 assert(truth_intersection_traces.size() ==
0081 recorded_intersections_view.size());
0082
0083 int trk_id = threadIdx.x + blockIdx.x * blockDim.x;
0084 if (trk_id >= truth_intersection_traces.size()) {
0085 return;
0086 }
0087
0088 propagator_t p{cfg};
0089
0090
0091 typename pathlimit_aborter_t::state aborter_state{cfg.stepping.path_limit};
0092 typename material_tracer_t::state mat_tracer_state{mat_steps.at(trk_id)};
0093 auto actor_states = ::detray::tie(aborter_state, mat_tracer_state);
0094
0095
0096 const auto &track = truth_intersection_traces[trk_id].front().track_param;
0097
0098
0099
0100 assert(recorded_intersections.at(trk_id).empty());
0101 recorded_intersections.at(trk_id).push_back(
0102 {track.pos(), track.dir(),
0103 truth_intersection_traces[trk_id].front().intersection});
0104
0105 assert(recorded_intersections.at(trk_id).size() == 1);
0106
0107
0108 if constexpr (is_no_bfield) {
0109 typename propagator_t::state propagation(
0110 track, det,
0111 typename navigator_t::state::view_type{
0112 recorded_intersections_view.ptr()[trk_id]});
0113 propagation.set_particle(update_particle_hypothesis(ptc_hypo, track));
0114
0115 p.propagate(propagation, actor_states);
0116 } else {
0117 typename propagator_t::stepper_type::magnetic_field_type bfield_view(
0118 field_data);
0119 typename propagator_t::state propagation(
0120 track, bfield_view, det,
0121 typename navigator_t::state::view_type{
0122 recorded_intersections_view.ptr()[trk_id]});
0123 propagation.set_particle(update_particle_hypothesis(ptc_hypo, track));
0124
0125 p.propagate(propagation, actor_states);
0126 }
0127
0128
0129 assert(truth_intersection_traces.size() == mat_records.size());
0130 mat_records.at(trk_id) = mat_tracer_state.get_material_record();
0131 }
0132
0133
0134 template <typename bfield_t, typename detector_t,
0135 typename intersection_record_t>
0136 void navigation_validation_device(
0137 typename detector_t::view_type det_view, const propagation::config &cfg,
0138 pdg_particle<typename detector_t::scalar_type> ptc_hypo,
0139 bfield_t field_data,
0140 vecmem::data::jagged_vector_view<const intersection_record_t>
0141 &truth_intersection_traces_view,
0142 vecmem::data::jagged_vector_view<navigation::detail::candidate_record<
0143 typename intersection_record_t::intersection_type>>
0144 &recorded_intersections_view,
0145 vecmem::data::vector_view<
0146 material_validator::material_record<typename detector_t::scalar_type>>
0147 &mat_records_view,
0148 vecmem::data::jagged_vector_view<
0149 material_validator::material_params<typename detector_t::scalar_type>>
0150 &mat_steps_view) {
0151 constexpr int thread_dim = 2 * WARP_SIZE;
0152 int block_dim = truth_intersection_traces_view.size() / thread_dim + 1;
0153
0154
0155 navigation_validation_kernel<bfield_t, detector_t, intersection_record_t>
0156 <<<block_dim, thread_dim>>>(
0157 det_view, cfg, ptc_hypo, field_data, truth_intersection_traces_view,
0158 recorded_intersections_view, mat_records_view, mat_steps_view);
0159
0160
0161 DETRAY_CUDA_ERROR_CHECK(cudaGetLastError());
0162 DETRAY_CUDA_ERROR_CHECK(cudaDeviceSynchronize());
0163 }
0164
0165
0166 #define DECLARE_NAVIGATION_VALIDATION(METADATA) \
0167 \
0168 template void navigation_validation_device< \
0169 covfie::field_view< \
0170 bfield::const_bknd_t<dscalar<typename METADATA::algebra_type>>>, \
0171 detector<METADATA>, detray::intersection_record<detector<METADATA>>>( \
0172 typename detector<METADATA>::view_type, const propagation::config &, \
0173 pdg_particle<typename detector<METADATA>::scalar_type>, \
0174 covfie::field_view< \
0175 bfield::const_bknd_t<dscalar<typename METADATA::algebra_type>>>, \
0176 vecmem::data::jagged_vector_view< \
0177 const detray::intersection_record<detector<METADATA>>> &, \
0178 vecmem::data::jagged_vector_view<navigation::detail::candidate_record< \
0179 typename detray::intersection_record< \
0180 detector<METADATA>>::intersection_type>> &, \
0181 vecmem::data::vector_view<material_validator::material_record< \
0182 typename detector<METADATA>::scalar_type>> &, \
0183 vecmem::data::jagged_vector_view<material_validator::material_params< \
0184 typename detector<METADATA>::scalar_type>> &); \
0185 \
0186 template void navigation_validation_device< \
0187 detray::navigation_validator::empty_bfield, detector<METADATA>, \
0188 detray::intersection_record<detector<METADATA>>>( \
0189 typename detector<METADATA>::view_type, const propagation::config &, \
0190 pdg_particle<typename detector<METADATA>::scalar_type>, \
0191 detray::navigation_validator::empty_bfield, \
0192 vecmem::data::jagged_vector_view< \
0193 const detray::intersection_record<detector<METADATA>>> &, \
0194 vecmem::data::jagged_vector_view<navigation::detail::candidate_record< \
0195 typename detray::intersection_record< \
0196 detector<METADATA>>::intersection_type>> &, \
0197 vecmem::data::vector_view<material_validator::material_record< \
0198 typename detector<METADATA>::scalar_type>> &, \
0199 vecmem::data::jagged_vector_view<material_validator::material_params< \
0200 typename detector<METADATA>::scalar_type>> &);
0201
0202 DECLARE_NAVIGATION_VALIDATION(test::default_metadata)
0203 DECLARE_NAVIGATION_VALIDATION(test::toy_metadata)
0204 DECLARE_NAVIGATION_VALIDATION(test::default_telescope_metadata)
0205 DECLARE_NAVIGATION_VALIDATION(test::wire_chamber_metadata)
0206
0207 }