File indexing completed on 2026-05-27 07:24:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011
0012 #include "detray/propagator/line_stepper.hpp"
0013 #include "detray/propagator/rk_stepper.hpp"
0014 #include "detray/tracks/ray.hpp"
0015 #include "detray/tracks/tracks.hpp"
0016 #include "detray/utils/logging.hpp"
0017
0018
0019 #include "detray/test/common/bfield.hpp"
0020 #include "detray/test/framework/fixture_base.hpp"
0021 #include "detray/test/framework/whiteboard.hpp"
0022 #include "detray/test/validation/detector_scan_utils.hpp"
0023 #include "detray/test/validation/detector_scanner.hpp"
0024 #include "detray/test/validation/material_validation_utils.hpp"
0025 #include "detray/test/validation/navigation_validation_config.hpp"
0026 #include "detray/test/validation/navigation_validation_utils.hpp"
0027
0028
0029 #include <vecmem/memory/host_memory_resource.hpp>
0030
0031
0032 #include <iostream>
0033 #include <memory>
0034 #include <string>
0035
0036 namespace detray::test {
0037
0038
0039
0040
0041 template <typename detector_t, template <typename> class scan_type>
0042 class navigation_validation : public test::fixture_base<> {
0043 using algebra_t = typename detector_t::algebra_type;
0044 using scalar_t = dscalar<algebra_t>;
0045 using vector3_t = dvector3D<algebra_t>;
0046 using free_track_parameters_t = free_track_parameters<algebra_t>;
0047 using trajectory_type = typename scan_type<algebra_t>::trajectory_type;
0048 using truth_trace_t = typename scan_type<
0049 algebra_t>::template intersection_trace_type<detector_t>;
0050
0051
0052 static constexpr auto k_use_rays{
0053 std::is_same_v<detail::ray<algebra_t>, trajectory_type>};
0054
0055 public:
0056 using fixture_type = test::fixture_base<>;
0057 using config = navigation_validation_config<algebra_t>;
0058
0059 explicit navigation_validation(
0060 const detector_t &det, const typename detector_t::name_map &names,
0061 const config &cfg = {}, std::shared_ptr<test::whiteboard> wb = nullptr,
0062 const typename detector_t::geometry_context gctx = {})
0063 : m_cfg{cfg},
0064 m_gctx{gctx},
0065 m_det{det},
0066 m_names{names},
0067 m_whiteboard{std::move(wb)} {
0068 if (!m_whiteboard) {
0069 throw std::invalid_argument("No white board was passed to " +
0070 m_cfg.name() + " test");
0071 }
0072 }
0073
0074
0075 void TestBody() override {
0076 using namespace detray;
0077 using namespace navigation;
0078
0079 using intersection_t =
0080 typename truth_trace_t::value_type::intersection_type;
0081
0082
0083 using hom_bfield_t = bfield::const_field_t<scalar_t>;
0084 using bfield_t =
0085 std::conditional_t<k_use_rays, navigation_validator::empty_bfield,
0086 hom_bfield_t>;
0087 using rk_stepper_t =
0088 rk_stepper<typename hom_bfield_t::view_t, algebra_t,
0089 unconstrained_step<scalar_t>, stepper_rk_policy<scalar_t>,
0090 stepping::print_inspector>;
0091 using line_stepper_t = line_stepper<algebra_t, unconstrained_step<scalar_t>,
0092 stepper_default_policy<scalar_t>,
0093 stepping::print_inspector>;
0094 using stepper_t =
0095 std::conditional_t<k_use_rays, line_stepper_t, rk_stepper_t>;
0096
0097 bfield_t b_field{};
0098 if constexpr (!k_use_rays) {
0099 b_field = create_const_field<scalar_t>(m_cfg.B_vector());
0100 }
0101
0102
0103 const std::string det_name{m_det.name(m_names)};
0104 const std::string truth_data_name{k_use_rays ? det_name + "_ray_scan"
0105 : det_name + "_helix_scan"};
0106
0107
0108 std::size_t n_tracks{0u};
0109 std::size_t n_matching_error{0u};
0110 std::size_t n_fatal{0u};
0111
0112 navigation_validator::surface_stats n_surfaces{};
0113
0114 navigation_validator::surface_stats n_miss_nav{};
0115
0116 navigation_validator::surface_stats n_miss_truth{};
0117
0118 DETRAY_INFO_HOST("Fetching data from white board...");
0119 if (!m_whiteboard->exists(truth_data_name)) {
0120 throw std::runtime_error(
0121 "White board is empty! Please run detector scan first");
0122 }
0123 auto &truth_traces =
0124 m_whiteboard->template get<std::vector<truth_trace_t>>(truth_data_name);
0125 ASSERT_EQ(m_cfg.n_tracks(), truth_traces.size());
0126
0127 DETRAY_INFO_HOST("Running navigation validation on: " << det_name
0128 << "...\n");
0129
0130 std::string momentum_str{""};
0131 const std::string prefix{k_use_rays ? det_name + "_ray_"
0132 : det_name + "_helix_"};
0133
0134 const auto data_path{
0135 std::filesystem::path{m_cfg.track_param_file()}.parent_path()};
0136
0137
0138 auto make_path = [&data_path, &prefix, &momentum_str](
0139 const std::string &name,
0140 const std::string &extension = ".csv") {
0141 return data_path / (prefix + name + momentum_str + extension);
0142 };
0143
0144 std::ios_base::openmode io_mode = std::ios::trunc | std::ios::out;
0145 const std::string debug_file_name{
0146 make_path("navigation_validation", ".txt")};
0147 detray::io::file_handle debug_file{debug_file_name, io_mode};
0148
0149
0150 dvector<dvector<navigation::detail::candidate_record<intersection_t>>>
0151 recorded_traces{};
0152 dvector<material_validator::material_record<scalar_t>> mat_records{};
0153 std::vector<std::pair<trajectory_type, std::vector<intersection_t>>>
0154 missed_intersections{};
0155
0156 scalar_t min_pT{std::numeric_limits<scalar_t>::max()};
0157 scalar_t max_pT{-std::numeric_limits<scalar_t>::max()};
0158 for (auto &truth_trace : truth_traces) {
0159 if (n_tracks >= m_cfg.n_tracks()) {
0160 break;
0161 }
0162
0163
0164
0165 const auto &start = truth_trace.front();
0166 const auto &track = start.track_param;
0167 assert(!track.is_invalid());
0168 trajectory_type test_traj = get_parametrized_trajectory(track);
0169
0170 const scalar q = start.charge;
0171 const scalar pT{q == 0.f ? 1.f * unit<scalar>::GeV : track.pT(q)};
0172 const scalar p{q == 0.f ? 1.f * unit<scalar>::GeV : track.p(q)};
0173
0174
0175
0176 if (detray::detail::is_invalid_value(m_cfg.p_range()[0])) {
0177 min_pT = std::min(min_pT, pT);
0178 max_pT = std::max(max_pT, pT);
0179 } else {
0180 min_pT = m_cfg.p_range()[0];
0181 max_pT = m_cfg.p_range()[1];
0182 }
0183 assert(min_pT > 0.f);
0184 assert(max_pT > 0.f);
0185 assert(min_pT < std::numeric_limits<scalar_t>::max());
0186 assert(max_pT < std::numeric_limits<scalar_t>::max());
0187
0188
0189 auto [success, obj_tracer, step_trace, mat_record, mat_trace, nav_printer,
0190 step_printer] =
0191 navigation_validator::record_propagation<stepper_t>(
0192 m_gctx, &m_host_mr, m_det, m_cfg.propagation(), track,
0193 m_cfg.ptc_hypothesis(), b_field);
0194
0195 if (success) {
0196 assert(!obj_tracer.object_trace.empty());
0197
0198
0199 obj_tracer.object_trace.insert(
0200 obj_tracer.object_trace.begin(),
0201 {track.pos(), track.dir(), start.intersection});
0202
0203
0204 for (auto &record : obj_tracer.object_trace) {
0205 record.charge = q;
0206 record.p_mag = p;
0207 }
0208
0209 auto [result, n_missed_nav, n_missed_truth, n_error, missed_inters] =
0210 navigation_validator::compare_traces(
0211 m_cfg, truth_trace, obj_tracer.object_trace, test_traj,
0212 n_tracks, &(*debug_file));
0213
0214 missed_intersections.push_back(
0215 std::make_pair(test_traj, std::move(missed_inters)));
0216
0217
0218 success = success && result;
0219 n_miss_nav += n_missed_nav;
0220 n_miss_truth += n_missed_truth;
0221 n_matching_error += n_error;
0222
0223 } else {
0224
0225 ++n_fatal;
0226
0227 std::vector<intersection_t> missed_inters{};
0228 missed_intersections.push_back(
0229 std::make_pair(test_traj, missed_inters));
0230 }
0231
0232 if (!success) {
0233
0234 *debug_file << "TEST TRACK " << n_tracks << ":\n\n"
0235 << "NAVIGATOR\n\n"
0236 << nav_printer.to_string() << "\nSTEPPER\n\n"
0237 << step_printer.to_string();
0238
0239 detector_scanner::display_error(
0240 m_gctx, m_det, m_names, m_cfg.name(), test_traj, truth_trace,
0241 m_cfg.svg_style(), n_tracks, m_cfg.n_tracks(),
0242 obj_tracer.object_trace);
0243 }
0244
0245 recorded_traces.push_back(std::move(obj_tracer.object_trace));
0246 mat_records.push_back(mat_record);
0247
0248 EXPECT_TRUE(success)
0249 << "\nDETRAY INFO (HOST): Wrote navigation debugging data in: "
0250 << debug_file_name;
0251
0252 ++n_tracks;
0253
0254
0255 ASSERT_EQ(truth_trace.size(), recorded_traces.back().size());
0256
0257
0258 navigation_validator::surface_stats n_truth{};
0259 navigation_validator::surface_stats n_nav{};
0260 for (std::size_t i = 0; i < truth_trace.size(); ++i) {
0261 const auto truth_desc = truth_trace[i].intersection.surface();
0262 const auto rec_desc = recorded_traces.back()[i].intersection.surface();
0263
0264
0265 if (!truth_desc.identifier().is_invalid()) {
0266 n_truth.count(truth_desc);
0267 }
0268 if (!rec_desc.identifier().is_invalid()) {
0269 n_nav.count(rec_desc);
0270 }
0271 }
0272
0273
0274 const std::size_t n_portals{
0275 math::max(n_truth.n_portals, n_nav.n_portals)};
0276 const std::size_t n_sensitives{
0277 math::max(n_truth.n_sensitives, n_nav.n_sensitives)};
0278 const std::size_t n_passives{
0279 math::max(n_truth.n_passives, n_nav.n_passives)};
0280 const std::size_t n{n_portals + n_sensitives + n_passives};
0281
0282
0283
0284 ASSERT_TRUE(n >= (truth_trace.size() - 1u));
0285
0286 n_surfaces.n_portals += n_portals;
0287 n_surfaces.n_sensitives += n_sensitives;
0288 n_surfaces.n_passives += n_passives;
0289 }
0290
0291
0292 navigation_validator::print_efficiency(n_tracks, n_surfaces, n_miss_nav,
0293 n_miss_truth, n_fatal,
0294 n_matching_error);
0295
0296
0297 if constexpr (!k_use_rays) {
0298 momentum_str =
0299 "_" +
0300 std::to_string(std::floor(10. * static_cast<double>(min_pT)) / 10.) +
0301 "_" +
0302 std::to_string(std::ceil(10. * static_cast<double>(max_pT)) / 10.) +
0303 "_GeV";
0304 }
0305
0306 const auto truth_trk_path{make_path("truth_track_params")};
0307 const auto trk_path{make_path("navigation_track_params")};
0308 const auto truth_intr_path{make_path("truth_intersections")};
0309 const auto intr_path{make_path("navigation_intersections")};
0310 const auto mat_path{make_path("accumulated_material")};
0311 const auto missed_path{make_path("missed_intersections_dists")};
0312
0313
0314
0315 navigation_validator::write_dist_to_boundary(
0316 m_det, m_names, missed_path.string(), missed_intersections);
0317 detector_scanner::write_tracks(truth_trk_path.string(), truth_traces);
0318 navigation_validator::write_tracks(trk_path.string(), recorded_traces);
0319 detector_scanner::write_intersections(truth_intr_path.string(),
0320 truth_traces);
0321 detector_scanner::write_intersections(intr_path.string(), recorded_traces);
0322 material_validator::write_material(mat_path.string(), mat_records);
0323
0324 DETRAY_INFO_HOST("Wrote distance to boundary of missed intersections in: "
0325 << missed_path);
0326 DETRAY_INFO_HOST("Wrote truth track states in: " << truth_trk_path);
0327 DETRAY_INFO_HOST("Wrote recorded track states in: " << trk_path);
0328 DETRAY_INFO_HOST(
0329 "Wrote recorded truth intersections in: " << truth_intr_path);
0330 DETRAY_INFO_HOST("Wrote recorded track intersections in: " << intr_path);
0331 DETRAY_INFO_HOST("Wrote accumulated material in: " << mat_path);
0332 }
0333
0334 private:
0335
0336
0337 trajectory_type get_parametrized_trajectory(
0338 const free_track_parameters_t &track) {
0339 std::unique_ptr<trajectory_type> test_traj{nullptr};
0340 if constexpr (k_use_rays) {
0341 test_traj = std::make_unique<trajectory_type>(track);
0342 } else {
0343 test_traj = std::make_unique<trajectory_type>(track, m_cfg.B_vector());
0344 }
0345 return *(test_traj.release());
0346 }
0347
0348
0349 config m_cfg;
0350
0351 typename detector_t::geometry_context m_gctx{};
0352
0353 vecmem::host_memory_resource m_host_mr{};
0354
0355 const detector_t &m_det;
0356
0357 const typename detector_t::name_map &m_names;
0358
0359 std::shared_ptr<test::whiteboard> m_whiteboard{nullptr};
0360 };
0361
0362 template <typename detector_t>
0363 using straight_line_navigation =
0364 navigation_validation<detector_t, detray::ray_scan>;
0365
0366 template <typename detector_t>
0367 using helix_navigation = navigation_validation<detector_t, detray::helix_scan>;
0368
0369 }