File indexing completed on 2025-01-18 09:43:02
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012 #ifndef _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
0013 #define _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
0014
0015 #include <type_traits>
0016 #include <stdexcept>
0017
0018
0019 namespace boost::numeric::ublas {
0020
0021 template<class element_type, class storage_format, class storage_type>
0022 class tensor;
0023
0024 template<class size_type>
0025 class basic_extents;
0026
0027 }
0028
0029 namespace boost::numeric::ublas::detail {
0030
0031 template<class T, class D>
0032 struct tensor_expression;
0033
0034 template<class T, class EL, class ER, class OP>
0035 struct binary_tensor_expression;
0036
0037 template<class T, class E, class OP>
0038 struct unary_tensor_expression;
0039
0040 }
0041
0042 namespace boost::numeric::ublas::detail {
0043
0044 template<class T, class E>
0045 struct has_tensor_types
0046 { static constexpr bool value = false; };
0047
0048 template<class T>
0049 struct has_tensor_types<T,T>
0050 { static constexpr bool value = true; };
0051
0052 template<class T, class D>
0053 struct has_tensor_types<T, tensor_expression<T,D>>
0054 { static constexpr bool value = std::is_same<T,D>::value || has_tensor_types<T,D>::value; };
0055
0056
0057 template<class T, class EL, class ER, class OP>
0058 struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
0059 { static constexpr bool value = std::is_same<T,EL>::value || std::is_same<T,ER>::value || has_tensor_types<T,EL>::value || has_tensor_types<T,ER>::value; };
0060
0061 template<class T, class E, class OP>
0062 struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
0063 { static constexpr bool value = std::is_same<T,E>::value || has_tensor_types<T,E>::value; };
0064
0065 }
0066
0067
0068 namespace boost::numeric::ublas::detail {
0069
0070
0071
0072
0073
0074
0075
0076
0077 template<class T, class F, class A>
0078 auto retrieve_extents(tensor<T,F,A> const& t)
0079 {
0080 return t.extents();
0081 }
0082
0083
0084
0085
0086
0087
0088
0089 template<class T, class D>
0090 auto retrieve_extents(tensor_expression<T,D> const& expr)
0091 {
0092 static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
0093 "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
0094
0095 auto const& cast_expr = static_cast<D const&>(expr);
0096
0097 if constexpr ( std::is_same<T,D>::value )
0098 return cast_expr.extents();
0099 else
0100 return retrieve_extents(cast_expr);
0101 }
0102
0103
0104
0105
0106
0107
0108
0109 template<class T, class EL, class ER, class OP>
0110 auto retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
0111 {
0112 static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
0113 "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
0114
0115 if constexpr ( std::is_same<T,EL>::value )
0116 return expr.el.extents();
0117
0118 if constexpr ( std::is_same<T,ER>::value )
0119 return expr.er.extents();
0120
0121 else if constexpr ( detail::has_tensor_types<T,EL>::value )
0122 return retrieve_extents(expr.el);
0123
0124 else if constexpr ( detail::has_tensor_types<T,ER>::value )
0125 return retrieve_extents(expr.er);
0126 }
0127
0128
0129
0130
0131
0132
0133
0134 template<class T, class E, class OP>
0135 auto retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
0136 {
0137
0138 static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
0139 "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
0140
0141 if constexpr ( std::is_same<T,E>::value )
0142 return expr.e.extents();
0143
0144 else if constexpr ( detail::has_tensor_types<T,E>::value )
0145 return retrieve_extents(expr.e);
0146 }
0147
0148 }
0149
0150
0151
0152
0153 namespace boost::numeric::ublas::detail {
0154
0155 template<class T, class F, class A, class S>
0156 auto all_extents_equal(tensor<T,F,A> const& t, basic_extents<S> const& extents)
0157 {
0158 return extents == t.extents();
0159 }
0160
0161 template<class T, class D, class S>
0162 auto all_extents_equal(tensor_expression<T,D> const& expr, basic_extents<S> const& extents)
0163 {
0164 static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
0165 "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
0166 auto const& cast_expr = static_cast<D const&>(expr);
0167
0168
0169 if constexpr ( std::is_same<T,D>::value )
0170 if( extents != cast_expr.extents() )
0171 return false;
0172
0173 if constexpr ( detail::has_tensor_types<T,D>::value )
0174 if ( !all_extents_equal(cast_expr, extents))
0175 return false;
0176
0177 return true;
0178
0179 }
0180
0181 template<class T, class EL, class ER, class OP, class S>
0182 auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, basic_extents<S> const& extents)
0183 {
0184 static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
0185 "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
0186
0187 if constexpr ( std::is_same<T,EL>::value )
0188 if(extents != expr.el.extents())
0189 return false;
0190
0191 if constexpr ( std::is_same<T,ER>::value )
0192 if(extents != expr.er.extents())
0193 return false;
0194
0195 if constexpr ( detail::has_tensor_types<T,EL>::value )
0196 if(!all_extents_equal(expr.el, extents))
0197 return false;
0198
0199 if constexpr ( detail::has_tensor_types<T,ER>::value )
0200 if(!all_extents_equal(expr.er, extents))
0201 return false;
0202
0203 return true;
0204 }
0205
0206
0207 template<class T, class E, class OP, class S>
0208 auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, basic_extents<S> const& extents)
0209 {
0210
0211 static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
0212 "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
0213
0214 if constexpr ( std::is_same<T,E>::value )
0215 if(extents != expr.e.extents())
0216 return false;
0217
0218 if constexpr ( detail::has_tensor_types<T,E>::value )
0219 if(!all_extents_equal(expr.e, extents))
0220 return false;
0221
0222 return true;
0223 }
0224
0225 }
0226
0227
0228 namespace boost::numeric::ublas::detail {
0229
0230
0231
0232
0233
0234
0235
0236
0237 template<class tensor_type, class derived_type>
0238 void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr)
0239 {
0240 if constexpr (detail::has_tensor_types<tensor_type, tensor_expression<tensor_type,derived_type> >::value )
0241 if(!detail::all_extents_equal(expr, lhs.extents() ))
0242 throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
0243
0244 #pragma omp parallel for
0245 for(auto i = 0u; i < lhs.size(); ++i)
0246 lhs(i) = expr()(i);
0247 }
0248
0249
0250
0251
0252
0253
0254
0255
0256 template<class tensor_type, class derived_type, class unary_fn>
0257 void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr, unary_fn const fn)
0258 {
0259
0260 if constexpr (detail::has_tensor_types< tensor_type, tensor_expression<tensor_type,derived_type> >::value )
0261 if(!detail::all_extents_equal( expr, lhs.extents() ))
0262 throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
0263
0264 #pragma omp parallel for
0265 for(auto i = 0u; i < lhs.size(); ++i)
0266 fn(lhs(i), expr()(i));
0267 }
0268
0269
0270
0271
0272
0273
0274
0275
0276
0277
0278 template<class tensor_type, class unary_fn>
0279 void eval(tensor_type& lhs, unary_fn const fn)
0280 {
0281 #pragma omp parallel for
0282 for(auto i = 0u; i < lhs.size(); ++i)
0283 fn(lhs(i));
0284 }
0285
0286
0287 }
0288 #endif