File indexing completed on 2025-01-18 09:57:04
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012 #ifndef KRONECKER_TENSOR_PRODUCT_H
0013 #define KRONECKER_TENSOR_PRODUCT_H
0014
0015 namespace Eigen {
0016
0017
0018
0019
0020
0021
0022
0023
0024 template<typename Derived>
0025 class KroneckerProductBase : public ReturnByValue<Derived>
0026 {
0027 private:
0028 typedef typename internal::traits<Derived> Traits;
0029 typedef typename Traits::Scalar Scalar;
0030
0031 protected:
0032 typedef typename Traits::Lhs Lhs;
0033 typedef typename Traits::Rhs Rhs;
0034
0035 public:
0036
0037 KroneckerProductBase(const Lhs& A, const Rhs& B)
0038 : m_A(A), m_B(B)
0039 {}
0040
0041 inline Index rows() const { return m_A.rows() * m_B.rows(); }
0042 inline Index cols() const { return m_A.cols() * m_B.cols(); }
0043
0044
0045
0046
0047
0048 Scalar coeff(Index row, Index col) const
0049 {
0050 return m_A.coeff(row / m_B.rows(), col / m_B.cols()) *
0051 m_B.coeff(row % m_B.rows(), col % m_B.cols());
0052 }
0053
0054
0055
0056
0057
0058 Scalar coeff(Index i) const
0059 {
0060 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
0061 return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
0062 }
0063
0064 protected:
0065 typename Lhs::Nested m_A;
0066 typename Rhs::Nested m_B;
0067 };
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081 template<typename Lhs, typename Rhs>
0082 class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs,Rhs> >
0083 {
0084 private:
0085 typedef KroneckerProductBase<KroneckerProduct> Base;
0086 using Base::m_A;
0087 using Base::m_B;
0088
0089 public:
0090
0091 KroneckerProduct(const Lhs& A, const Rhs& B)
0092 : Base(A, B)
0093 {}
0094
0095
0096 template<typename Dest> void evalTo(Dest& dst) const;
0097 };
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114 template<typename Lhs, typename Rhs>
0115 class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs,Rhs> >
0116 {
0117 private:
0118 typedef KroneckerProductBase<KroneckerProductSparse> Base;
0119 using Base::m_A;
0120 using Base::m_B;
0121
0122 public:
0123
0124 KroneckerProductSparse(const Lhs& A, const Rhs& B)
0125 : Base(A, B)
0126 {}
0127
0128
0129 template<typename Dest> void evalTo(Dest& dst) const;
0130 };
0131
0132 template<typename Lhs, typename Rhs>
0133 template<typename Dest>
0134 void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const
0135 {
0136 const int BlockRows = Rhs::RowsAtCompileTime,
0137 BlockCols = Rhs::ColsAtCompileTime;
0138 const Index Br = m_B.rows(),
0139 Bc = m_B.cols();
0140 for (Index i=0; i < m_A.rows(); ++i)
0141 for (Index j=0; j < m_A.cols(); ++j)
0142 Block<Dest,BlockRows,BlockCols>(dst,i*Br,j*Bc,Br,Bc) = m_A.coeff(i,j) * m_B;
0143 }
0144
0145 template<typename Lhs, typename Rhs>
0146 template<typename Dest>
0147 void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
0148 {
0149 Index Br = m_B.rows(), Bc = m_B.cols();
0150 dst.resize(this->rows(), this->cols());
0151 dst.resizeNonZeros(0);
0152
0153
0154 typedef typename internal::nested_eval<Lhs,Dynamic>::type Lhs1;
0155 typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned;
0156 const Lhs1 lhs1(m_A);
0157 typedef typename internal::nested_eval<Rhs,Dynamic>::type Rhs1;
0158 typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned;
0159 const Rhs1 rhs1(m_B);
0160
0161
0162 typedef Eigen::InnerIterator<Lhs1Cleaned> LhsInnerIterator;
0163 typedef Eigen::InnerIterator<Rhs1Cleaned> RhsInnerIterator;
0164
0165
0166 {
0167
0168 VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
0169 for (Index kA=0; kA < m_A.outerSize(); ++kA)
0170 for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
0171 nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
0172
0173 VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
0174 for (Index kB=0; kB < m_B.outerSize(); ++kB)
0175 for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
0176 nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
0177
0178 Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
0179 dst.reserve(VectorXi::Map(nnzAB.data(), nnzAB.size()));
0180 }
0181
0182 for (Index kA=0; kA < m_A.outerSize(); ++kA)
0183 {
0184 for (Index kB=0; kB < m_B.outerSize(); ++kB)
0185 {
0186 for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
0187 {
0188 for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
0189 {
0190 Index i = itA.row() * Br + itB.row(),
0191 j = itA.col() * Bc + itB.col();
0192 dst.insert(i,j) = itA.value() * itB.value();
0193 }
0194 }
0195 }
0196 }
0197 }
0198
0199 namespace internal {
0200
0201 template<typename _Lhs, typename _Rhs>
0202 struct traits<KroneckerProduct<_Lhs,_Rhs> >
0203 {
0204 typedef typename remove_all<_Lhs>::type Lhs;
0205 typedef typename remove_all<_Rhs>::type Rhs;
0206 typedef typename ScalarBinaryOpTraits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
0207 typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
0208
0209 enum {
0210 Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
0211 Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
0212 MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
0213 MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret
0214 };
0215
0216 typedef Matrix<Scalar,Rows,Cols> ReturnType;
0217 };
0218
0219 template<typename _Lhs, typename _Rhs>
0220 struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
0221 {
0222 typedef MatrixXpr XprKind;
0223 typedef typename remove_all<_Lhs>::type Lhs;
0224 typedef typename remove_all<_Rhs>::type Rhs;
0225 typedef typename ScalarBinaryOpTraits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
0226 typedef typename cwise_promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind, scalar_product_op<typename Lhs::Scalar, typename Rhs::Scalar> >::ret StorageKind;
0227 typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
0228
0229 enum {
0230 LhsFlags = Lhs::Flags,
0231 RhsFlags = Rhs::Flags,
0232
0233 RowsAtCompileTime = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
0234 ColsAtCompileTime = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
0235 MaxRowsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
0236 MaxColsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
0237
0238 EvalToRowMajor = (int(LhsFlags) & int(RhsFlags) & RowMajorBit),
0239 RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
0240
0241 Flags = ((int(LhsFlags) | int(RhsFlags)) & HereditaryBits & RemovedBits)
0242 | EvalBeforeNestingBit,
0243 CoeffReadCost = HugeCost
0244 };
0245
0246 typedef SparseMatrix<Scalar, 0, StorageIndex> ReturnType;
0247 };
0248
0249 }
0250
0251
0252
0253
0254
0255
0256
0257
0258
0259
0260
0261
0262
0263
0264
0265
0266
0267
0268
0269
0270 template<typename A, typename B>
0271 KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<B>& b)
0272 {
0273 return KroneckerProduct<A, B>(a.derived(), b.derived());
0274 }
0275
0276
0277
0278
0279
0280
0281
0282
0283
0284
0285
0286
0287
0288
0289
0290
0291
0292
0293
0294
0295
0296
0297 template<typename A, typename B>
0298 KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenBase<B>& b)
0299 {
0300 return KroneckerProductSparse<A,B>(a.derived(), b.derived());
0301 }
0302
0303 }
0304
0305 #endif