File indexing completed on 2026-05-27 07:24:05
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011
0012 #include "detray/definitions/algebra.hpp"
0013 #include "detray/definitions/containers.hpp"
0014 #include "detray/definitions/detail/qualifiers.hpp"
0015 #include "detray/definitions/math.hpp"
0016 #include "detray/utils/invalid_values.hpp"
0017
0018
0019 #include <limits>
0020
0021 namespace detray::detail {
0022
0023 template <concepts::scalar scalar_t, typename = void>
0024 class quadratic_equation {};
0025
0026
0027
0028
0029
0030
0031 template <concepts::scalar scalar_t>
0032 requires std::is_arithmetic_v<scalar_t>
0033 class quadratic_equation<scalar_t> {
0034 public:
0035 quadratic_equation() = delete;
0036
0037
0038
0039
0040
0041
0042 DETRAY_HOST_DEVICE
0043 constexpr quadratic_equation(
0044 const scalar_t a, const scalar_t b, const scalar_t c,
0045 const scalar_t tolerance = std::numeric_limits<scalar_t>::epsilon()) {
0046
0047 if (math::fabs(a) <= tolerance) {
0048 if (math::fabs(b) <= tolerance) {
0049 m_solutions = 0;
0050 } else {
0051 m_solutions = 1;
0052 m_values[0] = -c / b;
0053 }
0054 } else {
0055 const scalar_t discriminant{b * b - 4.f * a * c};
0056
0057 if (discriminant > tolerance) {
0058 m_solutions = 2;
0059 const scalar_t q{-0.5f *
0060 (b + math::copysign(math::sqrt(discriminant), b))};
0061 m_values = {q / a, c / q};
0062
0063 if (m_values[0] > m_values[1]) {
0064 m_values = {m_values[1], m_values[0]};
0065 }
0066 }
0067
0068 else if (discriminant >= 0.f) {
0069 m_solutions = 1;
0070 m_values[0] = -0.5f * b / a;
0071 }
0072
0073
0074 }
0075 }
0076
0077
0078
0079 constexpr int solutions() const { return m_solutions; }
0080 constexpr scalar_t smaller() const { return m_values[0]; }
0081 constexpr scalar_t larger() const { return m_values[1]; }
0082
0083
0084 private:
0085
0086 int m_solutions{0};
0087
0088 darray<scalar_t, 2> m_values{detail::invalid_value<scalar_t>(),
0089 detail::invalid_value<scalar_t>()};
0090 };
0091
0092
0093
0094
0095
0096
0097 template <concepts::scalar scalar_t>
0098 requires(!std::is_arithmetic_v<scalar_t>)
0099 class quadratic_equation<scalar_t> {
0100 public:
0101 quadratic_equation() = delete;
0102
0103 DETRAY_HOST_DEVICE
0104 constexpr quadratic_equation(const scalar_t &a, const scalar_t &b,
0105 const scalar_t &c,
0106 const scalar_t &tolerance = 1e-6f) {
0107
0108 auto one_sol = (math::fabs(a) <= tolerance);
0109 m_solutions(one_sol) = 1.f;
0110 m_values[0] = -c / b;
0111
0112
0113 if (detray::detail::all_of(one_sol)) {
0114 return;
0115 }
0116
0117 const scalar_t discriminant = b * b - (4.f * a) * c;
0118
0119 const auto two_sol = (discriminant > tolerance);
0120 one_sol = !two_sol && (discriminant >= 0.f);
0121
0122
0123 if (detray::detail::any_of(two_sol)) {
0124 m_solutions = 2.f;
0125 m_solutions.setZeroInverted(two_sol);
0126
0127 const scalar_t q =
0128 -0.5f * (b + math::copysign(math::sqrt(discriminant), b));
0129
0130 scalar_t first = q / a;
0131 scalar_t second = c / q;
0132 first.setZeroInverted(two_sol);
0133 second.setZeroInverted(two_sol);
0134
0135
0136 const auto do_swap = (second < first);
0137 if (detray::detail::all_of(do_swap)) {
0138 m_values = {second, first};
0139 } else if (detray::detail::none_of(do_swap)) {
0140 m_values = {first, second};
0141 } else {
0142 const auto tmp = second;
0143 second(do_swap) = first;
0144 first(do_swap) = tmp;
0145 m_values = {first, second};
0146 }
0147 }
0148
0149
0150 if (detray::detail::any_of(one_sol)) {
0151 scalar_t sol = 1.f;
0152 scalar_t result = -0.5f * b / a;
0153 sol.setZeroInverted(one_sol);
0154 result.setZeroInverted(one_sol);
0155
0156 m_solutions += sol;
0157 m_values[0] += result;
0158 }
0159
0160
0161 }
0162
0163
0164
0165 constexpr const auto &solutions() const { return m_solutions; }
0166 constexpr const scalar_t &smaller() const { return m_values[0]; }
0167 constexpr const scalar_t &larger() const { return m_values[1]; }
0168
0169
0170 private:
0171
0172
0173 scalar_t m_solutions = 0.f;
0174
0175 darray<scalar_t, 2> m_values{static_cast<scalar_t>(0.f),
0176 static_cast<scalar_t>(0.f)};
0177 };
0178
0179 template <typename S>
0180 quadratic_equation(const S a, const S &b, const S &c, const S &tolerance)
0181 -> quadratic_equation<S>;
0182
0183 }