File indexing completed on 2025-01-18 09:56:13
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 #ifndef EIGEN_GENERAL_PRODUCT_H
0012 #define EIGEN_GENERAL_PRODUCT_H
0013
0014 namespace Eigen {
0015
0016 enum {
0017 Large = 2,
0018 Small = 3
0019 };
0020
0021
0022
0023
0024
0025
0026 #ifndef EIGEN_GEMM_TO_COEFFBASED_THRESHOLD
0027
0028 #define EIGEN_GEMM_TO_COEFFBASED_THRESHOLD 20
0029 #endif
0030
0031 namespace internal {
0032
0033 template<int Rows, int Cols, int Depth> struct product_type_selector;
0034
0035 template<int Size, int MaxSize> struct product_size_category
0036 {
0037 enum {
0038 #ifndef EIGEN_GPU_COMPILE_PHASE
0039 is_large = MaxSize == Dynamic ||
0040 Size >= EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD ||
0041 (Size==Dynamic && MaxSize>=EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD),
0042 #else
0043 is_large = 0,
0044 #endif
0045 value = is_large ? Large
0046 : Size == 1 ? 1
0047 : Small
0048 };
0049 };
0050
0051 template<typename Lhs, typename Rhs> struct product_type
0052 {
0053 typedef typename remove_all<Lhs>::type _Lhs;
0054 typedef typename remove_all<Rhs>::type _Rhs;
0055 enum {
0056 MaxRows = traits<_Lhs>::MaxRowsAtCompileTime,
0057 Rows = traits<_Lhs>::RowsAtCompileTime,
0058 MaxCols = traits<_Rhs>::MaxColsAtCompileTime,
0059 Cols = traits<_Rhs>::ColsAtCompileTime,
0060 MaxDepth = EIGEN_SIZE_MIN_PREFER_FIXED(traits<_Lhs>::MaxColsAtCompileTime,
0061 traits<_Rhs>::MaxRowsAtCompileTime),
0062 Depth = EIGEN_SIZE_MIN_PREFER_FIXED(traits<_Lhs>::ColsAtCompileTime,
0063 traits<_Rhs>::RowsAtCompileTime)
0064 };
0065
0066
0067
0068 private:
0069 enum {
0070 rows_select = product_size_category<Rows,MaxRows>::value,
0071 cols_select = product_size_category<Cols,MaxCols>::value,
0072 depth_select = product_size_category<Depth,MaxDepth>::value
0073 };
0074 typedef product_type_selector<rows_select, cols_select, depth_select> selector;
0075
0076 public:
0077 enum {
0078 value = selector::ret,
0079 ret = selector::ret
0080 };
0081 #ifdef EIGEN_DEBUG_PRODUCT
0082 static void debug()
0083 {
0084 EIGEN_DEBUG_VAR(Rows);
0085 EIGEN_DEBUG_VAR(Cols);
0086 EIGEN_DEBUG_VAR(Depth);
0087 EIGEN_DEBUG_VAR(rows_select);
0088 EIGEN_DEBUG_VAR(cols_select);
0089 EIGEN_DEBUG_VAR(depth_select);
0090 EIGEN_DEBUG_VAR(value);
0091 }
0092 #endif
0093 };
0094
0095
0096
0097
0098
0099 template<int M, int N> struct product_type_selector<M,N,1> { enum { ret = OuterProduct }; };
0100 template<int M> struct product_type_selector<M, 1, 1> { enum { ret = LazyCoeffBasedProductMode }; };
0101 template<int N> struct product_type_selector<1, N, 1> { enum { ret = LazyCoeffBasedProductMode }; };
0102 template<int Depth> struct product_type_selector<1, 1, Depth> { enum { ret = InnerProduct }; };
0103 template<> struct product_type_selector<1, 1, 1> { enum { ret = InnerProduct }; };
0104 template<> struct product_type_selector<Small,1, Small> { enum { ret = CoeffBasedProductMode }; };
0105 template<> struct product_type_selector<1, Small,Small> { enum { ret = CoeffBasedProductMode }; };
0106 template<> struct product_type_selector<Small,Small,Small> { enum { ret = CoeffBasedProductMode }; };
0107 template<> struct product_type_selector<Small, Small, 1> { enum { ret = LazyCoeffBasedProductMode }; };
0108 template<> struct product_type_selector<Small, Large, 1> { enum { ret = LazyCoeffBasedProductMode }; };
0109 template<> struct product_type_selector<Large, Small, 1> { enum { ret = LazyCoeffBasedProductMode }; };
0110 template<> struct product_type_selector<1, Large,Small> { enum { ret = CoeffBasedProductMode }; };
0111 template<> struct product_type_selector<1, Large,Large> { enum { ret = GemvProduct }; };
0112 template<> struct product_type_selector<1, Small,Large> { enum { ret = CoeffBasedProductMode }; };
0113 template<> struct product_type_selector<Large,1, Small> { enum { ret = CoeffBasedProductMode }; };
0114 template<> struct product_type_selector<Large,1, Large> { enum { ret = GemvProduct }; };
0115 template<> struct product_type_selector<Small,1, Large> { enum { ret = CoeffBasedProductMode }; };
0116 template<> struct product_type_selector<Small,Small,Large> { enum { ret = GemmProduct }; };
0117 template<> struct product_type_selector<Large,Small,Large> { enum { ret = GemmProduct }; };
0118 template<> struct product_type_selector<Small,Large,Large> { enum { ret = GemmProduct }; };
0119 template<> struct product_type_selector<Large,Large,Large> { enum { ret = GemmProduct }; };
0120 template<> struct product_type_selector<Large,Small,Small> { enum { ret = CoeffBasedProductMode }; };
0121 template<> struct product_type_selector<Small,Large,Small> { enum { ret = CoeffBasedProductMode }; };
0122 template<> struct product_type_selector<Large,Large,Small> { enum { ret = GemmProduct }; };
0123
0124 }
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148
0149
0150
0151
0152 namespace internal {
0153
0154 template<int Side, int StorageOrder, bool BlasCompatible>
0155 struct gemv_dense_selector;
0156
0157 }
0158
0159 namespace internal {
0160
0161 template<typename Scalar,int Size,int MaxSize,bool Cond> struct gemv_static_vector_if;
0162
0163 template<typename Scalar,int Size,int MaxSize>
0164 struct gemv_static_vector_if<Scalar,Size,MaxSize,false>
0165 {
0166 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() { eigen_internal_assert(false && "should never be called"); return 0; }
0167 };
0168
0169 template<typename Scalar,int Size>
0170 struct gemv_static_vector_if<Scalar,Size,Dynamic,true>
0171 {
0172 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() { return 0; }
0173 };
0174
0175 template<typename Scalar,int Size,int MaxSize>
0176 struct gemv_static_vector_if<Scalar,Size,MaxSize,true>
0177 {
0178 enum {
0179 ForceAlignment = internal::packet_traits<Scalar>::Vectorizable,
0180 PacketSize = internal::packet_traits<Scalar>::size
0181 };
0182 #if EIGEN_MAX_STATIC_ALIGN_BYTES!=0
0183 internal::plain_array<Scalar,EIGEN_SIZE_MIN_PREFER_FIXED(Size,MaxSize),0,EIGEN_PLAIN_ENUM_MIN(AlignedMax,PacketSize)> m_data;
0184 EIGEN_STRONG_INLINE Scalar* data() { return m_data.array; }
0185 #else
0186
0187
0188 internal::plain_array<Scalar,EIGEN_SIZE_MIN_PREFER_FIXED(Size,MaxSize)+(ForceAlignment?EIGEN_MAX_ALIGN_BYTES:0),0> m_data;
0189 EIGEN_STRONG_INLINE Scalar* data() {
0190 return ForceAlignment
0191 ? reinterpret_cast<Scalar*>((internal::UIntPtr(m_data.array) & ~(std::size_t(EIGEN_MAX_ALIGN_BYTES-1))) + EIGEN_MAX_ALIGN_BYTES)
0192 : m_data.array;
0193 }
0194 #endif
0195 };
0196
0197
0198 template<int StorageOrder, bool BlasCompatible>
0199 struct gemv_dense_selector<OnTheLeft,StorageOrder,BlasCompatible>
0200 {
0201 template<typename Lhs, typename Rhs, typename Dest>
0202 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
0203 {
0204 Transpose<Dest> destT(dest);
0205 enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
0206 gemv_dense_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
0207 ::run(rhs.transpose(), lhs.transpose(), destT, alpha);
0208 }
0209 };
0210
0211 template<> struct gemv_dense_selector<OnTheRight,ColMajor,true>
0212 {
0213 template<typename Lhs, typename Rhs, typename Dest>
0214 static inline void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
0215 {
0216 typedef typename Lhs::Scalar LhsScalar;
0217 typedef typename Rhs::Scalar RhsScalar;
0218 typedef typename Dest::Scalar ResScalar;
0219 typedef typename Dest::RealScalar RealScalar;
0220
0221 typedef internal::blas_traits<Lhs> LhsBlasTraits;
0222 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
0223 typedef internal::blas_traits<Rhs> RhsBlasTraits;
0224 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
0225
0226 typedef Map<Matrix<ResScalar,Dynamic,1>, EIGEN_PLAIN_ENUM_MIN(AlignedMax,internal::packet_traits<ResScalar>::size)> MappedDest;
0227
0228 ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
0229 ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
0230
0231 ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
0232
0233
0234 typedef typename conditional<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr>::type ActualDest;
0235
0236 enum {
0237
0238
0239 EvalToDestAtCompileTime = (ActualDest::InnerStrideAtCompileTime==1),
0240 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
0241 MightCannotUseDest = ((!EvalToDestAtCompileTime) || ComplexByReal) && (ActualDest::MaxSizeAtCompileTime!=0)
0242 };
0243
0244 typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
0245 typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
0246 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
0247
0248 if(!MightCannotUseDest)
0249 {
0250
0251
0252 general_matrix_vector_product
0253 <Index,LhsScalar,LhsMapper,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
0254 actualLhs.rows(), actualLhs.cols(),
0255 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
0256 RhsMapper(actualRhs.data(), actualRhs.innerStride()),
0257 dest.data(), 1,
0258 compatibleAlpha);
0259 }
0260 else
0261 {
0262 gemv_static_vector_if<ResScalar,ActualDest::SizeAtCompileTime,ActualDest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
0263
0264 const bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
0265 const bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
0266
0267 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
0268 evalToDest ? dest.data() : static_dest.data());
0269
0270 if(!evalToDest)
0271 {
0272 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
0273 Index size = dest.size();
0274 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
0275 #endif
0276 if(!alphaIsCompatible)
0277 {
0278 MappedDest(actualDestPtr, dest.size()).setZero();
0279 compatibleAlpha = RhsScalar(1);
0280 }
0281 else
0282 MappedDest(actualDestPtr, dest.size()) = dest;
0283 }
0284
0285 general_matrix_vector_product
0286 <Index,LhsScalar,LhsMapper,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
0287 actualLhs.rows(), actualLhs.cols(),
0288 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
0289 RhsMapper(actualRhs.data(), actualRhs.innerStride()),
0290 actualDestPtr, 1,
0291 compatibleAlpha);
0292
0293 if (!evalToDest)
0294 {
0295 if(!alphaIsCompatible)
0296 dest.matrix() += actualAlpha * MappedDest(actualDestPtr, dest.size());
0297 else
0298 dest = MappedDest(actualDestPtr, dest.size());
0299 }
0300 }
0301 }
0302 };
0303
0304 template<> struct gemv_dense_selector<OnTheRight,RowMajor,true>
0305 {
0306 template<typename Lhs, typename Rhs, typename Dest>
0307 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
0308 {
0309 typedef typename Lhs::Scalar LhsScalar;
0310 typedef typename Rhs::Scalar RhsScalar;
0311 typedef typename Dest::Scalar ResScalar;
0312
0313 typedef internal::blas_traits<Lhs> LhsBlasTraits;
0314 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
0315 typedef internal::blas_traits<Rhs> RhsBlasTraits;
0316 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
0317 typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
0318
0319 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
0320 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
0321
0322 ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
0323
0324 enum {
0325
0326
0327 DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1 || ActualRhsTypeCleaned::MaxSizeAtCompileTime==0
0328 };
0329
0330 gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
0331
0332 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
0333 DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
0334
0335 if(!DirectlyUseRhs)
0336 {
0337 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
0338 Index size = actualRhs.size();
0339 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
0340 #endif
0341 Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
0342 }
0343
0344 typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
0345 typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
0346 general_matrix_vector_product
0347 <Index,LhsScalar,LhsMapper,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
0348 actualLhs.rows(), actualLhs.cols(),
0349 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
0350 RhsMapper(actualRhsPtr, 1),
0351 dest.data(), dest.col(0).innerStride(),
0352 actualAlpha);
0353 }
0354 };
0355
0356 template<> struct gemv_dense_selector<OnTheRight,ColMajor,false>
0357 {
0358 template<typename Lhs, typename Rhs, typename Dest>
0359 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
0360 {
0361 EIGEN_STATIC_ASSERT((!nested_eval<Lhs,1>::Evaluate),EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
0362
0363 typename nested_eval<Rhs,1>::type actual_rhs(rhs);
0364 const Index size = rhs.rows();
0365 for(Index k=0; k<size; ++k)
0366 dest += (alpha*actual_rhs.coeff(k)) * lhs.col(k);
0367 }
0368 };
0369
0370 template<> struct gemv_dense_selector<OnTheRight,RowMajor,false>
0371 {
0372 template<typename Lhs, typename Rhs, typename Dest>
0373 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
0374 {
0375 EIGEN_STATIC_ASSERT((!nested_eval<Lhs,1>::Evaluate),EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
0376 typename nested_eval<Rhs,Lhs::RowsAtCompileTime>::type actual_rhs(rhs);
0377 const Index rows = dest.rows();
0378 for(Index i=0; i<rows; ++i)
0379 dest.coeffRef(i) += alpha * (lhs.row(i).cwiseProduct(actual_rhs.transpose())).sum();
0380 }
0381 };
0382
0383 }
0384
0385
0386
0387
0388
0389
0390
0391
0392
0393
0394
0395 template<typename Derived>
0396 template<typename OtherDerived>
0397 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
0398 const Product<Derived, OtherDerived>
0399 MatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
0400 {
0401
0402
0403
0404
0405 enum {
0406 ProductIsValid = Derived::ColsAtCompileTime==Dynamic
0407 || OtherDerived::RowsAtCompileTime==Dynamic
0408 || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
0409 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
0410 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
0411 };
0412
0413
0414
0415 EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
0416 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
0417 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
0418 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
0419 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
0420 #ifdef EIGEN_DEBUG_PRODUCT
0421 internal::product_type<Derived,OtherDerived>::debug();
0422 #endif
0423
0424 return Product<Derived, OtherDerived>(derived(), other.derived());
0425 }
0426
0427
0428
0429
0430
0431
0432
0433
0434
0435
0436
0437
0438 template<typename Derived>
0439 template<typename OtherDerived>
0440 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
0441 const Product<Derived,OtherDerived,LazyProduct>
0442 MatrixBase<Derived>::lazyProduct(const MatrixBase<OtherDerived> &other) const
0443 {
0444 enum {
0445 ProductIsValid = Derived::ColsAtCompileTime==Dynamic
0446 || OtherDerived::RowsAtCompileTime==Dynamic
0447 || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
0448 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
0449 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
0450 };
0451
0452
0453
0454 EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
0455 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
0456 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
0457 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
0458 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
0459
0460 return Product<Derived,OtherDerived,LazyProduct>(derived(), other.derived());
0461 }
0462
0463 }
0464
0465 #endif