Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 10:14:16

0001 // This file is part of Eigen, a lightweight C++ template library
0002 // for linear algebra.
0003 //
0004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
0005 //
0006 // This Source Code Form is subject to the terms of the Mozilla
0007 // Public License v. 2.0. If a copy of the MPL was not distributed
0008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
0009 
0010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
0011 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
0012 
0013 // evaluator for thread pool device
0014 #ifdef EIGEN_USE_THREADS
0015 
0016 namespace Eigen {
0017 
0018 template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
0019 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> :
0020     public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > {
0021 
0022   typedef ThreadPoolDevice Device;
0023 
0024   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
0025   typedef TensorContractionEvaluatorBase<Self> Base;
0026 
0027   typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
0028   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
0029   typedef typename XprType::Index Index;
0030   typedef typename XprType::CoeffReturnType CoeffReturnType;
0031   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
0032 
0033   enum {
0034     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
0035   };
0036 
0037   // Most of the code is assuming that both input tensors are ColMajor. If the
0038   // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
0039   // If we want to compute A * B = C, where A is LHS and B is RHS, the code
0040   // will pretend B is LHS and A is RHS.
0041   typedef typename internal::conditional<
0042     static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
0043   typedef typename internal::conditional<
0044     static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
0045 
0046   static const int LDims =
0047       internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
0048   static const int RDims =
0049       internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
0050   static const int ContractDims = internal::array_size<Indices>::value;
0051 
0052   typedef array<Index, LDims> left_dim_mapper_t;
0053   typedef array<Index, RDims> right_dim_mapper_t;
0054 
0055   typedef array<Index, ContractDims> contract_t;
0056   typedef array<Index, LDims - ContractDims> left_nocontract_t;
0057   typedef array<Index, RDims - ContractDims> right_nocontract_t;
0058 
0059   static const int NumDims = LDims + RDims - 2 * ContractDims;
0060 
0061   typedef DSizes<Index, NumDims> Dimensions;
0062 
0063   // typedefs needed in evalTo
0064   typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
0065   typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
0066   typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
0067 
0068   typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
0069   typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
0070 
0071   TensorEvaluator(const XprType& op, const Device& device) :
0072       Base(op, device) {}
0073 
0074   template <int Alignment>
0075   void evalProduct(Scalar* buffer) const {
0076     evalProductImpl<NoCallback, Alignment>(buffer, NoCallback());
0077   }
0078 
0079   template <typename EvalToCallback, int Alignment>
0080   void evalProductAsync(Scalar* buffer, EvalToCallback done) const {
0081     evalProductImpl<EvalToCallback, Alignment>(buffer, std::move(done));
0082   }
0083 
0084   template <typename DoneCallback, int Alignment>
0085   void evalProductImpl(Scalar* buffer, DoneCallback done) const {
0086     // This function computes a lot of heuristics in multiple steps, and it
0087     // also has multiple exit points. To keep it sane, readable and all in one
0088     // place, sync/async execution decision is made at runtime at the very end.
0089     //
0090     // (1) In sync mode we allocate Context on the stack, submit computations
0091     //     to the device thread pool, and block on a barrier until it is
0092     //     completed.
0093     //
0094     // (2) In async mode we allocate Context on the heap, and after all tasks
0095     //     are finished, we call provided the done callback, and delete a
0096     //     context from the heap.
0097     //
0098     // (*) EvalParallelContext & EvalShardedByInnerDimContext owns all the state
0099     // and temporary buffers, requried for executing the tensor contraction.
0100     // They are responsible for cleaning it up after contraction is done.
0101     static const bool IsEvalInSyncMode =
0102         std::is_same<DoneCallback, NoCallback>::value;
0103 
0104     const Index m = this->m_i_size;
0105     const Index n = this->m_j_size;
0106     const Index k = this->m_k_size;
0107     if (m == 0 || n == 0 || k == 0) return;
0108 
0109     // Compute a set of algorithm parameters:
0110     // - kernel block sizes (bm, bn, bk)
0111     // - task grain sizes (number of kernels executed per task: gm, gn)
0112     // - number of threads
0113     // - sharding by row/column
0114     // - parallel packing or first lhs then rhs
0115     // and some derived parameters:
0116     // - number of tasks (nm, nn, nk)
0117     // - number of kernels (nm0, nn0)
0118     // Unfortunately, all these parameters are tightly interdependent.
0119     // So in some cases we first compute approximate values, then compute other
0120     // values based on these approximations and then refine the approximations.
0121 
0122     // There are lots of heuristics here. There is some reasoning behind them,
0123     // but ultimately they are just tuned on contraction benchmarks for
0124     // different input configurations, thread counts and instruction sets.
0125     // So feel free to question any of them.
0126 
0127     // Compute whether we want to shard by row or by column.
0128     // This is a first approximation, it will be refined later. Since we don't
0129     // know number of threads yet we use 2, because what's we are most
0130     // interested in at this point is whether it makes sense to use
0131     // parallelization at all or not.
0132     bool shard_by_col = shardByCol(m, n, 2);
0133 
0134     // First approximation of kernel blocking sizes.
0135     // Again, we don't know number of threads yet, so we use 2.
0136     Index bm, bn, bk;
0137     if (shard_by_col) {
0138       internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
0139                                           internal::ShardByCol>
0140           blocking(k, m, n, 2);
0141       bm = blocking.mc();
0142       bn = blocking.nc();
0143       bk = blocking.kc();
0144     } else {
0145       internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
0146                                           internal::ShardByRow>
0147           blocking(k, m, n, 2);
0148       bm = blocking.mc();
0149       bn = blocking.nc();
0150       bk = blocking.kc();
0151     }
0152 
0153     // Compute optimal number of threads.
0154     // Note: we use bk instead of k here because we are interested in amount of
0155     // _parallelizable_ computations, and computations are not parallelizable
0156     // across k dimension.
0157     const TensorOpCost cost =
0158         contractionCost(m, n, bm, bn, bk, shard_by_col, false);
0159     int num_threads = TensorCostModel<ThreadPoolDevice>::numThreads(
0160         static_cast<double>(n) * m, cost, this->m_device.numThreads());
0161     int num_threads_by_k = numThreadsInnerDim(m, n, k);
0162     if (shardByInnerDim(m, n, k, num_threads, num_threads_by_k)) {
0163       // We are in the scenario where it is more effective to shard by the
0164       // inner dimension.
0165       if (IsEvalInSyncMode) {
0166         EvalShardedByInnerDimContext<DoneCallback> ctx(
0167             this, num_threads_by_k, buffer, m, n, k, std::move(done));
0168         ctx.template run<Alignment>();
0169       } else {
0170         auto* ctx = new EvalShardedByInnerDimContext<DoneCallback>(
0171             this, num_threads_by_k, buffer, m, n, k, std::move(done));
0172         ctx->template runAsync<Alignment>();
0173       }
0174 
0175       return;
0176     }
0177 
0178     // TODO(dvyukov): this is a stop-gap to prevent regressions while the cost
0179     // model is not tuned. Remove this when the cost model is tuned.
0180     if (n == 1) num_threads = 1;
0181 
0182     if (num_threads == 1) {
0183       TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential,
0184                                   Unaligned, (buffer));
0185       if (!IsEvalInSyncMode) done();
0186       return;
0187     }
0188 
0189     // Now that we know number of threads, recalculate sharding and blocking.
0190     shard_by_col = shardByCol(m, n, num_threads);
0191     if (shard_by_col) {
0192       internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
0193                                           internal::ShardByCol>
0194           blocking(k, m, n, num_threads);
0195       bm = blocking.mc();
0196       bn = blocking.nc();
0197       bk = blocking.kc();
0198     } else {
0199       internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
0200                                           internal::ShardByRow>
0201           blocking(k, m, n, num_threads);
0202       bm = blocking.mc();
0203       bn = blocking.nc();
0204       bk = blocking.kc();
0205     }
0206 
0207     // Number of kernels for each dimension.
0208     Index nm0 = divup(m, bm);
0209     Index nn0 = divup(n, bn);
0210     Index nk = divup(k, bk);
0211 
0212     // Calculate task grain size (number of kernels executed per task).
0213     // This task size coarsening serves two purposes:
0214     // 1. It reduces per-task overheads including synchronization overheads.
0215     // 2. It allows to use caches better (reuse the same packed rhs in several
0216     // consecutive kernels).
0217     Index gm = 1;
0218     Index gn = 1;
0219     // If we are sharding by column, then we prefer to reduce rows first.
0220     if (shard_by_col) {
0221       gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
0222       gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
0223     } else {
0224       gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
0225       gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
0226     }
0227     // Number of tasks in each dimension.
0228     Index nm = divup(nm0, gm);
0229     Index nn = divup(nn0, gn);
0230 
0231     // If there is enough concurrency in the sharding dimension, we choose not
0232     // to paralellize by the other dimension, and execute all kernels in sync
0233     // mode. This reduces parallelism from the nm x nn down to nn
0234     // (shard_by_col==true) or nm (shard_by_col==false).
0235     const Index sharding_dim_tasks = shard_by_col ? nn : nm;
0236     const int num_worker_threads = this->m_device.numThreadsInPool();
0237 
0238     // With small number of threads we want to make sure that we do not reduce
0239     // parallelism too much. With large number of threads we trade maximum
0240     // parallelism for better memory locality.
0241     const float oversharding_factor =
0242         num_worker_threads <= 4  ? 8.0 :
0243         num_worker_threads <= 8  ? 4.0 :
0244         num_worker_threads <= 16 ? 2.0 :
0245         num_worker_threads <= 32 ? 1.0 :
0246         num_worker_threads <= 64 ? 0.8 : /* num_worker_threads > 64 */ 0.6;
0247 
0248     const bool parallelize_by_sharding_dim_only =
0249         sharding_dim_tasks >= oversharding_factor * num_worker_threads;
0250 
0251     // Last by not least, decide whether we want to issue both lhs and rhs
0252     // packing in parallel; or issue lhs packing first, and then issue rhs
0253     // packing when lhs packing completes (for !shard_by_col lhs and rhs are
0254     // swapped). Parallel packing allows more parallelism (for both packing and
0255     // kernels), while sequential packing provides better locality (once
0256     // a thread finishes rhs packing it proceed to kernels with that rhs).
0257     // First, we are interested in parallel packing if there are few tasks.
0258     bool parallel_pack = num_threads >= nm * nn;
0259     // Also do parallel packing if all data fits into L2$.
0260     if (m * bk * Index(sizeof(LhsScalar)) + n * bk * Index(sizeof(RhsScalar)) <=
0261         l2CacheSize() * num_threads)
0262       parallel_pack = true;
0263     // But don't do it if we will use each rhs only once. Locality seems to be
0264     // more important in this case.
0265     if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
0266     // Also don't get in the way of parallelize_by_sharding_dim_only
0267     // optimization.
0268     if (parallelize_by_sharding_dim_only) parallel_pack = false;
0269 
0270     // TODO(ezhulnev): With if contexpr we don't need SyncEvalParallelContext.
0271     if (IsEvalInSyncMode) {
0272 #define CONTEXT_ARGS                                                        \
0273   (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
0274    nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only,      \
0275    NoCallback())                                                            \
0276       .run()
0277       TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment,
0278                                   CONTEXT_ARGS);
0279 #undef CONTEXT_ARGS
0280 
0281     } else {
0282 #define CONTEXT_ARGS                                                        \
0283   (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
0284    nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only,      \
0285    std::move(done))
0286       TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback,
0287                                         Alignment, CONTEXT_ARGS, run());
0288 #undef CONTEXT_ARGS
0289     }
0290   }
0291 
0292   // ------------------------------------------------------------------------ //
0293 
0294   // Dummy struct to represent an empty DoneCallback.
0295 
0296   struct NoCallback {
0297     void operator()() {
0298       eigen_assert(false && "NoCallback should never be called");
0299     }
0300   };
0301 
0302   // ------------------------------------------------------------------------ //
0303 
0304   template <typename DoneCallback, typename Context>
0305   class EvalParallelNotification;
0306 
0307   // Synchronous evaluation notification that blocks caller thread in Wait().
0308   template <typename Context>
0309   class EvalParallelNotification<NoCallback, Context> {
0310    public:
0311     EvalParallelNotification(Context*, NoCallback) {}
0312     void Notify() { done_.Notify(); }
0313     void Wait() { done_.Wait(); }
0314    private:
0315     Eigen::Notification done_;
0316   };
0317 
0318   // Asynchronous evaluation notification that does not block in Wait().
0319   template <typename DoneCallback, typename Context>
0320   class EvalParallelNotification {
0321    public:
0322     EvalParallelNotification(Context* ctx, DoneCallback done)
0323         : ctx_(ctx), done_(std::move(done)) {}
0324 
0325     void Notify() {
0326       // Make a copy of done callback, because it will be destructed when we
0327       // will delete context in the next line (EvalParallelNotification is a
0328       // data member of EvalParallelContext class).
0329       DoneCallback done_copy = std::move(done_);
0330 
0331       // Delete parallel evaluation context.
0332       delete ctx_;
0333 
0334       // Now safely call the done callback.
0335       done_copy();
0336     }
0337 
0338     void Wait() {}
0339 
0340    private:
0341     Context* ctx_;
0342     DoneCallback done_;
0343   };
0344 
0345   // Context orchestrates sync/async parallel contraction evaluation. When it is
0346   // executed in asynchronous mode, it owns all the shared state that might be
0347   // accessible by block packing and kernel tasks.
0348 
0349   template <typename DoneCallback, bool lhs_inner_dim_contiguous,
0350             bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered,
0351             int Alignment>
0352   class EvalParallelContext {
0353    public:
0354     typedef internal::TensorContractionInputMapper<
0355         LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
0356         contract_t, internal::packet_traits<LhsScalar>::size,
0357         lhs_inner_dim_contiguous, false, Unaligned>
0358         LhsMapper;
0359     typedef internal::TensorContractionInputMapper<
0360         RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
0361         contract_t, internal::packet_traits<RhsScalar>::size,
0362         rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
0363         RhsMapper;
0364 
0365     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
0366 
0367     typedef internal::TensorContractionKernel<
0368         Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
0369         TensorContractionKernel;
0370 
0371     typedef typename TensorContractionKernel::LhsBlock LhsBlock;
0372     typedef typename TensorContractionKernel::RhsBlock RhsBlock;
0373     typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
0374 
0375     EvalParallelContext(const Self* self, int num_threads, Scalar* buffer,
0376                         Index tm, Index tn, Index tk, Index bm, Index bn,
0377                         Index bk, Index nm, Index nn, Index nk, Index gm,
0378                         Index gn, Index nm0, Index nn0, bool shard_by_col,
0379                         bool parallel_pack,
0380                         bool parallelize_by_sharding_dim_only,
0381                         DoneCallback done)
0382         : created_by_thread_id_(std::this_thread::get_id()),
0383           done_(this, std::move(done)),
0384           device_(self->m_device),
0385           lhs_(self->m_leftImpl, self->m_left_nocontract_strides,
0386                self->m_i_strides, self->m_left_contracting_strides,
0387                self->m_k_strides),
0388           rhs_(self->m_rightImpl, self->m_right_nocontract_strides,
0389                self->m_j_strides, self->m_right_contracting_strides,
0390                self->m_k_strides),
0391           buffer_(buffer),
0392           output_(buffer, tm),
0393           output_kernel_(self->m_output_kernel),
0394           tensor_contraction_params_(self->m_tensor_contraction_params),
0395           num_threads_(num_threads),
0396           shard_by_col_(shard_by_col),
0397           parallel_pack_(parallel_pack),
0398           parallelize_by_sharding_dim_only_(parallelize_by_sharding_dim_only),
0399           m_(tm),
0400           n_(tn),
0401           k_(tk),
0402           bm_(bm),
0403           bn_(bn),
0404           bk_(bk),
0405           nm_(nm),
0406           nn_(nn),
0407           nk_(nk),
0408           gm_(gm),
0409           gn_(gn),
0410           nm0_(nm0),
0411           nn0_(nn0),
0412           kernel_(m_, k_, n_, bm_, bk_, bn_),
0413           num_thread_local_allocations_(0),
0414           // We reserve 2X more capacity for a thread local values, than the
0415           // number of threads in the pool to efficiently handle task stealing
0416           // by threads that are not managed by the pool.
0417           thread_local_capacity(2 * (parallelize_by_sharding_dim_only_
0418                                          ? device_.numThreadsInPool()
0419                                          : 0)),
0420           // We will use only one of the Lhs/Rhs thread local storage depending
0421           // on the shard_by_col value and we parallelize by sharding dim ONLY.
0422           lhs_thread_local_blocks_(shard_by_col_ ? 0 : thread_local_capacity,
0423                                    {*this}, {*this}),
0424           rhs_thread_local_blocks_(shard_by_col_ ? thread_local_capacity : 0,
0425                                    {*this}, {*this}) {
0426       // These two options are mutually exclusive.
0427       eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
0428 
0429       for (Index x = 0; x < P; x++) {
0430         // Normal number of notifications for k slice switch is
0431         // nm_ + nn_ + nm_ * nn_. However, first P - 1 slices will receive only
0432         // nm_ + nn_ notifications, because they will not receive notifications
0433         // from preceding kernels.
0434         state_switch_[x] =
0435             x == 0
0436                 ? 1
0437                 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) +
0438                       (x == P - 1 ? nm_ * nn_ : 0);
0439         state_packing_ready_[x] =
0440             parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
0441         state_kernel_[x] = new std::atomic<uint8_t>*[nm_];
0442         for (Index m = 0; m < nm_; m++) {
0443           state_kernel_[x][m] = new std::atomic<uint8_t>[nn_];
0444           // Kernels generally receive 3 notifications (previous kernel + 2
0445           // packing), but the first slice won't get notifications from previous
0446           // kernels.
0447           for (Index n = 0; n < nn_; n++)
0448             state_kernel_[x][m][n].store(
0449                 (x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1),
0450                 std::memory_order_relaxed);
0451         }
0452       }
0453 
0454       // Allocate memory for packed rhs/lhs matrices.
0455       packed_mem_ = kernel_.allocateSlices(            //
0456           device_,                                     //
0457           /*num_lhs=*/nm0_,                            //
0458           /*num_rhs=*/nn0_,                            //
0459           /*num_slices=*/std::min<Index>(nk_, P - 1),  //
0460           packed_lhs_, packed_rhs_);
0461 
0462       if (parallelize_by_sharding_dim_only_) {
0463         const int num_worker_threads = device_.numThreadsInPool();
0464 
0465         if (shard_by_col) {
0466           can_use_thread_local_packed_ = new std::atomic<bool>[nn_];
0467           for (int i = 0; i < nn_; ++i)
0468             can_use_thread_local_packed_[i].store(true,
0469                                                   std::memory_order_relaxed);
0470 
0471           Index num_blocks = num_worker_threads * gn_;
0472           thread_local_pre_alocated_mem_ = kernel_.allocateSlices(  //
0473               device_,                                              //
0474               /*num_lhs=*/0,                                        //
0475               /*num_rhs=*/num_blocks,                               //
0476               /*num_slices=*/1,                                     //
0477               /*lhs_blocks=*/nullptr, &rhs_thread_local_pre_allocated_);
0478 
0479         } else {
0480           can_use_thread_local_packed_ = new std::atomic<bool>[nm_];
0481           for (int i = 0; i < nm_; ++i)
0482             can_use_thread_local_packed_[i].store(true,
0483                                                   std::memory_order_relaxed);
0484 
0485           Index num_blocks = num_worker_threads * gm_;
0486           thread_local_pre_alocated_mem_ = kernel_.allocateSlices(  //
0487               device_,                                              //
0488               /*num_lhs=*/num_blocks,                               //
0489               /*num_rhs=*/0,                                        //
0490               /*num_slices=*/1, &lhs_thread_local_pre_allocated_,   //
0491               /*rhs_blocks=*/nullptr);
0492         }
0493       }
0494     }
0495 
0496     ~EvalParallelContext() {
0497       for (Index x = 0; x < P; x++) {
0498         for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
0499         delete[] state_kernel_[x];
0500       }
0501       kernel_.deallocate(device_, packed_mem_);
0502       if (parallelize_by_sharding_dim_only_) {
0503         kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
0504         delete[] can_use_thread_local_packed_;
0505       }
0506     }
0507 
0508     void run() {
0509       // Kick off packing of the first slice.
0510       signal_switch(0, 1);
0511 
0512       // Wait for overall completion.
0513       //
0514       // If parallel evaluation is executed in async mode, this is a no-op, and
0515       // Wait() will return immediately. In synchronous mode it will block the
0516       // caller thread until it will receive notification from last task.
0517       //
0518       // In async mode, last task when completed will call done callback from
0519       // the same thread, and will delete this context.
0520       //
0521       // TODO(dvyukov): This wait can lead to deadlock if contraction is
0522       // evaluated in synchronous mode. If nthreads contractions are
0523       // concurrently submitted from worker threads, this wait will block all
0524       // worker threads and the system will deadlock.
0525       done_.Wait();
0526     }
0527 
0528    private:
0529     std::thread::id created_by_thread_id_;
0530 
0531     // This notification is specialized on the type of DoneCallback and can be
0532     // blocking or non-blocking.
0533     EvalParallelNotification<DoneCallback, EvalParallelContext> done_;
0534 
0535     const Device& device_;
0536     LhsMapper lhs_;
0537     RhsMapper rhs_;
0538     Scalar* const buffer_;
0539     OutputMapper output_;
0540     OutputKernelType output_kernel_;
0541     TensorContractionParams tensor_contraction_params_;
0542     const int num_threads_;
0543     const bool shard_by_col_;
0544     const bool parallel_pack_;
0545     const bool parallelize_by_sharding_dim_only_;
0546     // Matrix sizes.
0547     const Index m_;
0548     const Index n_;
0549     const Index k_;
0550     // Block sizes.
0551     const Index bm_;
0552     const Index bn_;
0553     const Index bk_;
0554     // Number of tasks.
0555     const Index nm_;
0556     const Index nn_;
0557     const Index nk_;
0558     // Task grain sizes (number of kernels executed per task).
0559     const Index gm_;
0560     const Index gn_;
0561     // Number of blocks (this is different from ni_/nn_ because of task size
0562     // coarsening).
0563     const Index nm0_;
0564     const Index nn0_;
0565     // Tensor contraction kernel.
0566     TensorContractionKernel kernel_;
0567 
0568     // Parallelization strategy.
0569     //
0570     // Blocks related to the same k block can run in parallel because they write
0571     // to different output blocks. So we parallelize within k slices, this
0572     // gives us parallelism level of m x n. Before we can start any kernels
0573     // related to k-th slice, we need to issue m lhs packing tasks and n rhs
0574     // packing tasks.
0575     //
0576     // However, there is a bottleneck when we are finishing kernels for k-th
0577     // slice (at the very end there is only 1 runnable kernel). To mitigate this
0578     // bottleneck we allow kernels from k-th and k+1-th slices to run in
0579     // parallel. Note that (m, n, k) and (m, n, k+1) kernels write to the same
0580     // output block, so they must not run in parallel.
0581     //
0582     // This gives us the following dependency graph.
0583     // On each k slice we have m x n kernel tasks, m lhs paking tasks and n rhs
0584     // packing tasks.
0585     // Kernel (m, n, k) can start when:
0586     //  - kernel (m, n, k-1) has finished
0587     //  - lhs packing (m, k) has finished
0588     //  - rhs packing (n, k) has finished
0589     // Lhs/rhs packing can start when:
0590     //  - all k-1 packing has finished (artificially imposed to limit amount of
0591     //  parallel packing)
0592     //
0593     // On top of that we limit runnable tasks to two consecutive k slices.
0594     // This is done to limit amount of memory we need for packed lhs/rhs
0595     // (for each k slice we need m*bk + n*bk memory in packed_lhs_/packed_rhs_).
0596     //
0597     // state_switch_ tracks when we are ready to switch to the next k slice.
0598     // state_kernel_[m][n] tracks when we are ready to kick off kernel (m, n).
0599     // These variable are rolling over 3 consecutive k slices: first two we are
0600     // actively executing + one to track completion of kernels in the second
0601     // slice.
0602     static const Index P = 3;
0603 
0604     // Handle to the allocated temporary storage for Lhs/Rhs blocks.
0605     BlockMemHandle packed_mem_;
0606     std::vector<LhsBlock> packed_lhs_[P - 1];
0607     std::vector<RhsBlock> packed_rhs_[P - 1];
0608 
0609     // If we choose to parallelize only by the sharding dimension, each thread
0610     // will have it's own "thead local" (not a c++ thread local storage) memory
0611     // for packed_lhs or packed_rhs (shard_by_col = false of true). This memory
0612     // can't be passed to a kernel that might execute on a different thread.
0613     //
0614     // In practice when we are ready to pack memory for the sharding dimension
0615     // (rhs if shard_by_col==true) of the K-th slice, all kernels for K-1 slice
0616     // already computed (99% of the time), and we can pack data into the thread
0617     // local storage, and guarantee that all the kernels will be executed
0618     // immediately in the same thread. This significantly increases L1 cache hit
0619     // ratio and reduces pressure on the memory bus.
0620     //
0621     // It's still possible that kernel for the K-th slice will be ready before
0622     // completion of the K-1 kernel, so we have to allocate "global" packed_lhs_
0623     // and packed_rhs_ to allow kernels to be executed later on a thread
0624     // different from the thread that was used for packing.
0625 
0626     // Handle for pre-allocated thread local memory buffers.
0627     BlockMemHandle thread_local_pre_alocated_mem_;
0628 
0629     // Only one of these will be initialized depending on shard_by_col value
0630     // (the size will be `num_worker_threads * num_grains_in_the_sharding_dim`).
0631     std::vector<LhsBlock> lhs_thread_local_pre_allocated_;
0632     std::vector<RhsBlock> rhs_thread_local_pre_allocated_;
0633 
0634     // How many thread local blocks were already allocated.
0635     std::atomic<int> num_thread_local_allocations_;
0636     const int thread_local_capacity;
0637 
0638     // We will use pre-allocated Lhs/Rhs blocks defined above, if the number of
0639     // unique threads in a system is below or equal to the number of threads in
0640     // a thread pool. We will fallback on dynamic memory allocation after that.
0641 
0642     // ThreadLocalBlocks is a container for Lhs or Rhs thread local buffers. Its
0643     // size is equal to the grain size in Lhs/Rhs sharding dimension.
0644     template <typename BlockType>
0645     class ThreadLocalBlocks {
0646      public:
0647       ThreadLocalBlocks() = default;
0648 
0649       ThreadLocalBlocks(BlockType* base, size_t grain_size)
0650           : is_pre_allocated_(true),
0651             thread_local_pre_allocated_base_(base),
0652             grain_size_(grain_size) {}
0653 
0654       ThreadLocalBlocks(BlockMemHandle mem_handle,
0655                         std::vector<BlockType> blocks)
0656           : is_pre_allocated_(false),
0657             mem_handle_(std::move(mem_handle)),
0658             blocks_(std::move(blocks)) {}
0659 
0660       BlockType& block(int grain_index) {
0661         eigen_assert(grain_index >= 0);
0662         eigen_assert(static_cast<size_t>(grain_index) < size());
0663         return is_pre_allocated_ ? thread_local_pre_allocated_base_[grain_index]
0664                                  : blocks_[grain_index];
0665       }
0666 
0667       void Release(EvalParallelContext& ctx) const {
0668         if (!is_pre_allocated_) {
0669           ctx.kernel_.deallocate(ctx.device_, mem_handle_);
0670         }
0671       }
0672 
0673       size_t size() const {
0674         return is_pre_allocated_ ? grain_size_ : blocks_.size();
0675       }
0676 
0677      private:
0678       bool is_pre_allocated_;
0679 
0680       // Reuse pre-allocated thread local buffers.
0681       BlockType* thread_local_pre_allocated_base_ = nullptr;
0682       size_t grain_size_ = 0;
0683 
0684       // These will be initialized only if `is_pre_allocated == false`.
0685       BlockMemHandle mem_handle_{};
0686       std::vector<BlockType> blocks_;
0687     };
0688 
0689     // ThreadLocalBlocksInitialize callable does custom thread local blocks
0690     // initialization, and will reuse pre-allocated buffers if possible, or will
0691     // dynamically allocate new memory.
0692     //
0693     // Lhs/Rhs blocks might be of the same type, so we have to pass explicitly
0694     // for what side do we plan to do block allocation.
0695     template <typename BlockType, bool is_rhs>
0696     class ThreadLocalBlocksInitialize {
0697       static constexpr bool kIsLhs =
0698           !is_rhs && std::is_same<BlockType, LhsBlock>::value;
0699       static const bool kIsRhs =
0700           is_rhs && std::is_same<BlockType, RhsBlock>::value;
0701       static_assert(kIsLhs || kIsRhs, "Unkown block type");
0702 
0703       using Blocks = ThreadLocalBlocks<BlockType>;
0704 
0705      public:
0706       ThreadLocalBlocksInitialize(EvalParallelContext& ctx)
0707           : ctx_(ctx),
0708             num_worker_threads_(ctx_.device_.numThreadsInPool()) {}
0709 
0710       void operator()(Blocks& blocks) {
0711         const int n = ctx_.num_thread_local_allocations_.fetch_add(
0712             1, std::memory_order_relaxed);
0713 
0714         if (n >= num_worker_threads_) {
0715           ThreadLocalBlocksAllocator<is_rhs>::allocate(ctx_, blocks);
0716         } else {
0717           ThreadLocalBlocksAllocator<is_rhs>::reuse(ctx_, n, blocks);
0718         }
0719       }
0720 
0721      private:
0722       // NOTE(ezhulenev): Without 'if constexpr' we have to put calls to
0723       // TensorContractionKernel::allocateSlices into template specializations.
0724       // Also explicit specializations are not allowed at class scope in C++03,
0725       // EvalCtx type parameter is just a workaround for that limitation.
0726       template <bool pack_rhs, typename EvalCtx = EvalParallelContext>
0727       struct ThreadLocalBlocksAllocator;
0728 
0729       template <typename EvalCtx>
0730       struct ThreadLocalBlocksAllocator</*pack_rhs=*/true, EvalCtx> {
0731         static void allocate(EvalCtx& ctx, Blocks& blocks) {
0732           std::vector<RhsBlock> rhs_blocks;
0733           BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
0734               ctx.device_,
0735               /*num_lhs=*/0,
0736               /*num_rhs=*/ctx.gn_,
0737               /*num_slices=*/1,
0738               /*lhs_blocks=*/nullptr, /*rhs_blocks=*/&rhs_blocks);
0739 
0740           blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle),
0741                                                std::move(rhs_blocks));
0742         }
0743 
0744         static void reuse(EvalCtx& ctx, int index, Blocks& blocks) {
0745           RhsBlock* ptr = &ctx.rhs_thread_local_pre_allocated_[ctx.gn_ * index];
0746           blocks = ThreadLocalBlocks<RhsBlock>(ptr, ctx.gn_);
0747         }
0748       };
0749 
0750       template <typename EvalCtx>
0751       struct ThreadLocalBlocksAllocator</*pack_rhs=*/false, EvalCtx> {
0752         static void allocate(EvalCtx& ctx, Blocks& blocks) {
0753           std::vector<LhsBlock> lhs_blocks;
0754           BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
0755               ctx.device_,
0756               /*num_lhs=*/ctx.gm_,
0757               /*num_rhs=*/0,
0758               /*num_slices=*/1,
0759               /*lhs_blocks=*/&lhs_blocks, /*rhs_blocks=*/nullptr);
0760 
0761           blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle),
0762                                                std::move(lhs_blocks));
0763         }
0764 
0765         static void reuse(EvalCtx& ctx, int index, Blocks& blocks) {
0766           LhsBlock* ptr = &ctx.lhs_thread_local_pre_allocated_[ctx.gm_ * index];
0767           blocks = ThreadLocalBlocks<LhsBlock>(ptr, ctx.gm_);
0768         }
0769       };
0770 
0771       EvalParallelContext& ctx_;
0772       const int num_worker_threads_;
0773     };
0774 
0775     template <typename BlockType>
0776     class ThreadLocalBlocksRelease {
0777      public:
0778       using Blocks = ThreadLocalBlocks<BlockType>;
0779       ThreadLocalBlocksRelease(EvalParallelContext& ctx) : ctx_(ctx) {}
0780       void operator()(Blocks& blocks) { blocks.Release(ctx_); }
0781 
0782      private:
0783       EvalParallelContext& ctx_;
0784     };
0785 
0786     // ThreadLocalBlocks initialization callables.
0787     using ThreadLocalLhsInit =
0788         ThreadLocalBlocksInitialize<LhsBlock, /*is_rhs=*/false>;
0789     using ThreadLocalRhsInit =
0790         ThreadLocalBlocksInitialize<RhsBlock, /*is_rhs=*/true>;
0791 
0792     // ThreadLocalBlocks release callables.
0793     using ThreadLocalLhsRelease = ThreadLocalBlocksRelease<LhsBlock>;
0794     using ThreadLocalRhsRelease = ThreadLocalBlocksRelease<RhsBlock>;
0795 
0796     // Thread local containers for Lhs/Rhs block packs. In practice only one of
0797     // them will be used, depending on the shard_by_col value.
0798     Eigen::ThreadLocal<ThreadLocalBlocks<LhsBlock>, ThreadLocalLhsInit,
0799                        ThreadLocalLhsRelease>
0800         lhs_thread_local_blocks_;
0801     Eigen::ThreadLocal<ThreadLocalBlocks<RhsBlock>, ThreadLocalRhsInit,
0802                        ThreadLocalRhsRelease>
0803         rhs_thread_local_blocks_;
0804 
0805     // After a particular shard for Kth slice missed thread local execution
0806     // opportunity (K-1 slice didn't complete kernels execution), we can no
0807     // longer schedule K+1 and following slices in thread local mode, because
0808     // there is no more guarantee that previous kernels were executed
0809     // sequentially in the same thread (size is nn_ or nm_).
0810     std::atomic<bool>* can_use_thread_local_packed_;
0811 
0812     std::atomic<uint8_t>** state_kernel_[P];
0813     // state_switch_ is frequently modified by worker threads, while other
0814     // fields are read-only after constructor. Let's move it to a separate cache
0815     // line to reduce cache-coherency traffic.
0816     char pad_[128];
0817     std::atomic<Index> state_packing_ready_[P];
0818     std::atomic<Index> state_switch_[P];
0819 
0820     LhsBlock& packed_lhs(Index m, Index k, Index m1, bool use_thread_local) {
0821       if (use_thread_local) {
0822         eigen_assert(!shard_by_col_);
0823         ThreadLocalBlocks<LhsBlock>& blocks = lhs_thread_local_blocks_.local();
0824 
0825         Index grain_index = m1 - m * gm_;
0826         return blocks.block(internal::convert_index<int>(grain_index)); // FIXME better make ThreadLocalBlocks use Eigen::Index?
0827       } else {
0828         return packed_lhs_[k % (P - 1)][m1];
0829       }
0830     }
0831 
0832     RhsBlock& packed_rhs(Index n, Index k, Index n1, bool use_thread_local) {
0833       if (use_thread_local) {
0834         eigen_assert(shard_by_col_);
0835         ThreadLocalBlocks<RhsBlock>& blocks = rhs_thread_local_blocks_.local();
0836 
0837         Index grain_index = n1 - n * gn_;
0838         return blocks.block(internal::convert_index<int>(grain_index)); // FIXME better make ThreadLocalBlocks use Eigen::Index?
0839       } else {
0840         return packed_rhs_[k % (P - 1)][n1];
0841       }
0842     }
0843 
0844     // In following two methods (pack_lhs and pack_rhs), if we know for sure
0845     // that we'll be able to immediately call a kernel with packed data, and do
0846     // not submit it to the thread pool, we can use thread local memory for
0847     // packed data.
0848     //
0849     // We can only reliably check it if we are running all kernels in sync mode
0850     // (parallelize only by sharding dim). If kernel for m==0 (n==0) is ready to
0851     // run, it's guaranteed that all kernels with larger values of m (n) are
0852     // also ready, because we execute them in the same order for all K slices.
0853 
0854     void pack_lhs(Index m, Index k) {
0855       bool use_thread_local = false;
0856 
0857       if (parallelize_by_sharding_dim_only_ && !shard_by_col_ &&
0858           can_use_thread_local_packed_[m].load(std::memory_order_relaxed)) {
0859         if (state_kernel_[k % P][m][0].load(std::memory_order_relaxed) == 1) {
0860           use_thread_local = true;
0861         } else {
0862           // If we can't guarantee that all kernels in `k` slice will be
0863           // executed sequentially in current thread, it's no longer safe to use
0864           // thread local memory in following slices along the k dimensions.
0865           eigen_assert(k > 0);
0866           can_use_thread_local_packed_[m].store(false,
0867                                                 std::memory_order_relaxed);
0868         }
0869       }
0870 
0871       const Index mend = m * gm_ + gm(m);
0872       for (Index m1 = m * gm_; m1 < mend; m1++)
0873         kernel_.packLhs(&packed_lhs(m, k, m1, use_thread_local),
0874                         lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
0875 
0876       if (!parallel_pack_ && shard_by_col_) {
0877         assert(!use_thread_local);
0878         signal_packing(k);
0879       } else {
0880         signal_switch(k + 1);
0881         for (Index n = nn_ - 1; n >= 0; n--) {
0882           bool sync = parallelize_by_sharding_dim_only_ || n == 0;
0883           signal_kernel(m, n, k, sync, use_thread_local);
0884         }
0885       }
0886     }
0887 
0888     void pack_rhs(Index n, Index k) {
0889       bool use_thread_local = false;
0890 
0891       if (parallelize_by_sharding_dim_only_ && shard_by_col_ &&
0892           can_use_thread_local_packed_[n].load(std::memory_order_relaxed)) {
0893         if (state_kernel_[k % P][0][n].load(std::memory_order_relaxed) == 1) {
0894           use_thread_local = true;
0895         } else {
0896           // If we can't guarantee that all kernels in `k` slice will be
0897           // executed sequentially in current thread, it's no longer safe to use
0898           // thread local memory in followig slices along the k dimensions.
0899           eigen_assert(k > 0);
0900           can_use_thread_local_packed_[n].store(false,
0901                                                 std::memory_order_relaxed);
0902         }
0903       }
0904 
0905       const Index nend = n * gn_ + gn(n);
0906       for (Index n1 = n * gn_; n1 < nend; n1++) {
0907         if (!TensorContractionKernel::HasBeta && k == 0) {
0908           // Zero the output memory in parallel, only if contraction kernel does
0909           // not support `beta`. Otherwise we will pass beta 0.0 to the first
0910           // call to the `TensorContractionKernel::invoke()`.
0911           //
0912           // On 10000x2x10000 mm zeroing can easily take half of time. Zero (bn
0913           // x m) row. Safe to do here because all kernels that will write to
0914           // this memory depend on completion of this task. Note: don't call
0915           // device_.memset() here. device_.memset() blocks on thread pool
0916           // worker thread, which can lead to underutilization and deadlocks.
0917           memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
0918         }
0919         kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local),
0920                         rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
0921       }
0922 
0923       if (parallel_pack_ || shard_by_col_) {
0924         signal_switch(k + 1);
0925         for (Index m = nm_ - 1; m >= 0; m--) {
0926           bool sync = parallelize_by_sharding_dim_only_ || m == 0;
0927           signal_kernel(m, n, k, sync, use_thread_local);
0928         }
0929       } else {
0930         assert(!use_thread_local);
0931         signal_packing(k);
0932       }
0933     }
0934 
0935     void kernel(Index m, Index n, Index k, bool use_thread_local) {
0936       // Note: order of iteration matters here. Iteration over m is innermost
0937       // because we want to reuse the same packed rhs in consecutive tasks
0938       // (rhs fits into L2$ while lhs only into L3$).
0939       const Index nend = n * gn_ + gn(n);
0940       const Index mend = m * gm_ + gm(m);
0941 
0942       // NOTE: output = alpha * LHS * RHS + beta * output.
0943       const Scalar alpha = Scalar(1);
0944       const Scalar beta =
0945           (TensorContractionKernel::HasBeta && k == 0) ? Scalar(0) : Scalar(1);
0946 
0947       if (shard_by_col_) {
0948         for (Index n1 = n * gn_; n1 < nend; n1++) {
0949           for (Index m1 = m * gm_; m1 < mend; m1++) {
0950             const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
0951             kernel_.invoke(
0952                 output_mapper,
0953                 packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
0954                 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
0955                 bk(k), bn(n1), alpha, beta);
0956 
0957             // We are done with the last task for the [m1, n1] block.
0958             if (k + 1 == nk_) {
0959               output_kernel_(output_mapper, tensor_contraction_params_,
0960                              m1 * bm_, n1 * bn_, bm(m1), bn(n1));
0961             }
0962           }
0963         }
0964       } else {
0965         for (Index m1 = m * gm_; m1 < mend; m1++)
0966           for (Index n1 = n * gn_; n1 < nend; n1++) {
0967             const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
0968             kernel_.invoke(
0969                 output_mapper,
0970                 packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
0971                 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
0972                 bk(k), bn(n1), alpha, beta);
0973 
0974             // We are done with the last task for the [m1, n1] block.
0975             if (k + 1 == nk_) {
0976               output_kernel_(output_mapper, tensor_contraction_params_,
0977                              m1 * bm_, n1 * bn_, bm(m1), bn(n1));
0978             }
0979           }
0980       }
0981       signal_kernel(m, n, k + 1, /*sync=*/false, /*use_thread_local=*/false);
0982       signal_switch(k + 2);
0983     }
0984 
0985     void signal_packing(Index k) {
0986       eigen_assert(!parallel_pack_);
0987       Index s = state_packing_ready_[k % P].fetch_sub(1);
0988       eigen_assert(s > 0);
0989       if (s != 1) return;
0990       state_packing_ready_[k % P] = shard_by_col_ ? nm_ : nn_;
0991       enqueue_packing(k, shard_by_col_);
0992     }
0993 
0994     void signal_kernel(Index m, Index n, Index k, bool sync,
0995                        bool use_thread_local) {
0996       std::atomic<uint8_t>* state = &state_kernel_[k % P][m][n];
0997       Index s = state->load();
0998       eigen_assert(s > 0);
0999       if (s != 1 && state->fetch_sub(1) != 1) {
1000         eigen_assert(!use_thread_local);
1001         return;
1002       }
1003       state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
1004       if (sync) {
1005         kernel(m, n, k, use_thread_local);
1006       } else {
1007         eigen_assert(!use_thread_local);
1008         device_.enqueueNoNotification(
1009             [=]() { kernel(m, n, k, use_thread_local); });
1010       }
1011     }
1012 
1013     void signal_switch(Index k, Index v = 1) {
1014       Index s = state_switch_[k % P].fetch_sub(v);
1015       eigen_assert(s >= v);
1016       if (s != v) return;
1017 
1018       // Ready to switch to the next k slice.
1019       // Reset counter for the next iteration.
1020       state_switch_[k % P] =
1021           (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) +
1022           nm_ * nn_;
1023       if (k < nk_) {
1024         // Issue lhs/rhs packing. Their completion will in turn kick off
1025         // kernels.
1026         if (parallel_pack_) {
1027           enqueue_packing(k, !shard_by_col_);
1028           enqueue_packing(k, shard_by_col_);
1029         } else if (shard_by_col_) {
1030           enqueue_packing(k, false);
1031         } else {
1032           enqueue_packing(k, true);
1033         }
1034 
1035         // Termination handling.
1036         // Because kernel completion signals k + 2 switch, we need to finish nk
1037         // + 2 slices without issuing any tasks on nk + 1 slice. So here we
1038         // pretend that all nk + 1 packing tasks just finish instantly; so that
1039         // nk + 2 switch only waits for completion of nk kernels.
1040       } else if (k == nk_) {
1041         signal_switch(k + 1,
1042                       parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
1043       } else {
1044         done_.Notify();
1045       }
1046     }
1047 
1048     // Enqueue all rhs/lhs packing for k-th slice.
1049     void enqueue_packing(Index k, bool rhs) {
1050       enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs);
1051     }
1052 
1053     void enqueue_packing_helper(Index start, Index end, Index k, bool rhs) {
1054       if (end - start == 1) {
1055         if (rhs)
1056           pack_rhs(start, k);
1057         else
1058           pack_lhs(start, k);
1059       } else {
1060         while (end - start > 1) {
1061           Index mid = (start + end) / 2;
1062           device_.enqueueNoNotification(
1063               [=]() { enqueue_packing_helper(mid, end, k, rhs); });
1064           end = mid;
1065         }
1066 
1067         // Decide if we want to run first packing task (start == 0) in
1068         // async mode if we parallelize only by sharding dim:
1069         // (1) pack_lhs and pack_rhs call signal_switch before completing
1070         //     all calls to signal_kernel, which in sync mode might lead
1071         //     to the execution of the first kernel of the k+1 slice, before
1072         //     completing a call to the last kernel of the k slice.
1073         // (2) all pack tasks for sharded dim must be executed in a thread
1074         //     pool to get pre-allocated thead local buffers.
1075         bool pack_async =
1076           (start == 0) &&
1077           (parallelize_by_sharding_dim_only_&& shard_by_col_ == rhs) &&
1078           (k > 0 || std::this_thread::get_id() == created_by_thread_id_);
1079 
1080         if (pack_async) {
1081           device_.enqueueNoNotification(
1082               [=]() { enqueue_packing_helper(start, end, k, rhs); });
1083         } else {
1084           enqueue_packing_helper(start, end, k, rhs);
1085         }
1086       }
1087     }
1088 
1089     // Block sizes with accounting for potentially incomplete last block.
1090     Index bm(Index m) const { return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
1091     Index bn(Index n) const { return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
1092     Index bk(Index k) const { return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
1093     // Task grain sizes accounting for potentially incomplete last task.
1094     Index gm(Index m) const { return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
1095     Index gn(Index n) const { return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
1096 
1097     EvalParallelContext(const EvalParallelContext&) = delete;
1098     void operator=(const EvalParallelContext&) = delete;
1099   };
1100 
1101   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
1102             bool rhs_inner_dim_reordered, int Alignment>
1103   using SyncEvalParallelContext =
1104       EvalParallelContext<NoCallback, lhs_inner_dim_contiguous,
1105                           rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
1106                           Alignment>;
1107 
1108   // ------------------------------------------------------------------------ //
1109 
1110   // EvalShardedByInnerDimContext orchestrates sync/async contraction
1111   // evaluation, when we shard by inner dimension. When it is executed in
1112   // asynchronous mode, it owns all the shared state that might be accessible by
1113   // block processing tasks.
1114 
1115   template <typename DoneCallback>
1116   struct EvalShardedByInnerDimContext {
1117     EvalShardedByInnerDimContext(const Self* self, int num_threads,
1118                                  Scalar* result_buffer,
1119                                  Index m_size, Index n_size, Index k_size,
1120                                  DoneCallback done_callback)
1121         : evaluator(self),
1122           m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous),
1123           m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous),
1124           m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered),
1125           result(result_buffer),
1126           m(m_size),
1127           n(n_size),
1128           k(k_size),
1129           done(std::move(done_callback)),
1130           buffer_size_bytes(m * n * sizeof(Scalar)),
1131           block_size(blockSize(k, num_threads)),
1132           num_blocks(divup<Index>(k, block_size)),
1133           num_pending_blocks(internal::convert_index<int>(num_blocks)),
1134           l0_ranges(divup<Index>(num_blocks, l0_size)),
1135           l0_state(l0_ranges),
1136           block_buffers(num_blocks) {
1137       // Keep count of pending gemm tasks for each l0 range.
1138       for (int i = 0; i < l0_ranges; ++i) {
1139         const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size, i);
1140         l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
1141       }
1142 
1143       // Allocate temporary buffers for each block.
1144       for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
1145         Scalar* buf = block_idx == 0
1146                           ? result
1147                           : static_cast<Scalar*>(evaluator->m_device.allocate(
1148                                 buffer_size_bytes));
1149         block_buffers.emplace_back(buf);
1150       }
1151     }
1152 
1153     ~EvalShardedByInnerDimContext() {
1154       for (Index i = 1; i < num_blocks; ++i) {
1155         evaluator->m_device.deallocate(block_buffers[i]);
1156       }
1157     }
1158 
1159     template <int Alignment>
1160     void run() {
1161       Barrier barrier(internal::convert_index<int>(num_blocks));
1162       eval<Alignment>(barrier, 0, num_blocks);
1163       barrier.Wait();
1164 
1165       // Aggregate partial sums from l0 ranges.
1166       aggregateL0Blocks<Alignment>();
1167 
1168       // Apply output kernel.
1169       applyOutputKernel();
1170     }
1171 
1172     template <int Alignment>
1173     void runAsync() {
1174       evalAsync<Alignment>(0, num_blocks);
1175     }
1176 
1177    private:
1178     // The underlying GEMM kernel assumes that k is a multiple of
1179     // the packet size and subtle breakage occurs if this is violated.
1180     static const Index packet_size = internal::packet_traits<RhsScalar>::size;
1181 
1182     const Self* evaluator;  // TensorContraction evaluator
1183 
1184     // These fields required fromTENSOR_CONTRACTION_DISPATCH macro.
1185     bool m_lhs_inner_dim_contiguous;
1186     bool m_rhs_inner_dim_contiguous;
1187     bool m_rhs_inner_dim_reordered;
1188 
1189     Scalar* result;
1190 
1191     Index m;
1192     Index n;
1193     Index k;
1194 
1195     DoneCallback done;
1196 
1197     // ----------------------------------------------------------------------//
1198     // Algorithm parameters.
1199 
1200     // We will compute partial results into the buffers of this size.
1201     Index buffer_size_bytes;
1202 
1203     Index block_size;
1204     Index num_blocks;
1205 
1206     // Keep track of pending tasks when evaluate in async mode.
1207     std::atomic<int> num_pending_blocks;
1208 
1209     // We compute partial gemm results in parallel, and to get the final result
1210     // we need to add them all together. For the large number of threads (>= 48)
1211     // this adds a very expensive sequential step at the end.
1212     //
1213     // We split the [0, num_blocks) into small ranges, and when a task for the
1214     // block finishes its partial gemm computation, it checks if it was the last
1215     // gemm in the range, and if so, it will add all blocks of the range.
1216     //
1217     // After all tasks done, we need to add only these pre-aggregated blocks.
1218 
1219     // For now we use just a single level of ranges to compute pre-aggregated
1220     // partial sums, but in general we can use more layers to compute tree
1221     // aggregation in parallel and reduce the size of the sequential step.
1222     //
1223     // TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
1224     // sense only if number of threads >= ~128?
1225     static const Index l0_size = 4;
1226     Index l0_ranges;
1227 
1228     // Keep count of pending gemm tasks for each l0 range.
1229     MaxSizeVector<std::atomic<int>> l0_state;  // [0, l0_ranges)
1230 
1231     // Buffers allocated for each temporary block computation.
1232     MaxSizeVector<Scalar*> block_buffers;  // [0, num_blocks)
1233 
1234     template <int Alignment>
1235     void processBlock(Index block_idx, Index begin, Index end) {
1236       Scalar* buf = block_buffers[block_idx];
1237 
1238       TENSOR_CONTRACTION_DISPATCH(
1239           evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
1240           (buf, begin, end,
1241            /*num_threads=*/internal::convert_index<int>(num_blocks)));
1242 
1243       // Check if it was the last task in l0 range.
1244       const Index l0_index = block_idx / l0_size;
1245       const int v = l0_state[l0_index].fetch_sub(1);
1246       eigen_assert(v >= 1);
1247 
1248       // If we processed the last block of the range, we can aggregate all
1249       // partial results into the first block of the range.
1250       if (v == 1) {
1251         const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index);
1252         const Index dst_block_idx = l0_index * l0_size;
1253 
1254         if (rng_size == l0_size) {
1255           addAllToBuffer<Alignment>(
1256               m * n,
1257               /*src_buf0=*/block_buffers[dst_block_idx + 1],
1258               /*src_buf1=*/block_buffers[dst_block_idx + 2],
1259               /*src_buf2=*/block_buffers[dst_block_idx + 3],
1260               /*dst_buf= */ block_buffers[dst_block_idx]);
1261         } else {
1262           // Aggregate blocks of potentially incomplete last range.
1263           for (int i = 1; i < rng_size; ++i) {
1264             addToBuffer<Alignment>(m * n,
1265                                    /*src_buf=*/block_buffers[dst_block_idx + i],
1266                                    /*dst_buf=*/block_buffers[dst_block_idx]);
1267           }
1268         }
1269       }
1270     }
1271 
1272     // Aggregate partial sums from l0 ranges.
1273     template <int Alignment>
1274     void aggregateL0Blocks() const {
1275       Index l0_index = 1;
1276 
1277       for (; l0_index + 2 < l0_ranges; l0_index += 3) {
1278         addAllToBuffer<Alignment>(
1279             m * n,
1280             /*src_buf0=*/block_buffers[(l0_index + 0) * l0_size],
1281             /*src_buf1=*/block_buffers[(l0_index + 1) * l0_size],
1282             /*src_buf2=*/block_buffers[(l0_index + 2) * l0_size],
1283             /*dst_buf= */ block_buffers[0]);
1284       }
1285 
1286       for (; l0_index < l0_ranges; ++l0_index) {
1287         addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size],
1288                                block_buffers[0]);
1289       }
1290     }
1291 
1292     void applyOutputKernel() const {
1293       typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1294       evaluator->m_output_kernel(
1295           OutputMapper(result, m), evaluator->m_tensor_contraction_params,
1296           static_cast<Eigen::Index>(0), static_cast<Eigen::Index>(0), m, n);
1297     }
1298 
1299     // Compute block size with accounting for potentially incomplete last block.
1300     Index actualBlockSize(Index block_idx) const {
1301       return block_idx + 1 < num_blocks
1302                  ? block_size
1303                  : k + block_size - block_size * num_blocks;
1304     };
1305 
1306     // Compute range size with accounting for potentially incomplete last range.
1307     Index actualRangeSize(Index num_ranges, Index range_size,
1308                           Index range_idx) const {
1309       eigen_assert(range_idx < num_ranges);
1310       return range_idx + 1 < num_ranges
1311                  ? range_size
1312                  : num_blocks + range_size - range_size * num_ranges;
1313     };
1314 
1315     template <int Alignment>
1316     EIGEN_STRONG_INLINE static void addToBuffer(size_t n, const Scalar* src_buf,
1317                                                 Scalar* tgt_buf) {
1318       const int output_packet_size =
1319           internal::unpacket_traits<PacketReturnType>::size;
1320       size_t i = 0;
1321       const size_t num_packets = n / output_packet_size;
1322       for (; i < output_packet_size * num_packets; i += output_packet_size) {
1323         const PacketReturnType src_val =
1324             internal::pload<PacketReturnType>(src_buf + i);
1325         const PacketReturnType tgt_val =
1326             internal::ploadt<PacketReturnType, Alignment>(tgt_buf + i);
1327         const PacketReturnType sum = internal::padd(src_val, tgt_val);
1328         internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + i,
1329                                                                sum);
1330       }
1331       for (; i < n; ++i) {
1332         tgt_buf[i] += src_buf[i];
1333       }
1334     }
1335 
1336     template <int Alignment>
1337     EIGEN_STRONG_INLINE static void addAllToBuffer(size_t n,
1338                                                    const Scalar* src_buf0,
1339                                                    const Scalar* src_buf1,
1340                                                    const Scalar* src_buf2,
1341                                                    Scalar* dst_buf) {
1342       using ::Eigen::internal::padd;
1343       using ::Eigen::internal::pload;
1344       using ::Eigen::internal::ploadt;
1345       using ::Eigen::internal::pstoret;
1346 
1347       const int output_packet_size =
1348           internal::unpacket_traits<PacketReturnType>::size;
1349 
1350       size_t i = 0;
1351       const size_t num_packets = n / output_packet_size;
1352       for (; i < output_packet_size * num_packets; i += output_packet_size) {
1353         const auto src_val0 = pload<PacketReturnType>(src_buf0 + i);
1354         const auto src_val1 = pload<PacketReturnType>(src_buf1 + i);
1355         const auto src_val2 = pload<PacketReturnType>(src_buf2 + i);
1356 
1357         const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i);
1358         const auto sum =
1359             padd(padd(dst_val, src_val0), padd(src_val1, src_val2));
1360 
1361         pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum);
1362       }
1363       for (; i < n; ++i) {
1364         dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i];
1365       }
1366     }
1367 
1368     template <int Alignment>
1369     void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) {
1370       while (end_block_idx - start_block_idx > 1) {
1371         Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1372         evaluator->m_device.enqueueNoNotification(
1373             [this, &barrier, mid_block_idx, end_block_idx]() {
1374               eval<Alignment>(barrier, mid_block_idx, end_block_idx);
1375             });
1376         end_block_idx = mid_block_idx;
1377       }
1378 
1379       Index block_idx = start_block_idx;
1380       Index block_start = block_idx * block_size;
1381       Index block_end = block_start + actualBlockSize(block_idx);
1382 
1383       processBlock<Alignment>(block_idx, block_start, block_end);
1384       barrier.Notify();
1385     }
1386 
1387     template <int Alignment>
1388     void evalAsync(Index start_block_idx, Index end_block_idx) {
1389       while (end_block_idx - start_block_idx > 1) {
1390         Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1391         evaluator->m_device.enqueueNoNotification(
1392             [this, mid_block_idx, end_block_idx]() {
1393               evalAsync<Alignment>(mid_block_idx, end_block_idx);
1394             });
1395         end_block_idx = mid_block_idx;
1396       }
1397 
1398       Index block_idx = start_block_idx;
1399 
1400       Index block_start = block_idx * block_size;
1401       Index block_end = block_start + actualBlockSize(block_idx);
1402 
1403       processBlock<Alignment>(block_idx, block_start, block_end);
1404 
1405       int v = num_pending_blocks.fetch_sub(1);
1406       eigen_assert(v >= 1);
1407 
1408       if (v == 1) {
1409         // Aggregate partial sums from l0 ranges.
1410         aggregateL0Blocks<Alignment>();
1411 
1412         // Apply output kernel.
1413         applyOutputKernel();
1414 
1415         // NOTE: If we call `done` callback before deleting this (context),
1416         // it might deallocate Self* pointer captured by context, and we'll
1417         // fail in destructor trying to deallocate temporary buffers.
1418 
1419         // Move done call back from context before it will be destructed.
1420         DoneCallback done_copy = std::move(done);
1421 
1422         // We are confident that we are the last one who touches context.
1423         delete this;
1424 
1425         // Now safely call the done callback.
1426         done_copy();
1427       }
1428     }
1429 
1430     // Cost model doesn't capture well the cost associated with constructing
1431     // tensor contraction mappers and computing loop bounds in gemm_pack_lhs
1432     // and gemm_pack_rhs, so we specify minimum desired block size.
1433     static Index blockSize(Index k, int num_threads) {
1434       const auto round_up = [=](Index index) -> Index {
1435         const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
1436         return divup<Index>(index, kmultiple) * kmultiple;
1437       };
1438 
1439       const Index target_block_size = round_up(divup<Index>(k, num_threads));
1440       const Index desired_min_block_size = 12 * packet_size;
1441 
1442       return numext::mini<Index>(
1443           k, numext::maxi<Index>(desired_min_block_size, target_block_size));
1444     }
1445 
1446     EvalShardedByInnerDimContext(const EvalShardedByInnerDimContext&) = delete;
1447     void operator=(const EvalShardedByInnerDimContext&) = delete;
1448   };
1449 
1450   // ------------------------------------------------------------------------ //
1451 
1452   // Below are the function used by evalProductImpl heuristics, trying to select
1453   // optimcal parameters for parallelization algorithm.
1454 
1455   // Decide whether we want to shard m x n contraction by columns or by rows.
1456   static bool shardByCol(Index m, Index n, Index num_threads) {
1457     // Note: we are comparing both n and m against Traits::nr, it is not
1458     // a mistake. We are trying to figure out how both n and m will fit into
1459     // the main sharding dimension.
1460 
1461     // Sharding by column is the default
1462     // ... unless there is enough data for vectorization over rows
1463     if (m / num_threads >= Traits::nr &&
1464         // and not enough data for vectorization over columns
1465         (n / num_threads < Traits::nr ||
1466          // ... or barely enough data for vectorization over columns,
1467          // but it is not evenly dividable across threads
1468          (n / num_threads < 4 * Traits::nr &&
1469           (n % (num_threads * Traits::nr)) != 0 &&
1470           // ... and it is evenly dividable across threads for rows
1471           ((m % (num_threads * Traits::nr)) == 0 ||
1472            // .. or it is not evenly dividable for both dimensions but
1473            // there is much more data over rows so that corner effects are
1474            // mitigated.
1475            (m / n >= 6)))))
1476       return false;
1477     // Wait, or if matrices are just substantially prolonged over the other
1478     // dimension.
1479     if (n / num_threads < 16 * Traits::nr && m > n * 32) return false;
1480     return true;
1481   }
1482 
1483   Index coarsenM(Index m, Index n, Index bm, Index bn, Index bk, Index gn,
1484                  int num_threads, bool shard_by_col) const {
1485     Index gm = 1;
1486     Index gm1 = 1;
1487     Index nm0 = divup(m, bm);
1488     Index nm1 = nm0;
1489     for (;;) {
1490       // Find the next candidate for m grain size. It needs to result in
1491       // different number of blocks. E.g. if we have 10 kernels, we want to try
1492       // 5 and 10, but not 6, 7, 8 and 9.
1493       while (gm1 <= nm0 && nm1 == divup(nm0, gm1)) gm1++;
1494       if (gm1 > nm0) break;
1495       // Check the candidate.
1496       int res = checkGrain(m, n, bm, bn, bk, gm1, gn, gm, gn, num_threads,
1497                            shard_by_col);
1498       if (res < 0) break;
1499       nm1 = divup(nm0, gm1);
1500       if (res == 0) continue;
1501       // Commit new grain size.
1502       gm = gm1;
1503     }
1504     return gm;
1505   }
1506 
1507   Index coarsenN(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
1508                  int num_threads, bool shard_by_col) const {
1509     Index gn = 1;
1510     Index gn1 = 1;
1511     Index nn0 = divup(n, bn);
1512     Index nn1 = nn0;
1513     for (;;) {
1514       while (gn1 <= nn0 && nn1 == divup(nn0, gn1)) gn1++;
1515       if (gn1 > nn0) break;
1516       int res = checkGrain(m, n, bm, bn, bk, gm, gn1, gm, gn, num_threads,
1517                            shard_by_col);
1518       if (res < 0) break;
1519       nn1 = divup(nn0, gn1);
1520       if (res == 0) continue;
1521       gn = gn1;
1522     }
1523     return gn;
1524   }
1525 
1526   // checkGrain checks whether grain (gm, gn) is suitable and is better than
1527   // (oldgm, oldgn).
1528   int checkGrain(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
1529                  Index gn, Index oldgm, Index oldgn, int num_threads,
1530                  bool shard_by_col) const {
1531     const TensorOpCost cost =
1532         contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col, true);
1533     double taskSize = TensorCostModel<ThreadPoolDevice>::taskSize(
1534         static_cast<double>(bm) * gm * bn * gn, cost);
1535     // If the task is too small, then we agree on it regardless of anything
1536     // else. Otherwise synchronization overheads will dominate.
1537     if (taskSize < 1) return 1;
1538     // If it is too large, then we reject it and all larger tasks.
1539     if (taskSize > 2) return -1;
1540     // Now we are in presumably good task size range.
1541     // The main deciding factor here is parallelism. Consider that we have 12
1542     // kernels and 4 threads. Grains of 2, 3 and 4 all yield good task sizes.
1543     // But 2/4 yield 6/3 tasks, which gives us parallelism of 0.75 (at most 3/4
1544     // of cores will be busy). While grain size 3 gives us 4 tasks, which gives
1545     // us parallelism of 1 (we can load all cores).
1546     Index nm0 = divup(m, bm);
1547     Index nn0 = divup(n, bn);
1548     Index new_tasks = divup(nm0, gm) * divup(nn0, gn);
1549     double new_parallelism = static_cast<double>(new_tasks) /
1550                              (divup<int>(new_tasks, num_threads) * num_threads);
1551     Index old_tasks = divup(nm0, oldgm) * divup(nn0, oldgn);
1552     double old_parallelism = static_cast<double>(old_tasks) /
1553                              (divup<int>(old_tasks, num_threads) * num_threads);
1554     if (new_parallelism > old_parallelism || new_parallelism == 1) return 1;
1555     return 0;
1556   }
1557 
1558   TensorOpCost contractionCost(Index m, Index n, Index bm, Index bn, Index bk,
1559                                bool shard_by_col, bool prepacked) const {
1560     const int packed_size = std::min<int>(PacketType<LhsScalar, Device>::size,
1561                                           PacketType<RhsScalar, Device>::size);
1562     const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1563     const double kd = static_cast<double>(bk);
1564     double compute_bandwidth = computeBandwidth(false, bm, bn, bk);
1565     // Computations.
1566     TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth, true, packed_size);
1567     // Output stores.
1568     cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
1569     if (prepacked) {
1570       // Packing and kernels are executed in different tasks. When we calculate
1571       // task grain size we look only at kernel cost assuming that kernel
1572       // is more expensive than packing.
1573       return cost;
1574     }
1575     // Lhs/rhs loads + computations.
1576     TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * (kd / n);
1577     TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * (kd / m);
1578     // Lhs packing memory cost does not contribute considerably to overall
1579     // execution time because lhs is prefetched early and accessed sequentially.
1580     if (shard_by_col)
1581       lhsCost.dropMemoryCost();
1582     else
1583       rhsCost.dropMemoryCost();
1584     return cost + lhsCost + rhsCost;
1585   }
1586 
1587   // Decide whether we want to shard m x k x n contraction over the inner
1588   // (contraction) dimension (k).
1589   static bool shardByInnerDim(Index m, Index n, Index k, int num_threads,
1590                               int num_threads_by_k) {
1591     std::ptrdiff_t bufsize = m * n * sizeof(Scalar);
1592     bool shard_by_k = false;
1593     if (n == 1 ||                // If mat*vec or...
1594         num_threads_by_k < 2 ||  // running single threaded or...
1595         num_threads_by_k <
1596             num_threads ||  // sharding by k gives less parallelism or...
1597         bufsize > l3CacheSize() / num_threads_by_k ||  // need more buffer space
1598         // than L3 cache or...
1599         k / num_threads_by_k < 2 * Traits::nr) {  // k per thread is tiny.
1600       shard_by_k = false;
1601     } else if (numext::maxi(m, n) / num_threads <
1602                    Traits::nr ||  // both other dimensions are tiny or...
1603                // k per thread is not small and...
1604                (k / num_threads_by_k > 8 * Traits::nr &&
1605                 // one of the outer dimensions is tiny or sharding by k offers
1606                 // more parallelism.
1607                 (numext::mini(m, n) < 2 * Traits::nr ||
1608                  num_threads_by_k > num_threads))) {
1609       shard_by_k = true;
1610     }
1611     return shard_by_k;
1612   }
1613 
1614   TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const {
1615     // Compute cost.
1616     const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1617     TensorOpCost cost(0, 0, (computeBandwidth(true, m, n, k) * m) * n, true, output_packet_size);
1618     // Output stores.
1619     cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
1620     TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * m;
1621     TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * n;
1622     // Since the inner gemm kernel is always sharded by column, the lhs
1623     // load cost is negligible.
1624     lhsCost.dropMemoryCost();
1625     return cost + lhsCost + rhsCost;
1626   }
1627 
1628   int numThreadsInnerDim(Index m, Index n, Index k) const {
1629     const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1630     TensorOpCost cost = contractionCostPerInnerDim(m, n, k);
1631     double total_parallel_cost =
1632         TensorCostModel<ThreadPoolDevice>::totalCost(k, cost);
1633     // Cost of reduction step accumulating the m*n per-thread buffers into the
1634     // result.
1635     double reduction_cost = TensorCostModel<ThreadPoolDevice>::totalCost(
1636         m * n, TensorOpCost(2, 1, 1, true, output_packet_size));
1637     int num_threads = 1;
1638     double min_cost = total_parallel_cost;
1639     double kPerThreadOverHead = 3000;
1640     double kFixedOverHead = 100000;
1641     for (int nt = 2; nt <= this->m_device.numThreads(); nt += 2) {
1642       double sequential_cost =
1643           kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
1644       double parallel_cost = total_parallel_cost / nt + sequential_cost;
1645       if (parallel_cost < min_cost) {
1646         num_threads = nt;
1647         min_cost = parallel_cost;
1648       }
1649     }
1650     return num_threads;
1651   }
1652 
1653   double computeBandwidth(bool shard_by_col, Index bm, Index bn,
1654                           Index bk) const {
1655     // Peak VFMA bandwidth is 0.5. However if we have not enough data for
1656     // vectorization bandwidth drops. The 4.0 and 2.0 bandwidth is determined
1657     // experimentally.
1658     double computeBandwidth =
1659         bk == 1 ? 4.0
1660                 : (shard_by_col ? bn : bm) < Traits::nr ||
1661                           (shard_by_col ? bm : bn) < Traits::mr
1662                       ? 2.0
1663                       : 0.5;
1664 #ifndef EIGEN_VECTORIZE_FMA
1665     // Bandwidth of all of VFMA/MULPS/ADDPS is 0.5 on latest Intel processors.
1666     // However for MULPS/ADDPS we have dependent sequence of 2 such
1667     // instructions,
1668     // so overall bandwidth is 1.0.
1669     if (computeBandwidth == 0.5) computeBandwidth = 1.0;
1670 #endif
1671     return computeBandwidth;
1672   }
1673 
1674 };
1675 
1676 } // end namespace Eigen
1677 
1678 #endif  // EIGEN_USE_THREADS
1679 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H