File indexing completed on 2025-04-19 09:06:20
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef EIGEN_ITERATIVE_SOLVER_BASE_H
0011 #define EIGEN_ITERATIVE_SOLVER_BASE_H
0012
0013 namespace RivetEigen {
0014
0015 namespace internal {
0016
0017 template<typename MatrixType>
0018 struct is_ref_compatible_impl
0019 {
0020 private:
0021 template <typename T0>
0022 struct any_conversion
0023 {
0024 template <typename T> any_conversion(const volatile T&);
0025 template <typename T> any_conversion(T&);
0026 };
0027 struct yes {int a[1];};
0028 struct no {int a[2];};
0029
0030 template<typename T>
0031 static yes test(const Ref<const T>&, int);
0032 template<typename T>
0033 static no test(any_conversion<T>, ...);
0034
0035 public:
0036 static MatrixType ms_from;
0037 enum { value = sizeof(test<MatrixType>(ms_from, 0))==sizeof(yes) };
0038 };
0039
0040 template<typename MatrixType>
0041 struct is_ref_compatible
0042 {
0043 enum { value = is_ref_compatible_impl<typename remove_all<MatrixType>::type>::value };
0044 };
0045
0046 template<typename MatrixType, bool MatrixFree = !internal::is_ref_compatible<MatrixType>::value>
0047 class generic_matrix_wrapper;
0048
0049
0050 template<typename MatrixType>
0051 class generic_matrix_wrapper<MatrixType,false>
0052 {
0053 public:
0054 typedef Ref<const MatrixType> ActualMatrixType;
0055 template<int UpLo> struct ConstSelfAdjointViewReturnType {
0056 typedef typename ActualMatrixType::template ConstSelfAdjointViewReturnType<UpLo>::Type Type;
0057 };
0058
0059 enum {
0060 MatrixFree = false
0061 };
0062
0063 generic_matrix_wrapper()
0064 : m_dummy(0,0), m_matrix(m_dummy)
0065 {}
0066
0067 template<typename InputType>
0068 generic_matrix_wrapper(const InputType &mat)
0069 : m_matrix(mat)
0070 {}
0071
0072 const ActualMatrixType& matrix() const
0073 {
0074 return m_matrix;
0075 }
0076
0077 template<typename MatrixDerived>
0078 void grab(const EigenBase<MatrixDerived> &mat)
0079 {
0080 m_matrix.~Ref<const MatrixType>();
0081 ::new (&m_matrix) Ref<const MatrixType>(mat.derived());
0082 }
0083
0084 void grab(const Ref<const MatrixType> &mat)
0085 {
0086 if(&(mat.derived()) != &m_matrix)
0087 {
0088 m_matrix.~Ref<const MatrixType>();
0089 ::new (&m_matrix) Ref<const MatrixType>(mat);
0090 }
0091 }
0092
0093 protected:
0094 MatrixType m_dummy;
0095 ActualMatrixType m_matrix;
0096 };
0097
0098
0099 template<typename MatrixType>
0100 class generic_matrix_wrapper<MatrixType,true>
0101 {
0102 public:
0103 typedef MatrixType ActualMatrixType;
0104 template<int UpLo> struct ConstSelfAdjointViewReturnType
0105 {
0106 typedef ActualMatrixType Type;
0107 };
0108
0109 enum {
0110 MatrixFree = true
0111 };
0112
0113 generic_matrix_wrapper()
0114 : mp_matrix(0)
0115 {}
0116
0117 generic_matrix_wrapper(const MatrixType &mat)
0118 : mp_matrix(&mat)
0119 {}
0120
0121 const ActualMatrixType& matrix() const
0122 {
0123 return *mp_matrix;
0124 }
0125
0126 void grab(const MatrixType &mat)
0127 {
0128 mp_matrix = &mat;
0129 }
0130
0131 protected:
0132 const ActualMatrixType *mp_matrix;
0133 };
0134
0135 }
0136
0137
0138
0139
0140
0141
0142 template< typename Derived>
0143 class IterativeSolverBase : public SparseSolverBase<Derived>
0144 {
0145 protected:
0146 typedef SparseSolverBase<Derived> Base;
0147 using Base::m_isInitialized;
0148
0149 public:
0150 typedef typename internal::traits<Derived>::MatrixType MatrixType;
0151 typedef typename internal::traits<Derived>::Preconditioner Preconditioner;
0152 typedef typename MatrixType::Scalar Scalar;
0153 typedef typename MatrixType::StorageIndex StorageIndex;
0154 typedef typename MatrixType::RealScalar RealScalar;
0155
0156 enum {
0157 ColsAtCompileTime = MatrixType::ColsAtCompileTime,
0158 MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
0159 };
0160
0161 public:
0162
0163 using Base::derived;
0164
0165
0166 IterativeSolverBase()
0167 {
0168 init();
0169 }
0170
0171
0172
0173
0174
0175
0176
0177
0178
0179
0180
0181 template<typename MatrixDerived>
0182 explicit IterativeSolverBase(const EigenBase<MatrixDerived>& A)
0183 : m_matrixWrapper(A.derived())
0184 {
0185 init();
0186 compute(matrix());
0187 }
0188
0189 ~IterativeSolverBase() {}
0190
0191
0192
0193
0194
0195
0196 template<typename MatrixDerived>
0197 Derived& analyzePattern(const EigenBase<MatrixDerived>& A)
0198 {
0199 grab(A.derived());
0200 m_preconditioner.analyzePattern(matrix());
0201 m_isInitialized = true;
0202 m_analysisIsOk = true;
0203 m_info = m_preconditioner.info();
0204 return derived();
0205 }
0206
0207
0208
0209
0210
0211
0212
0213
0214
0215
0216 template<typename MatrixDerived>
0217 Derived& factorize(const EigenBase<MatrixDerived>& A)
0218 {
0219 eigen_assert(m_analysisIsOk && "You must first call analyzePattern()");
0220 grab(A.derived());
0221 m_preconditioner.factorize(matrix());
0222 m_factorizationIsOk = true;
0223 m_info = m_preconditioner.info();
0224 return derived();
0225 }
0226
0227
0228
0229
0230
0231
0232
0233
0234
0235
0236
0237 template<typename MatrixDerived>
0238 Derived& compute(const EigenBase<MatrixDerived>& A)
0239 {
0240 grab(A.derived());
0241 m_preconditioner.compute(matrix());
0242 m_isInitialized = true;
0243 m_analysisIsOk = true;
0244 m_factorizationIsOk = true;
0245 m_info = m_preconditioner.info();
0246 return derived();
0247 }
0248
0249
0250 EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return matrix().rows(); }
0251
0252
0253 EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return matrix().cols(); }
0254
0255
0256
0257
0258 RealScalar tolerance() const { return m_tolerance; }
0259
0260
0261
0262
0263
0264
0265 Derived& setTolerance(const RealScalar& tolerance)
0266 {
0267 m_tolerance = tolerance;
0268 return derived();
0269 }
0270
0271
0272 Preconditioner& preconditioner() { return m_preconditioner; }
0273
0274
0275 const Preconditioner& preconditioner() const { return m_preconditioner; }
0276
0277
0278
0279
0280
0281 Index maxIterations() const
0282 {
0283 return (m_maxIterations<0) ? 2*matrix().cols() : m_maxIterations;
0284 }
0285
0286
0287
0288
0289 Derived& setMaxIterations(Index maxIters)
0290 {
0291 m_maxIterations = maxIters;
0292 return derived();
0293 }
0294
0295
0296 Index iterations() const
0297 {
0298 eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
0299 return m_iterations;
0300 }
0301
0302
0303
0304
0305 RealScalar error() const
0306 {
0307 eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
0308 return m_error;
0309 }
0310
0311
0312
0313
0314
0315
0316 template<typename Rhs,typename Guess>
0317 inline const SolveWithGuess<Derived, Rhs, Guess>
0318 solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const
0319 {
0320 eigen_assert(m_isInitialized && "Solver is not initialized.");
0321 eigen_assert(derived().rows()==b.rows() && "solve(): invalid number of rows of the right hand side matrix b");
0322 return SolveWithGuess<Derived, Rhs, Guess>(derived(), b.derived(), x0);
0323 }
0324
0325
0326 ComputationInfo info() const
0327 {
0328 eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
0329 return m_info;
0330 }
0331
0332
0333 template<typename Rhs, typename DestDerived>
0334 void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
0335 {
0336 eigen_assert(rows()==b.rows());
0337
0338 Index rhsCols = b.cols();
0339 Index size = b.rows();
0340 DestDerived& dest(aDest.derived());
0341 typedef typename DestDerived::Scalar DestScalar;
0342 RivetEigen::Matrix<DestScalar,Dynamic,1> tb(size);
0343 RivetEigen::Matrix<DestScalar,Dynamic,1> tx(cols());
0344
0345
0346 typename DestDerived::PlainObject tmp(cols(),rhsCols);
0347 ComputationInfo global_info = Success;
0348 for(Index k=0; k<rhsCols; ++k)
0349 {
0350 tb = b.col(k);
0351 tx = dest.col(k);
0352 derived()._solve_vector_with_guess_impl(tb,tx);
0353 tmp.col(k) = tx.sparseView(0);
0354
0355
0356
0357 if(m_info==NumericalIssue)
0358 global_info = NumericalIssue;
0359 else if(m_info==NoConvergence)
0360 global_info = NoConvergence;
0361 }
0362 m_info = global_info;
0363 dest.swap(tmp);
0364 }
0365
0366 template<typename Rhs, typename DestDerived>
0367 typename internal::enable_if<Rhs::ColsAtCompileTime!=1 && DestDerived::ColsAtCompileTime!=1>::type
0368 _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &aDest) const
0369 {
0370 eigen_assert(rows()==b.rows());
0371
0372 Index rhsCols = b.cols();
0373 DestDerived& dest(aDest.derived());
0374 ComputationInfo global_info = Success;
0375 for(Index k=0; k<rhsCols; ++k)
0376 {
0377 typename DestDerived::ColXpr xk(dest,k);
0378 typename Rhs::ConstColXpr bk(b,k);
0379 derived()._solve_vector_with_guess_impl(bk,xk);
0380
0381
0382
0383 if(m_info==NumericalIssue)
0384 global_info = NumericalIssue;
0385 else if(m_info==NoConvergence)
0386 global_info = NoConvergence;
0387 }
0388 m_info = global_info;
0389 }
0390
0391 template<typename Rhs, typename DestDerived>
0392 typename internal::enable_if<Rhs::ColsAtCompileTime==1 || DestDerived::ColsAtCompileTime==1>::type
0393 _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &dest) const
0394 {
0395 derived()._solve_vector_with_guess_impl(b,dest.derived());
0396 }
0397
0398
0399 template<typename Rhs,typename Dest>
0400 void _solve_impl(const Rhs& b, Dest& x) const
0401 {
0402 x.setZero();
0403 derived()._solve_with_guess_impl(b,x);
0404 }
0405
0406 protected:
0407 void init()
0408 {
0409 m_isInitialized = false;
0410 m_analysisIsOk = false;
0411 m_factorizationIsOk = false;
0412 m_maxIterations = -1;
0413 m_tolerance = NumTraits<Scalar>::epsilon();
0414 }
0415
0416 typedef internal::generic_matrix_wrapper<MatrixType> MatrixWrapper;
0417 typedef typename MatrixWrapper::ActualMatrixType ActualMatrixType;
0418
0419 const ActualMatrixType& matrix() const
0420 {
0421 return m_matrixWrapper.matrix();
0422 }
0423
0424 template<typename InputType>
0425 void grab(const InputType &A)
0426 {
0427 m_matrixWrapper.grab(A);
0428 }
0429
0430 MatrixWrapper m_matrixWrapper;
0431 Preconditioner m_preconditioner;
0432
0433 Index m_maxIterations;
0434 RealScalar m_tolerance;
0435
0436 mutable RealScalar m_error;
0437 mutable Index m_iterations;
0438 mutable ComputationInfo m_info;
0439 mutable bool m_analysisIsOk, m_factorizationIsOk;
0440 };
0441
0442 }
0443
0444 #endif