Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:24:13

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 // Project include(s)
0012 #include "detray/propagator/line_stepper.hpp"
0013 #include "detray/propagator/rk_stepper.hpp"
0014 #include "detray/tracks/ray.hpp"
0015 #include "detray/tracks/tracks.hpp"
0016 #include "detray/utils/logging.hpp"
0017 
0018 // Detray test include(s)
0019 #include "detray/test/common/bfield.hpp"
0020 #include "detray/test/framework/fixture_base.hpp"
0021 #include "detray/test/framework/whiteboard.hpp"
0022 #include "detray/test/validation/detector_scan_utils.hpp"
0023 #include "detray/test/validation/detector_scanner.hpp"
0024 #include "detray/test/validation/material_validation_utils.hpp"
0025 #include "detray/test/validation/navigation_validation_config.hpp"
0026 #include "detray/test/validation/navigation_validation_utils.hpp"
0027 
0028 // Vecmem include(s)
0029 #include <vecmem/memory/host_memory_resource.hpp>
0030 
0031 // System include(s)
0032 #include <iostream>
0033 #include <memory>
0034 #include <string>
0035 
0036 namespace detray::test {
0037 
0038 /// @brief Test class that runs the navigation check on a given detector.
0039 ///
0040 /// @note The lifetime of the detector needs to be guaranteed.
0041 template <typename detector_t, template <typename> class scan_type>
0042 class navigation_validation : public test::fixture_base<> {
0043   using algebra_t = typename detector_t::algebra_type;
0044   using scalar_t = dscalar<algebra_t>;
0045   using vector3_t = dvector3D<algebra_t>;
0046   using free_track_parameters_t = free_track_parameters<algebra_t>;
0047   using trajectory_type = typename scan_type<algebra_t>::trajectory_type;
0048   using truth_trace_t = typename scan_type<
0049       algebra_t>::template intersection_trace_type<detector_t>;
0050 
0051   /// Switch between rays and helices
0052   static constexpr auto k_use_rays{
0053       std::is_same_v<detail::ray<algebra_t>, trajectory_type>};
0054 
0055  public:
0056   using fixture_type = test::fixture_base<>;
0057   using config = navigation_validation_config<algebra_t>;
0058 
0059   explicit navigation_validation(
0060       const detector_t &det, const typename detector_t::name_map &names,
0061       const config &cfg = {}, std::shared_ptr<test::whiteboard> wb = nullptr,
0062       const typename detector_t::geometry_context gctx = {})
0063       : m_cfg{cfg},
0064         m_gctx{gctx},
0065         m_det{det},
0066         m_names{names},
0067         m_whiteboard{std::move(wb)} {
0068     if (!m_whiteboard) {
0069       throw std::invalid_argument("No white board was passed to " +
0070                                   m_cfg.name() + " test");
0071     }
0072   }
0073 
0074   /// Run the check
0075   void TestBody() override {
0076     using namespace detray;
0077     using namespace navigation;
0078 
0079     using intersection_t =
0080         typename truth_trace_t::value_type::intersection_type;
0081 
0082     // Runge-Kutta stepper
0083     using hom_bfield_t = bfield::const_field_t<scalar_t>;
0084     using bfield_t =
0085         std::conditional_t<k_use_rays, navigation_validator::empty_bfield,
0086                            hom_bfield_t>;
0087     using rk_stepper_t =
0088         rk_stepper<typename hom_bfield_t::view_t, algebra_t,
0089                    unconstrained_step<scalar_t>, stepper_rk_policy<scalar_t>,
0090                    stepping::print_inspector>;
0091     using line_stepper_t = line_stepper<algebra_t, unconstrained_step<scalar_t>,
0092                                         stepper_default_policy<scalar_t>,
0093                                         stepping::print_inspector>;
0094     using stepper_t =
0095         std::conditional_t<k_use_rays, line_stepper_t, rk_stepper_t>;
0096 
0097     bfield_t b_field{};
0098     if constexpr (!k_use_rays) {
0099       b_field = create_const_field<scalar_t>(m_cfg.B_vector());
0100     }
0101 
0102     // Use ray or helix
0103     const std::string det_name{m_det.name(m_names)};
0104     const std::string truth_data_name{k_use_rays ? det_name + "_ray_scan"
0105                                                  : det_name + "_helix_scan"};
0106 
0107     // Collect some statistics
0108     std::size_t n_tracks{0u};
0109     std::size_t n_matching_error{0u};
0110     std::size_t n_fatal{0u};
0111     // Total number of encountered surfaces
0112     navigation_validator::surface_stats n_surfaces{};
0113     // Missed by navigator
0114     navigation_validator::surface_stats n_miss_nav{};
0115     // Missed by truth finder
0116     navigation_validator::surface_stats n_miss_truth{};
0117 
0118     DETRAY_INFO_HOST("Fetching data from white board...");
0119     if (!m_whiteboard->exists(truth_data_name)) {
0120       throw std::runtime_error(
0121           "White board is empty! Please run detector scan first");
0122     }
0123     auto &truth_traces =
0124         m_whiteboard->template get<std::vector<truth_trace_t>>(truth_data_name);
0125     ASSERT_EQ(m_cfg.n_tracks(), truth_traces.size());
0126 
0127     DETRAY_INFO_HOST("Running navigation validation on: " << det_name
0128                                                           << "...\n");
0129 
0130     std::string momentum_str{""};
0131     const std::string prefix{k_use_rays ? det_name + "_ray_"
0132                                         : det_name + "_helix_"};
0133 
0134     const auto data_path{
0135         std::filesystem::path{m_cfg.track_param_file()}.parent_path()};
0136 
0137     // Create an output file path
0138     auto make_path = [&data_path, &prefix, &momentum_str](
0139                          const std::string &name,
0140                          const std::string &extension = ".csv") {
0141       return data_path / (prefix + name + momentum_str + extension);
0142     };
0143 
0144     std::ios_base::openmode io_mode = std::ios::trunc | std::ios::out;
0145     const std::string debug_file_name{
0146         make_path("navigation_validation", ".txt")};
0147     detray::io::file_handle debug_file{debug_file_name, io_mode};
0148 
0149     // Keep a record of track positions and material along the track
0150     dvector<dvector<navigation::detail::candidate_record<intersection_t>>>
0151         recorded_traces{};
0152     dvector<material_validator::material_record<scalar_t>> mat_records{};
0153     std::vector<std::pair<trajectory_type, std::vector<intersection_t>>>
0154         missed_intersections{};
0155 
0156     scalar_t min_pT{std::numeric_limits<scalar_t>::max()};
0157     scalar_t max_pT{-std::numeric_limits<scalar_t>::max()};
0158     for (auto &truth_trace : truth_traces) {
0159       if (n_tracks >= m_cfg.n_tracks()) {
0160         break;
0161       }
0162 
0163       // Follow the test trajectory with a track and check, if we find
0164       // the same volumes and distances along the way
0165       const auto &start = truth_trace.front();
0166       const auto &track = start.track_param;
0167       assert(!track.is_invalid());
0168       trajectory_type test_traj = get_parametrized_trajectory(track);
0169 
0170       const scalar q = start.charge;
0171       const scalar pT{q == 0.f ? 1.f * unit<scalar>::GeV : track.pT(q)};
0172       const scalar p{q == 0.f ? 1.f * unit<scalar>::GeV : track.p(q)};
0173 
0174       // If the momentum is unknown, 1 GeV is the safest option to keep
0175       // the direction vector normalized
0176       if (detray::detail::is_invalid_value(m_cfg.p_range()[0])) {
0177         min_pT = std::min(min_pT, pT);
0178         max_pT = std::max(max_pT, pT);
0179       } else {
0180         min_pT = m_cfg.p_range()[0];
0181         max_pT = m_cfg.p_range()[1];
0182       }
0183       assert(min_pT > 0.f);
0184       assert(max_pT > 0.f);
0185       assert(min_pT < std::numeric_limits<scalar_t>::max());
0186       assert(max_pT < std::numeric_limits<scalar_t>::max());
0187 
0188       // Run the propagation
0189       auto [success, obj_tracer, step_trace, mat_record, mat_trace, nav_printer,
0190             step_printer] =
0191           navigation_validator::record_propagation<stepper_t>(
0192               m_gctx, &m_host_mr, m_det, m_cfg.propagation(), track,
0193               m_cfg.ptc_hypothesis(), b_field);
0194 
0195       if (success) {
0196         assert(!obj_tracer.object_trace.empty());
0197         // The navigator does not record the initial track position:
0198         // add it as a dummy record
0199         obj_tracer.object_trace.insert(
0200             obj_tracer.object_trace.begin(),
0201             {track.pos(), track.dir(), start.intersection});
0202 
0203         // Adjust the track charge, which is unknown to the navigation
0204         for (auto &record : obj_tracer.object_trace) {
0205           record.charge = q;
0206           record.p_mag = p;
0207         }
0208 
0209         auto [result, n_missed_nav, n_missed_truth, n_error, missed_inters] =
0210             navigation_validator::compare_traces(
0211                 m_cfg, truth_trace, obj_tracer.object_trace, test_traj,
0212                 n_tracks, &(*debug_file));
0213 
0214         missed_intersections.push_back(
0215             std::make_pair(test_traj, std::move(missed_inters)));
0216 
0217         // Update statistics
0218         success = success && result;
0219         n_miss_nav += n_missed_nav;
0220         n_miss_truth += n_missed_truth;
0221         n_matching_error += n_error;
0222 
0223       } else {
0224         // Propagation did not succeed
0225         ++n_fatal;
0226 
0227         std::vector<intersection_t> missed_inters{};
0228         missed_intersections.push_back(
0229             std::make_pair(test_traj, missed_inters));
0230       }
0231 
0232       if (!success) {
0233         // Write debug info to file
0234         *debug_file << "TEST TRACK " << n_tracks << ":\n\n"
0235                     << "NAVIGATOR\n\n"
0236                     << nav_printer.to_string() << "\nSTEPPER\n\n"
0237                     << step_printer.to_string();
0238 
0239         detector_scanner::display_error(
0240             m_gctx, m_det, m_names, m_cfg.name(), test_traj, truth_trace,
0241             m_cfg.svg_style(), n_tracks, m_cfg.n_tracks(),
0242             obj_tracer.object_trace);
0243       }
0244 
0245       recorded_traces.push_back(std::move(obj_tracer.object_trace));
0246       mat_records.push_back(mat_record);
0247 
0248       EXPECT_TRUE(success)
0249           << "\nDETRAY INFO (HOST): Wrote navigation debugging data in: "
0250           << debug_file_name;
0251 
0252       ++n_tracks;
0253 
0254       // After dummy records insertion, traces should have the same size
0255       ASSERT_EQ(truth_trace.size(), recorded_traces.back().size());
0256 
0257       // Count the number of different surface types on this trace
0258       navigation_validator::surface_stats n_truth{};
0259       navigation_validator::surface_stats n_nav{};
0260       for (std::size_t i = 0; i < truth_trace.size(); ++i) {
0261         const auto truth_desc = truth_trace[i].intersection.surface();
0262         const auto rec_desc = recorded_traces.back()[i].intersection.surface();
0263 
0264         // Exclude dummy records for missing surfaces
0265         if (!truth_desc.identifier().is_invalid()) {
0266           n_truth.count(truth_desc);
0267         }
0268         if (!rec_desc.identifier().is_invalid()) {
0269           n_nav.count(rec_desc);
0270         }
0271       }
0272 
0273       // Take max count, since either trace might have skipped surfaces
0274       const std::size_t n_portals{
0275           math::max(n_truth.n_portals, n_nav.n_portals)};
0276       const std::size_t n_sensitives{
0277           math::max(n_truth.n_sensitives, n_nav.n_sensitives)};
0278       const std::size_t n_passives{
0279           math::max(n_truth.n_passives, n_nav.n_passives)};
0280       const std::size_t n{n_portals + n_sensitives + n_passives};
0281 
0282       // Cannot have less surfaces than truth intersections after matching
0283       // (Don't count first entry, which records the initial track params)
0284       ASSERT_TRUE(n >= (truth_trace.size() - 1u));
0285 
0286       n_surfaces.n_portals += n_portals;
0287       n_surfaces.n_sensitives += n_sensitives;
0288       n_surfaces.n_passives += n_passives;
0289     }
0290 
0291     // Calculate and display the result
0292     navigation_validator::print_efficiency(n_tracks, n_surfaces, n_miss_nav,
0293                                            n_miss_truth, n_fatal,
0294                                            n_matching_error);
0295 
0296     // Print track positions for plotting
0297     if constexpr (!k_use_rays) {
0298       momentum_str =
0299           "_" +
0300           std::to_string(std::floor(10. * static_cast<double>(min_pT)) / 10.) +
0301           "_" +
0302           std::to_string(std::ceil(10. * static_cast<double>(max_pT)) / 10.) +
0303           "_GeV";
0304     }
0305 
0306     const auto truth_trk_path{make_path("truth_track_params")};
0307     const auto trk_path{make_path("navigation_track_params")};
0308     const auto truth_intr_path{make_path("truth_intersections")};
0309     const auto intr_path{make_path("navigation_intersections")};
0310     const auto mat_path{make_path("accumulated_material")};
0311     const auto missed_path{make_path("missed_intersections_dists")};
0312 
0313     // Write the distance of the missed intersection local position
0314     // to the surface boundaries to file for plotting
0315     navigation_validator::write_dist_to_boundary(
0316         m_det, m_names, missed_path.string(), missed_intersections);
0317     detector_scanner::write_tracks(truth_trk_path.string(), truth_traces);
0318     navigation_validator::write_tracks(trk_path.string(), recorded_traces);
0319     detector_scanner::write_intersections(truth_intr_path.string(),
0320                                           truth_traces);
0321     detector_scanner::write_intersections(intr_path.string(), recorded_traces);
0322     material_validator::write_material(mat_path.string(), mat_records);
0323 
0324     DETRAY_INFO_HOST("Wrote distance to boundary of missed intersections in: "
0325                      << missed_path);
0326     DETRAY_INFO_HOST("Wrote truth track states in: " << truth_trk_path);
0327     DETRAY_INFO_HOST("Wrote recorded track states in: " << trk_path);
0328     DETRAY_INFO_HOST(
0329         "Wrote recorded truth intersections in: " << truth_intr_path);
0330     DETRAY_INFO_HOST("Wrote recorded track intersections in: " << intr_path);
0331     DETRAY_INFO_HOST("Wrote accumulated material in: " << mat_path);
0332   }
0333 
0334  private:
0335   /// @returns either the helix or ray corresponding to the input track
0336   /// parameters @param track
0337   trajectory_type get_parametrized_trajectory(
0338       const free_track_parameters_t &track) {
0339     std::unique_ptr<trajectory_type> test_traj{nullptr};
0340     if constexpr (k_use_rays) {
0341       test_traj = std::make_unique<trajectory_type>(track);
0342     } else {
0343       test_traj = std::make_unique<trajectory_type>(track, m_cfg.B_vector());
0344     }
0345     return *(test_traj.release());
0346   }
0347 
0348   /// The configuration of this test
0349   config m_cfg;
0350   /// The geometry context to check
0351   typename detector_t::geometry_context m_gctx{};
0352   /// Vecmem memory resource for the host allocations
0353   vecmem::host_memory_resource m_host_mr{};
0354   /// The detector to be checked
0355   const detector_t &m_det;
0356   /// Volume names
0357   const typename detector_t::name_map &m_names;
0358   /// Whiteboard to pin data
0359   std::shared_ptr<test::whiteboard> m_whiteboard{nullptr};
0360 };
0361 
0362 template <typename detector_t>
0363 using straight_line_navigation =
0364     navigation_validation<detector_t, detray::ray_scan>;
0365 
0366 template <typename detector_t>
0367 using helix_navigation = navigation_validation<detector_t, detray::helix_scan>;
0368 
0369 }  // namespace detray::test