diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineAndOrXor.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 216 |
1 files changed, 180 insertions, 36 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 55ebced..863eeaf 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -11,7 +11,7 @@ // //===----------------------------------------------------------------------===// -#include "InstCombine.h" +#include "InstCombineInternal.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Intrinsics.h" @@ -22,30 +22,12 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -/// isFreeToInvert - Return true if the specified value is free to invert (apply -/// ~ to). This happens in cases where the ~ can be eliminated. -static inline bool isFreeToInvert(Value *V) { - // ~(~(X)) -> X. - if (BinaryOperator::isNot(V)) - return true; - - // Constants can be considered to be not'ed values. - if (isa<ConstantInt>(V)) - return true; - - // Compares can be inverted if they have a single use. - if (CmpInst *CI = dyn_cast<CmpInst>(V)) - return CI->hasOneUse(); - - return false; -} - static inline Value *dyn_castNotVal(Value *V) { // If this is not(not(x)) don't return that this is a not: we want the two // not's to be folded first. if (BinaryOperator::isNot(V)) { Value *Operand = BinaryOperator::getNotArgument(V); - if (!isFreeToInvert(Operand)) + if (!IsFreeToInvert(Operand, Operand->hasOneUse())) return Operand; } @@ -117,6 +99,61 @@ static Value *getFCmpValue(bool isordered, unsigned code, return Builder->CreateFCmp(Pred, LHS, RHS); } +/// \brief Transform BITWISE_OP(BSWAP(A),BSWAP(B)) to BSWAP(BITWISE_OP(A, B)) +/// \param I Binary operator to transform. +/// \return Pointer to node that must replace the original binary operator, or +/// null pointer if no transformation was made. +Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) { + IntegerType *ITy = dyn_cast<IntegerType>(I.getType()); + + // Can't do vectors. + if (I.getType()->isVectorTy()) return nullptr; + + // Can only do bitwise ops. + unsigned Op = I.getOpcode(); + if (Op != Instruction::And && Op != Instruction::Or && + Op != Instruction::Xor) + return nullptr; + + Value *OldLHS = I.getOperand(0); + Value *OldRHS = I.getOperand(1); + ConstantInt *ConstLHS = dyn_cast<ConstantInt>(OldLHS); + ConstantInt *ConstRHS = dyn_cast<ConstantInt>(OldRHS); + IntrinsicInst *IntrLHS = dyn_cast<IntrinsicInst>(OldLHS); + IntrinsicInst *IntrRHS = dyn_cast<IntrinsicInst>(OldRHS); + bool IsBswapLHS = (IntrLHS && IntrLHS->getIntrinsicID() == Intrinsic::bswap); + bool IsBswapRHS = (IntrRHS && IntrRHS->getIntrinsicID() == Intrinsic::bswap); + + if (!IsBswapLHS && !IsBswapRHS) + return nullptr; + + if (!IsBswapLHS && !ConstLHS) + return nullptr; + + if (!IsBswapRHS && !ConstRHS) + return nullptr; + + /// OP( BSWAP(x), BSWAP(y) ) -> BSWAP( OP(x, y) ) + /// OP( BSWAP(x), CONSTANT ) -> BSWAP( OP(x, BSWAP(CONSTANT) ) ) + Value *NewLHS = IsBswapLHS ? IntrLHS->getOperand(0) : + Builder->getInt(ConstLHS->getValue().byteSwap()); + + Value *NewRHS = IsBswapRHS ? IntrRHS->getOperand(0) : + Builder->getInt(ConstRHS->getValue().byteSwap()); + + Value *BinOp = nullptr; + if (Op == Instruction::And) + BinOp = Builder->CreateAnd(NewLHS, NewRHS); + else if (Op == Instruction::Or) + BinOp = Builder->CreateOr(NewLHS, NewRHS); + else //if (Op == Instruction::Xor) + BinOp = Builder->CreateXor(NewLHS, NewRHS); + + Module *M = I.getParent()->getParent()->getParent(); + Function *F = Intrinsic::getDeclaration(M, Intrinsic::bswap, ITy); + return Builder->CreateCall(F, BinOp); +} + // OptAndOp - This handles expressions of the form ((val OP C1) & C2). Where // the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. Op is // guaranteed to be a binary operator. @@ -785,6 +822,62 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, return nullptr; } +/// Try to fold a signed range checked with lower bound 0 to an unsigned icmp. +/// Example: (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n +/// If \p Inverted is true then the check is for the inverted range, e.g. +/// (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n +Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool Inverted) { + // Check the lower range comparison, e.g. x >= 0 + // InstCombine already ensured that if there is a constant it's on the RHS. + ConstantInt *RangeStart = dyn_cast<ConstantInt>(Cmp0->getOperand(1)); + if (!RangeStart) + return nullptr; + + ICmpInst::Predicate Pred0 = (Inverted ? Cmp0->getInversePredicate() : + Cmp0->getPredicate()); + + // Accept x > -1 or x >= 0 (after potentially inverting the predicate). + if (!((Pred0 == ICmpInst::ICMP_SGT && RangeStart->isMinusOne()) || + (Pred0 == ICmpInst::ICMP_SGE && RangeStart->isZero()))) + return nullptr; + + ICmpInst::Predicate Pred1 = (Inverted ? Cmp1->getInversePredicate() : + Cmp1->getPredicate()); + + Value *Input = Cmp0->getOperand(0); + Value *RangeEnd; + if (Cmp1->getOperand(0) == Input) { + // For the upper range compare we have: icmp x, n + RangeEnd = Cmp1->getOperand(1); + } else if (Cmp1->getOperand(1) == Input) { + // For the upper range compare we have: icmp n, x + RangeEnd = Cmp1->getOperand(0); + Pred1 = ICmpInst::getSwappedPredicate(Pred1); + } else { + return nullptr; + } + + // Check the upper range comparison, e.g. x < n + ICmpInst::Predicate NewPred; + switch (Pred1) { + case ICmpInst::ICMP_SLT: NewPred = ICmpInst::ICMP_ULT; break; + case ICmpInst::ICMP_SLE: NewPred = ICmpInst::ICMP_ULE; break; + default: return nullptr; + } + + // This simplification is only valid if the upper range is not negative. + bool IsNegative, IsNotNegative; + ComputeSignBit(RangeEnd, IsNotNegative, IsNegative, /*Depth=*/0, Cmp1); + if (!IsNotNegative) + return nullptr; + + if (Inverted) + NewPred = ICmpInst::getInversePredicate(NewPred); + + return Builder->CreateICmp(NewPred, Input, RangeEnd); +} + /// FoldAndOfICmps - Fold (icmp)&(icmp) if possible. Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); @@ -807,6 +900,14 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, true, Builder)) return V; + // E.g. (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n + if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/false)) + return V; + + // E.g. (icmp slt x, n) & (icmp sge x, 0) --> icmp ult x, n + if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/false)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1)); @@ -1108,7 +1209,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyAndInst(Op0, Op1, DL, TLI, DT, AT)) + if (Value *V = SimplifyAndInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // (A|B)&(A|C) -> A|(B&C) etc @@ -1120,6 +1221,9 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (SimplifyDemandedInstructionBits(I)) return &I; + if (Value *V = SimplifyBSwap(I)) + return ReplaceInstUsesWith(I, V); + if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) { const APInt &AndRHSMask = AndRHS->getValue(); @@ -1605,15 +1709,15 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Value *Mask = nullptr; Value *Masked = nullptr; if (LAnd->getOperand(0) == RAnd->getOperand(0) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(1), false, 0, AT, CxtI, DT) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(1), false, 0, AT, CxtI, DT)) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(1), false, 0, AC, CxtI, DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(1), false, 0, AC, CxtI, DT)) { Mask = Builder->CreateOr(LAnd->getOperand(1), RAnd->getOperand(1)); Masked = Builder->CreateAnd(LAnd->getOperand(0), Mask); } else if (LAnd->getOperand(1) == RAnd->getOperand(1) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(0), - false, 0, AT, CxtI, DT) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(0), - false, 0, AT, CxtI, DT)) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(0), false, 0, AC, CxtI, + DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(0), false, 0, AC, CxtI, + DT)) { Mask = Builder->CreateOr(LAnd->getOperand(0), RAnd->getOperand(0)); Masked = Builder->CreateAnd(LAnd->getOperand(1), Mask); } @@ -1724,6 +1828,14 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Builder->CreateAdd(B, ConstantInt::getSigned(B->getType(), -1)), A); } + // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n + if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/true)) + return V; + + // E.g. (icmp sgt x, n) | (icmp slt x, 0) --> icmp ugt x, n + if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/true)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSCst || !RHSCst) return nullptr; @@ -2033,7 +2145,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyOrInst(Op0, Op1, DL, TLI, DT, AT)) + if (Value *V = SimplifyOrInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // (A&B)|(A&C) -> A&(B|C) etc @@ -2045,6 +2157,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (SimplifyDemandedInstructionBits(I)) return &I; + if (Value *V = SimplifyBSwap(I)) + return ReplaceInstUsesWith(I, V); + if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { ConstantInt *C1 = nullptr; Value *X = nullptr; // (X & C1) | C2 --> (X | C2) & (C1|C2) @@ -2305,11 +2420,34 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (SwappedForXor) std::swap(Op0, Op1); - if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) - if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) + { + ICmpInst *LHS = dyn_cast<ICmpInst>(Op0); + ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); + if (LHS && RHS) if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) return ReplaceInstUsesWith(I, Res); + // TODO: Make this recursive; it's a little tricky because an arbitrary + // number of 'or' instructions might have to be created. + Value *X, *Y; + if (LHS && match(Op1, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = FoldOrOfICmps(LHS, Cmp, &I)) + return ReplaceInstUsesWith(I, Builder->CreateOr(Res, Y)); + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = FoldOrOfICmps(LHS, Cmp, &I)) + return ReplaceInstUsesWith(I, Builder->CreateOr(Res, X)); + } + if (RHS && match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = FoldOrOfICmps(Cmp, RHS, &I)) + return ReplaceInstUsesWith(I, Builder->CreateOr(Res, Y)); + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = FoldOrOfICmps(Cmp, RHS, &I)) + return ReplaceInstUsesWith(I, Builder->CreateOr(Res, X)); + } + } + // (fcmp uno x, c) | (fcmp uno y, c) -> (fcmp uno x, y) if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) @@ -2394,7 +2532,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyXorInst(Op0, Op1, DL, TLI, DT, AT)) + if (Value *V = SimplifyXorInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // (A&B)^(A&C) -> A&(B^C) etc @@ -2406,6 +2544,9 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (SimplifyDemandedInstructionBits(I)) return &I; + if (Value *V = SimplifyBSwap(I)) + return ReplaceInstUsesWith(I, V); + // Is this a ~ operation? if (Value *NotOp = dyn_castNotVal(&I)) { if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(NotOp)) { @@ -2426,8 +2567,10 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // ~(X & Y) --> (~X | ~Y) - De Morgan's Law // ~(X | Y) === (~X & ~Y) - De Morgan's Law - if (isFreeToInvert(Op0I->getOperand(0)) && - isFreeToInvert(Op0I->getOperand(1))) { + if (IsFreeToInvert(Op0I->getOperand(0), + Op0I->getOperand(0)->hasOneUse()) && + IsFreeToInvert(Op0I->getOperand(1), + Op0I->getOperand(1)->hasOneUse())) { Value *NotX = Builder->CreateNot(Op0I->getOperand(0), "notlhs"); Value *NotY = @@ -2445,15 +2588,16 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - - if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { - if (RHS->isOne() && Op0->hasOneUse()) + if (Constant *RHS = dyn_cast<Constant>(Op1)) { + if (RHS->isAllOnesValue() && Op0->hasOneUse()) // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B if (CmpInst *CI = dyn_cast<CmpInst>(Op0)) return CmpInst::Create(CI->getOpcode(), CI->getInversePredicate(), CI->getOperand(0), CI->getOperand(1)); + } + if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { // fold (xor(zext(cmp)), 1) and (xor(sext(cmp)), -1) to ext(!cmp). if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { if (CmpInst *CI = dyn_cast<CmpInst>(Op0C->getOperand(0))) { |