File indexing completed on 2026-05-10 08:44:28
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
0014 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
0015
0016 #include "llvm/Support/DataTypes.h"
0017 #include <algorithm>
0018 #include <cassert>
0019 #include <iterator>
0020 #include <numeric>
0021
0022 namespace llvm {
0023
0024 class raw_ostream;
0025
0026
0027
0028
0029
0030 class BranchProbability {
0031
0032 uint32_t N;
0033
0034
0035 static constexpr uint32_t D = 1u << 31;
0036 static constexpr uint32_t UnknownN = UINT32_MAX;
0037
0038
0039
0040 explicit BranchProbability(uint32_t n) : N(n) {}
0041
0042 public:
0043 BranchProbability() : N(UnknownN) {}
0044 BranchProbability(uint32_t Numerator, uint32_t Denominator);
0045
0046 bool isZero() const { return N == 0; }
0047 bool isUnknown() const { return N == UnknownN; }
0048
0049 static BranchProbability getZero() { return BranchProbability(0); }
0050 static BranchProbability getOne() { return BranchProbability(D); }
0051 static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
0052
0053
0054 static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
0055
0056 static BranchProbability getBranchProbability(uint64_t Numerator,
0057 uint64_t Denominator);
0058
0059
0060
0061 template <class ProbabilityIter>
0062 static void normalizeProbabilities(ProbabilityIter Begin,
0063 ProbabilityIter End);
0064
0065 uint32_t getNumerator() const { return N; }
0066 static uint32_t getDenominator() { return D; }
0067
0068
0069 BranchProbability getCompl() const { return BranchProbability(D - N); }
0070
0071 raw_ostream &print(raw_ostream &OS) const;
0072
0073 void dump() const;
0074
0075
0076
0077
0078
0079
0080
0081 uint64_t scale(uint64_t Num) const;
0082
0083
0084
0085
0086
0087
0088
0089 uint64_t scaleByInverse(uint64_t Num) const;
0090
0091 BranchProbability &operator+=(BranchProbability RHS) {
0092 assert(N != UnknownN && RHS.N != UnknownN &&
0093 "Unknown probability cannot participate in arithmetics.");
0094
0095 N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N;
0096 return *this;
0097 }
0098
0099 BranchProbability &operator-=(BranchProbability RHS) {
0100 assert(N != UnknownN && RHS.N != UnknownN &&
0101 "Unknown probability cannot participate in arithmetics.");
0102
0103 N = N < RHS.N ? 0 : N - RHS.N;
0104 return *this;
0105 }
0106
0107 BranchProbability &operator*=(BranchProbability RHS) {
0108 assert(N != UnknownN && RHS.N != UnknownN &&
0109 "Unknown probability cannot participate in arithmetics.");
0110 N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
0111 return *this;
0112 }
0113
0114 BranchProbability &operator*=(uint32_t RHS) {
0115 assert(N != UnknownN &&
0116 "Unknown probability cannot participate in arithmetics.");
0117 N = (uint64_t(N) * RHS > D) ? D : N * RHS;
0118 return *this;
0119 }
0120
0121 BranchProbability &operator/=(BranchProbability RHS) {
0122 assert(N != UnknownN && RHS.N != UnknownN &&
0123 "Unknown probability cannot participate in arithmetics.");
0124 N = (static_cast<uint64_t>(N) * D + RHS.N / 2) / RHS.N;
0125 return *this;
0126 }
0127
0128 BranchProbability &operator/=(uint32_t RHS) {
0129 assert(N != UnknownN &&
0130 "Unknown probability cannot participate in arithmetics.");
0131 assert(RHS > 0 && "The divider cannot be zero.");
0132 N /= RHS;
0133 return *this;
0134 }
0135
0136 BranchProbability operator+(BranchProbability RHS) const {
0137 BranchProbability Prob(*this);
0138 Prob += RHS;
0139 return Prob;
0140 }
0141
0142 BranchProbability operator-(BranchProbability RHS) const {
0143 BranchProbability Prob(*this);
0144 Prob -= RHS;
0145 return Prob;
0146 }
0147
0148 BranchProbability operator*(BranchProbability RHS) const {
0149 BranchProbability Prob(*this);
0150 Prob *= RHS;
0151 return Prob;
0152 }
0153
0154 BranchProbability operator*(uint32_t RHS) const {
0155 BranchProbability Prob(*this);
0156 Prob *= RHS;
0157 return Prob;
0158 }
0159
0160 BranchProbability operator/(BranchProbability RHS) const {
0161 BranchProbability Prob(*this);
0162 Prob /= RHS;
0163 return Prob;
0164 }
0165
0166 BranchProbability operator/(uint32_t RHS) const {
0167 BranchProbability Prob(*this);
0168 Prob /= RHS;
0169 return Prob;
0170 }
0171
0172 bool operator==(BranchProbability RHS) const { return N == RHS.N; }
0173 bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
0174
0175 bool operator<(BranchProbability RHS) const {
0176 assert(N != UnknownN && RHS.N != UnknownN &&
0177 "Unknown probability cannot participate in comparisons.");
0178 return N < RHS.N;
0179 }
0180
0181 bool operator>(BranchProbability RHS) const {
0182 assert(N != UnknownN && RHS.N != UnknownN &&
0183 "Unknown probability cannot participate in comparisons.");
0184 return RHS < *this;
0185 }
0186
0187 bool operator<=(BranchProbability RHS) const {
0188 assert(N != UnknownN && RHS.N != UnknownN &&
0189 "Unknown probability cannot participate in comparisons.");
0190 return !(RHS < *this);
0191 }
0192
0193 bool operator>=(BranchProbability RHS) const {
0194 assert(N != UnknownN && RHS.N != UnknownN &&
0195 "Unknown probability cannot participate in comparisons.");
0196 return !(*this < RHS);
0197 }
0198 };
0199
0200 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
0201 return Prob.print(OS);
0202 }
0203
0204 template <class ProbabilityIter>
0205 void BranchProbability::normalizeProbabilities(ProbabilityIter Begin,
0206 ProbabilityIter End) {
0207 if (Begin == End)
0208 return;
0209
0210 unsigned UnknownProbCount = 0;
0211 uint64_t Sum = std::accumulate(Begin, End, uint64_t(0),
0212 [&](uint64_t S, const BranchProbability &BP) {
0213 if (!BP.isUnknown())
0214 return S + BP.N;
0215 UnknownProbCount++;
0216 return S;
0217 });
0218
0219 if (UnknownProbCount > 0) {
0220 BranchProbability ProbForUnknown = BranchProbability::getZero();
0221
0222
0223
0224 if (Sum < BranchProbability::getDenominator())
0225 ProbForUnknown = BranchProbability::getRaw(
0226 (BranchProbability::getDenominator() - Sum) / UnknownProbCount);
0227
0228 std::replace_if(Begin, End,
0229 [](const BranchProbability &BP) { return BP.isUnknown(); },
0230 ProbForUnknown);
0231
0232 if (Sum <= BranchProbability::getDenominator())
0233 return;
0234 }
0235
0236 if (Sum == 0) {
0237 BranchProbability BP(1, std::distance(Begin, End));
0238 std::fill(Begin, End, BP);
0239 return;
0240 }
0241
0242 for (auto I = Begin; I != End; ++I)
0243 I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
0244 }
0245
0246 }
0247
0248 #endif