Warning, file /include/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
0011 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
0012
0013
0014 namespace Eigen {
0015 namespace internal {
0016
0017 enum {
0018 ShardByRow = 0,
0019 ShardByCol = 1
0020 };
0021
0022
0023
0024 template<typename ResScalar, typename LhsScalar, typename RhsScalar, typename StorageIndex, int ShardingType = ShardByCol>
0025 class TensorContractionBlocking {
0026 public:
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042 #if !defined(EIGEN_HIPCC)
0043 EIGEN_DEVICE_FUNC
0044 #endif
0045 TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n, StorageIndex num_threads = 1) :
0046 kc_(k), mc_(m), nc_(n)
0047 {
0048 if (ShardingType == ShardByCol) {
0049 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads);
0050 }
0051 else {
0052 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads);
0053 }
0054
0055 const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
0056 kc_ = (rhs_packet_size <= 8 || kc_ <= rhs_packet_size) ?
0057 kc_ : (kc_ / rhs_packet_size) * rhs_packet_size;
0058 }
0059
0060 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; }
0061 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; }
0062 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; }
0063
0064 private:
0065 StorageIndex kc_;
0066 StorageIndex mc_;
0067 StorageIndex nc_;
0068 };
0069
0070 }
0071 }
0072
0073 #endif