Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:43:02

0001 //
0002 //  Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
0003 //
0004 //  Distributed under the Boost Software License, Version 1.0. (See
0005 //  accompanying file LICENSE_1_0.txt or copy at
0006 //  http://www.boost.org/LICENSE_1_0.txt)
0007 //
0008 //  The authors gratefully acknowledge the support of
0009 //  Fraunhofer IOSB, Ettlingen, Germany
0010 //
0011 
0012 #ifndef _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
0013 #define _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
0014 
0015 #include <type_traits>
0016 #include <stdexcept>
0017 
0018 
0019 namespace boost::numeric::ublas {
0020 
0021 template<class element_type, class storage_format, class storage_type>
0022 class tensor;
0023 
0024 template<class size_type>
0025 class basic_extents;
0026 
0027 }
0028 
0029 namespace boost::numeric::ublas::detail {
0030 
0031 template<class T, class D>
0032 struct tensor_expression;
0033 
0034 template<class T, class EL, class ER, class OP>
0035 struct binary_tensor_expression;
0036 
0037 template<class T, class E, class OP>
0038 struct unary_tensor_expression;
0039 
0040 }
0041 
0042 namespace boost::numeric::ublas::detail {
0043 
0044 template<class T, class E>
0045 struct has_tensor_types
0046 { static constexpr bool value = false; };
0047 
0048 template<class T>
0049 struct has_tensor_types<T,T>
0050 { static constexpr bool value = true; };
0051 
0052 template<class T, class D>
0053 struct has_tensor_types<T, tensor_expression<T,D>>
0054 { static constexpr bool value = std::is_same<T,D>::value || has_tensor_types<T,D>::value; };
0055 
0056 
0057 template<class T, class EL, class ER, class OP>
0058 struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
0059 { static constexpr bool value = std::is_same<T,EL>::value || std::is_same<T,ER>::value || has_tensor_types<T,EL>::value || has_tensor_types<T,ER>::value;  };
0060 
0061 template<class T, class E, class OP>
0062 struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
0063 { static constexpr bool value = std::is_same<T,E>::value || has_tensor_types<T,E>::value; };
0064 
0065 } // namespace boost::numeric::ublas::detail
0066 
0067 
0068 namespace boost::numeric::ublas::detail {
0069 
0070 
0071 
0072 
0073 
0074 /** @brief Retrieves extents of the tensor
0075  *
0076 */
0077 template<class T, class F, class A>
0078 auto retrieve_extents(tensor<T,F,A> const& t)
0079 {
0080     return t.extents();
0081 }
0082 
0083 /** @brief Retrieves extents of the tensor expression
0084  *
0085  * @note tensor expression must be a binary tree with at least one tensor type
0086  *
0087  * @returns extents of the child expression if it is a tensor or extents of one child of its child.
0088 */
0089 template<class T, class D>
0090 auto retrieve_extents(tensor_expression<T,D> const& expr)
0091 {
0092     static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
0093                   "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
0094 
0095     auto const& cast_expr = static_cast<D const&>(expr);
0096 
0097     if constexpr ( std::is_same<T,D>::value )
0098         return cast_expr.extents();
0099     else
0100     return retrieve_extents(cast_expr);
0101 }
0102 
0103 /** @brief Retrieves extents of the binary tensor expression
0104  *
0105  * @note tensor expression must be a binary tree with at least one tensor type
0106  *
0107  * @returns extents of the (left and if necessary then right) child expression if it is a tensor or extents of a child of its (left and if necessary then right) child.
0108 */
0109 template<class T, class EL, class ER, class OP>
0110 auto retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
0111 {
0112     static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
0113                   "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
0114 
0115     if constexpr ( std::is_same<T,EL>::value )
0116         return expr.el.extents();
0117 
0118     if constexpr ( std::is_same<T,ER>::value )
0119         return expr.er.extents();
0120 
0121     else if constexpr ( detail::has_tensor_types<T,EL>::value )
0122         return retrieve_extents(expr.el);
0123 
0124     else if constexpr ( detail::has_tensor_types<T,ER>::value  )
0125         return retrieve_extents(expr.er);
0126 }
0127 
0128 /** @brief Retrieves extents of the binary tensor expression
0129  *
0130  * @note tensor expression must be a binary tree with at least one tensor type
0131  *
0132  * @returns extents of the child expression if it is a tensor or extents of a child of its child.
0133 */
0134 template<class T, class E, class OP>
0135 auto retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
0136 {
0137 
0138     static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
0139                   "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
0140 
0141     if constexpr ( std::is_same<T,E>::value )
0142         return expr.e.extents();
0143 
0144     else if constexpr ( detail::has_tensor_types<T,E>::value  )
0145         return retrieve_extents(expr.e);
0146 }
0147 
0148 } // namespace boost::numeric::ublas::detail
0149 
0150 
0151 ///////////////
0152 
0153 namespace boost::numeric::ublas::detail {
0154 
0155 template<class T, class F, class A, class S>
0156 auto all_extents_equal(tensor<T,F,A> const& t, basic_extents<S> const& extents)
0157 {
0158     return extents == t.extents();
0159 }
0160 
0161 template<class T, class D, class S>
0162 auto all_extents_equal(tensor_expression<T,D> const& expr, basic_extents<S> const& extents)
0163 {
0164     static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
0165                   "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
0166     auto const& cast_expr = static_cast<D const&>(expr);
0167 
0168 
0169     if constexpr ( std::is_same<T,D>::value )
0170         if( extents != cast_expr.extents() )
0171         return false;
0172 
0173     if constexpr ( detail::has_tensor_types<T,D>::value )
0174         if ( !all_extents_equal(cast_expr, extents))
0175         return false;
0176 
0177     return true;
0178 
0179 }
0180 
0181 template<class T, class EL, class ER, class OP, class S>
0182 auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, basic_extents<S> const& extents)
0183 {
0184     static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
0185                   "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
0186 
0187     if constexpr ( std::is_same<T,EL>::value )
0188         if(extents !=  expr.el.extents())
0189         return false;
0190 
0191     if constexpr ( std::is_same<T,ER>::value )
0192         if(extents != expr.er.extents())
0193         return false;
0194 
0195     if constexpr ( detail::has_tensor_types<T,EL>::value )
0196         if(!all_extents_equal(expr.el, extents))
0197         return false;
0198 
0199     if constexpr ( detail::has_tensor_types<T,ER>::value )
0200         if(!all_extents_equal(expr.er, extents))
0201         return false;
0202 
0203     return true;
0204 }
0205 
0206 
0207 template<class T, class E, class OP, class S>
0208 auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, basic_extents<S> const& extents)
0209 {
0210 
0211     static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
0212                   "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
0213 
0214     if constexpr ( std::is_same<T,E>::value )
0215         if(extents != expr.e.extents())
0216         return false;
0217 
0218     if constexpr ( detail::has_tensor_types<T,E>::value )
0219         if(!all_extents_equal(expr.e, extents))
0220         return false;
0221 
0222     return true;
0223 }
0224 
0225 } // namespace boost::numeric::ublas::detail
0226 
0227 
0228 namespace boost::numeric::ublas::detail {
0229 
0230 
0231 /** @brief Evaluates expression for a tensor
0232  *
0233  * Assigns the results of the expression to the tensor.
0234  *
0235  * \note Checks if shape of the tensor matches those of all tensors within the expression.
0236 */
0237 template<class tensor_type, class derived_type>
0238 void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr)
0239 {
0240     if constexpr (detail::has_tensor_types<tensor_type, tensor_expression<tensor_type,derived_type> >::value )
0241         if(!detail::all_extents_equal(expr, lhs.extents() ))
0242         throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
0243 
0244 #pragma omp parallel for
0245     for(auto i = 0u; i < lhs.size(); ++i)
0246         lhs(i) = expr()(i);
0247 }
0248 
0249 /** @brief Evaluates expression for a tensor
0250  *
0251  * Applies a unary function to the results of the expressions before the assignment.
0252  * Usually applied needed for unary operators such as A += C;
0253  *
0254  * \note Checks if shape of the tensor matches those of all tensors within the expression.
0255 */
0256 template<class tensor_type, class derived_type, class unary_fn>
0257 void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr, unary_fn const fn)
0258 {
0259 
0260     if constexpr (detail::has_tensor_types< tensor_type, tensor_expression<tensor_type,derived_type> >::value )
0261         if(!detail::all_extents_equal( expr, lhs.extents() ))
0262         throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
0263 
0264 #pragma omp parallel for
0265     for(auto i = 0u; i < lhs.size(); ++i)
0266         fn(lhs(i), expr()(i));
0267 }
0268 
0269 
0270 
0271 /** @brief Evaluates expression for a tensor
0272  *
0273  * Applies a unary function to the results of the expressions before the assignment.
0274  * Usually applied needed for unary operators such as A += C;
0275  *
0276  * \note Checks if shape of the tensor matches those of all tensors within the expression.
0277 */
0278 template<class tensor_type, class unary_fn>
0279 void eval(tensor_type& lhs, unary_fn const fn)
0280 {
0281 #pragma omp parallel for
0282     for(auto i = 0u; i < lhs.size(); ++i)
0283         fn(lhs(i));
0284 }
0285 
0286 
0287 }
0288 #endif