diff options
-rw-r--r-- | include/llvm/Analysis/ScalarEvolution.h | 8 | ||||
-rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 104 | ||||
-rw-r--r-- | unittests/Analysis/ScalarEvolutionTest.cpp | 166 |
3 files changed, 239 insertions, 39 deletions
diff --git a/include/llvm/Analysis/ScalarEvolution.h b/include/llvm/Analysis/ScalarEvolution.h index f249bcf..10d933e 100644 --- a/include/llvm/Analysis/ScalarEvolution.h +++ b/include/llvm/Analysis/ScalarEvolution.h @@ -588,6 +588,14 @@ namespace llvm { Ops.push_back(RHS); return getMulExpr(Ops, Flags); } + const SCEV *getMulExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) { + SmallVector<const SCEV *, 3> Ops; + Ops.push_back(Op0); + Ops.push_back(Op1); + Ops.push_back(Op2); + return getMulExpr(Ops, Flags); + } const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags); diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index f35f116..ff2cf12 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -1812,6 +1812,38 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, return S; } +static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) { + uint64_t k = i*j; + if (j > 1 && k / j != i) Overflow = true; + return k; +} + +/// Compute the result of "n choose k", the binomial coefficient. If an +/// intermediate computation overflows, Overflow will be set and the return will +/// be garbage. Overflow is not cleared on absense of overflow. +static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { + // We use the multiplicative formula: + // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 . + // At each iteration, we take the n-th term of the numeral and divide by the + // (k-n)th term of the denominator. This division will always produce an + // integral result, and helps reduce the chance of overflow in the + // intermediate computations. However, we can still overflow even when the + // final result would fit. + + if (n == 0 || n == k) return 1; + if (k > n) return 0; + + if (k > n/2) + k = n-k; + + uint64_t r = 1; + for (uint64_t i = 1; i <= k; ++i) { + r = umul_ov(r, n-(i-1), Overflow); + r /= i; + } + return r; +} + /// getMulExpr - Get a canonical multiply expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, @@ -1987,53 +2019,61 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, for (unsigned OtherIdx = Idx+1; OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); ++OtherIdx) { - bool Retry = false; if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) { - // {A,+,B}<L> * {C,+,D}<L> --> {A*C,+,A*D + B*C + B*D,+,2*B*D}<L> - // - // {A,+,B} * {C,+,D} = A+It*B * C+It*D = A*C + (A*D + B*C)*It + B*D*It^2 - // Given an equation of the form x + y*It + z*It^2 (above), we want to - // express it in terms of {X,+,Y,+,Z}. - // {X,+,Y,+,Z} = X + Y*It + Z*(It^2 - It)/2. - // Rearranging, X = x, Y = y+z, Z = 2z. + // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L> + // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ + // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z + // ]]],+,...up to x=2n}. + // Note that the arguments to choose() are always integers with values + // known at compile time, never SCEV objects. // - // x = A*C, y = (A*D + B*C), z = B*D. - // Therefore X = A*C, Y = A*D + B*C + B*D and Z = 2*B*D. + // The implementation avoids pointless extra computations when the two + // addrec's are of different length (mathematically, it's equivalent to + // an infinite stream of zeros on the right). + bool OpsModified = false; for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); ++OtherIdx) if (const SCEVAddRecExpr *OtherAddRec = dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx])) if (OtherAddRec->getLoop() == AddRecLoop) { - const SCEV *A = AddRec->getStart(); - const SCEV *B = AddRec->getStepRecurrence(*this); - const SCEV *C = OtherAddRec->getStart(); - const SCEV *D = OtherAddRec->getStepRecurrence(*this); - const SCEV *NewStart = getMulExpr(A, C); - const SCEV *BD = getMulExpr(B, D); - const SCEV *NewStep = getAddExpr(getMulExpr(A, D), - getMulExpr(B, C), BD); - const SCEV *NewSecondOrderStep = - getMulExpr(BD, getConstant(BD->getType(), 2)); - - // This can happen when AddRec or OtherAddRec have >3 operands. - // TODO: support these add-recs. - if (isLoopInvariant(NewStart, AddRecLoop) && - isLoopInvariant(NewStep, AddRecLoop) && - isLoopInvariant(NewSecondOrderStep, AddRecLoop)) { - SmallVector<const SCEV *, 3> AddRecOps; - AddRecOps.push_back(NewStart); - AddRecOps.push_back(NewStep); - AddRecOps.push_back(NewSecondOrderStep); + bool Overflow = false; + Type *Ty = AddRec->getType(); + bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; + SmallVector<const SCEV*, 7> AddRecOps; + for (int x = 0, xe = AddRec->getNumOperands() + + OtherAddRec->getNumOperands() - 1; + x != xe && !Overflow; ++x) { + const SCEV *Term = getConstant(Ty, 0); + for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { + uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); + for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), + ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); + z < ze && !Overflow; ++z) { + uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); + uint64_t Coeff; + if (LargerThan64Bits) + Coeff = umul_ov(Coeff1, Coeff2, Overflow); + else + Coeff = Coeff1*Coeff2; + const SCEV *CoeffTerm = getConstant(Ty, Coeff); + const SCEV *Term1 = AddRec->getOperand(y-z); + const SCEV *Term2 = OtherAddRec->getOperand(z); + Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); + } + } + AddRecOps.push_back(Term); + } + if (!Overflow) { const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), SCEV::FlagAnyWrap); if (Ops.size() == 2) return NewAddRec; Ops[Idx] = AddRec = cast<SCEVAddRecExpr>(NewAddRec); Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; - Retry = true; + OpsModified = true; } } - if (Retry) + if (OpsModified) return getMulExpr(Ops); } } diff --git a/unittests/Analysis/ScalarEvolutionTest.cpp b/unittests/Analysis/ScalarEvolutionTest.cpp index a09cb1c..ea5aeb3 100644 --- a/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/unittests/Analysis/ScalarEvolutionTest.cpp @@ -8,20 +8,35 @@ //===----------------------------------------------------------------------===// #include <llvm/Analysis/ScalarEvolutionExpressions.h> +#include <llvm/Analysis/LoopInfo.h> #include <llvm/GlobalVariable.h> #include <llvm/Constants.h> #include <llvm/LLVMContext.h> #include <llvm/Module.h> #include <llvm/PassManager.h> +#include <llvm/ADT/SmallVector.h> #include "gtest/gtest.h" namespace llvm { namespace { -TEST(ScalarEvolutionsTest, SCEVUnknownRAUW) { +// We use this fixture to ensure that we clean up ScalarEvolution before +// deleting the PassManager. +class ScalarEvolutionsTest : public testing::Test { +protected: + ScalarEvolutionsTest() : M("", Context), SE(*new ScalarEvolution) {} + ~ScalarEvolutionsTest() { + // Manually clean up, since we allocated new SCEV objects after the + // pass was finished. + SE.releaseMemory(); + } LLVMContext Context; - Module M("world", Context); + Module M; + PassManager PM; + ScalarEvolution &SE; +}; +TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) { FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), std::vector<Type *>(), false); Function *F = cast<Function>(M.getOrInsertFunction("f", FTy)); @@ -35,8 +50,6 @@ TEST(ScalarEvolutionsTest, SCEVUnknownRAUW) { Value *V2 = new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, Init, "V2"); // Create a ScalarEvolution and "run" it so that it gets initialized. - PassManager PM; - ScalarEvolution &SE = *new ScalarEvolution(); PM.add(&SE); PM.run(M); @@ -72,10 +85,149 @@ TEST(ScalarEvolutionsTest, SCEVUnknownRAUW) { EXPECT_EQ(cast<SCEVUnknown>(M0->getOperand(1))->getValue(), V0); EXPECT_EQ(cast<SCEVUnknown>(M1->getOperand(1))->getValue(), V0); EXPECT_EQ(cast<SCEVUnknown>(M2->getOperand(1))->getValue(), V0); +} + +TEST_F(ScalarEvolutionsTest, SCEVMultiplyAddRecs) { + Type *Ty = Type::getInt32Ty(Context); + SmallVector<Type *, 10> Types; + Types.append(10, Ty); + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), Types, false); + Function *F = cast<Function>(M.getOrInsertFunction("f", FTy)); + BasicBlock *BB = BasicBlock::Create(Context, "entry", F); + ReturnInst::Create(Context, 0, BB); + + // Create a ScalarEvolution and "run" it so that it gets initialized. + PM.add(&SE); + PM.run(M); + + // It's possible to produce an empty loop through the default constructor, + // but you can't add any blocks to it without a LoopInfo pass. + Loop L; + const_cast<std::vector<BasicBlock*>&>(L.getBlocks()).push_back(BB); + + Function::arg_iterator AI = F->arg_begin(); + SmallVector<const SCEV *, 5> A; + A.push_back(SE.getSCEV(&*AI++)); + A.push_back(SE.getSCEV(&*AI++)); + A.push_back(SE.getSCEV(&*AI++)); + A.push_back(SE.getSCEV(&*AI++)); + A.push_back(SE.getSCEV(&*AI++)); + const SCEV *A_rec = SE.getAddRecExpr(A, &L, SCEV::FlagAnyWrap); + + SmallVector<const SCEV *, 5> B; + B.push_back(SE.getSCEV(&*AI++)); + B.push_back(SE.getSCEV(&*AI++)); + B.push_back(SE.getSCEV(&*AI++)); + B.push_back(SE.getSCEV(&*AI++)); + B.push_back(SE.getSCEV(&*AI++)); + const SCEV *B_rec = SE.getAddRecExpr(B, &L, SCEV::FlagAnyWrap); + + /* Spot check that we perform this transformation: + {A0,+,A1,+,A2,+,A3,+,A4} * {B0,+,B1,+,B2,+,B3,+,B4} = + {A0*B0,+, + A1*B0 + A0*B1 + A1*B1,+, + A2*B0 + 2A1*B1 + A0*B2 + 2A2*B1 + 2A1*B2 + A2*B2,+, + A3*B0 + 3A2*B1 + 3A1*B2 + A0*B3 + 3A3*B1 + 6A2*B2 + 3A1*B3 + 3A3*B2 + + 3A2*B3 + A3*B3,+, + A4*B0 + 4A3*B1 + 6A2*B2 + 4A1*B3 + A0*B4 + 4A4*B1 + 12A3*B2 + 12A2*B3 + + 4A1*B4 + 6A4*B2 + 12A3*B3 + 6A2*B4 + 4A4*B3 + 4A3*B4 + A4*B4,+, + 5A4*B1 + 10A3*B2 + 10A2*B3 + 5A1*B4 + 20A4*B2 + 30A3*B3 + 20A2*B4 + + 30A4*B3 + 30A3*B4 + 20A4*B4,+, + 15A4*B2 + 20A3*B3 + 15A2*B4 + 60A4*B3 + 60A3*B4 + 90A4*B4,+, + 35A4*B3 + 35A3*B4 + 140A4*B4,+, + 70A4*B4} + */ + + const SCEVAddRecExpr *Product = + dyn_cast<SCEVAddRecExpr>(SE.getMulExpr(A_rec, B_rec)); + ASSERT_TRUE(Product); + ASSERT_EQ(Product->getNumOperands(), 9u); + + SmallVector<const SCEV *, 16> Sum; + Sum.push_back(SE.getMulExpr(A[0], B[0])); + EXPECT_EQ(Product->getOperand(0), SE.getAddExpr(Sum)); + Sum.clear(); + + // SCEV produces different an equal but different expression for these. + // Re-enable when PR11052 is fixed. +#if 0 + Sum.push_back(SE.getMulExpr(A[1], B[0])); + Sum.push_back(SE.getMulExpr(A[0], B[1])); + Sum.push_back(SE.getMulExpr(A[1], B[1])); + EXPECT_EQ(Product->getOperand(1), SE.getAddExpr(Sum)); + Sum.clear(); + + Sum.push_back(SE.getMulExpr(A[2], B[0])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[1], B[1])); + Sum.push_back(SE.getMulExpr(A[0], B[2])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[2], B[1])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[1], B[2])); + Sum.push_back(SE.getMulExpr(A[2], B[2])); + EXPECT_EQ(Product->getOperand(2), SE.getAddExpr(Sum)); + Sum.clear(); + + Sum.push_back(SE.getMulExpr(A[3], B[0])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[2], B[1])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[1], B[2])); + Sum.push_back(SE.getMulExpr(A[0], B[3])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[3], B[1])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[2])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[1], B[3])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[3], B[2])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[2], B[3])); + Sum.push_back(SE.getMulExpr(A[3], B[3])); + EXPECT_EQ(Product->getOperand(3), SE.getAddExpr(Sum)); + Sum.clear(); + + Sum.push_back(SE.getMulExpr(A[4], B[0])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[3], B[1])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[2])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[1], B[3])); + Sum.push_back(SE.getMulExpr(A[0], B[4])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[4], B[1])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[3], B[2])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[2], B[3])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[1], B[4])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[4], B[2])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[3], B[3])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[4])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[4], B[3])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[3], B[4])); + Sum.push_back(SE.getMulExpr(A[4], B[4])); + EXPECT_EQ(Product->getOperand(4), SE.getAddExpr(Sum)); + Sum.clear(); + + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 5), A[4], B[1])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 10), A[3], B[2])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 10), A[2], B[3])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 5), A[1], B[4])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[4], B[2])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[3], B[3])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[2], B[4])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[4], B[3])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[3], B[4])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[4], B[4])); + EXPECT_EQ(Product->getOperand(5), SE.getAddExpr(Sum)); + Sum.clear(); + + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 15), A[4], B[2])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[3], B[3])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 15), A[2], B[4])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 60), A[4], B[3])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 60), A[3], B[4])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 90), A[4], B[4])); + EXPECT_EQ(Product->getOperand(6), SE.getAddExpr(Sum)); + Sum.clear(); + + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 35), A[4], B[3])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 35), A[3], B[4])); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 140), A[4], B[4])); + EXPECT_EQ(Product->getOperand(7), SE.getAddExpr(Sum)); + Sum.clear(); +#endif - // Manually clean up, since we allocated new SCEV objects after the - // pass was finished. - SE.releaseMemory(); + Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 70), A[4], B[4])); + EXPECT_EQ(Product->getOperand(8), SE.getAddExpr(Sum)); } } // end anonymous namespace |