diff options
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 105 | ||||
-rw-r--r-- | test/Transforms/InstCombine/select.ll | 82 |
2 files changed, 152 insertions, 35 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 03a030d..c2caedf 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -285,48 +285,83 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // place here, so make sure the select is the only user. if (ICI->hasOneUse()) if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) { + // X < MIN ? T : F --> F + if ((Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) + && CI->isMinValue(Pred == ICmpInst::ICMP_SLT)) + return ReplaceInstUsesWith(SI, FalseVal); + // X > MAX ? T : F --> F + else if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT) + && CI->isMaxValue(Pred == ICmpInst::ICMP_SGT)) + return ReplaceInstUsesWith(SI, FalseVal); switch (Pred) { default: break; case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_SLT: { - // X < MIN ? T : F --> F - if (CI->isMinValue(Pred == ICmpInst::ICMP_SLT)) - return ReplaceInstUsesWith(SI, FalseVal); - // X < C ? X : C-1 --> X > C-1 ? C-1 : X - Constant *AdjustedRHS = - ConstantInt::get(CI->getContext(), CI->getValue()-1); - if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || - (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { - Pred = ICmpInst::getSwappedPredicate(Pred); - CmpRHS = AdjustedRHS; - std::swap(FalseVal, TrueVal); - ICI->setPredicate(Pred); - ICI->setOperand(1, CmpRHS); - SI.setOperand(1, TrueVal); - SI.setOperand(2, FalseVal); - Changed = true; - } - break; - } + case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_SGT: { - // X > MAX ? T : F --> F - if (CI->isMaxValue(Pred == ICmpInst::ICMP_SGT)) - return ReplaceInstUsesWith(SI, FalseVal); + Constant *AdjustedRHS; + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT) + AdjustedRHS = ConstantInt::get(CI->getContext(), CI->getValue() + 1); + else // (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) + AdjustedRHS = ConstantInt::get(CI->getContext(), CI->getValue() - 1); + // X > C ? X : C+1 --> X < C+1 ? C+1 : X - Constant *AdjustedRHS = - ConstantInt::get(CI->getContext(), CI->getValue()+1); + // X < C ? X : C-1 --> X > C-1 ? C-1 : X if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || - (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { - Pred = ICmpInst::getSwappedPredicate(Pred); - CmpRHS = AdjustedRHS; - std::swap(FalseVal, TrueVal); - ICI->setPredicate(Pred); - ICI->setOperand(1, CmpRHS); - SI.setOperand(1, TrueVal); - SI.setOperand(2, FalseVal); - Changed = true; - } + (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) + ; // Nothing to do here. Values match without any sign/zero extension. + + // Types do not match. Instead of calculating this with mixed types + // promote all to the larger type. This enables scalar evolution to + // analyze this expression. + else if (CmpRHS->getType()->getScalarSizeInBits() + < SI.getType()->getScalarSizeInBits()) { + Constant *sextRHS = ConstantExpr::getSExt(AdjustedRHS, + SI.getType()); + + // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X + // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X + // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X + // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X + if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && + sextRHS == FalseVal) { + CmpLHS = TrueVal; + AdjustedRHS = sextRHS; + } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) && + sextRHS == TrueVal) { + CmpLHS = FalseVal; + AdjustedRHS = sextRHS; + } else if (ICI->isUnsigned()) { + Constant *zextRHS = ConstantExpr::getZExt(AdjustedRHS, + SI.getType()); + // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X + // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X + // zext + signed compare cannot be changed: + // 0xff <s 0x00, but 0x00ff >s 0x0000 + if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && + zextRHS == FalseVal) { + CmpLHS = TrueVal; + AdjustedRHS = zextRHS; + } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) && + zextRHS == TrueVal) { + CmpLHS = FalseVal; + AdjustedRHS = zextRHS; + } else + break; + } else + break; + } else + break; + + Pred = ICmpInst::getSwappedPredicate(Pred); + CmpRHS = AdjustedRHS; + std::swap(FalseVal, TrueVal); + ICI->setPredicate(Pred); + ICI->setOperand(0, CmpLHS); + ICI->setOperand(1, CmpRHS); + SI.setOperand(1, TrueVal); + SI.setOperand(2, FalseVal); + Changed = true; break; } } diff --git a/test/Transforms/InstCombine/select.ll b/test/Transforms/InstCombine/select.ll index fd3937d..c9b880d 100644 --- a/test/Transforms/InstCombine/select.ll +++ b/test/Transforms/InstCombine/select.ll @@ -596,3 +596,85 @@ define i32 @test42(i32 %x, i32 %y) { ; CHECK-NEXT: %c = add i32 %b, %y ; CHECK-NEXT: ret i32 %c } + +define i64 @test43(i32 %a) nounwind { + %a_ext = sext i32 %a to i64 + %is_a_nonnegative = icmp sgt i32 %a, -1 + %max = select i1 %is_a_nonnegative, i64 %a_ext, i64 0 + ret i64 %max +; CHECK: @test43 +; CHECK-NEXT: %a_ext = sext i32 %a to i64 +; CHECK-NEXT: %is_a_nonnegative = icmp slt i64 %a_ext, 0 +; CHECK-NEXT: %max = select i1 %is_a_nonnegative, i64 0, i64 %a_ext +; CHECK-NEXT: ret i64 %max +} + +define i64 @test44(i32 %a) nounwind { + %a_ext = sext i32 %a to i64 + %is_a_nonpositive = icmp slt i32 %a, 1 + %min = select i1 %is_a_nonpositive, i64 %a_ext, i64 0 + ret i64 %min +; CHECK: @test44 +; CHECK-NEXT: %a_ext = sext i32 %a to i64 +; CHECK-NEXT: %is_a_nonpositive = icmp sgt i64 %a_ext, 0 +; CHECK-NEXT: %min = select i1 %is_a_nonpositive, i64 0, i64 %a_ext +; CHECK-NEXT: ret i64 %min +} +define i64 @test45(i32 %a) nounwind { + %a_ext = zext i32 %a to i64 + %is_a_nonnegative = icmp ugt i32 %a, 2 + %max = select i1 %is_a_nonnegative, i64 %a_ext, i64 3 + ret i64 %max +; CHECK: @test45 +; CHECK-NEXT: %a_ext = zext i32 %a to i64 +; CHECK-NEXT: %is_a_nonnegative = icmp ult i64 %a_ext, 3 +; CHECK-NEXT: %max = select i1 %is_a_nonnegative, i64 3, i64 %a_ext +; CHECK-NEXT: ret i64 %max +} + +define i64 @test46(i32 %a) nounwind { + %a_ext = zext i32 %a to i64 + %is_a_nonpositive = icmp ult i32 %a, 3 + %min = select i1 %is_a_nonpositive, i64 %a_ext, i64 2 + ret i64 %min +; CHECK: @test46 +; CHECK-NEXT: %a_ext = zext i32 %a to i64 +; CHECK-NEXT: %is_a_nonpositive = icmp ugt i64 %a_ext, 2 +; CHECK-NEXT: %min = select i1 %is_a_nonpositive, i64 2, i64 %a_ext +; CHECK-NEXT: ret i64 %min +} +define i64 @test47(i32 %a) nounwind { + %a_ext = sext i32 %a to i64 + %is_a_nonnegative = icmp ugt i32 %a, 2 + %max = select i1 %is_a_nonnegative, i64 %a_ext, i64 3 + ret i64 %max +; CHECK: @test47 +; CHECK-NEXT: %a_ext = sext i32 %a to i64 +; CHECK-NEXT: %is_a_nonnegative = icmp ult i64 %a_ext, 3 +; CHECK-NEXT: %max = select i1 %is_a_nonnegative, i64 3, i64 %a_ext +; CHECK-NEXT: ret i64 %max +} + +define i64 @test48(i32 %a) nounwind { + %a_ext = sext i32 %a to i64 + %is_a_nonpositive = icmp ult i32 %a, 3 + %min = select i1 %is_a_nonpositive, i64 %a_ext, i64 2 + ret i64 %min +; CHECK: @test48 +; CHECK-NEXT: %a_ext = sext i32 %a to i64 +; CHECK-NEXT: %is_a_nonpositive = icmp ugt i64 %a_ext, 2 +; CHECK-NEXT: %min = select i1 %is_a_nonpositive, i64 2, i64 %a_ext +; CHECK-NEXT: ret i64 %min +} + +define i64 @test49(i32 %a) nounwind { + %a_ext = sext i32 %a to i64 + %is_a_nonpositive = icmp ult i32 %a, 3 + %min = select i1 %is_a_nonpositive, i64 2, i64 %a_ext + ret i64 %min +; CHECK: @test49 +; CHECK-NEXT: %a_ext = sext i32 %a to i64 +; CHECK-NEXT: %is_a_nonpositive = icmp ugt i64 %a_ext, 2 +; CHECK-NEXT: %min = select i1 %is_a_nonpositive, i64 %a_ext, i64 2 +; CHECK-NEXT: ret i64 %min +} |