diff options
-rw-r--r-- | include/llvm/Analysis/ScalarEvolutionExpressions.h | 140 |
1 files changed, 94 insertions, 46 deletions
diff --git a/include/llvm/Analysis/ScalarEvolutionExpressions.h b/include/llvm/Analysis/ScalarEvolutionExpressions.h index eac9113..1fc03be 100644 --- a/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -549,53 +549,60 @@ namespace llvm { T.visitAll(Root); } - /// The SCEVRewriter takes a scalar evolution expression and copies all its - /// components. The result after a rewrite is an identical SCEV. - struct SCEVRewriter - : public SCEVVisitor<SCEVRewriter, const SCEV*> { + typedef DenseMap<const Value*, Value*> ValueToValueMap; + + /// The SCEVParameterRewriter takes a scalar evolution expression and updates + /// the SCEVUnknown components following the Map (Value -> Value). + struct SCEVParameterRewriter + : public SCEVVisitor<SCEVParameterRewriter, const SCEV*> { public: - SCEVRewriter(ScalarEvolution &S) : SE(S) {} + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, + ValueToValueMap &Map) { + SCEVParameterRewriter Rewriter(SE, Map); + return Rewriter.visit(Scev); + } - virtual ~SCEVRewriter() {} + SCEVParameterRewriter(ScalarEvolution &S, ValueToValueMap &M) + : SE(S), Map(M) {} - virtual const SCEV *visitConstant(const SCEVConstant *Constant) { + const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } - virtual const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { + const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { const SCEV *Operand = visit(Expr->getOperand()); return SE.getTruncateExpr(Operand, Expr->getType()); } - virtual const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { const SCEV *Operand = visit(Expr->getOperand()); return SE.getZeroExtendExpr(Operand, Expr->getType()); } - virtual const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { const SCEV *Operand = visit(Expr->getOperand()); return SE.getSignExtendExpr(Operand, Expr->getType()); } - virtual const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { SmallVector<const SCEV *, 2> Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) Operands.push_back(visit(Expr->getOperand(i))); return SE.getAddExpr(Operands); } - virtual const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { + const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { SmallVector<const SCEV *, 2> Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) Operands.push_back(visit(Expr->getOperand(i))); return SE.getMulExpr(Operands); } - virtual const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { + const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS())); } - virtual const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { SmallVector<const SCEV *, 2> Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) Operands.push_back(visit(Expr->getOperand(i))); @@ -603,54 +610,33 @@ namespace llvm { Expr->getNoWrapFlags()); } - virtual const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { + const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { SmallVector<const SCEV *, 2> Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) Operands.push_back(visit(Expr->getOperand(i))); return SE.getSMaxExpr(Operands); } - virtual const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { + const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { SmallVector<const SCEV *, 2> Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) Operands.push_back(visit(Expr->getOperand(i))); return SE.getUMaxExpr(Operands); } - virtual const SCEV *visitUnknown(const SCEVUnknown *Expr) { - return Expr; - } - - virtual const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { - return Expr; - } - - protected: - ScalarEvolution &SE; - }; - - typedef DenseMap<const Value*, Value*> ValueToValueMap; - - /// The SCEVParameterRewriter takes a scalar evolution expression and updates - /// the SCEVUnknown components following the Map (Value -> Value). - struct SCEVParameterRewriter: public SCEVRewriter { - public: - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - ValueToValueMap &Map) { - SCEVParameterRewriter Rewriter(SE, Map); - return Rewriter.visit(Scev); - } - SCEVParameterRewriter(ScalarEvolution &S, ValueToValueMap &M) - : SCEVRewriter(S), Map(M) {} - - virtual const SCEV *visitUnknown(const SCEVUnknown *Expr) { + const SCEV *visitUnknown(const SCEVUnknown *Expr) { Value *V = Expr->getValue(); if (Map.count(V)) return SE.getUnknown(Map[V]); return Expr; } + const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return Expr; + } + private: + ScalarEvolution &SE; ValueToValueMap ⤅ }; @@ -658,17 +644,56 @@ namespace llvm { /// The SCEVApplyRewriter takes a scalar evolution expression and applies /// the Map (Loop -> SCEV) to all AddRecExprs. - struct SCEVApplyRewriter: public SCEVRewriter { + struct SCEVApplyRewriter + : public SCEVVisitor<SCEVApplyRewriter, const SCEV*> { public: static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, ScalarEvolution &SE) { SCEVApplyRewriter Rewriter(SE, Map); return Rewriter.visit(Scev); } + SCEVApplyRewriter(ScalarEvolution &S, LoopToScevMapT &M) - : SCEVRewriter(S), Map(M) {} + : SE(S), Map(M) {} + + const SCEV *visitConstant(const SCEVConstant *Constant) { + return Constant; + } + + const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getTruncateExpr(Operand, Expr->getType()); + } + + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector<const SCEV *, 2> Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getAddExpr(Operands); + } - virtual const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector<const SCEV *, 2> Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getMulExpr(Operands); + } + + const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { + return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS())); + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { SmallVector<const SCEV *, 2> Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) Operands.push_back(visit(Expr->getOperand(i))); @@ -683,7 +708,30 @@ namespace llvm { return Rec->evaluateAtIteration(Map[L], SE); } + const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { + SmallVector<const SCEV *, 2> Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getSMaxExpr(Operands); + } + + const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { + SmallVector<const SCEV *, 2> Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getUMaxExpr(Operands); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + return Expr; + } + + const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return Expr; + } + private: + ScalarEvolution &SE; LoopToScevMapT ⤅ }; |