aboutsummaryrefslogtreecommitdiffstats
path: root/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--lib/Analysis/ScalarEvolution.cpp629
1 files changed, 330 insertions, 299 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp
index 9e4eb11..4e713fb 100644
--- a/lib/Analysis/ScalarEvolution.cpp
+++ b/lib/Analysis/ScalarEvolution.cpp
@@ -1102,13 +1102,14 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
// trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
- // eliminate all the truncates.
+ // eliminate all the truncates, or we replace other casts with truncates.
if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
SmallVector<const SCEV *, 4> Operands;
bool hasTrunc = false;
for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
- hasTrunc = isa<SCEVTruncateExpr>(S);
+ if (!isa<SCEVCastExpr>(SA->getOperand(i)))
+ hasTrunc = isa<SCEVTruncateExpr>(S);
Operands.push_back(S);
}
if (!hasTrunc)
@@ -1117,13 +1118,14 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
}
// trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
- // eliminate all the truncates.
+ // eliminate all the truncates, or we replace other casts with truncates.
if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
SmallVector<const SCEV *, 4> Operands;
bool hasTrunc = false;
for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
- hasTrunc = isa<SCEVTruncateExpr>(S);
+ if (!isa<SCEVCastExpr>(SM->getOperand(i)))
+ hasTrunc = isa<SCEVTruncateExpr>(S);
Operands.push_back(S);
}
if (!hasTrunc)
@@ -1325,6 +1327,85 @@ static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
(SE->*GetExtendExpr)(PreStart, Ty));
}
+// Try to prove away overflow by looking at "nearby" add recurrences. A
+// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
+// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
+//
+// Formally:
+//
+// {S,+,X} == {S-T,+,X} + T
+// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
+//
+// If ({S-T,+,X} + T) does not overflow ... (1)
+//
+// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
+//
+// If {S-T,+,X} does not overflow ... (2)
+//
+// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
+// == {Ext(S-T)+Ext(T),+,Ext(X)}
+//
+// If (S-T)+T does not overflow ... (3)
+//
+// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
+// == {Ext(S),+,Ext(X)} == LHS
+//
+// Thus, if (1), (2) and (3) are true for some T, then
+// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
+//
+// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
+// does not overflow" restricted to the 0th iteration. Therefore we only need
+// to check for (1) and (2).
+//
+// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
+// is `Delta` (defined below).
+//
+template <typename ExtendOpTy>
+bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
+ const SCEV *Step,
+ const Loop *L) {
+ auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
+
+ // We restrict `Start` to a constant to prevent SCEV from spending too much
+ // time here. It is correct (but more expensive) to continue with a
+ // non-constant `Start` and do a general SCEV subtraction to compute
+ // `PreStart` below.
+ //
+ const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
+ if (!StartC)
+ return false;
+
+ APInt StartAI = StartC->getValue()->getValue();
+
+ for (unsigned Delta : {-2, -1, 1, 2}) {
+ const SCEV *PreStart = getConstant(StartAI - Delta);
+
+ // Give up if we don't already have the add recurrence we need because
+ // actually constructing an add recurrence is relatively expensive.
+ const SCEVAddRecExpr *PreAR = [&]() {
+ FoldingSetNodeID ID;
+ ID.AddInteger(scAddRecExpr);
+ ID.AddPointer(PreStart);
+ ID.AddPointer(Step);
+ ID.AddPointer(L);
+ void *IP = nullptr;
+ return static_cast<SCEVAddRecExpr *>(
+ this->UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+ }();
+
+ if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
+ const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
+ ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
+ const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
+ DeltaS, &Pred, this);
+ if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
+ return true;
+ }
+ }
+
+ return false;
+}
+
const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
Type *Ty) {
assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
@@ -1473,6 +1554,13 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
}
}
}
+
+ if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
+ getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+ }
}
// The cast wasn't folded; create an explicit cast node.
@@ -1664,6 +1752,13 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
return getAddExpr(Start, getSignExtendExpr(NewAR, Ty));
}
}
+
+ if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
+ getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+ }
}
// The cast wasn't folded; create an explicit cast node.
@@ -3037,39 +3132,23 @@ const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
}
const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
- // If we have DataLayout, we can bypass creating a target-independent
+ // We can bypass creating a target-independent
// constant expression and then folding it back into a ConstantInt.
// This is just a compile-time optimization.
- if (DL)
- return getConstant(IntTy, DL->getTypeAllocSize(AllocTy));
-
- Constant *C = ConstantExpr::getSizeOf(AllocTy);
- if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
- if (Constant *Folded = ConstantFoldConstantExpression(CE, DL, TLI))
- C = Folded;
- Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
- assert(Ty == IntTy && "Effective SCEV type doesn't match");
- return getTruncateOrZeroExtend(getSCEV(C), Ty);
+ return getConstant(IntTy,
+ F->getParent()->getDataLayout().getTypeAllocSize(AllocTy));
}
const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
StructType *STy,
unsigned FieldNo) {
- // If we have DataLayout, we can bypass creating a target-independent
+ // We can bypass creating a target-independent
// constant expression and then folding it back into a ConstantInt.
// This is just a compile-time optimization.
- if (DL) {
- return getConstant(IntTy,
- DL->getStructLayout(STy)->getElementOffset(FieldNo));
- }
-
- Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo);
- if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
- if (Constant *Folded = ConstantFoldConstantExpression(CE, DL, TLI))
- C = Folded;
-
- Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy));
- return getTruncateOrZeroExtend(getSCEV(C), Ty);
+ return getConstant(
+ IntTy,
+ F->getParent()->getDataLayout().getStructLayout(STy)->getElementOffset(
+ FieldNo));
}
const SCEV *ScalarEvolution::getUnknown(Value *V) {
@@ -3111,19 +3190,7 @@ bool ScalarEvolution::isSCEVable(Type *Ty) const {
/// for which isSCEVable must return true.
uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
assert(isSCEVable(Ty) && "Type is not SCEVable!");
-
- // If we have a DataLayout, use it!
- if (DL)
- return DL->getTypeSizeInBits(Ty);
-
- // Integer types have fixed sizes.
- if (Ty->isIntegerTy())
- return Ty->getPrimitiveSizeInBits();
-
- // The only other support type is pointer. Without DataLayout, conservatively
- // assume pointers are 64-bit.
- assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!");
- return 64;
+ return F->getParent()->getDataLayout().getTypeSizeInBits(Ty);
}
/// getEffectiveSCEVType - Return a type with the same bitwidth as
@@ -3139,12 +3206,7 @@ Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
// The only other support type is pointer.
assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
-
- if (DL)
- return DL->getIntPtrType(Ty);
-
- // Without DataLayout, conservatively assume pointers are 64-bit.
- return Type::getInt64Ty(getContext());
+ return F->getParent()->getDataLayout().getIntPtrType(Ty);
}
const SCEV *ScalarEvolution::getCouldNotCompute() {
@@ -3531,10 +3593,12 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
// If the increment doesn't overflow, then neither the addrec nor
// the post-increment will overflow.
if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
- if (OBO->hasNoUnsignedWrap())
- Flags = setFlags(Flags, SCEV::FlagNUW);
- if (OBO->hasNoSignedWrap())
- Flags = setFlags(Flags, SCEV::FlagNSW);
+ if (OBO->getOperand(0) == PN) {
+ if (OBO->hasNoUnsignedWrap())
+ Flags = setFlags(Flags, SCEV::FlagNUW);
+ if (OBO->hasNoSignedWrap())
+ Flags = setFlags(Flags, SCEV::FlagNSW);
+ }
} else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
// If the increment is an inbounds GEP, then we know the address
// space cannot be wrapped around. We cannot make any guarantee
@@ -3542,7 +3606,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
// unsigned but we may have a negative index from the base
// pointer. We can guarantee that no unsigned wrap occurs if the
// indices form a positive value.
- if (GEP->isInBounds()) {
+ if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
Flags = setFlags(Flags, SCEV::FlagNW);
const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
@@ -3608,7 +3672,8 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
// PHI's incoming blocks are in a different loop, in which case doing so
// risks breaking LCSSA form. Instcombine would normally zap these, but
// it doesn't have DominatorTree information, so it may miss cases.
- if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AC))
+ if (Value *V =
+ SimplifyInstruction(PN, F->getParent()->getDataLayout(), TLI, DT, AC))
if (LI->replacementPreservesLCSSAForm(PN, V))
return getSCEV(V);
@@ -3740,7 +3805,8 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
// For a SCEVUnknown, ask ValueTracking.
unsigned BitWidth = getTypeSizeInBits(U->getType());
APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
- computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT);
+ computeKnownBits(U->getValue(), Zeros, Ones,
+ F->getParent()->getDataLayout(), 0, AC, nullptr, DT);
return Zeros.countTrailingOnes();
}
@@ -3775,79 +3841,93 @@ static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
return None;
}
-/// getUnsignedRange - Determine the unsigned range for a particular SCEV.
+/// getRange - Determine the range for a particular SCEV. If SignHint is
+/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
+/// with a "cleaner" unsigned (resp. signed) representation.
///
ConstantRange
-ScalarEvolution::getUnsignedRange(const SCEV *S) {
+ScalarEvolution::getRange(const SCEV *S,
+ ScalarEvolution::RangeSignHint SignHint) {
+ DenseMap<const SCEV *, ConstantRange> &Cache =
+ SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
+ : SignedRanges;
+
// See if we've computed this range already.
- DenseMap<const SCEV *, ConstantRange>::iterator I = UnsignedRanges.find(S);
- if (I != UnsignedRanges.end())
+ DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
+ if (I != Cache.end())
return I->second;
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
- return setUnsignedRange(C, ConstantRange(C->getValue()->getValue()));
+ return setRange(C, SignHint, ConstantRange(C->getValue()->getValue()));
unsigned BitWidth = getTypeSizeInBits(S->getType());
ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
- // If the value has known zeros, the maximum unsigned value will have those
- // known zeros as well.
+ // If the value has known zeros, the maximum value will have those known zeros
+ // as well.
uint32_t TZ = GetMinTrailingZeros(S);
- if (TZ != 0)
- ConservativeResult =
- ConstantRange(APInt::getMinValue(BitWidth),
- APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
+ if (TZ != 0) {
+ if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
+ ConservativeResult =
+ ConstantRange(APInt::getMinValue(BitWidth),
+ APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
+ else
+ ConservativeResult = ConstantRange(
+ APInt::getSignedMinValue(BitWidth),
+ APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
+ }
if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
- ConstantRange X = getUnsignedRange(Add->getOperand(0));
+ ConstantRange X = getRange(Add->getOperand(0), SignHint);
for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
- X = X.add(getUnsignedRange(Add->getOperand(i)));
- return setUnsignedRange(Add, ConservativeResult.intersectWith(X));
+ X = X.add(getRange(Add->getOperand(i), SignHint));
+ return setRange(Add, SignHint, ConservativeResult.intersectWith(X));
}
if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
- ConstantRange X = getUnsignedRange(Mul->getOperand(0));
+ ConstantRange X = getRange(Mul->getOperand(0), SignHint);
for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
- X = X.multiply(getUnsignedRange(Mul->getOperand(i)));
- return setUnsignedRange(Mul, ConservativeResult.intersectWith(X));
+ X = X.multiply(getRange(Mul->getOperand(i), SignHint));
+ return setRange(Mul, SignHint, ConservativeResult.intersectWith(X));
}
if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
- ConstantRange X = getUnsignedRange(SMax->getOperand(0));
+ ConstantRange X = getRange(SMax->getOperand(0), SignHint);
for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
- X = X.smax(getUnsignedRange(SMax->getOperand(i)));
- return setUnsignedRange(SMax, ConservativeResult.intersectWith(X));
+ X = X.smax(getRange(SMax->getOperand(i), SignHint));
+ return setRange(SMax, SignHint, ConservativeResult.intersectWith(X));
}
if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
- ConstantRange X = getUnsignedRange(UMax->getOperand(0));
+ ConstantRange X = getRange(UMax->getOperand(0), SignHint);
for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
- X = X.umax(getUnsignedRange(UMax->getOperand(i)));
- return setUnsignedRange(UMax, ConservativeResult.intersectWith(X));
+ X = X.umax(getRange(UMax->getOperand(i), SignHint));
+ return setRange(UMax, SignHint, ConservativeResult.intersectWith(X));
}
if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
- ConstantRange X = getUnsignedRange(UDiv->getLHS());
- ConstantRange Y = getUnsignedRange(UDiv->getRHS());
- return setUnsignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y)));
+ ConstantRange X = getRange(UDiv->getLHS(), SignHint);
+ ConstantRange Y = getRange(UDiv->getRHS(), SignHint);
+ return setRange(UDiv, SignHint,
+ ConservativeResult.intersectWith(X.udiv(Y)));
}
if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
- ConstantRange X = getUnsignedRange(ZExt->getOperand());
- return setUnsignedRange(ZExt,
- ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
+ ConstantRange X = getRange(ZExt->getOperand(), SignHint);
+ return setRange(ZExt, SignHint,
+ ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
}
if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
- ConstantRange X = getUnsignedRange(SExt->getOperand());
- return setUnsignedRange(SExt,
- ConservativeResult.intersectWith(X.signExtend(BitWidth)));
+ ConstantRange X = getRange(SExt->getOperand(), SignHint);
+ return setRange(SExt, SignHint,
+ ConservativeResult.intersectWith(X.signExtend(BitWidth)));
}
if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
- ConstantRange X = getUnsignedRange(Trunc->getOperand());
- return setUnsignedRange(Trunc,
- ConservativeResult.intersectWith(X.truncate(BitWidth)));
+ ConstantRange X = getRange(Trunc->getOperand(), SignHint);
+ return setRange(Trunc, SignHint,
+ ConservativeResult.intersectWith(X.truncate(BitWidth)));
}
if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
@@ -3860,143 +3940,6 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) {
ConservativeResult.intersectWith(
ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0)));
- // TODO: non-affine addrec
- if (AddRec->isAffine()) {
- Type *Ty = AddRec->getType();
- const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
- if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
- getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
- MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
-
- const SCEV *Start = AddRec->getStart();
- const SCEV *Step = AddRec->getStepRecurrence(*this);
-
- ConstantRange StartRange = getUnsignedRange(Start);
- ConstantRange StepRange = getSignedRange(Step);
- ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
- ConstantRange EndRange =
- StartRange.add(MaxBECountRange.multiply(StepRange));
-
- // Check for overflow. This must be done with ConstantRange arithmetic
- // because we could be called from within the ScalarEvolution overflow
- // checking code.
- ConstantRange ExtStartRange = StartRange.zextOrTrunc(BitWidth*2+1);
- ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
- ConstantRange ExtMaxBECountRange =
- MaxBECountRange.zextOrTrunc(BitWidth*2+1);
- ConstantRange ExtEndRange = EndRange.zextOrTrunc(BitWidth*2+1);
- if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
- ExtEndRange)
- return setUnsignedRange(AddRec, ConservativeResult);
-
- APInt Min = APIntOps::umin(StartRange.getUnsignedMin(),
- EndRange.getUnsignedMin());
- APInt Max = APIntOps::umax(StartRange.getUnsignedMax(),
- EndRange.getUnsignedMax());
- if (Min.isMinValue() && Max.isMaxValue())
- return setUnsignedRange(AddRec, ConservativeResult);
- return setUnsignedRange(AddRec,
- ConservativeResult.intersectWith(ConstantRange(Min, Max+1)));
- }
- }
-
- return setUnsignedRange(AddRec, ConservativeResult);
- }
-
- if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
- // Check if the IR explicitly contains !range metadata.
- Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
- if (MDRange.hasValue())
- ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue());
-
- // For a SCEVUnknown, ask ValueTracking.
- APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
- computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT);
- if (Ones == ~Zeros + 1)
- return setUnsignedRange(U, ConservativeResult);
- return setUnsignedRange(U,
- ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)));
- }
-
- return setUnsignedRange(S, ConservativeResult);
-}
-
-/// getSignedRange - Determine the signed range for a particular SCEV.
-///
-ConstantRange
-ScalarEvolution::getSignedRange(const SCEV *S) {
- // See if we've computed this range already.
- DenseMap<const SCEV *, ConstantRange>::iterator I = SignedRanges.find(S);
- if (I != SignedRanges.end())
- return I->second;
-
- if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
- return setSignedRange(C, ConstantRange(C->getValue()->getValue()));
-
- unsigned BitWidth = getTypeSizeInBits(S->getType());
- ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
-
- // If the value has known zeros, the maximum signed value will have those
- // known zeros as well.
- uint32_t TZ = GetMinTrailingZeros(S);
- if (TZ != 0)
- ConservativeResult =
- ConstantRange(APInt::getSignedMinValue(BitWidth),
- APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
-
- if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
- ConstantRange X = getSignedRange(Add->getOperand(0));
- for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
- X = X.add(getSignedRange(Add->getOperand(i)));
- return setSignedRange(Add, ConservativeResult.intersectWith(X));
- }
-
- if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
- ConstantRange X = getSignedRange(Mul->getOperand(0));
- for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
- X = X.multiply(getSignedRange(Mul->getOperand(i)));
- return setSignedRange(Mul, ConservativeResult.intersectWith(X));
- }
-
- if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
- ConstantRange X = getSignedRange(SMax->getOperand(0));
- for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
- X = X.smax(getSignedRange(SMax->getOperand(i)));
- return setSignedRange(SMax, ConservativeResult.intersectWith(X));
- }
-
- if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
- ConstantRange X = getSignedRange(UMax->getOperand(0));
- for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
- X = X.umax(getSignedRange(UMax->getOperand(i)));
- return setSignedRange(UMax, ConservativeResult.intersectWith(X));
- }
-
- if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
- ConstantRange X = getSignedRange(UDiv->getLHS());
- ConstantRange Y = getSignedRange(UDiv->getRHS());
- return setSignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y)));
- }
-
- if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
- ConstantRange X = getSignedRange(ZExt->getOperand());
- return setSignedRange(ZExt,
- ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
- }
-
- if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
- ConstantRange X = getSignedRange(SExt->getOperand());
- return setSignedRange(SExt,
- ConservativeResult.intersectWith(X.signExtend(BitWidth)));
- }
-
- if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
- ConstantRange X = getSignedRange(Trunc->getOperand());
- return setSignedRange(Trunc,
- ConservativeResult.intersectWith(X.truncate(BitWidth)));
- }
-
- if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
// If there's no signed wrap, and all the operands have the same sign or
// zero, the value won't ever change sign.
if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) {
@@ -4022,41 +3965,66 @@ ScalarEvolution::getSignedRange(const SCEV *S) {
const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
+
+ // Check for overflow. This must be done with ConstantRange arithmetic
+ // because we could be called from within the ScalarEvolution overflow
+ // checking code.
+
MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
+ ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
+ ConstantRange ZExtMaxBECountRange =
+ MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1);
const SCEV *Start = AddRec->getStart();
const SCEV *Step = AddRec->getStepRecurrence(*this);
+ ConstantRange StepSRange = getSignedRange(Step);
+ ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1);
+
+ ConstantRange StartURange = getUnsignedRange(Start);
+ ConstantRange EndURange =
+ StartURange.add(MaxBECountRange.multiply(StepSRange));
+
+ // Check for unsigned overflow.
+ ConstantRange ZExtStartURange =
+ StartURange.zextOrTrunc(BitWidth * 2 + 1);
+ ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1);
+ if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
+ ZExtEndURange) {
+ APInt Min = APIntOps::umin(StartURange.getUnsignedMin(),
+ EndURange.getUnsignedMin());
+ APInt Max = APIntOps::umax(StartURange.getUnsignedMax(),
+ EndURange.getUnsignedMax());
+ bool IsFullRange = Min.isMinValue() && Max.isMaxValue();
+ if (!IsFullRange)
+ ConservativeResult =
+ ConservativeResult.intersectWith(ConstantRange(Min, Max + 1));
+ }
- ConstantRange StartRange = getSignedRange(Start);
- ConstantRange StepRange = getSignedRange(Step);
- ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
- ConstantRange EndRange =
- StartRange.add(MaxBECountRange.multiply(StepRange));
-
- // Check for overflow. This must be done with ConstantRange arithmetic
- // because we could be called from within the ScalarEvolution overflow
- // checking code.
- ConstantRange ExtStartRange = StartRange.sextOrTrunc(BitWidth*2+1);
- ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
- ConstantRange ExtMaxBECountRange =
- MaxBECountRange.zextOrTrunc(BitWidth*2+1);
- ConstantRange ExtEndRange = EndRange.sextOrTrunc(BitWidth*2+1);
- if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
- ExtEndRange)
- return setSignedRange(AddRec, ConservativeResult);
-
- APInt Min = APIntOps::smin(StartRange.getSignedMin(),
- EndRange.getSignedMin());
- APInt Max = APIntOps::smax(StartRange.getSignedMax(),
- EndRange.getSignedMax());
- if (Min.isMinSignedValue() && Max.isMaxSignedValue())
- return setSignedRange(AddRec, ConservativeResult);
- return setSignedRange(AddRec,
- ConservativeResult.intersectWith(ConstantRange(Min, Max+1)));
+ ConstantRange StartSRange = getSignedRange(Start);
+ ConstantRange EndSRange =
+ StartSRange.add(MaxBECountRange.multiply(StepSRange));
+
+ // Check for signed overflow. This must be done with ConstantRange
+ // arithmetic because we could be called from within the ScalarEvolution
+ // overflow checking code.
+ ConstantRange SExtStartSRange =
+ StartSRange.sextOrTrunc(BitWidth * 2 + 1);
+ ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1);
+ if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
+ SExtEndSRange) {
+ APInt Min = APIntOps::smin(StartSRange.getSignedMin(),
+ EndSRange.getSignedMin());
+ APInt Max = APIntOps::smax(StartSRange.getSignedMax(),
+ EndSRange.getSignedMax());
+ bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue();
+ if (!IsFullRange)
+ ConservativeResult =
+ ConservativeResult.intersectWith(ConstantRange(Min, Max + 1));
+ }
}
}
- return setSignedRange(AddRec, ConservativeResult);
+ return setRange(AddRec, SignHint, ConservativeResult);
}
if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
@@ -4065,18 +4033,31 @@ ScalarEvolution::getSignedRange(const SCEV *S) {
if (MDRange.hasValue())
ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue());
- // For a SCEVUnknown, ask ValueTracking.
- if (!U->getValue()->getType()->isIntegerTy() && !DL)
- return setSignedRange(U, ConservativeResult);
- unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT);
- if (NS <= 1)
- return setSignedRange(U, ConservativeResult);
- return setSignedRange(U, ConservativeResult.intersectWith(
- ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
- APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1)));
+ // Split here to avoid paying the compile-time cost of calling both
+ // computeKnownBits and ComputeNumSignBits. This restriction can be lifted
+ // if needed.
+ const DataLayout &DL = F->getParent()->getDataLayout();
+ if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
+ // For a SCEVUnknown, ask ValueTracking.
+ APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
+ computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT);
+ if (Ones != ~Zeros + 1)
+ ConservativeResult =
+ ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1));
+ } else {
+ assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED &&
+ "generalize as needed!");
+ unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT);
+ if (NS > 1)
+ ConservativeResult = ConservativeResult.intersectWith(
+ ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
+ APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1));
+ }
+
+ return setRange(U, SignHint, ConservativeResult);
}
- return setSignedRange(S, ConservativeResult);
+ return setRange(S, SignHint, ConservativeResult);
}
/// createSCEV - We know that there is no SCEV for the specified value.
@@ -4175,8 +4156,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
unsigned TZ = A.countTrailingZeros();
unsigned BitWidth = A.getBitWidth();
APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
- computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, 0, AC,
- nullptr, DT);
+ computeKnownBits(U->getOperand(0), KnownZero, KnownOne,
+ F->getParent()->getDataLayout(), 0, AC, nullptr, DT);
APInt EffectiveMask =
APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
@@ -5327,12 +5308,9 @@ static bool canConstantEvolve(Instruction *I, const Loop *L) {
if (!L->contains(I)) return false;
if (isa<PHINode>(I)) {
- if (L->getHeader() == I->getParent())
- return true;
- else
- // We don't currently keep track of the control flow needed to evaluate
- // PHIs, so we cannot handle PHIs inside of loops.
- return false;
+ // We don't currently keep track of the control flow needed to evaluate
+ // PHIs, so we cannot handle PHIs inside of loops.
+ return L->getHeader() == I->getParent();
}
// If we won't be able to constant fold this expression even if the operands
@@ -5403,7 +5381,7 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
/// reason, return null.
static Constant *EvaluateExpression(Value *V, const Loop *L,
DenseMap<Instruction *, Constant *> &Vals,
- const DataLayout *DL,
+ const DataLayout &DL,
const TargetLibraryInfo *TLI) {
// Convenient constant check, but redundant for recursive calls.
if (Constant *C = dyn_cast<Constant>(V)) return C;
@@ -5492,6 +5470,7 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
unsigned NumIterations = BEs.getZExtValue(); // must be in range
unsigned IterationNum = 0;
+ const DataLayout &DL = F->getParent()->getDataLayout();
for (; ; ++IterationNum) {
if (IterationNum == NumIterations)
return RetVal = CurrentIterVals[PN]; // Got exit value!
@@ -5499,8 +5478,8 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
// Compute the value of the PHIs for the next iteration.
// EvaluateExpression adds non-phi values to the CurrentIterVals map.
DenseMap<Instruction *, Constant *> NextIterVals;
- Constant *NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL,
- TLI);
+ Constant *NextPHI =
+ EvaluateExpression(BEValue, L, CurrentIterVals, DL, TLI);
if (!NextPHI)
return nullptr; // Couldn't evaluate!
NextIterVals[PN] = NextPHI;
@@ -5576,12 +5555,11 @@ const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L,
// Okay, we find a PHI node that defines the trip count of this loop. Execute
// the loop symbolically to determine when the condition gets a value of
// "ExitWhen".
-
unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
+ const DataLayout &DL = F->getParent()->getDataLayout();
for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
- ConstantInt *CondVal =
- dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, L, CurrentIterVals,
- DL, TLI));
+ ConstantInt *CondVal = dyn_cast_or_null<ConstantInt>(
+ EvaluateExpression(Cond, L, CurrentIterVals, DL, TLI));
// Couldn't symbolically evaluate.
if (!CondVal) return getCouldNotCompute();
@@ -5814,16 +5792,16 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
// Check to see if getSCEVAtScope actually made an improvement.
if (MadeImprovement) {
Constant *C = nullptr;
+ const DataLayout &DL = F->getParent()->getDataLayout();
if (const CmpInst *CI = dyn_cast<CmpInst>(I))
- C = ConstantFoldCompareInstOperands(CI->getPredicate(),
- Operands[0], Operands[1], DL,
- TLI);
+ C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
+ Operands[1], DL, TLI);
else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) {
if (!LI->isVolatile())
C = ConstantFoldLoadFromConstPtr(Operands[0], DL);
} else
- C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
- Operands, DL, TLI);
+ C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands,
+ DL, TLI);
if (!C) return V;
return getSCEV(C);
}
@@ -6105,7 +6083,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) {
dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT,
R1->getValue(),
R2->getValue()))) {
- if (CB->getZExtValue() == false)
+ if (!CB->getZExtValue())
std::swap(R1, R2); // R1 is the minimum root now.
// We can only use this value if the chrec ends up with an exact zero
@@ -6815,15 +6793,6 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
if (!ICI) return false;
- // Bail if the ICmp's operands' types are wider than the needed type
- // before attempting to call getSCEV on them. This avoids infinite
- // recursion, since the analysis of widening casts can require loop
- // exit condition information for overflow checking, which would
- // lead back here.
- if (getTypeSizeInBits(LHS->getType()) <
- getTypeSizeInBits(ICI->getOperand(0)->getType()))
- return false;
-
// Now that we found a conditional branch that dominates the loop or controls
// the loop latch. Check to see if it is the comparison we are looking for.
ICmpInst::Predicate FoundPred;
@@ -6835,9 +6804,17 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
- // Balance the types. The case where FoundLHS' type is wider than
- // LHS' type is checked for above.
- if (getTypeSizeInBits(LHS->getType()) >
+ // Balance the types.
+ if (getTypeSizeInBits(LHS->getType()) <
+ getTypeSizeInBits(FoundLHS->getType())) {
+ if (CmpInst::isSigned(Pred)) {
+ LHS = getSignExtendExpr(LHS, FoundLHS->getType());
+ RHS = getSignExtendExpr(RHS, FoundLHS->getType());
+ } else {
+ LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
+ RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
+ }
+ } else if (getTypeSizeInBits(LHS->getType()) >
getTypeSizeInBits(FoundLHS->getType())) {
if (CmpInst::isSigned(FoundPred)) {
FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
@@ -6963,6 +6940,9 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS,
const SCEV *FoundRHS) {
+ if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
+ return true;
+
return isImpliedCondOperandsHelper(Pred, LHS, RHS,
FoundLHS, FoundRHS) ||
// ~x < ~y --> x > y
@@ -7100,6 +7080,47 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
return false;
}
+/// isImpliedCondOperandsViaRanges - helper function for isImpliedCondOperands.
+/// Tries to get cases like "X `sgt` 0 => X - 1 `sgt` -1".
+bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
+ const SCEV *LHS,
+ const SCEV *RHS,
+ const SCEV *FoundLHS,
+ const SCEV *FoundRHS) {
+ if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
+ // The restriction on `FoundRHS` be lifted easily -- it exists only to
+ // reduce the compile time impact of this optimization.
+ return false;
+
+ const SCEVAddExpr *AddLHS = dyn_cast<SCEVAddExpr>(LHS);
+ if (!AddLHS || AddLHS->getOperand(1) != FoundLHS ||
+ !isa<SCEVConstant>(AddLHS->getOperand(0)))
+ return false;
+
+ APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getValue()->getValue();
+
+ // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
+ // antecedent "`FoundLHS` `Pred` `FoundRHS`".
+ ConstantRange FoundLHSRange =
+ ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS);
+
+ // Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range
+ // for `LHS`:
+ APInt Addend =
+ cast<SCEVConstant>(AddLHS->getOperand(0))->getValue()->getValue();
+ ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend));
+
+ // We can also compute the range of values for `LHS` that satisfy the
+ // consequent, "`LHS` `Pred` `RHS`":
+ APInt ConstRHS = cast<SCEVConstant>(RHS)->getValue()->getValue();
+ ConstantRange SatisfyingLHSRange =
+ ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS);
+
+ // The antecedent implies the consequent if every value of `LHS` that
+ // satisfies the antecedent also satisfies the consequent.
+ return SatisfyingLHSRange.contains(LHSRange);
+}
+
// Verify if an linear IV with positive stride can overflow when in a
// less-than comparison, knowing the invariant term of the comparison, the
// stride and the knowledge of NSW/NUW flags on the recurrence.
@@ -7428,7 +7449,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
if (ConstantInt *CB =
dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
R1->getValue(), R2->getValue()))) {
- if (CB->getZExtValue() == false)
+ if (!CB->getZExtValue())
std::swap(R1, R2); // R1 is the minimum root now.
// Make sure the root is not off by one. The returned iteration should
@@ -7956,8 +7977,6 @@ bool ScalarEvolution::runOnFunction(Function &F) {
this->F = &F;
AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
- DL = DLP ? &DLP->getDataLayout() : nullptr;
TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
return false;
@@ -8058,6 +8077,12 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const {
OS << " --> ";
const SCEV *SV = SE.getSCEV(&*I);
SV->print(OS);
+ if (!isa<SCEVCouldNotCompute>(SV)) {
+ OS << " U: ";
+ SE.getUnsignedRange(SV).print(OS);
+ OS << " S: ";
+ SE.getSignedRange(SV).print(OS);
+ }
const Loop *L = LI->getLoopFor((*I).getParent());
@@ -8065,6 +8090,12 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const {
if (AtUse != SV) {
OS << " --> ";
AtUse->print(OS);
+ if (!isa<SCEVCouldNotCompute>(AtUse)) {
+ OS << " U: ";
+ SE.getUnsignedRange(AtUse).print(OS);
+ OS << " S: ";
+ SE.getSignedRange(AtUse).print(OS);
+ }
}
if (L) {