Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:24:05

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 // Project include(s).
0012 #include "detray/definitions/algebra.hpp"
0013 #include "detray/definitions/containers.hpp"
0014 #include "detray/definitions/detail/qualifiers.hpp"
0015 #include "detray/definitions/math.hpp"
0016 #include "detray/definitions/units.hpp"
0017 #include "detray/navigation/intersection/intersection.hpp"
0018 #include "detray/navigation/intersection/intersection_config.hpp"
0019 #include "detray/utils/invalid_values.hpp"
0020 #include "detray/utils/logging.hpp"
0021 
0022 // System include(s).
0023 #include <algorithm>
0024 #include <iostream>
0025 #include <limits>
0026 #include <sstream>
0027 #include <stdexcept>
0028 #include <string>
0029 
0030 namespace detray {
0031 
0032 /// @brief Try to find a bracket around a root
0033 ///
0034 /// @param [in] a lower initial boundary
0035 /// @param [in] b upper initial boundary
0036 /// @param [in] f function for which to find the root
0037 /// @param [out] bracket bracket around the root
0038 /// @param [in] k scale factor with which to widen the bracket at every step
0039 ///
0040 /// @see Numerical Recepes pp. 445
0041 ///
0042 /// @return whether a bracket was found
0043 template <concepts::scalar scalar_t, typename function_t>
0044 DETRAY_HOST_DEVICE inline bool expand_bracket(const scalar_t a,
0045                                               const scalar_t b, function_t &f,
0046                                               darray<scalar_t, 2> &bracket,
0047                                               const scalar_t k = 1.f) {
0048   if (a == b) {
0049     throw std::invalid_argument(
0050         "Root bracketing: Not a valid start interval [" + std::to_string(a) +
0051         ", " + std::to_string(b) + "]");
0052   }
0053 
0054   scalar_t lower{a > b ? b : a};
0055   scalar_t upper{a > b ? a : b};
0056 
0057   // Sample function points at interval
0058   scalar_t f_l{f(lower)};
0059   scalar_t f_u{f(upper)};
0060   std::size_t n_tries{0u};
0061 
0062   /// Check if the bracket has become invalid
0063   const auto check_bracket = [a, b, &bracket](std::size_t n, scalar_t fl,
0064                                               scalar_t fu, scalar_t l,
0065                                               scalar_t u) {
0066     if ((n == 1000u) || !std::isfinite(fl) || !std::isfinite(fu) ||
0067         !std::isfinite(l) || !std::isfinite(u)) {
0068       DETRAY_VERBOSE_HOST("Could not bracket a root (a="
0069                           << l << ", b=" << u << ", f(a)=" << fl
0070                           << ", f(b)=" << fu
0071                           << ", root might not exist). Running "
0072                              "Newton-Raphson without bisection.");
0073       // Reset value
0074       bracket = {a, b};
0075       return false;
0076     }
0077     return true;
0078   };
0079 
0080   // If there is no sign change in interval, we don't know if there is a root
0081   while (!math::signbit(f_l * f_u)) {
0082     // No interval could be found to bracket the root
0083     // Might be correct, if there is not root
0084     if (!check_bracket(n_tries, f_l, f_u, lower, upper)) {
0085       return false;
0086     }
0087     scalar_t d{k * (upper - lower)};
0088     // Make interval larger in the direction where the function is smaller
0089     if (math::fabs(f_l) < math::fabs(f_u)) {
0090       lower -= d;
0091       f_l = f(lower);
0092     } else {
0093       upper += d;
0094       f_u = f(upper);
0095     }
0096     ++n_tries;
0097   }
0098 
0099   if (!check_bracket(n_tries, f_l, f_u, lower, upper)) {
0100     return false;
0101   } else {
0102     bracket = {lower, upper};
0103     return true;
0104   }
0105 }
0106 
0107 /// @brief Find a root using the Newton-Raphson algorithm
0108 ///
0109 /// @param evaluate_func evaluate the function and its derivative
0110 /// @param s initial guess for the root
0111 /// @param convergence_tolerance max distance between from root before finished
0112 /// @param max_n_tries max number of Newton-Bisection step to try
0113 /// @param max_path don't consider root if it is too far away
0114 ///
0115 /// @see Numerical Recepes pp. 445
0116 ///
0117 /// @return pathlength to root and the last step size
0118 template <typename scalar_t, typename function_t>
0119 DETRAY_HOST_DEVICE inline std::pair<scalar_t, scalar_t> newton_raphson(
0120     function_t &evaluate_func, scalar_t s,
0121     const scalar_t convergence_tolerance = 1.f * unit<scalar_t>::um,
0122     const std::size_t max_n_tries = 1000u,
0123     const scalar_t max_path = 5.f * unit<scalar_t>::m) {
0124   constexpr scalar_t inv{detail::invalid_value<scalar_t>()};
0125   constexpr scalar_t epsilon{std::numeric_limits<scalar_t>::epsilon()};
0126 
0127   if (math::fabs(s) >= max_path) {
0128     DETRAY_VERBOSE_HOST("Initial path estimate outside search area: s=" << s);
0129   }
0130   if (math::fabs(s) >= inv) {
0131     const std::string err_msg{"Initial path estimate invalid"};
0132     DETRAY_FATAL_HOST(err_msg);
0133     throw std::invalid_argument(err_msg);
0134   }
0135 
0136   // Run the iteration on s
0137   scalar_t s_prev{0.f};
0138   std::size_t n_tries{0u};
0139   auto [f_s, df_s] = evaluate_func(s);
0140 
0141   while (math::fabs(s - s_prev) > convergence_tolerance) {
0142     // Root already found?
0143     if (math::fabs(f_s) < convergence_tolerance) {
0144       return std::make_pair(s, epsilon);
0145     }
0146 
0147     // No intersection can be found if dividing by zero
0148     if (math::fabs(df_s) == 0.f) {
0149       DETRAY_VERBOSE_HOST(
0150           "Newton step encountered invalid derivative "
0151           "- skipping");
0152       return std::make_pair(inv, inv);
0153     }
0154 
0155     // Newton step
0156     s_prev = s;
0157     s -= f_s / df_s;
0158 
0159     // Update function evaluation
0160     std::tie(f_s, df_s) = evaluate_func(s);
0161 
0162     ++n_tries;
0163 
0164     // No intersection found within max number of trials
0165     if (n_tries >= max_n_tries) {
0166       DETRAY_VERBOSE_HOST("Helix intersector did not converge after "
0167                           << n_tries << " steps - skipping");
0168       return std::make_pair(inv, inv);
0169     }
0170   }
0171   // Final pathlengt to root and latest step size
0172   return std::make_pair(s, math::fabs(s - s_prev));
0173 }
0174 
0175 /// @brief Find a root using the Newton-Raphson and Bisection algorithms
0176 ///
0177 /// @param evaluate_func evaluate the function and its derivative
0178 /// @param s initial guess for the root
0179 /// @param convergence_tolerance max distance between from root before finished
0180 /// @param max_n_tries max number of Newton-Bisection step to try
0181 /// @param max_path don't consider root if it is too far away
0182 ///
0183 /// @see Numerical Recepes pp. 445
0184 ///
0185 /// @return pathlength to root and the last step size
0186 template <concepts::scalar scalar_t, typename function_t>
0187 DETRAY_HOST_DEVICE inline std::pair<scalar_t, scalar_t> newton_raphson_safe(
0188     function_t &evaluate_func, scalar_t s,
0189     const scalar_t convergence_tolerance = 1.f * unit<scalar_t>::um,
0190     const std::size_t max_n_tries = 1000u,
0191     const scalar_t max_path = 5.f * unit<scalar_t>::m) {
0192   constexpr scalar_t inv{detail::invalid_value<scalar_t>()};
0193   constexpr scalar_t epsilon{std::numeric_limits<scalar_t>::epsilon()};
0194 
0195   // Evaluate the test function at point 'x'
0196   auto f = [&evaluate_func](const scalar_t x) {
0197     auto [f_x, df_x] = evaluate_func(x);
0198 
0199     return f_x;
0200   };
0201 
0202   if (math::fabs(s) >= max_path) {
0203     DETRAY_VERBOSE_HOST("Initial path estimate outside search area: s=" << s);
0204   }
0205   if (math::fabs(s) >= inv) {
0206     const std::string err_msg{"Initial path estimate invalid"};
0207     DETRAY_ERROR_HOST(err_msg);
0208     throw std::invalid_argument(err_msg);
0209   }
0210 
0211   // Initial bracket (test a certain range around 's')
0212   scalar_t a{math::fabs(s) == 0.f ? -0.2f : 0.8f * s};
0213   scalar_t b{math::fabs(s) == 0.f ? 0.2f : 1.2f * s};
0214   darray<scalar_t, 2> br{};
0215   bool is_bracketed = expand_bracket(a, b, f, br);
0216 
0217   // Update initial guess on the root after bracketing
0218   s = is_bracketed ? 0.5f * (br[1] + br[0]) : s;
0219 
0220   if (!is_bracketed) {
0221     DETRAY_VERBOSE_HOST("Bracketing failed for initial path estimate: s=" << s);
0222   } else {
0223     // Check bracket
0224     [[maybe_unused]] auto [f_a, df_a] = evaluate_func(br[0]);
0225     [[maybe_unused]] auto [f_b, df_b] = evaluate_func(br[1]);
0226 
0227     // Bracket is not guaranteed to contain a root
0228     if (!math::signbit(f_a * f_b)) {
0229       throw std::runtime_error(
0230           "Incorrect bracket around root: No sign change!");
0231     }
0232 
0233     // No bisection algorithm possible if one bracket boundary is inf
0234     // (is already checked in bracketing alg)
0235     if ((math::fabs(br[0]) >= inv) || (math::fabs(br[1]) >= inv)) {
0236       throw std::runtime_error(
0237           "Incorrect bracket around root: Boundary reached inf!");
0238     }
0239 
0240     // Root is not within the maximal pathlength
0241     bool bracket_outside_tol{math::fabs(s) > max_path &&
0242                              math::fabs(br[0]) >= max_path &&
0243                              math::fabs(br[1]) >= max_path};
0244     if (bracket_outside_tol) {
0245       DETRAY_VERBOSE_HOST("Root outside maximum search area (s = "
0246                           << s << ", a: " << br[0] << ", b: " << br[1]
0247                           << ") - skipping");
0248       return std::make_pair(inv, inv);
0249     }
0250 
0251     // Root already found?
0252     if (math::fabs(f_a) < convergence_tolerance) {
0253       return std::make_pair(a, epsilon);
0254     }
0255     if (math::fabs(f_b) < convergence_tolerance) {
0256       return std::make_pair(b, epsilon);
0257     }
0258 
0259     // Make 'a' the boundary for the negative function value -> easier to
0260     // update
0261     bool is_lower_a{math::signbit(f_a)};
0262     a = br[is_lower_a ? 0u : 1u];
0263     b = br[is_lower_a ? 1u : 0u];
0264   }
0265 
0266   // Run the iteration on s
0267   scalar_t s_prev{0.f};
0268   std::size_t n_tries{0u};
0269   auto [f_s, df_s] = evaluate_func(s);
0270   if (math::fabs(f_s) < convergence_tolerance) {
0271     return std::make_pair(s, epsilon);
0272   }
0273   if (math::signbit(f_s)) {
0274     a = s;
0275   } else {
0276     b = s;
0277   }
0278 
0279   while (math::fabs(s - s_prev) > convergence_tolerance) {
0280     // Does Newton step escape bracket?
0281     bool bracket_escape{true};
0282     scalar_t s_newton{0.f};
0283     if (math::fabs(df_s) != 0.f) {
0284       s_newton = s - f_s / df_s;
0285       bracket_escape = math::signbit((s_newton - a) * (b - s_newton));
0286     }
0287 
0288     // This criterion from Numerical Recipes seems to work, but why?
0289     /*const bool slow_convergence{math::fabs(2.f * f_s) >
0290                                 math::fabs((s_prev - s) * df_s)};*/
0291 
0292     // Take a bisection step if it converges faster than Newton
0293     // |f(next_newton_s)| > |f(next_bisection_s)|
0294     bool slow_convergence{true};
0295     // The criterion is only well defined if the step lengths are small
0296     if (const scalar_t ds_bisection{0.5f * (a + b) - s};
0297         is_bracketed &&
0298         (math::fabs(ds_bisection) < 10.f * unit<scalar_t>::mm)) {
0299       slow_convergence =
0300           (2.f * math::fabs(f_s) > math::fabs(df_s * ds_bisection + f_s));
0301     }
0302 
0303     s_prev = s;
0304 
0305     // Run bisection if Newton-Raphson would be poor
0306     if (is_bracketed &&
0307         (bracket_escape || slow_convergence || math::fabs(df_s) == 0.f)) {
0308       // Test the function sign in the middle of the interval
0309       s = 0.5f * (a + b);
0310     } else {
0311       // No intersection can be found if dividing by zero
0312       if (!is_bracketed && math::fabs(df_s) == 0.f) {
0313         DETRAY_VERBOSE_HOST("Newton step encountered invalid derivative at s="
0314                             << s << " after " << n_tries
0315                             << " steps - skipping");
0316 
0317         return std::make_pair(inv, inv);
0318       }
0319 
0320       s = s_newton;
0321     }
0322 
0323     // Update function and bracket
0324     std::tie(f_s, df_s) = evaluate_func(s);
0325     if (is_bracketed && math::signbit(f_s)) {
0326       a = s;
0327     } else {
0328       b = s;
0329     }
0330 
0331     // Converges to a point outside the search space - early stop
0332     if (math::fabs(s) > max_path && math::fabs(s_prev) > max_path &&
0333         ((a < -max_path && b < -max_path) || (a > max_path && b > max_path))) {
0334       DETRAY_VERBOSE_HOST("WARNING: Root finding left the search space at (s = "
0335                           << s << ", a: " << a << ", b: " << b << ") after "
0336                           << n_tries << " steps - skipping");
0337 
0338       return std::make_pair(inv, inv);
0339     }
0340 
0341     ++n_tries;
0342 
0343     // No intersection found within max number of trials
0344     if (n_tries >= max_n_tries) {
0345       // Should have found the root
0346       if (is_bracketed) {
0347         std::stringstream err_str{};
0348         err_str << "Helix intersector did not find root for s=" << s << " in ["
0349                 << a << ", " << b << "]";
0350 
0351         DETRAY_FATAL_HOST(err_str.str());
0352         throw std::runtime_error(err_str.str());
0353       } else {
0354         DETRAY_VERBOSE_HOST("Helix intersector did not converge after "
0355                             << n_tries
0356                             << " steps unbracketed search - skipping");
0357       }
0358       return std::make_pair(inv, inv);
0359     }
0360   }
0361   // Final pathlengt to root and latest step size
0362   return std::make_pair(s, math::fabs(s - s_prev));
0363 }
0364 
0365 /// @brief Fill an intersection with the result of the root finding
0366 ///
0367 /// @param [out] sfi the surface intersection
0368 /// @param [in] traj the test trajectory that intersects the surface
0369 /// @param [in] s path length to the root
0370 /// @param [in] ds approximation error for the root
0371 /// @param [in] mask the mask of the surface
0372 /// @param [in] trf the transform of the surface
0373 /// @param [in] mask_tolerance minimal and maximal mask tolerance
0374 template <typename intersection_t, concepts::algebra algebra_t,
0375           typename surface_descr_t, typename mask_t, typename trajectory_t,
0376           concepts::transform3D transform3_t, concepts::scalar scalar_t>
0377 DETRAY_HOST_DEVICE constexpr void resolve_mask(
0378     intersection_t &is, const trajectory_t &traj,
0379     const intersection_point_err<algebra_t> &ip, const surface_descr_t sf_desc,
0380     const mask_t &mask, const transform3_t &trf,
0381     const intersection::config &intr_cfg,
0382     const scalar_t /*external_mask_tol*/ = 0.f) {
0383   assert((intr_cfg.min_mask_tolerance == intr_cfg.max_mask_tolerance) &&
0384          "Helix intersectors use only one mask tolerance value");
0385 
0386   // Build intersection struct from test trajectory, if the distance is valid
0387   if (!detail::is_invalid_value(ip.path)) {
0388     is.set_path(ip.path);
0389     is.set_local(
0390         mask_t::to_local_frame3D(trf, traj.pos(ip.path), traj.dir(ip.path)));
0391 
0392     const scalar_t cos_incidence_angle = vector::dot(
0393         mask_t::get_local_frame().normal(trf, is.local()), traj.dir(ip.path));
0394 
0395     scalar_t tol{sf_desc.is_portal() ? 0.f : intr_cfg.min_mask_tolerance};
0396     // If tolerance is inf, use tolerance estimation (intr_cfg is 'float'!)
0397     if (tol >= detail::invalid_value<float>()) {
0398       // Due to floating point errors this can be negative if cos ~ 1
0399       const scalar_t sin_inc2{
0400           math::fabs(1.f - cos_incidence_angle * cos_incidence_angle)};
0401 
0402       tol = math::fabs(ip.path_err * math::sqrt(sin_inc2));
0403     }
0404     // Make sure the tol. has been estimated/configured in a sensible way
0405     assert(!math::signbit(tol));
0406     assert(tol < 1000.f * unit<scalar_t>::mm);
0407 
0408     is.set_status(mask.is_inside(is.local(), tol)
0409                       ? intersection::status::e_inside
0410                       : intersection::status::e_outside);
0411     is.set_surface(sf_desc);
0412     is.set_direction(!math::signbit(ip.path));
0413     is.set_volume_link(mask.volume_link());
0414   } else {
0415     // Not a valid intersection
0416     is.set_status(intersection::status::e_outside);
0417   }
0418 }
0419 
0420 }  // namespace detray