File indexing completed on 2026-05-27 07:24:05
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011
0012 #include "detray/definitions/algebra.hpp"
0013 #include "detray/definitions/containers.hpp"
0014 #include "detray/definitions/detail/qualifiers.hpp"
0015 #include "detray/definitions/math.hpp"
0016 #include "detray/definitions/units.hpp"
0017 #include "detray/navigation/intersection/intersection.hpp"
0018 #include "detray/navigation/intersection/intersection_config.hpp"
0019 #include "detray/utils/invalid_values.hpp"
0020 #include "detray/utils/logging.hpp"
0021
0022
0023 #include <algorithm>
0024 #include <iostream>
0025 #include <limits>
0026 #include <sstream>
0027 #include <stdexcept>
0028 #include <string>
0029
0030 namespace detray {
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043 template <concepts::scalar scalar_t, typename function_t>
0044 DETRAY_HOST_DEVICE inline bool expand_bracket(const scalar_t a,
0045 const scalar_t b, function_t &f,
0046 darray<scalar_t, 2> &bracket,
0047 const scalar_t k = 1.f) {
0048 if (a == b) {
0049 throw std::invalid_argument(
0050 "Root bracketing: Not a valid start interval [" + std::to_string(a) +
0051 ", " + std::to_string(b) + "]");
0052 }
0053
0054 scalar_t lower{a > b ? b : a};
0055 scalar_t upper{a > b ? a : b};
0056
0057
0058 scalar_t f_l{f(lower)};
0059 scalar_t f_u{f(upper)};
0060 std::size_t n_tries{0u};
0061
0062
0063 const auto check_bracket = [a, b, &bracket](std::size_t n, scalar_t fl,
0064 scalar_t fu, scalar_t l,
0065 scalar_t u) {
0066 if ((n == 1000u) || !std::isfinite(fl) || !std::isfinite(fu) ||
0067 !std::isfinite(l) || !std::isfinite(u)) {
0068 DETRAY_VERBOSE_HOST("Could not bracket a root (a="
0069 << l << ", b=" << u << ", f(a)=" << fl
0070 << ", f(b)=" << fu
0071 << ", root might not exist). Running "
0072 "Newton-Raphson without bisection.");
0073
0074 bracket = {a, b};
0075 return false;
0076 }
0077 return true;
0078 };
0079
0080
0081 while (!math::signbit(f_l * f_u)) {
0082
0083
0084 if (!check_bracket(n_tries, f_l, f_u, lower, upper)) {
0085 return false;
0086 }
0087 scalar_t d{k * (upper - lower)};
0088
0089 if (math::fabs(f_l) < math::fabs(f_u)) {
0090 lower -= d;
0091 f_l = f(lower);
0092 } else {
0093 upper += d;
0094 f_u = f(upper);
0095 }
0096 ++n_tries;
0097 }
0098
0099 if (!check_bracket(n_tries, f_l, f_u, lower, upper)) {
0100 return false;
0101 } else {
0102 bracket = {lower, upper};
0103 return true;
0104 }
0105 }
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118 template <typename scalar_t, typename function_t>
0119 DETRAY_HOST_DEVICE inline std::pair<scalar_t, scalar_t> newton_raphson(
0120 function_t &evaluate_func, scalar_t s,
0121 const scalar_t convergence_tolerance = 1.f * unit<scalar_t>::um,
0122 const std::size_t max_n_tries = 1000u,
0123 const scalar_t max_path = 5.f * unit<scalar_t>::m) {
0124 constexpr scalar_t inv{detail::invalid_value<scalar_t>()};
0125 constexpr scalar_t epsilon{std::numeric_limits<scalar_t>::epsilon()};
0126
0127 if (math::fabs(s) >= max_path) {
0128 DETRAY_VERBOSE_HOST("Initial path estimate outside search area: s=" << s);
0129 }
0130 if (math::fabs(s) >= inv) {
0131 const std::string err_msg{"Initial path estimate invalid"};
0132 DETRAY_FATAL_HOST(err_msg);
0133 throw std::invalid_argument(err_msg);
0134 }
0135
0136
0137 scalar_t s_prev{0.f};
0138 std::size_t n_tries{0u};
0139 auto [f_s, df_s] = evaluate_func(s);
0140
0141 while (math::fabs(s - s_prev) > convergence_tolerance) {
0142
0143 if (math::fabs(f_s) < convergence_tolerance) {
0144 return std::make_pair(s, epsilon);
0145 }
0146
0147
0148 if (math::fabs(df_s) == 0.f) {
0149 DETRAY_VERBOSE_HOST(
0150 "Newton step encountered invalid derivative "
0151 "- skipping");
0152 return std::make_pair(inv, inv);
0153 }
0154
0155
0156 s_prev = s;
0157 s -= f_s / df_s;
0158
0159
0160 std::tie(f_s, df_s) = evaluate_func(s);
0161
0162 ++n_tries;
0163
0164
0165 if (n_tries >= max_n_tries) {
0166 DETRAY_VERBOSE_HOST("Helix intersector did not converge after "
0167 << n_tries << " steps - skipping");
0168 return std::make_pair(inv, inv);
0169 }
0170 }
0171
0172 return std::make_pair(s, math::fabs(s - s_prev));
0173 }
0174
0175
0176
0177
0178
0179
0180
0181
0182
0183
0184
0185
0186 template <concepts::scalar scalar_t, typename function_t>
0187 DETRAY_HOST_DEVICE inline std::pair<scalar_t, scalar_t> newton_raphson_safe(
0188 function_t &evaluate_func, scalar_t s,
0189 const scalar_t convergence_tolerance = 1.f * unit<scalar_t>::um,
0190 const std::size_t max_n_tries = 1000u,
0191 const scalar_t max_path = 5.f * unit<scalar_t>::m) {
0192 constexpr scalar_t inv{detail::invalid_value<scalar_t>()};
0193 constexpr scalar_t epsilon{std::numeric_limits<scalar_t>::epsilon()};
0194
0195
0196 auto f = [&evaluate_func](const scalar_t x) {
0197 auto [f_x, df_x] = evaluate_func(x);
0198
0199 return f_x;
0200 };
0201
0202 if (math::fabs(s) >= max_path) {
0203 DETRAY_VERBOSE_HOST("Initial path estimate outside search area: s=" << s);
0204 }
0205 if (math::fabs(s) >= inv) {
0206 const std::string err_msg{"Initial path estimate invalid"};
0207 DETRAY_ERROR_HOST(err_msg);
0208 throw std::invalid_argument(err_msg);
0209 }
0210
0211
0212 scalar_t a{math::fabs(s) == 0.f ? -0.2f : 0.8f * s};
0213 scalar_t b{math::fabs(s) == 0.f ? 0.2f : 1.2f * s};
0214 darray<scalar_t, 2> br{};
0215 bool is_bracketed = expand_bracket(a, b, f, br);
0216
0217
0218 s = is_bracketed ? 0.5f * (br[1] + br[0]) : s;
0219
0220 if (!is_bracketed) {
0221 DETRAY_VERBOSE_HOST("Bracketing failed for initial path estimate: s=" << s);
0222 } else {
0223
0224 [[maybe_unused]] auto [f_a, df_a] = evaluate_func(br[0]);
0225 [[maybe_unused]] auto [f_b, df_b] = evaluate_func(br[1]);
0226
0227
0228 if (!math::signbit(f_a * f_b)) {
0229 throw std::runtime_error(
0230 "Incorrect bracket around root: No sign change!");
0231 }
0232
0233
0234
0235 if ((math::fabs(br[0]) >= inv) || (math::fabs(br[1]) >= inv)) {
0236 throw std::runtime_error(
0237 "Incorrect bracket around root: Boundary reached inf!");
0238 }
0239
0240
0241 bool bracket_outside_tol{math::fabs(s) > max_path &&
0242 math::fabs(br[0]) >= max_path &&
0243 math::fabs(br[1]) >= max_path};
0244 if (bracket_outside_tol) {
0245 DETRAY_VERBOSE_HOST("Root outside maximum search area (s = "
0246 << s << ", a: " << br[0] << ", b: " << br[1]
0247 << ") - skipping");
0248 return std::make_pair(inv, inv);
0249 }
0250
0251
0252 if (math::fabs(f_a) < convergence_tolerance) {
0253 return std::make_pair(a, epsilon);
0254 }
0255 if (math::fabs(f_b) < convergence_tolerance) {
0256 return std::make_pair(b, epsilon);
0257 }
0258
0259
0260
0261 bool is_lower_a{math::signbit(f_a)};
0262 a = br[is_lower_a ? 0u : 1u];
0263 b = br[is_lower_a ? 1u : 0u];
0264 }
0265
0266
0267 scalar_t s_prev{0.f};
0268 std::size_t n_tries{0u};
0269 auto [f_s, df_s] = evaluate_func(s);
0270 if (math::fabs(f_s) < convergence_tolerance) {
0271 return std::make_pair(s, epsilon);
0272 }
0273 if (math::signbit(f_s)) {
0274 a = s;
0275 } else {
0276 b = s;
0277 }
0278
0279 while (math::fabs(s - s_prev) > convergence_tolerance) {
0280
0281 bool bracket_escape{true};
0282 scalar_t s_newton{0.f};
0283 if (math::fabs(df_s) != 0.f) {
0284 s_newton = s - f_s / df_s;
0285 bracket_escape = math::signbit((s_newton - a) * (b - s_newton));
0286 }
0287
0288
0289
0290
0291
0292
0293
0294 bool slow_convergence{true};
0295
0296 if (const scalar_t ds_bisection{0.5f * (a + b) - s};
0297 is_bracketed &&
0298 (math::fabs(ds_bisection) < 10.f * unit<scalar_t>::mm)) {
0299 slow_convergence =
0300 (2.f * math::fabs(f_s) > math::fabs(df_s * ds_bisection + f_s));
0301 }
0302
0303 s_prev = s;
0304
0305
0306 if (is_bracketed &&
0307 (bracket_escape || slow_convergence || math::fabs(df_s) == 0.f)) {
0308
0309 s = 0.5f * (a + b);
0310 } else {
0311
0312 if (!is_bracketed && math::fabs(df_s) == 0.f) {
0313 DETRAY_VERBOSE_HOST("Newton step encountered invalid derivative at s="
0314 << s << " after " << n_tries
0315 << " steps - skipping");
0316
0317 return std::make_pair(inv, inv);
0318 }
0319
0320 s = s_newton;
0321 }
0322
0323
0324 std::tie(f_s, df_s) = evaluate_func(s);
0325 if (is_bracketed && math::signbit(f_s)) {
0326 a = s;
0327 } else {
0328 b = s;
0329 }
0330
0331
0332 if (math::fabs(s) > max_path && math::fabs(s_prev) > max_path &&
0333 ((a < -max_path && b < -max_path) || (a > max_path && b > max_path))) {
0334 DETRAY_VERBOSE_HOST("WARNING: Root finding left the search space at (s = "
0335 << s << ", a: " << a << ", b: " << b << ") after "
0336 << n_tries << " steps - skipping");
0337
0338 return std::make_pair(inv, inv);
0339 }
0340
0341 ++n_tries;
0342
0343
0344 if (n_tries >= max_n_tries) {
0345
0346 if (is_bracketed) {
0347 std::stringstream err_str{};
0348 err_str << "Helix intersector did not find root for s=" << s << " in ["
0349 << a << ", " << b << "]";
0350
0351 DETRAY_FATAL_HOST(err_str.str());
0352 throw std::runtime_error(err_str.str());
0353 } else {
0354 DETRAY_VERBOSE_HOST("Helix intersector did not converge after "
0355 << n_tries
0356 << " steps unbracketed search - skipping");
0357 }
0358 return std::make_pair(inv, inv);
0359 }
0360 }
0361
0362 return std::make_pair(s, math::fabs(s - s_prev));
0363 }
0364
0365
0366
0367
0368
0369
0370
0371
0372
0373
0374 template <typename intersection_t, concepts::algebra algebra_t,
0375 typename surface_descr_t, typename mask_t, typename trajectory_t,
0376 concepts::transform3D transform3_t, concepts::scalar scalar_t>
0377 DETRAY_HOST_DEVICE constexpr void resolve_mask(
0378 intersection_t &is, const trajectory_t &traj,
0379 const intersection_point_err<algebra_t> &ip, const surface_descr_t sf_desc,
0380 const mask_t &mask, const transform3_t &trf,
0381 const intersection::config &intr_cfg,
0382 const scalar_t = 0.f) {
0383 assert((intr_cfg.min_mask_tolerance == intr_cfg.max_mask_tolerance) &&
0384 "Helix intersectors use only one mask tolerance value");
0385
0386
0387 if (!detail::is_invalid_value(ip.path)) {
0388 is.set_path(ip.path);
0389 is.set_local(
0390 mask_t::to_local_frame3D(trf, traj.pos(ip.path), traj.dir(ip.path)));
0391
0392 const scalar_t cos_incidence_angle = vector::dot(
0393 mask_t::get_local_frame().normal(trf, is.local()), traj.dir(ip.path));
0394
0395 scalar_t tol{sf_desc.is_portal() ? 0.f : intr_cfg.min_mask_tolerance};
0396
0397 if (tol >= detail::invalid_value<float>()) {
0398
0399 const scalar_t sin_inc2{
0400 math::fabs(1.f - cos_incidence_angle * cos_incidence_angle)};
0401
0402 tol = math::fabs(ip.path_err * math::sqrt(sin_inc2));
0403 }
0404
0405 assert(!math::signbit(tol));
0406 assert(tol < 1000.f * unit<scalar_t>::mm);
0407
0408 is.set_status(mask.is_inside(is.local(), tol)
0409 ? intersection::status::e_inside
0410 : intersection::status::e_outside);
0411 is.set_surface(sf_desc);
0412 is.set_direction(!math::signbit(ip.path));
0413 is.set_volume_link(mask.volume_link());
0414 } else {
0415
0416 is.set_status(intersection::status::e_outside);
0417 }
0418 }
0419
0420 }