File indexing completed on 2025-01-18 09:27:22
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015 #ifndef ABSL_RANDOM_GAUSSIAN_DISTRIBUTION_H_
0016 #define ABSL_RANDOM_GAUSSIAN_DISTRIBUTION_H_
0017
0018
0019
0020
0021
0022
0023
0024
0025 #include <cmath>
0026 #include <cstdint>
0027 #include <istream>
0028 #include <limits>
0029 #include <type_traits>
0030
0031 #include "absl/base/config.h"
0032 #include "absl/random/internal/fast_uniform_bits.h"
0033 #include "absl/random/internal/generate_real.h"
0034 #include "absl/random/internal/iostream_state_saver.h"
0035
0036 namespace absl {
0037 ABSL_NAMESPACE_BEGIN
0038 namespace random_internal {
0039
0040
0041
0042
0043
0044
0045
0046
0047 class ABSL_DLL gaussian_distribution_base {
0048 public:
0049 template <typename URBG>
0050 inline double zignor(URBG& g);
0051
0052 private:
0053 friend class TableGenerator;
0054
0055 template <typename URBG>
0056 inline double zignor_fallback(URBG& g,
0057 bool neg);
0058
0059
0060 static constexpr double kR = 3.442619855899;
0061 static constexpr double kRInv = 0.29047645161474317;
0062 static constexpr double kV = 9.91256303526217e-3;
0063 static constexpr uint64_t kMask = 0x07f;
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074 struct Tables {
0075 double x[kMask + 2];
0076 double f[kMask + 2];
0077 };
0078 static const Tables zg_;
0079 random_internal::FastUniformBits<uint64_t> fast_u64_;
0080 };
0081
0082 }
0083
0084
0085
0086 template <typename RealType = double>
0087 class gaussian_distribution : random_internal::gaussian_distribution_base {
0088 public:
0089 using result_type = RealType;
0090
0091 class param_type {
0092 public:
0093 using distribution_type = gaussian_distribution;
0094
0095 explicit param_type(result_type mean = 0, result_type stddev = 1)
0096 : mean_(mean), stddev_(stddev) {}
0097
0098
0099
0100 result_type mean() const { return mean_; }
0101
0102
0103 result_type stddev() const { return stddev_; }
0104
0105 friend bool operator==(const param_type& a, const param_type& b) {
0106 return a.mean_ == b.mean_ && a.stddev_ == b.stddev_;
0107 }
0108
0109 friend bool operator!=(const param_type& a, const param_type& b) {
0110 return !(a == b);
0111 }
0112
0113 private:
0114 result_type mean_;
0115 result_type stddev_;
0116
0117 static_assert(
0118 std::is_floating_point<RealType>::value,
0119 "Class-template absl::gaussian_distribution<> must be parameterized "
0120 "using a floating-point type.");
0121 };
0122
0123 gaussian_distribution() : gaussian_distribution(0) {}
0124
0125 explicit gaussian_distribution(result_type mean, result_type stddev = 1)
0126 : param_(mean, stddev) {}
0127
0128 explicit gaussian_distribution(const param_type& p) : param_(p) {}
0129
0130 void reset() {}
0131
0132
0133 template <typename URBG>
0134 result_type operator()(URBG& g) {
0135 return (*this)(g, param_);
0136 }
0137
0138 template <typename URBG>
0139 result_type operator()(URBG& g,
0140 const param_type& p);
0141
0142 param_type param() const { return param_; }
0143 void param(const param_type& p) { param_ = p; }
0144
0145 result_type(min)() const {
0146 return -std::numeric_limits<result_type>::infinity();
0147 }
0148 result_type(max)() const {
0149 return std::numeric_limits<result_type>::infinity();
0150 }
0151
0152 result_type mean() const { return param_.mean(); }
0153 result_type stddev() const { return param_.stddev(); }
0154
0155 friend bool operator==(const gaussian_distribution& a,
0156 const gaussian_distribution& b) {
0157 return a.param_ == b.param_;
0158 }
0159 friend bool operator!=(const gaussian_distribution& a,
0160 const gaussian_distribution& b) {
0161 return a.param_ != b.param_;
0162 }
0163
0164 private:
0165 param_type param_;
0166 };
0167
0168
0169
0170
0171
0172 template <typename RealType>
0173 template <typename URBG>
0174 typename gaussian_distribution<RealType>::result_type
0175 gaussian_distribution<RealType>::operator()(
0176 URBG& g,
0177 const param_type& p) {
0178 return p.mean() + p.stddev() * static_cast<result_type>(zignor(g));
0179 }
0180
0181 template <typename CharT, typename Traits, typename RealType>
0182 std::basic_ostream<CharT, Traits>& operator<<(
0183 std::basic_ostream<CharT, Traits>& os,
0184 const gaussian_distribution<RealType>& x) {
0185 auto saver = random_internal::make_ostream_state_saver(os);
0186 os.precision(random_internal::stream_precision_helper<RealType>::kPrecision);
0187 os << x.mean() << os.fill() << x.stddev();
0188 return os;
0189 }
0190
0191 template <typename CharT, typename Traits, typename RealType>
0192 std::basic_istream<CharT, Traits>& operator>>(
0193 std::basic_istream<CharT, Traits>& is,
0194 gaussian_distribution<RealType>& x) {
0195 using result_type = typename gaussian_distribution<RealType>::result_type;
0196 using param_type = typename gaussian_distribution<RealType>::param_type;
0197
0198 auto saver = random_internal::make_istream_state_saver(is);
0199 auto mean = random_internal::read_floating_point<result_type>(is);
0200 if (is.fail()) return is;
0201 auto stddev = random_internal::read_floating_point<result_type>(is);
0202 if (!is.fail()) {
0203 x.param(param_type(mean, stddev));
0204 }
0205 return is;
0206 }
0207
0208 namespace random_internal {
0209
0210 template <typename URBG>
0211 inline double gaussian_distribution_base::zignor_fallback(URBG& g, bool neg) {
0212 using random_internal::GeneratePositiveTag;
0213 using random_internal::GenerateRealFromBits;
0214
0215
0216 double x, y;
0217 do {
0218
0219 x = kRInv *
0220 std::log(GenerateRealFromBits<double, GeneratePositiveTag, false>(
0221 fast_u64_(g)));
0222 y = -std::log(
0223 GenerateRealFromBits<double, GeneratePositiveTag, false>(fast_u64_(g)));
0224 } while ((y + y) < (x * x));
0225 return neg ? (x - kR) : (kR - x);
0226 }
0227
0228 template <typename URBG>
0229 inline double gaussian_distribution_base::zignor(
0230 URBG& g) {
0231 using random_internal::GeneratePositiveTag;
0232 using random_internal::GenerateRealFromBits;
0233 using random_internal::GenerateSignedTag;
0234
0235 while (true) {
0236
0237
0238
0239
0240 uint64_t bits = fast_u64_(g);
0241 int i = static_cast<int>(bits & kMask);
0242 double j = GenerateRealFromBits<double, GenerateSignedTag, false>(
0243 bits);
0244 const double x = j * zg_.x[i];
0245
0246
0247
0248
0249 if (std::abs(x) < zg_.x[i + 1]) {
0250 return x;
0251 }
0252
0253
0254 if (i == 0) {
0255
0256 return zignor_fallback(g, j < 0);
0257 }
0258
0259
0260 double v = GenerateRealFromBits<double, GeneratePositiveTag, false>(
0261 fast_u64_(g));
0262 if ((zg_.f[i + 1] + v * (zg_.f[i] - zg_.f[i + 1])) <
0263 std::exp(-0.5 * x * x)) {
0264 return x;
0265 }
0266
0267
0268 }
0269 }
0270
0271 }
0272 ABSL_NAMESPACE_END
0273 }
0274
0275 #endif