File indexing completed on 2025-01-18 09:27:22
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015 #ifndef ABSL_RANDOM_DISCRETE_DISTRIBUTION_H_
0016 #define ABSL_RANDOM_DISCRETE_DISTRIBUTION_H_
0017
0018 #include <cassert>
0019 #include <cmath>
0020 #include <istream>
0021 #include <limits>
0022 #include <numeric>
0023 #include <type_traits>
0024 #include <utility>
0025 #include <vector>
0026
0027 #include "absl/random/bernoulli_distribution.h"
0028 #include "absl/random/internal/iostream_state_saver.h"
0029 #include "absl/random/uniform_int_distribution.h"
0030
0031 namespace absl {
0032 ABSL_NAMESPACE_BEGIN
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051 template <typename IntType = int>
0052 class discrete_distribution {
0053 public:
0054 using result_type = IntType;
0055
0056 class param_type {
0057 public:
0058 using distribution_type = discrete_distribution;
0059
0060 param_type() { init(); }
0061
0062 template <typename InputIterator>
0063 explicit param_type(InputIterator begin, InputIterator end)
0064 : p_(begin, end) {
0065 init();
0066 }
0067
0068 explicit param_type(std::initializer_list<double> weights) : p_(weights) {
0069 init();
0070 }
0071
0072 template <class UnaryOperation>
0073 explicit param_type(size_t nw, double xmin, double xmax,
0074 UnaryOperation fw) {
0075 if (nw > 0) {
0076 p_.reserve(nw);
0077 double delta = (xmax - xmin) / static_cast<double>(nw);
0078 assert(delta > 0);
0079 double t = delta * 0.5;
0080 for (size_t i = 0; i < nw; ++i) {
0081 p_.push_back(fw(xmin + i * delta + t));
0082 }
0083 }
0084 init();
0085 }
0086
0087 const std::vector<double>& probabilities() const { return p_; }
0088 size_t n() const { return p_.size() - 1; }
0089
0090 friend bool operator==(const param_type& a, const param_type& b) {
0091 return a.probabilities() == b.probabilities();
0092 }
0093
0094 friend bool operator!=(const param_type& a, const param_type& b) {
0095 return !(a == b);
0096 }
0097
0098 private:
0099 friend class discrete_distribution;
0100
0101 void init();
0102
0103 std::vector<double> p_;
0104 std::vector<std::pair<double, size_t>> q_;
0105
0106 static_assert(std::is_integral<result_type>::value,
0107 "Class-template absl::discrete_distribution<> must be "
0108 "parameterized using an integral type.");
0109 };
0110
0111 discrete_distribution() : param_() {}
0112
0113 explicit discrete_distribution(const param_type& p) : param_(p) {}
0114
0115 template <typename InputIterator>
0116 explicit discrete_distribution(InputIterator begin, InputIterator end)
0117 : param_(begin, end) {}
0118
0119 explicit discrete_distribution(std::initializer_list<double> weights)
0120 : param_(weights) {}
0121
0122 template <class UnaryOperation>
0123 explicit discrete_distribution(size_t nw, double xmin, double xmax,
0124 UnaryOperation fw)
0125 : param_(nw, xmin, xmax, std::move(fw)) {}
0126
0127 void reset() {}
0128
0129
0130 template <typename URBG>
0131 result_type operator()(URBG& g) {
0132 return (*this)(g, param_);
0133 }
0134
0135 template <typename URBG>
0136 result_type operator()(URBG& g,
0137 const param_type& p);
0138
0139 const param_type& param() const { return param_; }
0140 void param(const param_type& p) { param_ = p; }
0141
0142 result_type(min)() const { return 0; }
0143 result_type(max)() const {
0144 return static_cast<result_type>(param_.n());
0145 }
0146
0147
0148
0149 const std::vector<double>& probabilities() const {
0150 return param_.probabilities();
0151 }
0152
0153 friend bool operator==(const discrete_distribution& a,
0154 const discrete_distribution& b) {
0155 return a.param_ == b.param_;
0156 }
0157 friend bool operator!=(const discrete_distribution& a,
0158 const discrete_distribution& b) {
0159 return a.param_ != b.param_;
0160 }
0161
0162 private:
0163 param_type param_;
0164 };
0165
0166
0167
0168
0169
0170 namespace random_internal {
0171
0172
0173
0174
0175
0176
0177 std::vector<std::pair<double, size_t>> InitDiscreteDistribution(
0178 std::vector<double>* probabilities);
0179
0180 }
0181
0182 template <typename IntType>
0183 void discrete_distribution<IntType>::param_type::init() {
0184 if (p_.empty()) {
0185 p_.push_back(1.0);
0186 q_.emplace_back(1.0, 0);
0187 } else {
0188 assert(n() <= (std::numeric_limits<IntType>::max)());
0189 q_ = random_internal::InitDiscreteDistribution(&p_);
0190 }
0191 }
0192
0193 template <typename IntType>
0194 template <typename URBG>
0195 typename discrete_distribution<IntType>::result_type
0196 discrete_distribution<IntType>::operator()(
0197 URBG& g,
0198 const param_type& p) {
0199 const auto idx = absl::uniform_int_distribution<result_type>(0, p.n())(g);
0200 const auto& q = p.q_[idx];
0201 const bool selected = absl::bernoulli_distribution(q.first)(g);
0202 return selected ? idx : static_cast<result_type>(q.second);
0203 }
0204
0205 template <typename CharT, typename Traits, typename IntType>
0206 std::basic_ostream<CharT, Traits>& operator<<(
0207 std::basic_ostream<CharT, Traits>& os,
0208 const discrete_distribution<IntType>& x) {
0209 auto saver = random_internal::make_ostream_state_saver(os);
0210 const auto& probabilities = x.param().probabilities();
0211 os << probabilities.size();
0212
0213 os.precision(random_internal::stream_precision_helper<double>::kPrecision);
0214 for (const auto& p : probabilities) {
0215 os << os.fill() << p;
0216 }
0217 return os;
0218 }
0219
0220 template <typename CharT, typename Traits, typename IntType>
0221 std::basic_istream<CharT, Traits>& operator>>(
0222 std::basic_istream<CharT, Traits>& is,
0223 discrete_distribution<IntType>& x) {
0224 using param_type = typename discrete_distribution<IntType>::param_type;
0225 auto saver = random_internal::make_istream_state_saver(is);
0226
0227 size_t n;
0228 std::vector<double> p;
0229
0230 is >> n;
0231 if (is.fail()) return is;
0232 if (n > 0) {
0233 p.reserve(n);
0234 for (IntType i = 0; i < n && !is.fail(); ++i) {
0235 auto tmp = random_internal::read_floating_point<double>(is);
0236 if (is.fail()) return is;
0237 p.push_back(tmp);
0238 }
0239 }
0240 x.param(param_type(p.begin(), p.end()));
0241 return is;
0242 }
0243
0244 ABSL_NAMESPACE_END
0245 }
0246
0247 #endif