Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:24:15

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 // Project include(s)
0012 #include "detray/definitions/algebra.hpp"
0013 #include "detray/geometry/surface.hpp"
0014 #include "detray/navigation/detail/intersection_kernel.hpp"
0015 #include "detray/navigation/intersection/intersection.hpp"
0016 #include "detray/navigation/intersector.hpp"
0017 #include "detray/tracks/free_track_parameters.hpp"
0018 #include "detray/tracks/trajectories.hpp"
0019 
0020 // Detray IO include(s)
0021 #include "detray/io/csv/intersection2D.hpp"
0022 #include "detray/io/csv/track_parameters.hpp"
0023 
0024 // System include(s)
0025 #include <algorithm>
0026 #include <sstream>
0027 #include <stdexcept>
0028 #include <type_traits>
0029 
0030 namespace detray {
0031 
0032 /// Record of a surface intersection along a track
0033 template <typename detector_t>
0034 struct intersection_record {
0035   using algebra_t = typename detector_t::algebra_type;
0036   using scalar_t = dscalar<algebra_t>;
0037   using track_parameter_type = free_track_parameters<algebra_t>;
0038   using intersection_type =
0039       intersection2D<typename detector_t::surface_type, algebra_t,
0040                      intersection::contains_pos>;
0041 
0042   /// The charge associated with the track parameters
0043   scalar_t charge{};
0044   /// Current global track parameters
0045   track_parameter_type track_param{
0046       {0.f, 0.f, 0.f}, 0.f, {0.f, 0.f, 1.f}, detail::invalid_value<scalar_t>()};
0047   /// Index of the volume the intersection was found in
0048   dindex vol_idx{};
0049   /// The intersection result, including the surface descriptor
0050   intersection_type intersection{};
0051 };
0052 
0053 /// @brief struct that holds functionality to shoot a parametrized particle
0054 /// trajectory through a detector.
0055 ///
0056 /// Records intersections with every detector surface along the trajectory.
0057 template <typename trajectory_t>
0058 struct brute_force_scan {
0059   template <typename D>
0060   using intersection_trace_type = std::vector<intersection_record<D>>;
0061   using trajectory_type = trajectory_t;
0062 
0063   template <typename detector_t>
0064   inline auto operator()(const typename detector_t::geometry_context ctx,
0065                          const detector_t &detector, const trajectory_t &traj,
0066                          const typename detector_t::scalar_type mask_tol = 0.f,
0067                          const typename detector_t::scalar_type p =
0068                              1.f *
0069                              unit<typename detector_t::scalar_type>::GeV) {
0070     using algebra_t = typename detector_t::algebra_type;
0071     using scalar_t = dscalar<algebra_t>;
0072     using sf_desc_t = typename detector_t::surface_type;
0073     using nav_link_t = typename detector_t::surface_type::navigation_link;
0074 
0075     using intersection_t =
0076         typename intersection_record<detector_t>::intersection_type;
0077     using intersection_kernel_t = detail::intersection_initialize<intersector>;
0078 
0079     constexpr scalar_t external_mask_tol{0.f};
0080     const intersection::config intr_cfg{
0081         .min_mask_tolerance = static_cast<float>(mask_tol),
0082         .max_mask_tolerance = static_cast<float>(mask_tol),
0083         .mask_tolerance_scalor = 0.f,
0084         .overstep_tolerance = 0.f};
0085 
0086     intersection_trace_type<detector_t> intersection_trace;
0087 
0088     const auto &trf_store = detector.transform_store();
0089 
0090     assert(p > 0.f);
0091     const scalar_t q{p * traj.qop()};
0092 
0093     std::vector<intersection_t> intersections{};
0094     intersections.reserve(100u);
0095 
0096     // Loop over all surfaces in the detector
0097     for (const sf_desc_t &sf_desc : detector.surfaces()) {
0098       // Retrieve candidate(s) from the surface
0099       const auto sf = geometry::surface{detector, sf_desc};
0100       sf.template visit_mask<intersection_kernel_t>(
0101           intersections, traj, sf_desc, trf_store, ctx, intr_cfg,
0102           external_mask_tol);
0103 
0104       // Candidate is invalid if it lies in the opposite direction
0105       for (auto &sfi : intersections) {
0106         if (sfi.is_along()) {
0107           sfi.surface() = sf_desc;
0108           // Record the intersection
0109           intersection_trace.push_back(
0110               {q,
0111                {traj.pos(sfi.path()), 0.f, p * traj.dir(sfi.path()), q},
0112                sf.volume(),
0113                sfi});
0114         }
0115       }
0116       intersections.clear();
0117     }
0118 
0119     // Need to have at least an exit portal
0120     if (intersection_trace.empty()) {
0121       std::stringstream stream;
0122       stream << "No intersections found for traj: " << traj << std::endl;
0123       throw std::runtime_error(stream.str());
0124     }
0125 
0126     // Save initial track position as dummy intersection record
0127     const auto &first_record = intersection_trace.front();
0128     intersection_t start_intersection{};
0129     // Must not be invalid, since it will otherwise throw the navigation
0130     // validation off
0131     start_intersection.set_surface(first_record.intersection.surface());
0132     start_intersection.surface().set_id(surface_id::e_passive);
0133     start_intersection.surface().set_index(0);
0134     start_intersection.surface()
0135         .material()
0136         .set_id(detector_t::material::id::e_none)
0137         .set_index(dindex_invalid);
0138     start_intersection.set_path(0.f);
0139     start_intersection.set_local({0.f, 0.f, 0.f});
0140     start_intersection.set_volume_link(
0141         static_cast<nav_link_t>(first_record.vol_idx));
0142 
0143     intersection_trace.insert(
0144         intersection_trace.begin(),
0145         intersection_record<detector_t>{q,
0146                                         {traj.pos(), 0.f, p * traj.dir(), q},
0147                                         first_record.vol_idx,
0148                                         start_intersection});
0149 
0150     return intersection_trace;
0151   }
0152 };
0153 
0154 template <concepts::algebra algebra_t>
0155 using ray_scan = brute_force_scan<detail::ray<algebra_t>>;
0156 
0157 template <concepts::algebra algebra_t>
0158 using helix_scan = brute_force_scan<detail::helix<algebra_t>>;
0159 
0160 /// Run a scan on detector object by shooting test particles through it
0161 namespace detector_scanner {
0162 
0163 template <template <typename> class scan_type, typename detector_t,
0164           typename trajectory_t, typename... Args>
0165 inline auto run(const typename detector_t::geometry_context gctx,
0166                 const detector_t &detector, const trajectory_t &traj,
0167                 Args &&...args) {
0168   using algebra_t = typename detector_t::algebra_type;
0169   using nav_link_t = typename detector_t::surface_type::navigation_link;
0170 
0171   auto intersection_record =
0172       scan_type<algebra_t>{}(gctx, detector, traj, std::forward<Args>(args)...);
0173 
0174   using record_t = typename decltype(intersection_record)::value_type;
0175 
0176   // HACK: For whatever reason, std::stable_sort really dislikes custom
0177   // aligned types like the ones in Eigen and Fastor, so we have to sort
0178   // by indices and then reconstruct the sorted intersection record.
0179   auto sort_path = [&](const record_t &a, const record_t &b) -> bool {
0180     return (a.intersection < b.intersection);
0181   };
0182 
0183   std::ranges::stable_sort(intersection_record, sort_path);
0184 
0185   // Make sure the intersection record terminates at world portals
0186   auto is_world_exit = [](const record_t &r) {
0187     return r.intersection.volume_link() ==
0188            detray::detail::invalid_value<nav_link_t>();
0189   };
0190 
0191   if (auto it = std::ranges::find_if(intersection_record, is_world_exit);
0192       it != intersection_record.end()) {
0193     auto n{static_cast<std::size_t>(it - intersection_record.begin())};
0194     intersection_record.resize(n + 1u);
0195   }
0196 
0197   return intersection_record;
0198 }
0199 
0200 /// Write the @param intersection_traces to file
0201 template <typename detector_t>
0202 inline auto write_intersections(
0203     const std::string &intersection_file_name,
0204     const std::vector<std::vector<intersection_record<detector_t>>>
0205         &intersection_traces) {
0206   using record_t = intersection_record<detector_t>;
0207   using intersection_t = typename record_t::intersection_type;
0208 
0209   std::vector<std::vector<intersection_t>> intersections{};
0210 
0211   // Split data
0212   for (const auto &trace : intersection_traces) {
0213     auto &intrs = intersections.emplace_back();
0214     intrs.reserve(trace.size());
0215 
0216     for (const auto &record : trace) {
0217       intrs.push_back(record.intersection);
0218     }
0219   }
0220 
0221   // Write to file
0222   io::csv::write_intersection2D(intersection_file_name, intersections);
0223 }
0224 
0225 /// Write the @param intersection_traces to file
0226 template <typename record_t>
0227 inline auto write_intersections(
0228     const std::string &intersection_file_name,
0229     const dvector<dvector<record_t>> &intersection_traces) {
0230   using intersection_t = typename record_t::intersection_type;
0231 
0232   std::vector<std::vector<intersection_t>> intersections{};
0233 
0234   // Split data
0235   for (const auto &trace : intersection_traces) {
0236     auto &intrs = intersections.emplace_back();
0237     intrs.reserve(trace.size());
0238 
0239     for (const auto &record : trace) {
0240       intrs.push_back(record.intersection);
0241     }
0242   }
0243 
0244   // Write to file
0245   io::csv::write_intersection2D(intersection_file_name, intersections);
0246 }
0247 
0248 /// Write the @param intersection_traces to file
0249 template <typename detector_t>
0250 inline auto write_tracks(
0251     const std::string &track_param_file_name,
0252     const std::vector<std::vector<intersection_record<detector_t>>>
0253         &intersection_traces) {
0254   using scalar_t = dscalar<typename detector_t::algebra_type>;
0255   using record_t = intersection_record<detector_t>;
0256   using track_param_t = typename record_t::track_parameter_type;
0257 
0258   std::vector<std::vector<std::pair<scalar_t, track_param_t>>> track_params{};
0259 
0260   // Split data
0261   for (const auto &trace : intersection_traces) {
0262     track_params.push_back({});
0263     track_params.back().reserve(trace.size());
0264 
0265     for (const auto &record : trace) {
0266       track_params.back().emplace_back(record.charge, record.track_param);
0267     }
0268   }
0269 
0270   // Write to file
0271   io::csv::write_free_track_params(track_param_file_name, track_params);
0272 }
0273 
0274 /// Read the @param intersection_record from file
0275 template <typename detector_t>
0276 inline auto read(const std::string &intersection_file_name,
0277                  const std::string &track_param_file_name,
0278                  std::vector<std::vector<intersection_record<detector_t>>>
0279                      &intersection_traces) {
0280   // Read from file
0281   auto intersections_per_track =
0282       io::csv::read_intersection2D<detector_t>(intersection_file_name);
0283   auto track_params_per_track =
0284       io::csv::read_free_track_params<detector_t>(track_param_file_name);
0285 
0286   if (intersections_per_track.size() != track_params_per_track.size()) {
0287     throw std::invalid_argument(
0288         "Detector scanner: intersection and track parameters "
0289         "collections "
0290         "have different size");
0291   }
0292 
0293   // Interleave data
0294   for (dindex trk_idx = 0u; trk_idx < intersections_per_track.size();
0295        ++trk_idx) {
0296     const auto &intersections = intersections_per_track[trk_idx];
0297     const auto &track_params = track_params_per_track[trk_idx];
0298 
0299     // Check track id
0300     if (intersections.size() != track_params.size()) {
0301       throw std::invalid_argument(
0302           "Detector scanner: Found different number of intersections "
0303           "and "
0304           "track parameters for track no." +
0305           std::to_string(trk_idx));
0306     }
0307 
0308     // Check for empty input traces
0309     if (intersections.empty()) {
0310       throw std::invalid_argument("Detector scanner: Found empty trace no." +
0311                                   std::to_string(trk_idx));
0312     }
0313 
0314     // Add new trace
0315     if (intersection_traces.size() <= trk_idx) {
0316       intersection_traces.push_back({});
0317     }
0318 
0319     // Add records to trace
0320     for (dindex i = 0u; i < intersections.size(); ++i) {
0321       intersection_traces[trk_idx].push_back(intersection_record<detector_t>{
0322           track_params[i].first, track_params[i].second,
0323           intersections[i].surface().volume(), intersections[i]});
0324     }
0325   }
0326 }
0327 
0328 }  // namespace detector_scanner
0329 
0330 }  // namespace detray