diff options
Diffstat (limited to 'lib/Transforms/InstCombine')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombine.h | 54 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAddSub.cpp | 188 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 372 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCalls.cpp | 193 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCasts.cpp | 129 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 227 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp | 192 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 333 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombinePHI.cpp | 12 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 45 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineShifts.cpp | 53 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp | 77 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineWorklist.h | 4 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstructionCombining.cpp | 252 |
14 files changed, 1524 insertions, 607 deletions
diff --git a/lib/Transforms/InstCombine/InstCombine.h b/lib/Transforms/InstCombine/InstCombine.h index ab4dc1c..d4b252b 100644 --- a/lib/Transforms/InstCombine/InstCombine.h +++ b/lib/Transforms/InstCombine/InstCombine.h @@ -7,16 +7,18 @@ // //===----------------------------------------------------------------------===// -#ifndef INSTCOMBINE_INSTCOMBINE_H -#define INSTCOMBINE_INSTCOMBINE_H +#ifndef LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINE_H +#define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINE_H #include "InstCombineWorklist.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" @@ -25,6 +27,7 @@ namespace llvm { class CallSite; class DataLayout; +class DominatorTree; class TargetLibraryInfo; class DbgDeclareInst; class MemIntrinsic; @@ -71,14 +74,20 @@ static inline Constant *SubOne(Constant *C) { class LLVM_LIBRARY_VISIBILITY InstCombineIRInserter : public IRBuilderDefaultInserter<true> { InstCombineWorklist &Worklist; + AssumptionTracker *AT; public: - InstCombineIRInserter(InstCombineWorklist &WL) : Worklist(WL) {} + InstCombineIRInserter(InstCombineWorklist &WL, AssumptionTracker *AT) + : Worklist(WL), AT(AT) {} void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB, BasicBlock::iterator InsertPt) const { IRBuilderDefaultInserter<true>::InsertHelper(I, Name, BB, InsertPt); Worklist.Add(I); + + using namespace llvm::PatternMatch; + if (match(I, m_Intrinsic<Intrinsic::assume>())) + AT->registerAssumption(cast<CallInst>(I)); } }; @@ -86,8 +95,10 @@ public: class LLVM_LIBRARY_VISIBILITY InstCombiner : public FunctionPass, public InstVisitor<InstCombiner, Instruction *> { + AssumptionTracker *AT; const DataLayout *DL; TargetLibraryInfo *TLI; + DominatorTree *DT; // not required bool MadeIRChange; LibCallSimplifier *Simplifier; bool MinimizeSize; @@ -114,7 +125,11 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override; + AssumptionTracker *getAssumptionTracker() const { return AT; } + const DataLayout *getDataLayout() const { return DL; } + + DominatorTree *getDominatorTree() const { return DT; } TargetLibraryInfo *getTargetLibraryInfo() const { return TLI; } @@ -148,10 +163,12 @@ public: Value *FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS); Value *FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Instruction *visitAnd(BinaryOperator &I); - Value *FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS); + Value *FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI); Value *FoldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Instruction *FoldOrWithConstants(BinaryOperator &I, Value *Op, Value *A, Value *B, Value *C); + Instruction *FoldXorWithConstants(BinaryOperator &I, Value *Op, Value *A, + Value *B, Value *C); Instruction *visitOr(BinaryOperator &I); Instruction *visitXor(BinaryOperator &I); Instruction *visitShl(BinaryOperator &I); @@ -172,6 +189,10 @@ public: ConstantInt *DivRHS); Instruction *FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *DivI, ConstantInt *DivRHS); + Instruction *FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, ConstantInt *CI2); + Instruction *FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, ConstantInt *CI2); Instruction *FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, ICmpInst::Predicate Pred); Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, @@ -213,6 +234,7 @@ public: Instruction *visitStoreInst(StoreInst &SI); Instruction *visitBranchInst(BranchInst &BI); Instruction *visitSwitchInst(SwitchInst &SI); + Instruction *visitReturnInst(ReturnInst &RI); Instruction *visitInsertValueInst(InsertValueInst &IV); Instruction *visitInsertElementInst(InsertElementInst &IE); Instruction *visitExtractElementInst(ExtractElementInst &EI); @@ -246,8 +268,10 @@ private: Instruction *transformZExtICmp(ICmpInst *ICI, Instruction &CI, bool DoXform = true); Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); - bool WillNotOverflowSignedAdd(Value *LHS, Value *RHS); - bool WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS); + bool WillNotOverflowSignedAdd(Value *LHS, Value *RHS, Instruction *CxtI); + bool WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS, Instruction *CxtI); + bool WillNotOverflowSignedSub(Value *LHS, Value *RHS, Instruction *CxtI); + bool WillNotOverflowUnsignedSub(Value *LHS, Value *RHS, Instruction *CxtI); Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); @@ -316,16 +340,19 @@ public: } void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, - unsigned Depth = 0) const { - return llvm::computeKnownBits(V, KnownZero, KnownOne, DL, Depth); + unsigned Depth = 0, Instruction *CxtI = nullptr) const { + return llvm::computeKnownBits(V, KnownZero, KnownOne, DL, Depth, + AT, CxtI, DT); } bool MaskedValueIsZero(Value *V, const APInt &Mask, - unsigned Depth = 0) const { - return llvm::MaskedValueIsZero(V, Mask, DL, Depth); + unsigned Depth = 0, + Instruction *CxtI = nullptr) const { + return llvm::MaskedValueIsZero(V, Mask, DL, Depth, AT, CxtI, DT); } - unsigned ComputeNumSignBits(Value *Op, unsigned Depth = 0) const { - return llvm::ComputeNumSignBits(Op, DL, Depth); + unsigned ComputeNumSignBits(Value *Op, unsigned Depth = 0, + Instruction *CxtI = nullptr) const { + return llvm::ComputeNumSignBits(Op, DL, Depth, AT, CxtI, DT); } private: @@ -343,7 +370,8 @@ private: /// SimplifyDemandedUseBits - Attempts to replace V with a simpler value /// based on the demanded bits. Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero, - APInt &KnownOne, unsigned Depth); + APInt &KnownOne, unsigned Depth, + Instruction *CxtI = nullptr); bool SimplifyDemandedBits(Use &U, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth = 0); /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 99f0f1f..902b640 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -32,7 +32,7 @@ namespace { /// class FAddendCoef { public: - // The constructor has to initialize a APFloat, which is uncessary for + // The constructor has to initialize a APFloat, which is unnecessary for // most addends which have coefficient either 1 or -1. So, the constructor // is expensive. In order to avoid the cost of the constructor, we should // reuse some instances whenever possible. The pre-created instances @@ -895,7 +895,8 @@ static bool checkRippleForAdd(const APInt &Op0KnownZero, /// This basically requires proving that the add in the original type would not /// overflow to change the sign bit or have a carry out. /// TODO: Handle this for Vectors. -bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) { +bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS, + Instruction *CxtI) { // There are different heuristics we can use for this. Here are some simple // ones. @@ -913,18 +914,19 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) { // // Since the carry into the most significant position is always equal to // the carry out of the addition, there is no signed overflow. - if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1) + if (ComputeNumSignBits(LHS, 0, CxtI) > 1 && + ComputeNumSignBits(RHS, 0, CxtI) > 1) return true; if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) { int BitWidth = IT->getBitWidth(); APInt LHSKnownZero(BitWidth, 0); APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI); APInt RHSKnownZero(BitWidth, 0); APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI); // Addition of two 2's compliment numbers having opposite signs will never // overflow. @@ -943,19 +945,69 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) { /// WillNotOverflowUnsignedAdd - Return true if we can prove that: /// (zext (add LHS, RHS)) === (add (zext LHS), (zext RHS)) -bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) { +bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS, + Instruction *CxtI) { // There are different heuristics we can use for this. Here is a simple one. // If the sign bit of LHS and that of RHS are both zero, no unsigned wrap. bool LHSKnownNonNegative, LHSKnownNegative; bool RHSKnownNonNegative, RHSKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0); - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT); + ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT); if (LHSKnownNonNegative && RHSKnownNonNegative) return true; return false; } +/// \brief Return true if we can prove that: +/// (sub LHS, RHS) === (sub nsw LHS, RHS) +/// This basically requires proving that the add in the original type would not +/// overflow to change the sign bit or have a carry out. +/// TODO: Handle this for Vectors. +bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS, + Instruction *CxtI) { + // If LHS and RHS each have at least two sign bits, the subtraction + // cannot overflow. + if (ComputeNumSignBits(LHS, 0, CxtI) > 1 && + ComputeNumSignBits(RHS, 0, CxtI) > 1) + return true; + + if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) { + unsigned BitWidth = IT->getBitWidth(); + APInt LHSKnownZero(BitWidth, 0); + APInt LHSKnownOne(BitWidth, 0); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI); + + APInt RHSKnownZero(BitWidth, 0); + APInt RHSKnownOne(BitWidth, 0); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI); + + // Subtraction of two 2's compliment numbers having identical signs will + // never overflow. + if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) || + (LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1])) + return true; + + // TODO: implement logic similar to checkRippleForAdd + } + return false; +} + +/// \brief Return true if we can prove that: +/// (sub LHS, RHS) === (sub nuw LHS, RHS) +bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS, + Instruction *CxtI) { + // If the LHS is negative and the RHS is non-negative, no unsigned wrap. + bool LHSKnownNonNegative, LHSKnownNegative; + bool RHSKnownNonNegative, RHSKnownNegative; + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT); + ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT); + if (LHSKnownNegative && RHSKnownNonNegative) + return true; + + return false; +} + // Checks if any operand is negative and we can convert add to sub. // This function checks for following negative patterns // ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C)) @@ -1025,7 +1077,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL)) + I.hasNoUnsignedWrap(), DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // (A*B)+(A*C) -> A*(B+C) etc @@ -1064,7 +1116,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (ExtendAmt) { APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt); - if (!MaskedValueIsZero(XorLHS, Mask)) + if (!MaskedValueIsZero(XorLHS, Mask, 0, &I)) ExtendAmt = 0; } @@ -1080,7 +1132,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { IntegerType *IT = cast<IntegerType>(I.getType()); APInt LHSKnownOne(IT->getBitWidth(), 0); APInt LHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne, 0, &I); if ((XorRHS->getValue() | LHSKnownZero).isAllOnesValue()) return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI), XorLHS); @@ -1133,11 +1185,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) { APInt LHSKnownOne(IT->getBitWidth(), 0); APInt LHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &I); if (LHSKnownZero != 0) { APInt RHSKnownOne(IT->getBitWidth(), 0); APInt RHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &I); // No bits in common -> bitwise or. if ((LHSKnownZero|RHSKnownZero).isAllOnesValue()) @@ -1215,7 +1267,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); if (LHSConv->hasOneUse() && ConstantExpr::getSExt(CI, I.getType()) == RHSC && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) { + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) { // Insert the new, smaller add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); @@ -1231,7 +1283,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&& (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && WillNotOverflowSignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0))) { + RHSConv->getOperand(0), &I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); @@ -1240,7 +1292,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } - // Check for (x & y) + (x ^ y) + // (add (xor A, B) (and A, B)) --> (or A, B) { Value *A = nullptr, *B = nullptr; if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && @@ -1254,14 +1306,36 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateOr(A, B); } + // (add (or A, B) (and A, B)) --> (add A, B) + { + Value *A = nullptr, *B = nullptr; + if (match(RHS, m_Or(m_Value(A), m_Value(B))) && + (match(LHS, m_And(m_Specific(A), m_Specific(B))) || + match(LHS, m_And(m_Specific(B), m_Specific(A))))) { + auto *New = BinaryOperator::CreateAdd(A, B); + New->setHasNoSignedWrap(I.hasNoSignedWrap()); + New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + return New; + } + + if (match(LHS, m_Or(m_Value(A), m_Value(B))) && + (match(RHS, m_And(m_Specific(A), m_Specific(B))) || + match(RHS, m_And(m_Specific(B), m_Specific(A))))) { + auto *New = BinaryOperator::CreateAdd(A, B); + New->setHasNoSignedWrap(I.hasNoSignedWrap()); + New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + return New; + } + } + // TODO(jingyue): Consider WillNotOverflowSignedAdd and // WillNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. - if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS)) { + if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS, &I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS)) { + if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS, &I)) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -1276,7 +1350,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL)) + if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, + TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (isa<Constant>(RHS)) { @@ -1318,7 +1393,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType()); if (LHSConv->hasOneUse() && ConstantExpr::getSIToFP(CI, I.getType()) == CFP && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) { + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); @@ -1334,7 +1409,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&& (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && WillNotOverflowSignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0))) { + RHSConv->getOperand(0), &I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), RHSConv->getOperand(0),"addconv"); @@ -1356,11 +1431,11 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { Z2 = dyn_cast<Constant>(B2); B = B1; } else if (match(B1, m_AnyZero()) && match(A2, m_AnyZero())) { Z1 = dyn_cast<Constant>(B1); B = B2; - Z2 = dyn_cast<Constant>(A2); A = A1; + Z2 = dyn_cast<Constant>(A2); A = A1; } - - if (Z1 && Z2 && - (I.hasNoSignedZeros() || + + if (Z1 && Z2 && + (I.hasNoSignedZeros() || (Z1->isNegativeZeroValue() && Z2->isNegativeZeroValue()))) { return SelectInst::Create(C, A, B); } @@ -1447,7 +1522,6 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, return Builder->CreateIntCast(Result, Ty, true); } - Instruction *InstCombiner::visitSub(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1455,18 +1529,27 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL)) + I.hasNoUnsignedWrap(), DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // (A*B)-(A*C) -> A*(B-C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return ReplaceInstUsesWith(I, V); - // If this is a 'B = x-(-A)', change to B = x+A. This preserves NSW/NUW. + // If this is a 'B = x-(-A)', change to B = x+A. if (Value *V = dyn_castNegVal(Op1)) { BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V); - Res->setHasNoSignedWrap(I.hasNoSignedWrap()); - Res->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + + if (const auto *BO = dyn_cast<BinaryOperator>(Op1)) { + assert(BO->getOpcode() == Instruction::Sub && + "Expected a subtraction operator!"); + if (BO->hasNoSignedWrap() && I.hasNoSignedWrap()) + Res->setHasNoSignedWrap(true); + } else { + if (cast<Constant>(Op1)->isNotMinSignedValue() && I.hasNoSignedWrap()) + Res->setHasNoSignedWrap(true); + } + return Res; } @@ -1511,21 +1594,23 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // -(X >>u 31) -> (X >>s 31) // -(X >>s 31) -> (X >>u 31) if (C->isZero()) { - Value *X; ConstantInt *CI; + Value *X; + ConstantInt *CI; if (match(Op1, m_LShr(m_Value(X), m_ConstantInt(CI))) && // Verify we are shifting out everything but the sign bit. - CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1) + CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1) return BinaryOperator::CreateAShr(X, CI); if (match(Op1, m_AShr(m_Value(X), m_ConstantInt(CI))) && // Verify we are shifting out everything but the sign bit. - CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1) + CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1) return BinaryOperator::CreateLShr(X, CI); } } - { Value *Y; + { + Value *Y; // X-(X+Y) == -Y X-(Y+X) == -Y if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) || match(Op1, m_Add(m_Value(Y), m_Specific(Op0)))) @@ -1536,6 +1621,24 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateNeg(Y); } + // (sub (or A, B) (xor A, B)) --> (and A, B) + { + Value *A = nullptr, *B = nullptr; + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && + (match(Op0, m_Or(m_Specific(A), m_Specific(B))) || + match(Op0, m_Or(m_Specific(B), m_Specific(A))))) + return BinaryOperator::CreateAnd(A, B); + } + + if (Op0->hasOneUse()) { + Value *Y = nullptr; + // ((X | Y) - X) --> (~X & Y) + if (match(Op0, m_Or(m_Value(Y), m_Specific(Op1))) || + match(Op0, m_Or(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateAnd( + Y, Builder->CreateNot(Op1, Op1->getName() + ".not")); + } + if (Op1->hasOneUse()) { Value *X = nullptr, *Y = nullptr, *Z = nullptr; Constant *C = nullptr; @@ -1555,7 +1658,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // 0 - (X sdiv C) -> (X sdiv -C) provided the negation doesn't overflow. if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) && match(Op0, m_Zero()) && - !C->isMinSignedValue()) + C->isNotMinSignedValue() && !C->isOneValue()) return BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(C)); // 0 - (X << Y) -> (-X << Y) when X is freely negatable. @@ -1595,7 +1698,17 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return ReplaceInstUsesWith(I, Res); } - return nullptr; + bool Changed = false; + if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, &I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1, &I)) { + Changed = true; + I.setHasNoUnsignedWrap(true); + } + + return Changed ? &I : nullptr; } Instruction *InstCombiner::visitFSub(BinaryOperator &I) { @@ -1604,7 +1717,8 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL)) + if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, + TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (isa<Constant>(Op0)) diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index b23a606..55ebced 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -355,7 +355,7 @@ Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS, if (isRunOfOnes(Mask, MB, ME)) { // begin/end bit of run, inclusive uint32_t BitWidth = cast<IntegerType>(RHS->getType())->getBitWidth(); APInt Mask(APInt::getLowBitsSet(BitWidth, MB-1)); - if (MaskedValueIsZero(RHS, Mask)) + if (MaskedValueIsZero(RHS, Mask, 0, &I)) break; } } @@ -614,7 +614,7 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, } else if (R1->getType()->isIntegerTy()) { if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) { // As before, model no mask as a trivial mask if it'll let us do an - // optimisation. + // optimization. R11 = R1; R12 = Constant::getAllOnesValue(R1->getType()); } @@ -665,8 +665,8 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, /// foldLogOpOfMaskedICmps: /// try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) /// into a single (icmp(A & X) ==/!= Y) -static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, - llvm::InstCombiner::BuilderTy* Builder) { +static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, + llvm::InstCombiner::BuilderTy *Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); unsigned mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, @@ -697,26 +697,26 @@ static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, if (mask & FoldMskICmp_Mask_AllZeroes) { // (icmp eq (A & B), 0) & (icmp eq (A & D), 0) // -> (icmp eq (A & (B|D)), 0) - Value* newOr = Builder->CreateOr(B, D); - Value* newAnd = Builder->CreateAnd(A, newOr); + Value *newOr = Builder->CreateOr(B, D); + Value *newAnd = Builder->CreateAnd(A, newOr); // we can't use C as zero, because we might actually handle // (icmp ne (A & B), B) & (icmp ne (A & D), D) // with B and D, having a single bit set - Value* zero = Constant::getNullValue(A->getType()); + Value *zero = Constant::getNullValue(A->getType()); return Builder->CreateICmp(NEWCC, newAnd, zero); } if (mask & FoldMskICmp_BMask_AllOnes) { // (icmp eq (A & B), B) & (icmp eq (A & D), D) // -> (icmp eq (A & (B|D)), (B|D)) - Value* newOr = Builder->CreateOr(B, D); - Value* newAnd = Builder->CreateAnd(A, newOr); + Value *newOr = Builder->CreateOr(B, D); + Value *newAnd = Builder->CreateAnd(A, newOr); return Builder->CreateICmp(NEWCC, newAnd, newOr); } if (mask & FoldMskICmp_AMask_AllOnes) { // (icmp eq (A & B), A) & (icmp eq (A & D), A) // -> (icmp eq (A & (B&D)), A) - Value* newAnd1 = Builder->CreateAnd(B, D); - Value* newAnd = Builder->CreateAnd(A, newAnd1); + Value *newAnd1 = Builder->CreateAnd(B, D); + Value *newAnd = Builder->CreateAnd(A, newAnd1); return Builder->CreateICmp(NEWCC, newAnd, A); } @@ -766,19 +766,17 @@ static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // with B and D, having a single bit set ConstantInt *CCst = dyn_cast<ConstantInt>(C); if (!CCst) return nullptr; - if (LHSCC != NEWCC) - CCst = dyn_cast<ConstantInt>( ConstantExpr::getXor(BCst, CCst) ); ConstantInt *ECst = dyn_cast<ConstantInt>(E); if (!ECst) return nullptr; + if (LHSCC != NEWCC) + CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst)); if (RHSCC != NEWCC) - ECst = dyn_cast<ConstantInt>( ConstantExpr::getXor(DCst, ECst) ); - ConstantInt* MCst = dyn_cast<ConstantInt>( - ConstantExpr::getAnd(ConstantExpr::getAnd(BCst, DCst), - ConstantExpr::getXor(CCst, ECst)) ); + ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); // if there is a conflict we should actually return a false for the // whole construct - if (!MCst->isZero()) - return nullptr; + if (((BCst->getValue() & DCst->getValue()) & + (CCst->getValue() ^ ECst->getValue())) != 0) + return ConstantInt::get(LHS->getType(), !IsAnd); Value *newOr1 = Builder->CreateOr(B, D); Value *newOr2 = ConstantExpr::getOr(CCst, ECst); Value *newAnd = Builder->CreateAnd(A, newOr1); @@ -930,6 +928,8 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { case ICmpInst::ICMP_ULT: if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13 return Builder->CreateICmpULT(Val, LHSCst); + if (LHSCst->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13 + return InsertRangeTest(Val, AddOne(LHSCst), RHSCst, false, true); break; // (X != 13 & X u< 15) -> no change case ICmpInst::ICMP_SLT: if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13 @@ -1101,7 +1101,6 @@ Value *InstCombiner::FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { return nullptr; } - Instruction *InstCombiner::visitAnd(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1109,7 +1108,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyAndInst(Op0, Op1, DL)) + if (Value *V = SimplifyAndInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // (A|B)&(A|C) -> A|(B&C) etc @@ -1136,14 +1135,14 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (!Op0I->hasOneUse()) break; APInt NotAndRHS(~AndRHSMask); - if (MaskedValueIsZero(Op0LHS, NotAndRHS)) { + if (MaskedValueIsZero(Op0LHS, NotAndRHS, 0, &I)) { // Not masking anything out for the LHS, move to RHS. Value *NewRHS = Builder->CreateAnd(Op0RHS, AndRHS, Op0RHS->getName()+".masked"); return BinaryOperator::Create(Op0I->getOpcode(), Op0LHS, NewRHS); } if (!isa<Constant>(Op0RHS) && - MaskedValueIsZero(Op0RHS, NotAndRHS)) { + MaskedValueIsZero(Op0RHS, NotAndRHS, 0, &I)) { // Not masking anything out for the RHS, move to LHS. Value *NewLHS = Builder->CreateAnd(Op0LHS, AndRHS, Op0LHS->getName()+".masked"); @@ -1176,7 +1175,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { uint32_t Zeros = AndRHSMask.countLeadingZeros(); APInt Mask = APInt::getLowBitsSet(BitWidth, BitWidth - Zeros); - if (MaskedValueIsZero(Op0LHS, Mask)) { + if (MaskedValueIsZero(Op0LHS, Mask, 0, &I)) { Value *NewNeg = Builder->CreateNeg(Op0RHS); return BinaryOperator::CreateAnd(NewNeg, AndRHS); } @@ -1283,13 +1282,58 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (match(Op1, m_Or(m_Not(m_Specific(Op0)), m_Value(A))) || match(Op1, m_Or(m_Value(A), m_Not(m_Specific(Op0))))) return BinaryOperator::CreateAnd(A, Op0); + + // (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) + if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) + if (Op1->hasOneUse() || cast<BinaryOperator>(Op1)->hasOneUse()) + return BinaryOperator::CreateAnd(Op0, Builder->CreateNot(C)); + + // ((A ^ C) ^ B) & (B ^ A) -> (B ^ A) & ~C + if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) + if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) + if (Op0->hasOneUse() || cast<BinaryOperator>(Op0)->hasOneUse()) + return BinaryOperator::CreateAnd(Op1, Builder->CreateNot(C)); + + // (A | B) & ((~A) ^ B) -> (A & B) + if (match(Op0, m_Or(m_Value(A), m_Value(B))) && + match(Op1, m_Xor(m_Not(m_Specific(A)), m_Specific(B)))) + return BinaryOperator::CreateAnd(A, B); + + // ((~A) ^ B) & (A | B) -> (A & B) + if (match(Op0, m_Xor(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_Or(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateAnd(A, B); } - if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1)) - if (ICmpInst *LHS = dyn_cast<ICmpInst>(Op0)) + { + ICmpInst *LHS = dyn_cast<ICmpInst>(Op0); + ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); + if (LHS && RHS) if (Value *Res = FoldAndOfICmps(LHS, RHS)) return ReplaceInstUsesWith(I, Res); + // TODO: Make this recursive; it's a little tricky because an arbitrary + // number of 'and' instructions might have to be created. + Value *X, *Y; + if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = FoldAndOfICmps(LHS, Cmp)) + return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = FoldAndOfICmps(LHS, Cmp)) + return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, X)); + } + if (RHS && match(Op0, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = FoldAndOfICmps(Cmp, RHS)) + return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = FoldAndOfICmps(Cmp, RHS)) + return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, X)); + } + } + // If and'ing two fcmp, try combine them into one. if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) @@ -1329,20 +1373,6 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { } } - // (X >> Z) & (Y >> Z) -> (X&Y) >> Z for all shifts. - if (BinaryOperator *SI1 = dyn_cast<BinaryOperator>(Op1)) { - if (BinaryOperator *SI0 = dyn_cast<BinaryOperator>(Op0)) - if (SI0->isShift() && SI0->getOpcode() == SI1->getOpcode() && - SI0->getOperand(1) == SI1->getOperand(1) && - (SI0->hasOneUse() || SI1->hasOneUse())) { - Value *NewOp = - Builder->CreateAnd(SI0->getOperand(0), SI1->getOperand(0), - SI0->getName()); - return BinaryOperator::Create(SI1->getOpcode(), NewOp, - SI1->getOperand(1)); - } - } - { Value *X = nullptr; bool OpsSwapped = false; @@ -1554,7 +1584,8 @@ static Instruction *MatchSelectFromAndOr(Value *A, Value *B, } /// FoldOrOfICmps - Fold (icmp)|(icmp) if possible. -Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { +Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, + Instruction *CxtI) { ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) @@ -1574,13 +1605,15 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { Value *Mask = nullptr; Value *Masked = nullptr; if (LAnd->getOperand(0) == RAnd->getOperand(0) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(1)) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(1))) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(1), false, 0, AT, CxtI, DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(1), false, 0, AT, CxtI, DT)) { Mask = Builder->CreateOr(LAnd->getOperand(1), RAnd->getOperand(1)); Masked = Builder->CreateAnd(LAnd->getOperand(0), Mask); } else if (LAnd->getOperand(1) == RAnd->getOperand(1) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(0)) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(0))) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(0), + false, 0, AT, CxtI, DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(0), + false, 0, AT, CxtI, DT)) { Mask = Builder->CreateOr(LAnd->getOperand(0), RAnd->getOperand(0)); Masked = Builder->CreateAnd(LAnd->getOperand(1), Mask); } @@ -1590,6 +1623,61 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { } } + // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3) + // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3) + // The original condition actually refers to the following two ranges: + // [MAX_UINT-C1+1, MAX_UINT-C1+1+C3] and [MAX_UINT-C2+1, MAX_UINT-C2+1+C3] + // We can fold these two ranges if: + // 1) C1 and C2 is unsigned greater than C3. + // 2) The two ranges are separated. + // 3) C1 ^ C2 is one-bit mask. + // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask. + // This implies all values in the two ranges differ by exactly one bit. + + if ((LHSCC == ICmpInst::ICMP_ULT || LHSCC == ICmpInst::ICMP_ULE) && + LHSCC == RHSCC && LHSCst && RHSCst && LHS->hasOneUse() && + RHS->hasOneUse() && LHSCst->getType() == RHSCst->getType() && + LHSCst->getValue() == (RHSCst->getValue())) { + + Value *LAdd = LHS->getOperand(0); + Value *RAdd = RHS->getOperand(0); + + Value *LAddOpnd, *RAddOpnd; + ConstantInt *LAddCst, *RAddCst; + if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddCst))) && + match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddCst))) && + LAddCst->getValue().ugt(LHSCst->getValue()) && + RAddCst->getValue().ugt(LHSCst->getValue())) { + + APInt DiffCst = LAddCst->getValue() ^ RAddCst->getValue(); + if (LAddOpnd == RAddOpnd && DiffCst.isPowerOf2()) { + ConstantInt *MaxAddCst = nullptr; + if (LAddCst->getValue().ult(RAddCst->getValue())) + MaxAddCst = RAddCst; + else + MaxAddCst = LAddCst; + + APInt RRangeLow = -RAddCst->getValue(); + APInt RRangeHigh = RRangeLow + LHSCst->getValue(); + APInt LRangeLow = -LAddCst->getValue(); + APInt LRangeHigh = LRangeLow + LHSCst->getValue(); + APInt LowRangeDiff = RRangeLow ^ LRangeLow; + APInt HighRangeDiff = RRangeHigh ^ LRangeHigh; + APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow + : RRangeLow - LRangeLow; + + if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && + RangeDiff.ugt(LHSCst->getValue())) { + Value *MaskCst = ConstantInt::get(LAddCst->getType(), ~DiffCst); + + Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskCst); + Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddCst); + return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSCst)); + } + } + } + } + // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) if (PredicatesFoldable(LHSCC, RHSCC)) { if (LHS->getOperand(0) == RHS->getOperand(1) && @@ -1906,6 +1994,38 @@ Instruction *InstCombiner::FoldOrWithConstants(BinaryOperator &I, Value *Op, return nullptr; } +/// \brief This helper function folds: +/// +/// ((A | B) & C1) ^ (B & C2) +/// +/// into: +/// +/// (A & C1) ^ B +/// +/// when the XOR of the two constants is "all ones" (-1). +Instruction *InstCombiner::FoldXorWithConstants(BinaryOperator &I, Value *Op, + Value *A, Value *B, Value *C) { + ConstantInt *CI1 = dyn_cast<ConstantInt>(C); + if (!CI1) + return nullptr; + + Value *V1 = nullptr; + ConstantInt *CI2 = nullptr; + if (!match(Op, m_And(m_Value(V1), m_ConstantInt(CI2)))) + return nullptr; + + APInt Xor = CI1->getValue() ^ CI2->getValue(); + if (!Xor.isAllOnesValue()) + return nullptr; + + if (V1 == A || V1 == B) { + Value *NewOp = Builder->CreateAnd(V1 == A ? B : A, CI1); + return BinaryOperator::CreateXor(NewOp, V1); + } + + return nullptr; +} + Instruction *InstCombiner::visitOr(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1913,7 +2033,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyOrInst(Op0, Op1, DL)) + if (Value *V = SimplifyOrInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // (A&B)|(A&C) -> A&(B|C) etc @@ -1973,7 +2093,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // (X^C)|Y -> (X|Y)^C iff Y&C == 0 if (Op0->hasOneUse() && match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op1, C1->getValue())) { + MaskedValueIsZero(Op1, C1->getValue(), 0, &I)) { Value *NOr = Builder->CreateOr(A, Op1); NOr->takeName(Op0); return BinaryOperator::CreateXor(NOr, C1); @@ -1982,12 +2102,32 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // Y|(X^C) -> (X|Y)^C iff Y&C == 0 if (Op1->hasOneUse() && match(Op1, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op0, C1->getValue())) { + MaskedValueIsZero(Op0, C1->getValue(), 0, &I)) { Value *NOr = Builder->CreateOr(A, Op0); NOr->takeName(Op0); return BinaryOperator::CreateXor(NOr, C1); } + // ((~A & B) | A) -> (A | B) + if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_Specific(A))) + return BinaryOperator::CreateOr(A, B); + + // ((A & B) | ~A) -> (~A | B) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_Not(m_Specific(A)))) + return BinaryOperator::CreateOr(Builder->CreateNot(A), B); + + // (A & (~B)) | (A ^ B) -> (A ^ B) + if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_Xor(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateXor(A, B); + + // (A ^ B) | ( A & (~B)) -> (A ^ B) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_And(m_Specific(A), m_Not(m_Specific(B))))) + return BinaryOperator::CreateXor(A, B); + // (A & C)|(B & D) Value *C = nullptr, *D = nullptr; if (match(Op0, m_And(m_Value(A), m_Value(C))) && @@ -2000,14 +2140,18 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) // iff (C1&C2) == 0 and (N&~C1) == 0 if (match(A, m_Or(m_Value(V1), m_Value(V2))) && - ((V1 == B && MaskedValueIsZero(V2, ~C1->getValue())) || // (V|N) - (V2 == B && MaskedValueIsZero(V1, ~C1->getValue())))) // (N|V) + ((V1 == B && + MaskedValueIsZero(V2, ~C1->getValue(), 0, &I)) || // (V|N) + (V2 == B && + MaskedValueIsZero(V1, ~C1->getValue(), 0, &I)))) // (N|V) return BinaryOperator::CreateAnd(A, Builder->getInt(C1->getValue()|C2->getValue())); // Or commutes, try both ways. if (match(B, m_Or(m_Value(V1), m_Value(V2))) && - ((V1 == A && MaskedValueIsZero(V2, ~C2->getValue())) || // (V|N) - (V2 == A && MaskedValueIsZero(V1, ~C2->getValue())))) // (N|V) + ((V1 == A && + MaskedValueIsZero(V2, ~C2->getValue(), 0, &I)) || // (V|N) + (V2 == A && + MaskedValueIsZero(V1, ~C2->getValue(), 0, &I)))) // (N|V) return BinaryOperator::CreateAnd(B, Builder->getInt(C1->getValue()|C2->getValue())); @@ -2068,20 +2212,35 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Instruction *Ret = FoldOrWithConstants(I, Op0, A, V1, D); if (Ret) return Ret; } + // ((A^B)&1)|(B&-2) -> (A&1) ^ B + if (match(A, m_Xor(m_Value(V1), m_Specific(B))) || + match(A, m_Xor(m_Specific(B), m_Value(V1)))) { + Instruction *Ret = FoldXorWithConstants(I, Op1, V1, B, C); + if (Ret) return Ret; + } + // (B&-2)|((A^B)&1) -> (A&1) ^ B + if (match(B, m_Xor(m_Specific(A), m_Value(V1))) || + match(B, m_Xor(m_Value(V1), m_Specific(A)))) { + Instruction *Ret = FoldXorWithConstants(I, Op0, A, V1, D); + if (Ret) return Ret; + } } - // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. - if (BinaryOperator *SI1 = dyn_cast<BinaryOperator>(Op1)) { - if (BinaryOperator *SI0 = dyn_cast<BinaryOperator>(Op0)) - if (SI0->isShift() && SI0->getOpcode() == SI1->getOpcode() && - SI0->getOperand(1) == SI1->getOperand(1) && - (SI0->hasOneUse() || SI1->hasOneUse())) { - Value *NewOp = Builder->CreateOr(SI0->getOperand(0), SI1->getOperand(0), - SI0->getName()); - return BinaryOperator::Create(SI1->getOpcode(), NewOp, - SI1->getOperand(1)); - } - } + // (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) + if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) + if (Op1->hasOneUse() || cast<BinaryOperator>(Op1)->hasOneUse()) + return BinaryOperator::CreateOr(Op0, C); + + // ((A ^ C) ^ B) | (B ^ A) -> (B ^ A) | C + if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) + if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) + if (Op0->hasOneUse() || cast<BinaryOperator>(Op0)->hasOneUse()) + return BinaryOperator::CreateOr(Op1, C); + + // ((B | C) & A) | B -> B | (A & C) + if (match(Op0, m_And(m_Or(m_Specific(Op1), m_Value(C)), m_Value(A)))) + return BinaryOperator::CreateOr(Op1, Builder->CreateAnd(A, C)); // (~A | ~B) == (~(A & B)) - De Morgan's Law if (Value *Op0NotVal = dyn_castNotVal(Op0)) @@ -2133,12 +2292,22 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return BinaryOperator::CreateOr(Not, Op0); } + // (A & B) | ((~A) ^ B) -> (~A ^ B) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_Xor(m_Not(m_Specific(A)), m_Specific(B)))) + return BinaryOperator::CreateXor(Builder->CreateNot(A), B); + + // ((~A) ^ B) | (A & B) -> (~A ^ B) + if (match(Op0, m_Xor(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateXor(Builder->CreateNot(A), B); + if (SwappedForXor) std::swap(Op0, Op1); if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) - if (Value *Res = FoldOrOfICmps(LHS, RHS)) + if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) return ReplaceInstUsesWith(I, Res); // (fcmp uno x, c) | (fcmp uno y, c) -> (fcmp uno x, y) @@ -2169,7 +2338,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // cast is otherwise not optimizable. This happens for vector sexts. if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1COp)) if (ICmpInst *LHS = dyn_cast<ICmpInst>(Op0COp)) - if (Value *Res = FoldOrOfICmps(LHS, RHS)) + if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) return CastInst::Create(Op0C->getOpcode(), Res, I.getType()); // If this is or(cast(fcmp), cast(fcmp)), try to fold this even if the @@ -2225,7 +2394,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyXorInst(Op0, Op1, DL)) + if (Value *V = SimplifyXorInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // (A&B)^(A&C) -> A&(B^C) etc @@ -2327,7 +2496,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } else if (Op0I->getOpcode() == Instruction::Or) { // (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0 - if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue())) { + if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue(), + 0, &I)) { Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHS); // Anything in both C1 and C2 is known to be zero, remove it from // NewRHS. @@ -2418,18 +2588,6 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - // (X >> Z) ^ (Y >> Z) -> (X^Y) >> Z for all shifts. - if (Op0I && Op1I && Op0I->isShift() && - Op0I->getOpcode() == Op1I->getOpcode() && - Op0I->getOperand(1) == Op1I->getOperand(1) && - (Op0I->hasOneUse() || Op1I->hasOneUse())) { - Value *NewOp = - Builder->CreateXor(Op0I->getOperand(0), Op1I->getOperand(0), - Op0I->getName()); - return BinaryOperator::Create(Op1I->getOpcode(), NewOp, - Op1I->getOperand(1)); - } - if (Op0I && Op1I) { Value *A, *B, *C, *D; // (A & B)^(A | B) -> A ^ B @@ -2444,8 +2602,62 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if ((A == C && B == D) || (A == D && B == C)) return BinaryOperator::CreateXor(A, B); } + // (A | ~B) ^ (~A | B) -> A ^ B + if (match(Op0I, m_Or(m_Value(A), m_Not(m_Value(B)))) && + match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) { + return BinaryOperator::CreateXor(A, B); + } + // (~A | B) ^ (A | ~B) -> A ^ B + if (match(Op0I, m_Or(m_Not(m_Value(A)), m_Value(B))) && + match(Op1I, m_Or(m_Specific(A), m_Not(m_Specific(B))))) { + return BinaryOperator::CreateXor(A, B); + } + // (A & ~B) ^ (~A & B) -> A ^ B + if (match(Op0I, m_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B)))) { + return BinaryOperator::CreateXor(A, B); + } + // (~A & B) ^ (A & ~B) -> A ^ B + if (match(Op0I, m_And(m_Not(m_Value(A)), m_Value(B))) && + match(Op1I, m_And(m_Specific(A), m_Not(m_Specific(B))))) { + return BinaryOperator::CreateXor(A, B); + } + // (A ^ C)^(A | B) -> ((~A) & B) ^ C + if (match(Op0I, m_Xor(m_Value(D), m_Value(C))) && + match(Op1I, m_Or(m_Value(A), m_Value(B)))) { + if (D == A) + return BinaryOperator::CreateXor( + Builder->CreateAnd(Builder->CreateNot(A), B), C); + if (D == B) + return BinaryOperator::CreateXor( + Builder->CreateAnd(Builder->CreateNot(B), A), C); + } + // (A | B)^(A ^ C) -> ((~A) & B) ^ C + if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && + match(Op1I, m_Xor(m_Value(D), m_Value(C)))) { + if (D == A) + return BinaryOperator::CreateXor( + Builder->CreateAnd(Builder->CreateNot(A), B), C); + if (D == B) + return BinaryOperator::CreateXor( + Builder->CreateAnd(Builder->CreateNot(B), A), C); + } + // (A & B) ^ (A ^ B) -> (A | B) + if (match(Op0I, m_And(m_Value(A), m_Value(B))) && + match(Op1I, m_Xor(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); + // (A ^ B) ^ (A & B) -> (A | B) + if (match(Op0I, m_Xor(m_Value(A), m_Value(B))) && + match(Op1I, m_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); } + Value *A = nullptr, *B = nullptr; + // (A & ~B) ^ (~A) -> ~(A & B) + if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_Not(m_Specific(A)))) + return BinaryOperator::CreateNot(Builder->CreateAnd(A, B)); + // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 658178d..87e49a1 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -16,6 +16,7 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" @@ -58,8 +59,8 @@ static Type *reduceToSingleValueType(Type *T) { } Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { - unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL); - unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL); + unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, AT, MI, DT); + unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, AT, MI, DT); unsigned MinAlign = std::min(DstAlign, SrcAlign); unsigned CopyAlign = MI->getAlignment(); @@ -154,7 +155,7 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { } Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { - unsigned Alignment = getKnownAlignment(MI->getDest(), DL); + unsigned Alignment = getKnownAlignment(MI->getDest(), DL, AT, MI, DT); if (MI->getAlignment() < Alignment) { MI->setAlignment(ConstantInt::get(MI->getAlignmentType(), Alignment, false)); @@ -322,7 +323,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { uint32_t BitWidth = IT->getBitWidth(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne); + computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II); unsigned TrailingZeros = KnownOne.countTrailingZeros(); APInt Mask(APInt::getLowBitsSet(BitWidth, TrailingZeros)); if ((Mask & KnownZero) == Mask) @@ -340,7 +341,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { uint32_t BitWidth = IT->getBitWidth(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne); + computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II); unsigned LeadingZeros = KnownOne.countLeadingZeros(); APInt Mask(APInt::getHighBitsSet(BitWidth, LeadingZeros)); if ((Mask & KnownZero) == Mask) @@ -355,14 +356,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { uint32_t BitWidth = IT->getBitWidth(); APInt LHSKnownZero(BitWidth, 0); APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, II); bool LHSKnownNegative = LHSKnownOne[BitWidth - 1]; bool LHSKnownPositive = LHSKnownZero[BitWidth - 1]; if (LHSKnownNegative || LHSKnownPositive) { APInt RHSKnownZero(BitWidth, 0); APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, II); bool RHSKnownNegative = RHSKnownOne[BitWidth - 1]; bool RHSKnownPositive = RHSKnownZero[BitWidth - 1]; if (LHSKnownNegative && RHSKnownNegative) { @@ -426,7 +427,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // can prove that it will never overflow. if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow) { Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - if (WillNotOverflowSignedAdd(LHS, RHS)) { + if (WillNotOverflowSignedAdd(LHS, RHS, II)) { Value *Add = Builder->CreateNSWAdd(LHS, RHS); Add->takeName(&CI); Constant *V[] = {UndefValue::get(Add->getType()), Builder->getFalse()}; @@ -464,10 +465,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { APInt LHSKnownZero(BitWidth, 0); APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, II); APInt RHSKnownZero(BitWidth, 0); APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, II); // Get the largest possible values for each operand. APInt LHSMax = ~LHSKnownZero; @@ -518,30 +519,131 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } } break; + case Intrinsic::minnum: + case Intrinsic::maxnum: { + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + + // fmin(x, x) -> x + if (Arg0 == Arg1) + return ReplaceInstUsesWith(CI, Arg0); + + const ConstantFP *C0 = dyn_cast<ConstantFP>(Arg0); + const ConstantFP *C1 = dyn_cast<ConstantFP>(Arg1); + + // Canonicalize constants into the RHS. + if (C0 && !C1) { + II->setArgOperand(0, Arg1); + II->setArgOperand(1, Arg0); + return II; + } + + // fmin(x, nan) -> x + if (C1 && C1->isNaN()) + return ReplaceInstUsesWith(CI, Arg0); + + // This is the value because if undef were NaN, we would return the other + // value and cannot return a NaN unless both operands are. + // + // fmin(undef, x) -> x + if (isa<UndefValue>(Arg0)) + return ReplaceInstUsesWith(CI, Arg1); + + // fmin(x, undef) -> x + if (isa<UndefValue>(Arg1)) + return ReplaceInstUsesWith(CI, Arg0); + + Value *X = nullptr; + Value *Y = nullptr; + if (II->getIntrinsicID() == Intrinsic::minnum) { + // fmin(x, fmin(x, y)) -> fmin(x, y) + // fmin(y, fmin(x, y)) -> fmin(x, y) + if (match(Arg1, m_FMin(m_Value(X), m_Value(Y)))) { + if (Arg0 == X || Arg0 == Y) + return ReplaceInstUsesWith(CI, Arg1); + } + + // fmin(fmin(x, y), x) -> fmin(x, y) + // fmin(fmin(x, y), y) -> fmin(x, y) + if (match(Arg0, m_FMin(m_Value(X), m_Value(Y)))) { + if (Arg1 == X || Arg1 == Y) + return ReplaceInstUsesWith(CI, Arg0); + } + + // TODO: fmin(nnan x, inf) -> x + // TODO: fmin(nnan ninf x, flt_max) -> x + if (C1 && C1->isInfinity()) { + // fmin(x, -inf) -> -inf + if (C1->isNegative()) + return ReplaceInstUsesWith(CI, Arg1); + } + } else { + assert(II->getIntrinsicID() == Intrinsic::maxnum); + // fmax(x, fmax(x, y)) -> fmax(x, y) + // fmax(y, fmax(x, y)) -> fmax(x, y) + if (match(Arg1, m_FMax(m_Value(X), m_Value(Y)))) { + if (Arg0 == X || Arg0 == Y) + return ReplaceInstUsesWith(CI, Arg1); + } + + // fmax(fmax(x, y), x) -> fmax(x, y) + // fmax(fmax(x, y), y) -> fmax(x, y) + if (match(Arg0, m_FMax(m_Value(X), m_Value(Y)))) { + if (Arg1 == X || Arg1 == Y) + return ReplaceInstUsesWith(CI, Arg0); + } + + // TODO: fmax(nnan x, -inf) -> x + // TODO: fmax(nnan ninf x, -flt_max) -> x + if (C1 && C1->isInfinity()) { + // fmax(x, inf) -> inf + if (!C1->isNegative()) + return ReplaceInstUsesWith(CI, Arg1); + } + } + break; + } case Intrinsic::ppc_altivec_lvx: case Intrinsic::ppc_altivec_lvxl: // Turn PPC lvx -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL) >= 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, + DL, AT, II, DT) >= 16) { Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); return new LoadInst(Ptr); } break; + case Intrinsic::ppc_vsx_lxvw4x: + case Intrinsic::ppc_vsx_lxvd2x: { + // Turn PPC VSX loads into normal loads. + Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), + PointerType::getUnqual(II->getType())); + return new LoadInst(Ptr, Twine(""), false, 1); + } case Intrinsic::ppc_altivec_stvx: case Intrinsic::ppc_altivec_stvxl: // Turn stvx -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL) >= 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, + DL, AT, II, DT) >= 16) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); return new StoreInst(II->getArgOperand(0), Ptr); } break; + case Intrinsic::ppc_vsx_stxvw4x: + case Intrinsic::ppc_vsx_stxvd2x: { + // Turn PPC VSX stores into normal stores. + Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); + Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); + return new StoreInst(II->getArgOperand(0), Ptr, false, 1); + } case Intrinsic::x86_sse_storeu_ps: case Intrinsic::x86_sse2_storeu_pd: case Intrinsic::x86_sse2_storeu_dq: // Turn X86 storeu -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL) >= 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, + DL, AT, II, DT) >= 16) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(1)->getType()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), OpPtrTy); @@ -680,7 +782,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { CI, Builder->CreateShuffleVector( Vec, Undef, ConstantDataVector::get( - II->getContext(), ArrayRef<uint32_t>(Mask)))); + II->getContext(), makeArrayRef(Mask)))); } else if (auto Source = dyn_cast<IntrinsicInst>(II->getArgOperand(0))) { @@ -886,7 +988,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::arm_neon_vst2lane: case Intrinsic::arm_neon_vst3lane: case Intrinsic::arm_neon_vst4lane: { - unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), DL); + unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), DL, AT, II, DT); unsigned AlignArg = II->getNumArgOperands() - 1; ConstantInt *IntrAlign = dyn_cast<ConstantInt>(II->getArgOperand(AlignArg)); if (IntrAlign && IntrAlign->getZExtValue() < MemAlign) { @@ -994,6 +1096,55 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return EraseInstFromFunction(CI); break; } + case Intrinsic::assume: { + // Canonicalize assume(a && b) -> assume(a); assume(b); + // Note: New assumption intrinsics created here are registered by + // the InstCombineIRInserter object. + Value *IIOperand = II->getArgOperand(0), *A, *B, + *AssumeIntrinsic = II->getCalledValue(); + if (match(IIOperand, m_And(m_Value(A), m_Value(B)))) { + Builder->CreateCall(AssumeIntrinsic, A, II->getName()); + Builder->CreateCall(AssumeIntrinsic, B, II->getName()); + return EraseInstFromFunction(*II); + } + // assume(!(a || b)) -> assume(!a); assume(!b); + if (match(IIOperand, m_Not(m_Or(m_Value(A), m_Value(B))))) { + Builder->CreateCall(AssumeIntrinsic, Builder->CreateNot(A), + II->getName()); + Builder->CreateCall(AssumeIntrinsic, Builder->CreateNot(B), + II->getName()); + return EraseInstFromFunction(*II); + } + + // assume( (load addr) != null ) -> add 'nonnull' metadata to load + // (if assume is valid at the load) + if (ICmpInst* ICmp = dyn_cast<ICmpInst>(IIOperand)) { + Value *LHS = ICmp->getOperand(0); + Value *RHS = ICmp->getOperand(1); + if (ICmpInst::ICMP_NE == ICmp->getPredicate() && + isa<LoadInst>(LHS) && + isa<Constant>(RHS) && + RHS->getType()->isPointerTy() && + cast<Constant>(RHS)->isNullValue()) { + LoadInst* LI = cast<LoadInst>(LHS); + if (isValidAssumeForContext(II, LI, DL, DT)) { + MDNode* MD = MDNode::get(II->getContext(), ArrayRef<Value*>()); + LI->setMetadata(LLVMContext::MD_nonnull, MD); + return EraseInstFromFunction(*II); + } + } + // TODO: apply nonnull return attributes to calls and invokes + // TODO: apply range metadata for range check patterns? + } + // If there is a dominating assume with the same condition as this one, + // then this one is redundant, and should be removed. + APInt KnownZero(1, 0), KnownOne(1, 0); + computeKnownBits(IIOperand, KnownZero, KnownOne, 0, II); + if (KnownOne.isAllOnesValue()) + return EraseInstFromFunction(*II); + + break; + } } return visitCallSite(II); @@ -1253,7 +1404,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { if (!Caller->use_empty() && // void -> non-void is handled specially !NewRetTy->isVoidTy()) - return false; // Cannot transform this return value. + return false; // Cannot transform this return value. } if (!CallerPAL.isEmpty() && !Caller->use_empty()) { @@ -1472,8 +1623,14 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { if (!Caller->use_empty()) ReplaceInstUsesWith(*Caller, NV); - else if (Caller->hasValueHandle()) - ValueHandleBase::ValueIsRAUWd(Caller, NV); + else if (Caller->hasValueHandle()) { + if (OldRetTy == NV->getType()) + ValueHandleBase::ValueIsRAUWd(Caller, NV); + else + // We cannot call ValueIsRAUWd with a different type, and the + // actual tracked value will disappear. + ValueHandleBase::ValueIsDeleted(Caller); + } EraseInstFromFunction(*Caller); return true; diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index ff083d7..aba77bb 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -335,7 +335,8 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { /// /// This function works on both vectors and scalars. /// -static bool CanEvaluateTruncated(Value *V, Type *Ty) { +static bool CanEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, + Instruction *CxtI) { // We can always evaluate constants in another type. if (isa<Constant>(V)) return true; @@ -364,8 +365,8 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { case Instruction::Or: case Instruction::Xor: // These operators can all arbitrarily be extended or truncated. - return CanEvaluateTruncated(I->getOperand(0), Ty) && - CanEvaluateTruncated(I->getOperand(1), Ty); + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + CanEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); case Instruction::UDiv: case Instruction::URem: { @@ -374,10 +375,10 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { uint32_t BitWidth = Ty->getScalarSizeInBits(); if (BitWidth < OrigBitWidth) { APInt Mask = APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth); - if (MaskedValueIsZero(I->getOperand(0), Mask) && - MaskedValueIsZero(I->getOperand(1), Mask)) { - return CanEvaluateTruncated(I->getOperand(0), Ty) && - CanEvaluateTruncated(I->getOperand(1), Ty); + if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, CxtI) && + IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, CxtI)) { + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + CanEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); } } break; @@ -388,7 +389,7 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { uint32_t BitWidth = Ty->getScalarSizeInBits(); if (CI->getLimitedValue(BitWidth) < BitWidth) - return CanEvaluateTruncated(I->getOperand(0), Ty); + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } break; case Instruction::LShr: @@ -398,10 +399,10 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (MaskedValueIsZero(I->getOperand(0), - APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) && + if (IC.MaskedValueIsZero(I->getOperand(0), + APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth), 0, CxtI) && CI->getLimitedValue(BitWidth) < BitWidth) { - return CanEvaluateTruncated(I->getOperand(0), Ty); + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } } break; @@ -415,8 +416,8 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { return true; case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); - return CanEvaluateTruncated(SI->getTrueValue(), Ty) && - CanEvaluateTruncated(SI->getFalseValue(), Ty); + return CanEvaluateTruncated(SI->getTrueValue(), Ty, IC, CxtI) && + CanEvaluateTruncated(SI->getFalseValue(), Ty, IC, CxtI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -424,7 +425,7 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (!CanEvaluateTruncated(PN->getIncomingValue(i), Ty)) + if (!CanEvaluateTruncated(PN->getIncomingValue(i), Ty, IC, CxtI)) return false; return true; } @@ -453,7 +454,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // expression tree to something weird like i93 unless the source is also // strange. if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && - CanEvaluateTruncated(Src, DestTy)) { + CanEvaluateTruncated(Src, DestTy, *this, &CI)) { // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. @@ -553,7 +554,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, // If Op1C some other power of two, convert: uint32_t BitWidth = Op1C->getType()->getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(ICI->getOperand(0), KnownZero, KnownOne); + computeKnownBits(ICI->getOperand(0), KnownZero, KnownOne, 0, &CI); APInt KnownZeroMask(~KnownZero); if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? @@ -601,8 +602,8 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, APInt KnownZeroLHS(BitWidth, 0), KnownOneLHS(BitWidth, 0); APInt KnownZeroRHS(BitWidth, 0), KnownOneRHS(BitWidth, 0); - computeKnownBits(LHS, KnownZeroLHS, KnownOneLHS); - computeKnownBits(RHS, KnownZeroRHS, KnownOneRHS); + computeKnownBits(LHS, KnownZeroLHS, KnownOneLHS, 0, &CI); + computeKnownBits(RHS, KnownZeroRHS, KnownOneRHS, 0, &CI); if (KnownZeroLHS == KnownZeroRHS && KnownOneLHS == KnownOneRHS) { APInt KnownBits = KnownZeroLHS | KnownOneLHS; @@ -651,7 +652,8 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, /// clear the top bits anyway, doing this has no extra cost. /// /// This function works on both vectors and scalars. -static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { +static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, + InstCombiner &IC, Instruction *CxtI) { BitsToClear = 0; if (isa<Constant>(V)) return true; @@ -680,8 +682,8 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { case Instruction::Add: case Instruction::Sub: case Instruction::Mul: - if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear) || - !CanEvaluateZExtd(I->getOperand(1), Ty, Tmp)) + if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI) || + !CanEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI)) return false; // These can all be promoted if neither operand has 'bits to clear'. if (BitsToClear == 0 && Tmp == 0) @@ -695,8 +697,9 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // We use MaskedValueIsZero here for generality, but the case we care // about the most is constant RHS. unsigned VSize = V->getType()->getScalarSizeInBits(); - if (MaskedValueIsZero(I->getOperand(1), - APInt::getHighBitsSet(VSize, BitsToClear))) + if (IC.MaskedValueIsZero(I->getOperand(1), + APInt::getHighBitsSet(VSize, BitsToClear), + 0, CxtI)) return true; } @@ -707,7 +710,7 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // We can promote shl(x, cst) if we can promote x. Since shl overwrites the // upper bits we can reduce BitsToClear by the shift amount. if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { - if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear)) + if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; uint64_t ShiftAmt = Amt->getZExtValue(); BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0; @@ -718,7 +721,7 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // We can promote lshr(x, cst) if we can promote x. This requires the // ultimate 'and' to clear out the high zero bits we're clearing out though. if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { - if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear)) + if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; BitsToClear += Amt->getZExtValue(); if (BitsToClear > V->getType()->getScalarSizeInBits()) @@ -728,8 +731,8 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // Cannot promote variable LSHR. return false; case Instruction::Select: - if (!CanEvaluateZExtd(I->getOperand(1), Ty, Tmp) || - !CanEvaluateZExtd(I->getOperand(2), Ty, BitsToClear) || + if (!CanEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) || + !CanEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) || // TODO: If important, we could handle the case when the BitsToClear are // known zero in the disagreeing side. Tmp != BitsToClear) @@ -741,10 +744,10 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // get into trouble with cyclic PHIs here because we only consider // instructions with a single use. PHINode *PN = cast<PHINode>(I); - if (!CanEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear)) + if (!CanEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear, IC, CxtI)) return false; for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) - if (!CanEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp) || + if (!CanEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp, IC, CxtI) || // TODO: If important, we could handle the case when the BitsToClear // are known zero in the disagreeing input. Tmp != BitsToClear) @@ -781,7 +784,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { // strange. unsigned BitsToClear; if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && - CanEvaluateZExtd(Src, DestTy, BitsToClear)) { + CanEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) { assert(BitsToClear < SrcTy->getScalarSizeInBits() && "Unreasonable BitsToClear"); @@ -796,8 +799,10 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { // If the high bits are already filled with zeros, just replace this // cast with the result. - if (MaskedValueIsZero(Res, APInt::getHighBitsSet(DestBitSize, - DestBitSize-SrcBitsKept))) + if (MaskedValueIsZero(Res, + APInt::getHighBitsSet(DestBitSize, + DestBitSize-SrcBitsKept), + 0, &CI)) return ReplaceInstUsesWith(CI, Res); // We need to emit an AND to clear the high bits. @@ -895,6 +900,10 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { Value *Op0 = ICI->getOperand(0), *Op1 = ICI->getOperand(1); ICmpInst::Predicate Pred = ICI->getPredicate(); + // Don't bother if Op1 isn't of vector or integer type. + if (!Op1->getType()->isIntOrIntVectorTy()) + return nullptr; + if (Constant *Op1C = dyn_cast<Constant>(Op1)) { // (x <s 0) ? -1 : 0 -> ashr x, 31 -> all ones if negative // (x >s -1) ? -1 : 0 -> not (ashr x, 31) -> all ones if positive @@ -921,7 +930,7 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { ICI->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){ unsigned BitWidth = Op1C->getType()->getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(Op0, KnownZero, KnownOne); + computeKnownBits(Op0, KnownZero, KnownOne, 0, &CI); APInt KnownZeroMask(~KnownZero); if (KnownZeroMask.isPowerOf2()) { @@ -1072,7 +1081,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // If the high bits are already filled with sign bit, just replace this // cast with the result. - if (ComputeNumSignBits(Res) > DestBitSize - SrcBitSize) + if (ComputeNumSignBits(Res, 0, &CI) > DestBitSize - SrcBitSize) return ReplaceInstUsesWith(CI, Res); // We need to emit a shl + ashr to do the sign extend. @@ -1264,10 +1273,12 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { LHSOrig = Builder->CreateFPExt(LHSOrig, RHSOrig->getType()); else if (RHSWidth <= SrcWidth) RHSOrig = Builder->CreateFPExt(RHSOrig, LHSOrig->getType()); - Value *ExactResult = Builder->CreateFRem(LHSOrig, RHSOrig); - if (Instruction *RI = dyn_cast<Instruction>(ExactResult)) - RI->copyFastMathFlags(OpI); - return CastInst::CreateFPCast(ExactResult, CI.getType()); + if (LHSOrig != OpI->getOperand(0) || RHSOrig != OpI->getOperand(1)) { + Value *ExactResult = Builder->CreateFRem(LHSOrig, RHSOrig); + if (Instruction *RI = dyn_cast<Instruction>(ExactResult)) + RI->copyFastMathFlags(OpI); + return CastInst::CreateFPCast(ExactResult, CI.getType()); + } } // (fptrunc (fneg x)) -> (fneg (fptrunc x)) @@ -1312,42 +1323,6 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { } } - // Fold (fptrunc (sqrt (fpext x))) -> (sqrtf x) - // Note that we restrict this transformation based on - // TLI->has(LibFunc::sqrtf), even for the sqrt intrinsic, because - // TLI->has(LibFunc::sqrtf) is sufficient to guarantee that the - // single-precision intrinsic can be expanded in the backend. - CallInst *Call = dyn_cast<CallInst>(CI.getOperand(0)); - if (Call && Call->getCalledFunction() && TLI->has(LibFunc::sqrtf) && - (Call->getCalledFunction()->getName() == TLI->getName(LibFunc::sqrt) || - Call->getCalledFunction()->getIntrinsicID() == Intrinsic::sqrt) && - Call->getNumArgOperands() == 1 && - Call->hasOneUse()) { - CastInst *Arg = dyn_cast<CastInst>(Call->getArgOperand(0)); - if (Arg && Arg->getOpcode() == Instruction::FPExt && - CI.getType()->isFloatTy() && - Call->getType()->isDoubleTy() && - Arg->getType()->isDoubleTy() && - Arg->getOperand(0)->getType()->isFloatTy()) { - Function *Callee = Call->getCalledFunction(); - Module *M = CI.getParent()->getParent()->getParent(); - Constant *SqrtfFunc = (Callee->getIntrinsicID() == Intrinsic::sqrt) ? - Intrinsic::getDeclaration(M, Intrinsic::sqrt, Builder->getFloatTy()) : - M->getOrInsertFunction("sqrtf", Callee->getAttributes(), - Builder->getFloatTy(), Builder->getFloatTy(), - NULL); - CallInst *ret = CallInst::Create(SqrtfFunc, Arg->getOperand(0), - "sqrtfcall"); - ret->setAttributes(Callee->getAttributes()); - - - // Remove the old Call. With -fmath-errno, it won't get marked readnone. - ReplaceInstUsesWith(*Call, UndefValue::get(Call->getType())); - EraseInstFromFunction(*Call); - return ret; - } - } - return nullptr; } @@ -1909,9 +1884,9 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { } Instruction *InstCombiner::visitAddrSpaceCast(AddrSpaceCastInst &CI) { - // If the destination pointer element type is not the the same as the source's - // do the addrspacecast to the same type, and then the bitcast in the new - // address space. This allows the cast to be exposed to other transforms. + // If the destination pointer element type is not the same as the source's + // first do a bitcast to the destination type, and then the addrspacecast. + // This allows the cast to be exposed to other transforms. Value *Src = CI.getOperand(0); PointerType *SrcTy = cast<PointerType>(Src->getType()->getScalarType()); PointerType *DestTy = cast<PointerType>(CI.getType()->getScalarType()); diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 5e71c5c..399f1c3 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -740,21 +740,6 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, ICmpInst::Predicate Pred) { - // If we have X+0, exit early (simplifying logic below) and let it get folded - // elsewhere. icmp X+0, X -> icmp X, X - if (CI->isZero()) { - bool isTrue = ICmpInst::isTrueWhenEqual(Pred); - return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue)); - } - - // (X+4) == X -> false. - if (Pred == ICmpInst::ICMP_EQ) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); - - // (X+4) != X -> true. - if (Pred == ICmpInst::ICMP_NE) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); - // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" // operators. @@ -1044,6 +1029,111 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, return nullptr; } +/// FoldICmpCstShrCst - Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> +/// (icmp eq/ne A, Log2(const2/const1)) -> +/// (icmp eq/ne A, Log2(const2) - Log2(const1)). +Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, + ConstantInt *CI2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getConstant = [&I, this](bool IsTrue) { + if (I.getPredicate() == I.ICMP_NE) + IsTrue = !IsTrue; + return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + }; + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + APInt AP1 = CI1->getValue(); + APInt AP2 = CI2->getValue(); + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + bool IsAShr = isa<AShrOperator>(Op); + if (IsAShr) { + if (AP2.isAllOnesValue()) + return nullptr; + if (AP2.isNegative() != AP1.isNegative()) + return nullptr; + if (AP2.sgt(AP1)) + return nullptr; + } + + if (!AP1) + // 'A' must be large enough to shift out the highest set bit. + return getICmp(I.ICMP_UGT, A, + ConstantInt::get(A->getType(), AP2.logBase2())); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the highest bit that's set. + int Shift; + // Both the constants are negative, take their positive to calculate log. + if (IsAShr && AP1.isNegative()) + // Get the ones' complement of AP2 and AP1 when computing the distance. + Shift = (~AP2).logBase2() - (~AP1).logBase2(); + else + Shift = AP2.logBase2() - AP1.logBase2(); + + if (Shift > 0) { + if (IsAShr ? AP1 == AP2.ashr(Shift) : AP1 == AP2.lshr(Shift)) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } + // Shifting const2 will never be equal to const1. + return getConstant(false); +} + +/// FoldICmpCstShlCst - Handle "(icmp eq/ne (shl const2, A), const1)" -> +/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)). +Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, + ConstantInt *CI2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getConstant = [&I, this](bool IsTrue) { + if (I.getPredicate() == I.ICMP_NE) + IsTrue = !IsTrue; + return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + }; + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + APInt AP1 = CI1->getValue(); + APInt AP2 = CI2->getValue(); + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + + unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + + if (!AP1 && AP2TrailingZeros != 0) + return getICmp(I.ICMP_UGE, A, + ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the lowest bits that are set. + int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + + if (Shift > 0 && AP2.shl(Shift) == AP1) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + + // Shifting const2 will never be equal to const1. + return getConstant(false); +} /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". /// @@ -1060,7 +1150,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne); + computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI); // If all the high bits are known, we can do this xform. if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { @@ -1282,6 +1372,48 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return &ICI; } + // (icmp pred (and (or (lshr X, Y), X), 1), 0) --> + // (icmp pred (and X, (or (shl 1, Y), 1), 0)) + // + // iff pred isn't signed + { + Value *X, *Y, *LShr; + if (!ICI.isSigned() && RHSV == 0) { + if (match(LHSI->getOperand(1), m_One())) { + Constant *One = cast<Constant>(LHSI->getOperand(1)); + Value *Or = LHSI->getOperand(0); + if (match(Or, m_Or(m_Value(LShr), m_Value(X))) && + match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) { + unsigned UsesRemoved = 0; + if (LHSI->hasOneUse()) + ++UsesRemoved; + if (Or->hasOneUse()) + ++UsesRemoved; + if (LShr->hasOneUse()) + ++UsesRemoved; + Value *NewOr = nullptr; + // Compute X & ((1 << Y) | 1) + if (auto *C = dyn_cast<Constant>(Y)) { + if (UsesRemoved >= 1) + NewOr = + ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); + } else { + if (UsesRemoved >= 3) + NewOr = Builder->CreateOr(Builder->CreateShl(One, Y, + LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); + } + if (NewOr) { + Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName()); + ICI.setOperand(0, NewAnd); + return &ICI; + } + } + } + } + } + // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any // bit set in (X & AndCst) will produce a result greater than RHSV. if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { @@ -1377,16 +1509,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, unsigned RHSLog2 = RHSV.logBase2(); // (1 << X) >= 2147483648 -> X >= 31 -> X == 31 - // (1 << X) > 2147483648 -> X > 31 -> false - // (1 << X) <= 2147483648 -> X <= 31 -> true // (1 << X) < 2147483648 -> X < 31 -> X != 31 if (RHSLog2 == TypeBits-1) { if (Pred == ICmpInst::ICMP_UGE) Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_UGT) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); - else if (Pred == ICmpInst::ICMP_ULE) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); else if (Pred == ICmpInst::ICMP_ULT) Pred = ICmpInst::ICMP_NE; } @@ -1421,10 +1547,6 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (RHSVIsPowerOf2) return new ICmpInst( Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2())); - - return ReplaceInstUsesWith( - ICI, Pred == ICmpInst::ICMP_EQ ? Builder->getFalse() - : Builder->getTrue()); } } break; @@ -1932,8 +2054,8 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // sign-extended; check for that condition. For example, if CI2 is 2^31 and // the operands of the add are 64 bits wide, we need at least 33 sign bits. unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A) < NeededSignBits || - IC.ComputeNumSignBits(B) < NeededSignBits) + if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || + IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) return nullptr; // In order to replace the original add with a narrower @@ -2038,8 +2160,8 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, Instruction *MulInstr = cast<Instruction>(MulVal); assert(MulInstr->getOpcode() == Instruction::Mul); - Instruction *LHS = cast<Instruction>(MulInstr->getOperand(0)), - *RHS = cast<Instruction>(MulInstr->getOperand(1)); + auto *LHS = cast<ZExtOperator>(MulInstr->getOperand(0)), + *RHS = cast<ZExtOperator>(MulInstr->getOperand(1)); assert(LHS->getOpcode() == Instruction::ZExt); assert(RHS->getOpcode() == Instruction::ZExt); Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); @@ -2341,7 +2463,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Changed = true; } - if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL)) + if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // comparing -val or val with non-zero is the same as just comparing val @@ -2469,6 +2591,21 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Builder->getInt(CI->getValue()-1)); } + if (I.isEquality()) { + ConstantInt *CI2; + if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || + match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) { + // (icmp eq/ne (ashr/lshr const2, A), const1) + if (Instruction *Inst = FoldICmpCstShrCst(I, Op0, A, CI, CI2)) + return Inst; + } + if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) { + // (icmp eq/ne (shl const2, A), const1) + if (Instruction *Inst = FoldICmpCstShlCst(I, Op0, A, CI, CI2)) + return Inst; + } + } + // If this comparison is a normal comparison, it demands all // bits, if it is a sign bit comparison, it only demands the sign bit. bool UnusedBit; @@ -2878,6 +3015,12 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (BO1 && BO1->getOpcode() == Instruction::Add) C = BO1->getOperand(0), D = BO1->getOperand(1); + // icmp (X+cst) < 0 --> X < -cst + if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) + if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) + if (!RHSC->isMinValue(/*isSigned=*/true)) + return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); + // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. if ((A == Op1 || B == Op1) && NoOp0WrapProblem) return new ICmpInst(Pred, A == Op1 ? B : A, @@ -3112,7 +3255,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // and (A & ~B) != 0 --> (A & B) == 0 // if A is a power of 2. if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(A) && I.isEquality()) + match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(A, false, + 0, AT, &I, DT) && + I.isEquality()) return new ICmpInst(I.getInversePredicate(), Builder->CreateAnd(A, B), Op1); @@ -3273,6 +3418,22 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } } + // The 'cmpxchg' instruction returns an aggregate containing the old value and + // an i1 which indicates whether or not we successfully did the swap. + // + // Replace comparisons between the old value and the expected value with the + // indicator that 'cmpxchg' returns. + // + // N.B. This transform is only valid when the 'cmpxchg' is not permitted to + // spuriously fail. In those cases, the old value may equal the expected + // value but it is possible for the swap to not occur. + if (I.getPredicate() == ICmpInst::ICMP_EQ) + if (auto *EVI = dyn_cast<ExtractValueInst>(Op0)) + if (auto *ACXI = dyn_cast<AtomicCmpXchgInst>(EVI->getAggregateOperand())) + if (EVI->getIndices()[0] == 0 && ACXI->getCompareOperand() == Op1 && + !ACXI->isWeak()) + return ExtractValueInst::Create(ACXI, 1); + { Value *X; ConstantInt *Cst; // icmp X+Cst, X @@ -3502,7 +3663,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL)) + if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index c10e92a..f3ac44c 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/Loads.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" @@ -268,7 +269,8 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { SmallVector<Instruction *, 4> ToDelete; if (MemTransferInst *Copy = isOnlyCopiedFromConstantGlobal(&AI, ToDelete)) { unsigned SourceAlign = getOrEnforceKnownAlignment(Copy->getSource(), - AI.getAlignment(), DL); + AI.getAlignment(), + DL, AT, &AI, DT); if (AI.getAlignment() <= SourceAlign) { DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); @@ -290,80 +292,112 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { return visitAllocSite(AI); } +/// \brief Helper to combine a load to a new type. +/// +/// This just does the work of combining a load to a new type. It handles +/// metadata, etc., and returns the new instruction. The \c NewTy should be the +/// loaded *value* type. This will convert it to a pointer, cast the operand to +/// that pointer type, load it, etc. +/// +/// Note that this will create all of the instructions with whatever insert +/// point the \c InstCombiner currently is using. +static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewTy) { + Value *Ptr = LI.getPointerOperand(); + unsigned AS = LI.getPointerAddressSpace(); + SmallVector<std::pair<unsigned, MDNode *>, 8> MD; + LI.getAllMetadata(MD); + + LoadInst *NewLoad = IC.Builder->CreateAlignedLoad( + IC.Builder->CreateBitCast(Ptr, NewTy->getPointerTo(AS)), + LI.getAlignment(), LI.getName()); + for (const auto &MDPair : MD) { + unsigned ID = MDPair.first; + MDNode *N = MDPair.second; + // Note, essentially every kind of metadata should be preserved here! This + // routine is supposed to clone a load instruction changing *only its type*. + // The only metadata it makes sense to drop is metadata which is invalidated + // when the pointer type changes. This should essentially never be the case + // in LLVM, but we explicitly switch over only known metadata to be + // conservatively correct. If you are adding metadata to LLVM which pertains + // to loads, you almost certainly want to add it here. + switch (ID) { + case LLVMContext::MD_dbg: + case LLVMContext::MD_tbaa: + case LLVMContext::MD_prof: + case LLVMContext::MD_fpmath: + case LLVMContext::MD_tbaa_struct: + case LLVMContext::MD_invariant_load: + case LLVMContext::MD_alias_scope: + case LLVMContext::MD_noalias: + case LLVMContext::MD_nontemporal: + case LLVMContext::MD_mem_parallel_loop_access: + case LLVMContext::MD_nonnull: + // All of these directly apply. + NewLoad->setMetadata(ID, N); + break; -/// InstCombineLoadCast - Fold 'load (cast P)' -> cast (load P)' when possible. -static Instruction *InstCombineLoadCast(InstCombiner &IC, LoadInst &LI, - const DataLayout *DL) { - User *CI = cast<User>(LI.getOperand(0)); - Value *CastOp = CI->getOperand(0); + case LLVMContext::MD_range: + // FIXME: It would be nice to propagate this in some way, but the type + // conversions make it hard. + break; + } + } + return NewLoad; +} - PointerType *DestTy = cast<PointerType>(CI->getType()); - Type *DestPTy = DestTy->getElementType(); - if (PointerType *SrcTy = dyn_cast<PointerType>(CastOp->getType())) { - - // If the address spaces don't match, don't eliminate the cast. - if (DestTy->getAddressSpace() != SrcTy->getAddressSpace()) - return nullptr; - - Type *SrcPTy = SrcTy->getElementType(); - - if (DestPTy->isIntegerTy() || DestPTy->isPointerTy() || - DestPTy->isVectorTy()) { - // If the source is an array, the code below will not succeed. Check to - // see if a trivial 'gep P, 0, 0' will help matters. Only do this for - // constants. - if (ArrayType *ASrcTy = dyn_cast<ArrayType>(SrcPTy)) - if (Constant *CSrc = dyn_cast<Constant>(CastOp)) - if (ASrcTy->getNumElements() != 0) { - Type *IdxTy = DL - ? DL->getIntPtrType(SrcTy) - : Type::getInt64Ty(SrcTy->getContext()); - Value *Idx = Constant::getNullValue(IdxTy); - Value *Idxs[2] = { Idx, Idx }; - CastOp = ConstantExpr::getGetElementPtr(CSrc, Idxs); - SrcTy = cast<PointerType>(CastOp->getType()); - SrcPTy = SrcTy->getElementType(); - } - - if (IC.getDataLayout() && - (SrcPTy->isIntegerTy() || SrcPTy->isPointerTy() || - SrcPTy->isVectorTy()) && - // Do not allow turning this into a load of an integer, which is then - // casted to a pointer, this pessimizes pointer analysis a lot. - (SrcPTy->isPtrOrPtrVectorTy() == - LI.getType()->isPtrOrPtrVectorTy()) && - IC.getDataLayout()->getTypeSizeInBits(SrcPTy) == - IC.getDataLayout()->getTypeSizeInBits(DestPTy)) { - - // Okay, we are casting from one integer or pointer type to another of - // the same size. Instead of casting the pointer before the load, cast - // the result of the loaded value. - LoadInst *NewLoad = - IC.Builder->CreateLoad(CastOp, LI.isVolatile(), CI->getName()); - NewLoad->setAlignment(LI.getAlignment()); - NewLoad->setAtomic(LI.getOrdering(), LI.getSynchScope()); - // Now cast the result of the load. - PointerType *OldTy = dyn_cast<PointerType>(NewLoad->getType()); - PointerType *NewTy = dyn_cast<PointerType>(LI.getType()); - if (OldTy && NewTy && - OldTy->getAddressSpace() != NewTy->getAddressSpace()) { - return new AddrSpaceCastInst(NewLoad, LI.getType()); - } +/// \brief Combine loads to match the type of value their uses after looking +/// through intervening bitcasts. +/// +/// The core idea here is that if the result of a load is used in an operation, +/// we should load the type most conducive to that operation. For example, when +/// loading an integer and converting that immediately to a pointer, we should +/// instead directly load a pointer. +/// +/// However, this routine must never change the width of a load or the number of +/// loads as that would introduce a semantic change. This combine is expected to +/// be a semantic no-op which just allows loads to more closely model the types +/// of their consuming operations. +/// +/// Currently, we also refuse to change the precise type used for an atomic load +/// or a volatile load. This is debatable, and might be reasonable to change +/// later. However, it is risky in case some backend or other part of LLVM is +/// relying on the exact type loaded to select appropriate atomic operations. +static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { + // FIXME: We could probably with some care handle both volatile and atomic + // loads here but it isn't clear that this is important. + if (!LI.isSimple()) + return nullptr; - return new BitCastInst(NewLoad, LI.getType()); - } + if (LI.use_empty()) + return nullptr; + + + // Fold away bit casts of the loaded value by loading the desired type. + if (LI.hasOneUse()) + if (auto *BC = dyn_cast<BitCastInst>(LI.user_back())) { + LoadInst *NewLoad = combineLoadToNewType(IC, LI, BC->getDestTy()); + BC->replaceAllUsesWith(NewLoad); + IC.EraseInstFromFunction(*BC); + return &LI; } - } + + // FIXME: We should also canonicalize loads of vectors when their elements are + // cast to other types. return nullptr; } Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { Value *Op = LI.getOperand(0); + // Try to canonicalize the loaded type. + if (Instruction *Res = combineLoadToOperationType(*this, LI)) + return Res; + // Attempt to improve the alignment. if (DL) { unsigned KnownAlign = - getOrEnforceKnownAlignment(Op, DL->getPrefTypeAlignment(LI.getType()),DL); + getOrEnforceKnownAlignment(Op, DL->getPrefTypeAlignment(LI.getType()), + DL, AT, &LI, DT); unsigned LoadAlign = LI.getAlignment(); unsigned EffectiveLoadAlign = LoadAlign != 0 ? LoadAlign : DL->getABITypeAlignment(LI.getType()); @@ -374,11 +408,6 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { LI.setAlignment(EffectiveLoadAlign); } - // load (cast X) --> cast (load X) iff safe. - if (isa<CastInst>(Op)) - if (Instruction *Res = InstCombineLoadCast(*this, LI, DL)) - return Res; - // None of the following transforms are legal for volatile/atomic loads. // FIXME: Some of it is okay for atomic loads; needs refactoring. if (!LI.isSimple()) return nullptr; @@ -388,7 +417,8 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // separated by a few arithmetic operations. BasicBlock::iterator BBI = &LI; if (Value *AvailableVal = FindAvailableLoadedValue(Op, LI.getParent(), BBI,6)) - return ReplaceInstUsesWith(LI, AvailableVal); + return ReplaceInstUsesWith( + LI, Builder->CreateBitCast(AvailableVal, LI.getType())); // load(gep null, ...) -> unreachable if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { @@ -417,12 +447,6 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); } - // Instcombine load (constantexpr_cast global) -> cast (load global) - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Op)) - if (CE->isCast()) - if (Instruction *Res = InstCombineLoadCast(*this, LI, DL)) - return Res; - if (Op->hasOneUse()) { // Change select and PHI nodes to select values instead of addresses: this // helps alias analysis out a lot, allows many others simplifications, and @@ -473,7 +497,7 @@ static Instruction *InstCombineStoreToCast(InstCombiner &IC, StoreInst &SI) { User *CI = cast<User>(SI.getOperand(1)); Value *CastOp = CI->getOperand(0); - Type *DestPTy = cast<PointerType>(CI->getType())->getElementType(); + Type *DestPTy = CI->getType()->getPointerElementType(); PointerType *SrcTy = dyn_cast<PointerType>(CastOp->getType()); if (!SrcTy) return nullptr; @@ -518,8 +542,7 @@ static Instruction *InstCombineStoreToCast(InstCombiner &IC, StoreInst &SI) { // If the pointers point into different address spaces don't do the // transformation. - if (SrcTy->getAddressSpace() != - cast<PointerType>(CI->getType())->getAddressSpace()) + if (SrcTy->getAddressSpace() != CI->getType()->getPointerAddressSpace()) return nullptr; // If the pointers point to values of different sizes don't do the @@ -602,7 +625,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { if (DL) { unsigned KnownAlign = getOrEnforceKnownAlignment(Ptr, DL->getPrefTypeAlignment(Val->getType()), - DL); + DL, AT, &SI, DT); unsigned StoreAlign = SI.getAlignment(); unsigned EffectiveStoreAlign = StoreAlign != 0 ? StoreAlign : DL->getABITypeAlignment(Val->getType()); @@ -837,12 +860,13 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { InsertNewInstBefore(NewSI, *BBI); NewSI->setDebugLoc(OtherStore->getDebugLoc()); - // If the two stores had the same TBAA tag, preserve it. - if (MDNode *TBAATag = SI.getMetadata(LLVMContext::MD_tbaa)) - if ((TBAATag = MDNode::getMostGenericTBAA(TBAATag, - OtherStore->getMetadata(LLVMContext::MD_tbaa)))) - NewSI->setMetadata(LLVMContext::MD_tbaa, TBAATag); - + // If the two stores had AA tags, merge them. + AAMDNodes AATags; + SI.getAAMetadata(AATags); + if (AATags) { + OtherStore->getAAMetadata(AATags, /* Merge = */ true); + NewSI->setAAMetadata(AATags); + } // Nuke the old stores. EraseInstFromFunction(SI); diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 6c6e7d8..8c48dce 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -25,7 +25,8 @@ using namespace PatternMatch; /// simplifyValueKnownNonZero - The specific integer value is used in a context /// where it is known to be non-zero. If this allows us to simplify the /// computation, do so and return the new operand, otherwise return null. -static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { +static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, + Instruction *CxtI) { // If V has multiple uses, then we would have to do more analysis to determine // if this is safe. For example, the use could be in dynamically unreached // code. @@ -35,22 +36,23 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { // ((1 << A) >>u B) --> (1 << (A-B)) // Because V cannot be zero, we know that B is less than A. - Value *A = nullptr, *B = nullptr, *PowerOf2 = nullptr; - if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(PowerOf2), m_Value(A))), - m_Value(B))) && - // The "1" can be any value known to be a power of 2. - isKnownToBeAPowerOfTwo(PowerOf2)) { + Value *A = nullptr, *B = nullptr, *One = nullptr; + if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(One), m_Value(A))), m_Value(B))) && + match(One, m_One())) { A = IC.Builder->CreateSub(A, B); - return IC.Builder->CreateShl(PowerOf2, A); + return IC.Builder->CreateShl(One, A); } // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it // inexact. Similarly for <<. if (BinaryOperator *I = dyn_cast<BinaryOperator>(V)) - if (I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0))) { + if (I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0), false, + 0, IC.getAssumptionTracker(), + CxtI, + IC.getDominatorTree())) { // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. - if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC)) { + if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { I->setOperand(0, V2); MadeChange = true; } @@ -76,25 +78,30 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { /// MultiplyOverflows - True if the multiply can not be expressed in an int /// this size. -static bool MultiplyOverflows(ConstantInt *C1, ConstantInt *C2, bool sign) { - uint32_t W = C1->getBitWidth(); - APInt LHSExt = C1->getValue(), RHSExt = C2->getValue(); - if (sign) { - LHSExt = LHSExt.sext(W * 2); - RHSExt = RHSExt.sext(W * 2); - } else { - LHSExt = LHSExt.zext(W * 2); - RHSExt = RHSExt.zext(W * 2); - } +static bool MultiplyOverflows(const APInt &C1, const APInt &C2, APInt &Product, + bool IsSigned) { + bool Overflow; + if (IsSigned) + Product = C1.smul_ov(C2, Overflow); + else + Product = C1.umul_ov(C2, Overflow); + + return Overflow; +} - APInt MulExt = LHSExt * RHSExt; +/// \brief True if C2 is a multiple of C1. Quotient contains C2/C1. +static bool IsMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, + bool IsSigned) { + assert(C1.getBitWidth() == C2.getBitWidth() && + "Inconsistent width of constants!"); - if (!sign) - return MulExt.ugt(APInt::getLowBitsSet(W * 2, W)); + APInt Remainder(C1.getBitWidth(), /*Val=*/0ULL, IsSigned); + if (IsSigned) + APInt::sdivrem(C1, C2, Quotient, Remainder); + else + APInt::udivrem(C1, C2, Quotient, Remainder); - APInt Min = APInt::getSignedMinValue(W).sext(W * 2); - APInt Max = APInt::getSignedMaxValue(W).sext(W * 2); - return MulExt.slt(Min) || MulExt.sgt(Max); + return Remainder.isMinValue(); } /// \brief A helper routine of InstCombiner::visitMul(). @@ -123,7 +130,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyMulInst(Op0, Op1, DL)) + if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) @@ -155,8 +162,10 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (NewCst) { BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst); - if (I.hasNoSignedWrap()) Shl->setHasNoSignedWrap(); - if (I.hasNoUnsignedWrap()) Shl->setHasNoUnsignedWrap(); + + if (I.hasNoUnsignedWrap()) + Shl->setHasNoUnsignedWrap(); + return Shl; } } @@ -277,9 +286,9 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true); Value *BoolCast = nullptr, *OtherOp = nullptr; - if (MaskedValueIsZero(Op0, Negative2)) + if (MaskedValueIsZero(Op0, Negative2, 0, &I)) BoolCast = Op0, OtherOp = Op1; - else if (MaskedValueIsZero(Op1, Negative2)) + else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) BoolCast = Op1, OtherOp = Op0; if (BoolCast) { @@ -292,40 +301,32 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { return Changed ? &I : nullptr; } -// -// Detect pattern: -// -// log2(Y*0.5) -// -// And check for corresponding fast math flags -// - +/// Detect pattern log2(Y * 0.5) with corresponding fast math flags. static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) { - - if (!Op->hasOneUse()) - return; - - IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op); - if (!II) - return; - if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra()) - return; - Log2 = II; - - Value *OpLog2Of = II->getArgOperand(0); - if (!OpLog2Of->hasOneUse()) - return; - - Instruction *I = dyn_cast<Instruction>(OpLog2Of); - if (!I) - return; - if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra()) - return; - - if (match(I->getOperand(0), m_SpecificFP(0.5))) - Y = I->getOperand(1); - else if (match(I->getOperand(1), m_SpecificFP(0.5))) - Y = I->getOperand(0); + if (!Op->hasOneUse()) + return; + + IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op); + if (!II) + return; + if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra()) + return; + Log2 = II; + + Value *OpLog2Of = II->getArgOperand(0); + if (!OpLog2Of->hasOneUse()) + return; + + Instruction *I = dyn_cast<Instruction>(OpLog2Of); + if (!I) + return; + if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra()) + return; + + if (match(I->getOperand(0), m_SpecificFP(0.5))) + Y = I->getOperand(1); + else if (match(I->getOperand(1), m_SpecificFP(0.5))) + Y = I->getOperand(0); } static bool isFiniteNonZeroFp(Constant *C) { @@ -440,7 +441,8 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (isa<Constant>(Op0)) std::swap(Op0, Op1); - if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL)) + if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, + DT, AT)) return ReplaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); @@ -510,10 +512,15 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } + // sqrt(X) * sqrt(X) -> X + if (AllowReassociate && (Op0 == Op1)) + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) + if (II->getIntrinsicID() == Intrinsic::sqrt) + return ReplaceInstUsesWith(I, II->getOperand(0)); // Under unsafe algebra do: // X * log2(0.5*Y) = X*log2(Y) - X - if (I.hasUnsafeAlgebra()) { + if (AllowReassociate) { Value *OpX = nullptr; Value *OpY = nullptr; IntrinsicInst *Log2; @@ -596,36 +603,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } - // B * (uitofp i1 C) -> select C, B, 0 - if (I.hasNoNaNs() && I.hasNoInfs() && I.hasNoSignedZeros()) { - Value *LHS = Op0, *RHS = Op1; - Value *B, *C; - if (!match(RHS, m_UIToFP(m_Value(C)))) - std::swap(LHS, RHS); - - if (match(RHS, m_UIToFP(m_Value(C))) && - C->getType()->getScalarType()->isIntegerTy(1)) { - B = LHS; - Value *Zero = ConstantFP::getNegativeZero(B->getType()); - return SelectInst::Create(C, B, Zero); - } - } - - // A * (1 - uitofp i1 C) -> select C, 0, A - if (I.hasNoNaNs() && I.hasNoInfs() && I.hasNoSignedZeros()) { - Value *LHS = Op0, *RHS = Op1; - Value *A, *C; - if (!match(RHS, m_FSub(m_FPOne(), m_UIToFP(m_Value(C))))) - std::swap(LHS, RHS); - - if (match(RHS, m_FSub(m_FPOne(), m_UIToFP(m_Value(C)))) && - C->getType()->getScalarType()->isIntegerTy(1)) { - A = LHS; - Value *Zero = ConstantFP::getNegativeZero(A->getType()); - return SelectInst::Create(C, Zero, A); - } - } - if (!isa<Constant>(Op1)) std::swap(Opnd0, Opnd1); else @@ -714,7 +691,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this)) { + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, &I)) { I.setOperand(1, V); return &I; } @@ -724,25 +701,83 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) return &I; - if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { - // (X / C1) / C2 -> X / (C1*C2) - if (Instruction *LHS = dyn_cast<Instruction>(Op0)) - if (Instruction::BinaryOps(LHS->getOpcode()) == I.getOpcode()) - if (ConstantInt *LHSRHS = dyn_cast<ConstantInt>(LHS->getOperand(1))) { - if (MultiplyOverflows(RHS, LHSRHS, - I.getOpcode() == Instruction::SDiv)) - return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); - return BinaryOperator::Create(I.getOpcode(), LHS->getOperand(0), - ConstantExpr::getMul(RHS, LHSRHS)); + if (Instruction *LHS = dyn_cast<Instruction>(Op0)) { + const APInt *C2; + if (match(Op1, m_APInt(C2))) { + Value *X; + const APInt *C1; + bool IsSigned = I.getOpcode() == Instruction::SDiv; + + // (X / C1) / C2 -> X / (C1*C2) + if ((IsSigned && match(LHS, m_SDiv(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(LHS, m_UDiv(m_Value(X), m_APInt(C1))))) { + APInt Product(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + if (!MultiplyOverflows(*C1, *C2, Product, IsSigned)) + return BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(I.getType(), Product)); + } + + if ((IsSigned && match(LHS, m_NSWMul(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(LHS, m_NUWMul(m_Value(X), m_APInt(C1))))) { + APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + + // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. + if (IsMultiple(*C2, *C1, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient)); + BO->setIsExact(I.isExact()); + return BO; } - if (!RHS->isZero()) { // avoid X udiv 0 - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2. + if (IsMultiple(*C1, *C2, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient)); + BO->setHasNoUnsignedWrap( + !IsSigned && + cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap()); + BO->setHasNoSignedWrap( + cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap()); + return BO; + } + } + + if ((IsSigned && match(LHS, m_NSWShl(m_Value(X), m_APInt(C1))) && + *C1 != C1->getBitWidth() - 1) || + (!IsSigned && match(LHS, m_NUWShl(m_Value(X), m_APInt(C1))))) { + APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + APInt C1Shifted = APInt::getOneBitSet( + C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue())); + + // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of C1. + if (IsMultiple(*C2, C1Shifted, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient)); + BO->setIsExact(I.isExact()); + return BO; + } + + // (X << C1) / C2 -> X * (C2 >> C1) if C1 is a multiple of C2. + if (IsMultiple(C1Shifted, *C2, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient)); + BO->setHasNoUnsignedWrap( + !IsSigned && + cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap()); + BO->setHasNoSignedWrap( + cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap()); + return BO; + } + } + + if (*C2 != 0) { // avoid X udiv 0 + if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI)) + return R; + if (isa<PHINode>(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } } } @@ -828,7 +863,8 @@ static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, const APInt &C = cast<Constant>(Op1)->getUniqueInteger(); BinaryOperator *LShr = BinaryOperator::CreateLShr( Op0, ConstantInt::get(Op0->getType(), C.logBase2())); - if (I.isExact()) LShr->setIsExact(); + if (I.isExact()) + LShr->setIsExact(); return LShr; } @@ -856,7 +892,8 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1)) N = IC.Builder->CreateZExt(N, Z->getDestTy()); BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); - if (I.isExact()) LShr->setIsExact(); + if (I.isExact()) + LShr->setIsExact(); return LShr; } @@ -893,10 +930,10 @@ static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I, return 0; if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) - if (size_t LHSIdx = visitUDivOperand(Op0, SI->getOperand(1), I, Actions)) - if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions)) { - Actions.push_back(UDivFoldAction((FoldUDivOperandCb)nullptr, Op1, - LHSIdx-1)); + if (size_t LHSIdx = + visitUDivOperand(Op0, SI->getOperand(1), I, Actions, Depth)) + if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions, Depth)) { + Actions.push_back(UDivFoldAction(nullptr, Op1, LHSIdx - 1)); return Actions.size(); } @@ -909,7 +946,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyUDivInst(Op0, Op1, DL)) + if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // Handle the integer div common cases @@ -917,19 +954,25 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { return Common; // (x lshr C1) udiv C2 --> x udiv (C2 << C1) - if (Constant *C2 = dyn_cast<Constant>(Op1)) { + { Value *X; - Constant *C1; - if (match(Op0, m_LShr(m_Value(X), m_Constant(C1)))) - return BinaryOperator::CreateUDiv(X, ConstantExpr::getShl(C2, C1)); + const APInt *C1, *C2; + if (match(Op0, m_LShr(m_Value(X), m_APInt(C1))) && + match(Op1, m_APInt(C2))) { + bool Overflow; + APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow); + if (!Overflow) + return BinaryOperator::CreateUDiv( + X, ConstantInt::get(X->getType(), C2ShlC1)); + } } // (zext A) udiv (zext B) --> zext (A udiv B) if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0)) if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy())) - return new ZExtInst(Builder->CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", - I.isExact()), - I.getType()); + return new ZExtInst( + Builder->CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", I.isExact()), + I.getType()); // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...)))) SmallVector<UDivFoldAction, 6> UDivActions; @@ -971,7 +1014,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifySDivInst(Op0, Op1, DL)) + if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1008,8 +1051,8 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { // unsigned inputs), turn this into a udiv. if (I.getType()->isIntegerTy()) { APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); - if (MaskedValueIsZero(Op0, Mask)) { - if (MaskedValueIsZero(Op1, Mask)) { + if (MaskedValueIsZero(Op0, Mask, 0, &I)) { + if (MaskedValueIsZero(Op1, Mask, 0, &I)) { // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set return BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); } @@ -1034,8 +1077,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { /// If the conversion was successful, the simplified expression "X * 1/C" is /// returned; otherwise, NULL is returned. /// -static Instruction *CvtFDivConstToReciprocal(Value *Dividend, - Constant *Divisor, +static Instruction *CvtFDivConstToReciprocal(Value *Dividend, Constant *Divisor, bool AllowReciprocal) { if (!isa<ConstantFP>(Divisor)) // TODO: handle vectors. return nullptr; @@ -1064,7 +1106,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFDivInst(Op0, Op1, DL)) + if (Value *V = SimplifyFDivInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (isa<Constant>(Op0)) @@ -1195,7 +1237,7 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this)) { + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, &I)) { I.setOperand(1, V); return &I; } @@ -1229,7 +1271,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyURemInst(Op0, Op1, DL)) + if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) @@ -1242,7 +1284,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { I.getType()); // X urem Y -> X and Y-1, where Y is a power of 2, - if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true)) { + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true, 0, AT, &I, DT)) { Constant *N1 = Constant::getAllOnesValue(I.getType()); Value *Add = Builder->CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); @@ -1264,28 +1306,29 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifySRemInst(Op0, Op1, DL)) + if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // Handle the integer rem common cases if (Instruction *Common = commonIRemTransforms(I)) return Common; - if (Value *RHSNeg = dyn_castNegVal(Op1)) - if (!isa<Constant>(RHSNeg) || - (isa<ConstantInt>(RHSNeg) && - cast<ConstantInt>(RHSNeg)->getValue().isStrictlyPositive())) { - // X % -Y -> X % Y + { + const APInt *Y; + // X % -Y -> X % Y + if (match(Op1, m_APInt(Y)) && Y->isNegative() && !Y->isMinSignedValue()) { Worklist.AddValue(I.getOperand(1)); - I.setOperand(1, RHSNeg); + I.setOperand(1, ConstantInt::get(I.getType(), -*Y)); return &I; } + } // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a urem. if (I.getType()->isIntegerTy()) { APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); - if (MaskedValueIsZero(Op1, Mask) && MaskedValueIsZero(Op0, Mask)) { + if (MaskedValueIsZero(Op1, Mask, 0, &I) && + MaskedValueIsZero(Op0, Mask, 0, &I)) { // X srem Y -> X urem Y, iff X and Y don't have sign bit set return BinaryOperator::CreateURem(Op0, Op1, I.getName()); } @@ -1338,7 +1381,7 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFRemInst(Op0, Op1, DL)) + if (Value *V = SimplifyFRemInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 46f7b8a..794263a 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -506,12 +506,12 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { /// DeadPHICycle - Return true if this PHI node is only used by a PHI node cycle /// that is dead. static bool DeadPHICycle(PHINode *PN, - SmallPtrSet<PHINode*, 16> &PotentiallyDeadPHIs) { + SmallPtrSetImpl<PHINode*> &PotentiallyDeadPHIs) { if (PN->use_empty()) return true; if (!PN->hasOneUse()) return false; // Remember this node, and if we find the cycle, return. - if (!PotentiallyDeadPHIs.insert(PN)) + if (!PotentiallyDeadPHIs.insert(PN).second) return true; // Don't scan crazily complex things. @@ -528,9 +528,9 @@ static bool DeadPHICycle(PHINode *PN, /// NonPhiInVal. This happens with mutually cyclic phi nodes like: /// z = some value; x = phi (y, z); y = phi (x, z) static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal, - SmallPtrSet<PHINode*, 16> &ValueEqualPHIs) { + SmallPtrSetImpl<PHINode*> &ValueEqualPHIs) { // See if we already saw this PHI node. - if (!ValueEqualPHIs.insert(PN)) + if (!ValueEqualPHIs.insert(PN).second) return true; // Don't scan crazily complex things. @@ -654,7 +654,7 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // If the user is a PHI, inspect its uses recursively. if (PHINode *UserPN = dyn_cast<PHINode>(UserI)) { - if (PHIsInspected.insert(UserPN)) + if (PHIsInspected.insert(UserPN).second) PHIsToSlice.push_back(UserPN); continue; } @@ -788,7 +788,7 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // PHINode simplification // Instruction *InstCombiner::visitPHINode(PHINode &PN) { - if (Value *V = SimplifyInstruction(&PN, DL, TLI)) + if (Value *V = SimplifyInstruction(&PN, DL, TLI, DT, AT)) return ReplaceInstUsesWith(PN, V); // If all PHI operands are the same operation, pull them through the PHI, diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 06c9e29..079ae34 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -313,7 +313,9 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, /// replaced with RepOp. static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const DataLayout *TD, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + DominatorTree *DT, + AssumptionTracker *AT) { // Trivial replacement. if (V == Op) return RepOp; @@ -334,10 +336,10 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if (CmpInst *C = dyn_cast<CmpInst>(I)) { if (C->getOperand(0) == Op) return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD, - TLI); + TLI, DT, AT); if (C->getOperand(1) == Op) return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD, - TLI); + TLI, DT, AT); } // TODO: We could hand off more cases to instsimplify here. @@ -390,11 +392,11 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, /// /// This also tries to turn /// --- Single bit tests: -/// if ((x & C) == 0) x |= C to x |= C -/// if ((x & C) != 0) x ^= C to x &= ~C -/// if ((x & C) == 0) x ^= C to x |= C -/// if ((x & C) != 0) x &= ~C to x &= ~C -/// if ((x & C) == 0) x &= ~C to nothing +/// if ((x & C) == 0) x |= C to x |= C +/// if ((x & C) != 0) x ^= C to x &= ~C +/// if ((x & C) == 0) x ^= C to x |= C +/// if ((x & C) != 0) x &= ~C to x &= ~C +/// if ((x & C) == 0) x &= ~C to nothing static Value *foldSelectICmpAndOr(SelectInst &SI, Value *TrueVal, Value *FalseVal, InstCombiner::BuilderTy *Builder) { @@ -605,18 +607,26 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI) == TrueVal) + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, + DT, AT) == TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, + DT, AT) == TrueVal) return ReplaceInstUsesWith(SI, FalseVal); - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI) == FalseVal) + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, + DT, AT) == FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, + DT, AT) == FalseVal) return ReplaceInstUsesWith(SI, FalseVal); } else if (Pred == ICmpInst::ICMP_NE) { - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI) == FalseVal) + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, + DT, AT) == FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, + DT, AT) == FalseVal) return ReplaceInstUsesWith(SI, TrueVal); - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI) == TrueVal) + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, + DT, AT) == TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, + DT, AT) == TrueVal) return ReplaceInstUsesWith(SI, TrueVal); } @@ -825,7 +835,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); - if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, DL)) + if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, TLI, + DT, AT)) return ReplaceInstUsesWith(SI, V); if (SI.getType()->isIntegerTy(1)) { diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 2495747..afa907a 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -68,7 +68,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { /// this succeeds, the GetShiftedValue function will be called to produce the /// value. static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, - InstCombiner &IC) { + InstCombiner &IC, Instruction *CxtI) { // We can always evaluate constants shifted. if (isa<Constant>(V)) return true; @@ -111,8 +111,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, case Instruction::Or: case Instruction::Xor: // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. - return CanEvaluateShifted(I->getOperand(0), NumBits, isLeftShift, IC) && - CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC); + return CanEvaluateShifted(I->getOperand(0), NumBits, isLeftShift, IC, I) && + CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC, I); case Instruction::Shl: { // We can often fold the shift into shifts-by-a-constant. @@ -131,8 +131,9 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // profitable unless we know the and'd out bits are already zero. if (CI->getZExtValue() > NumBits) { unsigned LowBits = TypeWidth - CI->getZExtValue(); - if (MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits)) + if (IC.MaskedValueIsZero(I->getOperand(0), + APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, + 0, CxtI)) return true; } @@ -155,8 +156,9 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // profitable unless we know the and'd out bits are already zero. if (CI->getValue().ult(TypeWidth) && CI->getZExtValue() > NumBits) { unsigned LowBits = CI->getZExtValue() - NumBits; - if (MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits)) + if (IC.MaskedValueIsZero(I->getOperand(0), + APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, + 0, CxtI)) return true; } @@ -164,8 +166,9 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); - return CanEvaluateShifted(SI->getTrueValue(), NumBits, isLeftShift, IC) && - CanEvaluateShifted(SI->getFalseValue(), NumBits, isLeftShift, IC); + return CanEvaluateShifted(SI->getTrueValue(), NumBits, isLeftShift, + IC, SI) && + CanEvaluateShifted(SI->getFalseValue(), NumBits, isLeftShift, IC, SI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -173,7 +176,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (!CanEvaluateShifted(PN->getIncomingValue(i), NumBits, isLeftShift,IC)) + if (!CanEvaluateShifted(PN->getIncomingValue(i), NumBits, isLeftShift, + IC, PN)) return false; return true; } @@ -329,7 +333,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && - CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this)) { + CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) { DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); @@ -488,7 +492,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } - // If the operand is an bitwise operator with a constant RHS, and the + // If the operand is a bitwise operator with a constant RHS, and the // shift is the only use, we can pull it out of the shift. if (ConstantInt *Op0C = dyn_cast<ConstantInt>(Op0BO->getOperand(1))) { bool isValid = true; // Valid only for And, Or, Xor @@ -691,7 +695,7 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - DL)) + DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (Instruction *V = commonShiftTransforms(I)) @@ -703,14 +707,15 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && MaskedValueIsZero(I.getOperand(0), - APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt))) { + APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), + 0, &I)) { I.setHasNoUnsignedWrap(); return &I; } // If the shifted out value is all signbits, this is a NSW shift. if (!I.hasNoSignedWrap() && - ComputeNumSignBits(I.getOperand(0)) > ShAmt) { + ComputeNumSignBits(I.getOperand(0), 0, &I) > ShAmt) { I.setHasNoSignedWrap(); return &I; } @@ -731,7 +736,7 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), - I.isExact(), DL)) + I.isExact(), DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) @@ -760,7 +765,8 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt))){ + MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), + 0, &I)){ I.setIsExact(); return &I; } @@ -774,7 +780,7 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), - I.isExact(), DL)) + I.isExact(), DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) @@ -804,7 +810,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt))){ + MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt), + 0, &I)){ I.setIsExact(); return &I; } @@ -812,13 +819,9 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // See if we can turn a signed shr into an unsigned shr. if (MaskedValueIsZero(Op0, - APInt::getSignBit(I.getType()->getScalarSizeInBits()))) + APInt::getSignBit(I.getType()->getScalarSizeInBits()), + 0, &I)) return BinaryOperator::CreateLShr(Op0, Op1); - // Arithmetic shifting an all-sign-bit value is a no-op. - unsigned NumSignBits = ComputeNumSignBits(Op0); - if (NumSignBits == Op0->getType()->getScalarSizeInBits()) - return ReplaceInstUsesWith(I, Op0); - return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 1b42d3d..ad6983a 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -43,6 +43,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, // This instruction is producing bits that are not demanded. Shrink the RHS. Demanded &= OpC->getValue(); I->setOperand(OpNo, ConstantInt::get(OpC->getType(), Demanded)); + + // If either 'nsw' or 'nuw' is set and the constant is negative, + // removing *any* bits from the constant could make overflow occur. + // Remove 'nsw' and 'nuw' from the instruction in this case. + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(I)) { + assert(OBO->getOpcode() == Instruction::Add); + if (OBO->hasNoSignedWrap() || OBO->hasNoUnsignedWrap()) { + if (OpC->getValue().isNegative()) { + cast<BinaryOperator>(OBO)->setHasNoSignedWrap(false); + cast<BinaryOperator>(OBO)->setHasNoUnsignedWrap(false); + } + } + } + return true; } @@ -57,7 +71,7 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { APInt DemandedMask(APInt::getAllOnesValue(BitWidth)); Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, - KnownZero, KnownOne, 0); + KnownZero, KnownOne, 0, &Inst); if (!V) return false; if (V == &Inst) return true; ReplaceInstUsesWith(Inst, V); @@ -71,7 +85,8 @@ bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth) { Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, - KnownZero, KnownOne, Depth); + KnownZero, KnownOne, Depth, + dyn_cast<Instruction>(U.getUser())); if (!NewVal) return false; U = NewVal; return true; @@ -101,7 +116,8 @@ bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask, /// in the context where the specified bits are demanded, but not for all users. Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, - unsigned Depth) { + unsigned Depth, + Instruction *CxtI) { assert(V != nullptr && "Null pointer of Value???"); assert(Depth <= 6 && "Limit Search Depth"); uint32_t BitWidth = DemandedMask.getBitWidth(); @@ -144,7 +160,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Instruction *I = dyn_cast<Instruction>(V); if (!I) { - computeKnownBits(V, KnownZero, KnownOne, Depth); + computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); return nullptr; // Only analyze instructions. } @@ -158,8 +174,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // this instruction has a simpler value in that context. if (I->getOpcode() == Instruction::And) { // If either the LHS or the RHS are Zero, the result is zero. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If all of the demanded bits are known 1 on one side, return the other. // These bits cannot contribute to the result of the 'and' in this @@ -180,8 +198,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // only bits from X or Y are demanded. // If either the LHS or the RHS are One, the result is One. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If all of the demanded bits are known zero on one side, return the // other. These bits cannot contribute to the result of the 'or' in this @@ -205,8 +225,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // We can simplify (X^Y) -> X or Y in the user's context if we know that // only bits from X or Y are demanded. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If all of the demanded bits are known zero on one side, return the // other. @@ -217,7 +239,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } // Compute the KnownZero/KnownOne bits to simplify things downstream. - computeKnownBits(I, KnownZero, KnownOne, Depth); + computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); return nullptr; } @@ -230,7 +252,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, switch (I->getOpcode()) { default: - computeKnownBits(I, KnownZero, KnownOne, Depth); + computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); break; case Instruction::And: // If either the LHS or the RHS are Zero, the result is zero. @@ -242,6 +264,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + // If the client is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & ((RHSKnownZero | LHSKnownZero)| + (RHSKnownOne & LHSKnownOne))) == DemandedMask) + return Constant::getIntegerValue(VTy, RHSKnownOne & LHSKnownOne); + // If all of the demanded bits are known 1 on one side, return the other. // These bits cannot contribute to the result of the 'and'. if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) == @@ -274,6 +302,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + // If the client is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & ((RHSKnownZero & LHSKnownZero)| + (RHSKnownOne | LHSKnownOne))) == DemandedMask) + return Constant::getIntegerValue(VTy, RHSKnownOne | LHSKnownOne); + // If all of the demanded bits are known zero on one side, return the other. // These bits cannot contribute to the result of the 'or'. if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) == @@ -310,6 +344,18 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + // Output known-0 bits are known if clear or set in both the LHS & RHS. + APInt IKnownZero = (RHSKnownZero & LHSKnownZero) | + (RHSKnownOne & LHSKnownOne); + // Output known-1 are known to be set if set in only one of the LHS, RHS. + APInt IKnownOne = (RHSKnownZero & LHSKnownOne) | + (RHSKnownOne & LHSKnownZero); + + // If the client is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) + return Constant::getIntegerValue(VTy, IKnownOne); + // If all of the demanded bits are known zero on one side, return the other. // These bits cannot contribute to the result of the 'xor'. if ((DemandedMask & RHSKnownZero) == DemandedMask) @@ -581,7 +627,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Otherwise just hand the sub off to computeKnownBits to fill in // the known zeros and ones. - computeKnownBits(V, KnownZero, KnownOne, Depth); + computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known // zero. @@ -752,7 +798,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // remainder is zero. if (DemandedMask.isNegative() && KnownZero.isNonNegative()) { APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If it's known zero, our sign bit is also zero. if (LHSKnownZero.isNegative()) KnownZero.setBit(KnownZero.getBitWidth() - 1); @@ -814,7 +861,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return nullptr; } } - computeKnownBits(V, KnownZero, KnownOne, Depth); + computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); break; } diff --git a/lib/Transforms/InstCombine/InstCombineWorklist.h b/lib/Transforms/InstCombine/InstCombineWorklist.h index 1ab7db3..8d857d0 100644 --- a/lib/Transforms/InstCombine/InstCombineWorklist.h +++ b/lib/Transforms/InstCombine/InstCombineWorklist.h @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#ifndef INSTCOMBINE_WORKLIST_H -#define INSTCOMBINE_WORKLIST_H +#ifndef LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEWORKLIST_H +#define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEWORKLIST_H #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index 08e2446..e4a4fef 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -39,12 +39,16 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/Analysis/AssumptionTracker.h" +#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" @@ -68,11 +72,6 @@ STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumFactor , "Number of factorizations"); STATISTIC(NumReassoc , "Number of reassociations"); -static cl::opt<bool> UnsafeFPShrink("enable-double-float-shrink", cl::Hidden, - cl::init(false), - cl::desc("Enable unsafe double to float " - "shrinking for math lib calls")); - // Initialization Routines void llvm::initializeInstCombine(PassRegistry &Registry) { initializeInstCombinerPass(Registry); @@ -85,12 +84,14 @@ void LLVMInitializeInstCombine(LLVMPassRegistryRef R) { char InstCombiner::ID = 0; INITIALIZE_PASS_BEGIN(InstCombiner, "instcombine", "Combine redundant instructions", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_END(InstCombiner, "instcombine", "Combine redundant instructions", false, false) void InstCombiner::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesCFG(); + AU.addRequired<AssumptionTracker>(); AU.addRequired<TargetLibraryInfo>(); } @@ -390,6 +391,25 @@ static bool RightDistributesOverLeft(Instruction::BinaryOps LOp, Instruction::BinaryOps ROp) { if (Instruction::isCommutative(ROp)) return LeftDistributesOverRight(ROp, LOp); + + switch (LOp) { + default: + return false; + // (X >> Z) & (Y >> Z) -> (X&Y) >> Z for all shifts. + // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. + // (X >> Z) ^ (Y >> Z) -> (X^Y) >> Z for all shifts. + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + switch (ROp) { + default: + return false; + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + return true; + } + } // TODO: It would be nice to handle division, aka "(X + Y)/Z = X/Z + Y/Z", // but this requires knowing that the addition does not overflow and other // such subtleties. @@ -411,26 +431,37 @@ static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) { } /// This function factors binary ops which can be combined using distributive -/// laws. This also factor SHL as MUL e.g. SHL(X, 2) ==> MUL(X, 4). +/// laws. This function tries to transform 'Op' based TopLevelOpcode to enable +/// factorization e.g for ADD(SHL(X , 2), MUL(X, 5)), When this function called +/// with TopLevelOpcode == Instruction::Add and Op = SHL(X, 2), transforms +/// SHL(X, 2) to MUL(X, 4) i.e. returns Instruction::Mul with LHS set to 'X' and +/// RHS to 4. static Instruction::BinaryOps -getBinOpsForFactorization(BinaryOperator *Op, Value *&LHS, Value *&RHS) { +getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode, + BinaryOperator *Op, Value *&LHS, Value *&RHS) { if (!Op) return Instruction::BinaryOpsEnd; - if (Op->getOpcode() == Instruction::Shl) { - if (Constant *CST = dyn_cast<Constant>(Op->getOperand(1))) { - // The multiplier is really 1 << CST. - RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST); - LHS = Op->getOperand(0); - return Instruction::Mul; + LHS = Op->getOperand(0); + RHS = Op->getOperand(1); + + switch (TopLevelOpcode) { + default: + return Op->getOpcode(); + + case Instruction::Add: + case Instruction::Sub: + if (Op->getOpcode() == Instruction::Shl) { + if (Constant *CST = dyn_cast<Constant>(Op->getOperand(1))) { + // The multiplier is really 1 << CST. + RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST); + return Instruction::Mul; + } } + return Op->getOpcode(); } // TODO: We can add other conversions e.g. shr => div etc. - - LHS = Op->getOperand(0); - RHS = Op->getOperand(1); - return Op->getOpcode(); } /// This tries to simplify binary operations by factorizing out common terms @@ -529,8 +560,9 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { // Factorization. Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; - Instruction::BinaryOps LHSOpcode = getBinOpsForFactorization(Op0, A, B); - Instruction::BinaryOps RHSOpcode = getBinOpsForFactorization(Op1, C, D); + auto TopLevelOpcode = I.getOpcode(); + auto LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); + auto RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); // The instruction has the form "(A op' B) op (C op' D)". Try to factorize // a common term. @@ -552,7 +584,6 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { return V; // Expansion. - Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { // The instruction has the form "(A op' B) op C". See if expanding it out // to "(A op C) op' (B op C)" results in simplifications. @@ -765,13 +796,14 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { // If the incoming non-constant value is in I's block, we will remove one // instruction, but insert another equivalent one, leading to infinite // instcombine. - if (NonConstBB == I.getParent()) + if (isPotentiallyReachable(I.getParent(), NonConstBB, DT, + getAnalysisIfAvailable<LoopInfo>())) return nullptr; } // If there is exactly one non-constant value, we can insert a copy of the // operation in that block. However, if this is a critical edge, we would be - // inserting the computation one some other paths (e.g. inside a loop). Only + // inserting the computation on some other paths (e.g. inside a loop). Only // do this if the pred block is unconditionally branching into the phi block. if (NonConstBB != nullptr) { BranchInst *BI = dyn_cast<BranchInst>(NonConstBB->getTerminator()); @@ -1284,7 +1316,7 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); - if (Value *V = SimplifyGEPInst(Ops, DL)) + if (Value *V = SimplifyGEPInst(Ops, DL, TLI, DT, AT)) return ReplaceInstUsesWith(GEP, V); Value *PtrOp = GEP.getOperand(0); @@ -1478,19 +1510,50 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { GetElementPtrInst::Create(Src->getOperand(0), Indices, GEP.getName()); } - // Canonicalize (gep i8* X, -(ptrtoint Y)) to (sub (ptrtoint X), (ptrtoint Y)) - // The GEP pattern is emitted by the SCEV expander for certain kinds of - // pointer arithmetic. - if (DL && GEP.getNumIndices() == 1 && - match(GEP.getOperand(1), m_Neg(m_PtrToInt(m_Value())))) { + if (DL && GEP.getNumIndices() == 1) { unsigned AS = GEP.getPointerAddressSpace(); - if (GEP.getType() == Builder->getInt8PtrTy(AS) && - GEP.getOperand(1)->getType()->getScalarSizeInBits() == + if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == DL->getPointerSizeInBits(AS)) { - Operator *Index = cast<Operator>(GEP.getOperand(1)); - Value *PtrToInt = Builder->CreatePtrToInt(PtrOp, Index->getType()); - Value *NewSub = Builder->CreateSub(PtrToInt, Index->getOperand(1)); - return CastInst::Create(Instruction::IntToPtr, NewSub, GEP.getType()); + Type *PtrTy = GEP.getPointerOperandType(); + Type *Ty = PtrTy->getPointerElementType(); + uint64_t TyAllocSize = DL->getTypeAllocSize(Ty); + + bool Matched = false; + uint64_t C; + Value *V = nullptr; + if (TyAllocSize == 1) { + V = GEP.getOperand(1); + Matched = true; + } else if (match(GEP.getOperand(1), + m_AShr(m_Value(V), m_ConstantInt(C)))) { + if (TyAllocSize == 1ULL << C) + Matched = true; + } else if (match(GEP.getOperand(1), + m_SDiv(m_Value(V), m_ConstantInt(C)))) { + if (TyAllocSize == C) + Matched = true; + } + + if (Matched) { + // Canonicalize (gep i8* X, -(ptrtoint Y)) + // to (inttoptr (sub (ptrtoint X), (ptrtoint Y))) + // The GEP pattern is emitted by the SCEV expander for certain kinds of + // pointer arithmetic. + if (match(V, m_Neg(m_PtrToInt(m_Value())))) { + Operator *Index = cast<Operator>(V); + Value *PtrToInt = Builder->CreatePtrToInt(PtrOp, Index->getType()); + Value *NewSub = Builder->CreateSub(PtrToInt, Index->getOperand(1)); + return CastInst::Create(Instruction::IntToPtr, NewSub, GEP.getType()); + } + // Canonicalize (gep i8* X, (ptrtoint Y)-(ptrtoint X)) + // to (bitcast Y) + Value *Y; + if (match(V, m_Sub(m_PtrToInt(m_Value(Y)), + m_PtrToInt(m_Specific(GEP.getOperand(0)))))) { + return CastInst::CreatePointerBitCastOrAddrSpaceCast(Y, + GEP.getType()); + } + } } } @@ -1582,9 +1645,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { Builder->CreateGEP(StrippedPtr, Idx, GEP.getName()); // V and GEP are both pointer types --> BitCast - if (StrippedPtrTy->getAddressSpace() == GEP.getPointerAddressSpace()) - return new BitCastInst(NewGEP, GEP.getType()); - return new AddrSpaceCastInst(NewGEP, GEP.getType()); + return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, + GEP.getType()); } // Transform things like: @@ -1616,9 +1678,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { Builder->CreateGEP(StrippedPtr, NewIdx, GEP.getName()); // The NewGEP must be pointer typed, so must the old one -> BitCast - if (StrippedPtrTy->getAddressSpace() == GEP.getPointerAddressSpace()) - return new BitCastInst(NewGEP, GEP.getType()); - return new AddrSpaceCastInst(NewGEP, GEP.getType()); + return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, + GEP.getType()); } } } @@ -1658,9 +1719,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { Builder->CreateInBoundsGEP(StrippedPtr, Off, GEP.getName()) : Builder->CreateGEP(StrippedPtr, Off, GEP.getName()); // The NewGEP must be pointer typed, so must the old one -> BitCast - if (StrippedPtrTy->getAddressSpace() == GEP.getPointerAddressSpace()) - return new BitCastInst(NewGEP, GEP.getType()); - return new AddrSpaceCastInst(NewGEP, GEP.getType()); + return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, + GEP.getType()); } } } @@ -1670,6 +1730,18 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (!DL) return nullptr; + // addrspacecast between types is canonicalized as a bitcast, then an + // addrspacecast. To take advantage of the below bitcast + struct GEP, look + // through the addrspacecast. + if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) { + // X = bitcast A addrspace(1)* to B addrspace(1)* + // Y = addrspacecast A addrspace(1)* to B addrspace(2)* + // Z = gep Y, <...constant indices...> + // Into an addrspacecasted GEP of the struct. + if (BitCastInst *BC = dyn_cast<BitCastInst>(ASC->getOperand(0))) + PtrOp = BC; + } + /// See if we can simplify: /// X = bitcast A* to B* /// Y = gep X, <...constant indices...> @@ -1678,11 +1750,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (BitCastInst *BCI = dyn_cast<BitCastInst>(PtrOp)) { Value *Operand = BCI->getOperand(0); PointerType *OpType = cast<PointerType>(Operand->getType()); - unsigned OffsetBits = DL->getPointerTypeSizeInBits(OpType); + unsigned OffsetBits = DL->getPointerTypeSizeInBits(GEP.getType()); APInt Offset(OffsetBits, 0); if (!isa<BitCastInst>(Operand) && - GEP.accumulateConstantOffset(*DL, Offset) && - StrippedPtrTy->getAddressSpace() == GEP.getPointerAddressSpace()) { + GEP.accumulateConstantOffset(*DL, Offset)) { // If this GEP instruction doesn't move the pointer, just replace the GEP // with a bitcast of the real input to the dest type. @@ -1700,6 +1771,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { return &GEP; } } + + if (Operand->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) + return new AddrSpaceCastInst(Operand, GEP.getType()); return new BitCastInst(Operand, GEP.getType()); } @@ -1715,6 +1789,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (NGEP->getType() == GEP.getType()) return ReplaceInstUsesWith(GEP, NGEP); NGEP->takeName(&GEP); + + if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) + return new AddrSpaceCastInst(NGEP, GEP.getType()); return new BitCastInst(NGEP, GEP.getType()); } } @@ -1925,7 +2002,25 @@ Instruction *InstCombiner::visitFree(CallInst &FI) { return nullptr; } +Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) { + if (RI.getNumOperands() == 0) // ret void + return nullptr; + + Value *ResultOp = RI.getOperand(0); + Type *VTy = ResultOp->getType(); + if (!VTy->isIntegerTy()) + return nullptr; + + // There might be assume intrinsics dominating this return that completely + // determine the value. If so, constant fold it. + unsigned BitWidth = VTy->getPrimitiveSizeInBits(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + computeKnownBits(ResultOp, KnownZero, KnownOne, 0, &RI); + if ((KnownZero|KnownOne).isAllOnesValue()) + RI.setOperand(0, Constant::getIntegerValue(VTy, KnownOne)); + return nullptr; +} Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { // Change br (not X), label True, label False to: br X, label False, True @@ -1977,6 +2072,37 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { Value *Cond = SI.getCondition(); + unsigned BitWidth = cast<IntegerType>(Cond->getType())->getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + computeKnownBits(Cond, KnownZero, KnownOne); + unsigned LeadingKnownZeros = KnownZero.countLeadingOnes(); + unsigned LeadingKnownOnes = KnownOne.countLeadingOnes(); + + // Compute the number of leading bits we can ignore. + for (auto &C : SI.cases()) { + LeadingKnownZeros = std::min( + LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros()); + LeadingKnownOnes = std::min( + LeadingKnownOnes, C.getCaseValue()->getValue().countLeadingOnes()); + } + + unsigned NewWidth = BitWidth - std::max(LeadingKnownZeros, LeadingKnownOnes); + + // Truncate the condition operand if the new type is equal to or larger than + // the largest legal integer type. We need to be conservative here since + // x86 generates redundant zero-extenstion instructions if the operand is + // truncated to i8 or i16. + if (BitWidth > NewWidth && NewWidth >= DL->getLargestLegalIntTypeSize()) { + IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth); + Builder->SetInsertPoint(&SI); + Value *NewCond = Builder->CreateTrunc(SI.getCondition(), Ty, "trunc"); + SI.setCondition(NewCond); + + for (auto &C : SI.cases()) + static_cast<SwitchInst::CaseIt *>(&C)->setValue(ConstantInt::get( + SI.getContext(), C.getCaseValue()->getValue().trunc(NewWidth))); + } + if (Instruction *I = dyn_cast<Instruction>(Cond)) { if (I->getOpcode() == Instruction::Add) if (ConstantInt *AddRHS = dyn_cast<ConstantInt>(I->getOperand(1))) { @@ -2215,7 +2341,7 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) { // If we already saw this clause, there is no point in having a second // copy of it. - if (AlreadyCaught.insert(TypeInfo)) { + if (AlreadyCaught.insert(TypeInfo).second) { // This catch clause was not already seen. NewClauses.push_back(CatchClause); } else { @@ -2297,7 +2423,7 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) { continue; // There is no point in having multiple copies of the same typeinfo in // a filter, so only add it if we didn't already. - if (SeenInFilter.insert(TypeInfo)) + if (SeenInFilter.insert(TypeInfo).second) NewFilterElts.push_back(cast<Constant>(Elt)); } // A filter containing a catch-all cannot match anything by definition. @@ -2534,7 +2660,7 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { /// whose condition is a known constant, we only visit the reachable successors. /// static bool AddReachableCodeToWorklist(BasicBlock *BB, - SmallPtrSet<BasicBlock*, 64> &Visited, + SmallPtrSetImpl<BasicBlock*> &Visited, InstCombiner &IC, const DataLayout *DL, const TargetLibraryInfo *TLI) { @@ -2549,7 +2675,8 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, BB = Worklist.pop_back_val(); // We have now visited this block! If we've already been here, ignore it. - if (!Visited.insert(BB)) continue; + if (!Visited.insert(BB).second) + continue; for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E; ) { Instruction *Inst = BBI++; @@ -2730,9 +2857,18 @@ bool InstCombiner::DoOneIteration(Function &F, unsigned Iteration) { // If the user is one of our immediate successors, and if that successor // only has us as a predecessors (we'd have to split the critical edge // otherwise), we can keep going. - if (UserIsSuccessor && UserParent->getSinglePredecessor()) + if (UserIsSuccessor && UserParent->getSinglePredecessor()) { // Okay, the CFG is simple enough, try to sink this instruction. - MadeIRChange |= TryToSinkInstruction(I, UserParent); + if (TryToSinkInstruction(I, UserParent)) { + MadeIRChange = true; + // We'll add uses of the sunk instruction below, but since sinking + // can expose opportunities for it's *operands* add them to the + // worklist + for (Use &U : I->operands()) + if (Instruction *OpI = dyn_cast<Instruction>(U.get())) + Worklist.Add(OpI); + } + } } } @@ -2801,13 +2937,13 @@ bool InstCombiner::DoOneIteration(Function &F, unsigned Iteration) { } namespace { -class InstCombinerLibCallSimplifier : public LibCallSimplifier { +class InstCombinerLibCallSimplifier final : public LibCallSimplifier { InstCombiner *IC; public: InstCombinerLibCallSimplifier(const DataLayout *DL, const TargetLibraryInfo *TLI, InstCombiner *IC) - : LibCallSimplifier(DL, TLI, UnsafeFPShrink) { + : LibCallSimplifier(DL, TLI) { this->IC = IC; } @@ -2823,9 +2959,15 @@ bool InstCombiner::runOnFunction(Function &F) { if (skipOptnoneFunction(F)) return false; + AT = &getAnalysis<AssumptionTracker>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; TLI = &getAnalysis<TargetLibraryInfo>(); + + DominatorTreeWrapperPass *DTWP = + getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DT = DTWP ? &DTWP->getDomTree() : nullptr; + // Minimizing size? MinimizeSize = F.getAttributes().hasAttribute(AttributeSet::FunctionIndex, Attribute::MinSize); @@ -2834,7 +2976,7 @@ bool InstCombiner::runOnFunction(Function &F) { /// instructions into the worklist when they are created. IRBuilder<true, TargetFolder, InstCombineIRInserter> TheBuilder(F.getContext(), TargetFolder(DL), - InstCombineIRInserter(Worklist)); + InstCombineIRInserter(Worklist, AT)); Builder = &TheBuilder; InstCombinerLibCallSimplifier TheSimplifier(DL, TLI, this); |