aboutsummaryrefslogtreecommitdiffstats
path: root/lib/Transforms/InstCombine/InstCombineShifts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineShifts.cpp91
1 files changed, 55 insertions, 36 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 8273dfd..cc6665c 100644
--- a/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -19,6 +19,8 @@
using namespace llvm;
using namespace PatternMatch;
+#define DEBUG_TYPE "instcombine"
+
Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
assert(I.getOperand(1)->getType() == I.getOperand(0)->getType());
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -33,7 +35,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
if (Instruction *R = FoldOpIntoSelect(I, SI))
return R;
- if (ConstantInt *CUI = dyn_cast<ConstantInt>(Op1))
+ if (Constant *CUI = dyn_cast<Constant>(Op1))
if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
return Res;
@@ -50,7 +52,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
return &I;
}
- return 0;
+ return nullptr;
}
/// CanEvaluateShifted - See if we can compute the specified value, but shifted
@@ -78,7 +80,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
// if the needed bits are already zero in the input. This allows us to reuse
// the value which means that we don't care if the shift has multiple uses.
// TODO: Handle opposite shift by exact value.
- ConstantInt *CI = 0;
+ ConstantInt *CI = nullptr;
if ((isLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) ||
(!isLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) {
if (CI->getZExtValue() == NumBits) {
@@ -115,7 +117,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
case Instruction::Shl: {
// We can often fold the shift into shifts-by-a-constant.
CI = dyn_cast<ConstantInt>(I->getOperand(1));
- if (CI == 0) return false;
+ if (!CI) return false;
// We can always fold shl(c1)+shl(c2) -> shl(c1+c2).
if (isLeftShift) return true;
@@ -139,7 +141,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
case Instruction::LShr: {
// We can often fold the shift into shifts-by-a-constant.
CI = dyn_cast<ConstantInt>(I->getOperand(1));
- if (CI == 0) return false;
+ if (!CI) return false;
// We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2).
if (!isLeftShift) return true;
@@ -309,37 +311,38 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
-Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
+Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
BinaryOperator &I) {
bool isLeftShift = I.getOpcode() == Instruction::Shl;
+ ConstantInt *COp1 = nullptr;
+ if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(Op1))
+ COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
+ else if (ConstantVector *CV = dyn_cast<ConstantVector>(Op1))
+ COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
+ else
+ COp1 = dyn_cast<ConstantInt>(Op1);
+
+ if (!COp1)
+ return nullptr;
// See if we can propagate this shift into the input, this covers the trivial
// cast of lshr(shl(x,c1),c2) as well as other more complex cases.
if (I.getOpcode() != Instruction::AShr &&
- CanEvaluateShifted(Op0, Op1->getZExtValue(), isLeftShift, *this)) {
+ CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this)) {
DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression"
" to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n");
return ReplaceInstUsesWith(I,
- GetShiftedValue(Op0, Op1->getZExtValue(), isLeftShift, *this));
+ GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this));
}
-
// See if we can simplify any instructions used by the instruction whose sole
// purpose is to compute bits we don't care about.
uint32_t TypeBits = Op0->getType()->getScalarSizeInBits();
- // shl i32 X, 32 = 0 and srl i8 Y, 9 = 0, ... just don't eliminate
- // a signed shift.
- //
- if (Op1->uge(TypeBits)) {
- if (I.getOpcode() != Instruction::AShr)
- return ReplaceInstUsesWith(I, Constant::getNullValue(Op0->getType()));
- // ashr i32 X, 32 --> ashr i32 X, 31
- I.setOperand(1, ConstantInt::get(I.getType(), TypeBits-1));
- return &I;
- }
+ assert(!COp1->uge(TypeBits) &&
+ "Shift over the type width should have been removed already");
// ((X*C1) << C2) == (X * (C1 << C2))
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0))
@@ -367,7 +370,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
if (TrOp && I.isLogicalShift() && TrOp->isShift() &&
isa<ConstantInt>(TrOp->getOperand(1))) {
// Okay, we'll do this xform. Make the shift of shift.
- Constant *ShAmt = ConstantExpr::getZExt(Op1, TrOp->getType());
+ Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType());
// (shift2 (shift1 & 0x00FF), c2)
Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName());
@@ -384,10 +387,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
// shift. We know that it is a logical shift by a constant, so adjust the
// mask as appropriate.
if (I.getOpcode() == Instruction::Shl)
- MaskV <<= Op1->getZExtValue();
+ MaskV <<= COp1->getZExtValue();
else {
assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift");
- MaskV = MaskV.lshr(Op1->getZExtValue());
+ MaskV = MaskV.lshr(COp1->getZExtValue());
}
// shift1 & 0x00FF
@@ -421,9 +424,13 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
// (X + (Y << C))
Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1,
Op0BO->getOperand(1)->getName());
- uint32_t Op1Val = Op1->getLimitedValue(TypeBits);
- return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getContext(),
- APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val)));
+ uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
+
+ APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
+ Constant *Mask = ConstantInt::get(I.getContext(), Bits);
+ if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
+ Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
+ return BinaryOperator::CreateAnd(X, Mask);
}
// Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C))
@@ -453,9 +460,13 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
// (X + (Y << C))
Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS,
Op0BO->getOperand(0)->getName());
- uint32_t Op1Val = Op1->getLimitedValue(TypeBits);
- return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getContext(),
- APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val)));
+ uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
+
+ APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
+ Constant *Mask = ConstantInt::get(I.getContext(), Bits);
+ if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
+ Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
+ return BinaryOperator::CreateAnd(X, Mask);
}
// Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C)
@@ -523,7 +534,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
// Find out if this is a shift of a shift by a constant.
BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0);
if (ShiftOp && !ShiftOp->isShift())
- ShiftOp = 0;
+ ShiftOp = nullptr;
if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) {
@@ -541,9 +552,9 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1));
uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits);
- uint32_t ShiftAmt2 = Op1->getLimitedValue(TypeBits);
+ uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits);
assert(ShiftAmt2 != 0 && "Should have been simplified earlier");
- if (ShiftAmt1 == 0) return 0; // Will be simplified in the future.
+ if (ShiftAmt1 == 0) return nullptr; // Will be simplified in the future.
Value *X = ShiftOp->getOperand(0);
IntegerType *Ty = cast<IntegerType>(I.getType());
@@ -671,10 +682,13 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
}
}
}
- return 0;
+ return nullptr;
}
Instruction *InstCombiner::visitShl(BinaryOperator &I) {
+ if (Value *V = SimplifyVectorOp(I))
+ return ReplaceInstUsesWith(I, V);
+
if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1),
I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
DL))
@@ -709,10 +723,13 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
match(I.getOperand(1), m_Constant(C2)))
return BinaryOperator::CreateShl(ConstantExpr::getShl(C1, C2), A);
- return 0;
+ return nullptr;
}
Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
+ if (Value *V = SimplifyVectorOp(I))
+ return ReplaceInstUsesWith(I, V);
+
if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1),
I.isExact(), DL))
return ReplaceInstUsesWith(I, V);
@@ -749,10 +766,13 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
}
}
- return 0;
+ return nullptr;
}
Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
+ if (Value *V = SimplifyVectorOp(I))
+ return ReplaceInstUsesWith(I, V);
+
if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1),
I.isExact(), DL))
return ReplaceInstUsesWith(I, V);
@@ -805,6 +825,5 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
if (NumSignBits == Op0->getType()->getScalarSizeInBits())
return ReplaceInstUsesWith(I, Op0);
- return 0;
+ return nullptr;
}
-