Warning, file /include/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
0013 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
0014
0015 #if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
0016
0017 namespace Eigen {
0018
0019 template<typename Scalar, typename Index, typename LhsMapper,
0020 typename RhsMapper, typename OutputMapper, bool needs_edge_check>
0021 __device__ EIGEN_STRONG_INLINE void
0022 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
0023 const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
0024 const Index m_size, const Index n_size, const Index k_size) {
0025
0026 const Index m_block_idx = blockIdx.x;
0027 const Index n_block_idx = blockIdx.y;
0028
0029 const Index base_m = 64 * m_block_idx;
0030 const Index base_n = 64 * n_block_idx;
0031
0032
0033
0034
0035 Scalar lhs_pf0;
0036 Scalar lhs_pf1;
0037 Scalar lhs_pf2;
0038 Scalar lhs_pf3;
0039 Scalar lhs_pf4;
0040 Scalar lhs_pf5;
0041 Scalar lhs_pf6;
0042 Scalar lhs_pf7;
0043
0044 Scalar rhs_pf0;
0045 Scalar rhs_pf1;
0046 Scalar rhs_pf2;
0047 Scalar rhs_pf3;
0048 Scalar rhs_pf4;
0049 Scalar rhs_pf5;
0050 Scalar rhs_pf6;
0051 Scalar rhs_pf7;
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067 const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
0068 const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
0069
0070 const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
0071 const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
0072 const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
0073 const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
0074 const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
0075 const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
0076 const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
0077 const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
0078
0079 const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
0080 const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
0081 const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
0082 const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
0083 const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
0084 const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
0085 const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
0086 const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096
0097 const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
0098 const Index lhs_vert = base_m + load_idx_vert;
0099
0100 #define prefetchIntoRegisters(base_k) \
0101 { \
0102 lhs_pf0 = conv(0); \
0103 lhs_pf1 = conv(0); \
0104 lhs_pf2 = conv(0); \
0105 lhs_pf3 = conv(0); \
0106 lhs_pf4 = conv(0); \
0107 lhs_pf5 = conv(0); \
0108 lhs_pf6 = conv(0); \
0109 lhs_pf7 = conv(0); \
0110 \
0111 rhs_pf0 = conv(0); \
0112 rhs_pf1 = conv(0); \
0113 rhs_pf2 = conv(0); \
0114 rhs_pf3 = conv(0); \
0115 rhs_pf4 = conv(0); \
0116 rhs_pf5 = conv(0); \
0117 rhs_pf6 = conv(0); \
0118 rhs_pf7 = conv(0); \
0119 \
0120 if (!needs_edge_check || lhs_vert < m_size) { \
0121 const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
0122 const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
0123 const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
0124 const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
0125 const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
0126 const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
0127 const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
0128 const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
0129 \
0130 if (!needs_edge_check || lhs_horiz_7 < k_size) { \
0131 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
0132 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
0133 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
0134 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
0135 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
0136 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
0137 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
0138 lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
0139 } else if (lhs_horiz_6 < k_size) { \
0140 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
0141 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
0142 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
0143 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
0144 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
0145 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
0146 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
0147 } else if (lhs_horiz_5 < k_size) { \
0148 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
0149 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
0150 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
0151 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
0152 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
0153 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
0154 } else if (lhs_horiz_4 < k_size) { \
0155 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
0156 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
0157 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
0158 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
0159 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
0160 } else if (lhs_horiz_3 < k_size) { \
0161 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
0162 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
0163 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
0164 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
0165 } else if (lhs_horiz_2 < k_size) { \
0166 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
0167 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
0168 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
0169 } else if (lhs_horiz_1 < k_size) { \
0170 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
0171 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
0172 } else if (lhs_horiz_0 < k_size) { \
0173 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
0174 } \
0175 } \
0176 \
0177 const Index rhs_vert = base_k + load_idx_vert; \
0178 if (!needs_edge_check || rhs_vert < k_size) { \
0179 const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
0180 const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
0181 const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
0182 const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
0183 const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
0184 const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
0185 const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
0186 const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
0187 \
0188 if (rhs_horiz_7 < n_size) { \
0189 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
0190 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
0191 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
0192 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
0193 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
0194 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
0195 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
0196 rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
0197 } else if (rhs_horiz_6 < n_size) { \
0198 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
0199 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
0200 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
0201 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
0202 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
0203 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
0204 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
0205 } else if (rhs_horiz_5 < n_size) { \
0206 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
0207 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
0208 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
0209 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
0210 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
0211 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
0212 } else if (rhs_horiz_4 < n_size) { \
0213 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
0214 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
0215 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
0216 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
0217 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
0218 } else if (rhs_horiz_3 < n_size) { \
0219 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
0220 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
0221 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
0222 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
0223 } else if (rhs_horiz_2 < n_size) { \
0224 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
0225 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
0226 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
0227 } else if (rhs_horiz_1 < n_size) { \
0228 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
0229 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
0230 } else if (rhs_horiz_0 < n_size) { \
0231 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
0232 } \
0233 } \
0234 } \
0235
0236 #define writeRegToShmem(_) \
0237 lhs_shmem[lhs_store_idx_0] = lhs_pf0; \
0238 rhs_shmem[rhs_store_idx_0] = rhs_pf0; \
0239 \
0240 lhs_shmem[lhs_store_idx_1] = lhs_pf1; \
0241 rhs_shmem[rhs_store_idx_1] = rhs_pf1; \
0242 \
0243 lhs_shmem[lhs_store_idx_2] = lhs_pf2; \
0244 rhs_shmem[rhs_store_idx_2] = rhs_pf2; \
0245 \
0246 lhs_shmem[lhs_store_idx_3] = lhs_pf3; \
0247 rhs_shmem[rhs_store_idx_3] = rhs_pf3; \
0248 \
0249 lhs_shmem[lhs_store_idx_4] = lhs_pf4; \
0250 rhs_shmem[rhs_store_idx_4] = rhs_pf4; \
0251 \
0252 lhs_shmem[lhs_store_idx_5] = lhs_pf5; \
0253 rhs_shmem[rhs_store_idx_5] = rhs_pf5; \
0254 \
0255 lhs_shmem[lhs_store_idx_6] = lhs_pf6; \
0256 rhs_shmem[rhs_store_idx_6] = rhs_pf6; \
0257 \
0258 lhs_shmem[lhs_store_idx_7] = lhs_pf7; \
0259 rhs_shmem[rhs_store_idx_7] = rhs_pf7; \
0260
0261
0262 #define res(i, j) _res_##i##j
0263 #define initResultRow(i) \
0264 Scalar res(i, 0) = conv(0); \
0265 Scalar res(i, 1) = conv(0); \
0266 Scalar res(i, 2) = conv(0); \
0267 Scalar res(i, 3) = conv(0); \
0268 Scalar res(i, 4) = conv(0); \
0269 Scalar res(i, 5) = conv(0); \
0270 Scalar res(i, 6) = conv(0); \
0271 Scalar res(i, 7) = conv(0); \
0272
0273 internal::scalar_cast_op<int, Scalar> conv;
0274 initResultRow(0);
0275 initResultRow(1);
0276 initResultRow(2);
0277 initResultRow(3);
0278 initResultRow(4);
0279 initResultRow(5);
0280 initResultRow(6);
0281 initResultRow(7);
0282 #undef initResultRow
0283
0284 for (Index base_k = 0; base_k < k_size; base_k += 64) {
0285
0286
0287 __syncthreads();
0288
0289 prefetchIntoRegisters(base_k);
0290 writeRegToShmem();
0291
0292 #undef prefetchIntoRegisters
0293 #undef writeRegToShmem
0294
0295
0296 __syncthreads();
0297
0298
0299
0300
0301 #define lcol(i) _lcol##i
0302 Scalar lcol(0);
0303 Scalar lcol(1);
0304 Scalar lcol(2);
0305 Scalar lcol(3);
0306 Scalar lcol(4);
0307 Scalar lcol(5);
0308 Scalar lcol(6);
0309 Scalar lcol(7);
0310
0311 #define rrow(j) _rrow##j
0312 Scalar rrow(0);
0313 Scalar rrow(1);
0314 Scalar rrow(2);
0315 Scalar rrow(3);
0316 Scalar rrow(4);
0317 Scalar rrow(5);
0318 Scalar rrow(6);
0319 Scalar rrow(7);
0320
0321
0322 const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
0323 const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
0324
0325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
0326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
0327
0328 #define loadData(i, j) \
0329 lcol(0) = lhs_element(0, j); \
0330 rrow(0) = rhs_element(i, 0); \
0331 lcol(1) = lhs_element(1, j); \
0332 rrow(1) = rhs_element(i, 1); \
0333 lcol(2) = lhs_element(2, j); \
0334 rrow(2) = rhs_element(i, 2); \
0335 lcol(3) = lhs_element(3, j); \
0336 rrow(3) = rhs_element(i, 3); \
0337 lcol(4) = lhs_element(4, j); \
0338 rrow(4) = rhs_element(i, 4); \
0339 lcol(5) = lhs_element(5, j); \
0340 rrow(5) = rhs_element(i, 5); \
0341 lcol(6) = lhs_element(6, j); \
0342 rrow(6) = rhs_element(i, 6); \
0343 lcol(7) = lhs_element(7, j); \
0344 rrow(7) = rhs_element(i, 7); \
0345
0346 #define computeCol(j) \
0347 res(0, j) += lcol(0) * rrow(j); \
0348 res(1, j) += lcol(1) * rrow(j); \
0349 res(2, j) += lcol(2) * rrow(j); \
0350 res(3, j) += lcol(3) * rrow(j); \
0351 res(4, j) += lcol(4) * rrow(j); \
0352 res(5, j) += lcol(5) * rrow(j); \
0353 res(6, j) += lcol(6) * rrow(j); \
0354 res(7, j) += lcol(7) * rrow(j); \
0355
0356 #define computePass(i) \
0357 loadData(i, i); \
0358 \
0359 computeCol(0); \
0360 computeCol(1); \
0361 computeCol(2); \
0362 computeCol(3); \
0363 computeCol(4); \
0364 computeCol(5); \
0365 computeCol(6); \
0366 computeCol(7); \
0367
0368 computePass(0);
0369 computePass(1);
0370 computePass(2);
0371 computePass(3);
0372 computePass(4);
0373 computePass(5);
0374 computePass(6);
0375 computePass(7);
0376
0377 #undef lcol
0378 #undef rrow
0379 #undef lhs_element
0380 #undef rhs_element
0381 #undef loadData
0382 #undef computeCol
0383 #undef computePass
0384 }
0385
0386
0387
0388
0389
0390
0391 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
0392 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
0393 #else
0394 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
0395 #endif
0396
0397 #define reduceRow(i, mask) \
0398 shuffleInc(i, 0, mask); \
0399 shuffleInc(i, 1, mask); \
0400 shuffleInc(i, 2, mask); \
0401 shuffleInc(i, 3, mask); \
0402 shuffleInc(i, 4, mask); \
0403 shuffleInc(i, 5, mask); \
0404 shuffleInc(i, 6, mask); \
0405 shuffleInc(i, 7, mask); \
0406
0407 #define reduceMatrix(mask) \
0408 reduceRow(0, mask); \
0409 reduceRow(1, mask); \
0410 reduceRow(2, mask); \
0411 reduceRow(3, mask); \
0412 reduceRow(4, mask); \
0413 reduceRow(5, mask); \
0414 reduceRow(6, mask); \
0415 reduceRow(7, mask); \
0416
0417
0418
0419
0420 reduceMatrix(1);
0421 reduceMatrix(2);
0422 reduceMatrix(4);
0423
0424 #undef shuffleInc
0425 #undef reduceRow
0426 #undef reduceMatrix
0427
0428
0429
0430
0431
0432
0433
0434
0435
0436
0437
0438
0439
0440 __syncthreads();
0441
0442 #define writeResultShmem(i, j) \
0443 lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
0444
0445 #define writeRow(i) \
0446 writeResultShmem(i, 0); \
0447 writeResultShmem(i, 1); \
0448 writeResultShmem(i, 2); \
0449 writeResultShmem(i, 3); \
0450 writeResultShmem(i, 4); \
0451 writeResultShmem(i, 5); \
0452 writeResultShmem(i, 6); \
0453 writeResultShmem(i, 7); \
0454
0455 if (threadIdx.x == 0) {
0456 writeRow(0);
0457 writeRow(1);
0458 writeRow(2);
0459 writeRow(3);
0460 writeRow(4);
0461 writeRow(5);
0462 writeRow(6);
0463 writeRow(7);
0464 }
0465 #undef writeResultShmem
0466 #undef writeRow
0467
0468 const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
0469 const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
0470
0471 if (threadIdx.x < max_i_write) {
0472 if (max_j_write == 8) {
0473
0474 Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
0475 Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
0476 Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
0477 Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
0478 Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
0479 Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
0480 Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
0481 Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
0482
0483 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
0484 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
0485 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
0486 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
0487 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
0488 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
0489 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
0490 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
0491 } else {
0492 #pragma unroll 7
0493 for (int j = 0; j < max_j_write; j++) {
0494 Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
0495 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
0496 }
0497 }
0498 }
0499 #undef res
0500 }
0501
0502
0503 template<typename Scalar, typename Index, typename LhsMapper,
0504 typename RhsMapper, typename OutputMapper>
0505 __global__ void
0506 #if defined(EIGEN_HIPCC)
0507 __launch_bounds__(512, 1)
0508 #else
0509 __launch_bounds__(512)
0510 #endif
0511 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
0512 const OutputMapper output,
0513 const Index m_size, const Index n_size, const Index k_size) {
0514 __shared__ Scalar lhs_shmem[72 * 64];
0515 __shared__ Scalar rhs_shmem[72 * 64];
0516
0517 const Index m_block_idx = blockIdx.x;
0518 const Index n_block_idx = blockIdx.y;
0519
0520 const Index base_m = 64 * m_block_idx;
0521 const Index base_n = 64 * n_block_idx;
0522
0523 if (base_m + 63 < m_size && base_n + 63 < n_size) {
0524 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
0525 } else {
0526 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
0527 }
0528 }
0529
0530
0531 template<typename Index, typename LhsMapper,
0532 typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
0533 bool CHECK_RHS_BOUNDARY>
0534 __device__ __forceinline__ void
0535 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
0536 const OutputMapper output, float2 lhs_shmem2[][16],
0537 float2 rhs_shmem2[][8], const Index m_size,
0538 const Index n_size, const Index k_size,
0539 const Index base_m, const Index base_n) {
0540
0541
0542 float4 lhs_pf0, rhs_pf0;
0543
0544 float4 results[4];
0545 for (int i=0; i < 4; i++) {
0546 results[i].x = results[i].y = results[i].z = results[i].w = 0;
0547 }
0548
0549 #define prefetch_lhs(reg, row, col) \
0550 if (!CHECK_LHS_BOUNDARY) { \
0551 if (col < k_size) { \
0552 reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
0553 } \
0554 } else { \
0555 if (col < k_size) { \
0556 if (row + 3 < m_size) { \
0557 reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
0558 } else if (row + 2 < m_size) { \
0559 reg.x =lhs(row + 0, col); \
0560 reg.y =lhs(row + 1, col); \
0561 reg.z =lhs(row + 2, col); \
0562 } else if (row + 1 < m_size) { \
0563 reg.x =lhs(row + 0, col); \
0564 reg.y =lhs(row + 1, col); \
0565 } else if (row < m_size) { \
0566 reg.x =lhs(row + 0, col); \
0567 } \
0568 } \
0569 } \
0570
0571 Index lhs_vert = base_m+threadIdx.x*4;
0572
0573 for (Index k = 0; k < k_size; k += 16) {
0574
0575 lhs_pf0 = internal::pset1<float4>(0);
0576 rhs_pf0 = internal::pset1<float4>(0);
0577
0578 Index lhs_horiz = threadIdx.y+k;
0579 prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
0580
0581 Index rhs_vert = k+(threadIdx.x%4)*4;
0582 Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
0583
0584 if (!CHECK_RHS_BOUNDARY) {
0585 if ((rhs_vert + 3) < k_size) {
0586
0587 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
0588 } else if (rhs_vert + 2 < k_size) {
0589
0590 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0591 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0592 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
0593 } else if (rhs_vert + 1 < k_size) {
0594 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0595 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0596 } else if (rhs_vert < k_size) {
0597 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0598 }
0599 } else {
0600 if (rhs_horiz0 < n_size) {
0601 if ((rhs_vert + 3) < k_size) {
0602 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
0603 } else if ((rhs_vert + 2) < k_size) {
0604 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0605 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0606 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
0607 } else if ((rhs_vert + 1) < k_size) {
0608 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0609 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0610 } else if (rhs_vert < k_size) {
0611 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0612 }
0613 }
0614 }
0615 float x1, x2 ;
0616
0617 if((threadIdx.x%8) < 4) {
0618 x1 = rhs_pf0.y;
0619 x2 = rhs_pf0.w;
0620 } else {
0621 x1 = rhs_pf0.x;
0622 x2 = rhs_pf0.z;
0623 }
0624 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
0625 x1 = __shfl_xor(x1, 4);
0626 x2 = __shfl_xor(x2, 4);
0627 #else
0628 x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
0629 x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
0630 #endif
0631 if((threadIdx.x%8) < 4) {
0632 rhs_pf0.y = x1;
0633 rhs_pf0.w = x2;
0634 } else {
0635 rhs_pf0.x = x1;
0636 rhs_pf0.z = x2;
0637 }
0638
0639
0640
0641
0642
0643
0644
0645
0646 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
0647 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
0648
0649
0650
0651
0652
0653
0654
0655
0656 lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
0657 lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
0658
0659
0660 #define add_vals(fl1, fl2, fr1, fr2)\
0661 results[0].x += fl1.x * fr1.x;\
0662 results[0].y += fl1.y * fr1.x;\
0663 results[0].z += fl2.x * fr1.x;\
0664 results[0].w += fl2.y * fr1.x;\
0665 \
0666 results[1].x += fl1.x * fr1.y;\
0667 results[1].y += fl1.y * fr1.y;\
0668 results[1].z += fl2.x * fr1.y;\
0669 results[1].w += fl2.y * fr1.y;\
0670 \
0671 results[2].x += fl1.x * fr2.x;\
0672 results[2].y += fl1.y * fr2.x;\
0673 results[2].z += fl2.x * fr2.x;\
0674 results[2].w += fl2.y * fr2.x;\
0675 \
0676 results[3].x += fl1.x * fr2.y;\
0677 results[3].y += fl1.y * fr2.y;\
0678 results[3].z += fl2.x * fr2.y;\
0679 results[3].w += fl2.y * fr2.y;\
0680
0681 __syncthreads();
0682
0683
0684 #pragma unroll
0685 for (int koff = 0; koff < 16; koff ++) {
0686
0687 float2 fl1 = lhs_shmem2[koff][threadIdx.x];
0688 float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
0689
0690 int start_feature = threadIdx.y * 4;
0691 float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
0692 float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
0693
0694 add_vals(fl1, fl2, fr1, fr2)
0695 }
0696 __syncthreads();
0697 }
0698
0699 #undef prefetch_lhs
0700 #undef add_vals
0701
0702 Index horiz_base = threadIdx.y*4+base_n;
0703 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
0704 for (int i = 0; i < 4; i++) {
0705 output(lhs_vert, horiz_base + i) = results[i].x;
0706 output(lhs_vert + 1, horiz_base + i) = results[i].y;
0707 output(lhs_vert + 2, horiz_base + i) = results[i].z;
0708 output(lhs_vert + 3, horiz_base + i) = results[i].w;
0709 }
0710 } else if (!CHECK_RHS_BOUNDARY) {
0711
0712 if (lhs_vert + 3 < m_size) {
0713 for (int i = 0; i < 4; i++) {
0714 output(lhs_vert, horiz_base + i) = results[i].x;
0715 output(lhs_vert + 1, horiz_base + i) = results[i].y;
0716 output(lhs_vert + 2, horiz_base + i) = results[i].z;
0717 output(lhs_vert + 3, horiz_base + i) = results[i].w;
0718 }
0719 } else if (lhs_vert + 2 < m_size) {
0720 for (int i = 0; i < 4; i++) {
0721 output(lhs_vert, horiz_base + i) = results[i].x;
0722 output(lhs_vert + 1, horiz_base + i) = results[i].y;
0723 output(lhs_vert + 2, horiz_base + i) = results[i].z;
0724 }
0725 } else if (lhs_vert + 1 < m_size) {
0726 for (int i = 0; i < 4; i++) {
0727 output(lhs_vert, horiz_base + i) = results[i].x;
0728 output(lhs_vert + 1, horiz_base + i) = results[i].y;
0729 }
0730 } else if (lhs_vert < m_size) {
0731 for (int i = 0; i < 4; i++) {
0732 output(lhs_vert, horiz_base + i) = results[i].x;
0733 }
0734 }
0735 } else if (!CHECK_LHS_BOUNDARY) {
0736
0737
0738
0739
0740
0741
0742
0743
0744
0745 for (int i = 0; i < 4; i++) {
0746 if (horiz_base+i < n_size) {
0747 output(lhs_vert, horiz_base + i) = results[i].x;
0748 output(lhs_vert + 1, horiz_base + i) = results[i].y;
0749 output(lhs_vert + 2, horiz_base + i) = results[i].z;
0750 output(lhs_vert + 3, horiz_base + i) = results[i].w;
0751 }
0752 }
0753 } else {
0754
0755 for (int i = 0; i < 4; i++) {
0756 if (horiz_base+i < n_size) {
0757 if (lhs_vert < m_size)
0758 output(lhs_vert, horiz_base + i) = results[i].x;
0759 if (lhs_vert + 1 < m_size)
0760 output(lhs_vert + 1, horiz_base + i) = results[i].y;
0761 if (lhs_vert + 2 < m_size)
0762 output(lhs_vert + 2, horiz_base + i) = results[i].z;
0763 if (lhs_vert + 3 < m_size)
0764 output(lhs_vert + 3, horiz_base + i) = results[i].w;
0765 }
0766 }
0767 }
0768 }
0769
0770
0771 template<typename Index, typename LhsMapper,
0772 typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
0773 bool CHECK_RHS_BOUNDARY>
0774 __device__ __forceinline__ void
0775 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
0776 const OutputMapper output, float2 lhs_shmem2[][32],
0777 float2 rhs_shmem2[][8], const Index m_size,
0778 const Index n_size, const Index k_size,
0779 const Index base_m, const Index base_n) {
0780
0781
0782 float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
0783 float4 rhs_pf0, rhs_pf1;
0784
0785 float4 results[8];
0786 for (int i=0; i < 8; i++) {
0787 results[i].x = results[i].y = results[i].z = results[i].w = 0;
0788 }
0789
0790 Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
0791 for (Index k = 0; k < k_size; k += 32) {
0792 lhs_pf0 = internal::pset1<float4>(0);
0793 lhs_pf1 = internal::pset1<float4>(0);
0794 lhs_pf2 = internal::pset1<float4>(0);
0795 lhs_pf3 = internal::pset1<float4>(0);
0796
0797 rhs_pf0 = internal::pset1<float4>(0);
0798 rhs_pf1 = internal::pset1<float4>(0);
0799
0800 if (!CHECK_LHS_BOUNDARY) {
0801 if ((threadIdx.y/4+k+24) < k_size) {
0802 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0803 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0804 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
0805 lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
0806 } else if ((threadIdx.y/4+k+16) < k_size) {
0807 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0808 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0809 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
0810 } else if ((threadIdx.y/4+k+8) < k_size) {
0811 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0812 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0813 } else if ((threadIdx.y/4+k) < k_size) {
0814 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0815 }
0816 } else {
0817
0818 if (lhs_vert + 3 < m_size) {
0819 if ((threadIdx.y/4+k+24) < k_size) {
0820 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0821 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0822 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
0823 lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
0824 } else if ((threadIdx.y/4+k+16) < k_size) {
0825 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0826 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0827 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
0828 } else if ((threadIdx.y/4+k+8) < k_size) {
0829 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0830 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
0831 } else if ((threadIdx.y/4+k) < k_size) {
0832 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
0833 }
0834 } else if (lhs_vert + 2 < m_size) {
0835 if ((threadIdx.y/4+k+24) < k_size) {
0836 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0837 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0838 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
0839 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0840 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0841 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
0842 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0843 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
0844 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
0845 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
0846 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
0847 lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
0848 } else if ((threadIdx.y/4+k+16) < k_size) {
0849 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0850 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0851 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
0852 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0853 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0854 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
0855 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0856 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
0857 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
0858 } else if ((threadIdx.y/4+k+8) < k_size) {
0859 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0860 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0861 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
0862 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0863 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0864 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
0865 } else if ((threadIdx.y/4+k) < k_size) {
0866 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0867 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0868 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
0869 }
0870 } else if (lhs_vert + 1 < m_size) {
0871 if ((threadIdx.y/4+k+24) < k_size) {
0872 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0873 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0874 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0875 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0876 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0877 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
0878 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
0879 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
0880 } else if ((threadIdx.y/4+k+16) < k_size) {
0881 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0882 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0883 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0884 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0885 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0886 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
0887 } else if ((threadIdx.y/4+k+8) < k_size) {
0888 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0889 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0890 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0891 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
0892 } else if ((threadIdx.y/4+k) < k_size) {
0893 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0894 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
0895 }
0896 } else if (lhs_vert < m_size) {
0897 if ((threadIdx.y/4+k+24) < k_size) {
0898 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0899 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0900 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0901 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
0902 } else if ((threadIdx.y/4+k+16) < k_size) {
0903 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0904 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0905 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
0906 } else if ((threadIdx.y/4+k+8) < k_size) {
0907 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0908 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
0909 } else if ((threadIdx.y/4+k) < k_size) {
0910 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
0911 }
0912 }
0913 }
0914 __syncthreads();
0915 Index rhs_vert = k+threadIdx.x*4;
0916 Index rhs_horiz0 = threadIdx.y*2+base_n;
0917 Index rhs_horiz1 = threadIdx.y*2+1+base_n;
0918 if (!CHECK_RHS_BOUNDARY) {
0919 if ((rhs_vert + 3) < k_size) {
0920
0921 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
0922 rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
0923 } else if (rhs_vert + 2 < k_size) {
0924
0925 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0926 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0927 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
0928 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0929 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
0930 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
0931 } else if (rhs_vert + 1 < k_size) {
0932 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0933 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0934 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0935 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
0936 } else if (rhs_vert < k_size) {
0937 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0938 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0939 }
0940 } else {
0941 if (rhs_horiz1 < n_size) {
0942 if ((rhs_vert + 3) < k_size) {
0943
0944 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
0945 rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
0946 } else if (rhs_vert + 2 < k_size) {
0947
0948 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0949 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0950 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
0951 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0952 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
0953 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
0954 } else if (k+threadIdx.x*4 + 1 < k_size) {
0955 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0956 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0957 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0958 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
0959 } else if (k+threadIdx.x*4 < k_size) {
0960 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0961 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
0962 }
0963 } else if (rhs_horiz0 < n_size) {
0964 if ((rhs_vert + 3) < k_size) {
0965
0966 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
0967 } else if ((rhs_vert + 2) < k_size) {
0968
0969 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0970 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0971 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
0972 } else if ((rhs_vert + 1) < k_size) {
0973 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0974 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
0975 } else if (rhs_vert < k_size) {
0976 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
0977 }
0978 }
0979 }
0980 __syncthreads();
0981
0982
0983
0984
0985
0986 rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
0987
0988
0989
0990 rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
0991
0992
0993 rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
0994
0995
0996 rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
0997
0998
0999
1000
1001
1002
1003
1004
1005
1006 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\
1007 results[0].x += a_feat1.x * f1.x;\
1008 results[1].x += a_feat1.x * f1.y;\
1009 results[2].x += a_feat1.x * f2.x;\
1010 results[3].x += a_feat1.x * f2.y;\
1011 results[4].x += a_feat1.x * f3.x;\
1012 results[5].x += a_feat1.x * f3.y;\
1013 results[6].x += a_feat1.x * f4.x;\
1014 results[7].x += a_feat1.x * f4.y;\
1015 \
1016 results[0].y += a_feat1.y * f1.x;\
1017 results[1].y += a_feat1.y * f1.y;\
1018 results[2].y += a_feat1.y * f2.x;\
1019 results[3].y += a_feat1.y * f2.y;\
1020 results[4].y += a_feat1.y * f3.x;\
1021 results[5].y += a_feat1.y * f3.y;\
1022 results[6].y += a_feat1.y * f4.x;\
1023 results[7].y += a_feat1.y * f4.y;\
1024 \
1025 results[0].z += a_feat2.x * f1.x;\
1026 results[1].z += a_feat2.x * f1.y;\
1027 results[2].z += a_feat2.x * f2.x;\
1028 results[3].z += a_feat2.x * f2.y;\
1029 results[4].z += a_feat2.x * f3.x;\
1030 results[5].z += a_feat2.x * f3.y;\
1031 results[6].z += a_feat2.x * f4.x;\
1032 results[7].z += a_feat2.x * f4.y;\
1033 \
1034 results[0].w += a_feat2.y * f1.x;\
1035 results[1].w += a_feat2.y * f1.y;\
1036 results[2].w += a_feat2.y * f2.x;\
1037 results[3].w += a_feat2.y * f2.y;\
1038 results[4].w += a_feat2.y * f3.x;\
1039 results[5].w += a_feat2.y * f3.y;\
1040 results[6].w += a_feat2.y * f4.x;\
1041 results[7].w += a_feat2.y * f4.y;\
1042
1043 lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
1044 lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
1045 lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
1046 lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
1047
1048 lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
1049 lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
1050 lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
1051 lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
1052
1053 __syncthreads();
1054
1055
1056 #pragma unroll
1057 for (int koff = 0; koff < 32; koff ++) {
1058 float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
1059 float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
1060
1061
1062 int start_feature = (threadIdx.y / 4) * 8;
1063
1064 float2 br1 = rhs_shmem2[start_feature/2 + (koff % 4) * 32][koff/4];
1065 float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
1066 float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
1067 float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
1068
1069 add_vals(a3, a4, br1, br2, br3, br4)
1070 }
1071 __syncthreads();
1072 }
1073
1074 __syncthreads();
1075 Index horiz_base = (threadIdx.y/4)*8+base_n;
1076 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1077 for (int i = 0; i < 8; i++) {
1078 output(lhs_vert, horiz_base + i) = results[i].x;
1079 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1080 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1081 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1082 }
1083 } else if (!CHECK_RHS_BOUNDARY) {
1084 if (lhs_vert + 3 < m_size) {
1085 for (int i = 0; i < 8; i++) {
1086 output(lhs_vert, horiz_base + i) = results[i].x;
1087 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1088 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1089 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1090 }
1091 } else if (lhs_vert + 2 < m_size) {
1092 for (int i = 0; i < 8; i++) {
1093 output(lhs_vert, horiz_base + i) = results[i].x;
1094 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1095 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1096 }
1097 } else if (lhs_vert + 1 < m_size) {
1098 for (int i = 0; i < 8; i++) {
1099 output(lhs_vert, horiz_base + i) = results[i].x;
1100 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1101 }
1102 } else if (lhs_vert < m_size) {
1103 for (int i = 0; i < 8; i++) {
1104 output(lhs_vert, horiz_base + i) = results[i].x;
1105 }
1106 }
1107 } else if (!CHECK_LHS_BOUNDARY) {
1108
1109 for (int i = 0; i < 8; i++) {
1110 if (horiz_base + i < n_size) {
1111 output(lhs_vert, horiz_base + i) = results[i].x;
1112 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1113 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1114 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1115 }
1116 }
1117 } else {
1118
1119 for (int i = 0; i < 8; i++) {
1120 if (horiz_base + i < n_size) {
1121 if (lhs_vert < m_size)
1122 output(lhs_vert, horiz_base + i) = results[i].x;
1123 if (lhs_vert + 1 < m_size)
1124 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1125 if (lhs_vert + 2 < m_size)
1126 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1127 if (lhs_vert + 3 < m_size)
1128 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1129 }
1130 }
1131 }
1132 }
1133
1134
1135 template<typename Index, typename LhsMapper,
1136 typename RhsMapper, typename OutputMapper>
1137 __global__ void
1138 #if defined(EIGEN_HIPCC)
1139 __launch_bounds__(256, 1)
1140 #else
1141 __launch_bounds__(256)
1142 #endif
1143 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
1144 const OutputMapper output,
1145 const Index m_size, const Index n_size, const Index k_size) {
1146 __shared__ float2 lhs_shmem[64*32];
1147 __shared__ float2 rhs_shmem[128*8];
1148
1149 typedef float2 LHS_MEM[64][32];
1150 typedef float2 RHS_MEM[128][8];
1151
1152 const Index m_block_idx = blockIdx.x;
1153 const Index n_block_idx = blockIdx.y;
1154
1155 const Index base_m = 128 * m_block_idx;
1156 const Index base_n = 64 * n_block_idx;
1157
1158 bool check_rhs = (base_n + 63) >= n_size;
1159 bool check_lhs128 = (base_m + 127) >= m_size;
1160
1161 if (!check_rhs) {
1162 if (!check_lhs128) {
1163
1164 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1165 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1166 } else {
1167 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1168 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1169 }
1170 } else {
1171 if (!check_lhs128) {
1172
1173 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1174 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1175 } else {
1176 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1177 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1178 }
1179 }
1180 }
1181
1182 template<typename Index, typename LhsMapper,
1183 typename RhsMapper, typename OutputMapper>
1184 __global__ void
1185 #if defined(EIGEN_HIPCC)
1186 __launch_bounds__(256, 1)
1187 #else
1188 __launch_bounds__(256)
1189 #endif
1190 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
1191 const OutputMapper output,
1192 const Index m_size, const Index n_size, const Index k_size) {
1193 __shared__ float2 lhs_shmem[32][16];
1194 __shared__ float2 rhs_shmem[64][8];
1195
1196 const Index m_block_idx = blockIdx.x;
1197 const Index n_block_idx = blockIdx.y;
1198
1199 const Index base_m = 64 * m_block_idx;
1200 const Index base_n = 64 * n_block_idx;
1201
1202 if (base_m + 63 < m_size) {
1203 if (base_n + 63 < n_size) {
1204 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1205 } else {
1206 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1207 }
1208 } else {
1209 if (base_n + 63 < n_size) {
1210 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1211 } else {
1212 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1213 }
1214 }
1215 }
1216
1217
1218 template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
1219 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> :
1220 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
1221
1222 typedef GpuDevice Device;
1223
1224 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1225 typedef TensorContractionEvaluatorBase<Self> Base;
1226
1227 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
1228 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
1229 typedef typename XprType::Index Index;
1230 typedef typename XprType::CoeffReturnType CoeffReturnType;
1231 typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
1232
1233 enum {
1234 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
1235 };
1236
1237
1238
1239
1240
1241 typedef typename internal::conditional<
1242 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
1243 typedef typename internal::conditional<
1244 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
1245
1246 static const int LDims =
1247 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
1248 static const int RDims =
1249 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
1250 static const int ContractDims = internal::array_size<Indices>::value;
1251
1252 typedef array<Index, LDims> left_dim_mapper_t;
1253 typedef array<Index, RDims> right_dim_mapper_t;
1254
1255 typedef array<Index, ContractDims> contract_t;
1256 typedef array<Index, LDims - ContractDims> left_nocontract_t;
1257 typedef array<Index, RDims - ContractDims> right_nocontract_t;
1258
1259 static const int NumDims = LDims + RDims - 2 * ContractDims;
1260
1261 typedef DSizes<Index, NumDims> Dimensions;
1262
1263
1264 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
1265 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
1266
1267 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1268 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1269
1270 typedef typename LeftEvaluator::Dimensions LeftDimensions;
1271 typedef typename RightEvaluator::Dimensions RightDimensions;
1272
1273 TensorEvaluator(const XprType& op, const Device& device) :
1274 Base(op, device)
1275 {
1276 EIGEN_STATIC_ASSERT( (internal::is_same<OutputKernelType, const NoOpOutputKernel>::value),
1277 GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
1278 }
1279
1280
1281 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
1282 this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1283 this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1284 if (data) {
1285 evalTo(data);
1286 return false;
1287 } else {
1288 this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
1289 evalTo(this->m_result);
1290 return true;
1291 }
1292 }
1293
1294 void evalTo(Scalar* buffer) const {
1295 if (this->m_lhs_inner_dim_contiguous) {
1296 if (this->m_rhs_inner_dim_contiguous) {
1297 if (this->m_rhs_inner_dim_reordered) {
1298 evalTyped<true, true, true, Unaligned>(buffer);
1299 }
1300 else {
1301 evalTyped<true, true, false, Unaligned>(buffer);
1302 }
1303 }
1304 else {
1305 if (this->m_rhs_inner_dim_reordered) {
1306 evalTyped<true, false, true, Unaligned>(buffer);
1307 }
1308 else {
1309 evalTyped<true, false, false, Unaligned>(buffer);
1310 }
1311 }
1312 }
1313 else {
1314 if (this->m_rhs_inner_dim_contiguous) {
1315 if (this->m_rhs_inner_dim_reordered) {
1316 evalTyped<false, true, true, Unaligned>(buffer);
1317 }
1318 else {
1319 evalTyped<false, true, false, Unaligned>(buffer);
1320 }
1321 }
1322 else {
1323 if (this->m_rhs_inner_dim_reordered) {
1324 evalTyped<false, false, true, Unaligned>(buffer);
1325 }
1326 else {
1327 evalTyped<false, false, false, Unaligned>(buffer);
1328 }
1329 }
1330 }
1331 }
1332
1333 template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels {
1334 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1335 const Index m_blocks = (m + 63) / 64;
1336 const Index n_blocks = (n + 63) / 64;
1337 const dim3 num_blocks(m_blocks, n_blocks, 1);
1338 const dim3 block_size(8, 8, 8);
1339 LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1340 }
1341 };
1342
1343 template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
1344 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1345 if (m < 768 || n < 768) {
1346 const Index m_blocks = (m + 63) / 64;
1347 const Index n_blocks = (n + 63) / 64;
1348 const dim3 num_blocks(m_blocks, n_blocks, 1);
1349 const dim3 block_size(16, 16, 1);
1350 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1351 } else {
1352 const Index m_blocks = (m + 127) / 128;
1353 const Index n_blocks = (n + 63) / 64;
1354 const dim3 num_blocks(m_blocks, n_blocks, 1);
1355 const dim3 block_size(8, 32, 1);
1356 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1357 }
1358 }
1359 };
1360
1361 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1362 void evalTyped(Scalar* buffer) const {
1363
1364 const Index k = this->m_k_size;
1365 EIGEN_UNUSED_VARIABLE(k)
1366
1367
1368 const Index m = this->m_i_size;
1369
1370
1371 const Index n = this->m_j_size;
1372
1373
1374 this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
1375
1376 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
1377 LeftEvaluator, left_nocontract_t,
1378 contract_t, 4,
1379 lhs_inner_dim_contiguous,
1380 false, Unaligned> LhsMapper;
1381
1382 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
1383 RightEvaluator, right_nocontract_t,
1384 contract_t, 4,
1385 rhs_inner_dim_contiguous,
1386 rhs_inner_dim_reordered, Unaligned> RhsMapper;
1387
1388 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1389
1390
1391
1392 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1393 this->m_left_contracting_strides, this->m_k_strides);
1394
1395 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1396 this->m_right_contracting_strides, this->m_k_strides);
1397
1398 OutputMapper output(buffer, m);
1399
1400 #if defined(EIGEN_USE_HIP)
1401 setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
1402 #else
1403 setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
1404 #endif
1405
1406 LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k, this->m_device);
1407 }
1408 };
1409
1410 }
1411
1412 #endif
1413 #endif