File indexing completed on 2025-12-16 10:14:16
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
0011 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
0012
0013
0014 #ifdef EIGEN_USE_THREADS
0015
0016 namespace Eigen {
0017
0018 template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
0019 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> :
0020 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > {
0021
0022 typedef ThreadPoolDevice Device;
0023
0024 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
0025 typedef TensorContractionEvaluatorBase<Self> Base;
0026
0027 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
0028 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
0029 typedef typename XprType::Index Index;
0030 typedef typename XprType::CoeffReturnType CoeffReturnType;
0031 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
0032
0033 enum {
0034 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
0035 };
0036
0037
0038
0039
0040
0041 typedef typename internal::conditional<
0042 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
0043 typedef typename internal::conditional<
0044 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
0045
0046 static const int LDims =
0047 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
0048 static const int RDims =
0049 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
0050 static const int ContractDims = internal::array_size<Indices>::value;
0051
0052 typedef array<Index, LDims> left_dim_mapper_t;
0053 typedef array<Index, RDims> right_dim_mapper_t;
0054
0055 typedef array<Index, ContractDims> contract_t;
0056 typedef array<Index, LDims - ContractDims> left_nocontract_t;
0057 typedef array<Index, RDims - ContractDims> right_nocontract_t;
0058
0059 static const int NumDims = LDims + RDims - 2 * ContractDims;
0060
0061 typedef DSizes<Index, NumDims> Dimensions;
0062
0063
0064 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
0065 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
0066 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
0067
0068 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
0069 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
0070
0071 TensorEvaluator(const XprType& op, const Device& device) :
0072 Base(op, device) {}
0073
0074 template <int Alignment>
0075 void evalProduct(Scalar* buffer) const {
0076 evalProductImpl<NoCallback, Alignment>(buffer, NoCallback());
0077 }
0078
0079 template <typename EvalToCallback, int Alignment>
0080 void evalProductAsync(Scalar* buffer, EvalToCallback done) const {
0081 evalProductImpl<EvalToCallback, Alignment>(buffer, std::move(done));
0082 }
0083
0084 template <typename DoneCallback, int Alignment>
0085 void evalProductImpl(Scalar* buffer, DoneCallback done) const {
0086
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099
0100
0101 static const bool IsEvalInSyncMode =
0102 std::is_same<DoneCallback, NoCallback>::value;
0103
0104 const Index m = this->m_i_size;
0105 const Index n = this->m_j_size;
0106 const Index k = this->m_k_size;
0107 if (m == 0 || n == 0 || k == 0) return;
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132 bool shard_by_col = shardByCol(m, n, 2);
0133
0134
0135
0136 Index bm, bn, bk;
0137 if (shard_by_col) {
0138 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
0139 internal::ShardByCol>
0140 blocking(k, m, n, 2);
0141 bm = blocking.mc();
0142 bn = blocking.nc();
0143 bk = blocking.kc();
0144 } else {
0145 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
0146 internal::ShardByRow>
0147 blocking(k, m, n, 2);
0148 bm = blocking.mc();
0149 bn = blocking.nc();
0150 bk = blocking.kc();
0151 }
0152
0153
0154
0155
0156
0157 const TensorOpCost cost =
0158 contractionCost(m, n, bm, bn, bk, shard_by_col, false);
0159 int num_threads = TensorCostModel<ThreadPoolDevice>::numThreads(
0160 static_cast<double>(n) * m, cost, this->m_device.numThreads());
0161 int num_threads_by_k = numThreadsInnerDim(m, n, k);
0162 if (shardByInnerDim(m, n, k, num_threads, num_threads_by_k)) {
0163
0164
0165 if (IsEvalInSyncMode) {
0166 EvalShardedByInnerDimContext<DoneCallback> ctx(
0167 this, num_threads_by_k, buffer, m, n, k, std::move(done));
0168 ctx.template run<Alignment>();
0169 } else {
0170 auto* ctx = new EvalShardedByInnerDimContext<DoneCallback>(
0171 this, num_threads_by_k, buffer, m, n, k, std::move(done));
0172 ctx->template runAsync<Alignment>();
0173 }
0174
0175 return;
0176 }
0177
0178
0179
0180 if (n == 1) num_threads = 1;
0181
0182 if (num_threads == 1) {
0183 TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential,
0184 Unaligned, (buffer));
0185 if (!IsEvalInSyncMode) done();
0186 return;
0187 }
0188
0189
0190 shard_by_col = shardByCol(m, n, num_threads);
0191 if (shard_by_col) {
0192 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
0193 internal::ShardByCol>
0194 blocking(k, m, n, num_threads);
0195 bm = blocking.mc();
0196 bn = blocking.nc();
0197 bk = blocking.kc();
0198 } else {
0199 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
0200 internal::ShardByRow>
0201 blocking(k, m, n, num_threads);
0202 bm = blocking.mc();
0203 bn = blocking.nc();
0204 bk = blocking.kc();
0205 }
0206
0207
0208 Index nm0 = divup(m, bm);
0209 Index nn0 = divup(n, bn);
0210 Index nk = divup(k, bk);
0211
0212
0213
0214
0215
0216
0217 Index gm = 1;
0218 Index gn = 1;
0219
0220 if (shard_by_col) {
0221 gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
0222 gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
0223 } else {
0224 gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
0225 gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
0226 }
0227
0228 Index nm = divup(nm0, gm);
0229 Index nn = divup(nn0, gn);
0230
0231
0232
0233
0234
0235 const Index sharding_dim_tasks = shard_by_col ? nn : nm;
0236 const int num_worker_threads = this->m_device.numThreadsInPool();
0237
0238
0239
0240
0241 const float oversharding_factor =
0242 num_worker_threads <= 4 ? 8.0 :
0243 num_worker_threads <= 8 ? 4.0 :
0244 num_worker_threads <= 16 ? 2.0 :
0245 num_worker_threads <= 32 ? 1.0 :
0246 num_worker_threads <= 64 ? 0.8 : 0.6;
0247
0248 const bool parallelize_by_sharding_dim_only =
0249 sharding_dim_tasks >= oversharding_factor * num_worker_threads;
0250
0251
0252
0253
0254
0255
0256
0257
0258 bool parallel_pack = num_threads >= nm * nn;
0259
0260 if (m * bk * Index(sizeof(LhsScalar)) + n * bk * Index(sizeof(RhsScalar)) <=
0261 l2CacheSize() * num_threads)
0262 parallel_pack = true;
0263
0264
0265 if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
0266
0267
0268 if (parallelize_by_sharding_dim_only) parallel_pack = false;
0269
0270
0271 if (IsEvalInSyncMode) {
0272 #define CONTEXT_ARGS \
0273 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
0274 nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \
0275 NoCallback()) \
0276 .run()
0277 TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment,
0278 CONTEXT_ARGS);
0279 #undef CONTEXT_ARGS
0280
0281 } else {
0282 #define CONTEXT_ARGS \
0283 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
0284 nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \
0285 std::move(done))
0286 TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback,
0287 Alignment, CONTEXT_ARGS, run());
0288 #undef CONTEXT_ARGS
0289 }
0290 }
0291
0292
0293
0294
0295
0296 struct NoCallback {
0297 void operator()() {
0298 eigen_assert(false && "NoCallback should never be called");
0299 }
0300 };
0301
0302
0303
0304 template <typename DoneCallback, typename Context>
0305 class EvalParallelNotification;
0306
0307
0308 template <typename Context>
0309 class EvalParallelNotification<NoCallback, Context> {
0310 public:
0311 EvalParallelNotification(Context*, NoCallback) {}
0312 void Notify() { done_.Notify(); }
0313 void Wait() { done_.Wait(); }
0314 private:
0315 Eigen::Notification done_;
0316 };
0317
0318
0319 template <typename DoneCallback, typename Context>
0320 class EvalParallelNotification {
0321 public:
0322 EvalParallelNotification(Context* ctx, DoneCallback done)
0323 : ctx_(ctx), done_(std::move(done)) {}
0324
0325 void Notify() {
0326
0327
0328
0329 DoneCallback done_copy = std::move(done_);
0330
0331
0332 delete ctx_;
0333
0334
0335 done_copy();
0336 }
0337
0338 void Wait() {}
0339
0340 private:
0341 Context* ctx_;
0342 DoneCallback done_;
0343 };
0344
0345
0346
0347
0348
0349 template <typename DoneCallback, bool lhs_inner_dim_contiguous,
0350 bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered,
0351 int Alignment>
0352 class EvalParallelContext {
0353 public:
0354 typedef internal::TensorContractionInputMapper<
0355 LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
0356 contract_t, internal::packet_traits<LhsScalar>::size,
0357 lhs_inner_dim_contiguous, false, Unaligned>
0358 LhsMapper;
0359 typedef internal::TensorContractionInputMapper<
0360 RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
0361 contract_t, internal::packet_traits<RhsScalar>::size,
0362 rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
0363 RhsMapper;
0364
0365 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
0366
0367 typedef internal::TensorContractionKernel<
0368 Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
0369 TensorContractionKernel;
0370
0371 typedef typename TensorContractionKernel::LhsBlock LhsBlock;
0372 typedef typename TensorContractionKernel::RhsBlock RhsBlock;
0373 typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
0374
0375 EvalParallelContext(const Self* self, int num_threads, Scalar* buffer,
0376 Index tm, Index tn, Index tk, Index bm, Index bn,
0377 Index bk, Index nm, Index nn, Index nk, Index gm,
0378 Index gn, Index nm0, Index nn0, bool shard_by_col,
0379 bool parallel_pack,
0380 bool parallelize_by_sharding_dim_only,
0381 DoneCallback done)
0382 : created_by_thread_id_(std::this_thread::get_id()),
0383 done_(this, std::move(done)),
0384 device_(self->m_device),
0385 lhs_(self->m_leftImpl, self->m_left_nocontract_strides,
0386 self->m_i_strides, self->m_left_contracting_strides,
0387 self->m_k_strides),
0388 rhs_(self->m_rightImpl, self->m_right_nocontract_strides,
0389 self->m_j_strides, self->m_right_contracting_strides,
0390 self->m_k_strides),
0391 buffer_(buffer),
0392 output_(buffer, tm),
0393 output_kernel_(self->m_output_kernel),
0394 tensor_contraction_params_(self->m_tensor_contraction_params),
0395 num_threads_(num_threads),
0396 shard_by_col_(shard_by_col),
0397 parallel_pack_(parallel_pack),
0398 parallelize_by_sharding_dim_only_(parallelize_by_sharding_dim_only),
0399 m_(tm),
0400 n_(tn),
0401 k_(tk),
0402 bm_(bm),
0403 bn_(bn),
0404 bk_(bk),
0405 nm_(nm),
0406 nn_(nn),
0407 nk_(nk),
0408 gm_(gm),
0409 gn_(gn),
0410 nm0_(nm0),
0411 nn0_(nn0),
0412 kernel_(m_, k_, n_, bm_, bk_, bn_),
0413 num_thread_local_allocations_(0),
0414
0415
0416
0417 thread_local_capacity(2 * (parallelize_by_sharding_dim_only_
0418 ? device_.numThreadsInPool()
0419 : 0)),
0420
0421
0422 lhs_thread_local_blocks_(shard_by_col_ ? 0 : thread_local_capacity,
0423 {*this}, {*this}),
0424 rhs_thread_local_blocks_(shard_by_col_ ? thread_local_capacity : 0,
0425 {*this}, {*this}) {
0426
0427 eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
0428
0429 for (Index x = 0; x < P; x++) {
0430
0431
0432
0433
0434 state_switch_[x] =
0435 x == 0
0436 ? 1
0437 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) +
0438 (x == P - 1 ? nm_ * nn_ : 0);
0439 state_packing_ready_[x] =
0440 parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
0441 state_kernel_[x] = new std::atomic<uint8_t>*[nm_];
0442 for (Index m = 0; m < nm_; m++) {
0443 state_kernel_[x][m] = new std::atomic<uint8_t>[nn_];
0444
0445
0446
0447 for (Index n = 0; n < nn_; n++)
0448 state_kernel_[x][m][n].store(
0449 (x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1),
0450 std::memory_order_relaxed);
0451 }
0452 }
0453
0454
0455 packed_mem_ = kernel_.allocateSlices(
0456 device_,
0457 nm0_,
0458 nn0_,
0459 std::min<Index>(nk_, P - 1),
0460 packed_lhs_, packed_rhs_);
0461
0462 if (parallelize_by_sharding_dim_only_) {
0463 const int num_worker_threads = device_.numThreadsInPool();
0464
0465 if (shard_by_col) {
0466 can_use_thread_local_packed_ = new std::atomic<bool>[nn_];
0467 for (int i = 0; i < nn_; ++i)
0468 can_use_thread_local_packed_[i].store(true,
0469 std::memory_order_relaxed);
0470
0471 Index num_blocks = num_worker_threads * gn_;
0472 thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
0473 device_,
0474 0,
0475 num_blocks,
0476 1,
0477 nullptr, &rhs_thread_local_pre_allocated_);
0478
0479 } else {
0480 can_use_thread_local_packed_ = new std::atomic<bool>[nm_];
0481 for (int i = 0; i < nm_; ++i)
0482 can_use_thread_local_packed_[i].store(true,
0483 std::memory_order_relaxed);
0484
0485 Index num_blocks = num_worker_threads * gm_;
0486 thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
0487 device_,
0488 num_blocks,
0489 0,
0490 1, &lhs_thread_local_pre_allocated_,
0491 nullptr);
0492 }
0493 }
0494 }
0495
0496 ~EvalParallelContext() {
0497 for (Index x = 0; x < P; x++) {
0498 for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
0499 delete[] state_kernel_[x];
0500 }
0501 kernel_.deallocate(device_, packed_mem_);
0502 if (parallelize_by_sharding_dim_only_) {
0503 kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
0504 delete[] can_use_thread_local_packed_;
0505 }
0506 }
0507
0508 void run() {
0509
0510 signal_switch(0, 1);
0511
0512
0513
0514
0515
0516
0517
0518
0519
0520
0521
0522
0523
0524
0525 done_.Wait();
0526 }
0527
0528 private:
0529 std::thread::id created_by_thread_id_;
0530
0531
0532
0533 EvalParallelNotification<DoneCallback, EvalParallelContext> done_;
0534
0535 const Device& device_;
0536 LhsMapper lhs_;
0537 RhsMapper rhs_;
0538 Scalar* const buffer_;
0539 OutputMapper output_;
0540 OutputKernelType output_kernel_;
0541 TensorContractionParams tensor_contraction_params_;
0542 const int num_threads_;
0543 const bool shard_by_col_;
0544 const bool parallel_pack_;
0545 const bool parallelize_by_sharding_dim_only_;
0546
0547 const Index m_;
0548 const Index n_;
0549 const Index k_;
0550
0551 const Index bm_;
0552 const Index bn_;
0553 const Index bk_;
0554
0555 const Index nm_;
0556 const Index nn_;
0557 const Index nk_;
0558
0559 const Index gm_;
0560 const Index gn_;
0561
0562
0563 const Index nm0_;
0564 const Index nn0_;
0565
0566 TensorContractionKernel kernel_;
0567
0568
0569
0570
0571
0572
0573
0574
0575
0576
0577
0578
0579
0580
0581
0582
0583
0584
0585
0586
0587
0588
0589
0590
0591
0592
0593
0594
0595
0596
0597
0598
0599
0600
0601
0602 static const Index P = 3;
0603
0604
0605 BlockMemHandle packed_mem_;
0606 std::vector<LhsBlock> packed_lhs_[P - 1];
0607 std::vector<RhsBlock> packed_rhs_[P - 1];
0608
0609
0610
0611
0612
0613
0614
0615
0616
0617
0618
0619
0620
0621
0622
0623
0624
0625
0626
0627 BlockMemHandle thread_local_pre_alocated_mem_;
0628
0629
0630
0631 std::vector<LhsBlock> lhs_thread_local_pre_allocated_;
0632 std::vector<RhsBlock> rhs_thread_local_pre_allocated_;
0633
0634
0635 std::atomic<int> num_thread_local_allocations_;
0636 const int thread_local_capacity;
0637
0638
0639
0640
0641
0642
0643
0644 template <typename BlockType>
0645 class ThreadLocalBlocks {
0646 public:
0647 ThreadLocalBlocks() = default;
0648
0649 ThreadLocalBlocks(BlockType* base, size_t grain_size)
0650 : is_pre_allocated_(true),
0651 thread_local_pre_allocated_base_(base),
0652 grain_size_(grain_size) {}
0653
0654 ThreadLocalBlocks(BlockMemHandle mem_handle,
0655 std::vector<BlockType> blocks)
0656 : is_pre_allocated_(false),
0657 mem_handle_(std::move(mem_handle)),
0658 blocks_(std::move(blocks)) {}
0659
0660 BlockType& block(int grain_index) {
0661 eigen_assert(grain_index >= 0);
0662 eigen_assert(static_cast<size_t>(grain_index) < size());
0663 return is_pre_allocated_ ? thread_local_pre_allocated_base_[grain_index]
0664 : blocks_[grain_index];
0665 }
0666
0667 void Release(EvalParallelContext& ctx) const {
0668 if (!is_pre_allocated_) {
0669 ctx.kernel_.deallocate(ctx.device_, mem_handle_);
0670 }
0671 }
0672
0673 size_t size() const {
0674 return is_pre_allocated_ ? grain_size_ : blocks_.size();
0675 }
0676
0677 private:
0678 bool is_pre_allocated_;
0679
0680
0681 BlockType* thread_local_pre_allocated_base_ = nullptr;
0682 size_t grain_size_ = 0;
0683
0684
0685 BlockMemHandle mem_handle_{};
0686 std::vector<BlockType> blocks_;
0687 };
0688
0689
0690
0691
0692
0693
0694
0695 template <typename BlockType, bool is_rhs>
0696 class ThreadLocalBlocksInitialize {
0697 static constexpr bool kIsLhs =
0698 !is_rhs && std::is_same<BlockType, LhsBlock>::value;
0699 static const bool kIsRhs =
0700 is_rhs && std::is_same<BlockType, RhsBlock>::value;
0701 static_assert(kIsLhs || kIsRhs, "Unkown block type");
0702
0703 using Blocks = ThreadLocalBlocks<BlockType>;
0704
0705 public:
0706 ThreadLocalBlocksInitialize(EvalParallelContext& ctx)
0707 : ctx_(ctx),
0708 num_worker_threads_(ctx_.device_.numThreadsInPool()) {}
0709
0710 void operator()(Blocks& blocks) {
0711 const int n = ctx_.num_thread_local_allocations_.fetch_add(
0712 1, std::memory_order_relaxed);
0713
0714 if (n >= num_worker_threads_) {
0715 ThreadLocalBlocksAllocator<is_rhs>::allocate(ctx_, blocks);
0716 } else {
0717 ThreadLocalBlocksAllocator<is_rhs>::reuse(ctx_, n, blocks);
0718 }
0719 }
0720
0721 private:
0722
0723
0724
0725
0726 template <bool pack_rhs, typename EvalCtx = EvalParallelContext>
0727 struct ThreadLocalBlocksAllocator;
0728
0729 template <typename EvalCtx>
0730 struct ThreadLocalBlocksAllocator<true, EvalCtx> {
0731 static void allocate(EvalCtx& ctx, Blocks& blocks) {
0732 std::vector<RhsBlock> rhs_blocks;
0733 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
0734 ctx.device_,
0735 0,
0736 ctx.gn_,
0737 1,
0738 nullptr, &rhs_blocks);
0739
0740 blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle),
0741 std::move(rhs_blocks));
0742 }
0743
0744 static void reuse(EvalCtx& ctx, int index, Blocks& blocks) {
0745 RhsBlock* ptr = &ctx.rhs_thread_local_pre_allocated_[ctx.gn_ * index];
0746 blocks = ThreadLocalBlocks<RhsBlock>(ptr, ctx.gn_);
0747 }
0748 };
0749
0750 template <typename EvalCtx>
0751 struct ThreadLocalBlocksAllocator<false, EvalCtx> {
0752 static void allocate(EvalCtx& ctx, Blocks& blocks) {
0753 std::vector<LhsBlock> lhs_blocks;
0754 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
0755 ctx.device_,
0756 ctx.gm_,
0757 0,
0758 1,
0759 &lhs_blocks, nullptr);
0760
0761 blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle),
0762 std::move(lhs_blocks));
0763 }
0764
0765 static void reuse(EvalCtx& ctx, int index, Blocks& blocks) {
0766 LhsBlock* ptr = &ctx.lhs_thread_local_pre_allocated_[ctx.gm_ * index];
0767 blocks = ThreadLocalBlocks<LhsBlock>(ptr, ctx.gm_);
0768 }
0769 };
0770
0771 EvalParallelContext& ctx_;
0772 const int num_worker_threads_;
0773 };
0774
0775 template <typename BlockType>
0776 class ThreadLocalBlocksRelease {
0777 public:
0778 using Blocks = ThreadLocalBlocks<BlockType>;
0779 ThreadLocalBlocksRelease(EvalParallelContext& ctx) : ctx_(ctx) {}
0780 void operator()(Blocks& blocks) { blocks.Release(ctx_); }
0781
0782 private:
0783 EvalParallelContext& ctx_;
0784 };
0785
0786
0787 using ThreadLocalLhsInit =
0788 ThreadLocalBlocksInitialize<LhsBlock, false>;
0789 using ThreadLocalRhsInit =
0790 ThreadLocalBlocksInitialize<RhsBlock, true>;
0791
0792
0793 using ThreadLocalLhsRelease = ThreadLocalBlocksRelease<LhsBlock>;
0794 using ThreadLocalRhsRelease = ThreadLocalBlocksRelease<RhsBlock>;
0795
0796
0797
0798 Eigen::ThreadLocal<ThreadLocalBlocks<LhsBlock>, ThreadLocalLhsInit,
0799 ThreadLocalLhsRelease>
0800 lhs_thread_local_blocks_;
0801 Eigen::ThreadLocal<ThreadLocalBlocks<RhsBlock>, ThreadLocalRhsInit,
0802 ThreadLocalRhsRelease>
0803 rhs_thread_local_blocks_;
0804
0805
0806
0807
0808
0809
0810 std::atomic<bool>* can_use_thread_local_packed_;
0811
0812 std::atomic<uint8_t>** state_kernel_[P];
0813
0814
0815
0816 char pad_[128];
0817 std::atomic<Index> state_packing_ready_[P];
0818 std::atomic<Index> state_switch_[P];
0819
0820 LhsBlock& packed_lhs(Index m, Index k, Index m1, bool use_thread_local) {
0821 if (use_thread_local) {
0822 eigen_assert(!shard_by_col_);
0823 ThreadLocalBlocks<LhsBlock>& blocks = lhs_thread_local_blocks_.local();
0824
0825 Index grain_index = m1 - m * gm_;
0826 return blocks.block(internal::convert_index<int>(grain_index));
0827 } else {
0828 return packed_lhs_[k % (P - 1)][m1];
0829 }
0830 }
0831
0832 RhsBlock& packed_rhs(Index n, Index k, Index n1, bool use_thread_local) {
0833 if (use_thread_local) {
0834 eigen_assert(shard_by_col_);
0835 ThreadLocalBlocks<RhsBlock>& blocks = rhs_thread_local_blocks_.local();
0836
0837 Index grain_index = n1 - n * gn_;
0838 return blocks.block(internal::convert_index<int>(grain_index));
0839 } else {
0840 return packed_rhs_[k % (P - 1)][n1];
0841 }
0842 }
0843
0844
0845
0846
0847
0848
0849
0850
0851
0852
0853
0854 void pack_lhs(Index m, Index k) {
0855 bool use_thread_local = false;
0856
0857 if (parallelize_by_sharding_dim_only_ && !shard_by_col_ &&
0858 can_use_thread_local_packed_[m].load(std::memory_order_relaxed)) {
0859 if (state_kernel_[k % P][m][0].load(std::memory_order_relaxed) == 1) {
0860 use_thread_local = true;
0861 } else {
0862
0863
0864
0865 eigen_assert(k > 0);
0866 can_use_thread_local_packed_[m].store(false,
0867 std::memory_order_relaxed);
0868 }
0869 }
0870
0871 const Index mend = m * gm_ + gm(m);
0872 for (Index m1 = m * gm_; m1 < mend; m1++)
0873 kernel_.packLhs(&packed_lhs(m, k, m1, use_thread_local),
0874 lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
0875
0876 if (!parallel_pack_ && shard_by_col_) {
0877 assert(!use_thread_local);
0878 signal_packing(k);
0879 } else {
0880 signal_switch(k + 1);
0881 for (Index n = nn_ - 1; n >= 0; n--) {
0882 bool sync = parallelize_by_sharding_dim_only_ || n == 0;
0883 signal_kernel(m, n, k, sync, use_thread_local);
0884 }
0885 }
0886 }
0887
0888 void pack_rhs(Index n, Index k) {
0889 bool use_thread_local = false;
0890
0891 if (parallelize_by_sharding_dim_only_ && shard_by_col_ &&
0892 can_use_thread_local_packed_[n].load(std::memory_order_relaxed)) {
0893 if (state_kernel_[k % P][0][n].load(std::memory_order_relaxed) == 1) {
0894 use_thread_local = true;
0895 } else {
0896
0897
0898
0899 eigen_assert(k > 0);
0900 can_use_thread_local_packed_[n].store(false,
0901 std::memory_order_relaxed);
0902 }
0903 }
0904
0905 const Index nend = n * gn_ + gn(n);
0906 for (Index n1 = n * gn_; n1 < nend; n1++) {
0907 if (!TensorContractionKernel::HasBeta && k == 0) {
0908
0909
0910
0911
0912
0913
0914
0915
0916
0917 memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
0918 }
0919 kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local),
0920 rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
0921 }
0922
0923 if (parallel_pack_ || shard_by_col_) {
0924 signal_switch(k + 1);
0925 for (Index m = nm_ - 1; m >= 0; m--) {
0926 bool sync = parallelize_by_sharding_dim_only_ || m == 0;
0927 signal_kernel(m, n, k, sync, use_thread_local);
0928 }
0929 } else {
0930 assert(!use_thread_local);
0931 signal_packing(k);
0932 }
0933 }
0934
0935 void kernel(Index m, Index n, Index k, bool use_thread_local) {
0936
0937
0938
0939 const Index nend = n * gn_ + gn(n);
0940 const Index mend = m * gm_ + gm(m);
0941
0942
0943 const Scalar alpha = Scalar(1);
0944 const Scalar beta =
0945 (TensorContractionKernel::HasBeta && k == 0) ? Scalar(0) : Scalar(1);
0946
0947 if (shard_by_col_) {
0948 for (Index n1 = n * gn_; n1 < nend; n1++) {
0949 for (Index m1 = m * gm_; m1 < mend; m1++) {
0950 const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
0951 kernel_.invoke(
0952 output_mapper,
0953 packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
0954 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
0955 bk(k), bn(n1), alpha, beta);
0956
0957
0958 if (k + 1 == nk_) {
0959 output_kernel_(output_mapper, tensor_contraction_params_,
0960 m1 * bm_, n1 * bn_, bm(m1), bn(n1));
0961 }
0962 }
0963 }
0964 } else {
0965 for (Index m1 = m * gm_; m1 < mend; m1++)
0966 for (Index n1 = n * gn_; n1 < nend; n1++) {
0967 const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
0968 kernel_.invoke(
0969 output_mapper,
0970 packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
0971 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
0972 bk(k), bn(n1), alpha, beta);
0973
0974
0975 if (k + 1 == nk_) {
0976 output_kernel_(output_mapper, tensor_contraction_params_,
0977 m1 * bm_, n1 * bn_, bm(m1), bn(n1));
0978 }
0979 }
0980 }
0981 signal_kernel(m, n, k + 1, false, false);
0982 signal_switch(k + 2);
0983 }
0984
0985 void signal_packing(Index k) {
0986 eigen_assert(!parallel_pack_);
0987 Index s = state_packing_ready_[k % P].fetch_sub(1);
0988 eigen_assert(s > 0);
0989 if (s != 1) return;
0990 state_packing_ready_[k % P] = shard_by_col_ ? nm_ : nn_;
0991 enqueue_packing(k, shard_by_col_);
0992 }
0993
0994 void signal_kernel(Index m, Index n, Index k, bool sync,
0995 bool use_thread_local) {
0996 std::atomic<uint8_t>* state = &state_kernel_[k % P][m][n];
0997 Index s = state->load();
0998 eigen_assert(s > 0);
0999 if (s != 1 && state->fetch_sub(1) != 1) {
1000 eigen_assert(!use_thread_local);
1001 return;
1002 }
1003 state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
1004 if (sync) {
1005 kernel(m, n, k, use_thread_local);
1006 } else {
1007 eigen_assert(!use_thread_local);
1008 device_.enqueueNoNotification(
1009 [=]() { kernel(m, n, k, use_thread_local); });
1010 }
1011 }
1012
1013 void signal_switch(Index k, Index v = 1) {
1014 Index s = state_switch_[k % P].fetch_sub(v);
1015 eigen_assert(s >= v);
1016 if (s != v) return;
1017
1018
1019
1020 state_switch_[k % P] =
1021 (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) +
1022 nm_ * nn_;
1023 if (k < nk_) {
1024
1025
1026 if (parallel_pack_) {
1027 enqueue_packing(k, !shard_by_col_);
1028 enqueue_packing(k, shard_by_col_);
1029 } else if (shard_by_col_) {
1030 enqueue_packing(k, false);
1031 } else {
1032 enqueue_packing(k, true);
1033 }
1034
1035
1036
1037
1038
1039
1040 } else if (k == nk_) {
1041 signal_switch(k + 1,
1042 parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
1043 } else {
1044 done_.Notify();
1045 }
1046 }
1047
1048
1049 void enqueue_packing(Index k, bool rhs) {
1050 enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs);
1051 }
1052
1053 void enqueue_packing_helper(Index start, Index end, Index k, bool rhs) {
1054 if (end - start == 1) {
1055 if (rhs)
1056 pack_rhs(start, k);
1057 else
1058 pack_lhs(start, k);
1059 } else {
1060 while (end - start > 1) {
1061 Index mid = (start + end) / 2;
1062 device_.enqueueNoNotification(
1063 [=]() { enqueue_packing_helper(mid, end, k, rhs); });
1064 end = mid;
1065 }
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075 bool pack_async =
1076 (start == 0) &&
1077 (parallelize_by_sharding_dim_only_&& shard_by_col_ == rhs) &&
1078 (k > 0 || std::this_thread::get_id() == created_by_thread_id_);
1079
1080 if (pack_async) {
1081 device_.enqueueNoNotification(
1082 [=]() { enqueue_packing_helper(start, end, k, rhs); });
1083 } else {
1084 enqueue_packing_helper(start, end, k, rhs);
1085 }
1086 }
1087 }
1088
1089
1090 Index bm(Index m) const { return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
1091 Index bn(Index n) const { return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
1092 Index bk(Index k) const { return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
1093
1094 Index gm(Index m) const { return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
1095 Index gn(Index n) const { return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
1096
1097 EvalParallelContext(const EvalParallelContext&) = delete;
1098 void operator=(const EvalParallelContext&) = delete;
1099 };
1100
1101 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
1102 bool rhs_inner_dim_reordered, int Alignment>
1103 using SyncEvalParallelContext =
1104 EvalParallelContext<NoCallback, lhs_inner_dim_contiguous,
1105 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
1106 Alignment>;
1107
1108
1109
1110
1111
1112
1113
1114
1115 template <typename DoneCallback>
1116 struct EvalShardedByInnerDimContext {
1117 EvalShardedByInnerDimContext(const Self* self, int num_threads,
1118 Scalar* result_buffer,
1119 Index m_size, Index n_size, Index k_size,
1120 DoneCallback done_callback)
1121 : evaluator(self),
1122 m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous),
1123 m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous),
1124 m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered),
1125 result(result_buffer),
1126 m(m_size),
1127 n(n_size),
1128 k(k_size),
1129 done(std::move(done_callback)),
1130 buffer_size_bytes(m * n * sizeof(Scalar)),
1131 block_size(blockSize(k, num_threads)),
1132 num_blocks(divup<Index>(k, block_size)),
1133 num_pending_blocks(internal::convert_index<int>(num_blocks)),
1134 l0_ranges(divup<Index>(num_blocks, l0_size)),
1135 l0_state(l0_ranges),
1136 block_buffers(num_blocks) {
1137
1138 for (int i = 0; i < l0_ranges; ++i) {
1139 const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size, i);
1140 l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
1141 }
1142
1143
1144 for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
1145 Scalar* buf = block_idx == 0
1146 ? result
1147 : static_cast<Scalar*>(evaluator->m_device.allocate(
1148 buffer_size_bytes));
1149 block_buffers.emplace_back(buf);
1150 }
1151 }
1152
1153 ~EvalShardedByInnerDimContext() {
1154 for (Index i = 1; i < num_blocks; ++i) {
1155 evaluator->m_device.deallocate(block_buffers[i]);
1156 }
1157 }
1158
1159 template <int Alignment>
1160 void run() {
1161 Barrier barrier(internal::convert_index<int>(num_blocks));
1162 eval<Alignment>(barrier, 0, num_blocks);
1163 barrier.Wait();
1164
1165
1166 aggregateL0Blocks<Alignment>();
1167
1168
1169 applyOutputKernel();
1170 }
1171
1172 template <int Alignment>
1173 void runAsync() {
1174 evalAsync<Alignment>(0, num_blocks);
1175 }
1176
1177 private:
1178
1179
1180 static const Index packet_size = internal::packet_traits<RhsScalar>::size;
1181
1182 const Self* evaluator;
1183
1184
1185 bool m_lhs_inner_dim_contiguous;
1186 bool m_rhs_inner_dim_contiguous;
1187 bool m_rhs_inner_dim_reordered;
1188
1189 Scalar* result;
1190
1191 Index m;
1192 Index n;
1193 Index k;
1194
1195 DoneCallback done;
1196
1197
1198
1199
1200
1201 Index buffer_size_bytes;
1202
1203 Index block_size;
1204 Index num_blocks;
1205
1206
1207 std::atomic<int> num_pending_blocks;
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225 static const Index l0_size = 4;
1226 Index l0_ranges;
1227
1228
1229 MaxSizeVector<std::atomic<int>> l0_state;
1230
1231
1232 MaxSizeVector<Scalar*> block_buffers;
1233
1234 template <int Alignment>
1235 void processBlock(Index block_idx, Index begin, Index end) {
1236 Scalar* buf = block_buffers[block_idx];
1237
1238 TENSOR_CONTRACTION_DISPATCH(
1239 evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
1240 (buf, begin, end,
1241 internal::convert_index<int>(num_blocks)));
1242
1243
1244 const Index l0_index = block_idx / l0_size;
1245 const int v = l0_state[l0_index].fetch_sub(1);
1246 eigen_assert(v >= 1);
1247
1248
1249
1250 if (v == 1) {
1251 const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index);
1252 const Index dst_block_idx = l0_index * l0_size;
1253
1254 if (rng_size == l0_size) {
1255 addAllToBuffer<Alignment>(
1256 m * n,
1257 block_buffers[dst_block_idx + 1],
1258 block_buffers[dst_block_idx + 2],
1259 block_buffers[dst_block_idx + 3],
1260 block_buffers[dst_block_idx]);
1261 } else {
1262
1263 for (int i = 1; i < rng_size; ++i) {
1264 addToBuffer<Alignment>(m * n,
1265 block_buffers[dst_block_idx + i],
1266 block_buffers[dst_block_idx]);
1267 }
1268 }
1269 }
1270 }
1271
1272
1273 template <int Alignment>
1274 void aggregateL0Blocks() const {
1275 Index l0_index = 1;
1276
1277 for (; l0_index + 2 < l0_ranges; l0_index += 3) {
1278 addAllToBuffer<Alignment>(
1279 m * n,
1280 block_buffers[(l0_index + 0) * l0_size],
1281 block_buffers[(l0_index + 1) * l0_size],
1282 block_buffers[(l0_index + 2) * l0_size],
1283 block_buffers[0]);
1284 }
1285
1286 for (; l0_index < l0_ranges; ++l0_index) {
1287 addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size],
1288 block_buffers[0]);
1289 }
1290 }
1291
1292 void applyOutputKernel() const {
1293 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1294 evaluator->m_output_kernel(
1295 OutputMapper(result, m), evaluator->m_tensor_contraction_params,
1296 static_cast<Eigen::Index>(0), static_cast<Eigen::Index>(0), m, n);
1297 }
1298
1299
1300 Index actualBlockSize(Index block_idx) const {
1301 return block_idx + 1 < num_blocks
1302 ? block_size
1303 : k + block_size - block_size * num_blocks;
1304 };
1305
1306
1307 Index actualRangeSize(Index num_ranges, Index range_size,
1308 Index range_idx) const {
1309 eigen_assert(range_idx < num_ranges);
1310 return range_idx + 1 < num_ranges
1311 ? range_size
1312 : num_blocks + range_size - range_size * num_ranges;
1313 };
1314
1315 template <int Alignment>
1316 EIGEN_STRONG_INLINE static void addToBuffer(size_t n, const Scalar* src_buf,
1317 Scalar* tgt_buf) {
1318 const int output_packet_size =
1319 internal::unpacket_traits<PacketReturnType>::size;
1320 size_t i = 0;
1321 const size_t num_packets = n / output_packet_size;
1322 for (; i < output_packet_size * num_packets; i += output_packet_size) {
1323 const PacketReturnType src_val =
1324 internal::pload<PacketReturnType>(src_buf + i);
1325 const PacketReturnType tgt_val =
1326 internal::ploadt<PacketReturnType, Alignment>(tgt_buf + i);
1327 const PacketReturnType sum = internal::padd(src_val, tgt_val);
1328 internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + i,
1329 sum);
1330 }
1331 for (; i < n; ++i) {
1332 tgt_buf[i] += src_buf[i];
1333 }
1334 }
1335
1336 template <int Alignment>
1337 EIGEN_STRONG_INLINE static void addAllToBuffer(size_t n,
1338 const Scalar* src_buf0,
1339 const Scalar* src_buf1,
1340 const Scalar* src_buf2,
1341 Scalar* dst_buf) {
1342 using ::Eigen::internal::padd;
1343 using ::Eigen::internal::pload;
1344 using ::Eigen::internal::ploadt;
1345 using ::Eigen::internal::pstoret;
1346
1347 const int output_packet_size =
1348 internal::unpacket_traits<PacketReturnType>::size;
1349
1350 size_t i = 0;
1351 const size_t num_packets = n / output_packet_size;
1352 for (; i < output_packet_size * num_packets; i += output_packet_size) {
1353 const auto src_val0 = pload<PacketReturnType>(src_buf0 + i);
1354 const auto src_val1 = pload<PacketReturnType>(src_buf1 + i);
1355 const auto src_val2 = pload<PacketReturnType>(src_buf2 + i);
1356
1357 const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i);
1358 const auto sum =
1359 padd(padd(dst_val, src_val0), padd(src_val1, src_val2));
1360
1361 pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum);
1362 }
1363 for (; i < n; ++i) {
1364 dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i];
1365 }
1366 }
1367
1368 template <int Alignment>
1369 void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) {
1370 while (end_block_idx - start_block_idx > 1) {
1371 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1372 evaluator->m_device.enqueueNoNotification(
1373 [this, &barrier, mid_block_idx, end_block_idx]() {
1374 eval<Alignment>(barrier, mid_block_idx, end_block_idx);
1375 });
1376 end_block_idx = mid_block_idx;
1377 }
1378
1379 Index block_idx = start_block_idx;
1380 Index block_start = block_idx * block_size;
1381 Index block_end = block_start + actualBlockSize(block_idx);
1382
1383 processBlock<Alignment>(block_idx, block_start, block_end);
1384 barrier.Notify();
1385 }
1386
1387 template <int Alignment>
1388 void evalAsync(Index start_block_idx, Index end_block_idx) {
1389 while (end_block_idx - start_block_idx > 1) {
1390 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1391 evaluator->m_device.enqueueNoNotification(
1392 [this, mid_block_idx, end_block_idx]() {
1393 evalAsync<Alignment>(mid_block_idx, end_block_idx);
1394 });
1395 end_block_idx = mid_block_idx;
1396 }
1397
1398 Index block_idx = start_block_idx;
1399
1400 Index block_start = block_idx * block_size;
1401 Index block_end = block_start + actualBlockSize(block_idx);
1402
1403 processBlock<Alignment>(block_idx, block_start, block_end);
1404
1405 int v = num_pending_blocks.fetch_sub(1);
1406 eigen_assert(v >= 1);
1407
1408 if (v == 1) {
1409
1410 aggregateL0Blocks<Alignment>();
1411
1412
1413 applyOutputKernel();
1414
1415
1416
1417
1418
1419
1420 DoneCallback done_copy = std::move(done);
1421
1422
1423 delete this;
1424
1425
1426 done_copy();
1427 }
1428 }
1429
1430
1431
1432
1433 static Index blockSize(Index k, int num_threads) {
1434 const auto round_up = [=](Index index) -> Index {
1435 const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
1436 return divup<Index>(index, kmultiple) * kmultiple;
1437 };
1438
1439 const Index target_block_size = round_up(divup<Index>(k, num_threads));
1440 const Index desired_min_block_size = 12 * packet_size;
1441
1442 return numext::mini<Index>(
1443 k, numext::maxi<Index>(desired_min_block_size, target_block_size));
1444 }
1445
1446 EvalShardedByInnerDimContext(const EvalShardedByInnerDimContext&) = delete;
1447 void operator=(const EvalShardedByInnerDimContext&) = delete;
1448 };
1449
1450
1451
1452
1453
1454
1455
1456 static bool shardByCol(Index m, Index n, Index num_threads) {
1457
1458
1459
1460
1461
1462
1463 if (m / num_threads >= Traits::nr &&
1464
1465 (n / num_threads < Traits::nr ||
1466
1467
1468 (n / num_threads < 4 * Traits::nr &&
1469 (n % (num_threads * Traits::nr)) != 0 &&
1470
1471 ((m % (num_threads * Traits::nr)) == 0 ||
1472
1473
1474
1475 (m / n >= 6)))))
1476 return false;
1477
1478
1479 if (n / num_threads < 16 * Traits::nr && m > n * 32) return false;
1480 return true;
1481 }
1482
1483 Index coarsenM(Index m, Index n, Index bm, Index bn, Index bk, Index gn,
1484 int num_threads, bool shard_by_col) const {
1485 Index gm = 1;
1486 Index gm1 = 1;
1487 Index nm0 = divup(m, bm);
1488 Index nm1 = nm0;
1489 for (;;) {
1490
1491
1492
1493 while (gm1 <= nm0 && nm1 == divup(nm0, gm1)) gm1++;
1494 if (gm1 > nm0) break;
1495
1496 int res = checkGrain(m, n, bm, bn, bk, gm1, gn, gm, gn, num_threads,
1497 shard_by_col);
1498 if (res < 0) break;
1499 nm1 = divup(nm0, gm1);
1500 if (res == 0) continue;
1501
1502 gm = gm1;
1503 }
1504 return gm;
1505 }
1506
1507 Index coarsenN(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
1508 int num_threads, bool shard_by_col) const {
1509 Index gn = 1;
1510 Index gn1 = 1;
1511 Index nn0 = divup(n, bn);
1512 Index nn1 = nn0;
1513 for (;;) {
1514 while (gn1 <= nn0 && nn1 == divup(nn0, gn1)) gn1++;
1515 if (gn1 > nn0) break;
1516 int res = checkGrain(m, n, bm, bn, bk, gm, gn1, gm, gn, num_threads,
1517 shard_by_col);
1518 if (res < 0) break;
1519 nn1 = divup(nn0, gn1);
1520 if (res == 0) continue;
1521 gn = gn1;
1522 }
1523 return gn;
1524 }
1525
1526
1527
1528 int checkGrain(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
1529 Index gn, Index oldgm, Index oldgn, int num_threads,
1530 bool shard_by_col) const {
1531 const TensorOpCost cost =
1532 contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col, true);
1533 double taskSize = TensorCostModel<ThreadPoolDevice>::taskSize(
1534 static_cast<double>(bm) * gm * bn * gn, cost);
1535
1536
1537 if (taskSize < 1) return 1;
1538
1539 if (taskSize > 2) return -1;
1540
1541
1542
1543
1544
1545
1546 Index nm0 = divup(m, bm);
1547 Index nn0 = divup(n, bn);
1548 Index new_tasks = divup(nm0, gm) * divup(nn0, gn);
1549 double new_parallelism = static_cast<double>(new_tasks) /
1550 (divup<int>(new_tasks, num_threads) * num_threads);
1551 Index old_tasks = divup(nm0, oldgm) * divup(nn0, oldgn);
1552 double old_parallelism = static_cast<double>(old_tasks) /
1553 (divup<int>(old_tasks, num_threads) * num_threads);
1554 if (new_parallelism > old_parallelism || new_parallelism == 1) return 1;
1555 return 0;
1556 }
1557
1558 TensorOpCost contractionCost(Index m, Index n, Index bm, Index bn, Index bk,
1559 bool shard_by_col, bool prepacked) const {
1560 const int packed_size = std::min<int>(PacketType<LhsScalar, Device>::size,
1561 PacketType<RhsScalar, Device>::size);
1562 const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1563 const double kd = static_cast<double>(bk);
1564 double compute_bandwidth = computeBandwidth(false, bm, bn, bk);
1565
1566 TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth, true, packed_size);
1567
1568 cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
1569 if (prepacked) {
1570
1571
1572
1573 return cost;
1574 }
1575
1576 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * (kd / n);
1577 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * (kd / m);
1578
1579
1580 if (shard_by_col)
1581 lhsCost.dropMemoryCost();
1582 else
1583 rhsCost.dropMemoryCost();
1584 return cost + lhsCost + rhsCost;
1585 }
1586
1587
1588
1589 static bool shardByInnerDim(Index m, Index n, Index k, int num_threads,
1590 int num_threads_by_k) {
1591 std::ptrdiff_t bufsize = m * n * sizeof(Scalar);
1592 bool shard_by_k = false;
1593 if (n == 1 ||
1594 num_threads_by_k < 2 ||
1595 num_threads_by_k <
1596 num_threads ||
1597 bufsize > l3CacheSize() / num_threads_by_k ||
1598
1599 k / num_threads_by_k < 2 * Traits::nr) {
1600 shard_by_k = false;
1601 } else if (numext::maxi(m, n) / num_threads <
1602 Traits::nr ||
1603
1604 (k / num_threads_by_k > 8 * Traits::nr &&
1605
1606
1607 (numext::mini(m, n) < 2 * Traits::nr ||
1608 num_threads_by_k > num_threads))) {
1609 shard_by_k = true;
1610 }
1611 return shard_by_k;
1612 }
1613
1614 TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const {
1615
1616 const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1617 TensorOpCost cost(0, 0, (computeBandwidth(true, m, n, k) * m) * n, true, output_packet_size);
1618
1619 cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
1620 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * m;
1621 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * n;
1622
1623
1624 lhsCost.dropMemoryCost();
1625 return cost + lhsCost + rhsCost;
1626 }
1627
1628 int numThreadsInnerDim(Index m, Index n, Index k) const {
1629 const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1630 TensorOpCost cost = contractionCostPerInnerDim(m, n, k);
1631 double total_parallel_cost =
1632 TensorCostModel<ThreadPoolDevice>::totalCost(k, cost);
1633
1634
1635 double reduction_cost = TensorCostModel<ThreadPoolDevice>::totalCost(
1636 m * n, TensorOpCost(2, 1, 1, true, output_packet_size));
1637 int num_threads = 1;
1638 double min_cost = total_parallel_cost;
1639 double kPerThreadOverHead = 3000;
1640 double kFixedOverHead = 100000;
1641 for (int nt = 2; nt <= this->m_device.numThreads(); nt += 2) {
1642 double sequential_cost =
1643 kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
1644 double parallel_cost = total_parallel_cost / nt + sequential_cost;
1645 if (parallel_cost < min_cost) {
1646 num_threads = nt;
1647 min_cost = parallel_cost;
1648 }
1649 }
1650 return num_threads;
1651 }
1652
1653 double computeBandwidth(bool shard_by_col, Index bm, Index bn,
1654 Index bk) const {
1655
1656
1657
1658 double computeBandwidth =
1659 bk == 1 ? 4.0
1660 : (shard_by_col ? bn : bm) < Traits::nr ||
1661 (shard_by_col ? bm : bn) < Traits::mr
1662 ? 2.0
1663 : 0.5;
1664 #ifndef EIGEN_VECTORIZE_FMA
1665
1666
1667
1668
1669 if (computeBandwidth == 0.5) computeBandwidth = 1.0;
1670 #endif
1671 return computeBandwidth;
1672 }
1673
1674 };
1675
1676 }
1677
1678 #endif
1679 #endif