diff options
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 93 |
1 files changed, 71 insertions, 22 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index a5100d0..a78bbea 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -78,6 +78,8 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/Streams.h" #include "llvm/ADT/Statistic.h" +//TMP: +#include "llvm/Support/Debug.h" #include <ostream> #include <algorithm> #include <cmath> @@ -2461,6 +2463,53 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { return UnknownValue; } +/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the +/// following equation: +/// +/// A * X = B (mod N) +/// +/// where N = 2^BW and BW is the common bit width of A and B. The signedness of +/// A and B isn't important. +/// +/// If the equation does not have a solution, SCEVCouldNotCompute is returned. +static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B, + ScalarEvolution &SE) { + uint32_t BW = A.getBitWidth(); + assert(BW == B.getBitWidth() && "Bit widths must be the same."); + assert(A != 0 && "A must be non-zero."); + + // 1. D = gcd(A, N) + // + // The gcd of A and N may have only one prime factor: 2. The number of + // trailing zeros in A is its multiplicity + uint32_t Mult2 = A.countTrailingZeros(); + // D = 2^Mult2 + + // 2. Check if B is divisible by D. + // + // B is divisible by D if and only if the multiplicity of prime factor 2 for B + // is not less than multiplicity of this prime factor for D. + if (B.countTrailingZeros() < Mult2) + return new SCEVCouldNotCompute(); + + // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic + // modulo (N / D). + // + // (N / D) may need BW+1 bits in its representation. Hence, we'll use this + // bit width during computations. + APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D + APInt Mod(BW + 1, 0); + Mod.set(BW - Mult2); // Mod = N / D + APInt I = AD.multiplicativeInverse(Mod); + + // 4. Compute the minimum unsigned root of the equation: + // I * (B / D) mod (N / D) + APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod); + + // The result is guaranteed to be less than 2^BW so we may truncate it to BW + // bits. + return SE.getConstant(Result.trunc(BW)); +} /// SolveQuadraticEquation - Find the roots of the quadratic equation for the /// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which @@ -2533,36 +2582,36 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { return UnknownValue; if (AddRec->isAffine()) { - // If this is an affine expression the execution count of this branch is - // equal to: + // If this is an affine expression, the execution count of this branch is + // the minimum unsigned root of the following equation: + // + // Start + Step*N = 0 (mod 2^BW) // - // (0 - Start/Step) iff Start % Step == 0 + // equivalent to: // + // Step*N = -Start (mod 2^BW) + // + // where BW is the common bit width of Start and Step. + // Get the initial value for the loop. SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); if (isa<SCEVCouldNotCompute>(Start)) return UnknownValue; - SCEVHandle Step = AddRec->getOperand(1); - Step = getSCEVAtScope(Step, L->getParentLoop()); + SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); - // Figure out if Start % Step == 0. - // FIXME: We should add DivExpr and RemExpr operations to our AST. if (SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) { - if (StepC->getValue()->equalsInt(1)) // N % 1 == 0 - return SE.getNegativeSCEV(Start); // 0 - Start/1 == -Start - if (StepC->getValue()->isAllOnesValue()) // N % -1 == 0 - return Start; // 0 - Start/-1 == Start - - // Check to see if Start is divisible by SC with no remainder. - if (SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) { - ConstantInt *StartCC = StartC->getValue(); - Constant *StartNegC = ConstantExpr::getNeg(StartCC); - Constant *Rem = ConstantExpr::getURem(StartNegC, StepC->getValue()); - if (Rem->isNullValue()) { - Constant *Result = ConstantExpr::getUDiv(StartNegC,StepC->getValue()); - return SE.getUnknown(Result); - } - } + // For now we handle only constant steps. + + // First, handle unitary steps. + if (StepC->getValue()->equalsInt(1)) // 1*N = -Start (mod 2^BW), so: + return SE.getNegativeSCEV(Start); // N = -Start (as unsigned) + if (StepC->getValue()->isAllOnesValue()) // -1*N = -Start (mod 2^BW), so: + return Start; // N = Start (as unsigned) + + // Then, try to solve the above equation provided that Start is constant. + if (SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) + return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), + -StartC->getValue()->getValue(),SE); } } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) { // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of |