Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 // This file is part of Eigen, a lightweight C++ template library
0002 // for linear algebra.
0003 //
0004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
0005 //
0006 // This Source Code Form is subject to the terms of the Mozilla
0007 // Public License v. 2.0. If a copy of the MPL was not distributed
0008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
0009 
0010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
0011 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
0012 
0013 namespace Eigen {
0014 
0015 /** \class TensorBroadcasting
0016   * \ingroup CXX11_Tensor_Module
0017   *
0018   * \brief Tensor broadcasting class.
0019   *
0020   *
0021   */
0022 namespace internal {
0023 template<typename Broadcast, typename XprType>
0024 struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType>
0025 {
0026   typedef typename XprType::Scalar Scalar;
0027   typedef traits<XprType> XprTraits;
0028   typedef typename XprTraits::StorageKind StorageKind;
0029   typedef typename XprTraits::Index Index;
0030   typedef typename XprType::Nested Nested;
0031   typedef typename remove_reference<Nested>::type _Nested;
0032   static const int NumDimensions = XprTraits::NumDimensions;
0033   static const int Layout = XprTraits::Layout;
0034   typedef typename XprTraits::PointerType PointerType;
0035 };
0036 
0037 template<typename Broadcast, typename XprType>
0038 struct eval<TensorBroadcastingOp<Broadcast, XprType>, Eigen::Dense>
0039 {
0040   typedef const TensorBroadcastingOp<Broadcast, XprType> EIGEN_DEVICE_REF type;
0041 };
0042 
0043 template<typename Broadcast, typename XprType>
0044 struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
0045 {
0046   typedef TensorBroadcastingOp<Broadcast, XprType> type;
0047 };
0048 
0049 template <typename Dims>
0050 struct is_input_scalar {
0051   static const bool value = false;
0052 };
0053 template <>
0054 struct is_input_scalar<Sizes<> > {
0055   static const bool value = true;
0056 };
0057 #ifndef EIGEN_EMULATE_CXX11_META_H
0058 template <typename std::ptrdiff_t... Indices>
0059 struct is_input_scalar<Sizes<Indices...> > {
0060   static const bool value = (Sizes<Indices...>::total_size == 1);
0061 };
0062 #endif
0063 
0064 }  // end namespace internal
0065 
0066 
0067 
0068 template<typename Broadcast, typename XprType>
0069 class TensorBroadcastingOp : public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors>
0070 {
0071   public:
0072   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
0073   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
0074   typedef typename XprType::CoeffReturnType CoeffReturnType;
0075   typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
0076   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
0077   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
0078 
0079   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(const XprType& expr, const Broadcast& broadcast)
0080       : m_xpr(expr), m_broadcast(broadcast) {}
0081 
0082     EIGEN_DEVICE_FUNC
0083     const Broadcast& broadcast() const { return m_broadcast; }
0084 
0085     EIGEN_DEVICE_FUNC
0086     const typename internal::remove_all<typename XprType::Nested>::type&
0087     expression() const { return m_xpr; }
0088 
0089   protected:
0090     typename XprType::Nested m_xpr;
0091     const Broadcast m_broadcast;
0092 };
0093 
0094 
0095 // Eval as rvalue
0096 template<typename Broadcast, typename ArgType, typename Device>
0097 struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
0098 {
0099   typedef TensorBroadcastingOp<Broadcast, ArgType> XprType;
0100   typedef typename XprType::Index Index;
0101   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
0102   typedef DSizes<Index, NumDims> Dimensions;
0103   typedef typename XprType::Scalar Scalar;
0104   typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
0105   typedef typename XprType::CoeffReturnType CoeffReturnType;
0106   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
0107   static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
0108   protected: //  all the non-static fields must have the same access control, otherwise the TensorEvaluator wont be standard layout;
0109   bool isCopy, nByOne, oneByN;
0110   public:
0111   typedef StorageMemory<CoeffReturnType, Device> Storage;
0112   typedef typename Storage::Type EvaluatorPointerType;
0113 
0114   enum {
0115     IsAligned         = TensorEvaluator<ArgType, Device>::IsAligned,
0116     PacketAccess      = TensorEvaluator<ArgType, Device>::PacketAccess,
0117     BlockAccess       = TensorEvaluator<ArgType, Device>::BlockAccess,
0118     PreferBlockAccess = true,
0119     Layout            = TensorEvaluator<ArgType, Device>::Layout,
0120     RawAccess         = false
0121   };
0122 
0123   typedef typename internal::remove_const<Scalar>::type ScalarNoConst;
0124 
0125   // We do block based broadcasting using a trick with 2x tensor rank and 0
0126   // strides. See block method implementation for details.
0127   typedef DSizes<Index, 2 * NumDims> BroadcastDimensions;
0128 
0129   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
0130  typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
0131   typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
0132 
0133   typedef typename TensorEvaluator<const ArgType, Device>::TensorBlock
0134       ArgTensorBlock;
0135 
0136   typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumDims,
0137                                                      Layout, Index>
0138       TensorBlock;
0139   //===--------------------------------------------------------------------===//
0140 
0141   EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
0142       : isCopy(false), nByOne(false), oneByN(false),
0143         m_device(device), m_broadcast(op.broadcast()), m_impl(op.expression(), device)
0144   {
0145 
0146     // The broadcasting op doesn't change the rank of the tensor. One can't broadcast a scalar
0147     // and store the result in a scalar. Instead one should reshape the scalar into a a N-D
0148     // tensor with N >= 1 of 1 element first and then broadcast.
0149     EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
0150     const InputDimensions& input_dims = m_impl.dimensions();
0151     isCopy = true;
0152     for (int i = 0; i < NumDims; ++i) {
0153       eigen_assert(input_dims[i] > 0);
0154       m_dimensions[i] = input_dims[i] * m_broadcast[i];
0155       if (m_broadcast[i] != 1) {
0156         isCopy = false;
0157       }
0158     }
0159 
0160     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
0161       m_inputStrides[0] = 1;
0162       m_outputStrides[0] = 1;
0163       for (int i = 1; i < NumDims; ++i) {
0164         m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
0165         m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
0166       }
0167     } else {
0168       m_inputStrides[NumDims-1] = 1;
0169       m_outputStrides[NumDims-1] = 1;
0170       for (int i = NumDims-2; i >= 0; --i) {
0171         m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
0172         m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
0173       }
0174     }
0175 
0176     if (input_dims[0] == 1) {
0177       oneByN = true;
0178       for (int i = 1; i < NumDims; ++i) {
0179         if (m_broadcast[i] != 1) {
0180           oneByN = false;
0181           break;
0182         }
0183       }
0184     } else if (input_dims[NumDims-1] == 1) {
0185       nByOne = true;
0186       for (int i = 0; i < NumDims-1; ++i) {
0187         if (m_broadcast[i] != 1) {
0188           nByOne = false;
0189           break;
0190         }
0191       }
0192     }
0193 
0194     // Handle special format like NCHW, its input shape is '[1, N..., 1]' and
0195     // broadcast shape is '[N, 1..., N]'
0196     if (!oneByN && !nByOne) {
0197       if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) {
0198         nByOne = true;
0199         oneByN = true;
0200         for (int i = 1; i < NumDims-1; ++i) {
0201           if (m_broadcast[i] != 1) {
0202             nByOne = false;
0203             oneByN = false;
0204             break;
0205           }
0206         }
0207       }
0208     }
0209   }
0210 
0211   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
0212 
0213   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) {
0214     m_impl.evalSubExprsIfNeeded(NULL);
0215     return true;
0216   }
0217 
0218 #ifdef EIGEN_USE_THREADS
0219   template <typename EvalSubExprsCallback>
0220   EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(
0221       EvaluatorPointerType, EvalSubExprsCallback done) {
0222     m_impl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); });
0223   }
0224 #endif  // EIGEN_USE_THREADS
0225 
0226   EIGEN_STRONG_INLINE void cleanup() {
0227     m_impl.cleanup();
0228   }
0229 
0230   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const
0231   {
0232     if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
0233       return m_impl.coeff(0);
0234     }
0235 
0236     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
0237       if (isCopy) {
0238         return m_impl.coeff(index);
0239       } else {
0240         return coeffColMajor(index);
0241       }
0242     } else {
0243       if (isCopy) {
0244         return m_impl.coeff(index);
0245       } else {
0246         return coeffRowMajor(index);
0247       }
0248     }
0249   }
0250 
0251   // TODO: attempt to speed this up. The integer divisions and modulo are slow
0252   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexColMajor(Index index) const {
0253     Index inputIndex = 0;
0254     EIGEN_UNROLL_LOOP
0255     for (int i = NumDims - 1; i > 0; --i) {
0256       const Index idx = index / m_outputStrides[i];
0257       if (internal::index_statically_eq<Broadcast>(i, 1)) {
0258         eigen_assert(idx < m_impl.dimensions()[i]);
0259         inputIndex += idx * m_inputStrides[i];
0260       } else {
0261         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
0262           eigen_assert(idx % m_impl.dimensions()[i] == 0);
0263         } else {
0264           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
0265         }
0266       }
0267       index -= idx * m_outputStrides[i];
0268     }
0269     if (internal::index_statically_eq<Broadcast>(0, 1)) {
0270       eigen_assert(index < m_impl.dimensions()[0]);
0271       inputIndex += index;
0272     } else {
0273       if (internal::index_statically_eq<InputDimensions>(0, 1)) {
0274         eigen_assert(index % m_impl.dimensions()[0] == 0);
0275       } else {
0276         inputIndex += (index % m_impl.dimensions()[0]);
0277       }
0278     }
0279     return inputIndex;
0280   }
0281 
0282   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index) const
0283   {
0284     return m_impl.coeff(indexColMajor(index));
0285   }
0286 
0287   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexRowMajor(Index index) const {
0288     Index inputIndex = 0;
0289     EIGEN_UNROLL_LOOP
0290     for (int i = 0; i < NumDims - 1; ++i) {
0291       const Index idx = index / m_outputStrides[i];
0292       if (internal::index_statically_eq<Broadcast>(i, 1)) {
0293         eigen_assert(idx < m_impl.dimensions()[i]);
0294         inputIndex += idx * m_inputStrides[i];
0295       } else {
0296         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
0297           eigen_assert(idx % m_impl.dimensions()[i] == 0);
0298         } else {
0299           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
0300         }
0301       }
0302       index -= idx * m_outputStrides[i];
0303     }
0304     if (internal::index_statically_eq<Broadcast>(NumDims - 1, 1)) {
0305       eigen_assert(index < m_impl.dimensions()[NumDims - 1]);
0306       inputIndex += index;
0307     } else {
0308       if (internal::index_statically_eq<InputDimensions>(NumDims - 1, 1)) {
0309         eigen_assert(index % m_impl.dimensions()[NumDims - 1] == 0);
0310       } else {
0311         inputIndex += (index % m_impl.dimensions()[NumDims - 1]);
0312       }
0313     }
0314     return inputIndex;
0315   }
0316 
0317   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index) const
0318   {
0319     return m_impl.coeff(indexRowMajor(index));
0320   }
0321 
0322   template<int LoadMode>
0323   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const
0324   {
0325     if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
0326       return internal::pset1<PacketReturnType>(m_impl.coeff(0));
0327     }
0328 
0329     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
0330       if (isCopy) {
0331         #ifdef EIGEN_GPU_COMPILE_PHASE
0332         // See PR 437: on NVIDIA P100 and K20m we observed a x3-4 speed up by enforcing
0333         // unaligned loads here. The reason is unclear though.
0334         return m_impl.template packet<Unaligned>(index);
0335         #else
0336         return m_impl.template packet<LoadMode>(index);
0337         #endif
0338       } else if (oneByN && !nByOne) {
0339         return packetNByOne<LoadMode>(index);
0340       } else if (!oneByN && nByOne) {
0341         return packetOneByN<LoadMode>(index);
0342       } else if (oneByN && nByOne) {
0343         return packetOneByNByOne<LoadMode>(index);
0344       } else {
0345         return packetColMajor<LoadMode>(index);
0346       }
0347     } else {
0348       if (isCopy) {
0349         #ifdef EIGEN_GPU_COMPILE_PHASE
0350         // See above.
0351         return m_impl.template packet<Unaligned>(index);
0352         #else
0353         return m_impl.template packet<LoadMode>(index);
0354         #endif
0355       } else if (oneByN && !nByOne) {
0356         return packetOneByN<LoadMode>(index);
0357       } else if (!oneByN && nByOne) {
0358         return packetNByOne<LoadMode>(index);
0359       } else if (oneByN && nByOne) {
0360         return packetOneByNByOne<LoadMode>(index);
0361       } else {
0362         return packetRowMajor<LoadMode>(index);
0363       }
0364     }
0365   }
0366 
0367   template<int LoadMode>
0368   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne
0369   (Index index) const
0370   {
0371     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
0372     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
0373 
0374     EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
0375     Index startDim, endDim;
0376     Index inputIndex, outputOffset, batchedIndex;
0377 
0378     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
0379       startDim = NumDims - 1;
0380       endDim = 1;
0381     } else {
0382       startDim = 0;
0383       endDim = NumDims - 2;
0384     }
0385 
0386     batchedIndex = index % m_outputStrides[startDim];
0387     inputIndex   = batchedIndex / m_outputStrides[endDim];
0388     outputOffset = batchedIndex % m_outputStrides[endDim];
0389 
0390     if (outputOffset + PacketSize <= m_outputStrides[endDim]) {
0391       values[0] = m_impl.coeff(inputIndex);
0392       return internal::pload1<PacketReturnType>(values);
0393     } else {
0394       EIGEN_UNROLL_LOOP
0395       for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
0396         if (outputOffset + cur < m_outputStrides[endDim]) {
0397           values[i] = m_impl.coeff(inputIndex);
0398         } else {
0399           ++inputIndex;
0400           inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex);
0401           values[i] = m_impl.coeff(inputIndex);
0402           outputOffset = 0;
0403           cur = 0;
0404         }
0405       }
0406       return internal::pload<PacketReturnType>(values);
0407     }
0408   }
0409 
0410   template<int LoadMode>
0411   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index) const
0412   {
0413     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
0414     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
0415 
0416     Index dim, inputIndex;
0417 
0418     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
0419       dim = NumDims - 1;
0420     } else {
0421       dim = 0;
0422     }
0423 
0424     inputIndex = index % m_inputStrides[dim];
0425     if (inputIndex + PacketSize <= m_inputStrides[dim]) {
0426       return m_impl.template packet<Unaligned>(inputIndex);
0427     } else {
0428       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
0429       EIGEN_UNROLL_LOOP
0430       for (int i = 0; i < PacketSize; ++i) {
0431         if (inputIndex > m_inputStrides[dim]-1) {
0432           inputIndex = 0;
0433         }
0434         values[i] = m_impl.coeff(inputIndex++);
0435       }
0436       return internal::pload<PacketReturnType>(values);
0437     }
0438   }
0439 
0440   template<int LoadMode>
0441   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetNByOne(Index index) const
0442   {
0443     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
0444     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
0445 
0446     EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
0447     Index dim, inputIndex, outputOffset;
0448 
0449     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
0450       dim = 1;
0451     } else {
0452       dim = NumDims - 2;
0453     }
0454 
0455     inputIndex   = index / m_outputStrides[dim];
0456     outputOffset = index % m_outputStrides[dim];
0457     if (outputOffset + PacketSize <= m_outputStrides[dim]) {
0458       values[0] = m_impl.coeff(inputIndex);
0459       return internal::pload1<PacketReturnType>(values);
0460     } else {
0461       EIGEN_UNROLL_LOOP
0462       for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
0463         if (outputOffset + cur < m_outputStrides[dim]) {
0464           values[i] = m_impl.coeff(inputIndex);
0465         } else {
0466           values[i] = m_impl.coeff(++inputIndex);
0467           outputOffset = 0;
0468           cur = 0;
0469         }
0470       }
0471       return internal::pload<PacketReturnType>(values);
0472     }
0473   }
0474 
0475   // Ignore the LoadMode and always use unaligned loads since we can't guarantee
0476   // the alignment at compile time.
0477   template<int LoadMode>
0478   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index) const
0479   {
0480     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
0481     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
0482 
0483     const Index originalIndex = index;
0484 
0485     Index inputIndex = 0;
0486     EIGEN_UNROLL_LOOP
0487     for (int i = NumDims - 1; i > 0; --i) {
0488       const Index idx = index / m_outputStrides[i];
0489       if (internal::index_statically_eq<Broadcast>(i, 1)) {
0490         eigen_assert(idx < m_impl.dimensions()[i]);
0491         inputIndex += idx * m_inputStrides[i];
0492       } else {
0493         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
0494           eigen_assert(idx % m_impl.dimensions()[i] == 0);
0495         } else {
0496           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
0497         }
0498       }
0499       index -= idx * m_outputStrides[i];
0500     }
0501     Index innermostLoc;
0502     if (internal::index_statically_eq<Broadcast>(0, 1)) {
0503       eigen_assert(index < m_impl.dimensions()[0]);
0504       innermostLoc = index;
0505     } else {
0506       if (internal::index_statically_eq<InputDimensions>(0, 1)) {
0507         eigen_assert(index % m_impl.dimensions()[0] == 0);
0508         innermostLoc = 0;
0509       } else {
0510         innermostLoc = index % m_impl.dimensions()[0];
0511       }
0512     }
0513     inputIndex += innermostLoc;
0514 
0515     // Todo: this could be extended to the second dimension if we're not
0516     // broadcasting alongside the first dimension, and so on.
0517     if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) {
0518       return m_impl.template packet<Unaligned>(inputIndex);
0519     } else {
0520       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
0521       values[0] = m_impl.coeff(inputIndex);
0522       EIGEN_UNROLL_LOOP
0523       for (int i = 1; i < PacketSize; ++i) {
0524         if (innermostLoc + i < m_impl.dimensions()[0]) {
0525           values[i] = m_impl.coeff(inputIndex+i);
0526         } else {
0527           values[i] = coeffColMajor(originalIndex+i);
0528         }
0529       }
0530       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
0531       return rslt;
0532     }
0533   }
0534 
0535   template<int LoadMode>
0536   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index) const
0537   {
0538     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
0539     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
0540 
0541     const Index originalIndex = index;
0542 
0543     Index inputIndex = 0;
0544     EIGEN_UNROLL_LOOP
0545     for (int i = 0; i < NumDims - 1; ++i) {
0546       const Index idx = index / m_outputStrides[i];
0547       if (internal::index_statically_eq<Broadcast>(i, 1)) {
0548         eigen_assert(idx < m_impl.dimensions()[i]);
0549         inputIndex += idx * m_inputStrides[i];
0550       } else {
0551         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
0552           eigen_assert(idx % m_impl.dimensions()[i] == 0);
0553         } else {
0554           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
0555         }
0556       }
0557       index -= idx * m_outputStrides[i];
0558     }
0559     Index innermostLoc;
0560     if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
0561       eigen_assert(index < m_impl.dimensions()[NumDims-1]);
0562       innermostLoc = index;
0563     } else {
0564       if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
0565         eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
0566         innermostLoc = 0;
0567       } else {
0568         innermostLoc = index % m_impl.dimensions()[NumDims-1];
0569       }
0570     }
0571     inputIndex += innermostLoc;
0572 
0573     // Todo: this could be extended to the second dimension if we're not
0574     // broadcasting alongside the first dimension, and so on.
0575     if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims-1]) {
0576       return m_impl.template packet<Unaligned>(inputIndex);
0577     } else {
0578       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
0579       values[0] = m_impl.coeff(inputIndex);
0580       EIGEN_UNROLL_LOOP
0581       for (int i = 1; i < PacketSize; ++i) {
0582         if (innermostLoc + i < m_impl.dimensions()[NumDims-1]) {
0583           values[i] = m_impl.coeff(inputIndex+i);
0584         } else {
0585           values[i] = coeffRowMajor(originalIndex+i);
0586         }
0587       }
0588       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
0589       return rslt;
0590     }
0591   }
0592 
0593   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
0594   costPerCoeff(bool vectorized) const {
0595     double compute_cost = TensorOpCost::AddCost<Index>();
0596     if (!isCopy && NumDims > 0) {
0597       EIGEN_UNROLL_LOOP
0598       for (int i = NumDims - 1; i > 0; --i) {
0599         compute_cost += TensorOpCost::DivCost<Index>();
0600         if (internal::index_statically_eq<Broadcast>(i, 1)) {
0601           compute_cost +=
0602               TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
0603         } else {
0604           if (!internal::index_statically_eq<InputDimensions>(i, 1)) {
0605             compute_cost += TensorOpCost::MulCost<Index>() +
0606                             TensorOpCost::ModCost<Index>() +
0607                             TensorOpCost::AddCost<Index>();
0608           }
0609         }
0610         compute_cost +=
0611             TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
0612       }
0613     }
0614     return m_impl.costPerCoeff(vectorized) +
0615            TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
0616   }
0617 
0618   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
0619   internal::TensorBlockResourceRequirements getResourceRequirements() const {
0620     // TODO(wuke): Targeting L1 size is 30% faster than targeting L{-1} on large
0621     // tensors. But this might need further tuning.
0622     const size_t target_size = m_device.firstLevelCacheSize();
0623     return internal::TensorBlockResourceRequirements::merge(
0624         m_impl.getResourceRequirements(),
0625         internal::TensorBlockResourceRequirements::skewed<Scalar>(target_size));
0626   }
0627 
0628   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock
0629   block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
0630           bool /*root_of_expr_ast*/ = false) const {
0631     BlockBroadcastingParams params = blockBroadcastingParams(desc);
0632 
0633     if (params.inner_dim_size == 0 || params.bcast_dim_size == 0) {
0634       return emptyBlock();
0635     }
0636 
0637     // Prepare storage for the materialized broadcasting result.
0638     const typename TensorBlock::Storage block_storage =
0639         TensorBlock::prepareStorage(desc, scratch);
0640     ScalarNoConst* materialized_output = block_storage.data();
0641 
0642     // We potentially will need to materialize input blocks.
0643     size_t materialized_input_size = 0;
0644     ScalarNoConst* materialized_input = NULL;
0645 
0646     // Initialize block broadcating iterator state for outer dimensions (outer
0647     // with regard to bcast dimension). Dimension in this array are always in
0648     // inner_most -> outer_most order (col major layout).
0649     array<BlockBroadcastingIteratorState, NumDims> it;
0650     int idx = 0;
0651 
0652     for (int i = params.inner_dim_count + 1; i < NumDims; ++i) {
0653       const Index dim = IsColMajor ? i : NumDims - 1 - i;
0654       it[idx].size = params.output_dims[dim];
0655       it[idx].count = 0;
0656       it[idx].output_stride = m_outputStrides[dim];
0657       it[idx].output_span = it[idx].output_stride * (it[idx].size - 1);
0658       idx++;
0659     }
0660 
0661     // Write output into the beginning of `materialized_output`.
0662     Index output_offset = 0;
0663 
0664     // We will fill output block by broadcasting along the bcast dim, and
0665     // iterating over outer dimension.
0666     const Index output_size = NumDims == 0 ? 1 : params.output_dims.TotalSize();
0667 
0668     for (Index num_output_coeffs = 0; num_output_coeffs < output_size;) {
0669       ScalarNoConst* bcast_output = materialized_output + num_output_coeffs;
0670       Index bcast_offset = desc.offset() + output_offset;
0671 
0672       // Broadcast along the bcast dimension.
0673       num_output_coeffs += BroadcastBlockAlongBcastDim(
0674           params, bcast_offset, scratch, bcast_output, &materialized_input,
0675           &materialized_input_size);
0676 
0677       // Switch to the next outer dimension.
0678       for (int j = 0; j < idx; ++j) {
0679         if (++it[j].count < it[j].size) {
0680           output_offset += it[j].output_stride;
0681           break;
0682         }
0683         it[j].count = 0;
0684         output_offset -= it[j].output_span;
0685       }
0686     }
0687 
0688     return block_storage.AsTensorMaterializedBlock();
0689   }
0690 
0691   EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
0692 
0693   const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
0694 
0695   Broadcast functor() const { return m_broadcast; }
0696 #ifdef EIGEN_USE_SYCL
0697   // binding placeholder accessors to a command group handler for SYCL
0698   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(
0699       cl::sycl::handler& cgh) const {
0700     m_impl.bind(cgh);
0701   }
0702 #endif
0703  private:
0704   static const bool IsColMajor =
0705       static_cast<int>(Layout) == static_cast<int>(ColMajor);
0706 
0707   // We will build a general case block broadcasting on top of broadcasting
0708   // primitive that will do broadcasting only for the inner dimension(s) along
0709   // the first dimension smaller than the input size (it's called `bcast_dim`).
0710   //
0711   // Example:
0712   //           dim:  0  1  2   (ColMajor)
0713   //    input size: [9, 3, 6]
0714   //    block size: [9, 2, 6]
0715   //
0716   // We will compute broadcasted block by iterating over the outer dimensions
0717   // before `bcast_dim` (only dimension `2` in this example) and computing
0718   // broadcasts along the `bcast_dim` (dimension `1` in this example).
0719 
0720   // BlockBroadcastingParams holds precomputed parameters for broadcasting a
0721   // single block along the broadcasting dimension. Sizes and strides along the
0722   // `bcast_dim` might be invalid, they will be adjusted later in
0723   // `BroadcastBlockAlongBcastDim`.
0724   struct BlockBroadcastingParams {
0725     Dimensions input_dims;      // input expression dimensions
0726     Dimensions output_dims;     // output block sizes
0727     Dimensions output_strides;  // output block strides
0728 
0729     int inner_dim_count;   // count inner dimensions matching in size
0730     int bcast_dim;         // broadcasting dimension index
0731     Index bcast_dim_size;  // broadcasting dimension size
0732     Index inner_dim_size;  // inner dimensions size
0733 
0734     // Block sizes and strides for the input block where all dimensions before
0735     // `bcast_dim` are equal to `1`.
0736     Dimensions input_block_sizes;
0737     Dimensions input_block_strides;
0738 
0739     // Block sizes and strides for blocks with extra dimensions and strides `0`.
0740     BroadcastDimensions bcast_block_sizes;
0741     BroadcastDimensions bcast_block_strides;
0742     BroadcastDimensions bcast_input_strides;
0743   };
0744 
0745   struct BlockBroadcastingIteratorState {
0746     Index size;
0747     Index count;
0748     Index output_stride;
0749     Index output_span;
0750   };
0751 
0752   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockBroadcastingParams
0753   blockBroadcastingParams(TensorBlockDesc& desc) const {
0754     BlockBroadcastingParams params;
0755 
0756     params.input_dims = Dimensions(m_impl.dimensions());
0757 
0758     // Output block sizes and strides.
0759     params.output_dims = desc.dimensions();
0760     params.output_strides = internal::strides<Layout>(params.output_dims);
0761 
0762     // Find the broadcasting dimension (first dimension with output size smaller
0763     // that the input size).
0764     params.bcast_dim = 0;
0765     params.bcast_dim_size = 1;
0766     params.inner_dim_size = 1;
0767 
0768     // Count the number of inner dimensions that have the same size in the block
0769     // and in the broadcast expression.
0770     params.inner_dim_count = 0;
0771 
0772     for (int i = 0; i < NumDims; ++i) {
0773       const int dim = IsColMajor ? i : NumDims - i - 1;
0774 
0775       if (params.output_dims[dim] == m_dimensions[dim]) {
0776         params.inner_dim_size *= params.output_dims[dim];
0777         ++params.inner_dim_count;
0778         continue;
0779       }
0780 
0781       // First non-matching dimension is the broadcasting dimension.
0782       eigen_assert(params.output_dims[dim] < m_dimensions[dim]);
0783       params.bcast_dim = dim;
0784       params.bcast_dim_size = params.output_dims[dim];
0785       break;
0786     }
0787 
0788     // Calculate the input block size for looking into the input.
0789     for (int i = 0; i < params.inner_dim_count; ++i) {
0790       const int dim = IsColMajor ? i : NumDims - i - 1;
0791       params.input_block_sizes[dim] = params.input_dims[dim];
0792     }
0793     for (int i = params.inner_dim_count; i < NumDims; ++i) {
0794       const int dim = IsColMajor ? i : NumDims - i - 1;
0795       params.input_block_sizes[dim] = 1;
0796     }
0797     params.input_block_strides =
0798         internal::strides<Layout>(params.input_block_sizes);
0799 
0800     // Broadcast with the 0-stride trick: Create 1 extra dim for each
0801     // broadcast, set the input stride to 0.
0802     //
0803     // When ColMajor:
0804     //
0805     // - bcast_block_sizes:
0806     //   [d_0, b_0, d_1, b_1, ...]
0807     //
0808     // - bcast_block_strides:
0809     //   [output_block_strides[0], output_block_strides[0] * d_0,
0810     //    output_block_strides[1], output_block_strides[1] * d_1,
0811     //   ...]
0812     //
0813     // - bcast_input_strides:
0814     //   [input_block_strides[0], 0,
0815     //    input_block_strides[1], 0,
0816     //   ...].
0817     //
0818     for (int i = 0; i < params.inner_dim_count; ++i) {
0819       const int dim = IsColMajor ? i : NumDims - i - 1;
0820 
0821       const int copy_dim = IsColMajor ? 2 * i : 2 * NumDims - 2 * i - 1;
0822       const int broadcast_dim = IsColMajor ? copy_dim + 1 : copy_dim - 1;
0823 
0824       params.bcast_block_sizes[copy_dim] = params.input_dims[dim];
0825       params.bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
0826       params.bcast_block_strides[copy_dim] = params.output_strides[dim];
0827       params.bcast_block_strides[broadcast_dim] =
0828           params.output_strides[dim] * params.input_dims[dim];
0829       params.bcast_input_strides[copy_dim] = params.input_block_strides[dim];
0830       params.bcast_input_strides[broadcast_dim] = 0;
0831     }
0832 
0833     for (int i = 2 * params.inner_dim_count; i < 2 * NumDims; ++i) {
0834       const int dim = IsColMajor ? i : 2 * NumDims - i - 1;
0835       params.bcast_block_sizes[dim] = 1;
0836       params.bcast_block_strides[dim] = 0;
0837       params.bcast_input_strides[dim] = 0;
0838     }
0839 
0840     return params;
0841   }
0842 
0843   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock emptyBlock() const {
0844     DSizes<Index, NumDims> dimensions;
0845     for (int i = 0; i < NumDims; ++i) dimensions[i] = 0;
0846     return TensorBlock(internal::TensorBlockKind::kView, NULL, dimensions);
0847   }
0848 
0849   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlockAlongBcastDim(
0850       BlockBroadcastingParams params, Index bcast_offset,
0851       TensorBlockScratch& scratch, ScalarNoConst* materialized_output,
0852       ScalarNoConst** materialized_input,
0853       size_t* materialized_input_size) const {
0854     if (params.bcast_dim_size == 1) {
0855       // We just need one block read using the ready-set values above.
0856       return BroadcastBlock(
0857           params.input_block_sizes, params.input_block_strides,
0858           params.bcast_block_sizes, params.bcast_block_strides,
0859           params.bcast_input_strides, bcast_offset, 0, scratch,
0860           materialized_output, materialized_input, materialized_input_size);
0861 
0862     } else if (params.input_dims[params.bcast_dim] == 1) {
0863       // Broadcast bcast dimension (< NumDims) by bcast_dim_size.
0864       const int broadcast_bcast_dim =
0865           IsColMajor ? 2 * params.inner_dim_count + 1
0866                      : 2 * NumDims - 2 * params.inner_dim_count - 2;
0867 
0868       params.bcast_block_sizes[broadcast_bcast_dim] = params.bcast_dim_size;
0869       params.bcast_input_strides[broadcast_bcast_dim] = 0;
0870       params.bcast_block_strides[broadcast_bcast_dim] =
0871           params.output_strides[params.bcast_dim];
0872 
0873       return BroadcastBlock(
0874           params.input_block_sizes, params.input_block_strides,
0875           params.bcast_block_sizes, params.bcast_block_strides,
0876           params.bcast_input_strides, bcast_offset, 0, scratch,
0877           materialized_output, materialized_input, materialized_input_size);
0878 
0879     } else {
0880       // Keep track of the total number of the coefficients written to the
0881       // output block.
0882       Index num_output_coeffs = 0;
0883 
0884       // The general case. Let's denote the output block as
0885       //
0886       //   x[..., a:a+bcast_dim_size, :, ..., :]
0887       //
0888       // where a:a+bcast_dim_size is a slice on the bcast_dim dimension
0889       // (< NumDims). We need to split the a:a+bcast_dim_size into possibly 3
0890       // sub-blocks:
0891       //
0892       // (1) a:b, where b is the smallest multiple of
0893       //     input_dims[bcast_dim_start] in [a, a+bcast_dim_size].
0894       //
0895       // (2) b:c, where c is the largest multiple of input_dims[bcast_dim_start]
0896       //     in [a, a+bcast_dim_size].
0897       //
0898       // (3) c:a+bcast_dim_size .
0899       //
0900       // Or, when b and c do not exist, we just need to process the whole block
0901       // together.
0902 
0903       // Find a.
0904       const Index bcast_dim_left_index =
0905           bcast_offset / m_outputStrides[params.bcast_dim];
0906 
0907       // Find b and c.
0908       const Index input_bcast_dim_size = params.input_dims[params.bcast_dim];
0909 
0910       // First multiple after a. This is b when <= bcast_dim_left_index +
0911       // bcast_dim_size.
0912       const Index first_multiple =
0913           divup<Index>(bcast_dim_left_index, input_bcast_dim_size) *
0914           input_bcast_dim_size;
0915 
0916       if (first_multiple <= bcast_dim_left_index + params.bcast_dim_size) {
0917         // b exists, so does c. Find it.
0918         const Index last_multiple =
0919             (bcast_dim_left_index + params.bcast_dim_size) /
0920             input_bcast_dim_size * input_bcast_dim_size;
0921         const int copy_bcast_dim =
0922             IsColMajor ? 2 * params.inner_dim_count
0923                        : 2 * NumDims - 2 * params.inner_dim_count - 1;
0924         const int broadcast_bcast_dim =
0925             IsColMajor ? 2 * params.inner_dim_count + 1
0926                        : 2 * NumDims - 2 * params.inner_dim_count - 2;
0927 
0928         if (first_multiple > bcast_dim_left_index) {
0929           const Index head_size = first_multiple - bcast_dim_left_index;
0930           params.input_block_sizes[params.bcast_dim] = head_size;
0931           params.bcast_block_sizes[copy_bcast_dim] = head_size;
0932           params.bcast_input_strides[copy_bcast_dim] =
0933               params.input_block_strides[params.bcast_dim];
0934           params.bcast_block_strides[copy_bcast_dim] =
0935               params.output_strides[params.bcast_dim];
0936           params.bcast_block_sizes[broadcast_bcast_dim] = 1;
0937           params.bcast_input_strides[broadcast_bcast_dim] = 0;
0938           params.bcast_block_strides[broadcast_bcast_dim] =
0939               params.output_strides[params.bcast_dim] *
0940               params.input_dims[params.bcast_dim];
0941 
0942           num_output_coeffs += BroadcastBlock(
0943               params.input_block_sizes, params.input_block_strides,
0944               params.bcast_block_sizes, params.bcast_block_strides,
0945               params.bcast_input_strides, bcast_offset, 0, scratch,
0946               materialized_output, materialized_input, materialized_input_size);
0947         }
0948         if (first_multiple < last_multiple) {
0949           params.input_block_sizes[params.bcast_dim] = input_bcast_dim_size;
0950           params.bcast_block_sizes[copy_bcast_dim] = input_bcast_dim_size;
0951           params.bcast_input_strides[copy_bcast_dim] =
0952               params.input_block_strides[params.bcast_dim];
0953           params.bcast_block_strides[copy_bcast_dim] =
0954               params.output_strides[params.bcast_dim];
0955           params.bcast_block_sizes[broadcast_bcast_dim] =
0956               (last_multiple - first_multiple) / input_bcast_dim_size;
0957           params.bcast_input_strides[broadcast_bcast_dim] = 0;
0958           params.bcast_block_strides[broadcast_bcast_dim] =
0959               params.output_strides[params.bcast_dim] *
0960               params.input_dims[params.bcast_dim];
0961           const Index offset = (first_multiple - bcast_dim_left_index) *
0962                                m_outputStrides[params.bcast_dim];
0963 
0964           num_output_coeffs += BroadcastBlock(
0965               params.input_block_sizes, params.input_block_strides,
0966               params.bcast_block_sizes, params.bcast_block_strides,
0967               params.bcast_input_strides, bcast_offset, offset, scratch,
0968               materialized_output, materialized_input, materialized_input_size);
0969         }
0970         if (last_multiple < bcast_dim_left_index + params.bcast_dim_size) {
0971           const Index tail_size =
0972               bcast_dim_left_index + params.bcast_dim_size - last_multiple;
0973           params.input_block_sizes[params.bcast_dim] = tail_size;
0974           params.bcast_block_sizes[copy_bcast_dim] = tail_size;
0975           params.bcast_input_strides[copy_bcast_dim] =
0976               params.input_block_strides[params.bcast_dim];
0977           params.bcast_block_strides[copy_bcast_dim] =
0978               params.output_strides[params.bcast_dim];
0979           params.bcast_block_sizes[broadcast_bcast_dim] = 1;
0980           params.bcast_input_strides[broadcast_bcast_dim] = 0;
0981           params.bcast_block_strides[broadcast_bcast_dim] =
0982               params.output_strides[params.bcast_dim] *
0983               params.input_dims[params.bcast_dim];
0984           const Index offset = (last_multiple - bcast_dim_left_index) *
0985                                m_outputStrides[params.bcast_dim];
0986 
0987           num_output_coeffs += BroadcastBlock(
0988               params.input_block_sizes, params.input_block_strides,
0989               params.bcast_block_sizes, params.bcast_block_strides,
0990               params.bcast_input_strides, bcast_offset, offset, scratch,
0991               materialized_output, materialized_input, materialized_input_size);
0992         }
0993       } else {
0994         // b and c do not exist.
0995         const int copy_bcast_dim =
0996             IsColMajor ? 2 * params.inner_dim_count
0997                        : 2 * NumDims - 2 * params.inner_dim_count - 1;
0998         params.input_block_sizes[params.bcast_dim] = params.bcast_dim_size;
0999         params.bcast_block_sizes[copy_bcast_dim] = params.bcast_dim_size;
1000         params.bcast_input_strides[copy_bcast_dim] =
1001             params.input_block_strides[params.bcast_dim];
1002         params.bcast_block_strides[copy_bcast_dim] =
1003             params.output_strides[params.bcast_dim];
1004 
1005         num_output_coeffs += BroadcastBlock(
1006             params.input_block_sizes, params.input_block_strides,
1007             params.bcast_block_sizes, params.bcast_block_strides,
1008             params.bcast_input_strides, bcast_offset, 0, scratch,
1009             materialized_output, materialized_input, materialized_input_size);
1010       }
1011 
1012       return num_output_coeffs;
1013     }
1014   }
1015 
1016   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlock(
1017       const Dimensions& input_block_sizes,
1018       const Dimensions& input_block_strides,
1019       const BroadcastDimensions& bcast_block_sizes,
1020       const BroadcastDimensions& bcast_block_strides,
1021       const BroadcastDimensions& bcast_input_strides, Index bcast_offset,
1022       Index offset, TensorBlockScratch& scratch,
1023       ScalarNoConst* materialized_output, ScalarNoConst** materialized_input,
1024       size_t* materialized_input_size) const {
1025     // ---------------------------------------------------------------------- //
1026     // Tensor block descriptor for reading block from the input.
1027     const Index input_offset = bcast_offset + offset;
1028     TensorBlockDesc input_desc(
1029         IsColMajor ? indexColMajor(input_offset) : indexRowMajor(input_offset),
1030         input_block_sizes);
1031 
1032     ArgTensorBlock input_block = m_impl.block(input_desc, scratch);
1033 
1034     // ---------------------------------------------------------------------- //
1035     // Materialize input block into a temporary memory buffer only if it's not
1036     // already available in the arg block.
1037     const ScalarNoConst* input_buffer = NULL;
1038 
1039     if (input_block.data() != NULL) {
1040       // Input block already has raw data, there is no need to materialize it.
1041       input_buffer = input_block.data();
1042 
1043     } else {
1044       // Otherwise we have to do block assignment into a temporary buffer.
1045 
1046       // Maybe reuse previously allocated buffer, or allocate a new one with a
1047       // scratch allocator.
1048       const size_t input_total_size = input_block_sizes.TotalSize();
1049       if (*materialized_input == NULL ||
1050           *materialized_input_size < input_total_size) {
1051         *materialized_input_size = input_total_size;
1052         void* mem = scratch.allocate(*materialized_input_size * sizeof(Scalar));
1053         *materialized_input = static_cast<ScalarNoConst*>(mem);
1054       }
1055 
1056       typedef internal::TensorBlockAssignment<
1057           ScalarNoConst, NumDims, typename ArgTensorBlock::XprType, Index>
1058           TensorBlockAssignment;
1059 
1060       TensorBlockAssignment::Run(
1061           TensorBlockAssignment::target(input_block_sizes, input_block_strides,
1062                                         *materialized_input),
1063           input_block.expr());
1064 
1065       input_buffer = *materialized_input;
1066     }
1067 
1068     // ---------------------------------------------------------------------- //
1069     // Copy data from materialized input block to the materialized output, using
1070     // given broadcast strides (strides with zeroes).
1071     typedef internal::TensorBlockIO<ScalarNoConst, Index, 2 * NumDims, Layout>
1072         TensorBlockIO;
1073 
1074     typename TensorBlockIO::Src src(bcast_input_strides, input_buffer);
1075     typename TensorBlockIO::Dst dst(bcast_block_sizes, bcast_block_strides,
1076                                       materialized_output + offset);
1077 
1078     return TensorBlockIO::Copy(dst, src);
1079   }
1080 
1081 protected:
1082   const Device EIGEN_DEVICE_REF m_device;
1083   const typename internal::remove_reference<Broadcast>::type m_broadcast;
1084   Dimensions m_dimensions;
1085   array<Index, NumDims> m_outputStrides;
1086   array<Index, NumDims> m_inputStrides;
1087   TensorEvaluator<ArgType, Device> m_impl;
1088 };
1089 
1090 
1091 } // end namespace Eigen
1092 
1093 #endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H