File indexing completed on 2025-01-18 09:43:02
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013 #ifndef BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP
0014 #define BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP
0015
0016 #include <algorithm>
0017 #include <initializer_list>
0018 #include <limits>
0019 #include <numeric>
0020 #include <stdexcept>
0021 #include <vector>
0022
0023 #include <cassert>
0024
0025 namespace boost {
0026 namespace numeric {
0027 namespace ublas {
0028
0029
0030
0031
0032
0033
0034
0035 template<class int_type>
0036 class basic_extents
0037 {
0038 static_assert( std::numeric_limits<typename std::vector<int_type>::value_type>::is_integer, "Static error in basic_layout: type must be of type integer.");
0039 static_assert(!std::numeric_limits<typename std::vector<int_type>::value_type>::is_signed, "Static error in basic_layout: type must be of type unsigned integer.");
0040
0041 public:
0042 using base_type = std::vector<int_type>;
0043 using value_type = typename base_type::value_type;
0044 using const_reference = typename base_type::const_reference;
0045 using reference = typename base_type::reference;
0046 using size_type = typename base_type::size_type;
0047 using const_pointer = typename base_type::const_pointer;
0048 using const_iterator = typename base_type::const_iterator;
0049
0050
0051
0052
0053
0054
0055 constexpr explicit basic_extents()
0056 : _base{}
0057 {
0058 }
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068 explicit basic_extents(base_type const& b)
0069 : _base(b)
0070 {
0071 if (!this->valid()){
0072 throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
0073 }
0074 }
0075
0076
0077
0078
0079
0080
0081
0082
0083
0084 explicit basic_extents(base_type && b)
0085 : _base(std::move(b))
0086 {
0087 if (!this->valid()){
0088 throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
0089 }
0090 }
0091
0092
0093
0094
0095
0096
0097
0098
0099
0100 basic_extents(std::initializer_list<value_type> l)
0101 : basic_extents( base_type(std::move(l)) )
0102 {
0103 }
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114 basic_extents(const_iterator first, const_iterator last)
0115 : basic_extents ( base_type( first,last ) )
0116 {
0117 }
0118
0119
0120 basic_extents(basic_extents const& l )
0121 : _base(l._base)
0122 {
0123 }
0124
0125
0126 basic_extents(basic_extents && l ) noexcept
0127 : _base(std::move(l._base))
0128 {
0129 }
0130
0131 ~basic_extents() = default;
0132
0133 basic_extents& operator=(basic_extents other) noexcept
0134 {
0135 swap (*this, other);
0136 return *this;
0137 }
0138
0139 friend void swap(basic_extents& lhs, basic_extents& rhs) {
0140 std::swap(lhs._base , rhs._base );
0141 }
0142
0143
0144
0145
0146
0147
0148
0149 bool is_scalar() const
0150 {
0151 return
0152 _base.size() != 0 &&
0153 std::all_of(_base.begin(), _base.end(),
0154 [](const_reference a){ return a == 1;});
0155 }
0156
0157
0158
0159
0160
0161 bool is_vector() const
0162 {
0163 if(_base.size() == 0){
0164 return false;
0165 }
0166
0167 if(_base.size() == 1){
0168 return _base.at(0) > 1;
0169 }
0170
0171 auto greater_one = [](const_reference a){ return a > 1;};
0172 auto equal_one = [](const_reference a){ return a == 1;};
0173
0174 return
0175 std::any_of(_base.begin(), _base.begin()+2, greater_one) &&
0176 std::any_of(_base.begin(), _base.begin()+2, equal_one ) &&
0177 std::all_of(_base.begin()+2, _base.end(), equal_one);
0178 }
0179
0180
0181
0182
0183
0184 bool is_matrix() const
0185 {
0186 if(_base.size() < 2){
0187 return false;
0188 }
0189
0190 auto greater_one = [](const_reference a){ return a > 1;};
0191 auto equal_one = [](const_reference a){ return a == 1;};
0192
0193 return
0194 std::all_of(_base.begin(), _base.begin()+2, greater_one) &&
0195 std::all_of(_base.begin()+2, _base.end(), equal_one );
0196 }
0197
0198
0199
0200
0201
0202 bool is_tensor() const
0203 {
0204 if(_base.size() < 3){
0205 return false;
0206 }
0207
0208 auto greater_one = [](const_reference a){ return a > 1;};
0209
0210 return std::any_of(_base.begin()+2, _base.end(), greater_one);
0211 }
0212
0213 const_pointer data() const
0214 {
0215 return this->_base.data();
0216 }
0217
0218 const_reference operator[] (size_type p) const
0219 {
0220 return this->_base[p];
0221 }
0222
0223 const_reference at (size_type p) const
0224 {
0225 return this->_base.at(p);
0226 }
0227
0228 reference operator[] (size_type p)
0229 {
0230 return this->_base[p];
0231 }
0232
0233 reference at (size_type p)
0234 {
0235 return this->_base.at(p);
0236 }
0237
0238
0239 bool empty() const
0240 {
0241 return this->_base.empty();
0242 }
0243
0244 size_type size() const
0245 {
0246 return this->_base.size();
0247 }
0248
0249
0250 bool valid() const
0251 {
0252 return
0253 this->size() > 1 &&
0254 std::none_of(_base.begin(), _base.end(),
0255 [](const_reference a){ return a == value_type(0); });
0256 }
0257
0258
0259 size_type product() const
0260 {
0261 if(_base.empty()){
0262 return 0;
0263 }
0264
0265 return std::accumulate(_base.begin(), _base.end(), 1ul, std::multiplies<>());
0266
0267 }
0268
0269
0270
0271
0272
0273
0274
0275
0276
0277
0278
0279
0280
0281 basic_extents squeeze() const
0282 {
0283 if(this->size() <= 2){
0284 return *this;
0285 }
0286
0287 auto new_extent = basic_extents{};
0288 auto insert_iter = std::back_insert_iterator<typename basic_extents::base_type>(new_extent._base);
0289 std::remove_copy(this->_base.begin(), this->_base.end(), insert_iter ,value_type{1});
0290 return new_extent;
0291
0292 }
0293
0294 void clear()
0295 {
0296 this->_base.clear();
0297 }
0298
0299 bool operator == (basic_extents const& b) const
0300 {
0301 return _base == b._base;
0302 }
0303
0304 bool operator != (basic_extents const& b) const
0305 {
0306 return !( _base == b._base );
0307 }
0308
0309 const_iterator
0310 begin() const
0311 {
0312 return _base.begin();
0313 }
0314
0315 const_iterator
0316 end() const
0317 {
0318 return _base.end();
0319 }
0320
0321 base_type const& base() const { return _base; }
0322
0323 private:
0324
0325 base_type _base;
0326
0327 };
0328
0329 using shape = basic_extents<std::size_t>;
0330
0331 }
0332 }
0333 }
0334
0335 #endif