Andrew Scull | 5e1ddfa | 2018-08-14 10:06:54 +0100 | [diff] [blame] | 1 | //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===// |
| 2 | // |
Andrew Walbran | 16937d0 | 2019-10-22 13:54:20 +0100 | [diff] [blame^] | 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
Andrew Scull | 5e1ddfa | 2018-08-14 10:06:54 +0100 | [diff] [blame] | 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Definition of BranchProbability shared by IR and Machine Instructions. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H |
| 14 | #define LLVM_SUPPORT_BRANCHPROBABILITY_H |
| 15 | |
| 16 | #include "llvm/Support/DataTypes.h" |
| 17 | #include <algorithm> |
| 18 | #include <cassert> |
| 19 | #include <climits> |
| 20 | #include <numeric> |
| 21 | |
| 22 | namespace llvm { |
| 23 | |
| 24 | class raw_ostream; |
| 25 | |
| 26 | // This class represents Branch Probability as a non-negative fraction that is |
| 27 | // no greater than 1. It uses a fixed-point-like implementation, in which the |
| 28 | // denominator is always a constant value (here we use 1<<31 for maximum |
| 29 | // precision). |
| 30 | class BranchProbability { |
| 31 | // Numerator |
| 32 | uint32_t N; |
| 33 | |
| 34 | // Denominator, which is a constant value. |
| 35 | static const uint32_t D = 1u << 31; |
| 36 | static const uint32_t UnknownN = UINT32_MAX; |
| 37 | |
| 38 | // Construct a BranchProbability with only numerator assuming the denominator |
| 39 | // is 1<<31. For internal use only. |
| 40 | explicit BranchProbability(uint32_t n) : N(n) {} |
| 41 | |
| 42 | public: |
| 43 | BranchProbability() : N(UnknownN) {} |
| 44 | BranchProbability(uint32_t Numerator, uint32_t Denominator); |
| 45 | |
| 46 | bool isZero() const { return N == 0; } |
| 47 | bool isUnknown() const { return N == UnknownN; } |
| 48 | |
| 49 | static BranchProbability getZero() { return BranchProbability(0); } |
| 50 | static BranchProbability getOne() { return BranchProbability(D); } |
| 51 | static BranchProbability getUnknown() { return BranchProbability(UnknownN); } |
| 52 | // Create a BranchProbability object with the given numerator and 1<<31 |
| 53 | // as denominator. |
| 54 | static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); } |
| 55 | // Create a BranchProbability object from 64-bit integers. |
| 56 | static BranchProbability getBranchProbability(uint64_t Numerator, |
| 57 | uint64_t Denominator); |
| 58 | |
| 59 | // Normalize given probabilties so that the sum of them becomes approximate |
| 60 | // one. |
| 61 | template <class ProbabilityIter> |
| 62 | static void normalizeProbabilities(ProbabilityIter Begin, |
| 63 | ProbabilityIter End); |
| 64 | |
| 65 | uint32_t getNumerator() const { return N; } |
| 66 | static uint32_t getDenominator() { return D; } |
| 67 | |
| 68 | // Return (1 - Probability). |
| 69 | BranchProbability getCompl() const { return BranchProbability(D - N); } |
| 70 | |
| 71 | raw_ostream &print(raw_ostream &OS) const; |
| 72 | |
| 73 | void dump() const; |
| 74 | |
Andrew Scull | cdfcccc | 2018-10-05 20:58:37 +0100 | [diff] [blame] | 75 | /// Scale a large integer. |
Andrew Scull | 5e1ddfa | 2018-08-14 10:06:54 +0100 | [diff] [blame] | 76 | /// |
| 77 | /// Scales \c Num. Guarantees full precision. Returns the floor of the |
| 78 | /// result. |
| 79 | /// |
| 80 | /// \return \c Num times \c this. |
| 81 | uint64_t scale(uint64_t Num) const; |
| 82 | |
Andrew Scull | cdfcccc | 2018-10-05 20:58:37 +0100 | [diff] [blame] | 83 | /// Scale a large integer by the inverse. |
Andrew Scull | 5e1ddfa | 2018-08-14 10:06:54 +0100 | [diff] [blame] | 84 | /// |
| 85 | /// Scales \c Num by the inverse of \c this. Guarantees full precision. |
| 86 | /// Returns the floor of the result. |
| 87 | /// |
| 88 | /// \return \c Num divided by \c this. |
| 89 | uint64_t scaleByInverse(uint64_t Num) const; |
| 90 | |
| 91 | BranchProbability &operator+=(BranchProbability RHS) { |
| 92 | assert(N != UnknownN && RHS.N != UnknownN && |
| 93 | "Unknown probability cannot participate in arithmetics."); |
| 94 | // Saturate the result in case of overflow. |
| 95 | N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N; |
| 96 | return *this; |
| 97 | } |
| 98 | |
| 99 | BranchProbability &operator-=(BranchProbability RHS) { |
| 100 | assert(N != UnknownN && RHS.N != UnknownN && |
| 101 | "Unknown probability cannot participate in arithmetics."); |
| 102 | // Saturate the result in case of underflow. |
| 103 | N = N < RHS.N ? 0 : N - RHS.N; |
| 104 | return *this; |
| 105 | } |
| 106 | |
| 107 | BranchProbability &operator*=(BranchProbability RHS) { |
| 108 | assert(N != UnknownN && RHS.N != UnknownN && |
| 109 | "Unknown probability cannot participate in arithmetics."); |
| 110 | N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D; |
| 111 | return *this; |
| 112 | } |
| 113 | |
| 114 | BranchProbability &operator*=(uint32_t RHS) { |
| 115 | assert(N != UnknownN && |
| 116 | "Unknown probability cannot participate in arithmetics."); |
| 117 | N = (uint64_t(N) * RHS > D) ? D : N * RHS; |
| 118 | return *this; |
| 119 | } |
| 120 | |
| 121 | BranchProbability &operator/=(uint32_t RHS) { |
| 122 | assert(N != UnknownN && |
| 123 | "Unknown probability cannot participate in arithmetics."); |
| 124 | assert(RHS > 0 && "The divider cannot be zero."); |
| 125 | N /= RHS; |
| 126 | return *this; |
| 127 | } |
| 128 | |
| 129 | BranchProbability operator+(BranchProbability RHS) const { |
| 130 | BranchProbability Prob(*this); |
| 131 | return Prob += RHS; |
| 132 | } |
| 133 | |
| 134 | BranchProbability operator-(BranchProbability RHS) const { |
| 135 | BranchProbability Prob(*this); |
| 136 | return Prob -= RHS; |
| 137 | } |
| 138 | |
| 139 | BranchProbability operator*(BranchProbability RHS) const { |
| 140 | BranchProbability Prob(*this); |
| 141 | return Prob *= RHS; |
| 142 | } |
| 143 | |
| 144 | BranchProbability operator*(uint32_t RHS) const { |
| 145 | BranchProbability Prob(*this); |
| 146 | return Prob *= RHS; |
| 147 | } |
| 148 | |
| 149 | BranchProbability operator/(uint32_t RHS) const { |
| 150 | BranchProbability Prob(*this); |
| 151 | return Prob /= RHS; |
| 152 | } |
| 153 | |
| 154 | bool operator==(BranchProbability RHS) const { return N == RHS.N; } |
| 155 | bool operator!=(BranchProbability RHS) const { return !(*this == RHS); } |
| 156 | |
| 157 | bool operator<(BranchProbability RHS) const { |
| 158 | assert(N != UnknownN && RHS.N != UnknownN && |
| 159 | "Unknown probability cannot participate in comparisons."); |
| 160 | return N < RHS.N; |
| 161 | } |
| 162 | |
| 163 | bool operator>(BranchProbability RHS) const { |
| 164 | assert(N != UnknownN && RHS.N != UnknownN && |
| 165 | "Unknown probability cannot participate in comparisons."); |
| 166 | return RHS < *this; |
| 167 | } |
| 168 | |
| 169 | bool operator<=(BranchProbability RHS) const { |
| 170 | assert(N != UnknownN && RHS.N != UnknownN && |
| 171 | "Unknown probability cannot participate in comparisons."); |
| 172 | return !(RHS < *this); |
| 173 | } |
| 174 | |
| 175 | bool operator>=(BranchProbability RHS) const { |
| 176 | assert(N != UnknownN && RHS.N != UnknownN && |
| 177 | "Unknown probability cannot participate in comparisons."); |
| 178 | return !(*this < RHS); |
| 179 | } |
| 180 | }; |
| 181 | |
| 182 | inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) { |
| 183 | return Prob.print(OS); |
| 184 | } |
| 185 | |
| 186 | template <class ProbabilityIter> |
| 187 | void BranchProbability::normalizeProbabilities(ProbabilityIter Begin, |
| 188 | ProbabilityIter End) { |
| 189 | if (Begin == End) |
| 190 | return; |
| 191 | |
| 192 | unsigned UnknownProbCount = 0; |
| 193 | uint64_t Sum = std::accumulate(Begin, End, uint64_t(0), |
| 194 | [&](uint64_t S, const BranchProbability &BP) { |
| 195 | if (!BP.isUnknown()) |
| 196 | return S + BP.N; |
| 197 | UnknownProbCount++; |
| 198 | return S; |
| 199 | }); |
| 200 | |
| 201 | if (UnknownProbCount > 0) { |
| 202 | BranchProbability ProbForUnknown = BranchProbability::getZero(); |
| 203 | // If the sum of all known probabilities is less than one, evenly distribute |
| 204 | // the complement of sum to unknown probabilities. Otherwise, set unknown |
| 205 | // probabilities to zeros and continue to normalize known probabilities. |
| 206 | if (Sum < BranchProbability::getDenominator()) |
| 207 | ProbForUnknown = BranchProbability::getRaw( |
| 208 | (BranchProbability::getDenominator() - Sum) / UnknownProbCount); |
| 209 | |
| 210 | std::replace_if(Begin, End, |
| 211 | [](const BranchProbability &BP) { return BP.isUnknown(); }, |
| 212 | ProbForUnknown); |
| 213 | |
| 214 | if (Sum <= BranchProbability::getDenominator()) |
| 215 | return; |
| 216 | } |
| 217 | |
| 218 | if (Sum == 0) { |
| 219 | BranchProbability BP(1, std::distance(Begin, End)); |
| 220 | std::fill(Begin, End, BP); |
| 221 | return; |
| 222 | } |
| 223 | |
| 224 | for (auto I = Begin; I != End; ++I) |
| 225 | I->N = (I->N * uint64_t(D) + Sum / 2) / Sum; |
| 226 | } |
| 227 | |
| 228 | } |
| 229 | |
| 230 | #endif |