Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 // This file is part of Eigen, a lightweight C++ template library
0002 // for linear algebra.
0003 //
0004 // Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
0005 // Copyright (C) 2015 Navdeep Jaitly <ndjaitly@google.com>
0006 // Copyright (C) 2014 Eric Martin <eric@ericmart.in>
0007 //
0008 // This Source Code Form is subject to the terms of the Mozilla
0009 // Public License v. 2.0. If a copy of the MPL was not distributed
0010 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
0011 
0012 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
0013 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
0014 
0015 #if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
0016 
0017 namespace Eigen {
0018 
0019 template<typename Scalar, typename Index, typename LhsMapper,
0020          typename RhsMapper, typename OutputMapper, bool needs_edge_check>
0021 __device__ EIGEN_STRONG_INLINE void
0022 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
0023                                const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
0024                        const Index m_size, const Index n_size, const Index k_size) {
0025 
0026   const Index m_block_idx = blockIdx.x;
0027   const Index n_block_idx = blockIdx.y;
0028 
0029   const Index base_m = 64 * m_block_idx;
0030   const Index base_n = 64 * n_block_idx;
0031 
0032   // declare and initialize 64 registers for output 8x8 block
0033 
0034   // prefetch registers
0035   Scalar lhs_pf0;
0036   Scalar lhs_pf1;
0037   Scalar lhs_pf2;
0038   Scalar lhs_pf3;
0039   Scalar lhs_pf4;
0040   Scalar lhs_pf5;
0041   Scalar lhs_pf6;
0042   Scalar lhs_pf7;
0043 
0044   Scalar rhs_pf0;
0045   Scalar rhs_pf1;
0046   Scalar rhs_pf2;
0047   Scalar rhs_pf3;
0048   Scalar rhs_pf4;
0049   Scalar rhs_pf5;
0050   Scalar rhs_pf6;
0051   Scalar rhs_pf7;
0052 
0053   // shared memory is formatted
0054   // (contract idx in block, nocontract idx in block, block idx)
0055   // where block idx is column major. This transposition limits the number of
0056   // bank conflicts when reading the LHS. The core idea is that since the contracting
0057   // index is shared by both sides, then the contracting index should be in threadIdx.x.
0058 
0059   // On the LHS, we pad each row inside of each block with an extra element. This makes
0060   // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
0061   // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
0062 
0063   // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
0064   // conflicts on writes and also none on reads.
0065 
0066   // storage indices
0067   const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
0068   const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
0069 
0070   const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
0071   const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
0072   const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
0073   const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
0074   const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
0075   const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
0076   const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
0077   const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
0078 
0079   const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
0080   const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
0081   const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
0082   const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
0083   const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
0084   const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
0085   const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
0086   const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
0087 
0088   // in the loading code, the following variables are important:
0089   // threadIdx.x: the vertical position in an 8x8 block
0090   // threadIdx.y: the vertical index of the 8x8 block in the grid
0091   // threadIdx.z: the horizontal position in an 8x8 block
0092   // k: the horizontal index of the 8x8 block in the grid
0093   //
0094   // The k parameter is implicit (it was the loop counter for a loop that went
0095   // from 0 to <8, but now that loop is unrolled in the below code.
0096 
0097   const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
0098   const Index lhs_vert = base_m + load_idx_vert;
0099 
0100 #define prefetchIntoRegisters(base_k)                           \
0101   {                                                             \
0102     lhs_pf0 = conv(0);                                          \
0103     lhs_pf1 = conv(0);                                          \
0104     lhs_pf2 = conv(0);                                          \
0105     lhs_pf3 = conv(0);                                          \
0106     lhs_pf4 = conv(0);                                          \
0107     lhs_pf5 = conv(0);                                          \
0108     lhs_pf6 = conv(0);                                          \
0109     lhs_pf7 = conv(0);                                          \
0110                                                                 \
0111     rhs_pf0 = conv(0);                                          \
0112     rhs_pf1 = conv(0);                                          \
0113     rhs_pf2 = conv(0);                                          \
0114     rhs_pf3 = conv(0);                                          \
0115     rhs_pf4 = conv(0);                                          \
0116     rhs_pf5 = conv(0);                                          \
0117     rhs_pf6 = conv(0);                                          \
0118     rhs_pf7 = conv(0);                                          \
0119                                                                 \
0120     if (!needs_edge_check || lhs_vert < m_size) {               \
0121       const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8;   \
0122       const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8;   \
0123       const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8;   \
0124       const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8;   \
0125       const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8;   \
0126       const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8;   \
0127       const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8;   \
0128       const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8;   \
0129                                                                 \
0130       if (!needs_edge_check || lhs_horiz_7 < k_size) {          \
0131         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
0132         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
0133         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
0134         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
0135         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
0136         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
0137         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
0138         lhs_pf7 = lhs(lhs_vert, lhs_horiz_7);                   \
0139       } else if (lhs_horiz_6 < k_size) {                        \
0140         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
0141         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
0142         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
0143         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
0144         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
0145         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
0146         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
0147       } else if (lhs_horiz_5 < k_size) {                        \
0148         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
0149         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
0150         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
0151         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
0152         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
0153         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
0154       } else if (lhs_horiz_4 < k_size) {                        \
0155         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
0156         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
0157         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
0158         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
0159         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
0160       } else if (lhs_horiz_3 < k_size) {                        \
0161         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
0162         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
0163         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
0164         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
0165       } else if (lhs_horiz_2 < k_size) {                        \
0166         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
0167         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
0168         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
0169       } else if (lhs_horiz_1 < k_size) {                        \
0170         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
0171         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
0172       } else if (lhs_horiz_0 < k_size) {                        \
0173         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
0174       }                                                         \
0175     }                                                           \
0176                                                                 \
0177     const Index rhs_vert = base_k + load_idx_vert;              \
0178     if (!needs_edge_check || rhs_vert < k_size) {               \
0179       const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8;   \
0180       const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8;   \
0181       const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8;   \
0182       const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8;   \
0183       const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8;   \
0184       const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8;   \
0185       const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8;   \
0186       const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8;   \
0187                                                                 \
0188       if (rhs_horiz_7 < n_size) {                               \
0189         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
0190         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
0191         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
0192         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
0193         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
0194         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
0195         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
0196         rhs_pf7 = rhs(rhs_vert, rhs_horiz_7);                   \
0197       } else if (rhs_horiz_6 < n_size) {                        \
0198         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
0199         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
0200         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
0201         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
0202         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
0203         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
0204         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
0205       } else if (rhs_horiz_5 < n_size) {                        \
0206         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
0207         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
0208         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
0209         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
0210         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
0211         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
0212       } else if (rhs_horiz_4 < n_size) {                        \
0213         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
0214         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
0215         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
0216         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
0217         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
0218       } else if (rhs_horiz_3 < n_size) {                        \
0219         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
0220         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
0221         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
0222         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
0223       } else if (rhs_horiz_2 < n_size) {                        \
0224         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
0225         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
0226         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
0227       } else if (rhs_horiz_1 < n_size) {                        \
0228         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
0229         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
0230       } else if (rhs_horiz_0 < n_size) {                        \
0231         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
0232       }                                                         \
0233     }                                                           \
0234   }                                                             \
0235 
0236 #define writeRegToShmem(_)                      \
0237   lhs_shmem[lhs_store_idx_0] = lhs_pf0;         \
0238   rhs_shmem[rhs_store_idx_0] = rhs_pf0;         \
0239                                                 \
0240   lhs_shmem[lhs_store_idx_1] = lhs_pf1;         \
0241   rhs_shmem[rhs_store_idx_1] = rhs_pf1;         \
0242                                                 \
0243   lhs_shmem[lhs_store_idx_2] = lhs_pf2;         \
0244   rhs_shmem[rhs_store_idx_2] = rhs_pf2;         \
0245                                                 \
0246   lhs_shmem[lhs_store_idx_3] = lhs_pf3;         \
0247   rhs_shmem[rhs_store_idx_3] = rhs_pf3;         \
0248                                                 \
0249   lhs_shmem[lhs_store_idx_4] = lhs_pf4;         \
0250   rhs_shmem[rhs_store_idx_4] = rhs_pf4;         \
0251                                                 \
0252   lhs_shmem[lhs_store_idx_5] = lhs_pf5;         \
0253   rhs_shmem[rhs_store_idx_5] = rhs_pf5;         \
0254                                                 \
0255   lhs_shmem[lhs_store_idx_6] = lhs_pf6;         \
0256   rhs_shmem[rhs_store_idx_6] = rhs_pf6;         \
0257                                                 \
0258   lhs_shmem[lhs_store_idx_7] = lhs_pf7;         \
0259   rhs_shmem[rhs_store_idx_7] = rhs_pf7;         \
0260 
0261   // declare and initialize result array
0262 #define res(i, j) _res_##i##j
0263 #define initResultRow(i)                        \
0264   Scalar res(i, 0) = conv(0);                   \
0265   Scalar res(i, 1) = conv(0);                   \
0266   Scalar res(i, 2) = conv(0);                   \
0267   Scalar res(i, 3) = conv(0);                   \
0268   Scalar res(i, 4) = conv(0);                   \
0269   Scalar res(i, 5) = conv(0);                   \
0270   Scalar res(i, 6) = conv(0);                   \
0271   Scalar res(i, 7) = conv(0);                   \
0272 
0273   internal::scalar_cast_op<int, Scalar> conv;
0274   initResultRow(0);
0275   initResultRow(1);
0276   initResultRow(2);
0277   initResultRow(3);
0278   initResultRow(4);
0279   initResultRow(5);
0280   initResultRow(6);
0281   initResultRow(7);
0282 #undef initResultRow
0283 
0284   for (Index base_k = 0; base_k < k_size; base_k += 64) {
0285     // wait for previous iteration to finish with shmem. Despite common sense,
0286     // the code is a bit faster with this here then at bottom of loop
0287     __syncthreads();
0288 
0289     prefetchIntoRegisters(base_k);
0290     writeRegToShmem();
0291 
0292     #undef prefetchIntoRegisters
0293     #undef writeRegToShmem
0294 
0295     // wait for shared mem packing to be done before starting computation
0296     __syncthreads();
0297 
0298     // compute 8x8 matrix product by outer product. This involves packing one column
0299     // of LHS and one row of RHS into registers (takes 16 registers).
0300 
0301 #define lcol(i) _lcol##i
0302     Scalar lcol(0);
0303     Scalar lcol(1);
0304     Scalar lcol(2);
0305     Scalar lcol(3);
0306     Scalar lcol(4);
0307     Scalar lcol(5);
0308     Scalar lcol(6);
0309     Scalar lcol(7);
0310 
0311 #define rrow(j) _rrow##j
0312     Scalar rrow(0);
0313     Scalar rrow(1);
0314     Scalar rrow(2);
0315     Scalar rrow(3);
0316     Scalar rrow(4);
0317     Scalar rrow(5);
0318     Scalar rrow(6);
0319     Scalar rrow(7);
0320 
0321     // Now x corresponds to k, y to m, and z to n
0322     const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
0323     const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
0324 
0325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
0326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
0327 
0328 #define loadData(i, j)                          \
0329     lcol(0) = lhs_element(0, j);               \
0330     rrow(0) = rhs_element(i, 0);               \
0331     lcol(1) = lhs_element(1, j);               \
0332     rrow(1) = rhs_element(i, 1);               \
0333     lcol(2) = lhs_element(2, j);               \
0334     rrow(2) = rhs_element(i, 2);               \
0335     lcol(3) = lhs_element(3, j);               \
0336     rrow(3) = rhs_element(i, 3);               \
0337     lcol(4) = lhs_element(4, j);               \
0338     rrow(4) = rhs_element(i, 4);               \
0339     lcol(5) = lhs_element(5, j);               \
0340     rrow(5) = rhs_element(i, 5);               \
0341     lcol(6) = lhs_element(6, j);               \
0342     rrow(6) = rhs_element(i, 6);               \
0343     lcol(7) = lhs_element(7, j);               \
0344     rrow(7) = rhs_element(i, 7);               \
0345 
0346 #define computeCol(j)                           \
0347     res(0, j) += lcol(0) * rrow(j);             \
0348     res(1, j) += lcol(1) * rrow(j);             \
0349     res(2, j) += lcol(2) * rrow(j);             \
0350     res(3, j) += lcol(3) * rrow(j);             \
0351     res(4, j) += lcol(4) * rrow(j);             \
0352     res(5, j) += lcol(5) * rrow(j);             \
0353     res(6, j) += lcol(6) * rrow(j);             \
0354     res(7, j) += lcol(7) * rrow(j);             \
0355 
0356 #define computePass(i)                          \
0357     loadData(i, i);                             \
0358                                                 \
0359     computeCol(0);                              \
0360     computeCol(1);                              \
0361     computeCol(2);                              \
0362     computeCol(3);                              \
0363     computeCol(4);                              \
0364     computeCol(5);                              \
0365     computeCol(6);                              \
0366     computeCol(7);                              \
0367 
0368     computePass(0);
0369     computePass(1);
0370     computePass(2);
0371     computePass(3);
0372     computePass(4);
0373     computePass(5);
0374     computePass(6);
0375     computePass(7);
0376 
0377 #undef lcol
0378 #undef rrow
0379 #undef lhs_element
0380 #undef rhs_element
0381 #undef loadData
0382 #undef computeCol
0383 #undef computePass
0384   } // end loop over k
0385 
0386   // we've now iterated over all of the large (ie width 64) k blocks and
0387   // accumulated results in registers. At this point thread (x, y, z) contains
0388   // the sum across all big k blocks of the product of little k block of index (x, y)
0389   // with block of index (y, z). To compute the final output, we need to reduce
0390   // the 8 threads over y by summation.
0391 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
0392 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
0393 #else
0394 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
0395 #endif
0396 
0397 #define reduceRow(i, mask)                      \
0398   shuffleInc(i, 0, mask);                       \
0399   shuffleInc(i, 1, mask);                       \
0400   shuffleInc(i, 2, mask);                       \
0401   shuffleInc(i, 3, mask);                       \
0402   shuffleInc(i, 4, mask);                       \
0403   shuffleInc(i, 5, mask);                       \
0404   shuffleInc(i, 6, mask);                       \
0405   shuffleInc(i, 7, mask);                       \
0406 
0407 #define reduceMatrix(mask)                      \
0408   reduceRow(0, mask);                           \
0409   reduceRow(1, mask);                           \
0410   reduceRow(2, mask);                           \
0411   reduceRow(3, mask);                           \
0412   reduceRow(4, mask);                           \
0413   reduceRow(5, mask);                           \
0414   reduceRow(6, mask);                           \
0415   reduceRow(7, mask);                           \
0416 
0417   // actually perform the reduction, now each thread of index (_, y, z)
0418   // contains the correct values in its registers that belong in the output
0419   // block
0420   reduceMatrix(1);
0421   reduceMatrix(2);
0422   reduceMatrix(4);
0423 
0424 #undef shuffleInc
0425 #undef reduceRow
0426 #undef reduceMatrix
0427 
0428   // now we need to copy the 64 values into main memory. We can't split work
0429   // among threads because all variables are in registers. There's 2 ways
0430   // to do this:
0431   // (1) have 1 thread do 64 writes from registers into global memory
0432   // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
0433   //     each do 8 writes into global memory. We can just overwrite the shared
0434   //     memory from the problem we just solved.
0435   // (2) is slightly faster than (1) due to less branching and more ILP
0436 
0437   // TODO: won't yield much gain, but could just use currently unused shared mem
0438   //       and then we won't have to sync
0439   // wait for shared mem to be out of use
0440   __syncthreads();
0441 
0442 #define writeResultShmem(i, j)                                          \
0443   lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
0444 
0445 #define writeRow(i)                             \
0446   writeResultShmem(i, 0);                       \
0447   writeResultShmem(i, 1);                       \
0448   writeResultShmem(i, 2);                       \
0449   writeResultShmem(i, 3);                       \
0450   writeResultShmem(i, 4);                       \
0451   writeResultShmem(i, 5);                       \
0452   writeResultShmem(i, 6);                       \
0453   writeResultShmem(i, 7);                       \
0454 
0455   if (threadIdx.x == 0) {
0456     writeRow(0);
0457     writeRow(1);
0458     writeRow(2);
0459     writeRow(3);
0460     writeRow(4);
0461     writeRow(5);
0462     writeRow(6);
0463     writeRow(7);
0464   }
0465 #undef writeResultShmem
0466 #undef writeRow
0467 
0468   const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
0469   const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
0470 
0471   if (threadIdx.x < max_i_write) {
0472     if (max_j_write == 8) {
0473       // TODO: can i trade bank conflicts for coalesced writes?
0474       Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
0475       Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
0476       Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
0477       Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
0478       Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
0479       Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
0480       Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
0481       Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
0482 
0483       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
0484       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
0485       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
0486       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
0487       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
0488       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
0489       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
0490       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
0491     } else {
0492 #pragma unroll 7
0493       for (int j = 0; j < max_j_write; j++) {
0494         Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
0495         output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
0496       }
0497     }
0498   }
0499 #undef res
0500 }
0501 
0502 
0503 template<typename Scalar, typename Index, typename LhsMapper,
0504          typename RhsMapper, typename OutputMapper>
0505 __global__ void
0506 #if defined(EIGEN_HIPCC)
0507 __launch_bounds__(512, 1)
0508 #else
0509 __launch_bounds__(512)
0510 #endif
0511 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
0512                        const OutputMapper output,
0513                        const Index m_size, const Index n_size, const Index k_size) {
0514   __shared__ Scalar lhs_shmem[72 * 64];
0515   __shared__ Scalar rhs_shmem[72 * 64];
0516 
0517   const Index m_block_idx = blockIdx.x;
0518   const Index n_block_idx = blockIdx.y;
0519 
0520   const Index base_m = 64 * m_block_idx;
0521   const Index base_n = 64 * n_block_idx;
0522 
0523   if (base_m + 63 < m_size && base_n + 63 < n_size) {
0524     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
0525   } else {
0526     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
0527   }
0528 }
0529 
0530 
0531 template<typename Index, typename LhsMapper,
0532          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
0533          bool CHECK_RHS_BOUNDARY>
0534 __device__ __forceinline__ void
0535 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
0536                        const OutputMapper output, float2 lhs_shmem2[][16],
0537                        float2 rhs_shmem2[][8], const Index m_size,
0538                        const Index n_size, const Index k_size,
0539                        const Index base_m, const Index base_n) {
0540 
0541   // prefetch registers
0542   float4 lhs_pf0, rhs_pf0;
0543 
0544   float4 results[4];
0545   for (int i=0; i < 4; i++) {
0546     results[i].x = results[i].y = results[i].z = results[i].w = 0;
0547   }
0548 
0549 #define prefetch_lhs(reg, row, col)                            \
0550     if (!CHECK_LHS_BOUNDARY) {                                 \
0551       if (col < k_size) {                                      \
0552         reg =lhs.template loadPacket<float4,Unaligned>(row, col);     \
0553       }                                                        \
0554     } else {                                                   \
0555       if (col < k_size) {                                      \
0556         if (row + 3 < m_size) {                                \
0557           reg =lhs.template loadPacket<float4,Unaligned>(row, col);   \
0558         } else if (row + 2 < m_size) {                         \
0559           reg.x =lhs(row + 0, col);                            \
0560           reg.y =lhs(row + 1, col);                            \
0561           reg.z =lhs(row + 2, col);                            \
0562         } else if (row + 1 < m_size) {                         \
0563           reg.x =lhs(row + 0, col);                            \
0564           reg.y =lhs(row + 1, col);                            \
0565         } else if (row  < m_size) {                            \
0566           reg.x =lhs(row + 0, col);                            \
0567         }                                                      \
0568       }                                                        \
0569     }                                  \
0570 
0571   Index lhs_vert = base_m+threadIdx.x*4;
0572 
0573   for (Index k = 0; k < k_size; k += 16) {
0574 
0575     lhs_pf0 = internal::pset1<float4>(0);
0576     rhs_pf0 = internal::pset1<float4>(0);
0577 
0578     Index lhs_horiz = threadIdx.y+k;
0579     prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
0580 
0581     Index rhs_vert = k+(threadIdx.x%4)*4;
0582     Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
0583 
0584     if (!CHECK_RHS_BOUNDARY) {
0585       if ((rhs_vert + 3) < k_size) {
0586         // just CHECK_RHS_BOUNDARY
0587         rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
0588       } else if (rhs_vert + 2 < k_size) {
0589         // just CHECK_RHS_BOUNDARY
0590         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0591         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0592         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
0593       } else if (rhs_vert + 1 < k_size) {
0594         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0595         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0596       } else if (rhs_vert  < k_size) {
0597         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0598       }
0599     } else {
0600       if (rhs_horiz0 < n_size) {
0601         if ((rhs_vert + 3) < k_size) {
0602           rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
0603         } else if ((rhs_vert + 2) < k_size) {
0604           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0605           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0606           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
0607         } else if ((rhs_vert + 1) < k_size) {
0608           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0609           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0610         } else if (rhs_vert  < k_size) {
0611           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0612         }
0613       }
0614     }
0615     float x1, x2 ;
0616     // the following can be a bitwise operation..... some day.
0617     if((threadIdx.x%8) < 4) {
0618       x1 = rhs_pf0.y;
0619       x2 = rhs_pf0.w;
0620     } else {
0621       x1 = rhs_pf0.x;
0622       x2 = rhs_pf0.z;
0623     }
0624     #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
0625     x1 = __shfl_xor(x1, 4);
0626     x2 = __shfl_xor(x2, 4);
0627     #else
0628     x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
0629     x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
0630     #endif
0631     if((threadIdx.x%8) < 4) {
0632       rhs_pf0.y = x1;
0633       rhs_pf0.w = x2;
0634     } else {
0635       rhs_pf0.x = x1;
0636       rhs_pf0.z = x2;
0637     }
0638 
0639     // We have 64 features.
0640     // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1.
0641     // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3.
0642     // ...
0643     // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
0644     // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
0645     // ...
0646     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
0647     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
0648 
0649     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
0650     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
0651     // ...
0652     // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
0653     // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63)
0654     // ...
0655 
0656     lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
0657     lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
0658 
0659 
0660 #define add_vals(fl1, fl2, fr1, fr2)\
0661     results[0].x += fl1.x * fr1.x;\
0662     results[0].y += fl1.y * fr1.x;\
0663     results[0].z += fl2.x * fr1.x;\
0664     results[0].w += fl2.y * fr1.x;\
0665 \
0666     results[1].x += fl1.x * fr1.y;\
0667     results[1].y += fl1.y * fr1.y;\
0668     results[1].z += fl2.x * fr1.y;\
0669     results[1].w += fl2.y * fr1.y;\
0670 \
0671     results[2].x += fl1.x * fr2.x;\
0672     results[2].y += fl1.y * fr2.x;\
0673     results[2].z += fl2.x * fr2.x;\
0674     results[2].w += fl2.y * fr2.x;\
0675 \
0676     results[3].x += fl1.x * fr2.y;\
0677     results[3].y += fl1.y * fr2.y;\
0678     results[3].z += fl2.x * fr2.y;\
0679     results[3].w += fl2.y * fr2.y;\
0680 
0681     __syncthreads();
0682 
0683     // Do the multiplies.
0684     #pragma unroll
0685     for (int koff = 0; koff < 16; koff ++) {
0686       // 32 x threads.
0687       float2 fl1 = lhs_shmem2[koff][threadIdx.x];
0688       float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
0689 
0690       int start_feature = threadIdx.y * 4;
0691       float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
0692       float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
0693 
0694       add_vals(fl1, fl2, fr1, fr2)
0695     }
0696     __syncthreads();
0697   }
0698 
0699 #undef prefetch_lhs
0700 #undef add_vals
0701 
0702   Index horiz_base = threadIdx.y*4+base_n;
0703   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
0704     for (int i = 0; i < 4; i++) {
0705       output(lhs_vert, horiz_base + i) = results[i].x;
0706       output(lhs_vert + 1, horiz_base + i) = results[i].y;
0707       output(lhs_vert + 2, horiz_base + i) = results[i].z;
0708       output(lhs_vert + 3, horiz_base + i) = results[i].w;
0709     }
0710   } else if (!CHECK_RHS_BOUNDARY) {
0711     // CHECK LHS
0712     if (lhs_vert + 3 < m_size) {
0713       for (int i = 0; i < 4; i++) {
0714         output(lhs_vert, horiz_base + i) = results[i].x;
0715         output(lhs_vert + 1, horiz_base + i) = results[i].y;
0716         output(lhs_vert + 2, horiz_base + i) = results[i].z;
0717         output(lhs_vert + 3, horiz_base + i) = results[i].w;
0718       }
0719     } else if (lhs_vert + 2 < m_size) {
0720       for (int i = 0; i < 4; i++) {
0721         output(lhs_vert, horiz_base + i) = results[i].x;
0722         output(lhs_vert + 1, horiz_base + i) = results[i].y;
0723         output(lhs_vert + 2, horiz_base + i) = results[i].z;
0724       }
0725     } else if (lhs_vert + 1 < m_size) {
0726       for (int i = 0; i < 4; i++) {
0727         output(lhs_vert, horiz_base + i) = results[i].x;
0728         output(lhs_vert + 1, horiz_base + i) = results[i].y;
0729       }
0730     } else if (lhs_vert  < m_size) {
0731       for (int i = 0; i < 4; i++) {
0732         output(lhs_vert, horiz_base + i) = results[i].x;
0733       }
0734     }
0735   } else if (!CHECK_LHS_BOUNDARY) {
0736     // CHECK RHS
0737     /*
0738     int ncols_rem = fminf(n_size- horiz_base, 4);
0739     for (int i = 0; i < ncols_rem; i++) {
0740       output(lhs_vert, horiz_base + i) = results[i].x;
0741       output(lhs_vert + 1, horiz_base + i) = results[i].y;
0742       output(lhs_vert + 2, horiz_base + i) = results[i].z;
0743       output(lhs_vert + 3, horiz_base + i) = results[i].w;
0744     }*/
0745     for (int i = 0; i < 4; i++) {
0746       if (horiz_base+i < n_size) {
0747         output(lhs_vert, horiz_base + i) = results[i].x;
0748         output(lhs_vert + 1, horiz_base + i) = results[i].y;
0749         output(lhs_vert + 2, horiz_base + i) = results[i].z;
0750         output(lhs_vert + 3, horiz_base + i) = results[i].w;
0751        }
0752     }
0753   } else {
0754     // CHECK both boundaries.
0755     for (int i = 0; i < 4; i++) {
0756       if (horiz_base+i < n_size) {
0757         if (lhs_vert < m_size)
0758           output(lhs_vert, horiz_base + i) = results[i].x;
0759         if (lhs_vert + 1 < m_size)
0760           output(lhs_vert + 1, horiz_base + i) = results[i].y;
0761         if (lhs_vert + 2 < m_size)
0762           output(lhs_vert + 2, horiz_base + i) = results[i].z;
0763         if (lhs_vert + 3 < m_size)
0764           output(lhs_vert + 3, horiz_base + i) = results[i].w;
0765       }
0766     }
0767   }
0768 }
0769 
0770 
0771 template<typename Index, typename LhsMapper,
0772          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
0773          bool CHECK_RHS_BOUNDARY>
0774 __device__ __forceinline__ void
0775 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
0776                        const OutputMapper output, float2 lhs_shmem2[][32],
0777                        float2 rhs_shmem2[][8], const Index m_size,
0778                        const Index n_size, const Index k_size,
0779                        const Index base_m, const Index base_n) {
0780 
0781   // prefetch registers
0782   float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
0783   float4 rhs_pf0, rhs_pf1;
0784 
0785   float4 results[8];
0786   for (int i=0; i < 8; i++) {
0787     results[i].x = results[i].y = results[i].z = results[i].w = 0;
0788   }
0789 
0790   Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
0791   for (Index k = 0; k < k_size; k += 32) {
0792     lhs_pf0 = internal::pset1<float4>(0);
0793     lhs_pf1 = internal::pset1<float4>(0);
0794     lhs_pf2 = internal::pset1<float4>(0);
0795     lhs_pf3 = internal::pset1<float4>(0);
0796 
0797     rhs_pf0 = internal::pset1<float4>(0);
0798     rhs_pf1 = internal::pset1<float4>(0);
0799 
0800      if (!CHECK_LHS_BOUNDARY) {
0801       if ((threadIdx.y/4+k+24) < k_size) {
0802         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0803         lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0804         lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
0805         lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
0806       } else if ((threadIdx.y/4+k+16) < k_size) {
0807         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0808         lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0809         lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
0810       } else if ((threadIdx.y/4+k+8) < k_size) {
0811         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0812         lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0813       } else if ((threadIdx.y/4+k) < k_size) {
0814         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0815       }
0816     } else {
0817       // just CHECK_LHS_BOUNDARY
0818       if (lhs_vert + 3 < m_size) {
0819         if ((threadIdx.y/4+k+24) < k_size) {
0820           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0821           lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0822           lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
0823           lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
0824         } else if ((threadIdx.y/4+k+16) < k_size) {
0825           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0826           lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0827           lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
0828         } else if ((threadIdx.y/4+k+8) < k_size) {
0829           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0830           lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0831         } else if ((threadIdx.y/4+k) < k_size) {
0832           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0833         }
0834       } else if (lhs_vert + 2 < m_size) {
0835         if ((threadIdx.y/4+k+24) < k_size) {
0836           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0837           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0838           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
0839           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0840           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0841           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
0842           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0843           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
0844           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
0845           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
0846           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
0847           lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
0848         } else if ((threadIdx.y/4+k+16) < k_size) {
0849           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0850           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0851           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
0852           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0853           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0854           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
0855           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0856           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
0857           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
0858         } else if ((threadIdx.y/4+k+8) < k_size) {
0859           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0860           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0861           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
0862           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0863           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0864           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
0865         } else if ((threadIdx.y/4+k) < k_size) {
0866           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0867           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0868           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
0869         }
0870       } else if (lhs_vert + 1 < m_size) {
0871         if ((threadIdx.y/4+k+24) < k_size) {
0872           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0873           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0874           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0875           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0876           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0877           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
0878           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
0879           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
0880         } else if ((threadIdx.y/4+k+16) < k_size) {
0881           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0882           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0883           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0884           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0885           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0886           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
0887         } else if ((threadIdx.y/4+k+8) < k_size) {
0888           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0889           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0890           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0891           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0892         } else if ((threadIdx.y/4+k) < k_size) {
0893           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0894           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0895         }
0896       } else if (lhs_vert < m_size) {
0897         if ((threadIdx.y/4+k+24) < k_size) {
0898           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0899           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0900           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0901           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
0902         } else if ((threadIdx.y/4+k+16) < k_size) {
0903           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0904           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0905           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0906         } else if ((threadIdx.y/4+k+8) < k_size) {
0907           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0908           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0909         } else if ((threadIdx.y/4+k) < k_size) {
0910           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0911         }
0912       }
0913     }
0914     __syncthreads();
0915     Index rhs_vert = k+threadIdx.x*4;
0916     Index rhs_horiz0 = threadIdx.y*2+base_n;
0917     Index rhs_horiz1 = threadIdx.y*2+1+base_n;
0918     if (!CHECK_RHS_BOUNDARY) {
0919       if ((rhs_vert + 3) < k_size) {
0920         // just CHECK_RHS_BOUNDARY
0921         rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
0922         rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
0923       } else if (rhs_vert + 2 < k_size) {
0924         // just CHECK_RHS_BOUNDARY
0925         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0926         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0927         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
0928         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0929         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
0930         rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
0931       } else if (rhs_vert + 1 < k_size) {
0932         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0933         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0934         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0935         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
0936       } else if (rhs_vert  < k_size) {
0937         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0938         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0939       }
0940     } else {
0941       if (rhs_horiz1 < n_size) {
0942         if ((rhs_vert + 3) < k_size) {
0943           // just CHECK_RHS_BOUNDARY
0944           rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
0945           rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
0946         } else if (rhs_vert + 2 < k_size) {
0947           // just CHECK_RHS_BOUNDARY
0948           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0949           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0950           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
0951           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0952           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
0953           rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
0954         } else if (k+threadIdx.x*4 + 1 < k_size) {
0955           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0956           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0957           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0958           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
0959         } else if (k+threadIdx.x*4  < k_size) {
0960           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0961           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0962         }
0963       } else if (rhs_horiz0 < n_size) {
0964         if ((rhs_vert + 3) < k_size) {
0965           // just CHECK_RHS_BOUNDARY
0966           rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
0967         } else if ((rhs_vert + 2) < k_size) {
0968           // just CHECK_RHS_BOUNDARY
0969           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0970           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0971           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
0972         } else if ((rhs_vert + 1) < k_size) {
0973           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0974           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0975         } else if (rhs_vert  < k_size) {
0976           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0977         }
0978       }
0979     }
0980     __syncthreads();
0981     // Loaded. Do computation
0982     // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1.
0983     // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
0984     // ..
0985     // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
0986     rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
0987     // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
0988     // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
0989     // ..
0990     rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
0991     // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
0992     // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
0993     rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
0994     // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
0995     // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
0996     rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
0997 
0998     // LHS.
0999     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
1000     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
1001     // ...
1002     // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
1003     // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
1004 
1005 
1006 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\
1007       results[0].x += a_feat1.x * f1.x;\
1008       results[1].x += a_feat1.x * f1.y;\
1009       results[2].x += a_feat1.x * f2.x;\
1010       results[3].x += a_feat1.x * f2.y;\
1011       results[4].x += a_feat1.x * f3.x;\
1012       results[5].x += a_feat1.x * f3.y;\
1013       results[6].x += a_feat1.x * f4.x;\
1014       results[7].x += a_feat1.x * f4.y;\
1015 \
1016       results[0].y += a_feat1.y * f1.x;\
1017       results[1].y += a_feat1.y * f1.y;\
1018       results[2].y += a_feat1.y * f2.x;\
1019       results[3].y += a_feat1.y * f2.y;\
1020       results[4].y += a_feat1.y * f3.x;\
1021       results[5].y += a_feat1.y * f3.y;\
1022       results[6].y += a_feat1.y * f4.x;\
1023       results[7].y += a_feat1.y * f4.y;\
1024 \
1025       results[0].z += a_feat2.x * f1.x;\
1026       results[1].z += a_feat2.x * f1.y;\
1027       results[2].z += a_feat2.x * f2.x;\
1028       results[3].z += a_feat2.x * f2.y;\
1029       results[4].z += a_feat2.x * f3.x;\
1030       results[5].z += a_feat2.x * f3.y;\
1031       results[6].z += a_feat2.x * f4.x;\
1032       results[7].z += a_feat2.x * f4.y;\
1033 \
1034       results[0].w += a_feat2.y * f1.x;\
1035       results[1].w += a_feat2.y * f1.y;\
1036       results[2].w += a_feat2.y * f2.x;\
1037       results[3].w += a_feat2.y * f2.y;\
1038       results[4].w += a_feat2.y * f3.x;\
1039       results[5].w += a_feat2.y * f3.y;\
1040       results[6].w += a_feat2.y * f4.x;\
1041       results[7].w += a_feat2.y * f4.y;\
1042 
1043     lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
1044     lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
1045     lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
1046     lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
1047 
1048     lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
1049     lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
1050     lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
1051     lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
1052 
1053     __syncthreads();
1054 
1055     // Do the multiplies.
1056     #pragma unroll
1057     for (int koff = 0; koff < 32; koff ++) {
1058       float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
1059       float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
1060 
1061       // first feature is at (threadIdx.y/4) * 8 last is at start + 8.
1062       int start_feature = (threadIdx.y / 4) * 8;
1063 
1064       float2 br1 = rhs_shmem2[start_feature/2 +     (koff % 4) * 32][koff/4];
1065       float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
1066       float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
1067       float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
1068 
1069       add_vals(a3, a4, br1, br2, br3, br4)
1070     }
1071     __syncthreads();
1072   } // end loop over k
1073 
1074   __syncthreads();
1075   Index horiz_base = (threadIdx.y/4)*8+base_n;
1076   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1077     for (int i = 0; i < 8; i++) {
1078       output(lhs_vert, horiz_base + i) = results[i].x;
1079       output(lhs_vert + 1, horiz_base + i) = results[i].y;
1080       output(lhs_vert + 2, horiz_base + i) = results[i].z;
1081       output(lhs_vert + 3, horiz_base + i) = results[i].w;
1082     }
1083   } else if (!CHECK_RHS_BOUNDARY) {
1084     if (lhs_vert + 3 < m_size) {
1085       for (int i = 0; i < 8; i++) {
1086         output(lhs_vert, horiz_base + i) = results[i].x;
1087         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1088         output(lhs_vert + 2, horiz_base + i) = results[i].z;
1089         output(lhs_vert + 3, horiz_base + i) = results[i].w;
1090       }
1091     } else if (lhs_vert + 2 < m_size) {
1092       for (int i = 0; i < 8; i++) {
1093         output(lhs_vert, horiz_base + i) = results[i].x;
1094         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1095         output(lhs_vert + 2, horiz_base + i) = results[i].z;
1096       }
1097     } else if (lhs_vert + 1 < m_size) {
1098       for (int i = 0; i < 8; i++) {
1099         output(lhs_vert, horiz_base + i) = results[i].x;
1100         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1101       }
1102     } else if (lhs_vert  < m_size) {
1103       for (int i = 0; i < 8; i++) {
1104         output(lhs_vert, horiz_base + i) = results[i].x;
1105       }
1106     }
1107   } else if (!CHECK_LHS_BOUNDARY) {
1108     // CHECK BOUNDARY_B
1109     for (int i = 0; i < 8; i++) {
1110       if (horiz_base + i < n_size) {
1111         output(lhs_vert, horiz_base + i) = results[i].x;
1112         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1113         output(lhs_vert + 2, horiz_base + i) = results[i].z;
1114         output(lhs_vert + 3, horiz_base + i) = results[i].w;
1115       }
1116     }
1117   } else {
1118     // CHECK both boundaries.
1119     for (int i = 0; i < 8; i++) {
1120       if (horiz_base + i < n_size) {
1121         if (lhs_vert < m_size)
1122           output(lhs_vert, horiz_base + i) = results[i].x;
1123         if (lhs_vert + 1 < m_size)
1124           output(lhs_vert + 1, horiz_base + i) = results[i].y;
1125         if (lhs_vert + 2 < m_size)
1126           output(lhs_vert + 2, horiz_base + i) = results[i].z;
1127         if (lhs_vert + 3 < m_size)
1128           output(lhs_vert + 3, horiz_base + i) = results[i].w;
1129       }
1130     }
1131   }
1132 }
1133 
1134 
1135 template<typename Index, typename LhsMapper,
1136          typename RhsMapper, typename OutputMapper>
1137 __global__ void
1138 #if defined(EIGEN_HIPCC)
1139 __launch_bounds__(256, 1)
1140 #else
1141 __launch_bounds__(256)
1142 #endif
1143 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
1144                        const OutputMapper output,
1145                        const Index m_size, const Index n_size, const Index k_size) {
1146   __shared__ float2 lhs_shmem[64*32];
1147   __shared__ float2 rhs_shmem[128*8];
1148 
1149   typedef float2 LHS_MEM[64][32];
1150   typedef float2 RHS_MEM[128][8];
1151 
1152   const Index m_block_idx = blockIdx.x;
1153   const Index n_block_idx = blockIdx.y;
1154 
1155   const Index base_m = 128 * m_block_idx;
1156   const Index base_n = 64 * n_block_idx;
1157 
1158   bool check_rhs = (base_n + 63) >= n_size;
1159   bool check_lhs128 = (base_m + 127) >= m_size;
1160 
1161   if (!check_rhs) {
1162     if (!check_lhs128) {
1163       // >= 128 rows left
1164       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1165                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1166     } else {
1167       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1168                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1169     }
1170   } else {
1171     if (!check_lhs128) {
1172       // >= 128 rows left
1173       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1174                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1175     } else {
1176       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1177                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1178     }
1179   }
1180 }
1181 
1182 template<typename Index, typename LhsMapper,
1183          typename RhsMapper, typename OutputMapper>
1184 __global__ void
1185 #if defined(EIGEN_HIPCC)
1186 __launch_bounds__(256, 1)
1187 #else
1188 __launch_bounds__(256)
1189 #endif
1190 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
1191                        const OutputMapper output,
1192                        const Index m_size, const Index n_size, const Index k_size) {
1193   __shared__ float2 lhs_shmem[32][16];
1194   __shared__ float2 rhs_shmem[64][8];
1195 
1196   const Index m_block_idx = blockIdx.x;
1197   const Index n_block_idx = blockIdx.y;
1198 
1199   const Index base_m = 64 * m_block_idx;
1200   const Index base_n = 64 * n_block_idx;
1201 
1202   if (base_m + 63 < m_size) {
1203     if (base_n + 63 < n_size) {
1204       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1205     } else {
1206       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1207     }
1208   } else {
1209     if (base_n + 63 < n_size) {
1210       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1211     } else {
1212       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1213     }
1214   }
1215 }
1216 
1217 
1218 template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
1219 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> :
1220     public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
1221 
1222   typedef GpuDevice Device;
1223 
1224   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1225   typedef TensorContractionEvaluatorBase<Self> Base;
1226 
1227   typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
1228   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
1229   typedef typename XprType::Index Index;
1230   typedef typename XprType::CoeffReturnType CoeffReturnType;
1231   typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
1232 
1233   enum {
1234     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
1235   };
1236 
1237   // Most of the code is assuming that both input tensors are ColMajor. If the
1238   // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
1239   // If we want to compute A * B = C, where A is LHS and B is RHS, the code
1240   // will pretend B is LHS and A is RHS.
1241   typedef typename internal::conditional<
1242     static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
1243   typedef typename internal::conditional<
1244     static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
1245 
1246   static const int LDims =
1247       internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
1248   static const int RDims =
1249       internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
1250   static const int ContractDims = internal::array_size<Indices>::value;
1251 
1252   typedef array<Index, LDims> left_dim_mapper_t;
1253   typedef array<Index, RDims> right_dim_mapper_t;
1254 
1255   typedef array<Index, ContractDims> contract_t;
1256   typedef array<Index, LDims - ContractDims> left_nocontract_t;
1257   typedef array<Index, RDims - ContractDims> right_nocontract_t;
1258 
1259   static const int NumDims = LDims + RDims - 2 * ContractDims;
1260 
1261   typedef DSizes<Index, NumDims> Dimensions;
1262 
1263   // typedefs needed in evalTo
1264   typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
1265   typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
1266 
1267   typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1268   typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1269 
1270   typedef typename LeftEvaluator::Dimensions LeftDimensions;
1271   typedef typename RightEvaluator::Dimensions RightDimensions;
1272 
1273   TensorEvaluator(const XprType& op, const Device& device) :
1274       Base(op, device)
1275   {
1276     EIGEN_STATIC_ASSERT( (internal::is_same<OutputKernelType, const NoOpOutputKernel>::value),
1277                           GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
1278   }
1279 
1280   // We need to redefine this method to make nvcc happy
1281   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
1282     this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1283     this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1284     if (data) {
1285       evalTo(data);
1286       return false;
1287     } else {
1288       this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
1289       evalTo(this->m_result);
1290       return true;
1291     }
1292   }
1293 
1294   void evalTo(Scalar* buffer) const {
1295     if (this->m_lhs_inner_dim_contiguous) {
1296       if (this->m_rhs_inner_dim_contiguous) {
1297         if (this->m_rhs_inner_dim_reordered) {
1298           evalTyped<true, true, true, Unaligned>(buffer);
1299         }
1300         else {
1301           evalTyped<true, true, false, Unaligned>(buffer);
1302         }
1303       }
1304       else {
1305        if (this->m_rhs_inner_dim_reordered) {
1306           evalTyped<true, false, true, Unaligned>(buffer);
1307         }
1308         else {
1309           evalTyped<true, false, false, Unaligned>(buffer);
1310         }
1311       }
1312     }
1313     else {
1314       if (this->m_rhs_inner_dim_contiguous) {
1315         if (this->m_rhs_inner_dim_reordered) {
1316           evalTyped<false, true, true, Unaligned>(buffer);
1317         }
1318         else {
1319           evalTyped<false, true, false, Unaligned>(buffer);
1320         }
1321       }
1322       else {
1323        if (this->m_rhs_inner_dim_reordered) {
1324           evalTyped<false, false, true, Unaligned>(buffer);
1325         }
1326         else {
1327           evalTyped<false, false, false, Unaligned>(buffer);
1328         }
1329       }
1330     }
1331   }
1332 
1333   template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels {
1334     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1335     const Index m_blocks = (m + 63) / 64;
1336     const Index n_blocks = (n + 63) / 64;
1337     const dim3 num_blocks(m_blocks, n_blocks, 1);
1338     const dim3 block_size(8, 8, 8);
1339     LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1340     }
1341   };
1342 
1343   template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
1344     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1345       if (m < 768 || n < 768) {
1346         const Index m_blocks = (m + 63) / 64;
1347         const Index n_blocks = (n + 63) / 64;
1348         const dim3 num_blocks(m_blocks, n_blocks, 1);
1349         const dim3 block_size(16, 16, 1);
1350         LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1351       } else {
1352         const Index m_blocks = (m + 127) / 128;
1353         const Index n_blocks = (n + 63) / 64;
1354         const dim3 num_blocks(m_blocks, n_blocks, 1);
1355         const dim3 block_size(8, 32, 1);
1356         LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1357       }
1358     }
1359   };
1360 
1361   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1362   void evalTyped(Scalar* buffer) const {
1363     // columns in left side, rows in right side
1364     const Index k = this->m_k_size;
1365     EIGEN_UNUSED_VARIABLE(k)
1366 
1367     // rows in left side
1368     const Index m = this->m_i_size;
1369 
1370     // columns in right side
1371     const Index n = this->m_j_size;
1372 
1373     // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
1374     this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
1375 
1376     typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
1377                                                    LeftEvaluator, left_nocontract_t,
1378                                                    contract_t, 4,
1379                                                    lhs_inner_dim_contiguous,
1380                                                    false, Unaligned> LhsMapper;
1381 
1382     typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
1383                                                    RightEvaluator, right_nocontract_t,
1384                                                    contract_t, 4,
1385                                                    rhs_inner_dim_contiguous,
1386                                                    rhs_inner_dim_reordered, Unaligned> RhsMapper;
1387 
1388     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1389 
1390 
1391     // initialize data mappers
1392     LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1393                   this->m_left_contracting_strides, this->m_k_strides);
1394 
1395     RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1396                   this->m_right_contracting_strides, this->m_k_strides);
1397 
1398     OutputMapper output(buffer, m);
1399 
1400 #if defined(EIGEN_USE_HIP)
1401     setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
1402 #else
1403     setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
1404 #endif
1405 
1406     LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output,  m, n, k, this->m_device);
1407   }
1408 };
1409 
1410 } // end namespace Eigen
1411 
1412 #endif // EIGEN_USE_GPU and EIGEN_GPUCC
1413 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H