File indexing completed on 2025-02-22 10:34:44
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 #ifndef EIGEN_SPARSE_TRIANGULARVIEW_H
0012 #define EIGEN_SPARSE_TRIANGULARVIEW_H
0013
0014 namespace Eigen {
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025 template<typename MatrixType, unsigned int Mode> class TriangularViewImpl<MatrixType,Mode,Sparse>
0026 : public SparseMatrixBase<TriangularView<MatrixType,Mode> >
0027 {
0028 enum { SkipFirst = ((Mode&Lower) && !(MatrixType::Flags&RowMajorBit))
0029 || ((Mode&Upper) && (MatrixType::Flags&RowMajorBit)),
0030 SkipLast = !SkipFirst,
0031 SkipDiag = (Mode&ZeroDiag) ? 1 : 0,
0032 HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
0033 };
0034
0035 typedef TriangularView<MatrixType,Mode> TriangularViewType;
0036
0037 protected:
0038
0039 void solve() const;
0040
0041 typedef SparseMatrixBase<TriangularViewType> Base;
0042 public:
0043
0044 EIGEN_SPARSE_PUBLIC_INTERFACE(TriangularViewType)
0045
0046 typedef typename MatrixType::Nested MatrixTypeNested;
0047 typedef typename internal::remove_reference<MatrixTypeNested>::type MatrixTypeNestedNonRef;
0048 typedef typename internal::remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned;
0049
0050 template<typename RhsType, typename DstType>
0051 EIGEN_DEVICE_FUNC
0052 EIGEN_STRONG_INLINE void _solve_impl(const RhsType &rhs, DstType &dst) const {
0053 if(!(internal::is_same<RhsType,DstType>::value && internal::extract_data(dst) == internal::extract_data(rhs)))
0054 dst = rhs;
0055 this->solveInPlace(dst);
0056 }
0057
0058
0059 template<typename OtherDerived> void solveInPlace(MatrixBase<OtherDerived>& other) const;
0060
0061
0062 template<typename OtherDerived> void solveInPlace(SparseMatrixBase<OtherDerived>& other) const;
0063
0064 };
0065
0066 namespace internal {
0067
0068 template<typename ArgType, unsigned int Mode>
0069 struct unary_evaluator<TriangularView<ArgType,Mode>, IteratorBased>
0070 : evaluator_base<TriangularView<ArgType,Mode> >
0071 {
0072 typedef TriangularView<ArgType,Mode> XprType;
0073
0074 protected:
0075
0076 typedef typename XprType::Scalar Scalar;
0077 typedef typename XprType::StorageIndex StorageIndex;
0078 typedef typename evaluator<ArgType>::InnerIterator EvalIterator;
0079
0080 enum { SkipFirst = ((Mode&Lower) && !(ArgType::Flags&RowMajorBit))
0081 || ((Mode&Upper) && (ArgType::Flags&RowMajorBit)),
0082 SkipLast = !SkipFirst,
0083 SkipDiag = (Mode&ZeroDiag) ? 1 : 0,
0084 HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
0085 };
0086
0087 public:
0088
0089 enum {
0090 CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
0091 Flags = XprType::Flags
0092 };
0093
0094 explicit unary_evaluator(const XprType &xpr) : m_argImpl(xpr.nestedExpression()), m_arg(xpr.nestedExpression()) {}
0095
0096 inline Index nonZerosEstimate() const {
0097 return m_argImpl.nonZerosEstimate();
0098 }
0099
0100 class InnerIterator : public EvalIterator
0101 {
0102 typedef EvalIterator Base;
0103 public:
0104
0105 EIGEN_STRONG_INLINE InnerIterator(const unary_evaluator& xprEval, Index outer)
0106 : Base(xprEval.m_argImpl,outer), m_returnOne(false), m_containsDiag(Base::outer()<xprEval.m_arg.innerSize())
0107 {
0108 if(SkipFirst)
0109 {
0110 while((*this) && ((HasUnitDiag||SkipDiag) ? this->index()<=outer : this->index()<outer))
0111 Base::operator++();
0112 if(HasUnitDiag)
0113 m_returnOne = m_containsDiag;
0114 }
0115 else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer()))
0116 {
0117 if((!SkipFirst) && Base::operator bool())
0118 Base::operator++();
0119 m_returnOne = m_containsDiag;
0120 }
0121 }
0122
0123 EIGEN_STRONG_INLINE InnerIterator& operator++()
0124 {
0125 if(HasUnitDiag && m_returnOne)
0126 m_returnOne = false;
0127 else
0128 {
0129 Base::operator++();
0130 if(HasUnitDiag && (!SkipFirst) && ((!Base::operator bool()) || Base::index()>=Base::outer()))
0131 {
0132 if((!SkipFirst) && Base::operator bool())
0133 Base::operator++();
0134 m_returnOne = m_containsDiag;
0135 }
0136 }
0137 return *this;
0138 }
0139
0140 EIGEN_STRONG_INLINE operator bool() const
0141 {
0142 if(HasUnitDiag && m_returnOne)
0143 return true;
0144 if(SkipFirst) return Base::operator bool();
0145 else
0146 {
0147 if (SkipDiag) return (Base::operator bool() && this->index() < this->outer());
0148 else return (Base::operator bool() && this->index() <= this->outer());
0149 }
0150 }
0151
0152
0153
0154 inline StorageIndex index() const
0155 {
0156 if(HasUnitDiag && m_returnOne) return internal::convert_index<StorageIndex>(Base::outer());
0157 else return Base::index();
0158 }
0159 inline Scalar value() const
0160 {
0161 if(HasUnitDiag && m_returnOne) return Scalar(1);
0162 else return Base::value();
0163 }
0164
0165 protected:
0166 bool m_returnOne;
0167 bool m_containsDiag;
0168 private:
0169 Scalar& valueRef();
0170 };
0171
0172 protected:
0173 evaluator<ArgType> m_argImpl;
0174 const ArgType& m_arg;
0175 };
0176
0177 }
0178
0179 template<typename Derived>
0180 template<int Mode>
0181 inline const TriangularView<const Derived, Mode>
0182 SparseMatrixBase<Derived>::triangularView() const
0183 {
0184 return TriangularView<const Derived, Mode>(derived());
0185 }
0186
0187 }
0188
0189 #endif