File indexing completed on 2025-01-18 09:51:09
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013 #ifndef BOOST_RANDOM_DETAIL_POLYNOMIAL_HPP
0014 #define BOOST_RANDOM_DETAIL_POLYNOMIAL_HPP
0015
0016 #include <cstddef>
0017 #include <limits>
0018 #include <vector>
0019 #include <algorithm>
0020 #include <boost/assert.hpp>
0021 #include <boost/cstdint.hpp>
0022
0023 namespace boost {
0024 namespace random {
0025 namespace detail {
0026
0027 class polynomial_ops {
0028 public:
0029 typedef unsigned long digit_t;
0030
0031 static void add(std::size_t size, const digit_t * lhs,
0032 const digit_t * rhs, digit_t * output)
0033 {
0034 for(std::size_t i = 0; i < size; ++i) {
0035 output[i] = lhs[i] ^ rhs[i];
0036 }
0037 }
0038
0039 static void add_shifted_inplace(std::size_t size, const digit_t * lhs,
0040 digit_t * output, std::size_t shift)
0041 {
0042 if(shift == 0) {
0043 add(size, lhs, output, output);
0044 return;
0045 }
0046 std::size_t bits = std::numeric_limits<digit_t>::digits;
0047 digit_t prev = 0;
0048 for(std::size_t i = 0; i < size; ++i) {
0049 digit_t tmp = lhs[i];
0050 output[i] ^= (tmp << shift) | (prev >> (bits-shift));
0051 prev = tmp;
0052 }
0053 output[size] ^= (prev >> (bits-shift));
0054 }
0055
0056 static void multiply_simple(std::size_t size, const digit_t * lhs,
0057 const digit_t * rhs, digit_t * output)
0058 {
0059 std::size_t bits = std::numeric_limits<digit_t>::digits;
0060 for(std::size_t i = 0; i < 2*size; ++i) {
0061 output[i] = 0;
0062 }
0063 for(std::size_t i = 0; i < size; ++i) {
0064 for(std::size_t j = 0; j < bits; ++j) {
0065 if((lhs[i] & (digit_t(1) << j)) != 0) {
0066 add_shifted_inplace(size, rhs, output + i, j);
0067 }
0068 }
0069 }
0070 }
0071
0072
0073 static void multiply_karatsuba(std::size_t size,
0074 const digit_t * lhs, const digit_t * rhs,
0075 digit_t * output)
0076 {
0077 if(size < 64) {
0078 multiply_simple(size, lhs, rhs, output);
0079 return;
0080 }
0081
0082 std::size_t cutoff = size/2;
0083 multiply_karatsuba(cutoff, lhs, rhs, output);
0084 multiply_karatsuba(size - cutoff, lhs + cutoff, rhs + cutoff,
0085 output + cutoff*2);
0086 std::vector<digit_t> local1(size - cutoff);
0087 std::vector<digit_t> local2(size - cutoff);
0088
0089 add(cutoff, lhs, lhs + cutoff, &local1[0]);
0090 if(size & 1) local1[cutoff] = lhs[size - 1];
0091 add(cutoff, rhs + cutoff, rhs, &local2[0]);
0092 if(size & 1) local2[cutoff] = rhs[size - 1];
0093 std::vector<digit_t> local3((size - cutoff) * 2);
0094 multiply_karatsuba(size - cutoff, &local1[0], &local2[0], &local3[0]);
0095 add(cutoff * 2, output, &local3[0], &local3[0]);
0096 add((size - cutoff) * 2, output + cutoff*2, &local3[0], &local3[0]);
0097
0098 add((size - cutoff) * 2, output + cutoff, &local3[0], output + cutoff);
0099 }
0100
0101 static void multiply_add_karatsuba(std::size_t size,
0102 const digit_t * lhs, const digit_t * rhs,
0103 digit_t * output)
0104 {
0105 std::vector<digit_t> buf(size * 2);
0106 multiply_karatsuba(size, lhs, rhs, &buf[0]);
0107 add(size * 2, &buf[0], output, output);
0108 }
0109
0110 static void multiply(const digit_t * lhs, std::size_t lhs_size,
0111 const digit_t * rhs, std::size_t rhs_size,
0112 digit_t * output)
0113 {
0114 std::fill_n(output, lhs_size + rhs_size, digit_t(0));
0115 multiply_add(lhs, lhs_size, rhs, rhs_size, output);
0116 }
0117
0118 static void multiply_add(const digit_t * lhs, std::size_t lhs_size,
0119 const digit_t * rhs, std::size_t rhs_size,
0120 digit_t * output)
0121 {
0122
0123
0124 while(lhs_size != 0) {
0125 if(lhs_size < rhs_size) {
0126 std::swap(lhs, rhs);
0127 std::swap(lhs_size, rhs_size);
0128 }
0129
0130 multiply_add_karatsuba(rhs_size, lhs, rhs, output);
0131
0132 lhs += rhs_size;
0133 lhs_size -= rhs_size;
0134 output += rhs_size;
0135 }
0136 }
0137
0138 static void copy_bits(const digit_t * x, std::size_t low, std::size_t high,
0139 digit_t * out)
0140 {
0141 const std::size_t bits = std::numeric_limits<digit_t>::digits;
0142 std::size_t offset = low/bits;
0143 x += offset;
0144 low -= offset*bits;
0145 high -= offset*bits;
0146 std::size_t n = (high-low)/bits;
0147 if(low == 0) {
0148 for(std::size_t i = 0; i < n; ++i) {
0149 out[i] = x[i];
0150 }
0151 } else {
0152 for(std::size_t i = 0; i < n; ++i) {
0153 out[i] = (x[i] >> low) | (x[i+1] << (bits-low));
0154 }
0155 }
0156 if((high-low)%bits) {
0157 digit_t low_mask = (digit_t(1) << ((high-low)%bits)) - 1;
0158 digit_t result = (x[n] >> low);
0159 if(low != 0 && (n+1)*bits < high) {
0160 result |= (x[n+1] << (bits-low));
0161 }
0162 out[n] = (result & low_mask);
0163 }
0164 }
0165
0166 static void shift_left(digit_t * val, std::size_t size, std::size_t shift)
0167 {
0168 const std::size_t bits = std::numeric_limits<digit_t>::digits;
0169 BOOST_ASSERT(shift > 0);
0170 BOOST_ASSERT(shift < bits);
0171 digit_t prev = 0;
0172 for(std::size_t i = 0; i < size; ++i) {
0173 digit_t tmp = val[i];
0174 val[i] = (prev >> (bits - shift)) | (val[i] << shift);
0175 prev = tmp;
0176 }
0177 }
0178
0179 static digit_t sqr(digit_t val) {
0180 const std::size_t bits = std::numeric_limits<digit_t>::digits;
0181 digit_t mask = (digit_t(1) << bits/2) - 1;
0182 for(std::size_t i = bits; i > 1; i /= 2) {
0183 val = ((val & ~mask) << i/2) | (val & mask);
0184 mask = mask & (mask >> i/4);
0185 mask = mask | (mask << i/2);
0186 }
0187 return val;
0188 }
0189
0190 static void sqr(digit_t * val, std::size_t size)
0191 {
0192 const std::size_t bits = std::numeric_limits<digit_t>::digits;
0193 digit_t mask = (digit_t(1) << bits/2) - 1;
0194 for(std::size_t i = 0; i < size; ++i) {
0195 digit_t x = val[size - i - 1];
0196 val[(size - i - 1) * 2] = sqr(x & mask);
0197 val[(size - i - 1) * 2 + 1] = sqr(x >> bits/2);
0198 }
0199 }
0200
0201
0202 struct sparse_mod {
0203 sparse_mod(const digit_t * divisor, std::size_t divisor_bits)
0204 {
0205 const std::size_t bits = std::numeric_limits<digit_t>::digits;
0206 _remainder_bits = divisor_bits - 1;
0207 for(std::size_t i = 0; i < divisor_bits; ++i) {
0208 if(divisor[i/bits] & (digit_t(1) << i%bits)) {
0209 _bit_indices.push_back(i);
0210 }
0211 }
0212 BOOST_ASSERT(_bit_indices.back() == divisor_bits - 1);
0213 _bit_indices.pop_back();
0214 if(_bit_indices.empty()) {
0215 _block_bits = divisor_bits;
0216 _lower_bits = 0;
0217 } else {
0218 _block_bits = divisor_bits - _bit_indices.back() - 1;
0219 _lower_bits = _bit_indices.back() + 1;
0220 }
0221
0222 _partial_quotient.resize((_block_bits + bits - 1)/bits);
0223 }
0224 void operator()(digit_t * dividend, std::size_t dividend_bits)
0225 {
0226 const std::size_t bits = std::numeric_limits<digit_t>::digits;
0227 while(dividend_bits > _remainder_bits) {
0228 std::size_t block_start = (std::max)(dividend_bits - _block_bits, _remainder_bits);
0229 std::size_t block_size = (dividend_bits - block_start + bits - 1) / bits;
0230 copy_bits(dividend, block_start, dividend_bits, &_partial_quotient[0]);
0231 for(std::size_t i = 0; i < _bit_indices.size(); ++i) {
0232 std::size_t pos = _bit_indices[i] + block_start - _remainder_bits;
0233 add_shifted_inplace(block_size, &_partial_quotient[0], dividend + pos/bits, pos%bits);
0234 }
0235 add_shifted_inplace(block_size, &_partial_quotient[0], dividend + block_start/bits, block_start%bits);
0236 dividend_bits = block_start;
0237 }
0238 }
0239 std::vector<digit_t> _partial_quotient;
0240 std::size_t _remainder_bits;
0241 std::size_t _block_bits;
0242 std::size_t _lower_bits;
0243 std::vector<std::size_t> _bit_indices;
0244 };
0245
0246
0247
0248
0249 static void mod_pow_x(boost::uintmax_t exponent, const digit_t * mod, std::size_t mod_bits, digit_t * out)
0250 {
0251 const std::size_t bits = std::numeric_limits<digit_t>::digits;
0252 const std::size_t n = (mod_bits + bits - 1) / bits;
0253 const std::size_t highbit = mod_bits - 1;
0254 if(exponent == 0) {
0255 out[0] = 1;
0256 std::fill_n(out + 1, n - 1, digit_t(0));
0257 return;
0258 }
0259 boost::uintmax_t i = std::numeric_limits<boost::uintmax_t>::digits - 1;
0260 while(((boost::uintmax_t(1) << i) & exponent) == 0) {
0261 --i;
0262 }
0263 out[0] = 2;
0264 std::fill_n(out + 1, n - 1, digit_t(0));
0265 sparse_mod m(mod, mod_bits);
0266 while(i--) {
0267 sqr(out, n);
0268 m(out, 2 * mod_bits - 1);
0269 if((boost::uintmax_t(1) << i) & exponent) {
0270 shift_left(out, n, 1);
0271 if(out[highbit / bits] & (digit_t(1) << highbit%bits))
0272 add(n, out, mod, out);
0273 }
0274 }
0275 }
0276 };
0277
0278 class polynomial
0279 {
0280 typedef polynomial_ops::digit_t digit_t;
0281 public:
0282 polynomial() : _size(0) {}
0283 class reference {
0284 public:
0285 reference(digit_t &value, int idx)
0286 : _value(value), _idx(idx) {}
0287 operator bool() const { return (_value & (digit_t(1) << _idx)) != 0; }
0288 reference& operator=(bool b)
0289 {
0290 if(b) {
0291 _value |= (digit_t(1) << _idx);
0292 } else {
0293 _value &= ~(digit_t(1) << _idx);
0294 }
0295 return *this;
0296 }
0297 reference &operator^=(bool b)
0298 {
0299 _value ^= (digit_t(b) << _idx);
0300 return *this;
0301 }
0302
0303 reference &operator=(const reference &other)
0304 {
0305 return *this = static_cast<bool>(other);
0306 }
0307 private:
0308 digit_t &_value;
0309 int _idx;
0310 };
0311 reference operator[](std::size_t i)
0312 {
0313 static const std::size_t bits = std::numeric_limits<digit_t>::digits;
0314 ensure_bit(i);
0315 return reference(_storage[i/bits], i%bits);
0316 }
0317 bool operator[](std::size_t i) const
0318 {
0319 static const std::size_t bits = std::numeric_limits<digit_t>::digits;
0320 if(i < size())
0321 return (_storage[i/bits] & (digit_t(1) << (i%bits))) != 0;
0322 else
0323 return false;
0324 }
0325 std::size_t size() const
0326 {
0327 return _size;
0328 }
0329 void resize(std::size_t n)
0330 {
0331 static const std::size_t bits = std::numeric_limits<digit_t>::digits;
0332 _storage.resize((n + bits - 1)/bits);
0333
0334 if(n%bits) {
0335 _storage.back() &= ((digit_t(1) << (n%bits)) - 1);
0336 }
0337 _size = n;
0338 }
0339 friend polynomial operator*(const polynomial &lhs, const polynomial &rhs);
0340 friend polynomial mod_pow_x(boost::uintmax_t exponent, polynomial mod);
0341 private:
0342 std::vector<polynomial_ops::digit_t> _storage;
0343 std::size_t _size;
0344 void ensure_bit(std::size_t i)
0345 {
0346 if(i >= size()) {
0347 resize(i + 1);
0348 }
0349 }
0350 void normalize()
0351 {
0352 while(size() && (*this)[size() - 1] == 0)
0353 resize(size() - 1);
0354 }
0355 };
0356
0357 inline polynomial operator*(const polynomial &lhs, const polynomial &rhs)
0358 {
0359 polynomial result;
0360 result._storage.resize(lhs._storage.size() + rhs._storage.size());
0361 polynomial_ops::multiply(&lhs._storage[0], lhs._storage.size(),
0362 &rhs._storage[0], rhs._storage.size(),
0363 &result._storage[0]);
0364 result._size = lhs._size + rhs._size;
0365 return result;
0366 }
0367
0368 inline polynomial mod_pow_x(boost::uintmax_t exponent, polynomial mod)
0369 {
0370 polynomial result;
0371 mod.normalize();
0372 std::size_t mod_size = mod.size();
0373 result._storage.resize(mod._storage.size() * 2);
0374 result._size = mod.size() * 2;
0375 polynomial_ops::mod_pow_x(exponent, &mod._storage[0], mod_size, &result._storage[0]);
0376 result.resize(mod.size() - 1);
0377 return result;
0378 }
0379
0380 }
0381 }
0382 }
0383
0384 #endif