aboutsummaryrefslogtreecommitdiffstats
path: root/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineMulDivRem.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineMulDivRem.cpp174
1 files changed, 138 insertions, 36 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 8c48dce..c48e3c9 100644
--- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -12,7 +12,7 @@
//
//===----------------------------------------------------------------------===//
-#include "InstCombine.h"
+#include "InstCombineInternal.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
@@ -46,10 +46,10 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC,
// (PowerOfTwo >>u B) --> isExact since shifting out the result would make it
// inexact. Similarly for <<.
if (BinaryOperator *I = dyn_cast<BinaryOperator>(V))
- if (I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0), false,
- 0, IC.getAssumptionTracker(),
- CxtI,
- IC.getDominatorTree())) {
+ if (I->isLogicalShift() &&
+ isKnownToBeAPowerOfTwo(I->getOperand(0), false, 0,
+ IC.getAssumptionCache(), CxtI,
+ IC.getDominatorTree())) {
// We know that this is an exact/nuw shift and that the input is a
// non-zero context as well.
if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) {
@@ -123,6 +123,48 @@ static Constant *getLogBase2Vector(ConstantDataVector *CV) {
return ConstantVector::get(Elts);
}
+/// \brief Return true if we can prove that:
+/// (mul LHS, RHS) === (mul nsw LHS, RHS)
+bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS,
+ Instruction *CxtI) {
+ // Multiplying n * m significant bits yields a result of n + m significant
+ // bits. If the total number of significant bits does not exceed the
+ // result bit width (minus 1), there is no overflow.
+ // This means if we have enough leading sign bits in the operands
+ // we can guarantee that the result does not overflow.
+ // Ref: "Hacker's Delight" by Henry Warren
+ unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
+
+ // Note that underestimating the number of sign bits gives a more
+ // conservative answer.
+ unsigned SignBits = ComputeNumSignBits(LHS, 0, CxtI) +
+ ComputeNumSignBits(RHS, 0, CxtI);
+
+ // First handle the easy case: if we have enough sign bits there's
+ // definitely no overflow.
+ if (SignBits > BitWidth + 1)
+ return true;
+
+ // There are two ambiguous cases where there can be no overflow:
+ // SignBits == BitWidth + 1 and
+ // SignBits == BitWidth
+ // The second case is difficult to check, therefore we only handle the
+ // first case.
+ if (SignBits == BitWidth + 1) {
+ // It overflows only when both arguments are negative and the true
+ // product is exactly the minimum negative number.
+ // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
+ // For simplicity we just check if at least one side is not negative.
+ bool LHSNonNegative, LHSNegative;
+ bool RHSNonNegative, RHSNegative;
+ ComputeSignBit(LHS, LHSNonNegative, LHSNegative, /*Depth=*/0, CxtI);
+ ComputeSignBit(RHS, RHSNonNegative, RHSNegative, /*Depth=*/0, CxtI);
+ if (LHSNonNegative || RHSNonNegative)
+ return true;
+ }
+ return false;
+}
+
Instruction *InstCombiner::visitMul(BinaryOperator &I) {
bool Changed = SimplifyAssociativeOrCommutative(I);
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -130,14 +172,19 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return ReplaceInstUsesWith(I, V);
- if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AT))
+ if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC))
return ReplaceInstUsesWith(I, V);
if (Value *V = SimplifyUsingDistributiveLaws(I))
return ReplaceInstUsesWith(I, V);
- if (match(Op1, m_AllOnes())) // X * -1 == 0 - X
- return BinaryOperator::CreateNeg(Op0, I.getName());
+ // X * -1 == 0 - X
+ if (match(Op1, m_AllOnes())) {
+ BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName());
+ if (I.hasNoSignedWrap())
+ BO->setHasNoSignedWrap();
+ return BO;
+ }
// Also allow combining multiply instructions on vectors.
{
@@ -146,9 +193,18 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
const APInt *IVal;
if (match(&I, m_Mul(m_Shl(m_Value(NewOp), m_Constant(C2)),
m_Constant(C1))) &&
- match(C1, m_APInt(IVal)))
- // ((X << C1)*C2) == (X * (C2 << C1))
- return BinaryOperator::CreateMul(NewOp, ConstantExpr::getShl(C1, C2));
+ match(C1, m_APInt(IVal))) {
+ // ((X << C2)*C1) == (X * (C1 << C2))
+ Constant *Shl = ConstantExpr::getShl(C1, C2);
+ BinaryOperator *Mul = cast<BinaryOperator>(I.getOperand(0));
+ BinaryOperator *BO = BinaryOperator::CreateMul(NewOp, Shl);
+ if (I.hasNoUnsignedWrap() && Mul->hasNoUnsignedWrap())
+ BO->setHasNoUnsignedWrap();
+ if (I.hasNoSignedWrap() && Mul->hasNoSignedWrap() &&
+ Shl->isNotMinSignedValue())
+ BO->setHasNoSignedWrap();
+ return BO;
+ }
if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) {
Constant *NewCst = nullptr;
@@ -165,6 +221,8 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
if (I.hasNoUnsignedWrap())
Shl->setHasNoUnsignedWrap();
+ if (I.hasNoSignedWrap() && NewCst->isNotMinSignedValue())
+ Shl->setHasNoSignedWrap();
return Shl;
}
@@ -221,9 +279,16 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
}
}
- if (Value *Op0v = dyn_castNegVal(Op0)) // -X * -Y = X*Y
- if (Value *Op1v = dyn_castNegVal(Op1))
- return BinaryOperator::CreateMul(Op0v, Op1v);
+ if (Value *Op0v = dyn_castNegVal(Op0)) { // -X * -Y = X*Y
+ if (Value *Op1v = dyn_castNegVal(Op1)) {
+ BinaryOperator *BO = BinaryOperator::CreateMul(Op0v, Op1v);
+ if (I.hasNoSignedWrap() &&
+ match(Op0, m_NSWSub(m_Value(), m_Value())) &&
+ match(Op1, m_NSWSub(m_Value(), m_Value())))
+ BO->setHasNoSignedWrap();
+ return BO;
+ }
+ }
// (X / Y) * Y = X - (X % Y)
// (X / Y) * -Y = (X % Y) - X
@@ -272,10 +337,22 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
// (1 << Y)*X --> X << Y
{
Value *Y;
- if (match(Op0, m_Shl(m_One(), m_Value(Y))))
- return BinaryOperator::CreateShl(Op1, Y);
- if (match(Op1, m_Shl(m_One(), m_Value(Y))))
- return BinaryOperator::CreateShl(Op0, Y);
+ BinaryOperator *BO = nullptr;
+ bool ShlNSW = false;
+ if (match(Op0, m_Shl(m_One(), m_Value(Y)))) {
+ BO = BinaryOperator::CreateShl(Op1, Y);
+ ShlNSW = cast<ShlOperator>(Op0)->hasNoSignedWrap();
+ } else if (match(Op1, m_Shl(m_One(), m_Value(Y)))) {
+ BO = BinaryOperator::CreateShl(Op0, Y);
+ ShlNSW = cast<ShlOperator>(Op1)->hasNoSignedWrap();
+ }
+ if (BO) {
+ if (I.hasNoUnsignedWrap())
+ BO->setHasNoUnsignedWrap();
+ if (I.hasNoSignedWrap() && ShlNSW)
+ BO->setHasNoSignedWrap();
+ return BO;
+ }
}
// If one of the operands of the multiply is a cast from a boolean value, then
@@ -298,6 +375,18 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
}
}
+ if (!I.hasNoSignedWrap() && WillNotOverflowSignedMul(Op0, Op1, &I)) {
+ Changed = true;
+ I.setHasNoSignedWrap(true);
+ }
+
+ if (!I.hasNoUnsignedWrap() &&
+ computeOverflowForUnsignedMul(Op0, Op1, &I) ==
+ OverflowResult::NeverOverflows) {
+ Changed = true;
+ I.setHasNoUnsignedWrap(true);
+ }
+
return Changed ? &I : nullptr;
}
@@ -441,8 +530,8 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
if (isa<Constant>(Op0))
std::swap(Op0, Op1);
- if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI,
- DT, AT))
+ if (Value *V =
+ SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC))
return ReplaceInstUsesWith(I, V);
bool AllowReassociate = I.hasUnsafeAlgebra();
@@ -946,7 +1035,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return ReplaceInstUsesWith(I, V);
- if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AT))
+ if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC))
return ReplaceInstUsesWith(I, V);
// Handle the integer div common cases
@@ -961,9 +1050,14 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
match(Op1, m_APInt(C2))) {
bool Overflow;
APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow);
- if (!Overflow)
- return BinaryOperator::CreateUDiv(
+ if (!Overflow) {
+ bool IsExact = I.isExact() && match(Op0, m_Exact(m_Value()));
+ BinaryOperator *BO = BinaryOperator::CreateUDiv(
X, ConstantInt::get(X->getType(), C2ShlC1));
+ if (IsExact)
+ BO->setIsExact();
+ return BO;
+ }
}
}
@@ -1014,7 +1108,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return ReplaceInstUsesWith(I, V);
- if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AT))
+ if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC))
return ReplaceInstUsesWith(I, V);
// Handle the integer div common cases
@@ -1041,10 +1135,12 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
return new ZExtInst(Builder->CreateICmpEQ(Op0, Op1), I.getType());
// -X/C --> X/-C provided the negation doesn't overflow.
- if (SubOperator *Sub = dyn_cast<SubOperator>(Op0))
- if (match(Sub->getOperand(0), m_Zero()) && Sub->hasNoSignedWrap())
- return BinaryOperator::CreateSDiv(Sub->getOperand(1),
- ConstantExpr::getNeg(RHS));
+ Value *X;
+ if (match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) {
+ auto *BO = BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(RHS));
+ BO->setIsExact(I.isExact());
+ return BO;
+ }
}
// If the sign bits of both operands are zero (i.e. we can prove they are
@@ -1054,15 +1150,19 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
if (MaskedValueIsZero(Op0, Mask, 0, &I)) {
if (MaskedValueIsZero(Op1, Mask, 0, &I)) {
// X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set
- return BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
+ auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
+ BO->setIsExact(I.isExact());
+ return BO;
}
- if (match(Op1, m_Shl(m_Power2(), m_Value()))) {
+ if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, AC, &I, DT)) {
// X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y)
// Safe because the only negative value (1 << Y) can take on is
// INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have
// the sign bit set.
- return BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
+ auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
+ BO->setIsExact(I.isExact());
+ return BO;
}
}
}
@@ -1106,7 +1206,8 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return ReplaceInstUsesWith(I, V);
- if (Value *V = SimplifyFDivInst(Op0, Op1, DL, TLI, DT, AT))
+ if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(),
+ DL, TLI, DT, AC))
return ReplaceInstUsesWith(I, V);
if (isa<Constant>(Op0))
@@ -1271,7 +1372,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return ReplaceInstUsesWith(I, V);
- if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AT))
+ if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC))
return ReplaceInstUsesWith(I, V);
if (Instruction *common = commonIRemTransforms(I))
@@ -1284,7 +1385,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) {
I.getType());
// X urem Y -> X and Y-1, where Y is a power of 2,
- if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true, 0, AT, &I, DT)) {
+ if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, AC, &I, DT)) {
Constant *N1 = Constant::getAllOnesValue(I.getType());
Value *Add = Builder->CreateAdd(Op1, N1);
return BinaryOperator::CreateAnd(Op0, Add);
@@ -1306,7 +1407,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return ReplaceInstUsesWith(I, V);
- if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AT))
+ if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC))
return ReplaceInstUsesWith(I, V);
// Handle the integer rem common cases
@@ -1381,7 +1482,8 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return ReplaceInstUsesWith(I, V);
- if (Value *V = SimplifyFRemInst(Op0, Op1, DL, TLI, DT, AT))
+ if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(),
+ DL, TLI, DT, AC))
return ReplaceInstUsesWith(I, V);
// Handle cases involving: rem X, (select Cond, Y, Z)