aboutsummaryrefslogtreecommitdiffstats
path: root/lib/Transforms/Utils
diff options
context:
space:
mode:
authorNick Lewycky <nicholas@mxc.ca>2012-01-25 09:43:14 +0000
committerNick Lewycky <nicholas@mxc.ca>2012-01-25 09:43:14 +0000
commit6c00c6a181934afc8bc23db3033e46a3ff93602b (patch)
treeb7351f45458205630ac60e1969d40855399b41eb /lib/Transforms/Utils
parent6977e79526a6342b9408c358bca21c3b0c7e61d5 (diff)
downloadexternal_llvm-6c00c6a181934afc8bc23db3033e46a3ff93602b.zip
external_llvm-6c00c6a181934afc8bc23db3033e46a3ff93602b.tar.gz
external_llvm-6c00c6a181934afc8bc23db3033e46a3ff93602b.tar.bz2
Gracefully degrade precision in branch probability numbers.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@148946 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Transforms/Utils')
-rw-r--r--lib/Transforms/Utils/SimplifyCFG.cpp89
1 files changed, 72 insertions, 17 deletions
diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp
index 326bd7a..d5ae609 100644
--- a/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1466,6 +1466,29 @@ static bool ExtractBranchMetadata(BranchInst *BI,
return true;
}
+/// MultiplyAndLosePrecision - Multiplies A and B, then returns the result. In
+/// the event of overflow, logically-shifts all four inputs right until the
+/// multiply fits.
+static APInt MultiplyAndLosePrecision(APInt &A, APInt &B, APInt &C, APInt &D,
+ unsigned &BitsLost) {
+ BitsLost = 0;
+ bool Overflow = false;
+ APInt Result = A.umul_ov(B, Overflow);
+ if (Overflow) {
+ APInt MaxB = APInt::getMaxValue(A.getBitWidth()).udiv(A);
+ do {
+ B = B.lshr(1);
+ ++BitsLost;
+ } while (B.ugt(MaxB));
+ A = A.lshr(BitsLost);
+ C = C.lshr(BitsLost);
+ D = D.lshr(BitsLost);
+ Result = A * B;
+ }
+ return Result;
+}
+
+
/// FoldBranchToCommonDest - If this basic block is simple enough, and if a
/// predecessor branches to us and one of our successors, fold the block into
/// the predecessor and use logical operations to pick the right destination.
@@ -1665,32 +1688,64 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI) {
// we get:
// (a*c)% = A*C, (b+(a*d))% = A*D+B*C+B*D.
- bool Overflow1 = false, Overflow2 = false, Overflow3 = false;
- bool Overflow4 = false, Overflow5 = false, Overflow6 = false;
- APInt ProbTrue = A.umul_ov(C, Overflow1);
+ // In the event of overflow, we want to drop the LSB of the input
+ // probabilities.
+ unsigned BitsLost;
- APInt Tmp1 = A.umul_ov(D, Overflow2);
- APInt Tmp2 = B.umul_ov(C, Overflow3);
- APInt Tmp3 = B.umul_ov(D, Overflow4);
- APInt Tmp4 = Tmp1.uadd_ov(Tmp2, Overflow5);
- APInt ProbFalse = Tmp4.uadd_ov(Tmp3, Overflow6);
+ // Ignore overflow result on ProbTrue.
+ APInt ProbTrue = MultiplyAndLosePrecision(A, C, B, D, BitsLost);
- APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
- ProbTrue = ProbTrue.udiv(GCD);
- ProbFalse = ProbFalse.udiv(GCD);
+ APInt Tmp1 = MultiplyAndLosePrecision(B, D, A, C, BitsLost);
+ if (BitsLost) {
+ ProbTrue = ProbTrue.lshr(BitsLost*2);
+ }
+
+ APInt Tmp2 = MultiplyAndLosePrecision(A, D, C, B, BitsLost);
+ if (BitsLost) {
+ ProbTrue = ProbTrue.lshr(BitsLost*2);
+ Tmp1 = Tmp1.lshr(BitsLost*2);
+ }
+
+ APInt Tmp3 = MultiplyAndLosePrecision(B, C, A, D, BitsLost);
+ if (BitsLost) {
+ ProbTrue = ProbTrue.lshr(BitsLost*2);
+ Tmp1 = Tmp1.lshr(BitsLost*2);
+ Tmp2 = Tmp2.lshr(BitsLost*2);
+ }
+
+ bool Overflow1 = false, Overflow2 = false;
+ APInt Tmp4 = Tmp2.uadd_ov(Tmp3, Overflow1);
+ APInt ProbFalse = Tmp4.uadd_ov(Tmp1, Overflow2);
+
+ if (Overflow1 || Overflow2) {
+ ProbTrue = ProbTrue.lshr(1);
+ Tmp1 = Tmp1.lshr(1);
+ Tmp2 = Tmp2.lshr(1);
+ Tmp3 = Tmp3.lshr(1);
+ Tmp4 = Tmp2 + Tmp3;
+ ProbFalse = Tmp4 + Tmp1;
+ }
+
+ // The sum of branch weights must fit in 32-bits.
+ if (ProbTrue.isNegative() && ProbFalse.isNegative()) {
+ ProbTrue = ProbTrue.lshr(1);
+ ProbFalse = ProbFalse.lshr(1);
+ }
+
+ if (ProbTrue != ProbFalse) {
+ // Normalize the result.
+ APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
+ ProbTrue = ProbTrue.udiv(GCD);
+ ProbFalse = ProbFalse.udiv(GCD);
- if (Overflow1 || Overflow2 || Overflow3 || Overflow4 || Overflow5 ||
- Overflow6) {
- DEBUG(dbgs() << "Overflow recomputing branch weight on: " << *PBI
- << "when merging with: " << *BI);
- PBI->setMetadata(LLVMContext::MD_prof, NULL);
- } else {
LLVMContext &Context = BI->getContext();
Value *Ops[3];
Ops[0] = BI->getMetadata(LLVMContext::MD_prof)->getOperand(0);
Ops[1] = ConstantInt::get(Context, ProbTrue);
Ops[2] = ConstantInt::get(Context, ProbFalse);
PBI->setMetadata(LLVMContext::MD_prof, MDNode::get(Context, Ops));
+ } else {
+ PBI->setMetadata(LLVMContext::MD_prof, NULL);
}
} else {
PBI->setMetadata(LLVMContext::MD_prof, NULL);