diff options
Diffstat (limited to 'lib/Transforms/Scalar/StraightLineStrengthReduce.cpp')
-rw-r--r-- | lib/Transforms/Scalar/StraightLineStrengthReduce.cpp | 395 |
1 files changed, 328 insertions, 67 deletions
diff --git a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index 4edc86c..e71031c 100644 --- a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -15,19 +15,30 @@ // // There are many optimizations we can perform in the domain of SLSR. This file // for now contains only an initial step. Specifically, we look for strength -// reduction candidate in the form of +// reduction candidates in two forms: // -// (B + i) * S +// Form 1: (B + i) * S +// Form 2: &B[i * S] // -// where B and S are integer constants or variables, and i is a constant -// integer. If we found two such candidates +// where S is an integer variable, and i is a constant integer. If we found two +// candidates // -// S1: X = (B + i) * S S2: Y = (B + i') * S +// S1: X = (B + i) * S +// S2: Y = (B + i') * S +// +// or +// +// S1: X = &B[i * S] +// S2: Y = &B[i' * S] // // and S1 dominates S2, we call S1 a basis of S2, and can replace S2 with // // Y = X + (i' - i) * S // +// or +// +// Y = &X[(i' - i) * S] +// // where (i' - i) * S is folded to the extent possible. When S2 has multiple // bases, we pick the one that is closest to S2, or S2's "immediate" basis. // @@ -35,8 +46,6 @@ // // - Handle candidates in the form of B + i * S // -// - Handle candidates in the form of pointer arithmetics. e.g., B[i * S] -// // - Floating point arithmetics when fast math is enabled. // // - SLSR may decrease ILP at the architecture level. Targets that are very @@ -45,6 +54,10 @@ #include <vector> #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/FoldingSet.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" @@ -58,14 +71,30 @@ using namespace PatternMatch; namespace { class StraightLineStrengthReduce : public FunctionPass { - public: +public: // SLSR candidate. Such a candidate must be in the form of // (Base + Index) * Stride + // or + // Base[..][Index * Stride][..] struct Candidate : public ilist_node<Candidate> { - Candidate(Value *B = nullptr, ConstantInt *Idx = nullptr, - Value *S = nullptr, Instruction *I = nullptr) - : Base(B), Index(Idx), Stride(S), Ins(I), Basis(nullptr) {} - Value *Base; + enum Kind { + Invalid, // reserved for the default constructor + Mul, // (B + i) * S + GEP, // &B[..][i * S][..] + }; + + Candidate() + : CandidateKind(Invalid), Base(nullptr), Index(nullptr), + Stride(nullptr), Ins(nullptr), Basis(nullptr) {} + Candidate(Kind CT, const SCEV *B, ConstantInt *Idx, Value *S, + Instruction *I) + : CandidateKind(CT), Base(B), Index(Idx), Stride(S), Ins(I), + Basis(nullptr) {} + Kind CandidateKind; + const SCEV *Base; + // Note that Index and Stride of a GEP candidate may not have the same + // integer type. In that case, during rewriting, Stride will be + // sign-extended or truncated to Index's type. ConstantInt *Index; Value *Stride; // The instruction this candidate corresponds to. It helps us to rewrite a @@ -90,33 +119,70 @@ class StraightLineStrengthReduce : public FunctionPass { static char ID; - StraightLineStrengthReduce() : FunctionPass(ID), DT(nullptr) { + StraightLineStrengthReduce() + : FunctionPass(ID), DL(nullptr), DT(nullptr), TTI(nullptr) { initializeStraightLineStrengthReducePass(*PassRegistry::getPassRegistry()); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<ScalarEvolution>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); // We do not modify the shape of the CFG. AU.setPreservesCFG(); } + bool doInitialization(Module &M) override { + DL = &M.getDataLayout(); + return false; + } + bool runOnFunction(Function &F) override; - private: +private: // Returns true if Basis is a basis for C, i.e., Basis dominates C and they // share the same base and stride. bool isBasisFor(const Candidate &Basis, const Candidate &C); // Checks whether I is in a candidate form. If so, adds all the matching forms // to Candidates, and tries to find the immediate basis for each of them. void allocateCandidateAndFindBasis(Instruction *I); - // Given that I is in the form of "(B + Idx) * S", adds this form to - // Candidates, and finds its immediate basis. - void allocateCandidateAndFindBasis(Value *B, ConstantInt *Idx, Value *S, + // Allocate candidates and find bases for Mul instructions. + void allocateCandidateAndFindBasisForMul(Instruction *I); + // Splits LHS into Base + Index and, if succeeds, calls + // allocateCandidateAndFindBasis. + void allocateCandidateAndFindBasisForMul(Value *LHS, Value *RHS, + Instruction *I); + // Allocate candidates and find bases for GetElementPtr instructions. + void allocateCandidateAndFindBasisForGEP(GetElementPtrInst *GEP); + // A helper function that scales Idx with ElementSize before invoking + // allocateCandidateAndFindBasis. + void allocateCandidateAndFindBasisForGEP(const SCEV *B, ConstantInt *Idx, + Value *S, uint64_t ElementSize, + Instruction *I); + // Adds the given form <CT, B, Idx, S> to Candidates, and finds its immediate + // basis. + void allocateCandidateAndFindBasis(Candidate::Kind CT, const SCEV *B, + ConstantInt *Idx, Value *S, Instruction *I); // Rewrites candidate C with respect to Basis. void rewriteCandidateWithBasis(const Candidate &C, const Candidate &Basis); + // A helper function that factors ArrayIdx to a product of a stride and a + // constant index, and invokes allocateCandidateAndFindBasis with the + // factorings. + void factorArrayIndex(Value *ArrayIdx, const SCEV *Base, uint64_t ElementSize, + GetElementPtrInst *GEP); + // Emit code that computes the "bump" from Basis to C. If the candidate is a + // GEP and the bump is not divisible by the element size of the GEP, this + // function sets the BumpWithUglyGEP flag to notify its caller to bump the + // basis using an ugly GEP. + static Value *emitBump(const Candidate &Basis, const Candidate &C, + IRBuilder<> &Builder, const DataLayout *DL, + bool &BumpWithUglyGEP); + const DataLayout *DL; DominatorTree *DT; + ScalarEvolution *SE; + TargetTransformInfo *TTI; ilist<Candidate> Candidates; // Temporarily holds all instructions that are unlinked (but not deleted) by // rewriteCandidateWithBasis. These instructions will be actually removed @@ -129,6 +195,8 @@ char StraightLineStrengthReduce::ID = 0; INITIALIZE_PASS_BEGIN(StraightLineStrengthReduce, "slsr", "Straight line strength reduction", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolution) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(StraightLineStrengthReduce, "slsr", "Straight line strength reduction", false, false) @@ -141,9 +209,47 @@ bool StraightLineStrengthReduce::isBasisFor(const Candidate &Basis, return (Basis.Ins != C.Ins && // skip the same instruction // Basis must dominate C in order to rewrite C with respect to Basis. DT->dominates(Basis.Ins->getParent(), C.Ins->getParent()) && - // They share the same base and stride. + // They share the same base, stride, and candidate kind. Basis.Base == C.Base && - Basis.Stride == C.Stride); + Basis.Stride == C.Stride && + Basis.CandidateKind == C.CandidateKind); +} + +static bool isCompletelyFoldable(GetElementPtrInst *GEP, + const TargetTransformInfo *TTI, + const DataLayout *DL) { + GlobalVariable *BaseGV = nullptr; + int64_t BaseOffset = 0; + bool HasBaseReg = false; + int64_t Scale = 0; + + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getPointerOperand())) + BaseGV = GV; + else + HasBaseReg = true; + + gep_type_iterator GTI = gep_type_begin(GEP); + for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I, ++GTI) { + if (isa<SequentialType>(*GTI)) { + int64_t ElementSize = DL->getTypeAllocSize(GTI.getIndexedType()); + if (ConstantInt *ConstIdx = dyn_cast<ConstantInt>(*I)) { + BaseOffset += ConstIdx->getSExtValue() * ElementSize; + } else { + // Needs scale register. + if (Scale != 0) { + // No addressing mode takes two scale registers. + return false; + } + Scale = ElementSize; + } + } else { + StructType *STy = cast<StructType>(*GTI); + uint64_t Field = cast<ConstantInt>(*I)->getZExtValue(); + BaseOffset += DL->getStructLayout(STy)->getElementOffset(Field); + } + } + return TTI->isLegalAddressingMode(GEP->getType()->getElementType(), BaseGV, + BaseOffset, HasBaseReg, Scale); } // TODO: We currently implement an algorithm whose time complexity is linear to @@ -153,11 +259,17 @@ bool StraightLineStrengthReduce::isBasisFor(const Candidate &Basis, // table is indexed by the base and the stride of a candidate. Therefore, // finding the immediate basis of a candidate boils down to one hash-table look // up. -void StraightLineStrengthReduce::allocateCandidateAndFindBasis(Value *B, - ConstantInt *Idx, - Value *S, - Instruction *I) { - Candidate C(B, Idx, S, I); +void StraightLineStrengthReduce::allocateCandidateAndFindBasis( + Candidate::Kind CT, const SCEV *B, ConstantInt *Idx, Value *S, + Instruction *I) { + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) { + // If &B[Idx * S] fits into an addressing mode, do not turn it into + // non-free computation. + if (isCompletelyFoldable(GEP, TTI, DL)) + return; + } + + Candidate C(CT, B, Idx, S, I); // Try to compute the immediate basis of C. unsigned NumIterations = 0; // Limit the scan radius to avoid running forever. @@ -176,60 +288,209 @@ void StraightLineStrengthReduce::allocateCandidateAndFindBasis(Value *B, } void StraightLineStrengthReduce::allocateCandidateAndFindBasis(Instruction *I) { + switch (I->getOpcode()) { + case Instruction::Mul: + allocateCandidateAndFindBasisForMul(I); + break; + case Instruction::GetElementPtr: + allocateCandidateAndFindBasisForGEP(cast<GetElementPtrInst>(I)); + break; + } +} + +void StraightLineStrengthReduce::allocateCandidateAndFindBasisForMul( + Value *LHS, Value *RHS, Instruction *I) { Value *B = nullptr; ConstantInt *Idx = nullptr; - // "(Base + Index) * Stride" must be a Mul instruction at the first hand. - if (I->getOpcode() == Instruction::Mul) { - if (IntegerType *ITy = dyn_cast<IntegerType>(I->getType())) { - Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); - for (unsigned Swapped = 0; Swapped < 2; ++Swapped) { - // Only handle the canonical operand ordering. - if (match(LHS, m_Add(m_Value(B), m_ConstantInt(Idx)))) { - // If LHS is in the form of "Base + Index", then I is in the form of - // "(Base + Index) * RHS". - allocateCandidateAndFindBasis(B, Idx, RHS, I); - } else { - // Otherwise, at least try the form (LHS + 0) * RHS. - allocateCandidateAndFindBasis(LHS, ConstantInt::get(ITy, 0), RHS, I); - } - // Swap LHS and RHS so that we also cover the cases where LHS is the - // stride. - if (LHS == RHS) - break; - std::swap(LHS, RHS); - } - } + // Only handle the canonical operand ordering. + if (match(LHS, m_Add(m_Value(B), m_ConstantInt(Idx)))) { + // If LHS is in the form of "Base + Index", then I is in the form of + // "(Base + Index) * RHS". + allocateCandidateAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I); + } else { + // Otherwise, at least try the form (LHS + 0) * RHS. + ConstantInt *Zero = ConstantInt::get(cast<IntegerType>(I->getType()), 0); + allocateCandidateAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS, + I); + } +} + +void StraightLineStrengthReduce::allocateCandidateAndFindBasisForMul( + Instruction *I) { + // Try matching (B + i) * S. + // TODO: we could extend SLSR to float and vector types. + if (!isa<IntegerType>(I->getType())) + return; + + Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); + allocateCandidateAndFindBasisForMul(LHS, RHS, I); + if (LHS != RHS) { + // Symmetrically, try to split RHS to Base + Index. + allocateCandidateAndFindBasisForMul(RHS, LHS, I); + } +} + +void StraightLineStrengthReduce::allocateCandidateAndFindBasisForGEP( + const SCEV *B, ConstantInt *Idx, Value *S, uint64_t ElementSize, + Instruction *I) { + // I = B + sext(Idx *nsw S) *nsw ElementSize + // = B + (sext(Idx) * ElementSize) * sext(S) + // Casting to IntegerType is safe because we skipped vector GEPs. + IntegerType *IntPtrTy = cast<IntegerType>(DL->getIntPtrType(I->getType())); + ConstantInt *ScaledIdx = ConstantInt::get( + IntPtrTy, Idx->getSExtValue() * (int64_t)ElementSize, true); + allocateCandidateAndFindBasis(Candidate::GEP, B, ScaledIdx, S, I); +} + +void StraightLineStrengthReduce::factorArrayIndex(Value *ArrayIdx, + const SCEV *Base, + uint64_t ElementSize, + GetElementPtrInst *GEP) { + // At least, ArrayIdx = ArrayIdx *s 1. + allocateCandidateAndFindBasisForGEP( + Base, ConstantInt::get(cast<IntegerType>(ArrayIdx->getType()), 1), + ArrayIdx, ElementSize, GEP); + Value *LHS = nullptr; + ConstantInt *RHS = nullptr; + // TODO: handle shl. e.g., we could treat (S << 2) as (S * 4). + // + // One alternative is matching the SCEV of ArrayIdx instead of ArrayIdx + // itself. This would allow us to handle the shl case for free. However, + // matching SCEVs has two issues: + // + // 1. this would complicate rewriting because the rewriting procedure + // would have to translate SCEVs back to IR instructions. This translation + // is difficult when LHS is further evaluated to a composite SCEV. + // + // 2. ScalarEvolution is designed to be control-flow oblivious. It tends + // to strip nsw/nuw flags which are critical for SLSR to trace into + // sext'ed multiplication. + if (match(ArrayIdx, m_NSWMul(m_Value(LHS), m_ConstantInt(RHS)))) { + // SLSR is currently unsafe if i * S may overflow. + // GEP = Base + sext(LHS *nsw RHS) *nsw ElementSize + allocateCandidateAndFindBasisForGEP(Base, RHS, LHS, ElementSize, GEP); + } +} + +void StraightLineStrengthReduce::allocateCandidateAndFindBasisForGEP( + GetElementPtrInst *GEP) { + // TODO: handle vector GEPs + if (GEP->getType()->isVectorTy()) + return; + + const SCEV *GEPExpr = SE->getSCEV(GEP); + Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); + + gep_type_iterator GTI = gep_type_begin(GEP); + for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I) { + if (!isa<SequentialType>(*GTI++)) + continue; + Value *ArrayIdx = *I; + // Compute the byte offset of this index. + uint64_t ElementSize = DL->getTypeAllocSize(*GTI); + const SCEV *ElementSizeExpr = SE->getSizeOfExpr(IntPtrTy, *GTI); + const SCEV *ArrayIdxExpr = SE->getSCEV(ArrayIdx); + ArrayIdxExpr = SE->getTruncateOrSignExtend(ArrayIdxExpr, IntPtrTy); + const SCEV *LocalOffset = + SE->getMulExpr(ArrayIdxExpr, ElementSizeExpr, SCEV::FlagNSW); + // The base of this candidate equals GEPExpr less the byte offset of this + // index. + const SCEV *Base = SE->getMinusSCEV(GEPExpr, LocalOffset); + factorArrayIndex(ArrayIdx, Base, ElementSize, GEP); + // When ArrayIdx is the sext of a value, we try to factor that value as + // well. Handling this case is important because array indices are + // typically sign-extended to the pointer size. + Value *TruncatedArrayIdx = nullptr; + if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx)))) + factorArrayIndex(TruncatedArrayIdx, Base, ElementSize, GEP); } } +// A helper function that unifies the bitwidth of A and B. +static void unifyBitWidth(APInt &A, APInt &B) { + if (A.getBitWidth() < B.getBitWidth()) + A = A.sext(B.getBitWidth()); + else if (A.getBitWidth() > B.getBitWidth()) + B = B.sext(A.getBitWidth()); +} + +Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis, + const Candidate &C, + IRBuilder<> &Builder, + const DataLayout *DL, + bool &BumpWithUglyGEP) { + APInt Idx = C.Index->getValue(), BasisIdx = Basis.Index->getValue(); + unifyBitWidth(Idx, BasisIdx); + APInt IndexOffset = Idx - BasisIdx; + + BumpWithUglyGEP = false; + if (Basis.CandidateKind == Candidate::GEP) { + APInt ElementSize( + IndexOffset.getBitWidth(), + DL->getTypeAllocSize( + cast<GetElementPtrInst>(Basis.Ins)->getType()->getElementType())); + APInt Q, R; + APInt::sdivrem(IndexOffset, ElementSize, Q, R); + if (R.getSExtValue() == 0) + IndexOffset = Q; + else + BumpWithUglyGEP = true; + } + // Compute Bump = C - Basis = (i' - i) * S. + // Common case 1: if (i' - i) is 1, Bump = S. + if (IndexOffset.getSExtValue() == 1) + return C.Stride; + // Common case 2: if (i' - i) is -1, Bump = -S. + if (IndexOffset.getSExtValue() == -1) + return Builder.CreateNeg(C.Stride); + // Otherwise, Bump = (i' - i) * sext/trunc(S). + ConstantInt *Delta = ConstantInt::get(Basis.Ins->getContext(), IndexOffset); + Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, Delta->getType()); + return Builder.CreateMul(ExtendedStride, Delta); +} + void StraightLineStrengthReduce::rewriteCandidateWithBasis( const Candidate &C, const Candidate &Basis) { + assert(C.CandidateKind == Basis.CandidateKind && C.Base == Basis.Base && + C.Stride == Basis.Stride); + // An instruction can correspond to multiple candidates. Therefore, instead of // simply deleting an instruction when we rewrite it, we mark its parent as // nullptr (i.e. unlink it) so that we can skip the candidates whose // instruction is already rewritten. if (!C.Ins->getParent()) return; - assert(C.Base == Basis.Base && C.Stride == Basis.Stride); - // Basis = (B + i) * S - // C = (B + i') * S - // ==> - // C = Basis + (i' - i) * S + IRBuilder<> Builder(C.Ins); - ConstantInt *IndexOffset = ConstantInt::get( - C.Ins->getContext(), C.Index->getValue() - Basis.Index->getValue()); - Value *Reduced; - // TODO: preserve nsw/nuw in some cases. - if (IndexOffset->isOne()) { - // If (i' - i) is 1, fold C into Basis + S. - Reduced = Builder.CreateAdd(Basis.Ins, C.Stride); - } else if (IndexOffset->isMinusOne()) { - // If (i' - i) is -1, fold C into Basis - S. - Reduced = Builder.CreateSub(Basis.Ins, C.Stride); - } else { - Value *Bump = Builder.CreateMul(C.Stride, IndexOffset); + bool BumpWithUglyGEP; + Value *Bump = emitBump(Basis, C, Builder, DL, BumpWithUglyGEP); + Value *Reduced = nullptr; // equivalent to but weaker than C.Ins + switch (C.CandidateKind) { + case Candidate::Mul: Reduced = Builder.CreateAdd(Basis.Ins, Bump); - } + break; + case Candidate::GEP: + { + Type *IntPtrTy = DL->getIntPtrType(C.Ins->getType()); + if (BumpWithUglyGEP) { + // C = (char *)Basis + Bump + unsigned AS = Basis.Ins->getType()->getPointerAddressSpace(); + Type *CharTy = Type::getInt8PtrTy(Basis.Ins->getContext(), AS); + Reduced = Builder.CreateBitCast(Basis.Ins, CharTy); + // We only considered inbounds GEP as candidates. + Reduced = Builder.CreateInBoundsGEP(Reduced, Bump); + Reduced = Builder.CreateBitCast(Reduced, C.Ins->getType()); + } else { + // C = gep Basis, Bump + // Canonicalize bump to pointer size. + Bump = Builder.CreateSExtOrTrunc(Bump, IntPtrTy); + Reduced = Builder.CreateInBoundsGEP(Basis.Ins, Bump); + } + } + break; + default: + llvm_unreachable("C.CandidateKind is invalid"); + }; Reduced->takeName(C.Ins); C.Ins->replaceAllUsesWith(Reduced); C.Ins->dropAllReferences(); @@ -243,15 +504,15 @@ bool StraightLineStrengthReduce::runOnFunction(Function &F) { if (skipOptnoneFunction(F)) return false; + TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + SE = &getAnalysis<ScalarEvolution>(); // Traverse the dominator tree in the depth-first order. This order makes sure // all bases of a candidate are in Candidates when we process it. for (auto node = GraphTraits<DominatorTree *>::nodes_begin(DT); node != GraphTraits<DominatorTree *>::nodes_end(DT); ++node) { - BasicBlock *B = node->getBlock(); - for (auto I = B->begin(); I != B->end(); ++I) { - allocateCandidateAndFindBasis(I); - } + for (auto &I : *node->getBlock()) + allocateCandidateAndFindBasis(&I); } // Rewrite candidates in the reverse depth-first order. This order makes sure |