Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-19 09:06:20

0001 // This file is part of Eigen, a lightweight C++ template library
0002 // for linear algebra.
0003 //
0004 // Copyright (C) 2011-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
0005 //
0006 // This Source Code Form is subject to the terms of the Mozilla
0007 // Public License v. 2.0. If a copy of the MPL was not distributed
0008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
0009 
0010 #ifndef EIGEN_ITERATIVE_SOLVER_BASE_H
0011 #define EIGEN_ITERATIVE_SOLVER_BASE_H
0012 
0013 namespace RivetEigen {
0014 
0015 namespace internal {
0016 
0017 template<typename MatrixType>
0018 struct is_ref_compatible_impl
0019 {
0020 private:
0021   template <typename T0>
0022   struct any_conversion
0023   {
0024     template <typename T> any_conversion(const volatile T&);
0025     template <typename T> any_conversion(T&);
0026   };
0027   struct yes {int a[1];};
0028   struct no  {int a[2];};
0029 
0030   template<typename T>
0031   static yes test(const Ref<const T>&, int);
0032   template<typename T>
0033   static no  test(any_conversion<T>, ...);
0034 
0035 public:
0036   static MatrixType ms_from;
0037   enum { value = sizeof(test<MatrixType>(ms_from, 0))==sizeof(yes) };
0038 };
0039 
0040 template<typename MatrixType>
0041 struct is_ref_compatible
0042 {
0043   enum { value = is_ref_compatible_impl<typename remove_all<MatrixType>::type>::value };
0044 };
0045 
0046 template<typename MatrixType, bool MatrixFree = !internal::is_ref_compatible<MatrixType>::value>
0047 class generic_matrix_wrapper;
0048 
0049 // We have an explicit matrix at hand, compatible with Ref<>
0050 template<typename MatrixType>
0051 class generic_matrix_wrapper<MatrixType,false>
0052 {
0053 public:
0054   typedef Ref<const MatrixType> ActualMatrixType;
0055   template<int UpLo> struct ConstSelfAdjointViewReturnType {
0056     typedef typename ActualMatrixType::template ConstSelfAdjointViewReturnType<UpLo>::Type Type;
0057   };
0058 
0059   enum {
0060     MatrixFree = false
0061   };
0062 
0063   generic_matrix_wrapper()
0064     : m_dummy(0,0), m_matrix(m_dummy)
0065   {}
0066 
0067   template<typename InputType>
0068   generic_matrix_wrapper(const InputType &mat)
0069     : m_matrix(mat)
0070   {}
0071 
0072   const ActualMatrixType& matrix() const
0073   {
0074     return m_matrix;
0075   }
0076 
0077   template<typename MatrixDerived>
0078   void grab(const EigenBase<MatrixDerived> &mat)
0079   {
0080     m_matrix.~Ref<const MatrixType>();
0081     ::new (&m_matrix) Ref<const MatrixType>(mat.derived());
0082   }
0083 
0084   void grab(const Ref<const MatrixType> &mat)
0085   {
0086     if(&(mat.derived()) != &m_matrix)
0087     {
0088       m_matrix.~Ref<const MatrixType>();
0089       ::new (&m_matrix) Ref<const MatrixType>(mat);
0090     }
0091   }
0092 
0093 protected:
0094   MatrixType m_dummy; // used to default initialize the Ref<> object
0095   ActualMatrixType m_matrix;
0096 };
0097 
0098 // MatrixType is not compatible with Ref<> -> matrix-free wrapper
0099 template<typename MatrixType>
0100 class generic_matrix_wrapper<MatrixType,true>
0101 {
0102 public:
0103   typedef MatrixType ActualMatrixType;
0104   template<int UpLo> struct ConstSelfAdjointViewReturnType
0105   {
0106     typedef ActualMatrixType Type;
0107   };
0108 
0109   enum {
0110     MatrixFree = true
0111   };
0112 
0113   generic_matrix_wrapper()
0114     : mp_matrix(0)
0115   {}
0116 
0117   generic_matrix_wrapper(const MatrixType &mat)
0118     : mp_matrix(&mat)
0119   {}
0120 
0121   const ActualMatrixType& matrix() const
0122   {
0123     return *mp_matrix;
0124   }
0125 
0126   void grab(const MatrixType &mat)
0127   {
0128     mp_matrix = &mat;
0129   }
0130 
0131 protected:
0132   const ActualMatrixType *mp_matrix;
0133 };
0134 
0135 }
0136 
0137 /** \ingroup IterativeLinearSolvers_Module
0138   * \brief Base class for linear iterative solvers
0139   *
0140   * \sa class SimplicialCholesky, DiagonalPreconditioner, IdentityPreconditioner
0141   */
0142 template< typename Derived>
0143 class IterativeSolverBase : public SparseSolverBase<Derived>
0144 {
0145 protected:
0146   typedef SparseSolverBase<Derived> Base;
0147   using Base::m_isInitialized;
0148 
0149 public:
0150   typedef typename internal::traits<Derived>::MatrixType MatrixType;
0151   typedef typename internal::traits<Derived>::Preconditioner Preconditioner;
0152   typedef typename MatrixType::Scalar Scalar;
0153   typedef typename MatrixType::StorageIndex StorageIndex;
0154   typedef typename MatrixType::RealScalar RealScalar;
0155 
0156   enum {
0157     ColsAtCompileTime = MatrixType::ColsAtCompileTime,
0158     MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
0159   };
0160 
0161 public:
0162 
0163   using Base::derived;
0164 
0165   /** Default constructor. */
0166   IterativeSolverBase()
0167   {
0168     init();
0169   }
0170 
0171   /** Initialize the solver with matrix \a A for further \c Ax=b solving.
0172     *
0173     * This constructor is a shortcut for the default constructor followed
0174     * by a call to compute().
0175     *
0176     * \warning this class stores a reference to the matrix A as well as some
0177     * precomputed values that depend on it. Therefore, if \a A is changed
0178     * this class becomes invalid. Call compute() to update it with the new
0179     * matrix A, or modify a copy of A.
0180     */
0181   template<typename MatrixDerived>
0182   explicit IterativeSolverBase(const EigenBase<MatrixDerived>& A)
0183     : m_matrixWrapper(A.derived())
0184   {
0185     init();
0186     compute(matrix());
0187   }
0188 
0189   ~IterativeSolverBase() {}
0190 
0191   /** Initializes the iterative solver for the sparsity pattern of the matrix \a A for further solving \c Ax=b problems.
0192     *
0193     * Currently, this function mostly calls analyzePattern on the preconditioner. In the future
0194     * we might, for instance, implement column reordering for faster matrix vector products.
0195     */
0196   template<typename MatrixDerived>
0197   Derived& analyzePattern(const EigenBase<MatrixDerived>& A)
0198   {
0199     grab(A.derived());
0200     m_preconditioner.analyzePattern(matrix());
0201     m_isInitialized = true;
0202     m_analysisIsOk = true;
0203     m_info = m_preconditioner.info();
0204     return derived();
0205   }
0206 
0207   /** Initializes the iterative solver with the numerical values of the matrix \a A for further solving \c Ax=b problems.
0208     *
0209     * Currently, this function mostly calls factorize on the preconditioner.
0210     *
0211     * \warning this class stores a reference to the matrix A as well as some
0212     * precomputed values that depend on it. Therefore, if \a A is changed
0213     * this class becomes invalid. Call compute() to update it with the new
0214     * matrix A, or modify a copy of A.
0215     */
0216   template<typename MatrixDerived>
0217   Derived& factorize(const EigenBase<MatrixDerived>& A)
0218   {
0219     eigen_assert(m_analysisIsOk && "You must first call analyzePattern()");
0220     grab(A.derived());
0221     m_preconditioner.factorize(matrix());
0222     m_factorizationIsOk = true;
0223     m_info = m_preconditioner.info();
0224     return derived();
0225   }
0226 
0227   /** Initializes the iterative solver with the matrix \a A for further solving \c Ax=b problems.
0228     *
0229     * Currently, this function mostly initializes/computes the preconditioner. In the future
0230     * we might, for instance, implement column reordering for faster matrix vector products.
0231     *
0232     * \warning this class stores a reference to the matrix A as well as some
0233     * precomputed values that depend on it. Therefore, if \a A is changed
0234     * this class becomes invalid. Call compute() to update it with the new
0235     * matrix A, or modify a copy of A.
0236     */
0237   template<typename MatrixDerived>
0238   Derived& compute(const EigenBase<MatrixDerived>& A)
0239   {
0240     grab(A.derived());
0241     m_preconditioner.compute(matrix());
0242     m_isInitialized = true;
0243     m_analysisIsOk = true;
0244     m_factorizationIsOk = true;
0245     m_info = m_preconditioner.info();
0246     return derived();
0247   }
0248 
0249   /** \internal */
0250   EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return matrix().rows(); }
0251 
0252   /** \internal */
0253   EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return matrix().cols(); }
0254 
0255   /** \returns the tolerance threshold used by the stopping criteria.
0256     * \sa setTolerance()
0257     */
0258   RealScalar tolerance() const { return m_tolerance; }
0259 
0260   /** Sets the tolerance threshold used by the stopping criteria.
0261     *
0262     * This value is used as an upper bound to the relative residual error: |Ax-b|/|b|.
0263     * The default value is the machine precision given by NumTraits<Scalar>::epsilon()
0264     */
0265   Derived& setTolerance(const RealScalar& tolerance)
0266   {
0267     m_tolerance = tolerance;
0268     return derived();
0269   }
0270 
0271   /** \returns a read-write reference to the preconditioner for custom configuration. */
0272   Preconditioner& preconditioner() { return m_preconditioner; }
0273 
0274   /** \returns a read-only reference to the preconditioner. */
0275   const Preconditioner& preconditioner() const { return m_preconditioner; }
0276 
0277   /** \returns the max number of iterations.
0278     * It is either the value set by setMaxIterations or, by default,
0279     * twice the number of columns of the matrix.
0280     */
0281   Index maxIterations() const
0282   {
0283     return (m_maxIterations<0) ? 2*matrix().cols() : m_maxIterations;
0284   }
0285 
0286   /** Sets the max number of iterations.
0287     * Default is twice the number of columns of the matrix.
0288     */
0289   Derived& setMaxIterations(Index maxIters)
0290   {
0291     m_maxIterations = maxIters;
0292     return derived();
0293   }
0294 
0295   /** \returns the number of iterations performed during the last solve */
0296   Index iterations() const
0297   {
0298     eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
0299     return m_iterations;
0300   }
0301 
0302   /** \returns the tolerance error reached during the last solve.
0303     * It is a close approximation of the true relative residual error |Ax-b|/|b|.
0304     */
0305   RealScalar error() const
0306   {
0307     eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
0308     return m_error;
0309   }
0310 
0311   /** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A
0312     * and \a x0 as an initial solution.
0313     *
0314     * \sa solve(), compute()
0315     */
0316   template<typename Rhs,typename Guess>
0317   inline const SolveWithGuess<Derived, Rhs, Guess>
0318   solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const
0319   {
0320     eigen_assert(m_isInitialized && "Solver is not initialized.");
0321     eigen_assert(derived().rows()==b.rows() && "solve(): invalid number of rows of the right hand side matrix b");
0322     return SolveWithGuess<Derived, Rhs, Guess>(derived(), b.derived(), x0);
0323   }
0324 
0325   /** \returns Success if the iterations converged, and NoConvergence otherwise. */
0326   ComputationInfo info() const
0327   {
0328     eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
0329     return m_info;
0330   }
0331 
0332   /** \internal */
0333   template<typename Rhs, typename DestDerived>
0334   void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
0335   {
0336     eigen_assert(rows()==b.rows());
0337 
0338     Index rhsCols = b.cols();
0339     Index size = b.rows();
0340     DestDerived& dest(aDest.derived());
0341     typedef typename DestDerived::Scalar DestScalar;
0342     RivetEigen::Matrix<DestScalar,Dynamic,1> tb(size);
0343     RivetEigen::Matrix<DestScalar,Dynamic,1> tx(cols());
0344     // We do not directly fill dest because sparse expressions have to be free of aliasing issue.
0345     // For non square least-square problems, b and dest might not have the same size whereas they might alias each-other.
0346     typename DestDerived::PlainObject tmp(cols(),rhsCols);
0347     ComputationInfo global_info = Success;
0348     for(Index k=0; k<rhsCols; ++k)
0349     {
0350       tb = b.col(k);
0351       tx = dest.col(k);
0352       derived()._solve_vector_with_guess_impl(tb,tx);
0353       tmp.col(k) = tx.sparseView(0);
0354 
0355       // The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column
0356       // we need to restore it to the worst value.
0357       if(m_info==NumericalIssue)
0358         global_info = NumericalIssue;
0359       else if(m_info==NoConvergence)
0360         global_info = NoConvergence;
0361     }
0362     m_info = global_info;
0363     dest.swap(tmp);
0364   }
0365 
0366   template<typename Rhs, typename DestDerived>
0367   typename internal::enable_if<Rhs::ColsAtCompileTime!=1 && DestDerived::ColsAtCompileTime!=1>::type
0368   _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &aDest) const
0369   {
0370     eigen_assert(rows()==b.rows());
0371 
0372     Index rhsCols = b.cols();
0373     DestDerived& dest(aDest.derived());
0374     ComputationInfo global_info = Success;
0375     for(Index k=0; k<rhsCols; ++k)
0376     {
0377       typename DestDerived::ColXpr xk(dest,k);
0378       typename Rhs::ConstColXpr bk(b,k);
0379       derived()._solve_vector_with_guess_impl(bk,xk);
0380 
0381       // The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column
0382       // we need to restore it to the worst value.
0383       if(m_info==NumericalIssue)
0384         global_info = NumericalIssue;
0385       else if(m_info==NoConvergence)
0386         global_info = NoConvergence;
0387     }
0388     m_info = global_info;
0389   }
0390 
0391   template<typename Rhs, typename DestDerived>
0392   typename internal::enable_if<Rhs::ColsAtCompileTime==1 || DestDerived::ColsAtCompileTime==1>::type
0393   _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &dest) const
0394   {
0395     derived()._solve_vector_with_guess_impl(b,dest.derived());
0396   }
0397 
0398   /** \internal default initial guess = 0 */
0399   template<typename Rhs,typename Dest>
0400   void _solve_impl(const Rhs& b, Dest& x) const
0401   {
0402     x.setZero();
0403     derived()._solve_with_guess_impl(b,x);
0404   }
0405 
0406 protected:
0407   void init()
0408   {
0409     m_isInitialized = false;
0410     m_analysisIsOk = false;
0411     m_factorizationIsOk = false;
0412     m_maxIterations = -1;
0413     m_tolerance = NumTraits<Scalar>::epsilon();
0414   }
0415 
0416   typedef internal::generic_matrix_wrapper<MatrixType> MatrixWrapper;
0417   typedef typename MatrixWrapper::ActualMatrixType ActualMatrixType;
0418 
0419   const ActualMatrixType& matrix() const
0420   {
0421     return m_matrixWrapper.matrix();
0422   }
0423 
0424   template<typename InputType>
0425   void grab(const InputType &A)
0426   {
0427     m_matrixWrapper.grab(A);
0428   }
0429 
0430   MatrixWrapper m_matrixWrapper;
0431   Preconditioner m_preconditioner;
0432 
0433   Index m_maxIterations;
0434   RealScalar m_tolerance;
0435 
0436   mutable RealScalar m_error;
0437   mutable Index m_iterations;
0438   mutable ComputationInfo m_info;
0439   mutable bool m_analysisIsOk, m_factorizationIsOk;
0440 };
0441 
0442 } // end namespace RivetEigen
0443 
0444 #endif // EIGEN_ITERATIVE_SOLVER_BASE_H