File indexing completed on 2025-04-19 09:06:20
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef EIGEN_LEAST_SQUARE_CONJUGATE_GRADIENT_H
0011 #define EIGEN_LEAST_SQUARE_CONJUGATE_GRADIENT_H
0012
0013 namespace RivetEigen {
0014
0015 namespace internal {
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 template<typename MatrixType, typename Rhs, typename Dest, typename Preconditioner>
0027 EIGEN_DONT_INLINE
0028 void least_square_conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
0029 const Preconditioner& precond, Index& iters,
0030 typename Dest::RealScalar& tol_error)
0031 {
0032 using std::sqrt;
0033 using std::abs;
0034 typedef typename Dest::RealScalar RealScalar;
0035 typedef typename Dest::Scalar Scalar;
0036 typedef Matrix<Scalar,Dynamic,1> VectorType;
0037
0038 RealScalar tol = tol_error;
0039 Index maxIters = iters;
0040
0041 Index m = mat.rows(), n = mat.cols();
0042
0043 VectorType residual = rhs - mat * x;
0044 VectorType normal_residual = mat.adjoint() * residual;
0045
0046 RealScalar rhsNorm2 = (mat.adjoint()*rhs).squaredNorm();
0047 if(rhsNorm2 == 0)
0048 {
0049 x.setZero();
0050 iters = 0;
0051 tol_error = 0;
0052 return;
0053 }
0054 RealScalar threshold = tol*tol*rhsNorm2;
0055 RealScalar residualNorm2 = normal_residual.squaredNorm();
0056 if (residualNorm2 < threshold)
0057 {
0058 iters = 0;
0059 tol_error = sqrt(residualNorm2 / rhsNorm2);
0060 return;
0061 }
0062
0063 VectorType p(n);
0064 p = precond.solve(normal_residual);
0065
0066 VectorType z(n), tmp(m);
0067 RealScalar absNew = numext::real(normal_residual.dot(p));
0068 Index i = 0;
0069 while(i < maxIters)
0070 {
0071 tmp.noalias() = mat * p;
0072
0073 Scalar alpha = absNew / tmp.squaredNorm();
0074 x += alpha * p;
0075 residual -= alpha * tmp;
0076 normal_residual = mat.adjoint() * residual;
0077
0078 residualNorm2 = normal_residual.squaredNorm();
0079 if(residualNorm2 < threshold)
0080 break;
0081
0082 z = precond.solve(normal_residual);
0083
0084 RealScalar absOld = absNew;
0085 absNew = numext::real(normal_residual.dot(z));
0086 RealScalar beta = absNew / absOld;
0087 p = z + beta * p;
0088 i++;
0089 }
0090 tol_error = sqrt(residualNorm2 / rhsNorm2);
0091 iters = i;
0092 }
0093
0094 }
0095
0096 template< typename _MatrixType,
0097 typename _Preconditioner = LeastSquareDiagonalPreconditioner<typename _MatrixType::Scalar> >
0098 class LeastSquaresConjugateGradient;
0099
0100 namespace internal {
0101
0102 template< typename _MatrixType, typename _Preconditioner>
0103 struct traits<LeastSquaresConjugateGradient<_MatrixType,_Preconditioner> >
0104 {
0105 typedef _MatrixType MatrixType;
0106 typedef _Preconditioner Preconditioner;
0107 };
0108
0109 }
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148 template< typename _MatrixType, typename _Preconditioner>
0149 class LeastSquaresConjugateGradient : public IterativeSolverBase<LeastSquaresConjugateGradient<_MatrixType,_Preconditioner> >
0150 {
0151 typedef IterativeSolverBase<LeastSquaresConjugateGradient> Base;
0152 using Base::matrix;
0153 using Base::m_error;
0154 using Base::m_iterations;
0155 using Base::m_info;
0156 using Base::m_isInitialized;
0157 public:
0158 typedef _MatrixType MatrixType;
0159 typedef typename MatrixType::Scalar Scalar;
0160 typedef typename MatrixType::RealScalar RealScalar;
0161 typedef _Preconditioner Preconditioner;
0162
0163 public:
0164
0165
0166 LeastSquaresConjugateGradient() : Base() {}
0167
0168
0169
0170
0171
0172
0173
0174
0175
0176
0177
0178 template<typename MatrixDerived>
0179 explicit LeastSquaresConjugateGradient(const EigenBase<MatrixDerived>& A) : Base(A.derived()) {}
0180
0181 ~LeastSquaresConjugateGradient() {}
0182
0183
0184 template<typename Rhs,typename Dest>
0185 void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
0186 {
0187 m_iterations = Base::maxIterations();
0188 m_error = Base::m_tolerance;
0189
0190 internal::least_square_conjugate_gradient(matrix(), b, x, Base::m_preconditioner, m_iterations, m_error);
0191 m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
0192 }
0193
0194 };
0195
0196 }
0197
0198 #endif