File indexing completed on 2024-11-15 09:36:43
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef EIGEN_CONJUGATE_GRADIENT_H
0011 #define EIGEN_CONJUGATE_GRADIENT_H
0012
0013 namespace Eigen {
0014
0015 namespace internal {
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 template<typename MatrixType, typename Rhs, typename Dest, typename Preconditioner>
0027 EIGEN_DONT_INLINE
0028 void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
0029 const Preconditioner& precond, Index& iters,
0030 typename Dest::RealScalar& tol_error)
0031 {
0032 using std::sqrt;
0033 using std::abs;
0034 typedef typename Dest::RealScalar RealScalar;
0035 typedef typename Dest::Scalar Scalar;
0036 typedef Matrix<Scalar,Dynamic,1> VectorType;
0037
0038 RealScalar tol = tol_error;
0039 Index maxIters = iters;
0040
0041 Index n = mat.cols();
0042
0043 VectorType residual = rhs - mat * x;
0044
0045 RealScalar rhsNorm2 = rhs.squaredNorm();
0046 if(rhsNorm2 == 0)
0047 {
0048 x.setZero();
0049 iters = 0;
0050 tol_error = 0;
0051 return;
0052 }
0053 const RealScalar considerAsZero = (std::numeric_limits<RealScalar>::min)();
0054 RealScalar threshold = numext::maxi(RealScalar(tol*tol*rhsNorm2),considerAsZero);
0055 RealScalar residualNorm2 = residual.squaredNorm();
0056 if (residualNorm2 < threshold)
0057 {
0058 iters = 0;
0059 tol_error = sqrt(residualNorm2 / rhsNorm2);
0060 return;
0061 }
0062
0063 VectorType p(n);
0064 p = precond.solve(residual);
0065
0066 VectorType z(n), tmp(n);
0067 RealScalar absNew = numext::real(residual.dot(p));
0068 Index i = 0;
0069 while(i < maxIters)
0070 {
0071 tmp.noalias() = mat * p;
0072
0073 Scalar alpha = absNew / p.dot(tmp);
0074 x += alpha * p;
0075 residual -= alpha * tmp;
0076
0077 residualNorm2 = residual.squaredNorm();
0078 if(residualNorm2 < threshold)
0079 break;
0080
0081 z = precond.solve(residual);
0082
0083 RealScalar absOld = absNew;
0084 absNew = numext::real(residual.dot(z));
0085 RealScalar beta = absNew / absOld;
0086 p = z + beta * p;
0087 i++;
0088 }
0089 tol_error = sqrt(residualNorm2 / rhsNorm2);
0090 iters = i;
0091 }
0092
0093 }
0094
0095 template< typename _MatrixType, int _UpLo=Lower,
0096 typename _Preconditioner = DiagonalPreconditioner<typename _MatrixType::Scalar> >
0097 class ConjugateGradient;
0098
0099 namespace internal {
0100
0101 template< typename _MatrixType, int _UpLo, typename _Preconditioner>
0102 struct traits<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> >
0103 {
0104 typedef _MatrixType MatrixType;
0105 typedef _Preconditioner Preconditioner;
0106 };
0107
0108 }
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148
0149
0150
0151
0152
0153
0154
0155
0156
0157 template< typename _MatrixType, int _UpLo, typename _Preconditioner>
0158 class ConjugateGradient : public IterativeSolverBase<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> >
0159 {
0160 typedef IterativeSolverBase<ConjugateGradient> Base;
0161 using Base::matrix;
0162 using Base::m_error;
0163 using Base::m_iterations;
0164 using Base::m_info;
0165 using Base::m_isInitialized;
0166 public:
0167 typedef _MatrixType MatrixType;
0168 typedef typename MatrixType::Scalar Scalar;
0169 typedef typename MatrixType::RealScalar RealScalar;
0170 typedef _Preconditioner Preconditioner;
0171
0172 enum {
0173 UpLo = _UpLo
0174 };
0175
0176 public:
0177
0178
0179 ConjugateGradient() : Base() {}
0180
0181
0182
0183
0184
0185
0186
0187
0188
0189
0190
0191 template<typename MatrixDerived>
0192 explicit ConjugateGradient(const EigenBase<MatrixDerived>& A) : Base(A.derived()) {}
0193
0194 ~ConjugateGradient() {}
0195
0196
0197 template<typename Rhs,typename Dest>
0198 void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
0199 {
0200 typedef typename Base::MatrixWrapper MatrixWrapper;
0201 typedef typename Base::ActualMatrixType ActualMatrixType;
0202 enum {
0203 TransposeInput = (!MatrixWrapper::MatrixFree)
0204 && (UpLo==(Lower|Upper))
0205 && (!MatrixType::IsRowMajor)
0206 && (!NumTraits<Scalar>::IsComplex)
0207 };
0208 typedef typename internal::conditional<TransposeInput,Transpose<const ActualMatrixType>, ActualMatrixType const&>::type RowMajorWrapper;
0209 EIGEN_STATIC_ASSERT(EIGEN_IMPLIES(MatrixWrapper::MatrixFree,UpLo==(Lower|Upper)),MATRIX_FREE_CONJUGATE_GRADIENT_IS_COMPATIBLE_WITH_UPPER_UNION_LOWER_MODE_ONLY);
0210 typedef typename internal::conditional<UpLo==(Lower|Upper),
0211 RowMajorWrapper,
0212 typename MatrixWrapper::template ConstSelfAdjointViewReturnType<UpLo>::Type
0213 >::type SelfAdjointWrapper;
0214
0215 m_iterations = Base::maxIterations();
0216 m_error = Base::m_tolerance;
0217
0218 RowMajorWrapper row_mat(matrix());
0219 internal::conjugate_gradient(SelfAdjointWrapper(row_mat), b, x, Base::m_preconditioner, m_iterations, m_error);
0220 m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
0221 }
0222
0223 protected:
0224
0225 };
0226
0227 }
0228
0229 #endif