File indexing completed on 2025-12-16 10:14:17
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
0011 #define EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
0012
0013 namespace Eigen {
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 template <typename ExpressionType, typename DeviceType> class TensorDevice {
0028 public:
0029 TensorDevice(const DeviceType& device, ExpressionType& expression) : m_device(device), m_expression(expression) {}
0030
0031 EIGEN_DEFAULT_COPY_CONSTRUCTOR(TensorDevice)
0032
0033 template<typename OtherDerived>
0034 EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
0035 typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
0036 Assign assign(m_expression, other);
0037 internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
0038 return *this;
0039 }
0040
0041 template<typename OtherDerived>
0042 EIGEN_STRONG_INLINE TensorDevice& operator+=(const OtherDerived& other) {
0043 typedef typename OtherDerived::Scalar Scalar;
0044 typedef TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ExpressionType, const OtherDerived> Sum;
0045 Sum sum(m_expression, other);
0046 typedef TensorAssignOp<ExpressionType, const Sum> Assign;
0047 Assign assign(m_expression, sum);
0048 internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
0049 return *this;
0050 }
0051
0052 template<typename OtherDerived>
0053 EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
0054 typedef typename OtherDerived::Scalar Scalar;
0055 typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
0056 Difference difference(m_expression, other);
0057 typedef TensorAssignOp<ExpressionType, const Difference> Assign;
0058 Assign assign(m_expression, difference);
0059 internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
0060 return *this;
0061 }
0062
0063 protected:
0064 const DeviceType& m_device;
0065 ExpressionType& m_expression;
0066 };
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082 template <typename ExpressionType, typename DeviceType, typename DoneCallback>
0083 class TensorAsyncDevice {
0084 public:
0085 TensorAsyncDevice(const DeviceType& device, ExpressionType& expression,
0086 DoneCallback done)
0087 : m_device(device), m_expression(expression), m_done(std::move(done)) {}
0088
0089 template <typename OtherDerived>
0090 EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) {
0091 typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
0092 typedef internal::TensorExecutor<const Assign, DeviceType> Executor;
0093
0094 Assign assign(m_expression, other);
0095 Executor::run(assign, m_device);
0096 m_done();
0097
0098 return *this;
0099 }
0100
0101 protected:
0102 const DeviceType& m_device;
0103 ExpressionType& m_expression;
0104 DoneCallback m_done;
0105 };
0106
0107
0108 #ifdef EIGEN_USE_THREADS
0109 template <typename ExpressionType, typename DoneCallback>
0110 class TensorAsyncDevice<ExpressionType, ThreadPoolDevice, DoneCallback> {
0111 public:
0112 TensorAsyncDevice(const ThreadPoolDevice& device, ExpressionType& expression,
0113 DoneCallback done)
0114 : m_device(device), m_expression(expression), m_done(std::move(done)) {}
0115
0116 template <typename OtherDerived>
0117 EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) {
0118 typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
0119 typedef internal::TensorAsyncExecutor<const Assign, ThreadPoolDevice, DoneCallback> Executor;
0120
0121
0122 Assign assign(m_expression, other);
0123 Executor::runAsync(assign, m_device, std::move(m_done));
0124
0125 return *this;
0126 }
0127
0128 protected:
0129 const ThreadPoolDevice& m_device;
0130 ExpressionType& m_expression;
0131 DoneCallback m_done;
0132 };
0133 #endif
0134
0135 }
0136
0137 #endif