Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-10 08:44:28

0001 //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
0002 //
0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
0004 // See https://llvm.org/LICENSE.txt for license information.
0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
0006 //
0007 //===----------------------------------------------------------------------===//
0008 //
0009 // Definition of BranchProbability shared by IR and Machine Instructions.
0010 //
0011 //===----------------------------------------------------------------------===//
0012 
0013 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
0014 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
0015 
0016 #include "llvm/Support/DataTypes.h"
0017 #include <algorithm>
0018 #include <cassert>
0019 #include <iterator>
0020 #include <numeric>
0021 
0022 namespace llvm {
0023 
0024 class raw_ostream;
0025 
0026 // This class represents Branch Probability as a non-negative fraction that is
0027 // no greater than 1. It uses a fixed-point-like implementation, in which the
0028 // denominator is always a constant value (here we use 1<<31 for maximum
0029 // precision).
0030 class BranchProbability {
0031   // Numerator
0032   uint32_t N;
0033 
0034   // Denominator, which is a constant value.
0035   static constexpr uint32_t D = 1u << 31;
0036   static constexpr uint32_t UnknownN = UINT32_MAX;
0037 
0038   // Construct a BranchProbability with only numerator assuming the denominator
0039   // is 1<<31. For internal use only.
0040   explicit BranchProbability(uint32_t n) : N(n) {}
0041 
0042 public:
0043   BranchProbability() : N(UnknownN) {}
0044   BranchProbability(uint32_t Numerator, uint32_t Denominator);
0045 
0046   bool isZero() const { return N == 0; }
0047   bool isUnknown() const { return N == UnknownN; }
0048 
0049   static BranchProbability getZero() { return BranchProbability(0); }
0050   static BranchProbability getOne() { return BranchProbability(D); }
0051   static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
0052   // Create a BranchProbability object with the given numerator and 1<<31
0053   // as denominator.
0054   static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
0055   // Create a BranchProbability object from 64-bit integers.
0056   static BranchProbability getBranchProbability(uint64_t Numerator,
0057                                                 uint64_t Denominator);
0058 
0059   // Normalize given probabilties so that the sum of them becomes approximate
0060   // one.
0061   template <class ProbabilityIter>
0062   static void normalizeProbabilities(ProbabilityIter Begin,
0063                                      ProbabilityIter End);
0064 
0065   uint32_t getNumerator() const { return N; }
0066   static uint32_t getDenominator() { return D; }
0067 
0068   // Return (1 - Probability).
0069   BranchProbability getCompl() const { return BranchProbability(D - N); }
0070 
0071   raw_ostream &print(raw_ostream &OS) const;
0072 
0073   void dump() const;
0074 
0075   /// Scale a large integer.
0076   ///
0077   /// Scales \c Num.  Guarantees full precision.  Returns the floor of the
0078   /// result.
0079   ///
0080   /// \return \c Num times \c this.
0081   uint64_t scale(uint64_t Num) const;
0082 
0083   /// Scale a large integer by the inverse.
0084   ///
0085   /// Scales \c Num by the inverse of \c this.  Guarantees full precision.
0086   /// Returns the floor of the result.
0087   ///
0088   /// \return \c Num divided by \c this.
0089   uint64_t scaleByInverse(uint64_t Num) const;
0090 
0091   BranchProbability &operator+=(BranchProbability RHS) {
0092     assert(N != UnknownN && RHS.N != UnknownN &&
0093            "Unknown probability cannot participate in arithmetics.");
0094     // Saturate the result in case of overflow.
0095     N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N;
0096     return *this;
0097   }
0098 
0099   BranchProbability &operator-=(BranchProbability RHS) {
0100     assert(N != UnknownN && RHS.N != UnknownN &&
0101            "Unknown probability cannot participate in arithmetics.");
0102     // Saturate the result in case of underflow.
0103     N = N < RHS.N ? 0 : N - RHS.N;
0104     return *this;
0105   }
0106 
0107   BranchProbability &operator*=(BranchProbability RHS) {
0108     assert(N != UnknownN && RHS.N != UnknownN &&
0109            "Unknown probability cannot participate in arithmetics.");
0110     N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
0111     return *this;
0112   }
0113 
0114   BranchProbability &operator*=(uint32_t RHS) {
0115     assert(N != UnknownN &&
0116            "Unknown probability cannot participate in arithmetics.");
0117     N = (uint64_t(N) * RHS > D) ? D : N * RHS;
0118     return *this;
0119   }
0120 
0121   BranchProbability &operator/=(BranchProbability RHS) {
0122     assert(N != UnknownN && RHS.N != UnknownN &&
0123            "Unknown probability cannot participate in arithmetics.");
0124     N = (static_cast<uint64_t>(N) * D + RHS.N / 2) / RHS.N;
0125     return *this;
0126   }
0127 
0128   BranchProbability &operator/=(uint32_t RHS) {
0129     assert(N != UnknownN &&
0130            "Unknown probability cannot participate in arithmetics.");
0131     assert(RHS > 0 && "The divider cannot be zero.");
0132     N /= RHS;
0133     return *this;
0134   }
0135 
0136   BranchProbability operator+(BranchProbability RHS) const {
0137     BranchProbability Prob(*this);
0138     Prob += RHS;
0139     return Prob;
0140   }
0141 
0142   BranchProbability operator-(BranchProbability RHS) const {
0143     BranchProbability Prob(*this);
0144     Prob -= RHS;
0145     return Prob;
0146   }
0147 
0148   BranchProbability operator*(BranchProbability RHS) const {
0149     BranchProbability Prob(*this);
0150     Prob *= RHS;
0151     return Prob;
0152   }
0153 
0154   BranchProbability operator*(uint32_t RHS) const {
0155     BranchProbability Prob(*this);
0156     Prob *= RHS;
0157     return Prob;
0158   }
0159 
0160   BranchProbability operator/(BranchProbability RHS) const {
0161     BranchProbability Prob(*this);
0162     Prob /= RHS;
0163     return Prob;
0164   }
0165 
0166   BranchProbability operator/(uint32_t RHS) const {
0167     BranchProbability Prob(*this);
0168     Prob /= RHS;
0169     return Prob;
0170   }
0171 
0172   bool operator==(BranchProbability RHS) const { return N == RHS.N; }
0173   bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
0174 
0175   bool operator<(BranchProbability RHS) const {
0176     assert(N != UnknownN && RHS.N != UnknownN &&
0177            "Unknown probability cannot participate in comparisons.");
0178     return N < RHS.N;
0179   }
0180 
0181   bool operator>(BranchProbability RHS) const {
0182     assert(N != UnknownN && RHS.N != UnknownN &&
0183            "Unknown probability cannot participate in comparisons.");
0184     return RHS < *this;
0185   }
0186 
0187   bool operator<=(BranchProbability RHS) const {
0188     assert(N != UnknownN && RHS.N != UnknownN &&
0189            "Unknown probability cannot participate in comparisons.");
0190     return !(RHS < *this);
0191   }
0192 
0193   bool operator>=(BranchProbability RHS) const {
0194     assert(N != UnknownN && RHS.N != UnknownN &&
0195            "Unknown probability cannot participate in comparisons.");
0196     return !(*this < RHS);
0197   }
0198 };
0199 
0200 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
0201   return Prob.print(OS);
0202 }
0203 
0204 template <class ProbabilityIter>
0205 void BranchProbability::normalizeProbabilities(ProbabilityIter Begin,
0206                                                ProbabilityIter End) {
0207   if (Begin == End)
0208     return;
0209 
0210   unsigned UnknownProbCount = 0;
0211   uint64_t Sum = std::accumulate(Begin, End, uint64_t(0),
0212                                  [&](uint64_t S, const BranchProbability &BP) {
0213                                    if (!BP.isUnknown())
0214                                      return S + BP.N;
0215                                    UnknownProbCount++;
0216                                    return S;
0217                                  });
0218 
0219   if (UnknownProbCount > 0) {
0220     BranchProbability ProbForUnknown = BranchProbability::getZero();
0221     // If the sum of all known probabilities is less than one, evenly distribute
0222     // the complement of sum to unknown probabilities. Otherwise, set unknown
0223     // probabilities to zeros and continue to normalize known probabilities.
0224     if (Sum < BranchProbability::getDenominator())
0225       ProbForUnknown = BranchProbability::getRaw(
0226           (BranchProbability::getDenominator() - Sum) / UnknownProbCount);
0227 
0228     std::replace_if(Begin, End,
0229                     [](const BranchProbability &BP) { return BP.isUnknown(); },
0230                     ProbForUnknown);
0231 
0232     if (Sum <= BranchProbability::getDenominator())
0233       return;
0234   }
0235 
0236   if (Sum == 0) {
0237     BranchProbability BP(1, std::distance(Begin, End));
0238     std::fill(Begin, End, BP);
0239     return;
0240   }
0241 
0242   for (auto I = Begin; I != End; ++I)
0243     I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
0244 }
0245 
0246 }
0247 
0248 #endif