File indexing completed on 2025-01-18 09:57:05
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef EIGEN_MATRIX_SQUARE_ROOT
0011 #define EIGEN_MATRIX_SQUARE_ROOT
0012
0013 namespace Eigen {
0014
0015 namespace internal {
0016
0017
0018
0019 template <typename MatrixType, typename ResultType>
0020 void matrix_sqrt_quasi_triangular_2x2_diagonal_block(const MatrixType& T, Index i, ResultType& sqrtT)
0021 {
0022
0023
0024 typedef typename traits<MatrixType>::Scalar Scalar;
0025 Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
0026 EigenSolver<Matrix<Scalar,2,2> > es(block);
0027 sqrtT.template block<2,2>(i,i)
0028 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
0029 }
0030
0031
0032
0033
0034 template <typename MatrixType, typename ResultType>
0035 void matrix_sqrt_quasi_triangular_1x1_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT)
0036 {
0037 typedef typename traits<MatrixType>::Scalar Scalar;
0038 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
0039 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
0040 }
0041
0042
0043 template <typename MatrixType, typename ResultType>
0044 void matrix_sqrt_quasi_triangular_1x2_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT)
0045 {
0046 typedef typename traits<MatrixType>::Scalar Scalar;
0047 Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
0048 if (j-i > 1)
0049 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
0050 Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
0051 A += sqrtT.template block<2,2>(j,j).transpose();
0052 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
0053 }
0054
0055
0056 template <typename MatrixType, typename ResultType>
0057 void matrix_sqrt_quasi_triangular_2x1_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT)
0058 {
0059 typedef typename traits<MatrixType>::Scalar Scalar;
0060 Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
0061 if (j-i > 2)
0062 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
0063 Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
0064 A += sqrtT.template block<2,2>(i,i);
0065 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
0066 }
0067
0068
0069 template <typename MatrixType>
0070 void matrix_sqrt_quasi_triangular_solve_auxiliary_equation(MatrixType& X, const MatrixType& A, const MatrixType& B, const MatrixType& C)
0071 {
0072 typedef typename traits<MatrixType>::Scalar Scalar;
0073 Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
0074 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
0075 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
0076 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
0077 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
0078 coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
0079 coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
0080 coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
0081 coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
0082 coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
0083 coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
0084 coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
0085 coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
0086
0087 Matrix<Scalar,4,1> rhs;
0088 rhs.coeffRef(0) = C.coeff(0,0);
0089 rhs.coeffRef(1) = C.coeff(0,1);
0090 rhs.coeffRef(2) = C.coeff(1,0);
0091 rhs.coeffRef(3) = C.coeff(1,1);
0092
0093 Matrix<Scalar,4,1> result;
0094 result = coeffMatrix.fullPivLu().solve(rhs);
0095
0096 X.coeffRef(0,0) = result.coeff(0);
0097 X.coeffRef(0,1) = result.coeff(1);
0098 X.coeffRef(1,0) = result.coeff(2);
0099 X.coeffRef(1,1) = result.coeff(3);
0100 }
0101
0102
0103 template <typename MatrixType, typename ResultType>
0104 void matrix_sqrt_quasi_triangular_2x2_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT)
0105 {
0106 typedef typename traits<MatrixType>::Scalar Scalar;
0107 Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
0108 Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
0109 Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
0110 if (j-i > 2)
0111 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
0112 Matrix<Scalar,2,2> X;
0113 matrix_sqrt_quasi_triangular_solve_auxiliary_equation(X, A, B, C);
0114 sqrtT.template block<2,2>(i,j) = X;
0115 }
0116
0117
0118
0119 template <typename MatrixType, typename ResultType>
0120 void matrix_sqrt_quasi_triangular_diagonal(const MatrixType& T, ResultType& sqrtT)
0121 {
0122 using std::sqrt;
0123 const Index size = T.rows();
0124 for (Index i = 0; i < size; i++) {
0125 if (i == size - 1 || T.coeff(i+1, i) == 0) {
0126 eigen_assert(T(i,i) >= 0);
0127 sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i));
0128 }
0129 else {
0130 matrix_sqrt_quasi_triangular_2x2_diagonal_block(T, i, sqrtT);
0131 ++i;
0132 }
0133 }
0134 }
0135
0136
0137
0138 template <typename MatrixType, typename ResultType>
0139 void matrix_sqrt_quasi_triangular_off_diagonal(const MatrixType& T, ResultType& sqrtT)
0140 {
0141 const Index size = T.rows();
0142 for (Index j = 1; j < size; j++) {
0143 if (T.coeff(j, j-1) != 0)
0144 continue;
0145 for (Index i = j-1; i >= 0; i--) {
0146 if (i > 0 && T.coeff(i, i-1) != 0)
0147 continue;
0148 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
0149 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
0150 if (iBlockIs2x2 && jBlockIs2x2)
0151 matrix_sqrt_quasi_triangular_2x2_off_diagonal_block(T, i, j, sqrtT);
0152 else if (iBlockIs2x2 && !jBlockIs2x2)
0153 matrix_sqrt_quasi_triangular_2x1_off_diagonal_block(T, i, j, sqrtT);
0154 else if (!iBlockIs2x2 && jBlockIs2x2)
0155 matrix_sqrt_quasi_triangular_1x2_off_diagonal_block(T, i, j, sqrtT);
0156 else if (!iBlockIs2x2 && !jBlockIs2x2)
0157 matrix_sqrt_quasi_triangular_1x1_off_diagonal_block(T, i, j, sqrtT);
0158 }
0159 }
0160 }
0161
0162 }
0163
0164
0165
0166
0167
0168
0169
0170
0171
0172
0173
0174
0175
0176
0177
0178
0179 template <typename MatrixType, typename ResultType>
0180 void matrix_sqrt_quasi_triangular(const MatrixType &arg, ResultType &result)
0181 {
0182 eigen_assert(arg.rows() == arg.cols());
0183 result.resize(arg.rows(), arg.cols());
0184 internal::matrix_sqrt_quasi_triangular_diagonal(arg, result);
0185 internal::matrix_sqrt_quasi_triangular_off_diagonal(arg, result);
0186 }
0187
0188
0189
0190
0191
0192
0193
0194
0195
0196
0197
0198
0199
0200
0201
0202
0203 template <typename MatrixType, typename ResultType>
0204 void matrix_sqrt_triangular(const MatrixType &arg, ResultType &result)
0205 {
0206 using std::sqrt;
0207 typedef typename MatrixType::Scalar Scalar;
0208
0209 eigen_assert(arg.rows() == arg.cols());
0210
0211
0212
0213 result.resize(arg.rows(), arg.cols());
0214 for (Index i = 0; i < arg.rows(); i++) {
0215 result.coeffRef(i,i) = sqrt(arg.coeff(i,i));
0216 }
0217 for (Index j = 1; j < arg.cols(); j++) {
0218 for (Index i = j-1; i >= 0; i--) {
0219
0220 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
0221
0222 result.coeffRef(i,j) = (arg.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
0223 }
0224 }
0225 }
0226
0227
0228 namespace internal {
0229
0230
0231
0232
0233
0234
0235
0236
0237 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
0238 struct matrix_sqrt_compute
0239 {
0240
0241
0242
0243
0244
0245
0246
0247 template <typename ResultType> static void run(const MatrixType &arg, ResultType &result);
0248 };
0249
0250
0251
0252
0253 template <typename MatrixType>
0254 struct matrix_sqrt_compute<MatrixType, 0>
0255 {
0256 typedef typename MatrixType::PlainObject PlainType;
0257 template <typename ResultType>
0258 static void run(const MatrixType &arg, ResultType &result)
0259 {
0260 eigen_assert(arg.rows() == arg.cols());
0261
0262
0263 const RealSchur<PlainType> schurOfA(arg);
0264 const PlainType& T = schurOfA.matrixT();
0265 const PlainType& U = schurOfA.matrixU();
0266
0267
0268 PlainType sqrtT = PlainType::Zero(arg.rows(), arg.cols());
0269 matrix_sqrt_quasi_triangular(T, sqrtT);
0270
0271
0272 result = U * sqrtT * U.adjoint();
0273 }
0274 };
0275
0276
0277
0278
0279 template <typename MatrixType>
0280 struct matrix_sqrt_compute<MatrixType, 1>
0281 {
0282 typedef typename MatrixType::PlainObject PlainType;
0283 template <typename ResultType>
0284 static void run(const MatrixType &arg, ResultType &result)
0285 {
0286 eigen_assert(arg.rows() == arg.cols());
0287
0288
0289 const ComplexSchur<PlainType> schurOfA(arg);
0290 const PlainType& T = schurOfA.matrixT();
0291 const PlainType& U = schurOfA.matrixU();
0292
0293
0294 PlainType sqrtT;
0295 matrix_sqrt_triangular(T, sqrtT);
0296
0297
0298 result = U * (sqrtT.template triangularView<Upper>() * U.adjoint());
0299 }
0300 };
0301
0302 }
0303
0304
0305
0306
0307
0308
0309
0310
0311
0312
0313
0314
0315
0316 template<typename Derived> class MatrixSquareRootReturnValue
0317 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
0318 {
0319 protected:
0320 typedef typename internal::ref_selector<Derived>::type DerivedNested;
0321
0322 public:
0323
0324
0325
0326
0327
0328 explicit MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
0329
0330
0331
0332
0333
0334
0335 template <typename ResultType>
0336 inline void evalTo(ResultType& result) const
0337 {
0338 typedef typename internal::nested_eval<Derived, 10>::type DerivedEvalType;
0339 typedef typename internal::remove_all<DerivedEvalType>::type DerivedEvalTypeClean;
0340 DerivedEvalType tmp(m_src);
0341 internal::matrix_sqrt_compute<DerivedEvalTypeClean>::run(tmp, result);
0342 }
0343
0344 Index rows() const { return m_src.rows(); }
0345 Index cols() const { return m_src.cols(); }
0346
0347 protected:
0348 const DerivedNested m_src;
0349 };
0350
0351 namespace internal {
0352 template<typename Derived>
0353 struct traits<MatrixSquareRootReturnValue<Derived> >
0354 {
0355 typedef typename Derived::PlainObject ReturnType;
0356 };
0357 }
0358
0359 template <typename Derived>
0360 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
0361 {
0362 eigen_assert(rows() == cols());
0363 return MatrixSquareRootReturnValue<Derived>(derived());
0364 }
0365
0366 }
0367
0368 #endif