Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:43:02

0001 //
0002 //  Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
0003 //
0004 //  Distributed under the Boost Software License, Version 1.0. (See
0005 //  accompanying file LICENSE_1_0.txt or copy at
0006 //  http://www.boost.org/LICENSE_1_0.txt)
0007 //
0008 //  The authors gratefully acknowledge the support of
0009 //  Fraunhofer IOSB, Ettlingen, Germany
0010 //
0011 
0012 
0013 #ifndef BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP
0014 #define BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP
0015 
0016 #include <algorithm>
0017 #include <initializer_list>
0018 #include <limits>
0019 #include <numeric>
0020 #include <stdexcept>
0021 #include <vector>
0022 
0023 #include <cassert>
0024 
0025 namespace boost {
0026 namespace numeric {
0027 namespace ublas {
0028 
0029 
0030 /** @brief Template class for storing tensor extents with runtime variable size.
0031  *
0032  * Proxy template class of std::vector<int_type>.
0033  *
0034  */
0035 template<class int_type>
0036 class basic_extents
0037 {
0038     static_assert( std::numeric_limits<typename std::vector<int_type>::value_type>::is_integer, "Static error in basic_layout: type must be of type integer.");
0039     static_assert(!std::numeric_limits<typename std::vector<int_type>::value_type>::is_signed,  "Static error in basic_layout: type must be of type unsigned integer.");
0040 
0041 public:
0042     using base_type = std::vector<int_type>;
0043     using value_type = typename base_type::value_type;
0044     using const_reference = typename base_type::const_reference;
0045     using reference = typename base_type::reference;
0046     using size_type = typename base_type::size_type;
0047     using const_pointer = typename base_type::const_pointer;
0048     using const_iterator = typename base_type::const_iterator;
0049 
0050 
0051     /** @brief Default constructs basic_extents
0052      *
0053      * @code auto ex = basic_extents<unsigned>{};
0054      */
0055     constexpr explicit basic_extents()
0056       : _base{}
0057     {
0058     }
0059 
0060     /** @brief Copy constructs basic_extents from a one-dimensional container
0061      *
0062      * @code auto ex = basic_extents<unsigned>(  std::vector<unsigned>(3u,3u) );
0063      *
0064      * @note checks if size > 1 and all elements > 0
0065      *
0066      * @param b one-dimensional std::vector<int_type> container
0067      */
0068     explicit basic_extents(base_type const& b)
0069       : _base(b)
0070     {
0071         if (!this->valid()){
0072             throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
0073         }
0074     }
0075 
0076     /** @brief Move constructs basic_extents from a one-dimensional container
0077      *
0078      * @code auto ex = basic_extents<unsigned>(  std::vector<unsigned>(3u,3u) );
0079      *
0080      * @note checks if size > 1 and all elements > 0
0081      *
0082      * @param b one-dimensional container of type std::vector<int_type>
0083      */
0084     explicit basic_extents(base_type && b)
0085       : _base(std::move(b))
0086     {
0087         if (!this->valid()){
0088             throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
0089         }
0090     }
0091 
0092     /** @brief Constructs basic_extents from an initializer list
0093      *
0094      * @code auto ex = basic_extents<unsigned>{3,2,4};
0095      *
0096      * @note checks if size > 1 and all elements > 0
0097      *
0098      * @param l one-dimensional list of type std::initializer<int_type>
0099      */
0100     basic_extents(std::initializer_list<value_type> l)
0101       : basic_extents( base_type(std::move(l)) )
0102     {
0103     }
0104 
0105     /** @brief Constructs basic_extents from a range specified by two iterators
0106      *
0107      * @code auto ex = basic_extents<unsigned>(a.begin(), a.end());
0108      *
0109      * @note checks if size > 1 and all elements > 0
0110      *
0111      * @param first iterator pointing to the first element
0112      * @param last iterator pointing to the next position after the last element
0113      */
0114     basic_extents(const_iterator first, const_iterator last)
0115       : basic_extents ( base_type( first,last ) )
0116     {
0117     }
0118 
0119     /** @brief Copy constructs basic_extents */
0120     basic_extents(basic_extents const& l )
0121       : _base(l._base)
0122     {
0123     }
0124 
0125     /** @brief Move constructs basic_extents */
0126     basic_extents(basic_extents && l ) noexcept
0127       : _base(std::move(l._base))
0128     {
0129     }
0130 
0131     ~basic_extents() = default;
0132 
0133     basic_extents& operator=(basic_extents other) noexcept
0134     {
0135         swap (*this, other);
0136         return *this;
0137     }
0138 
0139     friend void swap(basic_extents& lhs, basic_extents& rhs) {
0140         std::swap(lhs._base   , rhs._base   );
0141     }
0142 
0143 
0144 
0145     /** @brief Returns true if this has a scalar shape
0146      *
0147      * @returns true if (1,1,[1,...,1])
0148     */
0149     bool is_scalar() const
0150     {
0151         return
0152             _base.size() != 0 &&
0153             std::all_of(_base.begin(), _base.end(),
0154                         [](const_reference a){ return a == 1;});
0155     }
0156 
0157     /** @brief Returns true if this has a vector shape
0158      *
0159      * @returns true if (1,n,[1,...,1]) or (n,1,[1,...,1]) with n > 1
0160     */
0161     bool is_vector() const
0162     {
0163         if(_base.size() == 0){
0164             return false;
0165         }
0166 
0167         if(_base.size() == 1){
0168             return _base.at(0) > 1;
0169         }
0170 
0171         auto greater_one = [](const_reference a){ return a >  1;};
0172         auto equal_one   = [](const_reference a){ return a == 1;};
0173 
0174         return
0175             std::any_of(_base.begin(),   _base.begin()+2, greater_one) &&
0176             std::any_of(_base.begin(),   _base.begin()+2, equal_one  ) &&
0177             std::all_of(_base.begin()+2, _base.end(),     equal_one);
0178     }
0179 
0180     /** @brief Returns true if this has a matrix shape
0181      *
0182      * @returns true if (m,n,[1,...,1]) with m > 1 and n > 1
0183     */
0184     bool is_matrix() const
0185     {
0186         if(_base.size() < 2){
0187             return false;
0188         }
0189 
0190         auto greater_one = [](const_reference a){ return a >  1;};
0191         auto equal_one   = [](const_reference a){ return a == 1;};
0192 
0193         return
0194             std::all_of(_base.begin(),   _base.begin()+2, greater_one) &&
0195             std::all_of(_base.begin()+2, _base.end(),     equal_one  );
0196     }
0197 
0198     /** @brief Returns true if this is has a tensor shape
0199      *
0200      * @returns true if !empty() && !is_scalar() && !is_vector() && !is_matrix()
0201     */
0202     bool is_tensor() const
0203     {
0204         if(_base.size() < 3){
0205             return false;
0206         }
0207 
0208         auto greater_one = [](const_reference a){ return a > 1;};
0209 
0210         return std::any_of(_base.begin()+2, _base.end(), greater_one);
0211     }
0212 
0213     const_pointer data() const
0214     {
0215         return this->_base.data();
0216     }
0217 
0218     const_reference operator[] (size_type p) const
0219     {
0220         return this->_base[p];
0221     }
0222 
0223     const_reference at (size_type p) const
0224     {
0225         return this->_base.at(p);
0226     }
0227 
0228     reference operator[] (size_type p)
0229     {
0230         return this->_base[p];
0231     }
0232 
0233     reference at (size_type p)
0234     {
0235         return this->_base.at(p);
0236     }
0237 
0238 
0239     bool empty() const
0240     {
0241         return this->_base.empty();
0242     }
0243 
0244     size_type size() const
0245     {
0246         return this->_base.size();
0247     }
0248 
0249     /** @brief Returns true if size > 1 and all elements > 0 */
0250     bool valid() const
0251     {
0252         return
0253             this->size() > 1 &&
0254             std::none_of(_base.begin(), _base.end(),
0255                          [](const_reference a){ return a == value_type(0); });
0256     }
0257 
0258     /** @brief Returns the number of elements a tensor holds with this */
0259     size_type product() const
0260     {
0261         if(_base.empty()){
0262             return 0;
0263         }
0264 
0265         return std::accumulate(_base.begin(), _base.end(), 1ul, std::multiplies<>());
0266 
0267     }
0268 
0269 
0270     /** @brief Eliminates singleton dimensions when size > 2
0271      *
0272      * squeeze {  1,1} -> {  1,1}
0273      * squeeze {  2,1} -> {  2,1}
0274      * squeeze {  1,2} -> {  1,2}
0275      *
0276      * squeeze {1,2,3} -> {  2,3}
0277      * squeeze {2,1,3} -> {  2,3}
0278      * squeeze {1,3,1} -> {  3,1}
0279      *
0280     */
0281     basic_extents squeeze() const
0282     {
0283         if(this->size() <= 2){
0284             return *this;
0285         }
0286 
0287         auto new_extent = basic_extents{};
0288         auto insert_iter = std::back_insert_iterator<typename basic_extents::base_type>(new_extent._base);
0289         std::remove_copy(this->_base.begin(), this->_base.end(), insert_iter ,value_type{1});
0290         return new_extent;
0291 
0292     }
0293 
0294     void clear()
0295     {
0296         this->_base.clear();
0297     }
0298 
0299     bool operator == (basic_extents const& b) const
0300     {
0301         return _base == b._base;
0302     }
0303 
0304     bool operator != (basic_extents const& b) const
0305     {
0306         return !( _base == b._base );
0307     }
0308 
0309     const_iterator
0310     begin() const
0311     {
0312         return _base.begin();
0313     }
0314 
0315     const_iterator
0316     end() const
0317     {
0318         return _base.end();
0319     }
0320 
0321     base_type const& base() const { return _base; }
0322 
0323 private:
0324 
0325     base_type _base;
0326 
0327 };
0328 
0329 using shape = basic_extents<std::size_t>;
0330 
0331 } // namespace ublas
0332 } // namespace numeric
0333 } // namespace boost
0334 
0335 #endif