Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:24:14

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 // Project include(s)
0012 #include "detray/core/detector.hpp"
0013 #include "detray/definitions/pdg_particle.hpp"
0014 #include "detray/propagator/line_stepper.hpp"
0015 #include "detray/propagator/rk_stepper.hpp"
0016 #include "detray/tracks/ray.hpp"
0017 #include "detray/tracks/tracks.hpp"
0018 #include "detray/utils/logging.hpp"
0019 
0020 // Detray test include(s)
0021 #include "detray/test/common/bfield.hpp"
0022 #include "detray/test/framework/fixture_base.hpp"
0023 #include "detray/test/framework/whiteboard.hpp"
0024 #include "detray/test/utils/inspectors.hpp"
0025 #include "detray/test/validation/detector_scan_utils.hpp"
0026 #include "detray/test/validation/detector_scanner.hpp"
0027 #include "detray/test/validation/material_validation_utils.hpp"
0028 #include "detray/test/validation/navigation_validation_config.hpp"
0029 #include "detray/test/validation/navigation_validation_utils.hpp"
0030 
0031 // Vecmem include(s)
0032 #include <vecmem/memory/cuda/device_memory_resource.hpp>
0033 #include <vecmem/memory/host_memory_resource.hpp>
0034 #include <vecmem/memory/memory_resource.hpp>
0035 #include <vecmem/utils/cuda/copy.hpp>
0036 
0037 // System include(s)
0038 #include <memory>
0039 #include <tuple>
0040 
0041 namespace detray::cuda {
0042 
0043 /// Launch the navigation validation kernel
0044 ///
0045 /// @param[in] det_view the detector vecmem view
0046 /// @param[in] cfg the propagation configuration
0047 /// @param[in] field_data the magentic field view (maybe an empty field)
0048 /// @param[in] truth_intersection_traces_view vecemem view of the truth data
0049 /// @param[out] recorded_intersections_view vecemem view of the intersections
0050 ///                                         recorded by the navigator
0051 template <typename bfield_t, typename detector_t,
0052           typename intersection_record_t>
0053 void navigation_validation_device(
0054     typename detector_t::view_type det_view, const propagation::config &cfg,
0055     pdg_particle<typename detector_t::scalar_type> ptc_hypo,
0056     bfield_t field_data,
0057     vecmem::data::jagged_vector_view<const intersection_record_t>
0058         &truth_intersection_traces_view,
0059     vecmem::data::jagged_vector_view<navigation::detail::candidate_record<
0060         typename intersection_record_t::intersection_type>>
0061         &recorded_intersections_view,
0062     vecmem::data::vector_view<
0063         material_validator::material_record<typename detector_t::scalar_type>>
0064         &mat_records_view,
0065     vecmem::data::jagged_vector_view<
0066         material_validator::material_params<typename detector_t::scalar_type>>
0067         &mat_steps_view);
0068 
0069 /// Prepare data for device navigation run
0070 template <typename bfield_t, typename detector_t,
0071           typename intersection_record_t>
0072 inline auto run_navigation_validation(
0073     vecmem::memory_resource *host_mr, vecmem::memory_resource *dev_mr,
0074     const detector_t &det, const propagation::config &cfg,
0075     pdg_particle<typename detector_t::scalar_type> ptc_hypo,
0076     bfield_t field_data,
0077     const std::vector<std::vector<intersection_record_t>>
0078         &truth_intersection_traces) {
0079   using scalar_t = dscalar<typename detector_t::algebra_type>;
0080   using intersection_t = typename intersection_record_t::intersection_type;
0081   using material_record_t = material_validator::material_record<scalar_t>;
0082   using material_params_t = material_validator::material_params<scalar_t>;
0083 
0084   // Helper object for performing memory copies (to CUDA devices)
0085   vecmem::cuda::copy cuda_cpy;
0086 
0087   // Copy the detector to device and get its view
0088   auto det_buffer = detray::get_buffer(det, *dev_mr, cuda_cpy);
0089   auto det_view = detray::get_data(det_buffer);
0090 
0091   // Move truth intersection traces data to device
0092   auto truth_intersection_traces_data =
0093       vecmem::get_data(truth_intersection_traces, host_mr);
0094   auto truth_intersection_traces_buffer =
0095       cuda_cpy.to(truth_intersection_traces_data, *dev_mr, host_mr,
0096                   vecmem::copy::type::host_to_device);
0097   vecmem::data::jagged_vector_view<const intersection_record_t>
0098       truth_intersection_traces_view =
0099           vecmem::get_data(truth_intersection_traces_buffer);
0100 
0101   // Buffer for the intersections recorded by the navigator
0102   std::vector<std::size_t> capacities;
0103   for (const auto &trace : truth_intersection_traces) {
0104     // Increase the capacity, in case the navigator finds more surfaces
0105     // than the truth intersections (usually just one)
0106     capacities.push_back(trace.size() + 10u);
0107   }
0108 
0109   vecmem::data::jagged_vector_buffer<
0110       navigation::detail::candidate_record<intersection_t>>
0111       recorded_intersections_buffer(capacities, *dev_mr, host_mr,
0112                                     vecmem::data::buffer_type::resizable);
0113   cuda_cpy.setup(recorded_intersections_buffer)->wait();
0114   auto recorded_intersections_view =
0115       vecmem::get_data(recorded_intersections_buffer);
0116 
0117   vecmem::data::vector_buffer<material_record_t> mat_records_buffer(
0118       static_cast<unsigned int>(truth_intersection_traces_view.size()), *dev_mr,
0119       vecmem::data::buffer_type::fixed_size);
0120   cuda_cpy.setup(mat_records_buffer)->wait();
0121   auto mat_records_view = vecmem::get_data(mat_records_buffer);
0122 
0123   // Buffer for the material parameters at every step per track
0124   vecmem::data::jagged_vector_buffer<material_params_t> mat_steps_buffer(
0125       capacities, *dev_mr, host_mr, vecmem::data::buffer_type::resizable);
0126   cuda_cpy.setup(mat_steps_buffer)->wait();
0127   auto mat_steps_view = vecmem::get_data(mat_steps_buffer);
0128 
0129   // Run the navigation validation test on device
0130   navigation_validation_device<bfield_t, detector_t, intersection_record_t>(
0131       det_view, cfg, ptc_hypo, field_data, truth_intersection_traces_view,
0132       recorded_intersections_view, mat_records_view, mat_steps_view);
0133 
0134   // Get the results back to the host and pass them on to the checking
0135   vecmem::jagged_vector<navigation::detail::candidate_record<intersection_t>>
0136       recorded_intersections(host_mr);
0137   cuda_cpy(recorded_intersections_buffer, recorded_intersections)->wait();
0138 
0139   vecmem::vector<material_record_t> mat_records(host_mr);
0140   cuda_cpy(mat_records_buffer, mat_records)->wait();
0141 
0142   vecmem::jagged_vector<material_params_t> mat_steps(host_mr);
0143   cuda_cpy(mat_steps_buffer, mat_steps)->wait();
0144 
0145   return std::make_tuple(std::move(recorded_intersections),
0146                          std::move(mat_records), std::move(mat_steps));
0147 }
0148 
0149 /// @brief Test class that runs the navigation validation for a given detector
0150 /// on device.
0151 ///
0152 /// @note The lifetime of the detector needs to be guaranteed outside this class
0153 template <typename detector_t, template <typename> class scan_type>
0154 class navigation_validation : public test::fixture_base<> {
0155   using algebra_t = typename detector_t::algebra_type;
0156   using scalar_t = dscalar<algebra_t>;
0157   using vector3_t = dvector3D<algebra_t>;
0158   using free_track_parameters_t = free_track_parameters<algebra_t>;
0159   using trajectory_type = typename scan_type<algebra_t>::trajectory_type;
0160   using intersection_trace_t = typename scan_type<
0161       algebra_t>::template intersection_trace_type<detector_t>;
0162 
0163   /// Switch between rays and helices
0164   static constexpr auto k_use_rays{
0165       std::is_same_v<detail::ray<algebra_t>, trajectory_type>};
0166 
0167  public:
0168   using fixture_type = test::fixture_base<>;
0169   using config = detray::test::navigation_validation_config<algebra_t>;
0170 
0171   explicit navigation_validation(
0172       const detector_t &det, const typename detector_t::name_map &names,
0173       const config &cfg = {}, std::shared_ptr<test::whiteboard> wb = nullptr,
0174       const typename detector_t::geometry_context gctx = {})
0175       : m_cfg{cfg}, m_gctx{gctx}, m_det{det}, m_names{names}, m_whiteboard{wb} {
0176     if (!m_whiteboard) {
0177       throw std::invalid_argument("No white board was passed to " +
0178                                   m_cfg.name() + " test");
0179     }
0180 
0181     // Use ray or helix
0182     const std::string det_name{m_det.name(m_names)};
0183     m_truth_data_name = k_use_rays ? det_name + "_ray_scan_for_cuda"
0184                                    : det_name + "_helix_scan_for_cuda";
0185 
0186     // Pin the data onto the whiteboard
0187     if (!m_whiteboard->exists(m_truth_data_name) &&
0188         io::file_exists(m_cfg.intersection_file()) &&
0189         io::file_exists(m_cfg.track_param_file())) {
0190       // Name clash: Choose alternative name
0191       if (m_whiteboard->exists(m_truth_data_name)) {
0192         m_truth_data_name = io::alt_file_name(m_truth_data_name);
0193       }
0194 
0195       std::vector<intersection_trace_t> intersection_traces;
0196 
0197       DETRAY_INFO_HOST("Reading data from file...");
0198 
0199       // Fill the intersection traces from file
0200       detray::detector_scanner::read(m_cfg.intersection_file(),
0201                                      m_cfg.track_param_file(),
0202                                      intersection_traces);
0203 
0204       m_whiteboard->add(m_truth_data_name, std::move(intersection_traces));
0205     } else if (m_whiteboard->exists(m_truth_data_name)) {
0206       DETRAY_INFO_HOST("Fetching data from white board...");
0207     } else {
0208       throw std::invalid_argument(
0209           "Navigation validation: Could not find data files");
0210     }
0211 
0212     // Check that data is ready
0213     if (!m_whiteboard->exists(m_truth_data_name)) {
0214       DETRAY_INFO_HOST("Data for navigation check is not on the whiteboard");
0215       throw std::invalid_argument(
0216           "Data for navigation check is not on the whiteboard");
0217     }
0218   }
0219 
0220   /// Run the check
0221   void TestBody() override {
0222     using namespace detray;
0223     using namespace navigation;
0224 
0225     // Runge-Kutta stepper
0226     using hom_bfield_t = bfield::const_field_t<scalar_t>;
0227     using bfield_view_t =
0228         std::conditional_t<k_use_rays, navigation_validator::empty_bfield,
0229                            typename hom_bfield_t::view_t>;
0230     using bfield_t =
0231         std::conditional_t<k_use_rays, navigation_validator::empty_bfield,
0232                            hom_bfield_t>;
0233     using intersection_t =
0234         typename intersection_trace_t::value_type::intersection_type;
0235 
0236     bfield_t b_field{};
0237     if constexpr (!k_use_rays) {
0238       b_field = create_const_field<scalar_t>(m_cfg.B_vector());
0239     }
0240 
0241     // Fetch the truth data
0242     auto &truth_intersection_traces =
0243         m_whiteboard->template get<std::vector<intersection_trace_t>>(
0244             m_truth_data_name);
0245     ASSERT_EQ(m_cfg.n_tracks(), truth_intersection_traces.size());
0246 
0247     DETRAY_INFO_HOST("Running device navigation validation on: "
0248                      << m_det.name(m_names) << "...\n");
0249 
0250     std::string momentum_str{""};
0251     const std::string det_name{m_det.name(m_names)};
0252     const std::string prefix{k_use_rays ? det_name + "_ray_"
0253                                         : det_name + "_helix_"};
0254 
0255     const auto data_path{
0256         std::filesystem::path{m_cfg.track_param_file()}.parent_path()};
0257 
0258     // Create an output file path
0259     auto make_path = [&data_path, &prefix, &momentum_str](
0260                          const std::string &name,
0261                          const std::string &extension = ".csv") {
0262       return data_path / (prefix + name + momentum_str + extension);
0263     };
0264 
0265     std::ios_base::openmode io_mode = std::ios::trunc | std::ios::out;
0266     const std::string debug_file_name{
0267         make_path(prefix + "navigation_validation_cuda", ".txt")};
0268     detray::io::file_handle debug_file{debug_file_name, io_mode};
0269 
0270     // Run the propagation on device and record the navigation data
0271     auto [recorded_intersections, mat_records, mat_steps] =
0272         run_navigation_validation<bfield_view_t>(
0273             &m_host_mr, &m_dev_mr, m_det, m_cfg.propagation(),
0274             m_cfg.ptc_hypothesis(), b_field, truth_intersection_traces);
0275 
0276     // Collect some statistics
0277     std::size_t n_tracks{0u};
0278     std::size_t n_matching_error{0u};
0279     std::size_t n_fatal{0u};
0280     // Total number of encountered surfaces
0281     navigation_validator::surface_stats n_surfaces{};
0282     // Missed by navigator
0283     navigation_validator::surface_stats n_miss_nav{};
0284     // Missed by truth finder
0285     navigation_validator::surface_stats n_miss_truth{};
0286 
0287     std::vector<std::pair<trajectory_type, std::vector<intersection_t>>>
0288         missed_intersections{};
0289 
0290     EXPECT_EQ(recorded_intersections.size(), truth_intersection_traces.size());
0291 
0292     scalar_t min_pT{std::numeric_limits<scalar_t>::max()};
0293     scalar_t max_pT{-std::numeric_limits<scalar_t>::max()};
0294     for (std::size_t i = 0u; i < truth_intersection_traces.size(); ++i) {
0295       auto &truth_trace = truth_intersection_traces[i];
0296       auto &recorded_trace = recorded_intersections[i];
0297 
0298       if (n_tracks >= m_cfg.n_tracks()) {
0299         break;
0300       }
0301 
0302       // Get the original test trajectory (ray or helix)
0303       const auto &start = truth_trace.front();
0304       const auto &trck_param = start.track_param;
0305       trajectory_type test_traj = get_parametrized_trajectory(trck_param);
0306 
0307       const scalar q = start.charge;
0308       const scalar pT{q == 0.f ? 1.f * unit<scalar>::GeV : trck_param.pT(q)};
0309       const scalar p{q == 0.f ? 1.f * unit<scalar>::GeV : trck_param.p(q)};
0310 
0311       if (detray::detail::is_invalid_value(m_cfg.p_range()[0])) {
0312         min_pT = std::min(min_pT, pT);
0313         max_pT = std::max(max_pT, pT);
0314       } else {
0315         min_pT = m_cfg.p_range()[0];
0316         max_pT = m_cfg.p_range()[1];
0317       }
0318 
0319       // Recorded only the start position, which added by default
0320       bool success{true};
0321       if (truth_trace.size() == 1) {
0322         // Propagation did not succeed
0323         success = false;
0324         std::vector<intersection_t> missed_inters{};
0325         missed_intersections.push_back(
0326             std::make_pair(test_traj, missed_inters));
0327 
0328         ++n_fatal;
0329       } else {
0330         // Adjust the track charge, which is unknown to the navigation
0331         for (auto &record : recorded_trace) {
0332           record.charge = q;
0333           record.p_mag = p;
0334         }
0335 
0336         // Compare truth and recorded data elementwise
0337         auto [result, n_missed_nav, n_missed_truth, n_error, missed_inters] =
0338             navigation_validator::compare_traces(m_cfg, truth_trace,
0339                                                  recorded_trace, test_traj,
0340                                                  n_tracks, &(*debug_file));
0341 
0342         missed_intersections.push_back(
0343             std::make_pair(test_traj, std::move(missed_inters)));
0344 
0345         // Update statistics
0346         success = success && result;
0347         n_miss_nav += n_missed_nav;
0348         n_miss_truth += n_missed_truth;
0349         n_matching_error += n_error;
0350       }
0351 
0352       if (!success) {
0353         detector_scanner::display_error(
0354             m_gctx, m_det, m_names, m_cfg.name(), test_traj, truth_trace,
0355             m_cfg.svg_style(), n_tracks, m_cfg.n_tracks(), recorded_trace);
0356       }
0357 
0358       EXPECT_TRUE(success) << "\nINFO: Wrote navigation debugging data in: "
0359                            << debug_file_name;
0360 
0361       ++n_tracks;
0362 
0363       // After dummy records insertion, traces should have the same size
0364       ASSERT_EQ(truth_trace.size(), recorded_trace.size());
0365 
0366       // Count the number of different surface types on this trace
0367       navigation_validator::surface_stats n_truth{};
0368       navigation_validator::surface_stats n_nav{};
0369       for (std::size_t j = 0; j < truth_trace.size(); ++j) {
0370         const auto truth_desc = truth_trace[j].intersection.surface();
0371         const auto rec_desc = recorded_trace[j].intersection.surface();
0372 
0373         // Exclude dummy records for missing surfaces
0374         if (!truth_desc.identifier().is_invalid()) {
0375           n_truth.count(truth_desc);
0376         }
0377         if (!rec_desc.identifier().is_invalid()) {
0378           n_nav.count(rec_desc);
0379         }
0380       }
0381 
0382       // Take max count, since either trace might have skipped surfaces
0383       const std::size_t n_portals{
0384           math::max(n_truth.n_portals, n_nav.n_portals)};
0385       const std::size_t n_sensitives{
0386           math::max(n_truth.n_sensitives, n_nav.n_sensitives)};
0387       const std::size_t n_passives{
0388           math::max(n_truth.n_passives, n_nav.n_passives)};
0389       const std::size_t n{n_portals + n_sensitives + n_passives};
0390 
0391       // Cannot have less surfaces than truth intersections after matching
0392       // (Don't count first entry, which records the initial track params)
0393       ASSERT_TRUE(n >= (truth_trace.size() - 1u));
0394 
0395       n_surfaces.n_portals += n_portals;
0396       n_surfaces.n_sensitives += n_sensitives;
0397       n_surfaces.n_passives += n_passives;
0398     }
0399 
0400     // Calculate and display the result
0401     navigation_validator::print_efficiency(n_tracks, n_surfaces, n_miss_nav,
0402                                            n_miss_truth, n_fatal,
0403                                            n_matching_error);
0404 
0405     // Print track positions for plotting
0406     if constexpr (!k_use_rays) {
0407       momentum_str =
0408           "_" +
0409           std::to_string(std::floor(10. * static_cast<double>(min_pT)) / 10.) +
0410           "_" +
0411           std::to_string(std::ceil(10. * static_cast<double>(max_pT)) / 10.) +
0412           "_GeV";
0413     }
0414 
0415     const auto truth_trk_path{make_path("truth_track_params_cuda")};
0416     const auto trk_path{make_path("navigation_track_params_cuda")};
0417     const auto truth_intr_path{make_path("truth_intersections_cuda")};
0418     const auto intr_path{make_path("navigation_intersections_cuda")};
0419     const auto mat_path{make_path("accumulated_material_cuda")};
0420     const auto missed_path{make_path("missed_intersections_dists_cuda")};
0421 
0422     // Write the distance of the missed intersection local position
0423     // to the surface boundaries to file for plotting
0424     navigation_validator::write_dist_to_boundary(
0425         m_det, m_names, missed_path.string(), missed_intersections);
0426     detector_scanner::write_tracks(truth_trk_path.string(),
0427                                    truth_intersection_traces);
0428     navigation_validator::write_tracks(trk_path.string(),
0429                                        recorded_intersections);
0430     detector_scanner::write_intersections(truth_intr_path.string(),
0431                                           truth_intersection_traces);
0432     detector_scanner::write_intersections(intr_path.string(),
0433                                           recorded_intersections);
0434     material_validator::write_material(mat_path.string(), mat_records);
0435 
0436     DETRAY_INFO_HOST("Wrote distance to boundary of missed intersections in: "
0437                      << missed_path);
0438     DETRAY_INFO_HOST("Wrote track states in: " << trk_path);
0439     DETRAY_INFO_HOST("Wrote truth intersections in: " << truth_intr_path);
0440     DETRAY_INFO_HOST("Wrote track intersections in: " << intr_path);
0441     DETRAY_INFO_HOST("Wrote accumulated material in: " << mat_path);
0442   }
0443 
0444  private:
0445   /// @returns either the helix or ray corresponding to the input track
0446   /// parameters @param track
0447   trajectory_type get_parametrized_trajectory(
0448       const free_track_parameters_t &track) {
0449     std::unique_ptr<trajectory_type> test_traj{nullptr};
0450     if constexpr (k_use_rays) {
0451       test_traj = std::make_unique<trajectory_type>(track);
0452     } else {
0453       test_traj = std::make_unique<trajectory_type>(track, m_cfg.B_vector());
0454     }
0455     return *(test_traj.release());
0456   }
0457 
0458   /// Vecmem memory resource for the host allocations
0459   vecmem::host_memory_resource m_host_mr{};
0460   /// Vecmem memory resource for the device allocations
0461   vecmem::cuda::device_memory_resource m_dev_mr{};
0462   /// The configuration of this test
0463   config m_cfg;
0464   /// Name of the truth data collection
0465   std::string m_truth_data_name{""};
0466   /// The geometry context to check
0467   typename detector_t::geometry_context m_gctx{};
0468   /// The detector to be checked
0469   const detector_t &m_det;
0470   /// Volume names
0471   const typename detector_t::name_map &m_names;
0472   /// Whiteboard to pin data
0473   std::shared_ptr<test::whiteboard> m_whiteboard{nullptr};
0474 };
0475 
0476 template <typename detector_t>
0477 using straight_line_navigation =
0478     detray::cuda::navigation_validation<detector_t, detray::ray_scan>;
0479 
0480 template <typename detector_t>
0481 using helix_navigation =
0482     detray::cuda::navigation_validation<detector_t, detray::helix_scan>;
0483 
0484 }  // namespace detray::cuda