Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:24:05

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 // Project include(s)
0012 #include "detray/definitions/algebra.hpp"
0013 #include "detray/definitions/containers.hpp"
0014 #include "detray/definitions/detail/qualifiers.hpp"
0015 #include "detray/definitions/math.hpp"
0016 #include "detray/utils/invalid_values.hpp"
0017 
0018 // System include(s)
0019 #include <limits>
0020 
0021 namespace detray::detail {
0022 
0023 template <concepts::scalar scalar_t, typename = void>
0024 class quadratic_equation {};
0025 
0026 /// Class to solve a quadratic equation of type a * x^2 + b * x + c = 0
0027 ///
0028 /// @note If there are no real solutions, the result is undefined
0029 /// @note The solutions are sorted by default. If there is only one solution,
0030 /// the larger value is undefined.
0031 template <concepts::scalar scalar_t>
0032   requires std::is_arithmetic_v<scalar_t>
0033 class quadratic_equation<scalar_t> {
0034  public:
0035   quadratic_equation() = delete;
0036 
0037   /// Solve the quadratic equation with the coefficients @param a, @param b
0038   /// and @param c
0039   ///
0040   /// @param tolerance threshold to compare the discrimant against to decide
0041   ///                  if we have two separate solutions.
0042   DETRAY_HOST_DEVICE
0043   constexpr quadratic_equation(
0044       const scalar_t a, const scalar_t b, const scalar_t c,
0045       const scalar_t tolerance = std::numeric_limits<scalar_t>::epsilon()) {
0046     // linear case
0047     if (math::fabs(a) <= tolerance) {
0048       if (math::fabs(b) <= tolerance) {
0049         m_solutions = 0;
0050       } else {
0051         m_solutions = 1;
0052         m_values[0] = -c / b;
0053       }
0054     } else {
0055       const scalar_t discriminant{b * b - 4.f * a * c};
0056       // If there is more than one solution, then a != 0 and q != 0
0057       if (discriminant > tolerance) {
0058         m_solutions = 2;
0059         const scalar_t q{-0.5f *
0060                          (b + math::copysign(math::sqrt(discriminant), b))};
0061         m_values = {q / a, c / q};
0062         // Sort the two solutions
0063         if (m_values[0] > m_values[1]) {
0064           m_values = {m_values[1], m_values[0]};
0065         }
0066       }
0067       // Only one solution and a != 0
0068       else if (discriminant >= 0.f) {
0069         m_solutions = 1;
0070         m_values[0] = -0.5f * b / a;
0071       }
0072       // discriminant < 0 is not allowed, since all solutions should be
0073       // real
0074     }
0075   }
0076 
0077   /// Getters for the solution(s)
0078   /// @{
0079   constexpr int solutions() const { return m_solutions; }
0080   constexpr scalar_t smaller() const { return m_values[0]; }
0081   constexpr scalar_t larger() const { return m_values[1]; }
0082   /// @}
0083 
0084  private:
0085   /// Number of solutions of the equation
0086   int m_solutions{0};
0087   /// The solutions
0088   darray<scalar_t, 2> m_values{detail::invalid_value<scalar_t>(),
0089                                detail::invalid_value<scalar_t>()};
0090 };
0091 
0092 /// Class to solve a quadratic equation of type a * x^2 + b * x + c = 0
0093 ///
0094 /// @note If there are no real solutions, the result is undefined
0095 /// @note The solutions are sorted by default. If there is only one
0096 /// solution, the larger value is undefined.
0097 template <concepts::scalar scalar_t>
0098   requires(!std::is_arithmetic_v<scalar_t>)
0099 class quadratic_equation<scalar_t> {
0100  public:
0101   quadratic_equation() = delete;
0102 
0103   DETRAY_HOST_DEVICE
0104   constexpr quadratic_equation(const scalar_t &a, const scalar_t &b,
0105                                const scalar_t &c,
0106                                const scalar_t &tolerance = 1e-6f) {
0107     // Linear case
0108     auto one_sol = (math::fabs(a) <= tolerance);
0109     m_solutions(one_sol) = 1.f;
0110     m_values[0] = -c / b;
0111 
0112     // Early exit
0113     if (detray::detail::all_of(one_sol)) {
0114       return;
0115     }
0116 
0117     const scalar_t discriminant = b * b - (4.f * a) * c;
0118 
0119     const auto two_sol = (discriminant > tolerance);
0120     one_sol = !two_sol && (discriminant >= 0.f);
0121 
0122     // If there is more than one solution, then a != 0 and q != 0
0123     if (detray::detail::any_of(two_sol)) {
0124       m_solutions = 2.f;
0125       m_solutions.setZeroInverted(two_sol);
0126 
0127       const scalar_t q =
0128           -0.5f * (b + math::copysign(math::sqrt(discriminant), b));
0129 
0130       scalar_t first = q / a;
0131       scalar_t second = c / q;
0132       first.setZeroInverted(two_sol);
0133       second.setZeroInverted(two_sol);
0134 
0135       // Sort the solutions
0136       const auto do_swap = (second < first);
0137       if (detray::detail::all_of(do_swap)) {
0138         m_values = {second, first};
0139       } else if (detray::detail::none_of(do_swap)) {
0140         m_values = {first, second};
0141       } else {
0142         const auto tmp = second;
0143         second(do_swap) = first;
0144         first(do_swap) = tmp;
0145         m_values = {first, second};
0146       }
0147     }
0148 
0149     // Only one solution and a != 0
0150     if (detray::detail::any_of(one_sol)) {
0151       scalar_t sol = 1.f;
0152       scalar_t result = -0.5f * b / a;
0153       sol.setZeroInverted(one_sol);
0154       result.setZeroInverted(one_sol);
0155 
0156       m_solutions += sol;
0157       m_values[0] += result;
0158     }
0159     // discriminant < 0 is not allowed, since all solutions should
0160     // be real
0161   }
0162 
0163   /// Getters for the solution(s)
0164   /// @{
0165   constexpr const auto &solutions() const { return m_solutions; }
0166   constexpr const scalar_t &smaller() const { return m_values[0]; }
0167   constexpr const scalar_t &larger() const { return m_values[1]; }
0168   /// @}
0169 
0170  private:
0171   /// Number of solutions of the equation (needs to be floating point to
0172   /// apply the masks correctly)
0173   scalar_t m_solutions = 0.f;
0174   /// The solutions
0175   darray<scalar_t, 2> m_values{static_cast<scalar_t>(0.f),
0176                                static_cast<scalar_t>(0.f)};
0177 };
0178 
0179 template <typename S>
0180 quadratic_equation(const S a, const S &b, const S &c, const S &tolerance)
0181     -> quadratic_equation<S>;
0182 
0183 }  // namespace detray::detail