aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lib/Transforms/InstCombine/InstCombineSelect.cpp105
-rw-r--r--test/Transforms/InstCombine/select.ll82
2 files changed, 152 insertions, 35 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 03a030d..c2caedf 100644
--- a/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -285,48 +285,83 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
// place here, so make sure the select is the only user.
if (ICI->hasOneUse())
if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) {
+ // X < MIN ? T : F --> F
+ if ((Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT)
+ && CI->isMinValue(Pred == ICmpInst::ICMP_SLT))
+ return ReplaceInstUsesWith(SI, FalseVal);
+ // X > MAX ? T : F --> F
+ else if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT)
+ && CI->isMaxValue(Pred == ICmpInst::ICMP_SGT))
+ return ReplaceInstUsesWith(SI, FalseVal);
switch (Pred) {
default: break;
case ICmpInst::ICMP_ULT:
- case ICmpInst::ICMP_SLT: {
- // X < MIN ? T : F --> F
- if (CI->isMinValue(Pred == ICmpInst::ICMP_SLT))
- return ReplaceInstUsesWith(SI, FalseVal);
- // X < C ? X : C-1 --> X > C-1 ? C-1 : X
- Constant *AdjustedRHS =
- ConstantInt::get(CI->getContext(), CI->getValue()-1);
- if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) ||
- (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) {
- Pred = ICmpInst::getSwappedPredicate(Pred);
- CmpRHS = AdjustedRHS;
- std::swap(FalseVal, TrueVal);
- ICI->setPredicate(Pred);
- ICI->setOperand(1, CmpRHS);
- SI.setOperand(1, TrueVal);
- SI.setOperand(2, FalseVal);
- Changed = true;
- }
- break;
- }
+ case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_SGT: {
- // X > MAX ? T : F --> F
- if (CI->isMaxValue(Pred == ICmpInst::ICMP_SGT))
- return ReplaceInstUsesWith(SI, FalseVal);
+ Constant *AdjustedRHS;
+ if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT)
+ AdjustedRHS = ConstantInt::get(CI->getContext(), CI->getValue() + 1);
+ else // (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT)
+ AdjustedRHS = ConstantInt::get(CI->getContext(), CI->getValue() - 1);
+
// X > C ? X : C+1 --> X < C+1 ? C+1 : X
- Constant *AdjustedRHS =
- ConstantInt::get(CI->getContext(), CI->getValue()+1);
+ // X < C ? X : C-1 --> X > C-1 ? C-1 : X
if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) ||
- (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) {
- Pred = ICmpInst::getSwappedPredicate(Pred);
- CmpRHS = AdjustedRHS;
- std::swap(FalseVal, TrueVal);
- ICI->setPredicate(Pred);
- ICI->setOperand(1, CmpRHS);
- SI.setOperand(1, TrueVal);
- SI.setOperand(2, FalseVal);
- Changed = true;
- }
+ (CmpLHS == FalseVal && AdjustedRHS == TrueVal))
+ ; // Nothing to do here. Values match without any sign/zero extension.
+
+ // Types do not match. Instead of calculating this with mixed types
+ // promote all to the larger type. This enables scalar evolution to
+ // analyze this expression.
+ else if (CmpRHS->getType()->getScalarSizeInBits()
+ < SI.getType()->getScalarSizeInBits()) {
+ Constant *sextRHS = ConstantExpr::getSExt(AdjustedRHS,
+ SI.getType());
+
+ // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X
+ // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X
+ // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X
+ // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X
+ if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) &&
+ sextRHS == FalseVal) {
+ CmpLHS = TrueVal;
+ AdjustedRHS = sextRHS;
+ } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) &&
+ sextRHS == TrueVal) {
+ CmpLHS = FalseVal;
+ AdjustedRHS = sextRHS;
+ } else if (ICI->isUnsigned()) {
+ Constant *zextRHS = ConstantExpr::getZExt(AdjustedRHS,
+ SI.getType());
+ // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X
+ // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X
+ // zext + signed compare cannot be changed:
+ // 0xff <s 0x00, but 0x00ff >s 0x0000
+ if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) &&
+ zextRHS == FalseVal) {
+ CmpLHS = TrueVal;
+ AdjustedRHS = zextRHS;
+ } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) &&
+ zextRHS == TrueVal) {
+ CmpLHS = FalseVal;
+ AdjustedRHS = zextRHS;
+ } else
+ break;
+ } else
+ break;
+ } else
+ break;
+
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ CmpRHS = AdjustedRHS;
+ std::swap(FalseVal, TrueVal);
+ ICI->setPredicate(Pred);
+ ICI->setOperand(0, CmpLHS);
+ ICI->setOperand(1, CmpRHS);
+ SI.setOperand(1, TrueVal);
+ SI.setOperand(2, FalseVal);
+ Changed = true;
break;
}
}
diff --git a/test/Transforms/InstCombine/select.ll b/test/Transforms/InstCombine/select.ll
index fd3937d..c9b880d 100644
--- a/test/Transforms/InstCombine/select.ll
+++ b/test/Transforms/InstCombine/select.ll
@@ -596,3 +596,85 @@ define i32 @test42(i32 %x, i32 %y) {
; CHECK-NEXT: %c = add i32 %b, %y
; CHECK-NEXT: ret i32 %c
}
+
+define i64 @test43(i32 %a) nounwind {
+ %a_ext = sext i32 %a to i64
+ %is_a_nonnegative = icmp sgt i32 %a, -1
+ %max = select i1 %is_a_nonnegative, i64 %a_ext, i64 0
+ ret i64 %max
+; CHECK: @test43
+; CHECK-NEXT: %a_ext = sext i32 %a to i64
+; CHECK-NEXT: %is_a_nonnegative = icmp slt i64 %a_ext, 0
+; CHECK-NEXT: %max = select i1 %is_a_nonnegative, i64 0, i64 %a_ext
+; CHECK-NEXT: ret i64 %max
+}
+
+define i64 @test44(i32 %a) nounwind {
+ %a_ext = sext i32 %a to i64
+ %is_a_nonpositive = icmp slt i32 %a, 1
+ %min = select i1 %is_a_nonpositive, i64 %a_ext, i64 0
+ ret i64 %min
+; CHECK: @test44
+; CHECK-NEXT: %a_ext = sext i32 %a to i64
+; CHECK-NEXT: %is_a_nonpositive = icmp sgt i64 %a_ext, 0
+; CHECK-NEXT: %min = select i1 %is_a_nonpositive, i64 0, i64 %a_ext
+; CHECK-NEXT: ret i64 %min
+}
+define i64 @test45(i32 %a) nounwind {
+ %a_ext = zext i32 %a to i64
+ %is_a_nonnegative = icmp ugt i32 %a, 2
+ %max = select i1 %is_a_nonnegative, i64 %a_ext, i64 3
+ ret i64 %max
+; CHECK: @test45
+; CHECK-NEXT: %a_ext = zext i32 %a to i64
+; CHECK-NEXT: %is_a_nonnegative = icmp ult i64 %a_ext, 3
+; CHECK-NEXT: %max = select i1 %is_a_nonnegative, i64 3, i64 %a_ext
+; CHECK-NEXT: ret i64 %max
+}
+
+define i64 @test46(i32 %a) nounwind {
+ %a_ext = zext i32 %a to i64
+ %is_a_nonpositive = icmp ult i32 %a, 3
+ %min = select i1 %is_a_nonpositive, i64 %a_ext, i64 2
+ ret i64 %min
+; CHECK: @test46
+; CHECK-NEXT: %a_ext = zext i32 %a to i64
+; CHECK-NEXT: %is_a_nonpositive = icmp ugt i64 %a_ext, 2
+; CHECK-NEXT: %min = select i1 %is_a_nonpositive, i64 2, i64 %a_ext
+; CHECK-NEXT: ret i64 %min
+}
+define i64 @test47(i32 %a) nounwind {
+ %a_ext = sext i32 %a to i64
+ %is_a_nonnegative = icmp ugt i32 %a, 2
+ %max = select i1 %is_a_nonnegative, i64 %a_ext, i64 3
+ ret i64 %max
+; CHECK: @test47
+; CHECK-NEXT: %a_ext = sext i32 %a to i64
+; CHECK-NEXT: %is_a_nonnegative = icmp ult i64 %a_ext, 3
+; CHECK-NEXT: %max = select i1 %is_a_nonnegative, i64 3, i64 %a_ext
+; CHECK-NEXT: ret i64 %max
+}
+
+define i64 @test48(i32 %a) nounwind {
+ %a_ext = sext i32 %a to i64
+ %is_a_nonpositive = icmp ult i32 %a, 3
+ %min = select i1 %is_a_nonpositive, i64 %a_ext, i64 2
+ ret i64 %min
+; CHECK: @test48
+; CHECK-NEXT: %a_ext = sext i32 %a to i64
+; CHECK-NEXT: %is_a_nonpositive = icmp ugt i64 %a_ext, 2
+; CHECK-NEXT: %min = select i1 %is_a_nonpositive, i64 2, i64 %a_ext
+; CHECK-NEXT: ret i64 %min
+}
+
+define i64 @test49(i32 %a) nounwind {
+ %a_ext = sext i32 %a to i64
+ %is_a_nonpositive = icmp ult i32 %a, 3
+ %min = select i1 %is_a_nonpositive, i64 2, i64 %a_ext
+ ret i64 %min
+; CHECK: @test49
+; CHECK-NEXT: %a_ext = sext i32 %a to i64
+; CHECK-NEXT: %is_a_nonpositive = icmp ugt i64 %a_ext, 2
+; CHECK-NEXT: %min = select i1 %is_a_nonpositive, i64 %a_ext, i64 2
+; CHECK-NEXT: ret i64 %min
+}