Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:24:14

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "detray/definitions/detail/cuda_definitions.hpp"
0010 #include "navigation_validation.hpp"
0011 
0012 namespace detray::cuda {
0013 
0014 template <typename bfield_t, typename detector_t,
0015           typename intersection_record_t>
0016 __global__ void navigation_validation_kernel(
0017     typename detector_t::view_type det_data, const propagation::config cfg,
0018     pdg_particle<typename detector_t::scalar_type> ptc_hypo,
0019     bfield_t field_data,
0020     vecmem::data::jagged_vector_view<const intersection_record_t>
0021         truth_intersection_traces_view,
0022     vecmem::data::jagged_vector_view<navigation::detail::candidate_record<
0023         typename intersection_record_t::intersection_type>>
0024         recorded_intersections_view,
0025     vecmem::data::vector_view<
0026         material_validator::material_record<typename detector_t::scalar_type>>
0027         mat_records_view,
0028     vecmem::data::jagged_vector_view<
0029         material_validator::material_params<typename detector_t::scalar_type>>
0030         mat_steps_view) {
0031   using detector_device_t =
0032       detector<typename detector_t::metadata, device_container_types>;
0033   using algebra_t = typename detector_device_t::algebra_type;
0034   using scalar_t = dscalar<algebra_t>;
0035 
0036   static_assert(std::is_same_v<typename detector_t::view_type,
0037                                typename detector_device_t::view_type>,
0038                 "Host and device detector view types do not match");
0039 
0040   using hom_bfield_view_t = typename bfield::const_field_t<scalar_t>::view_t;
0041   using rk_stepper_t = rk_stepper<hom_bfield_view_t, algebra_t>;
0042   using line_stepper_t = line_stepper<algebra_t>;
0043   // Use RK-stepper when a non-empty b-field was passed
0044   static constexpr auto is_no_bfield{
0045       std::is_same_v<bfield_t, navigation_validator::empty_bfield>};
0046   using stepper_t =
0047       std::conditional_t<is_no_bfield, line_stepper_t, rk_stepper_t>;
0048 
0049   // Inspector that records all encountered surfaces
0050   using intersection_t = typename intersection_record_t::intersection_type;
0051   using object_tracer_t =
0052       navigation::object_tracer<intersection_t, vecmem::device_vector,
0053                                 navigation::status::e_on_object,
0054                                 navigation::status::e_on_portal>;
0055   // Navigation with inspection
0056   using navigator_t =
0057       caching_navigator<detector_device_t, navigation::default_cache_size,
0058                         object_tracer_t, intersection_t>;
0059 
0060   // Propagator with pathlimit aborter
0061   using material_tracer_t =
0062       material_validator::material_tracer<scalar_t, vecmem::device_vector>;
0063   using pathlimit_aborter_t = actor::pathlimit_aborter<scalar_t>;
0064   using actor_chain_t = actor_chain<pathlimit_aborter_t, material_tracer_t>;
0065   using propagator_t = propagator<stepper_t, navigator_t, actor_chain_t>;
0066 
0067   detector_device_t det(det_data);
0068 
0069   vecmem::jagged_device_vector<const intersection_record_t>
0070       truth_intersection_traces(truth_intersection_traces_view);
0071   vecmem::jagged_device_vector<
0072       navigation::detail::candidate_record<intersection_t>>
0073       recorded_intersections(recorded_intersections_view);
0074   vecmem::device_vector<typename material_tracer_t::material_record_type>
0075       mat_records(mat_records_view);
0076   vecmem::jagged_device_vector<typename material_tracer_t::material_params_type>
0077       mat_steps(mat_steps_view);
0078 
0079   // Check the memory setup
0080   assert(truth_intersection_traces.size() ==
0081          recorded_intersections_view.size());
0082 
0083   int trk_id = threadIdx.x + blockIdx.x * blockDim.x;
0084   if (trk_id >= truth_intersection_traces.size()) {
0085     return;
0086   }
0087 
0088   propagator_t p{cfg};
0089 
0090   // Create the actor states
0091   typename pathlimit_aborter_t::state aborter_state{cfg.stepping.path_limit};
0092   typename material_tracer_t::state mat_tracer_state{mat_steps.at(trk_id)};
0093   auto actor_states = ::detray::tie(aborter_state, mat_tracer_state);
0094 
0095   // Get the initial track parameters
0096   const auto &track = truth_intersection_traces[trk_id].front().track_param;
0097 
0098   // Save the initial intersection, since it is not recorded by the
0099   // object tracer
0100   assert(recorded_intersections.at(trk_id).empty());
0101   recorded_intersections.at(trk_id).push_back(
0102       {track.pos(), track.dir(),
0103        truth_intersection_traces[trk_id].front().intersection});
0104   // Did the insertion of an element work?
0105   assert(recorded_intersections.at(trk_id).size() == 1);
0106 
0107   // Run propagation
0108   if constexpr (is_no_bfield) {
0109     typename propagator_t::state propagation(
0110         track, det,
0111         typename navigator_t::state::view_type{
0112             recorded_intersections_view.ptr()[trk_id]});
0113     propagation.set_particle(update_particle_hypothesis(ptc_hypo, track));
0114 
0115     p.propagate(propagation, actor_states);
0116   } else {
0117     typename propagator_t::stepper_type::magnetic_field_type bfield_view(
0118         field_data);
0119     typename propagator_t::state propagation(
0120         track, bfield_view, det,
0121         typename navigator_t::state::view_type{
0122             recorded_intersections_view.ptr()[trk_id]});
0123     propagation.set_particle(update_particle_hypothesis(ptc_hypo, track));
0124 
0125     p.propagate(propagation, actor_states);
0126   }
0127 
0128   // Record the accumulated material
0129   assert(truth_intersection_traces.size() == mat_records.size());
0130   mat_records.at(trk_id) = mat_tracer_state.get_material_record();
0131 }
0132 
0133 /// Launch the device kernel
0134 template <typename bfield_t, typename detector_t,
0135           typename intersection_record_t>
0136 void navigation_validation_device(
0137     typename detector_t::view_type det_view, const propagation::config &cfg,
0138     pdg_particle<typename detector_t::scalar_type> ptc_hypo,
0139     bfield_t field_data,
0140     vecmem::data::jagged_vector_view<const intersection_record_t>
0141         &truth_intersection_traces_view,
0142     vecmem::data::jagged_vector_view<navigation::detail::candidate_record<
0143         typename intersection_record_t::intersection_type>>
0144         &recorded_intersections_view,
0145     vecmem::data::vector_view<
0146         material_validator::material_record<typename detector_t::scalar_type>>
0147         &mat_records_view,
0148     vecmem::data::jagged_vector_view<
0149         material_validator::material_params<typename detector_t::scalar_type>>
0150         &mat_steps_view) {
0151   constexpr int thread_dim = 2 * WARP_SIZE;
0152   int block_dim = truth_intersection_traces_view.size() / thread_dim + 1;
0153 
0154   // run the test kernel
0155   navigation_validation_kernel<bfield_t, detector_t, intersection_record_t>
0156       <<<block_dim, thread_dim>>>(
0157           det_view, cfg, ptc_hypo, field_data, truth_intersection_traces_view,
0158           recorded_intersections_view, mat_records_view, mat_steps_view);
0159 
0160   // cuda error check
0161   DETRAY_CUDA_ERROR_CHECK(cudaGetLastError());
0162   DETRAY_CUDA_ERROR_CHECK(cudaDeviceSynchronize());
0163 }
0164 
0165 /// Macro declaring the template instantiations for the different detector types
0166 #define DECLARE_NAVIGATION_VALIDATION(METADATA)                              \
0167                                                                              \
0168   template void navigation_validation_device<                                \
0169       covfie::field_view<                                                    \
0170           bfield::const_bknd_t<dscalar<typename METADATA::algebra_type>>>,   \
0171       detector<METADATA>, detray::intersection_record<detector<METADATA>>>(  \
0172       typename detector<METADATA>::view_type, const propagation::config &,   \
0173       pdg_particle<typename detector<METADATA>::scalar_type>,                \
0174       covfie::field_view<                                                    \
0175           bfield::const_bknd_t<dscalar<typename METADATA::algebra_type>>>,   \
0176       vecmem::data::jagged_vector_view<                                      \
0177           const detray::intersection_record<detector<METADATA>>> &,          \
0178       vecmem::data::jagged_vector_view<navigation::detail::candidate_record< \
0179           typename detray::intersection_record<                              \
0180               detector<METADATA>>::intersection_type>> &,                    \
0181       vecmem::data::vector_view<material_validator::material_record<         \
0182           typename detector<METADATA>::scalar_type>> &,                      \
0183       vecmem::data::jagged_vector_view<material_validator::material_params<  \
0184           typename detector<METADATA>::scalar_type>> &);                     \
0185                                                                              \
0186   template void navigation_validation_device<                                \
0187       detray::navigation_validator::empty_bfield, detector<METADATA>,        \
0188       detray::intersection_record<detector<METADATA>>>(                      \
0189       typename detector<METADATA>::view_type, const propagation::config &,   \
0190       pdg_particle<typename detector<METADATA>::scalar_type>,                \
0191       detray::navigation_validator::empty_bfield,                            \
0192       vecmem::data::jagged_vector_view<                                      \
0193           const detray::intersection_record<detector<METADATA>>> &,          \
0194       vecmem::data::jagged_vector_view<navigation::detail::candidate_record< \
0195           typename detray::intersection_record<                              \
0196               detector<METADATA>>::intersection_type>> &,                    \
0197       vecmem::data::vector_view<material_validator::material_record<         \
0198           typename detector<METADATA>::scalar_type>> &,                      \
0199       vecmem::data::jagged_vector_view<material_validator::material_params<  \
0200           typename detector<METADATA>::scalar_type>> &);
0201 
0202 DECLARE_NAVIGATION_VALIDATION(test::default_metadata)
0203 DECLARE_NAVIGATION_VALIDATION(test::toy_metadata)
0204 DECLARE_NAVIGATION_VALIDATION(test::default_telescope_metadata)
0205 DECLARE_NAVIGATION_VALIDATION(test::wire_chamber_metadata)
0206 
0207 }  // namespace detray::cuda