diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineMulDivRem.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 174 |
1 files changed, 138 insertions, 36 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 8c48dce..c48e3c9 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -#include "InstCombine.h" +#include "InstCombineInternal.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" @@ -46,10 +46,10 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it // inexact. Similarly for <<. if (BinaryOperator *I = dyn_cast<BinaryOperator>(V)) - if (I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0), false, - 0, IC.getAssumptionTracker(), - CxtI, - IC.getDominatorTree())) { + if (I->isLogicalShift() && + isKnownToBeAPowerOfTwo(I->getOperand(0), false, 0, + IC.getAssumptionCache(), CxtI, + IC.getDominatorTree())) { // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { @@ -123,6 +123,48 @@ static Constant *getLogBase2Vector(ConstantDataVector *CV) { return ConstantVector::get(Elts); } +/// \brief Return true if we can prove that: +/// (mul LHS, RHS) === (mul nsw LHS, RHS) +bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS, + Instruction *CxtI) { + // Multiplying n * m significant bits yields a result of n + m significant + // bits. If the total number of significant bits does not exceed the + // result bit width (minus 1), there is no overflow. + // This means if we have enough leading sign bits in the operands + // we can guarantee that the result does not overflow. + // Ref: "Hacker's Delight" by Henry Warren + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + + // Note that underestimating the number of sign bits gives a more + // conservative answer. + unsigned SignBits = ComputeNumSignBits(LHS, 0, CxtI) + + ComputeNumSignBits(RHS, 0, CxtI); + + // First handle the easy case: if we have enough sign bits there's + // definitely no overflow. + if (SignBits > BitWidth + 1) + return true; + + // There are two ambiguous cases where there can be no overflow: + // SignBits == BitWidth + 1 and + // SignBits == BitWidth + // The second case is difficult to check, therefore we only handle the + // first case. + if (SignBits == BitWidth + 1) { + // It overflows only when both arguments are negative and the true + // product is exactly the minimum negative number. + // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 + // For simplicity we just check if at least one side is not negative. + bool LHSNonNegative, LHSNegative; + bool RHSNonNegative, RHSNegative; + ComputeSignBit(LHS, LHSNonNegative, LHSNegative, /*Depth=*/0, CxtI); + ComputeSignBit(RHS, RHSNonNegative, RHSNegative, /*Depth=*/0, CxtI); + if (LHSNonNegative || RHSNonNegative) + return true; + } + return false; +} + Instruction *InstCombiner::visitMul(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -130,14 +172,19 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AT)) + if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) return ReplaceInstUsesWith(I, V); - if (match(Op1, m_AllOnes())) // X * -1 == 0 - X - return BinaryOperator::CreateNeg(Op0, I.getName()); + // X * -1 == 0 - X + if (match(Op1, m_AllOnes())) { + BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName()); + if (I.hasNoSignedWrap()) + BO->setHasNoSignedWrap(); + return BO; + } // Also allow combining multiply instructions on vectors. { @@ -146,9 +193,18 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { const APInt *IVal; if (match(&I, m_Mul(m_Shl(m_Value(NewOp), m_Constant(C2)), m_Constant(C1))) && - match(C1, m_APInt(IVal))) - // ((X << C1)*C2) == (X * (C2 << C1)) - return BinaryOperator::CreateMul(NewOp, ConstantExpr::getShl(C1, C2)); + match(C1, m_APInt(IVal))) { + // ((X << C2)*C1) == (X * (C1 << C2)) + Constant *Shl = ConstantExpr::getShl(C1, C2); + BinaryOperator *Mul = cast<BinaryOperator>(I.getOperand(0)); + BinaryOperator *BO = BinaryOperator::CreateMul(NewOp, Shl); + if (I.hasNoUnsignedWrap() && Mul->hasNoUnsignedWrap()) + BO->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && Mul->hasNoSignedWrap() && + Shl->isNotMinSignedValue()) + BO->setHasNoSignedWrap(); + return BO; + } if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) { Constant *NewCst = nullptr; @@ -165,6 +221,8 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (I.hasNoUnsignedWrap()) Shl->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && NewCst->isNotMinSignedValue()) + Shl->setHasNoSignedWrap(); return Shl; } @@ -221,9 +279,16 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } - if (Value *Op0v = dyn_castNegVal(Op0)) // -X * -Y = X*Y - if (Value *Op1v = dyn_castNegVal(Op1)) - return BinaryOperator::CreateMul(Op0v, Op1v); + if (Value *Op0v = dyn_castNegVal(Op0)) { // -X * -Y = X*Y + if (Value *Op1v = dyn_castNegVal(Op1)) { + BinaryOperator *BO = BinaryOperator::CreateMul(Op0v, Op1v); + if (I.hasNoSignedWrap() && + match(Op0, m_NSWSub(m_Value(), m_Value())) && + match(Op1, m_NSWSub(m_Value(), m_Value()))) + BO->setHasNoSignedWrap(); + return BO; + } + } // (X / Y) * Y = X - (X % Y) // (X / Y) * -Y = (X % Y) - X @@ -272,10 +337,22 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { // (1 << Y)*X --> X << Y { Value *Y; - if (match(Op0, m_Shl(m_One(), m_Value(Y)))) - return BinaryOperator::CreateShl(Op1, Y); - if (match(Op1, m_Shl(m_One(), m_Value(Y)))) - return BinaryOperator::CreateShl(Op0, Y); + BinaryOperator *BO = nullptr; + bool ShlNSW = false; + if (match(Op0, m_Shl(m_One(), m_Value(Y)))) { + BO = BinaryOperator::CreateShl(Op1, Y); + ShlNSW = cast<ShlOperator>(Op0)->hasNoSignedWrap(); + } else if (match(Op1, m_Shl(m_One(), m_Value(Y)))) { + BO = BinaryOperator::CreateShl(Op0, Y); + ShlNSW = cast<ShlOperator>(Op1)->hasNoSignedWrap(); + } + if (BO) { + if (I.hasNoUnsignedWrap()) + BO->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && ShlNSW) + BO->setHasNoSignedWrap(); + return BO; + } } // If one of the operands of the multiply is a cast from a boolean value, then @@ -298,6 +375,18 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + if (!I.hasNoSignedWrap() && WillNotOverflowSignedMul(Op0, Op1, &I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + + if (!I.hasNoUnsignedWrap() && + computeOverflowForUnsignedMul(Op0, Op1, &I) == + OverflowResult::NeverOverflows) { + Changed = true; + I.setHasNoUnsignedWrap(true); + } + return Changed ? &I : nullptr; } @@ -441,8 +530,8 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (isa<Constant>(Op0)) std::swap(Op0, Op1); - if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, - DT, AT)) + if (Value *V = + SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); @@ -946,7 +1035,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AT)) + if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle the integer div common cases @@ -961,9 +1050,14 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { match(Op1, m_APInt(C2))) { bool Overflow; APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow); - if (!Overflow) - return BinaryOperator::CreateUDiv( + if (!Overflow) { + bool IsExact = I.isExact() && match(Op0, m_Exact(m_Value())); + BinaryOperator *BO = BinaryOperator::CreateUDiv( X, ConstantInt::get(X->getType(), C2ShlC1)); + if (IsExact) + BO->setIsExact(); + return BO; + } } } @@ -1014,7 +1108,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AT)) + if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1041,10 +1135,12 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return new ZExtInst(Builder->CreateICmpEQ(Op0, Op1), I.getType()); // -X/C --> X/-C provided the negation doesn't overflow. - if (SubOperator *Sub = dyn_cast<SubOperator>(Op0)) - if (match(Sub->getOperand(0), m_Zero()) && Sub->hasNoSignedWrap()) - return BinaryOperator::CreateSDiv(Sub->getOperand(1), - ConstantExpr::getNeg(RHS)); + Value *X; + if (match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) { + auto *BO = BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(RHS)); + BO->setIsExact(I.isExact()); + return BO; + } } // If the sign bits of both operands are zero (i.e. we can prove they are @@ -1054,15 +1150,19 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { if (MaskedValueIsZero(Op0, Mask, 0, &I)) { if (MaskedValueIsZero(Op1, Mask, 0, &I)) { // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set - return BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + BO->setIsExact(I.isExact()); + return BO; } - if (match(Op1, m_Shl(m_Power2(), m_Value()))) { + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, AC, &I, DT)) { // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y) // Safe because the only negative value (1 << Y) can take on is // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have // the sign bit set. - return BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + BO->setIsExact(I.isExact()); + return BO; } } } @@ -1106,7 +1206,8 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFDivInst(Op0, Op1, DL, TLI, DT, AT)) + if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(), + DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (isa<Constant>(Op0)) @@ -1271,7 +1372,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AT)) + if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) @@ -1284,7 +1385,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { I.getType()); // X urem Y -> X and Y-1, where Y is a power of 2, - if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true, 0, AT, &I, DT)) { + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, AC, &I, DT)) { Constant *N1 = Constant::getAllOnesValue(I.getType()); Value *Add = Builder->CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); @@ -1306,7 +1407,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AT)) + if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle the integer rem common cases @@ -1381,7 +1482,8 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFRemInst(Op0, Op1, DL, TLI, DT, AT)) + if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(), + DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) |