|
||||
File indexing completed on 2025-01-30 09:31:48
0001 // Copyright 2017 The Abseil Authors. 0002 // 0003 // Licensed under the Apache License, Version 2.0 (the "License"); 0004 // you may not use this file except in compliance with the License. 0005 // You may obtain a copy of the License at 0006 // 0007 // https://www.apache.org/licenses/LICENSE-2.0 0008 // 0009 // Unless required by applicable law or agreed to in writing, software 0010 // distributed under the License is distributed on an "AS IS" BASIS, 0011 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 0012 // See the License for the specific language governing permissions and 0013 // limitations under the License. 0014 0015 #ifndef ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_ 0016 #define ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_ 0017 0018 #include <cstdint> 0019 #include <istream> 0020 #include <limits> 0021 0022 #include "absl/base/optimization.h" 0023 #include "absl/random/internal/fast_uniform_bits.h" 0024 #include "absl/random/internal/iostream_state_saver.h" 0025 0026 namespace absl { 0027 ABSL_NAMESPACE_BEGIN 0028 0029 // absl::bernoulli_distribution is a drop in replacement for 0030 // std::bernoulli_distribution. It guarantees that (given a perfect 0031 // UniformRandomBitGenerator) the acceptance probability is *exactly* equal to 0032 // the given double. 0033 // 0034 // The implementation assumes that double is IEEE754 0035 class bernoulli_distribution { 0036 public: 0037 using result_type = bool; 0038 0039 class param_type { 0040 public: 0041 using distribution_type = bernoulli_distribution; 0042 0043 explicit param_type(double p = 0.5) : prob_(p) { 0044 assert(p >= 0.0 && p <= 1.0); 0045 } 0046 0047 double p() const { return prob_; } 0048 0049 friend bool operator==(const param_type& p1, const param_type& p2) { 0050 return p1.p() == p2.p(); 0051 } 0052 friend bool operator!=(const param_type& p1, const param_type& p2) { 0053 return p1.p() != p2.p(); 0054 } 0055 0056 private: 0057 double prob_; 0058 }; 0059 0060 bernoulli_distribution() : bernoulli_distribution(0.5) {} 0061 0062 explicit bernoulli_distribution(double p) : param_(p) {} 0063 0064 explicit bernoulli_distribution(param_type p) : param_(p) {} 0065 0066 // no-op 0067 void reset() {} 0068 0069 template <typename URBG> 0070 bool operator()(URBG& g) { // NOLINT(runtime/references) 0071 return Generate(param_.p(), g); 0072 } 0073 0074 template <typename URBG> 0075 bool operator()(URBG& g, // NOLINT(runtime/references) 0076 const param_type& param) { 0077 return Generate(param.p(), g); 0078 } 0079 0080 param_type param() const { return param_; } 0081 void param(const param_type& param) { param_ = param; } 0082 0083 double p() const { return param_.p(); } 0084 0085 result_type(min)() const { return false; } 0086 result_type(max)() const { return true; } 0087 0088 friend bool operator==(const bernoulli_distribution& d1, 0089 const bernoulli_distribution& d2) { 0090 return d1.param_ == d2.param_; 0091 } 0092 0093 friend bool operator!=(const bernoulli_distribution& d1, 0094 const bernoulli_distribution& d2) { 0095 return d1.param_ != d2.param_; 0096 } 0097 0098 private: 0099 static constexpr uint64_t kP32 = static_cast<uint64_t>(1) << 32; 0100 0101 template <typename URBG> 0102 static bool Generate(double p, URBG& g); // NOLINT(runtime/references) 0103 0104 param_type param_; 0105 }; 0106 0107 template <typename CharT, typename Traits> 0108 std::basic_ostream<CharT, Traits>& operator<<( 0109 std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) 0110 const bernoulli_distribution& x) { 0111 auto saver = random_internal::make_ostream_state_saver(os); 0112 os.precision(random_internal::stream_precision_helper<double>::kPrecision); 0113 os << x.p(); 0114 return os; 0115 } 0116 0117 template <typename CharT, typename Traits> 0118 std::basic_istream<CharT, Traits>& operator>>( 0119 std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) 0120 bernoulli_distribution& x) { // NOLINT(runtime/references) 0121 auto saver = random_internal::make_istream_state_saver(is); 0122 auto p = random_internal::read_floating_point<double>(is); 0123 if (!is.fail()) { 0124 x.param(bernoulli_distribution::param_type(p)); 0125 } 0126 return is; 0127 } 0128 0129 template <typename URBG> 0130 bool bernoulli_distribution::Generate(double p, 0131 URBG& g) { // NOLINT(runtime/references) 0132 random_internal::FastUniformBits<uint32_t> fast_u32; 0133 0134 while (true) { 0135 // There are two aspects of the definition of `c` below that are worth 0136 // commenting on. First, because `p` is in the range [0, 1], `c` is in the 0137 // range [0, 2^32] which does not fit in a uint32_t and therefore requires 0138 // 64 bits. 0139 // 0140 // Second, `c` is constructed by first casting explicitly to a signed 0141 // integer and then casting explicitly to an unsigned integer of the same 0142 // size. This is done because the hardware conversion instructions produce 0143 // signed integers from double; if taken as a uint64_t the conversion would 0144 // be wrong for doubles greater than 2^63 (not relevant in this use-case). 0145 // If converted directly to an unsigned integer, the compiler would end up 0146 // emitting code to handle such large values that are not relevant due to 0147 // the known bounds on `c`. To avoid these extra instructions this 0148 // implementation converts first to the signed type and then convert to 0149 // unsigned (which is a no-op). 0150 const uint64_t c = static_cast<uint64_t>(static_cast<int64_t>(p * kP32)); 0151 const uint32_t v = fast_u32(g); 0152 // FAST PATH: this path fails with probability 1/2^32. Note that simply 0153 // returning v <= c would approximate P very well (up to an absolute error 0154 // of 1/2^32); the slow path (taken in that range of possible error, in the 0155 // case of equality) eliminates the remaining error. 0156 if (ABSL_PREDICT_TRUE(v != c)) return v < c; 0157 0158 // It is guaranteed that `q` is strictly less than 1, because if `q` were 0159 // greater than or equal to 1, the same would be true for `p`. Certainly `p` 0160 // cannot be greater than 1, and if `p == 1`, then the fast path would 0161 // necessary have been taken already. 0162 const double q = static_cast<double>(c) / kP32; 0163 0164 // The probability of acceptance on the fast path is `q` and so the 0165 // probability of acceptance here should be `p - q`. 0166 // 0167 // Note that `q` is obtained from `p` via some shifts and conversions, the 0168 // upshot of which is that `q` is simply `p` with some of the 0169 // least-significant bits of its mantissa set to zero. This means that the 0170 // difference `p - q` will not have any rounding errors. To see why, pretend 0171 // that double has 10 bits of resolution and q is obtained from `p` in such 0172 // a way that the 4 least-significant bits of its mantissa are set to zero. 0173 // For example: 0174 // p = 1.1100111011 * 2^-1 0175 // q = 1.1100110000 * 2^-1 0176 // p - q = 1.011 * 2^-8 0177 // The difference `p - q` has exactly the nonzero mantissa bits that were 0178 // "lost" in `q` producing a number which is certainly representable in a 0179 // double. 0180 const double left = p - q; 0181 0182 // By construction, the probability of being on this slow path is 1/2^32, so 0183 // P(accept in slow path) = P(accept| in slow path) * P(slow path), 0184 // which means the probability of acceptance here is `1 / (left * kP32)`: 0185 const double here = left * kP32; 0186 0187 // The simplest way to compute the result of this trial is to repeat the 0188 // whole algorithm with the new probability. This terminates because even 0189 // given arbitrarily unfriendly "random" bits, each iteration either 0190 // multiplies a tiny probability by 2^32 (if c == 0) or strips off some 0191 // number of nonzero mantissa bits. That process is bounded. 0192 if (here == 0) return false; 0193 p = here; 0194 } 0195 } 0196 0197 ABSL_NAMESPACE_END 0198 } // namespace absl 0199 0200 #endif // ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |