File indexing completed on 2025-10-30 08:39:32
0001 
0002 
0003 
0004 
0005 
0006 
0007 
0008 
0009 
0010 
0011 #ifndef EIGEN_GENERAL_PRODUCT_H
0012 #define EIGEN_GENERAL_PRODUCT_H
0013 
0014 namespace Eigen {
0015 
0016 enum {
0017   Large = 2,
0018   Small = 3
0019 };
0020 
0021 
0022 
0023 
0024 
0025 
0026 #ifndef EIGEN_GEMM_TO_COEFFBASED_THRESHOLD
0027 
0028 #define EIGEN_GEMM_TO_COEFFBASED_THRESHOLD 20
0029 #endif
0030 
0031 namespace internal {
0032 
0033 template<int Rows, int Cols, int Depth> struct product_type_selector;
0034 
0035 template<int Size, int MaxSize> struct product_size_category
0036 {
0037   enum {
0038     #ifndef EIGEN_GPU_COMPILE_PHASE
0039     is_large = MaxSize == Dynamic ||
0040                Size >= EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD ||
0041                (Size==Dynamic && MaxSize>=EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD),
0042     #else
0043     is_large = 0,
0044     #endif
0045     value = is_large  ? Large
0046           : Size == 1 ? 1
0047                       : Small
0048   };
0049 };
0050 
0051 template<typename Lhs, typename Rhs> struct product_type
0052 {
0053   typedef typename remove_all<Lhs>::type _Lhs;
0054   typedef typename remove_all<Rhs>::type _Rhs;
0055   enum {
0056     MaxRows = traits<_Lhs>::MaxRowsAtCompileTime,
0057     Rows    = traits<_Lhs>::RowsAtCompileTime,
0058     MaxCols = traits<_Rhs>::MaxColsAtCompileTime,
0059     Cols    = traits<_Rhs>::ColsAtCompileTime,
0060     MaxDepth = EIGEN_SIZE_MIN_PREFER_FIXED(traits<_Lhs>::MaxColsAtCompileTime,
0061                                            traits<_Rhs>::MaxRowsAtCompileTime),
0062     Depth = EIGEN_SIZE_MIN_PREFER_FIXED(traits<_Lhs>::ColsAtCompileTime,
0063                                         traits<_Rhs>::RowsAtCompileTime)
0064   };
0065 
0066   
0067   
0068 private:
0069   enum {
0070     rows_select = product_size_category<Rows,MaxRows>::value,
0071     cols_select = product_size_category<Cols,MaxCols>::value,
0072     depth_select = product_size_category<Depth,MaxDepth>::value
0073   };
0074   typedef product_type_selector<rows_select, cols_select, depth_select> selector;
0075 
0076 public:
0077   enum {
0078     value = selector::ret,
0079     ret = selector::ret
0080   };
0081 #ifdef EIGEN_DEBUG_PRODUCT
0082   static void debug()
0083   {
0084       EIGEN_DEBUG_VAR(Rows);
0085       EIGEN_DEBUG_VAR(Cols);
0086       EIGEN_DEBUG_VAR(Depth);
0087       EIGEN_DEBUG_VAR(rows_select);
0088       EIGEN_DEBUG_VAR(cols_select);
0089       EIGEN_DEBUG_VAR(depth_select);
0090       EIGEN_DEBUG_VAR(value);
0091   }
0092 #endif
0093 };
0094 
0095 
0096 
0097 
0098 
0099 template<int M, int N>  struct product_type_selector<M,N,1>              { enum { ret = OuterProduct }; };
0100 template<int M>         struct product_type_selector<M, 1, 1>            { enum { ret = LazyCoeffBasedProductMode }; };
0101 template<int N>         struct product_type_selector<1, N, 1>            { enum { ret = LazyCoeffBasedProductMode }; };
0102 template<int Depth>     struct product_type_selector<1,    1,    Depth>  { enum { ret = InnerProduct }; };
0103 template<>              struct product_type_selector<1,    1,    1>      { enum { ret = InnerProduct }; };
0104 template<>              struct product_type_selector<Small,1,    Small>  { enum { ret = CoeffBasedProductMode }; };
0105 template<>              struct product_type_selector<1,    Small,Small>  { enum { ret = CoeffBasedProductMode }; };
0106 template<>              struct product_type_selector<Small,Small,Small>  { enum { ret = CoeffBasedProductMode }; };
0107 template<>              struct product_type_selector<Small, Small, 1>    { enum { ret = LazyCoeffBasedProductMode }; };
0108 template<>              struct product_type_selector<Small, Large, 1>    { enum { ret = LazyCoeffBasedProductMode }; };
0109 template<>              struct product_type_selector<Large, Small, 1>    { enum { ret = LazyCoeffBasedProductMode }; };
0110 template<>              struct product_type_selector<1,    Large,Small>  { enum { ret = CoeffBasedProductMode }; };
0111 template<>              struct product_type_selector<1,    Large,Large>  { enum { ret = GemvProduct }; };
0112 template<>              struct product_type_selector<1,    Small,Large>  { enum { ret = CoeffBasedProductMode }; };
0113 template<>              struct product_type_selector<Large,1,    Small>  { enum { ret = CoeffBasedProductMode }; };
0114 template<>              struct product_type_selector<Large,1,    Large>  { enum { ret = GemvProduct }; };
0115 template<>              struct product_type_selector<Small,1,    Large>  { enum { ret = CoeffBasedProductMode }; };
0116 template<>              struct product_type_selector<Small,Small,Large>  { enum { ret = GemmProduct }; };
0117 template<>              struct product_type_selector<Large,Small,Large>  { enum { ret = GemmProduct }; };
0118 template<>              struct product_type_selector<Small,Large,Large>  { enum { ret = GemmProduct }; };
0119 template<>              struct product_type_selector<Large,Large,Large>  { enum { ret = GemmProduct }; };
0120 template<>              struct product_type_selector<Large,Small,Small>  { enum { ret = CoeffBasedProductMode }; };
0121 template<>              struct product_type_selector<Small,Large,Small>  { enum { ret = CoeffBasedProductMode }; };
0122 template<>              struct product_type_selector<Large,Large,Small>  { enum { ret = GemmProduct }; };
0123 
0124 } 
0125 
0126 
0127 
0128 
0129 
0130 
0131 
0132 
0133 
0134 
0135 
0136 
0137 
0138 
0139 
0140 
0141 
0142 
0143 
0144 
0145 
0146 
0147 
0148 
0149 
0150 
0151 
0152 namespace internal {
0153 
0154 template<int Side, int StorageOrder, bool BlasCompatible>
0155 struct gemv_dense_selector;
0156 
0157 } 
0158 
0159 namespace internal {
0160 
0161 template<typename Scalar,int Size,int MaxSize,bool Cond> struct gemv_static_vector_if;
0162 
0163 template<typename Scalar,int Size,int MaxSize>
0164 struct gemv_static_vector_if<Scalar,Size,MaxSize,false>
0165 {
0166   EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() { eigen_internal_assert(false && "should never be called"); return 0; }
0167 };
0168 
0169 template<typename Scalar,int Size>
0170 struct gemv_static_vector_if<Scalar,Size,Dynamic,true>
0171 {
0172   EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() { return 0; }
0173 };
0174 
0175 template<typename Scalar,int Size,int MaxSize>
0176 struct gemv_static_vector_if<Scalar,Size,MaxSize,true>
0177 {
0178   enum {
0179     ForceAlignment  = internal::packet_traits<Scalar>::Vectorizable,
0180     PacketSize      = internal::packet_traits<Scalar>::size
0181   };
0182   #if EIGEN_MAX_STATIC_ALIGN_BYTES!=0
0183   internal::plain_array<Scalar,EIGEN_SIZE_MIN_PREFER_FIXED(Size,MaxSize),0,EIGEN_PLAIN_ENUM_MIN(AlignedMax,PacketSize)> m_data;
0184   EIGEN_STRONG_INLINE Scalar* data() { return m_data.array; }
0185   #else
0186   
0187   
0188   internal::plain_array<Scalar,EIGEN_SIZE_MIN_PREFER_FIXED(Size,MaxSize)+(ForceAlignment?EIGEN_MAX_ALIGN_BYTES:0),0> m_data;
0189   EIGEN_STRONG_INLINE Scalar* data() {
0190     return ForceAlignment
0191             ? reinterpret_cast<Scalar*>((internal::UIntPtr(m_data.array) & ~(std::size_t(EIGEN_MAX_ALIGN_BYTES-1))) + EIGEN_MAX_ALIGN_BYTES)
0192             : m_data.array;
0193   }
0194   #endif
0195 };
0196 
0197 
0198 template<int StorageOrder, bool BlasCompatible>
0199 struct gemv_dense_selector<OnTheLeft,StorageOrder,BlasCompatible>
0200 {
0201   template<typename Lhs, typename Rhs, typename Dest>
0202   static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
0203   {
0204     Transpose<Dest> destT(dest);
0205     enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
0206     gemv_dense_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
0207       ::run(rhs.transpose(), lhs.transpose(), destT, alpha);
0208   }
0209 };
0210 
0211 template<> struct gemv_dense_selector<OnTheRight,ColMajor,true>
0212 {
0213   template<typename Lhs, typename Rhs, typename Dest>
0214   static inline void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
0215   {
0216     typedef typename Lhs::Scalar   LhsScalar;
0217     typedef typename Rhs::Scalar   RhsScalar;
0218     typedef typename Dest::Scalar  ResScalar;
0219     typedef typename Dest::RealScalar  RealScalar;
0220     
0221     typedef internal::blas_traits<Lhs> LhsBlasTraits;
0222     typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
0223     typedef internal::blas_traits<Rhs> RhsBlasTraits;
0224     typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
0225   
0226     typedef Map<Matrix<ResScalar,Dynamic,1>, EIGEN_PLAIN_ENUM_MIN(AlignedMax,internal::packet_traits<ResScalar>::size)> MappedDest;
0227 
0228     ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
0229     ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
0230 
0231     ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
0232 
0233     
0234     typedef typename conditional<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr>::type ActualDest;
0235 
0236     enum {
0237       
0238       
0239       EvalToDestAtCompileTime = (ActualDest::InnerStrideAtCompileTime==1),
0240       ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
0241       MightCannotUseDest = ((!EvalToDestAtCompileTime) || ComplexByReal) && (ActualDest::MaxSizeAtCompileTime!=0)
0242     };
0243 
0244     typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
0245     typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
0246     RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
0247 
0248     if(!MightCannotUseDest)
0249     {
0250       
0251       
0252       general_matrix_vector_product
0253           <Index,LhsScalar,LhsMapper,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
0254           actualLhs.rows(), actualLhs.cols(),
0255           LhsMapper(actualLhs.data(), actualLhs.outerStride()),
0256           RhsMapper(actualRhs.data(), actualRhs.innerStride()),
0257           dest.data(), 1,
0258           compatibleAlpha);
0259     }
0260     else
0261     {
0262       gemv_static_vector_if<ResScalar,ActualDest::SizeAtCompileTime,ActualDest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
0263 
0264       const bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
0265       const bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
0266 
0267       ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
0268                                                     evalToDest ? dest.data() : static_dest.data());
0269 
0270       if(!evalToDest)
0271       {
0272         #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
0273         Index size = dest.size();
0274         EIGEN_DENSE_STORAGE_CTOR_PLUGIN
0275         #endif
0276         if(!alphaIsCompatible)
0277         {
0278           MappedDest(actualDestPtr, dest.size()).setZero();
0279           compatibleAlpha = RhsScalar(1);
0280         }
0281         else
0282           MappedDest(actualDestPtr, dest.size()) = dest;
0283       }
0284 
0285       general_matrix_vector_product
0286           <Index,LhsScalar,LhsMapper,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
0287           actualLhs.rows(), actualLhs.cols(),
0288           LhsMapper(actualLhs.data(), actualLhs.outerStride()),
0289           RhsMapper(actualRhs.data(), actualRhs.innerStride()),
0290           actualDestPtr, 1,
0291           compatibleAlpha);
0292 
0293       if (!evalToDest)
0294       {
0295         if(!alphaIsCompatible)
0296           dest.matrix() += actualAlpha * MappedDest(actualDestPtr, dest.size());
0297         else
0298           dest = MappedDest(actualDestPtr, dest.size());
0299       }
0300     }
0301   }
0302 };
0303 
0304 template<> struct gemv_dense_selector<OnTheRight,RowMajor,true>
0305 {
0306   template<typename Lhs, typename Rhs, typename Dest>
0307   static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
0308   {
0309     typedef typename Lhs::Scalar   LhsScalar;
0310     typedef typename Rhs::Scalar   RhsScalar;
0311     typedef typename Dest::Scalar  ResScalar;
0312     
0313     typedef internal::blas_traits<Lhs> LhsBlasTraits;
0314     typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
0315     typedef internal::blas_traits<Rhs> RhsBlasTraits;
0316     typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
0317     typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
0318 
0319     typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
0320     typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
0321 
0322     ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
0323 
0324     enum {
0325       
0326       
0327       DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1 || ActualRhsTypeCleaned::MaxSizeAtCompileTime==0
0328     };
0329 
0330     gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
0331 
0332     ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
0333         DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
0334 
0335     if(!DirectlyUseRhs)
0336     {
0337       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
0338       Index size = actualRhs.size();
0339       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
0340       #endif
0341       Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
0342     }
0343 
0344     typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
0345     typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
0346     general_matrix_vector_product
0347         <Index,LhsScalar,LhsMapper,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
0348         actualLhs.rows(), actualLhs.cols(),
0349         LhsMapper(actualLhs.data(), actualLhs.outerStride()),
0350         RhsMapper(actualRhsPtr, 1),
0351         dest.data(), dest.col(0).innerStride(), 
0352         actualAlpha);
0353   }
0354 };
0355 
0356 template<> struct gemv_dense_selector<OnTheRight,ColMajor,false>
0357 {
0358   template<typename Lhs, typename Rhs, typename Dest>
0359   static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
0360   {
0361     EIGEN_STATIC_ASSERT((!nested_eval<Lhs,1>::Evaluate),EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
0362     
0363     typename nested_eval<Rhs,1>::type actual_rhs(rhs);
0364     const Index size = rhs.rows();
0365     for(Index k=0; k<size; ++k)
0366       dest += (alpha*actual_rhs.coeff(k)) * lhs.col(k);
0367   }
0368 };
0369 
0370 template<> struct gemv_dense_selector<OnTheRight,RowMajor,false>
0371 {
0372   template<typename Lhs, typename Rhs, typename Dest>
0373   static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
0374   {
0375     EIGEN_STATIC_ASSERT((!nested_eval<Lhs,1>::Evaluate),EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
0376     typename nested_eval<Rhs,Lhs::RowsAtCompileTime>::type actual_rhs(rhs);
0377     const Index rows = dest.rows();
0378     for(Index i=0; i<rows; ++i)
0379       dest.coeffRef(i) += alpha * (lhs.row(i).cwiseProduct(actual_rhs.transpose())).sum();
0380   }
0381 };
0382 
0383 } 
0384 
0385 
0386 
0387 
0388 
0389 
0390 
0391 
0392 
0393 
0394 
0395 template<typename Derived>
0396 template<typename OtherDerived>
0397 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
0398 const Product<Derived, OtherDerived>
0399 MatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
0400 {
0401   
0402   
0403   
0404   
0405   enum {
0406     ProductIsValid =  Derived::ColsAtCompileTime==Dynamic
0407                    || OtherDerived::RowsAtCompileTime==Dynamic
0408                    || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
0409     AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
0410     SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
0411   };
0412   
0413   
0414   
0415   EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
0416     INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
0417   EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
0418     INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
0419   EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
0420 #ifdef EIGEN_DEBUG_PRODUCT
0421   internal::product_type<Derived,OtherDerived>::debug();
0422 #endif
0423 
0424   return Product<Derived, OtherDerived>(derived(), other.derived());
0425 }
0426 
0427 
0428 
0429 
0430 
0431 
0432 
0433 
0434 
0435 
0436 
0437 
0438 template<typename Derived>
0439 template<typename OtherDerived>
0440 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
0441 const Product<Derived,OtherDerived,LazyProduct>
0442 MatrixBase<Derived>::lazyProduct(const MatrixBase<OtherDerived> &other) const
0443 {
0444   enum {
0445     ProductIsValid =  Derived::ColsAtCompileTime==Dynamic
0446                    || OtherDerived::RowsAtCompileTime==Dynamic
0447                    || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
0448     AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
0449     SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
0450   };
0451   
0452   
0453   
0454   EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
0455     INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
0456   EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
0457     INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
0458   EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
0459 
0460   return Product<Derived,OtherDerived,LazyProduct>(derived(), other.derived());
0461 }
0462 
0463 } 
0464 
0465 #endif