diff options
Diffstat (limited to 'lib/Transforms/Utils')
26 files changed, 3819 insertions, 2524 deletions
diff --git a/lib/Transforms/Utils/AddDiscriminators.cpp b/lib/Transforms/Utils/AddDiscriminators.cpp index 196ac79..f8e5af5 100644 --- a/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/lib/Transforms/Utils/AddDiscriminators.cpp @@ -193,13 +193,11 @@ bool AddDiscriminators::runOnFunction(Function &F) { // Create a new lexical scope and compute a new discriminator // number for it. StringRef Filename = FirstDIL.getFilename(); - unsigned LineNumber = FirstDIL.getLineNumber(); - unsigned ColumnNumber = FirstDIL.getColumnNumber(); DIScope Scope = FirstDIL.getScope(); DIFile File = Builder.createFile(Filename, Scope.getDirectory()); unsigned Discriminator = FirstDIL.computeNewDiscriminator(Ctx); - DILexicalBlock NewScope = Builder.createLexicalBlock( - Scope, File, LineNumber, ColumnNumber, Discriminator); + DILexicalBlockFile NewScope = + Builder.createLexicalBlockFile(Scope, File, Discriminator); DILocation NewDIL = FirstDIL.copyWithNewScope(Ctx, NewScope); DebugLoc newDebugLoc = DebugLoc::getFromDILocation(NewDIL); diff --git a/lib/Transforms/Utils/Android.mk b/lib/Transforms/Utils/Android.mk index 2390027..e20dc0a 100644 --- a/lib/Transforms/Utils/Android.mk +++ b/lib/Transforms/Utils/Android.mk @@ -13,6 +13,7 @@ transforms_utils_SRC_FILES := \ CodeExtractor.cpp \ CtorUtils.cpp \ DemoteRegToStack.cpp \ + FlattenCFG.cpp \ GlobalStatus.cpp \ InlineFunction.cpp \ InstructionNamer.cpp \ @@ -33,6 +34,7 @@ transforms_utils_SRC_FILES := \ SimplifyIndVar.cpp \ SimplifyInstructions.cpp \ SimplifyLibCalls.cpp \ + SymbolRewriter.cpp \ UnifyFunctionExitNodes.cpp \ Utils.cpp \ ValueMapper.cpp diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp index 80b7e22..983f025 100644 --- a/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -265,6 +265,18 @@ BasicBlock *llvm::SplitEdge(BasicBlock *BB, BasicBlock *Succ, Pass *P) { return SplitBlock(BB, BB->getTerminator(), P); } +unsigned llvm::SplitAllCriticalEdges(Function &F, Pass *P) { + unsigned NumBroken = 0; + for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) { + TerminatorInst *TI = I->getTerminator(); + if (TI->getNumSuccessors() > 1 && !isa<IndirectBrInst>(TI)) + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + if (SplitCriticalEdge(TI, i, P)) + ++NumBroken; + } + return NumBroken; +} + /// SplitBlock - Split the specified block at the specified instruction - every /// thing before SplitPt stays in Old and everything starting with SplitPt moves /// to a new block. The two blocks are joined by an unconditional branch and @@ -673,7 +685,8 @@ ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, TerminatorInst *llvm::SplitBlockAndInsertIfThen(Value *Cond, Instruction *SplitBefore, bool Unreachable, - MDNode *BranchWeights) { + MDNode *BranchWeights, + DominatorTree *DT) { BasicBlock *Head = SplitBefore->getParent(); BasicBlock *Tail = Head->splitBasicBlock(SplitBefore); TerminatorInst *HeadOldTerm = Head->getTerminator(); @@ -690,6 +703,20 @@ TerminatorInst *llvm::SplitBlockAndInsertIfThen(Value *Cond, HeadNewTerm->setDebugLoc(SplitBefore->getDebugLoc()); HeadNewTerm->setMetadata(LLVMContext::MD_prof, BranchWeights); ReplaceInstWithInst(HeadOldTerm, HeadNewTerm); + + if (DT) { + if (DomTreeNode *OldNode = DT->getNode(Head)) { + std::vector<DomTreeNode *> Children(OldNode->begin(), OldNode->end()); + + DomTreeNode *NewNode = DT->addNewBlock(Tail, Head); + for (auto Child : Children) + DT->changeImmediateDominator(Child, NewNode); + + // Head dominates ThenBlock. + DT->addNewBlock(ThenBlock, Head); + } + } + return CheckTerm; } diff --git a/lib/Transforms/Utils/BreakCriticalEdges.cpp b/lib/Transforms/Utils/BreakCriticalEdges.cpp index 80bd516..eda22cf 100644 --- a/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -40,7 +40,11 @@ namespace { initializeBreakCriticalEdgesPass(*PassRegistry::getPassRegistry()); } - bool runOnFunction(Function &F) override; + bool runOnFunction(Function &F) override { + unsigned N = SplitAllCriticalEdges(F, this); + NumBroken += N; + return N > 0; + } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addPreserved<DominatorTreeWrapperPass>(); @@ -62,24 +66,6 @@ FunctionPass *llvm::createBreakCriticalEdgesPass() { return new BreakCriticalEdges(); } -// runOnFunction - Loop over all of the edges in the CFG, breaking critical -// edges as they are found. -// -bool BreakCriticalEdges::runOnFunction(Function &F) { - bool Changed = false; - for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) { - TerminatorInst *TI = I->getTerminator(); - if (TI->getNumSuccessors() > 1 && !isa<IndirectBrInst>(TI)) - for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) - if (SplitCriticalEdge(TI, i, this)) { - ++NumBroken; - Changed = true; - } - } - - return Changed; -} - //===----------------------------------------------------------------------===// // Implementation of the external critical edge manipulation functions //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp index be00b69..112d26c 100644 --- a/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/lib/Transforms/Utils/BuildLibCalls.cpp @@ -42,8 +42,7 @@ Value *llvm::EmitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout *TD, AttributeSet AS[2]; AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Constant *StrLen = M->getOrInsertFunction("strlen", @@ -51,7 +50,7 @@ Value *llvm::EmitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout *TD, AS), TD->getIntPtrType(Context), B.getInt8PtrTy(), - NULL); + nullptr); CallInst *CI = B.CreateCall(StrLen, CastToCStr(Ptr, B), "strlen"); if (const Function *F = dyn_cast<Function>(StrLen->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -71,8 +70,7 @@ Value *llvm::EmitStrNLen(Value *Ptr, Value *MaxLen, IRBuilder<> &B, AttributeSet AS[2]; AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Constant *StrNLen = M->getOrInsertFunction("strnlen", @@ -81,7 +79,7 @@ Value *llvm::EmitStrNLen(Value *Ptr, Value *MaxLen, IRBuilder<> &B, TD->getIntPtrType(Context), B.getInt8PtrTy(), TD->getIntPtrType(Context), - NULL); + nullptr); CallInst *CI = B.CreateCall2(StrNLen, CastToCStr(Ptr, B), MaxLen, "strnlen"); if (const Function *F = dyn_cast<Function>(StrNLen->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -100,15 +98,14 @@ Value *llvm::EmitStrChr(Value *Ptr, char C, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getParent()->getParent(); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; AttributeSet AS = - AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); Type *I8Ptr = B.getInt8PtrTy(); Type *I32Ty = B.getInt32Ty(); Constant *StrChr = M->getOrInsertFunction("strchr", AttributeSet::get(M->getContext(), AS), - I8Ptr, I8Ptr, I32Ty, NULL); + I8Ptr, I8Ptr, I32Ty, nullptr); CallInst *CI = B.CreateCall2(StrChr, CastToCStr(Ptr, B), ConstantInt::get(I32Ty, C), "strchr"); if (const Function *F = dyn_cast<Function>(StrChr->stripPointerCasts())) @@ -128,8 +125,7 @@ Value *llvm::EmitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); AS[1] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *StrNCmp = M->getOrInsertFunction("strncmp", @@ -138,7 +134,7 @@ Value *llvm::EmitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, B.getInt32Ty(), B.getInt8PtrTy(), B.getInt8PtrTy(), - TD->getIntPtrType(Context), NULL); + TD->getIntPtrType(Context), nullptr); CallInst *CI = B.CreateCall3(StrNCmp, CastToCStr(Ptr1, B), CastToCStr(Ptr2, B), Len, "strncmp"); @@ -164,7 +160,7 @@ Value *llvm::EmitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B, Type *I8Ptr = B.getInt8PtrTy(); Value *StrCpy = M->getOrInsertFunction(Name, AttributeSet::get(M->getContext(), AS), - I8Ptr, I8Ptr, I8Ptr, NULL); + I8Ptr, I8Ptr, I8Ptr, nullptr); CallInst *CI = B.CreateCall2(StrCpy, CastToCStr(Dst, B), CastToCStr(Src, B), Name); if (const Function *F = dyn_cast<Function>(StrCpy->stripPointerCasts())) @@ -190,7 +186,7 @@ Value *llvm::EmitStrNCpy(Value *Dst, Value *Src, Value *Len, AttributeSet::get(M->getContext(), AS), I8Ptr, I8Ptr, I8Ptr, - Len->getType(), NULL); + Len->getType(), nullptr); CallInst *CI = B.CreateCall3(StrNCpy, CastToCStr(Dst, B), CastToCStr(Src, B), Len, "strncpy"); if (const Function *F = dyn_cast<Function>(StrNCpy->stripPointerCasts())) @@ -218,7 +214,7 @@ Value *llvm::EmitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, B.getInt8PtrTy(), B.getInt8PtrTy(), TD->getIntPtrType(Context), - TD->getIntPtrType(Context), NULL); + TD->getIntPtrType(Context), nullptr); Dst = CastToCStr(Dst, B); Src = CastToCStr(Src, B); CallInst *CI = B.CreateCall4(MemCpy, Dst, Src, Len, ObjSize); @@ -238,8 +234,7 @@ Value *llvm::EmitMemChr(Value *Ptr, Value *Val, Module *M = B.GetInsertBlock()->getParent()->getParent(); AttributeSet AS; Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *MemChr = M->getOrInsertFunction("memchr", AttributeSet::get(M->getContext(), AS), @@ -247,7 +242,7 @@ Value *llvm::EmitMemChr(Value *Ptr, Value *Val, B.getInt8PtrTy(), B.getInt32Ty(), TD->getIntPtrType(Context), - NULL); + nullptr); CallInst *CI = B.CreateCall3(MemChr, CastToCStr(Ptr, B), Val, Len, "memchr"); if (const Function *F = dyn_cast<Function>(MemChr->stripPointerCasts())) @@ -268,8 +263,7 @@ Value *llvm::EmitMemCmp(Value *Ptr1, Value *Ptr2, AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); AS[1] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *MemCmp = M->getOrInsertFunction("memcmp", @@ -277,7 +271,7 @@ Value *llvm::EmitMemCmp(Value *Ptr1, Value *Ptr2, B.getInt32Ty(), B.getInt8PtrTy(), B.getInt8PtrTy(), - TD->getIntPtrType(Context), NULL); + TD->getIntPtrType(Context), nullptr); CallInst *CI = B.CreateCall3(MemCmp, CastToCStr(Ptr1, B), CastToCStr(Ptr2, B), Len, "memcmp"); @@ -313,7 +307,7 @@ Value *llvm::EmitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getParent()->getParent(); Value *Callee = M->getOrInsertFunction(Name, Op->getType(), - Op->getType(), NULL); + Op->getType(), nullptr); CallInst *CI = B.CreateCall(Callee, Op, Name); CI->setAttributes(Attrs); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) @@ -334,7 +328,7 @@ Value *llvm::EmitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, Module *M = B.GetInsertBlock()->getParent()->getParent(); Value *Callee = M->getOrInsertFunction(Name, Op1->getType(), - Op1->getType(), Op2->getType(), NULL); + Op1->getType(), Op2->getType(), nullptr); CallInst *CI = B.CreateCall2(Callee, Op1, Op2, Name); CI->setAttributes(Attrs); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) @@ -352,7 +346,7 @@ Value *llvm::EmitPutChar(Value *Char, IRBuilder<> &B, const DataLayout *TD, Module *M = B.GetInsertBlock()->getParent()->getParent(); Value *PutChar = M->getOrInsertFunction("putchar", B.getInt32Ty(), - B.getInt32Ty(), NULL); + B.getInt32Ty(), nullptr); CallInst *CI = B.CreateCall(PutChar, B.CreateIntCast(Char, B.getInt32Ty(), @@ -382,7 +376,7 @@ Value *llvm::EmitPutS(Value *Str, IRBuilder<> &B, const DataLayout *TD, AttributeSet::get(M->getContext(), AS), B.getInt32Ty(), B.getInt8PtrTy(), - NULL); + nullptr); CallInst *CI = B.CreateCall(PutS, CastToCStr(Str, B), "puts"); if (const Function *F = dyn_cast<Function>(PutS->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -407,12 +401,12 @@ Value *llvm::EmitFPutC(Value *Char, Value *File, IRBuilder<> &B, AttributeSet::get(M->getContext(), AS), B.getInt32Ty(), B.getInt32Ty(), File->getType(), - NULL); + nullptr); else F = M->getOrInsertFunction("fputc", B.getInt32Ty(), B.getInt32Ty(), - File->getType(), NULL); + File->getType(), nullptr); Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true, "chari"); CallInst *CI = B.CreateCall2(F, Char, File, "fputc"); @@ -442,11 +436,11 @@ Value *llvm::EmitFPutS(Value *Str, Value *File, IRBuilder<> &B, AttributeSet::get(M->getContext(), AS), B.getInt32Ty(), B.getInt8PtrTy(), - File->getType(), NULL); + File->getType(), nullptr); else F = M->getOrInsertFunction(FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), - File->getType(), NULL); + File->getType(), nullptr); CallInst *CI = B.CreateCall2(F, CastToCStr(Str, B), File, "fputs"); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) @@ -478,13 +472,13 @@ Value *llvm::EmitFWrite(Value *Ptr, Value *Size, Value *File, B.getInt8PtrTy(), TD->getIntPtrType(Context), TD->getIntPtrType(Context), - File->getType(), NULL); + File->getType(), nullptr); else F = M->getOrInsertFunction(FWriteName, TD->getIntPtrType(Context), B.getInt8PtrTy(), TD->getIntPtrType(Context), TD->getIntPtrType(Context), - File->getType(), NULL); + File->getType(), nullptr); CallInst *CI = B.CreateCall4(F, CastToCStr(Ptr, B), Size, ConstantInt::get(TD->getIntPtrType(Context), 1), File); diff --git a/lib/Transforms/Utils/CMakeLists.txt b/lib/Transforms/Utils/CMakeLists.txt index fcf548f..6ce22b1 100644 --- a/lib/Transforms/Utils/CMakeLists.txt +++ b/lib/Transforms/Utils/CMakeLists.txt @@ -1,16 +1,17 @@ add_llvm_library(LLVMTransformUtils - AddDiscriminators.cpp ASanStackFrameLayout.cpp + AddDiscriminators.cpp BasicBlockUtils.cpp BreakCriticalEdges.cpp BuildLibCalls.cpp BypassSlowDivision.cpp - CtorUtils.cpp CloneFunction.cpp CloneModule.cpp CmpInstAnalysis.cpp CodeExtractor.cpp + CtorUtils.cpp DemoteRegToStack.cpp + FlattenCFG.cpp GlobalStatus.cpp InlineFunction.cpp InstructionNamer.cpp @@ -29,10 +30,10 @@ add_llvm_library(LLVMTransformUtils PromoteMemoryToRegister.cpp SSAUpdater.cpp SimplifyCFG.cpp - FlattenCFG.cpp SimplifyIndVar.cpp SimplifyInstructions.cpp SimplifyLibCalls.cpp + SymbolRewriter.cpp UnifyFunctionExitNodes.cpp Utils.cpp ValueMapper.cpp diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp index 3f75b3e..d078c96 100644 --- a/lib/Transforms/Utils/CloneModule.cpp +++ b/lib/Transforms/Utils/CloneModule.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/ValueMapper.h" +#include "llvm-c/Core.h" using namespace llvm; /// CloneModule - Return an exact copy of the specified module. This is not as @@ -122,3 +123,11 @@ Module *llvm::CloneModule(const Module *M, ValueToValueMapTy &VMap) { return New; } + +extern "C" { + +LLVMModuleRef LLVMCloneModule(LLVMModuleRef M) { + return wrap(CloneModule(unwrap(M))); +} + +} diff --git a/lib/Transforms/Utils/CtorUtils.cpp b/lib/Transforms/Utils/CtorUtils.cpp index a359424..26875e8 100644 --- a/lib/Transforms/Utils/CtorUtils.cpp +++ b/lib/Transforms/Utils/CtorUtils.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/BitVector.h" #include "llvm/Transforms/Utils/CtorUtils.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -24,41 +25,22 @@ namespace llvm { namespace { -/// Given a specified llvm.global_ctors list, install the -/// specified array. -void installGlobalCtors(GlobalVariable *GCL, - const std::vector<Function *> &Ctors) { - // If we made a change, reassemble the initializer list. - Constant *CSVals[3]; - - StructType *StructTy = - cast<StructType>(GCL->getType()->getElementType()->getArrayElementType()); - - // Create the new init list. - std::vector<Constant *> CAList; - for (Function *F : Ctors) { - Type *Int32Ty = Type::getInt32Ty(GCL->getContext()); - if (F) { - CSVals[0] = ConstantInt::get(Int32Ty, 65535); - CSVals[1] = F; - } else { - CSVals[0] = ConstantInt::get(Int32Ty, 0x7fffffff); - CSVals[1] = Constant::getNullValue(StructTy->getElementType(1)); - } - // FIXME: Only allow the 3-field form in LLVM 4.0. - size_t NumElts = StructTy->getNumElements(); - if (NumElts > 2) - CSVals[2] = Constant::getNullValue(StructTy->getElementType(2)); - CAList.push_back( - ConstantStruct::get(StructTy, makeArrayRef(CSVals, NumElts))); - } - - // Create the array initializer. - Constant *CA = - ConstantArray::get(ArrayType::get(StructTy, CAList.size()), CAList); +/// Given a specified llvm.global_ctors list, remove the listed elements. +void removeGlobalCtors(GlobalVariable *GCL, const BitVector &CtorsToRemove) { + // Filter out the initializer elements to remove. + ConstantArray *OldCA = cast<ConstantArray>(GCL->getInitializer()); + SmallVector<Constant *, 10> CAList; + for (unsigned I = 0, E = OldCA->getNumOperands(); I < E; ++I) + if (!CtorsToRemove.test(I)) + CAList.push_back(OldCA->getOperand(I)); + + // Create the new array initializer. + ArrayType *ATy = + ArrayType::get(OldCA->getType()->getElementType(), CAList.size()); + Constant *CA = ConstantArray::get(ATy, CAList); // If we didn't change the number of elements, don't create a new GV. - if (CA->getType() == GCL->getInitializer()->getType()) { + if (CA->getType() == OldCA->getType()) { GCL->setInitializer(CA); return; } @@ -82,7 +64,7 @@ void installGlobalCtors(GlobalVariable *GCL, /// Given a llvm.global_ctors list that we can understand, /// return a list of the functions and null terminator as a vector. -std::vector<Function*> parseGlobalCtors(GlobalVariable *GV) { +std::vector<Function *> parseGlobalCtors(GlobalVariable *GV) { if (GV->getInitializer()->isNullValue()) return std::vector<Function *>(); ConstantArray *CA = cast<ConstantArray>(GV->getInitializer()); @@ -147,17 +129,15 @@ bool optimizeGlobalCtorsList(Module &M, bool MadeChange = false; // Loop over global ctors, optimizing them when we can. - for (unsigned i = 0; i != Ctors.size(); ++i) { + unsigned NumCtors = Ctors.size(); + BitVector CtorsToRemove(NumCtors); + for (unsigned i = 0; i != Ctors.size() && NumCtors > 0; ++i) { Function *F = Ctors[i]; // Found a null terminator in the middle of the list, prune off the rest of // the list. - if (!F) { - if (i != Ctors.size() - 1) { - Ctors.resize(i + 1); - MadeChange = true; - } - break; - } + if (!F) + continue; + DEBUG(dbgs() << "Optimizing Global Constructor: " << *F << "\n"); // We cannot simplify external ctor functions. @@ -166,9 +146,10 @@ bool optimizeGlobalCtorsList(Module &M, // If we can evaluate the ctor at compile time, do. if (ShouldRemove(F)) { - Ctors.erase(Ctors.begin() + i); + Ctors[i] = nullptr; + CtorsToRemove.set(i); + NumCtors--; MadeChange = true; - --i; continue; } } @@ -176,7 +157,7 @@ bool optimizeGlobalCtorsList(Module &M, if (!MadeChange) return false; - installGlobalCtors(GlobalCtors, Ctors); + removeGlobalCtors(GlobalCtors, CtorsToRemove); return true; } diff --git a/lib/Transforms/Utils/FlattenCFG.cpp b/lib/Transforms/Utils/FlattenCFG.cpp index 51ead40..4eb3e3d 100644 --- a/lib/Transforms/Utils/FlattenCFG.cpp +++ b/lib/Transforms/Utils/FlattenCFG.cpp @@ -238,9 +238,13 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder, // Do branch inversion. BasicBlock *CurrBlock = LastCondBlock; bool EverChanged = false; - while (1) { + for (;CurrBlock != FirstCondBlock; + CurrBlock = CurrBlock->getSinglePredecessor()) { BranchInst *BI = dyn_cast<BranchInst>(CurrBlock->getTerminator()); CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition()); + if (!CI) + continue; + CmpInst::Predicate Predicate = CI->getPredicate(); // Canonicalize icmp_ne -> icmp_eq, fcmp_one -> fcmp_oeq if ((Predicate == CmpInst::ICMP_NE) || (Predicate == CmpInst::FCMP_ONE)) { @@ -248,9 +252,6 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder, BI->swapSuccessors(); EverChanged = true; } - if (CurrBlock == FirstCondBlock) - break; - CurrBlock = CurrBlock->getSinglePredecessor(); } return EverChanged; } diff --git a/lib/Transforms/Utils/GlobalStatus.cpp b/lib/Transforms/Utils/GlobalStatus.cpp index 12057e4..52e2d59 100644 --- a/lib/Transforms/Utils/GlobalStatus.cpp +++ b/lib/Transforms/Utils/GlobalStatus.cpp @@ -35,6 +35,9 @@ bool llvm::isSafeToDestroyConstant(const Constant *C) { if (isa<GlobalValue>(C)) return false; + if (isa<ConstantInt>(C) || isa<ConstantFP>(C)) + return false; + for (const User *U : C->users()) if (const Constant *CU = dyn_cast<Constant>(U)) { if (!isSafeToDestroyConstant(CU)) @@ -45,7 +48,7 @@ bool llvm::isSafeToDestroyConstant(const Constant *C) { } static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, - SmallPtrSet<const PHINode *, 16> &PhiUsers) { + SmallPtrSetImpl<const PHINode *> &PhiUsers) { for (const Use &U : V->uses()) { const User *UR = U.getUser(); if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(UR)) { @@ -130,7 +133,7 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, } else if (const PHINode *PN = dyn_cast<PHINode>(I)) { // PHI nodes we can check just like select or GEP instructions, but we // have to be careful about infinite recursion. - if (PhiUsers.insert(PN)) // Not already visited. + if (PhiUsers.insert(PN).second) // Not already visited. if (analyzeGlobalAux(I, GS, PhiUsers)) return true; } else if (isa<CmpInst>(I)) { diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp index f0a9f2b..2d0b7dc 100644 --- a/lib/Transforms/Utils/InlineFunction.cpp +++ b/lib/Transforms/Utils/InlineFunction.cpp @@ -13,10 +13,16 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/CFG.h" @@ -24,14 +30,28 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/CommandLine.h" +#include <algorithm> using namespace llvm; +static cl::opt<bool> +EnableNoAliasConversion("enable-noalias-to-md-conversion", cl::init(true), + cl::Hidden, + cl::desc("Convert noalias attributes to metadata during inlining.")); + +static cl::opt<bool> +PreserveAlignmentAssumptions("preserve-alignment-assumptions-during-inlining", + cl::init(true), cl::Hidden, + cl::desc("Convert align attributes to assumptions during inlining.")); + bool llvm::InlineFunction(CallInst *CI, InlineFunctionInfo &IFI, bool InsertLifetime) { return InlineFunction(CallSite(CI), IFI, InsertLifetime); @@ -84,7 +104,7 @@ namespace { /// split the landing pad block after the landingpad instruction and jump /// to there. void forwardResume(ResumeInst *RI, - SmallPtrSet<LandingPadInst*, 16> &InlinedLPads); + SmallPtrSetImpl<LandingPadInst*> &InlinedLPads); /// addIncomingPHIValuesFor - Add incoming-PHI values to the unwind /// destination block for the given basic block, using the values for the @@ -143,7 +163,7 @@ BasicBlock *InvokeInliningInfo::getInnerResumeDest() { /// branch. When there is more than one predecessor, we need to split the /// landing pad block after the landingpad instruction and jump to there. void InvokeInliningInfo::forwardResume(ResumeInst *RI, - SmallPtrSet<LandingPadInst*, 16> &InlinedLPads) { + SmallPtrSetImpl<LandingPadInst*> &InlinedLPads) { BasicBlock *Dest = getInnerResumeDest(); BasicBlock *Src = RI->getParent(); @@ -233,9 +253,7 @@ static void HandleInlinedInvoke(InvokeInst *II, BasicBlock *FirstNewBlock, // Append the clauses from the outer landing pad instruction into the inlined // landing pad instructions. LandingPadInst *OuterLPad = Invoke.getLandingPadInst(); - for (SmallPtrSet<LandingPadInst*, 16>::iterator I = InlinedLPads.begin(), - E = InlinedLPads.end(); I != E; ++I) { - LandingPadInst *InlinedLPad = *I; + for (LandingPadInst *InlinedLPad : InlinedLPads) { unsigned OuterNum = OuterLPad->getNumClauses(); InlinedLPad->reserveClauses(OuterNum); for (unsigned OuterIdx = 0; OuterIdx != OuterNum; ++OuterIdx) @@ -260,6 +278,385 @@ static void HandleInlinedInvoke(InvokeInst *II, BasicBlock *FirstNewBlock, InvokeDest->removePredecessor(II->getParent()); } +/// CloneAliasScopeMetadata - When inlining a function that contains noalias +/// scope metadata, this metadata needs to be cloned so that the inlined blocks +/// have different "unqiue scopes" at every call site. Were this not done, then +/// aliasing scopes from a function inlined into a caller multiple times could +/// not be differentiated (and this would lead to miscompiles because the +/// non-aliasing property communicated by the metadata could have +/// call-site-specific control dependencies). +static void CloneAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap) { + const Function *CalledFunc = CS.getCalledFunction(); + SetVector<const MDNode *> MD; + + // Note: We could only clone the metadata if it is already used in the + // caller. I'm omitting that check here because it might confuse + // inter-procedural alias analysis passes. We can revisit this if it becomes + // an efficiency or overhead problem. + + for (Function::const_iterator I = CalledFunc->begin(), IE = CalledFunc->end(); + I != IE; ++I) + for (BasicBlock::const_iterator J = I->begin(), JE = I->end(); J != JE; ++J) { + if (const MDNode *M = J->getMetadata(LLVMContext::MD_alias_scope)) + MD.insert(M); + if (const MDNode *M = J->getMetadata(LLVMContext::MD_noalias)) + MD.insert(M); + } + + if (MD.empty()) + return; + + // Walk the existing metadata, adding the complete (perhaps cyclic) chain to + // the set. + SmallVector<const Value *, 16> Queue(MD.begin(), MD.end()); + while (!Queue.empty()) { + const MDNode *M = cast<MDNode>(Queue.pop_back_val()); + for (unsigned i = 0, ie = M->getNumOperands(); i != ie; ++i) + if (const MDNode *M1 = dyn_cast<MDNode>(M->getOperand(i))) + if (MD.insert(M1)) + Queue.push_back(M1); + } + + // Now we have a complete set of all metadata in the chains used to specify + // the noalias scopes and the lists of those scopes. + SmallVector<MDNode *, 16> DummyNodes; + DenseMap<const MDNode *, TrackingVH<MDNode> > MDMap; + for (SetVector<const MDNode *>::iterator I = MD.begin(), IE = MD.end(); + I != IE; ++I) { + MDNode *Dummy = MDNode::getTemporary(CalledFunc->getContext(), None); + DummyNodes.push_back(Dummy); + MDMap[*I] = Dummy; + } + + // Create new metadata nodes to replace the dummy nodes, replacing old + // metadata references with either a dummy node or an already-created new + // node. + for (SetVector<const MDNode *>::iterator I = MD.begin(), IE = MD.end(); + I != IE; ++I) { + SmallVector<Value *, 4> NewOps; + for (unsigned i = 0, ie = (*I)->getNumOperands(); i != ie; ++i) { + const Value *V = (*I)->getOperand(i); + if (const MDNode *M = dyn_cast<MDNode>(V)) + NewOps.push_back(MDMap[M]); + else + NewOps.push_back(const_cast<Value *>(V)); + } + + MDNode *NewM = MDNode::get(CalledFunc->getContext(), NewOps), + *TempM = MDMap[*I]; + + TempM->replaceAllUsesWith(NewM); + } + + // Now replace the metadata in the new inlined instructions with the + // repacements from the map. + for (ValueToValueMapTy::iterator VMI = VMap.begin(), VMIE = VMap.end(); + VMI != VMIE; ++VMI) { + if (!VMI->second) + continue; + + Instruction *NI = dyn_cast<Instruction>(VMI->second); + if (!NI) + continue; + + if (MDNode *M = NI->getMetadata(LLVMContext::MD_alias_scope)) { + MDNode *NewMD = MDMap[M]; + // If the call site also had alias scope metadata (a list of scopes to + // which instructions inside it might belong), propagate those scopes to + // the inlined instructions. + if (MDNode *CSM = + CS.getInstruction()->getMetadata(LLVMContext::MD_alias_scope)) + NewMD = MDNode::concatenate(NewMD, CSM); + NI->setMetadata(LLVMContext::MD_alias_scope, NewMD); + } else if (NI->mayReadOrWriteMemory()) { + if (MDNode *M = + CS.getInstruction()->getMetadata(LLVMContext::MD_alias_scope)) + NI->setMetadata(LLVMContext::MD_alias_scope, M); + } + + if (MDNode *M = NI->getMetadata(LLVMContext::MD_noalias)) { + MDNode *NewMD = MDMap[M]; + // If the call site also had noalias metadata (a list of scopes with + // which instructions inside it don't alias), propagate those scopes to + // the inlined instructions. + if (MDNode *CSM = + CS.getInstruction()->getMetadata(LLVMContext::MD_noalias)) + NewMD = MDNode::concatenate(NewMD, CSM); + NI->setMetadata(LLVMContext::MD_noalias, NewMD); + } else if (NI->mayReadOrWriteMemory()) { + if (MDNode *M = CS.getInstruction()->getMetadata(LLVMContext::MD_noalias)) + NI->setMetadata(LLVMContext::MD_noalias, M); + } + } + + // Now that everything has been replaced, delete the dummy nodes. + for (unsigned i = 0, ie = DummyNodes.size(); i != ie; ++i) + MDNode::deleteTemporary(DummyNodes[i]); +} + +/// AddAliasScopeMetadata - If the inlined function has noalias arguments, then +/// add new alias scopes for each noalias argument, tag the mapped noalias +/// parameters with noalias metadata specifying the new scope, and tag all +/// non-derived loads, stores and memory intrinsics with the new alias scopes. +static void AddAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap, + const DataLayout *DL, AliasAnalysis *AA) { + if (!EnableNoAliasConversion) + return; + + const Function *CalledFunc = CS.getCalledFunction(); + SmallVector<const Argument *, 4> NoAliasArgs; + + for (Function::const_arg_iterator I = CalledFunc->arg_begin(), + E = CalledFunc->arg_end(); I != E; ++I) { + if (I->hasNoAliasAttr() && !I->hasNUses(0)) + NoAliasArgs.push_back(I); + } + + if (NoAliasArgs.empty()) + return; + + // To do a good job, if a noalias variable is captured, we need to know if + // the capture point dominates the particular use we're considering. + DominatorTree DT; + DT.recalculate(const_cast<Function&>(*CalledFunc)); + + // noalias indicates that pointer values based on the argument do not alias + // pointer values which are not based on it. So we add a new "scope" for each + // noalias function argument. Accesses using pointers based on that argument + // become part of that alias scope, accesses using pointers not based on that + // argument are tagged as noalias with that scope. + + DenseMap<const Argument *, MDNode *> NewScopes; + MDBuilder MDB(CalledFunc->getContext()); + + // Create a new scope domain for this function. + MDNode *NewDomain = + MDB.createAnonymousAliasScopeDomain(CalledFunc->getName()); + for (unsigned i = 0, e = NoAliasArgs.size(); i != e; ++i) { + const Argument *A = NoAliasArgs[i]; + + std::string Name = CalledFunc->getName(); + if (A->hasName()) { + Name += ": %"; + Name += A->getName(); + } else { + Name += ": argument "; + Name += utostr(i); + } + + // Note: We always create a new anonymous root here. This is true regardless + // of the linkage of the callee because the aliasing "scope" is not just a + // property of the callee, but also all control dependencies in the caller. + MDNode *NewScope = MDB.createAnonymousAliasScope(NewDomain, Name); + NewScopes.insert(std::make_pair(A, NewScope)); + } + + // Iterate over all new instructions in the map; for all memory-access + // instructions, add the alias scope metadata. + for (ValueToValueMapTy::iterator VMI = VMap.begin(), VMIE = VMap.end(); + VMI != VMIE; ++VMI) { + if (const Instruction *I = dyn_cast<Instruction>(VMI->first)) { + if (!VMI->second) + continue; + + Instruction *NI = dyn_cast<Instruction>(VMI->second); + if (!NI) + continue; + + bool IsArgMemOnlyCall = false, IsFuncCall = false; + SmallVector<const Value *, 2> PtrArgs; + + if (const LoadInst *LI = dyn_cast<LoadInst>(I)) + PtrArgs.push_back(LI->getPointerOperand()); + else if (const StoreInst *SI = dyn_cast<StoreInst>(I)) + PtrArgs.push_back(SI->getPointerOperand()); + else if (const VAArgInst *VAAI = dyn_cast<VAArgInst>(I)) + PtrArgs.push_back(VAAI->getPointerOperand()); + else if (const AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(I)) + PtrArgs.push_back(CXI->getPointerOperand()); + else if (const AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) + PtrArgs.push_back(RMWI->getPointerOperand()); + else if (ImmutableCallSite ICS = ImmutableCallSite(I)) { + // If we know that the call does not access memory, then we'll still + // know that about the inlined clone of this call site, and we don't + // need to add metadata. + if (ICS.doesNotAccessMemory()) + continue; + + IsFuncCall = true; + if (AA) { + AliasAnalysis::ModRefBehavior MRB = AA->getModRefBehavior(ICS); + if (MRB == AliasAnalysis::OnlyAccessesArgumentPointees || + MRB == AliasAnalysis::OnlyReadsArgumentPointees) + IsArgMemOnlyCall = true; + } + + for (ImmutableCallSite::arg_iterator AI = ICS.arg_begin(), + AE = ICS.arg_end(); AI != AE; ++AI) { + // We need to check the underlying objects of all arguments, not just + // the pointer arguments, because we might be passing pointers as + // integers, etc. + // However, if we know that the call only accesses pointer arguments, + // then we only need to check the pointer arguments. + if (IsArgMemOnlyCall && !(*AI)->getType()->isPointerTy()) + continue; + + PtrArgs.push_back(*AI); + } + } + + // If we found no pointers, then this instruction is not suitable for + // pairing with an instruction to receive aliasing metadata. + // However, if this is a call, this we might just alias with none of the + // noalias arguments. + if (PtrArgs.empty() && !IsFuncCall) + continue; + + // It is possible that there is only one underlying object, but you + // need to go through several PHIs to see it, and thus could be + // repeated in the Objects list. + SmallPtrSet<const Value *, 4> ObjSet; + SmallVector<Value *, 4> Scopes, NoAliases; + + SmallSetVector<const Argument *, 4> NAPtrArgs; + for (unsigned i = 0, ie = PtrArgs.size(); i != ie; ++i) { + SmallVector<Value *, 4> Objects; + GetUnderlyingObjects(const_cast<Value*>(PtrArgs[i]), + Objects, DL, /* MaxLookup = */ 0); + + for (Value *O : Objects) + ObjSet.insert(O); + } + + // Figure out if we're derived from anything that is not a noalias + // argument. + bool CanDeriveViaCapture = false, UsesAliasingPtr = false; + for (const Value *V : ObjSet) { + // Is this value a constant that cannot be derived from any pointer + // value (we need to exclude constant expressions, for example, that + // are formed from arithmetic on global symbols). + bool IsNonPtrConst = isa<ConstantInt>(V) || isa<ConstantFP>(V) || + isa<ConstantPointerNull>(V) || + isa<ConstantDataVector>(V) || isa<UndefValue>(V); + if (IsNonPtrConst) + continue; + + // If this is anything other than a noalias argument, then we cannot + // completely describe the aliasing properties using alias.scope + // metadata (and, thus, won't add any). + if (const Argument *A = dyn_cast<Argument>(V)) { + if (!A->hasNoAliasAttr()) + UsesAliasingPtr = true; + } else { + UsesAliasingPtr = true; + } + + // If this is not some identified function-local object (which cannot + // directly alias a noalias argument), or some other argument (which, + // by definition, also cannot alias a noalias argument), then we could + // alias a noalias argument that has been captured). + if (!isa<Argument>(V) && + !isIdentifiedFunctionLocal(const_cast<Value*>(V))) + CanDeriveViaCapture = true; + } + + // A function call can always get captured noalias pointers (via other + // parameters, globals, etc.). + if (IsFuncCall && !IsArgMemOnlyCall) + CanDeriveViaCapture = true; + + // First, we want to figure out all of the sets with which we definitely + // don't alias. Iterate over all noalias set, and add those for which: + // 1. The noalias argument is not in the set of objects from which we + // definitely derive. + // 2. The noalias argument has not yet been captured. + // An arbitrary function that might load pointers could see captured + // noalias arguments via other noalias arguments or globals, and so we + // must always check for prior capture. + for (const Argument *A : NoAliasArgs) { + if (!ObjSet.count(A) && (!CanDeriveViaCapture || + // It might be tempting to skip the + // PointerMayBeCapturedBefore check if + // A->hasNoCaptureAttr() is true, but this is + // incorrect because nocapture only guarantees + // that no copies outlive the function, not + // that the value cannot be locally captured. + !PointerMayBeCapturedBefore(A, + /* ReturnCaptures */ false, + /* StoreCaptures */ false, I, &DT))) + NoAliases.push_back(NewScopes[A]); + } + + if (!NoAliases.empty()) + NI->setMetadata(LLVMContext::MD_noalias, + MDNode::concatenate( + NI->getMetadata(LLVMContext::MD_noalias), + MDNode::get(CalledFunc->getContext(), NoAliases))); + + // Next, we want to figure out all of the sets to which we might belong. + // We might belong to a set if the noalias argument is in the set of + // underlying objects. If there is some non-noalias argument in our list + // of underlying objects, then we cannot add a scope because the fact + // that some access does not alias with any set of our noalias arguments + // cannot itself guarantee that it does not alias with this access + // (because there is some pointer of unknown origin involved and the + // other access might also depend on this pointer). We also cannot add + // scopes to arbitrary functions unless we know they don't access any + // non-parameter pointer-values. + bool CanAddScopes = !UsesAliasingPtr; + if (CanAddScopes && IsFuncCall) + CanAddScopes = IsArgMemOnlyCall; + + if (CanAddScopes) + for (const Argument *A : NoAliasArgs) { + if (ObjSet.count(A)) + Scopes.push_back(NewScopes[A]); + } + + if (!Scopes.empty()) + NI->setMetadata( + LLVMContext::MD_alias_scope, + MDNode::concatenate(NI->getMetadata(LLVMContext::MD_alias_scope), + MDNode::get(CalledFunc->getContext(), Scopes))); + } + } +} + +/// If the inlined function has non-byval align arguments, then +/// add @llvm.assume-based alignment assumptions to preserve this information. +static void AddAlignmentAssumptions(CallSite CS, InlineFunctionInfo &IFI) { + if (!PreserveAlignmentAssumptions || !IFI.DL) + return; + + // To avoid inserting redundant assumptions, we should check for assumptions + // already in the caller. To do this, we might need a DT of the caller. + DominatorTree DT; + bool DTCalculated = false; + + const Function *CalledFunc = CS.getCalledFunction(); + for (Function::const_arg_iterator I = CalledFunc->arg_begin(), + E = CalledFunc->arg_end(); I != E; ++I) { + unsigned Align = I->getType()->isPointerTy() ? I->getParamAlignment() : 0; + if (Align && !I->hasByValOrInAllocaAttr() && !I->hasNUses(0)) { + if (!DTCalculated) { + DT.recalculate(const_cast<Function&>(*CS.getInstruction()->getParent() + ->getParent())); + DTCalculated = true; + } + + // If we can already prove the asserted alignment in the context of the + // caller, then don't bother inserting the assumption. + Value *Arg = CS.getArgument(I->getArgNo()); + if (getKnownAlignment(Arg, IFI.DL, IFI.AT, CS.getInstruction(), + &DT) >= Align) + continue; + + IRBuilder<>(CS.getInstruction()).CreateAlignmentAssumption(*IFI.DL, Arg, + Align); + } + } +} + /// UpdateCallGraphAfterInlining - Once we have cloned code over from a callee /// into the caller, update the specified callgraph to reflect the changes we /// made. Note that it's possible that not all code was copied over, so only @@ -327,31 +724,19 @@ static void UpdateCallGraphAfterInlining(CallSite CS, static void HandleByValArgumentInit(Value *Dst, Value *Src, Module *M, BasicBlock *InsertBlock, InlineFunctionInfo &IFI) { - LLVMContext &Context = Src->getContext(); - Type *VoidPtrTy = Type::getInt8PtrTy(Context); Type *AggTy = cast<PointerType>(Src->getType())->getElementType(); - Type *Tys[3] = { VoidPtrTy, VoidPtrTy, Type::getInt64Ty(Context) }; - Function *MemCpyFn = Intrinsic::getDeclaration(M, Intrinsic::memcpy, Tys); - IRBuilder<> builder(InsertBlock->begin()); - Value *DstCast = builder.CreateBitCast(Dst, VoidPtrTy, "tmp"); - Value *SrcCast = builder.CreateBitCast(Src, VoidPtrTy, "tmp"); + IRBuilder<> Builder(InsertBlock->begin()); Value *Size; if (IFI.DL == nullptr) Size = ConstantExpr::getSizeOf(AggTy); else - Size = ConstantInt::get(Type::getInt64Ty(Context), - IFI.DL->getTypeStoreSize(AggTy)); + Size = Builder.getInt64(IFI.DL->getTypeStoreSize(AggTy)); // Always generate a memcpy of alignment 1 here because we don't know // the alignment of the src pointer. Other optimizations can infer // better alignment. - Value *CallArgs[] = { - DstCast, SrcCast, Size, - ConstantInt::get(Type::getInt32Ty(Context), 1), - ConstantInt::getFalse(Context) // isVolatile - }; - builder.CreateCall(MemCpyFn, CallArgs); + Builder.CreateMemCpy(Dst, Src, Size, /*Align=*/1); } /// HandleByValArgument - When inlining a call site that has a byval argument, @@ -376,7 +761,7 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, // If the pointer is already known to be sufficiently aligned, or if we can // round it up to a larger alignment, then we don't need a temporary. if (getOrEnforceKnownAlignment(Arg, ByValAlignment, - IFI.DL) >= ByValAlignment) + IFI.DL, IFI.AT, TheCall) >= ByValAlignment) return Arg; // Otherwise, we have to make a memcpy to get a safe alignment. This is bad @@ -472,6 +857,12 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, // originates from the call location. This is important for // ((__always_inline__, __nodebug__)) functions which must use caller // location for all instructions in their function body. + + // Don't update static allocas, as they may get moved later. + if (auto *AI = dyn_cast<AllocaInst>(BI)) + if (isa<Constant>(AI->getArraySize())) + continue; + BI->setDebugLoc(TheCallDL); } else { BI->setDebugLoc(updateInlinedAtInfo(DL, TheCallDL, BI->getContext())); @@ -486,33 +877,6 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, } } -/// Returns a musttail call instruction if one immediately precedes the given -/// return instruction with an optional bitcast instruction between them. -static CallInst *getPrecedingMustTailCall(ReturnInst *RI) { - Instruction *Prev = RI->getPrevNode(); - if (!Prev) - return nullptr; - - if (Value *RV = RI->getReturnValue()) { - if (RV != Prev) - return nullptr; - - // Look through the optional bitcast. - if (auto *BI = dyn_cast<BitCastInst>(Prev)) { - RV = BI->getOperand(0); - Prev = BI->getPrevNode(); - if (!Prev || RV != Prev) - return nullptr; - } - } - - if (auto *CI = dyn_cast<CallInst>(Prev)) { - if (CI->isMustTailCall()) - return CI; - } - return nullptr; -} - /// InlineFunction - This function inlines the called function into the basic /// block of the caller. This returns false if it is not possible to inline /// this call. The program is still in a well defined state if this occurs @@ -626,6 +990,11 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, VMap[I] = ActualArg; } + // Add alignment assumptions if necessary. We do this before the inlined + // instructions are actually cloned into the caller so that we can easily + // check what will be known at the start of the inlined code. + AddAlignmentAssumptions(CS, IFI); + // We want the inliner to prune the code as it copies. We would LOVE to // have no dead or constant instructions leftover after inlining occurs // (which can happen, e.g., because an argument was constant), but we'll be @@ -648,6 +1017,17 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Update inlined instructions' line number information. fixupLineNumbers(Caller, FirstNewBlock, TheCall); + + // Clone existing noalias metadata if necessary. + CloneAliasScopeMetadata(CS, VMap); + + // Add noalias metadata if necessary. + AddAliasScopeMetadata(CS, VMap, IFI.DL, IFI.AA); + + // FIXME: We could register any cloned assumptions instead of clearing the + // whole function's cache. + if (IFI.AT) + IFI.AT->forgetCachedAssumptions(Caller); } // If there are any alloca instructions in the block that used to be the entry @@ -765,7 +1145,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, for (ReturnInst *RI : Returns) { // Don't insert llvm.lifetime.end calls between a musttail call and a // return. The return kills all local allocas. - if (InlinedMustTailCalls && getPrecedingMustTailCall(RI)) + if (InlinedMustTailCalls && + RI->getParent()->getTerminatingMustTailCall()) continue; IRBuilder<>(RI).CreateLifetimeEnd(AI, AllocaSize); } @@ -789,7 +1170,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, for (ReturnInst *RI : Returns) { // Don't insert llvm.stackrestore calls between a musttail call and a // return. The return will restore the stack pointer. - if (InlinedMustTailCalls && getPrecedingMustTailCall(RI)) + if (InlinedMustTailCalls && RI->getParent()->getTerminatingMustTailCall()) continue; IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr); } @@ -812,7 +1193,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Handle the returns preceded by musttail calls separately. SmallVector<ReturnInst *, 8> NormalReturns; for (ReturnInst *RI : Returns) { - CallInst *ReturnedMustTail = getPrecedingMustTailCall(RI); + CallInst *ReturnedMustTail = + RI->getParent()->getTerminatingMustTailCall(); if (!ReturnedMustTail) { NormalReturns.push_back(RI); continue; @@ -1016,7 +1398,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // the entries are the same or undef). If so, remove the PHI so it doesn't // block other optimizations. if (PHI) { - if (Value *V = SimplifyInstruction(PHI, IFI.DL)) { + if (Value *V = SimplifyInstruction(PHI, IFI.DL, nullptr, nullptr, IFI.AT)) { PHI->replaceAllUsesWith(V); PHI->eraseFromParent(); } diff --git a/lib/Transforms/Utils/IntegerDivision.cpp b/lib/Transforms/Utils/IntegerDivision.cpp index 9f91eeb..0ae746c 100644 --- a/lib/Transforms/Utils/IntegerDivision.cpp +++ b/lib/Transforms/Utils/IntegerDivision.cpp @@ -398,11 +398,13 @@ bool llvm::expandRemainder(BinaryOperator *Rem) { Rem->dropAllReferences(); Rem->eraseFromParent(); - // If we didn't actually generate a udiv instruction, we're done - BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); - if (!BO || BO->getOpcode() != Instruction::URem) + // If we didn't actually generate an urem instruction, we're done + // This happens for example if the input were constant. In this case the + // Builder insertion point was unchanged + if (Rem == Builder.GetInsertPoint()) return true; + BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); Rem = BO; } @@ -456,11 +458,13 @@ bool llvm::expandDivision(BinaryOperator *Div) { Div->dropAllReferences(); Div->eraseFromParent(); - // If we didn't actually generate a udiv instruction, we're done - BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); - if (!BO || BO->getOpcode() != Instruction::UDiv) + // If we didn't actually generate an udiv instruction, we're done + // This happens for example if the input were constant. In this case the + // Builder insertion point was unchanged + if (Div == Builder.GetInsertPoint()) return true; + BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); Div = BO; } diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index aedd787..c963c51 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -128,7 +128,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Check to see if this branch is going to the same place as the default // dest. If so, eliminate it as an explicit compare. if (i.getCaseSuccessor() == DefaultDest) { - MDNode* MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); unsigned NCases = SI->getNumCases(); // Fold the case metadata into the default if there will be any branches // left, unless the metadata doesn't match the switch. @@ -206,7 +206,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, BranchInst *NewBr = Builder.CreateCondBr(Cond, FirstCase.getCaseSuccessor(), SI->getDefaultDest()); - MDNode* MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); if (MD && MD->getNumOperands() == 3) { ConstantInt *SICase = dyn_cast<ConstantInt>(MD->getOperand(2)); ConstantInt *SIDef = dyn_cast<ConstantInt>(MD->getOperand(1)); @@ -301,6 +301,14 @@ bool llvm::isInstructionTriviallyDead(Instruction *I, if (II->getIntrinsicID() == Intrinsic::lifetime_start || II->getIntrinsicID() == Intrinsic::lifetime_end) return isa<UndefValue>(II->getArgOperand(1)); + + // Assumptions are dead if their condition is trivially true. + if (II->getIntrinsicID() == Intrinsic::assume) { + if (ConstantInt *Cond = dyn_cast<ConstantInt>(II->getArgOperand(0))) + return !Cond->isZero(); + + return false; + } } if (isAllocLikeFn(I, TLI)) return true; @@ -384,7 +392,7 @@ bool llvm::RecursivelyDeleteDeadPHINode(PHINode *PN, // If we find an instruction more than once, we're on a cycle that // won't prove fruitful. - if (!Visited.insert(I)) { + if (!Visited.insert(I).second) { // Break the cycle and delete the instruction and its operands. I->replaceAllUsesWith(UndefValue::get(I->getType())); (void)RecursivelyDeleteTriviallyDeadInstructions(I, TLI); @@ -509,6 +517,11 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, Pass *P) { PredBB->getTerminator()->eraseFromParent(); DestBB->getInstList().splice(DestBB->begin(), PredBB->getInstList()); + // If the PredBB is the entry block of the function, move DestBB up to + // become the entry block after we erase PredBB. + if (PredBB == &DestBB->getParent()->getEntryBlock()) + DestBB->moveAfter(PredBB); + if (P) { if (DominatorTreeWrapperPass *DTWP = P->getAnalysisIfAvailable<DominatorTreeWrapperPass>()) { @@ -926,13 +939,16 @@ static unsigned enforceKnownAlignment(Value *V, unsigned Align, /// and it is more than the alignment of the ultimate object, see if we can /// increase the alignment of the ultimate object, making this check succeed. unsigned llvm::getOrEnforceKnownAlignment(Value *V, unsigned PrefAlign, - const DataLayout *DL) { + const DataLayout *DL, + AssumptionTracker *AT, + const Instruction *CxtI, + const DominatorTree *DT) { assert(V->getType()->isPointerTy() && "getOrEnforceKnownAlignment expects a pointer!"); unsigned BitWidth = DL ? DL->getPointerTypeSizeInBits(V->getType()) : 64; APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, DL); + computeKnownBits(V, KnownZero, KnownOne, DL, 0, AT, CxtI, DT); unsigned TrailZ = KnownZero.countTrailingOnes(); // Avoid trouble with ridiculously large TrailZ values, such as @@ -977,6 +993,7 @@ static bool LdStHasDebugValue(DIVariable &DIVar, Instruction *I) { bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, StoreInst *SI, DIBuilder &Builder) { DIVariable DIVar(DDI->getVariable()); + DIExpression DIExpr(DDI->getExpression()); assert((!DIVar || DIVar.isVariable()) && "Variable in DbgDeclareInst should be either null or a DIVariable."); if (!DIVar) @@ -994,9 +1011,10 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, if (SExtInst *SExt = dyn_cast<SExtInst>(SI->getOperand(0))) ExtendedArg = dyn_cast<Argument>(SExt->getOperand(0)); if (ExtendedArg) - DbgVal = Builder.insertDbgValueIntrinsic(ExtendedArg, 0, DIVar, SI); + DbgVal = Builder.insertDbgValueIntrinsic(ExtendedArg, 0, DIVar, DIExpr, SI); else - DbgVal = Builder.insertDbgValueIntrinsic(SI->getOperand(0), 0, DIVar, SI); + DbgVal = Builder.insertDbgValueIntrinsic(SI->getOperand(0), 0, DIVar, + DIExpr, SI); DbgVal->setDebugLoc(DDI->getDebugLoc()); return true; } @@ -1006,6 +1024,7 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, LoadInst *LI, DIBuilder &Builder) { DIVariable DIVar(DDI->getVariable()); + DIExpression DIExpr(DDI->getExpression()); assert((!DIVar || DIVar.isVariable()) && "Variable in DbgDeclareInst should be either null or a DIVariable."); if (!DIVar) @@ -1015,8 +1034,7 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, return true; Instruction *DbgVal = - Builder.insertDbgValueIntrinsic(LI->getOperand(0), 0, - DIVar, LI); + Builder.insertDbgValueIntrinsic(LI->getOperand(0), 0, DIVar, DIExpr, LI); DbgVal->setDebugLoc(DDI->getDebugLoc()); return true; } @@ -1056,14 +1074,14 @@ bool llvm::LowerDbgDeclare(Function &F) { else if (LoadInst *LI = dyn_cast<LoadInst>(U)) ConvertDebugDeclareToDebugValue(DDI, LI, DIB); else if (CallInst *CI = dyn_cast<CallInst>(U)) { - // This is a call by-value or some other instruction that - // takes a pointer to the variable. Insert a *value* - // intrinsic that describes the alloca. - auto DbgVal = - DIB.insertDbgValueIntrinsic(AI, 0, - DIVariable(DDI->getVariable()), CI); - DbgVal->setDebugLoc(DDI->getDebugLoc()); - } + // This is a call by-value or some other instruction that + // takes a pointer to the variable. Insert a *value* + // intrinsic that describes the alloca. + auto DbgVal = DIB.insertDbgValueIntrinsic( + AI, 0, DIVariable(DDI->getVariable()), + DIExpression(DDI->getExpression()), CI); + DbgVal->setDebugLoc(DDI->getDebugLoc()); + } DDI->eraseFromParent(); } } @@ -1087,6 +1105,7 @@ bool llvm::replaceDbgDeclareForAlloca(AllocaInst *AI, Value *NewAllocaAddress, if (!DDI) return false; DIVariable DIVar(DDI->getVariable()); + DIExpression DIExpr(DDI->getExpression()); assert((!DIVar || DIVar.isVariable()) && "Variable in DbgDeclareInst should be either null or a DIVariable."); if (!DIVar) @@ -1096,24 +1115,19 @@ bool llvm::replaceDbgDeclareForAlloca(AllocaInst *AI, Value *NewAllocaAddress, // "deref" operation to a list of address elements, as new llvm.dbg.declare // will take a value storing address of the memory for variable, not // alloca itself. - Type *Int64Ty = Type::getInt64Ty(AI->getContext()); - SmallVector<Value*, 4> NewDIVarAddress; - if (DIVar.hasComplexAddress()) { - for (unsigned i = 0, n = DIVar.getNumAddrElements(); i < n; ++i) { - NewDIVarAddress.push_back( - ConstantInt::get(Int64Ty, DIVar.getAddrElement(i))); + SmallVector<int64_t, 4> NewDIExpr; + if (DIExpr) { + for (unsigned i = 0, n = DIExpr.getNumElements(); i < n; ++i) { + NewDIExpr.push_back(DIExpr.getElement(i)); } } - NewDIVarAddress.push_back(ConstantInt::get(Int64Ty, DIBuilder::OpDeref)); - DIVariable NewDIVar = Builder.createComplexVariable( - DIVar.getTag(), DIVar.getContext(), DIVar.getName(), - DIVar.getFile(), DIVar.getLineNumber(), DIVar.getType(), - NewDIVarAddress, DIVar.getArgNumber()); + NewDIExpr.push_back(dwarf::DW_OP_deref); // Insert llvm.dbg.declare in the same basic block as the original alloca, // and remove old llvm.dbg.declare. BasicBlock *BB = AI->getParent(); - Builder.insertDeclare(NewAllocaAddress, NewDIVar, BB); + Builder.insertDeclare(NewAllocaAddress, DIVar, + Builder.createExpression(NewDIExpr), BB); DDI->eraseFromParent(); return true; } @@ -1165,7 +1179,7 @@ static void changeToCall(InvokeInst *II) { } static bool markAliveBlocks(BasicBlock *BB, - SmallPtrSet<BasicBlock*, 128> &Reachable) { + SmallPtrSetImpl<BasicBlock*> &Reachable) { SmallVector<BasicBlock*, 128> Worklist; Worklist.push_back(BB); @@ -1178,6 +1192,26 @@ static bool markAliveBlocks(BasicBlock *BB, // instructions into LLVM unreachable insts. The instruction combining pass // canonicalizes unreachable insts into stores to null or undef. for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E;++BBI){ + // Assumptions that are known to be false are equivalent to unreachable. + // Also, if the condition is undefined, then we make the choice most + // beneficial to the optimizer, and choose that to also be unreachable. + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(BBI)) + if (II->getIntrinsicID() == Intrinsic::assume) { + bool MakeUnreachable = false; + if (isa<UndefValue>(II->getArgOperand(0))) + MakeUnreachable = true; + else if (ConstantInt *Cond = + dyn_cast<ConstantInt>(II->getArgOperand(0))) + MakeUnreachable = Cond->isZero(); + + if (MakeUnreachable) { + // Don't insert a call to llvm.trap right before the unreachable. + changeToUnreachable(BBI, false); + Changed = true; + break; + } + } + if (CallInst *CI = dyn_cast<CallInst>(BBI)) { if (CI->doesNotReturn()) { // If we found a call to a no-return function, insert an unreachable @@ -1232,7 +1266,7 @@ static bool markAliveBlocks(BasicBlock *BB, Changed |= ConstantFoldTerminator(BB, true); for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) - if (Reachable.insert(*SI)) + if (Reachable.insert(*SI).second) Worklist.push_back(*SI); } while (!Worklist.empty()); return Changed; @@ -1272,3 +1306,43 @@ bool llvm::removeUnreachableBlocks(Function &F) { return true; } + +void llvm::combineMetadata(Instruction *K, const Instruction *J, ArrayRef<unsigned> KnownIDs) { + SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; + K->dropUnknownMetadata(KnownIDs); + K->getAllMetadataOtherThanDebugLoc(Metadata); + for (unsigned i = 0, n = Metadata.size(); i < n; ++i) { + unsigned Kind = Metadata[i].first; + MDNode *JMD = J->getMetadata(Kind); + MDNode *KMD = Metadata[i].second; + + switch (Kind) { + default: + K->setMetadata(Kind, nullptr); // Remove unknown metadata + break; + case LLVMContext::MD_dbg: + llvm_unreachable("getAllMetadataOtherThanDebugLoc returned a MD_dbg"); + case LLVMContext::MD_tbaa: + K->setMetadata(Kind, MDNode::getMostGenericTBAA(JMD, KMD)); + break; + case LLVMContext::MD_alias_scope: + case LLVMContext::MD_noalias: + K->setMetadata(Kind, MDNode::intersect(JMD, KMD)); + break; + case LLVMContext::MD_range: + K->setMetadata(Kind, MDNode::getMostGenericRange(JMD, KMD)); + break; + case LLVMContext::MD_fpmath: + K->setMetadata(Kind, MDNode::getMostGenericFPMath(JMD, KMD)); + break; + case LLVMContext::MD_invariant_load: + // Only set the !invariant.load if it is present in both instructions. + K->setMetadata(Kind, JMD); + break; + case LLVMContext::MD_nonnull: + // Only set the !nonnull if it is present in both instructions. + K->setMetadata(Kind, JMD); + break; + } + } +} diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp index ef42291..af0501f 100644 --- a/lib/Transforms/Utils/LoopSimplify.cpp +++ b/lib/Transforms/Utils/LoopSimplify.cpp @@ -44,6 +44,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -173,8 +174,7 @@ static BasicBlock *rewriteLoopExitBlock(Loop *L, BasicBlock *Exit, Pass *PP) { if (Exit->isLandingPad()) { SmallVector<BasicBlock*, 2> NewBBs; - SplitLandingPadPredecessors(Exit, ArrayRef<BasicBlock*>(&LoopBlocks[0], - LoopBlocks.size()), + SplitLandingPadPredecessors(Exit, LoopBlocks, ".loopexit", ".nonloopexit", PP, NewBBs); NewExitBB = NewBBs[0]; @@ -209,11 +209,12 @@ static void addBlockAndPredsToSet(BasicBlock *InputBB, BasicBlock *StopBlock, /// \brief The first part of loop-nestification is to find a PHI node that tells /// us how to partition the loops. static PHINode *findPHIToPartitionLoops(Loop *L, AliasAnalysis *AA, - DominatorTree *DT) { + DominatorTree *DT, + AssumptionTracker *AT) { for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ) { PHINode *PN = cast<PHINode>(I); ++I; - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT, AT)) { // This is a degenerate PHI already, don't modify it! PN->replaceAllUsesWith(V); if (AA) AA->deleteValue(PN); @@ -252,7 +253,8 @@ static PHINode *findPHIToPartitionLoops(Loop *L, AliasAnalysis *AA, /// static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, AliasAnalysis *AA, DominatorTree *DT, - LoopInfo *LI, ScalarEvolution *SE, Pass *PP) { + LoopInfo *LI, ScalarEvolution *SE, Pass *PP, + AssumptionTracker *AT) { // Don't try to separate loops without a preheader. if (!Preheader) return nullptr; @@ -261,7 +263,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, assert(!L->getHeader()->isLandingPad() && "Can't insert backedge to landing pad"); - PHINode *PN = findPHIToPartitionLoops(L, AA, DT); + PHINode *PN = findPHIToPartitionLoops(L, AA, DT, AT); if (!PN) return nullptr; // No known way to partition. // Pull out all predecessors that have varying values in the loop. This @@ -475,7 +477,7 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, static bool simplifyOneLoop(Loop *L, SmallVectorImpl<Loop *> &Worklist, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, Pass *PP, - const DataLayout *DL) { + const DataLayout *DL, AssumptionTracker *AT) { bool Changed = false; ReprocessLoop: @@ -496,20 +498,19 @@ ReprocessLoop: } // Delete each unique out-of-loop (and thus dead) predecessor. - for (SmallPtrSet<BasicBlock*, 4>::iterator I = BadPreds.begin(), - E = BadPreds.end(); I != E; ++I) { + for (BasicBlock *P : BadPreds) { DEBUG(dbgs() << "LoopSimplify: Deleting edge from dead predecessor " - << (*I)->getName() << "\n"); + << P->getName() << "\n"); // Inform each successor of each dead pred. - for (succ_iterator SI = succ_begin(*I), SE = succ_end(*I); SI != SE; ++SI) - (*SI)->removePredecessor(*I); + for (succ_iterator SI = succ_begin(P), SE = succ_end(P); SI != SE; ++SI) + (*SI)->removePredecessor(P); // Zap the dead pred's terminator and replace it with unreachable. - TerminatorInst *TI = (*I)->getTerminator(); + TerminatorInst *TI = P->getTerminator(); TI->replaceAllUsesWith(UndefValue::get(TI->getType())); - (*I)->getTerminator()->eraseFromParent(); - new UnreachableInst((*I)->getContext(), *I); + P->getTerminator()->eraseFromParent(); + new UnreachableInst(P->getContext(), P); Changed = true; } } @@ -582,7 +583,8 @@ ReprocessLoop: // this for loops with a giant number of backedges, just factor them into a // common backedge instead. if (L->getNumBackEdges() < 8) { - if (Loop *OuterL = separateNestedLoop(L, Preheader, AA, DT, LI, SE, PP)) { + if (Loop *OuterL = separateNestedLoop(L, Preheader, AA, DT, LI, SE, + PP, AT)) { ++NumNested; // Enqueue the outer loop as it should be processed next in our // depth-first nest walk. @@ -612,7 +614,7 @@ ReprocessLoop: PHINode *PN; for (BasicBlock::iterator I = L->getHeader()->begin(); (PN = dyn_cast<PHINode>(I++)); ) - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT, AT)) { if (AA) AA->deleteValue(PN); if (SE) SE->forgetValue(PN); PN->replaceAllUsesWith(V); @@ -712,7 +714,7 @@ ReprocessLoop: bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, Pass *PP, AliasAnalysis *AA, ScalarEvolution *SE, - const DataLayout *DL) { + const DataLayout *DL, AssumptionTracker *AT) { bool Changed = false; // Worklist maintains our depth-first queue of loops in this nest to process. @@ -730,7 +732,7 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, Pass *PP, while (!Worklist.empty()) Changed |= simplifyOneLoop(Worklist.pop_back_val(), Worklist, AA, DT, LI, - SE, PP, DL); + SE, PP, DL, AT); return Changed; } @@ -749,10 +751,13 @@ namespace { LoopInfo *LI; ScalarEvolution *SE; const DataLayout *DL; + AssumptionTracker *AT; bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionTracker>(); + // We need loop information to identify the loops... AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); @@ -773,11 +778,12 @@ namespace { char LoopSimplify::ID = 0; INITIALIZE_PASS_BEGIN(LoopSimplify, "loop-simplify", - "Canonicalize natural loops", true, false) + "Canonicalize natural loops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfo) INITIALIZE_PASS_END(LoopSimplify, "loop-simplify", - "Canonicalize natural loops", true, false) + "Canonicalize natural loops", false, false) // Publicly exposed interface to pass... char &llvm::LoopSimplifyID = LoopSimplify::ID; @@ -794,10 +800,11 @@ bool LoopSimplify::runOnFunction(Function &F) { SE = getAnalysisIfAvailable<ScalarEvolution>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; + AT = &getAnalysis<AssumptionTracker>(); // Simplify each loop nest in the function. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) - Changed |= simplifyLoop(*I, DT, LI, this, AA, SE, DL); + Changed |= simplifyLoop(*I, DT, LI, this, AA, SE, DL, AT); return Changed; } diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp index c86b82c..0e1baa1 100644 --- a/lib/Transforms/Utils/LoopUnroll.cpp +++ b/lib/Transforms/Utils/LoopUnroll.cpp @@ -17,7 +17,9 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/UnrollLoop.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" @@ -64,10 +66,15 @@ static inline void RemapInstruction(Instruction *I, /// FoldBlockIntoPredecessor - Folds a basic block into its predecessor if it /// only has one predecessor, and that predecessor only has one successor. -/// The LoopInfo Analysis that is passed will be kept consistent. -/// Returns the new combined block. -static BasicBlock *FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, - LPPassManager *LPM) { +/// The LoopInfo Analysis that is passed will be kept consistent. If folding is +/// successful references to the containing loop must be removed from +/// ScalarEvolution by calling ScalarEvolution::forgetLoop because SE may have +/// references to the eliminated BB. The argument ForgottenLoops contains a set +/// of loops that have already been forgotten to prevent redundant, expensive +/// calls to ScalarEvolution::forgetLoop. Returns the new combined block. +static BasicBlock * +FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, LPPassManager *LPM, + SmallPtrSetImpl<Loop *> &ForgottenLoops) { // Merge basic blocks into their predecessor if there is only one distinct // pred, and if there is only one distinct successor of the predecessor, and // if there are no PHI nodes. @@ -104,8 +111,10 @@ static BasicBlock *FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, // ScalarEvolution holds references to loop exit blocks. if (LPM) { if (ScalarEvolution *SE = LPM->getAnalysisIfAvailable<ScalarEvolution>()) { - if (Loop *L = LI->getLoopFor(BB)) - SE->forgetLoop(L); + if (Loop *L = LI->getLoopFor(BB)) { + if (ForgottenLoops.insert(L).second) + SE->forgetLoop(L); + } } } LI->removeBlock(BB); @@ -146,7 +155,8 @@ static BasicBlock *FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, /// available from the Pass it must also preserve those analyses. bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool AllowRuntime, unsigned TripMultiple, - LoopInfo *LI, Pass *PP, LPPassManager *LPM) { + LoopInfo *LI, Pass *PP, LPPassManager *LPM, + AssumptionTracker *AT) { BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { DEBUG(dbgs() << " Can't unroll; loop preheader-insertion failed.\n"); @@ -214,11 +224,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, // Notify ScalarEvolution that the loop will be substantially changed, // if not outright eliminated. - if (PP) { - ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>(); - if (SE) - SE->forgetLoop(L); - } + ScalarEvolution *SE = + PP ? PP->getAnalysisIfAvailable<ScalarEvolution>() : nullptr; + if (SE) + SE->forgetLoop(L); // If we know the trip count, we know the multiple... unsigned BreakoutTrip = 0; @@ -292,15 +301,45 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, for (unsigned It = 1; It != Count; ++It) { std::vector<BasicBlock*> NewBlocks; + SmallDenseMap<const Loop *, Loop *, 4> NewLoops; + NewLoops[L] = L; for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { ValueToValueMapTy VMap; BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It)); Header->getParent()->getBasicBlockList().push_back(New); - // Loop over all of the PHI nodes in the block, changing them to use the - // incoming values from the previous block. + // Tell LI about New. + if (*BB == Header) { + assert(LI->getLoopFor(*BB) == L && "Header should not be in a sub-loop"); + L->addBasicBlockToLoop(New, LI->getBase()); + } else { + // Figure out which loop New is in. + const Loop *OldLoop = LI->getLoopFor(*BB); + assert(OldLoop && "Should (at least) be in the loop being unrolled!"); + + Loop *&NewLoop = NewLoops[OldLoop]; + if (!NewLoop) { + // Found a new sub-loop. + assert(*BB == OldLoop->getHeader() && + "Header should be first in RPO"); + + Loop *NewLoopParent = NewLoops.lookup(OldLoop->getParentLoop()); + assert(NewLoopParent && + "Expected parent loop before sub-loop in RPO"); + NewLoop = new Loop; + NewLoopParent->addChildLoop(NewLoop); + + // Forget the old loop, since its inputs may have changed. + if (SE) + SE->forgetLoop(OldLoop); + } + NewLoop->addBasicBlockToLoop(New, LI->getBase()); + } + if (*BB == Header) + // Loop over all of the PHI nodes in the block, changing them to use + // the incoming values from the previous block. for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) { PHINode *NewPHI = cast<PHINode>(VMap[OrigPHINode[i]]); Value *InVal = NewPHI->getIncomingValueForBlock(LatchBlock); @@ -317,8 +356,6 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, VI != VE; ++VI) LastValueMap[VI->first] = VI->second; - L->addBasicBlockToLoop(New, LI->getBase()); - // Add phi entries for newly created values to all exit blocks. for (succ_iterator SI = succ_begin(*BB), SE = succ_end(*BB); SI != SE; ++SI) { @@ -423,15 +460,21 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, } // Merge adjacent basic blocks, if possible. + SmallPtrSet<Loop *, 4> ForgottenLoops; for (unsigned i = 0, e = Latches.size(); i != e; ++i) { BranchInst *Term = cast<BranchInst>(Latches[i]->getTerminator()); if (Term->isUnconditional()) { BasicBlock *Dest = Term->getSuccessor(0); - if (BasicBlock *Fold = FoldBlockIntoPredecessor(Dest, LI, LPM)) + if (BasicBlock *Fold = FoldBlockIntoPredecessor(Dest, LI, LPM, + ForgottenLoops)) std::replace(Latches.begin(), Latches.end(), Dest, Fold); } } + // FIXME: We could register any cloned assumptions instead of clearing the + // whole function's cache. + AT->forgetCachedAssumptions(F); + DominatorTree *DT = nullptr; if (PP) { // FIXME: Reconstruct dom info, because it is not preserved properly. @@ -443,7 +486,6 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, } // Simplify any new induction variables in the partially unrolled loop. - ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>(); if (SE && !CompletelyUnroll) { SmallVector<WeakVH, 16> DeadInsts; simplifyLoopIVs(L, SE, LPM, DeadInsts); @@ -492,8 +534,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, if (OuterL) { DataLayoutPass *DLP = PP->getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; - ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>(); - simplifyLoop(OuterL, DT, LI, PP, /*AliasAnalysis*/ nullptr, SE, DL); + simplifyLoop(OuterL, DT, LI, PP, /*AliasAnalysis*/ nullptr, SE, DL, AT); // LCSSA must be performed on the outermost affected loop. The unrolled // loop's last loop latch is guaranteed to be in the outermost loop after diff --git a/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/lib/Transforms/Utils/LoopUnrollRuntime.cpp index a96c46a..3d91336 100644 --- a/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -28,6 +28,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Metadata.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -57,7 +58,7 @@ STATISTIC(NumRuntimeUnrolled, static void ConnectProlog(Loop *L, Value *TripCount, unsigned Count, BasicBlock *LastPrologBB, BasicBlock *PrologEnd, BasicBlock *OrigPH, BasicBlock *NewPH, - ValueToValueMapTy &LVMap, Pass *P) { + ValueToValueMapTy &VMap, Pass *P) { BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Loop must have a latch"); @@ -86,7 +87,7 @@ static void ConnectProlog(Loop *L, Value *TripCount, unsigned Count, Value *V = PN->getIncomingValueForBlock(Latch); if (Instruction *I = dyn_cast<Instruction>(V)) { if (L->contains(I)) { - V = LVMap[I]; + V = VMap[I]; } } // Adding a value to the new PHI node from the last prolog block @@ -127,76 +128,122 @@ static void ConnectProlog(Loop *L, Value *TripCount, unsigned Count, } /// Create a clone of the blocks in a loop and connect them together. -/// This function doesn't create a clone of the loop structure. +/// If UnrollProlog is true, loop structure will not be cloned, otherwise a new +/// loop will be created including all cloned blocks, and the iterator of it +/// switches to count NewIter down to 0. /// -/// There are two value maps that are defined and used. VMap is -/// for the values in the current loop instance. LVMap contains -/// the values from the last loop instance. We need the LVMap values -/// to update the initial values for the current loop instance. -/// -static void CloneLoopBlocks(Loop *L, - bool FirstCopy, - BasicBlock *InsertTop, - BasicBlock *InsertBot, +static void CloneLoopBlocks(Loop *L, Value *NewIter, const bool UnrollProlog, + BasicBlock *InsertTop, BasicBlock *InsertBot, std::vector<BasicBlock *> &NewBlocks, - LoopBlocksDFS &LoopBlocks, - ValueToValueMapTy &VMap, - ValueToValueMapTy &LVMap, + LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, LoopInfo *LI) { - BasicBlock *Preheader = L->getLoopPreheader(); BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); Function *F = Header->getParent(); LoopBlocksDFS::RPOIterator BlockBegin = LoopBlocks.beginRPO(); LoopBlocksDFS::RPOIterator BlockEnd = LoopBlocks.endRPO(); + Loop *NewLoop = 0; + Loop *ParentLoop = L->getParentLoop(); + if (!UnrollProlog) { + NewLoop = new Loop(); + if (ParentLoop) + ParentLoop->addChildLoop(NewLoop); + else + LI->addTopLevelLoop(NewLoop); + } + // For each block in the original loop, create a new copy, // and update the value map with the newly created values. for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { - BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, ".unr", F); + BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, ".prol", F); NewBlocks.push_back(NewBB); - if (Loop *ParentLoop = L->getParentLoop()) + if (NewLoop) + NewLoop->addBasicBlockToLoop(NewBB, LI->getBase()); + else if (ParentLoop) ParentLoop->addBasicBlockToLoop(NewBB, LI->getBase()); VMap[*BB] = NewBB; if (Header == *BB) { // For the first block, add a CFG connection to this newly - // created block + // created block. InsertTop->getTerminator()->setSuccessor(0, NewBB); - // Change the incoming values to the ones defined in the - // previously cloned loop. - for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { - PHINode *NewPHI = cast<PHINode>(VMap[I]); - if (FirstCopy) { - // We replace the first phi node with the value from the preheader - VMap[I] = NewPHI->getIncomingValueForBlock(Preheader); - NewBB->getInstList().erase(NewPHI); - } else { - // Update VMap with values from the previous block - unsigned idx = NewPHI->getBasicBlockIndex(Latch); - Value *InVal = NewPHI->getIncomingValue(idx); - if (Instruction *I = dyn_cast<Instruction>(InVal)) - if (L->contains(I)) - InVal = LVMap[InVal]; - NewPHI->setIncomingValue(idx, InVal); - NewPHI->setIncomingBlock(idx, InsertTop); - } - } } - if (Latch == *BB) { + // For the last block, if UnrollProlog is true, create a direct jump to + // InsertBot. If not, create a loop back to cloned head. VMap.erase((*BB)->getTerminator()); - NewBB->getTerminator()->eraseFromParent(); - BranchInst::Create(InsertBot, NewBB); + BasicBlock *FirstLoopBB = cast<BasicBlock>(VMap[Header]); + BranchInst *LatchBR = cast<BranchInst>(NewBB->getTerminator()); + if (UnrollProlog) { + LatchBR->eraseFromParent(); + BranchInst::Create(InsertBot, NewBB); + } else { + PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2, "prol.iter", + FirstLoopBB->getFirstNonPHI()); + IRBuilder<> Builder(LatchBR); + Value *IdxSub = + Builder.CreateSub(NewIdx, ConstantInt::get(NewIdx->getType(), 1), + NewIdx->getName() + ".sub"); + Value *IdxCmp = + Builder.CreateIsNotNull(IdxSub, NewIdx->getName() + ".cmp"); + BranchInst::Create(FirstLoopBB, InsertBot, IdxCmp, NewBB); + NewIdx->addIncoming(NewIter, InsertTop); + NewIdx->addIncoming(IdxSub, NewBB); + LatchBR->eraseFromParent(); + } } } - // LastValueMap is updated with the values for the current loop - // which are used the next time this function is called. - for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end(); - VI != VE; ++VI) { - LVMap[VI->first] = VI->second; + + // Change the incoming values to the ones defined in the preheader or + // cloned loop. + for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { + PHINode *NewPHI = cast<PHINode>(VMap[I]); + if (UnrollProlog) { + VMap[I] = NewPHI->getIncomingValueForBlock(Preheader); + cast<BasicBlock>(VMap[Header])->getInstList().erase(NewPHI); + } else { + unsigned idx = NewPHI->getBasicBlockIndex(Preheader); + NewPHI->setIncomingBlock(idx, InsertTop); + BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); + idx = NewPHI->getBasicBlockIndex(Latch); + Value *InVal = NewPHI->getIncomingValue(idx); + NewPHI->setIncomingBlock(idx, NewLatch); + if (VMap[InVal]) + NewPHI->setIncomingValue(idx, VMap[InVal]); + } + } + if (NewLoop) { + // Add unroll disable metadata to disable future unrolling for this loop. + SmallVector<Value *, 4> Vals; + // Reserve first location for self reference to the LoopID metadata node. + Vals.push_back(nullptr); + MDNode *LoopID = NewLoop->getLoopID(); + if (LoopID) { + // First remove any existing loop unrolling metadata. + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + bool IsUnrollMetadata = false; + MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + if (MD) { + const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); + IsUnrollMetadata = S && S->getString().startswith("llvm.loop.unroll."); + } + if (!IsUnrollMetadata) Vals.push_back(LoopID->getOperand(i)); + } + } + + LLVMContext &Context = NewLoop->getHeader()->getContext(); + SmallVector<Value *, 1> DisableOperands; + DisableOperands.push_back(MDString::get(Context, "llvm.loop.unroll.disable")); + MDNode *DisableNode = MDNode::get(Context, DisableOperands); + Vals.push_back(DisableNode); + + MDNode *NewLoopID = MDNode::get(Context, Vals); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + NewLoop->setLoopID(NewLoopID); } } @@ -212,18 +259,16 @@ static void CloneLoopBlocks(Loop *L, /// instruction in SimplifyCFG.cpp. Then, the backend decides how code for /// the switch instruction is generated. /// -/// extraiters = tripcount % loopfactor -/// if (extraiters == 0) jump Loop: -/// if (extraiters == loopfactor) jump L1 -/// if (extraiters == loopfactor-1) jump L2 -/// ... -/// L1: LoopBody; -/// L2: LoopBody; -/// ... -/// if tripcount < loopfactor jump End -/// Loop: -/// ... -/// End: +/// extraiters = tripcount % loopfactor +/// if (extraiters == 0) jump Loop: +/// else jump Prol +/// Prol: LoopBody; +/// extraiters -= 1 // Omitted if unroll factor is 2. +/// if (extraiters != 0) jump Prol: // Omitted if unroll factor is 2. +/// if (tripcount < loopfactor) jump End +/// Loop: +/// ... +/// End: /// bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, LPPassManager *LPM) { @@ -250,6 +295,10 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, if (isa<SCEVCouldNotCompute>(BECount) || !BECount->getType()->isIntegerTy()) return false; + // If BECount is INT_MAX, we can't compute trip-count without overflow. + if (BECount->isAllOnesValue()) + return false; + // Add 1 since the backedge count doesn't include the first loop iteration const SCEV *TripCountSC = SE->getAddExpr(BECount, SE->getConstant(BECount->getType(), 1)); @@ -284,26 +333,21 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, IRBuilder<> B(PreHeaderBR); Value *ModVal = B.CreateAnd(TripCount, Count - 1, "xtraiter"); - // Check if for no extra iterations, then jump to unrolled loop. We have to - // check that the trip count computation didn't overflow when adding one to - // the backedge taken count. + // Check if for no extra iterations, then jump to cloned/unrolled loop. + // We have to check that the trip count computation didn't overflow when + // adding one to the backedge taken count. Value *LCmp = B.CreateIsNotNull(ModVal, "lcmp.mod"); Value *OverflowCheck = B.CreateIsNull(TripCount, "lcmp.overflow"); Value *BranchVal = B.CreateOr(OverflowCheck, LCmp, "lcmp.or"); - // Branch to either the extra iterations or the unrolled loop + // Branch to either the extra iterations or the cloned/unrolled loop // We will fix up the true branch label when adding loop body copies BranchInst::Create(PEnd, PEnd, BranchVal, PreHeaderBR); assert(PreHeaderBR->isUnconditional() && PreHeaderBR->getSuccessor(0) == PEnd && "CFG edges in Preheader are not correct"); PreHeaderBR->eraseFromParent(); - - ValueToValueMapTy LVMap; Function *F = Header->getParent(); - // These variables are used to update the CFG links in each iteration - BasicBlock *CompareBB = nullptr; - BasicBlock *LastLoopBB = PH; // Get an ordered list of blocks in the loop to help with the ordering of the // cloned blocks in the prolog code LoopBlocksDFS LoopBlocks(L); @@ -314,62 +358,39 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, // and generate a condition that branches to the copy depending on the // number of 'left over' iterations. // - for (unsigned leftOverIters = Count-1; leftOverIters > 0; --leftOverIters) { - std::vector<BasicBlock*> NewBlocks; - ValueToValueMapTy VMap; - - // Clone all the basic blocks in the loop, but we don't clone the loop - // This function adds the appropriate CFG connections. - CloneLoopBlocks(L, (leftOverIters == Count-1), LastLoopBB, PEnd, NewBlocks, - LoopBlocks, VMap, LVMap, LI); - LastLoopBB = cast<BasicBlock>(VMap[Latch]); - - // Insert the cloned blocks into function just before the original loop - F->getBasicBlockList().splice(PEnd, F->getBasicBlockList(), - NewBlocks[0], F->end()); - - // Generate the code for the comparison which determines if the loop - // prolog code needs to be executed. - if (leftOverIters == Count-1) { - // There is no compare block for the fall-thru case when for the last - // left over iteration - CompareBB = NewBlocks[0]; - } else { - // Create a new block for the comparison - BasicBlock *NewBB = BasicBlock::Create(CompareBB->getContext(), "unr.cmp", - F, CompareBB); - if (Loop *ParentLoop = L->getParentLoop()) { - // Add the new block to the parent loop, if needed - ParentLoop->addBasicBlockToLoop(NewBB, LI->getBase()); - } - - // The comparison w/ the extra iteration value and branch - Type *CountTy = TripCount->getType(); - Value *BranchVal = new ICmpInst(*NewBB, ICmpInst::ICMP_EQ, ModVal, - ConstantInt::get(CountTy, leftOverIters), - "un.tmp"); - // Branch to either the extra iterations or the unrolled loop - BranchInst::Create(NewBlocks[0], CompareBB, - BranchVal, NewBB); - CompareBB = NewBB; - PH->getTerminator()->setSuccessor(0, NewBB); - VMap[NewPH] = CompareBB; - } - - // Rewrite the cloned instruction operands to use the values - // created when the clone is created. - for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) { - for (BasicBlock::iterator I = NewBlocks[i]->begin(), - E = NewBlocks[i]->end(); I != E; ++I) { - RemapInstruction(I, VMap, - RF_NoModuleLevelChanges|RF_IgnoreMissingEntries); - } + std::vector<BasicBlock *> NewBlocks; + ValueToValueMapTy VMap; + + // If unroll count is 2 and we can't overflow in tripcount computation (which + // is BECount + 1), then we don't need a loop for prologue, and we can unroll + // it. We can be sure that we don't overflow only if tripcount is a constant. + bool UnrollPrologue = (Count == 2 && isa<ConstantInt>(TripCount)); + + // Clone all the basic blocks in the loop. If Count is 2, we don't clone + // the loop, otherwise we create a cloned loop to execute the extra + // iterations. This function adds the appropriate CFG connections. + CloneLoopBlocks(L, ModVal, UnrollPrologue, PH, PEnd, NewBlocks, LoopBlocks, + VMap, LI); + + // Insert the cloned blocks into function just before the original loop + F->getBasicBlockList().splice(PEnd, F->getBasicBlockList(), NewBlocks[0], + F->end()); + + // Rewrite the cloned instruction operands to use the values + // created when the clone is created. + for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) { + for (BasicBlock::iterator I = NewBlocks[i]->begin(), + E = NewBlocks[i]->end(); + I != E; ++I) { + RemapInstruction(I, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); } } // Connect the prolog code to the original loop and update the // PHI functions. - ConnectProlog(L, TripCount, Count, LastLoopBB, PEnd, PH, NewPH, LVMap, + BasicBlock *LastLoopBB = cast<BasicBlock>(VMap[Latch]); + ConnectProlog(L, TripCount, Count, LastLoopBB, PEnd, PH, NewPH, VMap, LPM->getAsPass()); NumRuntimeUnrolled++; return true; diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp index eac693b..a0105c2 100644 --- a/lib/Transforms/Utils/LowerSwitch.cpp +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -67,8 +67,8 @@ namespace { BasicBlock *switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, ConstantInt *UpperBound, - Value *Val, BasicBlock *OrigBlock, - BasicBlock *Default); + Value *Val, BasicBlock *Predecessor, + BasicBlock *OrigBlock, BasicBlock *Default); BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock, BasicBlock *Default); unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); @@ -131,6 +131,28 @@ static raw_ostream& operator<<(raw_ostream &O, return O << "]"; } +/// \brief Update the first occurrence of the "switch statement" BB in the PHI +/// node with the "new" BB. The other occurrences will be updated by subsequent +/// calls to this function. +/// +/// Switch statements may have more than one incoming edge into the same BB if +/// they all have the same value. When the switch statement is converted these +/// incoming edges are now coming from multiple BBs. +static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB) { + for (BasicBlock::iterator I = SuccBB->begin(), E = SuccBB->getFirstNonPHI(); + I != E; ++I) { + PHINode *PN = cast<PHINode>(I); + + // Only update the first occurence. + for (unsigned Idx = 0, E = PN->getNumIncomingValues(); Idx != E; ++Idx) { + if (PN->getIncomingBlock(Idx) == OrigBB) { + PN->setIncomingBlock(Idx, NewBB); + break; + } + } + } +} + // switchConvert - Convert the switch statement into a binary lookup of // the case values. The function recursively builds this tree. // LowerBound and UpperBound are used to keep track of the bounds for Val @@ -139,6 +161,7 @@ static raw_ostream& operator<<(raw_ostream &O, BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, ConstantInt *UpperBound, Value *Val, + BasicBlock *Predecessor, BasicBlock *OrigBlock, BasicBlock *Default) { unsigned Size = End - Begin; @@ -149,6 +172,7 @@ BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, // emitting the code that checks if the value actually falls in the range // because the bounds already tell us so. if (Begin->Low == LowerBound && Begin->High == UpperBound) { + fixPhis(Begin->BB, OrigBlock, Predecessor); return Begin->BB; } return newLeafBlock(*Begin, Val, OrigBlock, Default); @@ -200,21 +224,25 @@ BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, dbgs() << "NONE\n"; }); - BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound, - NewUpperBound, Val, OrigBlock, Default); - BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound, - UpperBound, Val, OrigBlock, Default); - // Create a new node that checks if the value is < pivot. Go to the // left branch if it is and right branch if not. Function* F = OrigBlock->getParent(); BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); - Function::iterator FI = OrigBlock; - F->getBasicBlockList().insert(++FI, NewNode); ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, Val, Pivot.Low, "Pivot"); + + BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound, + NewUpperBound, Val, NewNode, OrigBlock, + Default); + BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound, + UpperBound, Val, NewNode, OrigBlock, + Default); + + Function::iterator FI = OrigBlock; + F->getBasicBlockList().insert(++FI, NewNode); NewNode->getInstList().push_back(Comp); + BranchInst::Create(LBranch, RBranch, Comp, NewNode); return NewNode; } @@ -386,7 +414,7 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI) { } BasicBlock *SwitchBlock = switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, - OrigBlock, NewDefault); + OrigBlock, OrigBlock, NewDefault); // Branch to our shiny new if-then stuff... BranchInst::Create(SwitchBlock, OrigBlock); diff --git a/lib/Transforms/Utils/Mem2Reg.cpp b/lib/Transforms/Utils/Mem2Reg.cpp index 189caa7..477ee7a 100644 --- a/lib/Transforms/Utils/Mem2Reg.cpp +++ b/lib/Transforms/Utils/Mem2Reg.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -38,6 +39,7 @@ namespace { bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.setPreservesCFG(); // This is a cluster of orthogonal Transforms @@ -51,6 +53,7 @@ namespace { char PromotePass::ID = 0; INITIALIZE_PASS_BEGIN(PromotePass, "mem2reg", "Promote Memory to Register", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(PromotePass, "mem2reg", "Promote Memory to Register", false, false) @@ -63,6 +66,7 @@ bool PromotePass::runOnFunction(Function &F) { bool Changed = false; DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AssumptionTracker *AT = &getAnalysis<AssumptionTracker>(); while (1) { Allocas.clear(); @@ -76,7 +80,7 @@ bool PromotePass::runOnFunction(Function &F) { if (Allocas.empty()) break; - PromoteMemToReg(Allocas, DT); + PromoteMemToReg(Allocas, DT, nullptr, AT); NumPromoted += Allocas.size(); Changed = true; } diff --git a/lib/Transforms/Utils/ModuleUtils.cpp b/lib/Transforms/Utils/ModuleUtils.cpp index d9dbbca..35c701e 100644 --- a/lib/Transforms/Utils/ModuleUtils.cpp +++ b/lib/Transforms/Utils/ModuleUtils.cpp @@ -78,7 +78,7 @@ void llvm::appendToGlobalDtors(Module &M, Function *F, int Priority) { } GlobalVariable * -llvm::collectUsedGlobalVariables(Module &M, SmallPtrSet<GlobalValue *, 8> &Set, +llvm::collectUsedGlobalVariables(Module &M, SmallPtrSetImpl<GlobalValue *> &Set, bool CompilerUsed) { const char *Name = CompilerUsed ? "llvm.compiler.used" : "llvm.used"; GlobalVariable *GV = M.getGlobalVariable(Name); diff --git a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 06d73fe..1fd7071 100644 --- a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -238,6 +238,9 @@ struct PromoteMem2Reg { /// An AliasSetTracker object to update. If null, don't update it. AliasSetTracker *AST; + /// A cache of @llvm.assume intrinsics used by SimplifyInstruction. + AssumptionTracker *AT; + /// Reverse mapping of Allocas. DenseMap<AllocaInst *, unsigned> AllocaLookup; @@ -279,9 +282,9 @@ struct PromoteMem2Reg { public: PromoteMem2Reg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, - AliasSetTracker *AST) + AliasSetTracker *AST, AssumptionTracker *AT) : Allocas(Allocas.begin(), Allocas.end()), DT(DT), - DIB(*DT.getRoot()->getParent()->getParent()), AST(AST) {} + DIB(*DT.getRoot()->getParent()->getParent()), AST(AST), AT(AT) {} void run(); @@ -302,8 +305,8 @@ private: void DetermineInsertionPoint(AllocaInst *AI, unsigned AllocaNum, AllocaInfo &Info); void ComputeLiveInBlocks(AllocaInst *AI, AllocaInfo &Info, - const SmallPtrSet<BasicBlock *, 32> &DefBlocks, - SmallPtrSet<BasicBlock *, 32> &LiveInBlocks); + const SmallPtrSetImpl<BasicBlock *> &DefBlocks, + SmallPtrSetImpl<BasicBlock *> &LiveInBlocks); void RenamePass(BasicBlock *BB, BasicBlock *Pred, RenamePassData::ValVector &IncVals, std::vector<RenamePassData> &Worklist); @@ -685,7 +688,7 @@ void PromoteMem2Reg::run() { PHINode *PN = I->second; // If this PHI node merges one value and/or undefs, get the value. - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, &DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, &DT, AT)) { if (AST && PN->getType()->isPointerTy()) AST->deleteValue(PN); PN->replaceAllUsesWith(V); @@ -766,8 +769,8 @@ void PromoteMem2Reg::run() { /// inserted phi nodes would be dead). void PromoteMem2Reg::ComputeLiveInBlocks( AllocaInst *AI, AllocaInfo &Info, - const SmallPtrSet<BasicBlock *, 32> &DefBlocks, - SmallPtrSet<BasicBlock *, 32> &LiveInBlocks) { + const SmallPtrSetImpl<BasicBlock *> &DefBlocks, + SmallPtrSetImpl<BasicBlock *> &LiveInBlocks) { // To determine liveness, we must iterate through the predecessors of blocks // where the def is live. Blocks are added to the worklist if we need to @@ -816,7 +819,7 @@ void PromoteMem2Reg::ComputeLiveInBlocks( // The block really is live in here, insert it into the set. If already in // the set, then it has already been processed. - if (!LiveInBlocks.insert(BB)) + if (!LiveInBlocks.insert(BB).second) continue; // Since the value is live into BB, it is either defined in a predecessor or @@ -857,10 +860,8 @@ void PromoteMem2Reg::DetermineInsertionPoint(AllocaInst *AI, unsigned AllocaNum, less_second> IDFPriorityQueue; IDFPriorityQueue PQ; - for (SmallPtrSet<BasicBlock *, 32>::const_iterator I = DefBlocks.begin(), - E = DefBlocks.end(); - I != E; ++I) { - if (DomTreeNode *Node = DT.getNode(*I)) + for (BasicBlock *BB : DefBlocks) { + if (DomTreeNode *Node = DT.getNode(BB)) PQ.push(std::make_pair(Node, DomLevels[Node])); } @@ -898,7 +899,7 @@ void PromoteMem2Reg::DetermineInsertionPoint(AllocaInst *AI, unsigned AllocaNum, if (SuccLevel > RootLevel) continue; - if (!Visited.insert(SuccNode)) + if (!Visited.insert(SuccNode).second) continue; BasicBlock *SuccBB = SuccNode->getBlock(); @@ -1003,7 +1004,7 @@ NextIteration: } // Don't revisit blocks. - if (!Visited.insert(BB)) + if (!Visited.insert(BB).second) return; for (BasicBlock::iterator II = BB->begin(); !isa<TerminatorInst>(II);) { @@ -1060,17 +1061,17 @@ NextIteration: ++I; for (; I != E; ++I) - if (VisitedSuccs.insert(*I)) + if (VisitedSuccs.insert(*I).second) Worklist.push_back(RenamePassData(*I, Pred, IncomingVals)); goto NextIteration; } void llvm::PromoteMemToReg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, - AliasSetTracker *AST) { + AliasSetTracker *AST, AssumptionTracker *AT) { // If there is nothing to do, bail out... if (Allocas.empty()) return; - PromoteMem2Reg(Allocas, DT, AST).run(); + PromoteMem2Reg(Allocas, DT, AST, AT).run(); } diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index 960b198..92fd56a 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -43,6 +43,8 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <map> #include <set> @@ -68,12 +70,23 @@ static cl::opt<bool> HoistCondStores( cl::desc("Hoist conditional stores if an unconditional store precedes")); STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps"); +STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping"); STATISTIC(NumLookupTables, "Number of switch instructions turned into lookup tables"); STATISTIC(NumLookupTablesHoles, "Number of switch instructions turned into lookup tables (holes checked)"); STATISTIC(NumSinkCommons, "Number of common instructions sunk down to the end block"); STATISTIC(NumSpeculations, "Number of speculative executed instructions"); namespace { + // The first field contains the value that the switch produces when a certain + // case group is selected, and the second field is a vector containing the cases + // composing the case group. + typedef SmallVector<std::pair<Constant *, SmallVector<ConstantInt *, 4>>, 2> + SwitchCaseResultVectorTy; + // The first field contains the phi node that generates a result of the switch + // and the second field contains the value generated for a certain case in the switch + // for that PHI. + typedef SmallVector<std::pair<PHINode *, Constant *>, 4> SwitchCaseResultsTy; + /// ValueEqualityComparisonCase - Represents a case of a switch. struct ValueEqualityComparisonCase { ConstantInt *Value; @@ -92,7 +105,9 @@ namespace { class SimplifyCFGOpt { const TargetTransformInfo &TTI; + unsigned BonusInstThreshold; const DataLayout *const DL; + AssumptionTracker *AT; Value *isValueEqualityComparison(TerminatorInst *TI); BasicBlock *GetValueEqualityComparisonCases(TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases); @@ -111,8 +126,9 @@ class SimplifyCFGOpt { bool SimplifyCondBranch(BranchInst *BI, IRBuilder <>&Builder); public: - SimplifyCFGOpt(const TargetTransformInfo &TTI, const DataLayout *DL) - : TTI(TTI), DL(DL) {} + SimplifyCFGOpt(const TargetTransformInfo &TTI, unsigned BonusInstThreshold, + const DataLayout *DL, AssumptionTracker *AT) + : TTI(TTI), BonusInstThreshold(BonusInstThreshold), DL(DL), AT(AT) {} bool run(BasicBlock *BB); }; } @@ -256,7 +272,7 @@ static unsigned ComputeSpeculationCost(const User *I, const DataLayout *DL) { /// V plus its non-dominating operands. If that cost is greater than /// CostRemaining, false is returned and CostRemaining is undefined. static bool DominatesMergePoint(Value *V, BasicBlock *BB, - SmallPtrSet<Instruction*, 4> *AggressiveInsts, + SmallPtrSetImpl<Instruction*> *AggressiveInsts, unsigned &CostRemaining, const DataLayout *DL) { Instruction *I = dyn_cast<Instruction>(V); @@ -341,114 +357,177 @@ static ConstantInt *GetConstantInt(Value *V, const DataLayout *DL) { return nullptr; } -/// GatherConstantCompares - Given a potentially 'or'd or 'and'd together -/// collection of icmp eq/ne instructions that compare a value against a -/// constant, return the value being compared, and stick the constant into the -/// Values vector. -static Value * -GatherConstantCompares(Value *V, std::vector<ConstantInt*> &Vals, Value *&Extra, - const DataLayout *DL, bool isEQ, unsigned &UsedICmps) { - Instruction *I = dyn_cast<Instruction>(V); - if (!I) return nullptr; - - // If this is an icmp against a constant, handle this as one of the cases. - if (ICmpInst *ICI = dyn_cast<ICmpInst>(I)) { - if (ConstantInt *C = GetConstantInt(I->getOperand(1), DL)) { - Value *RHSVal; - ConstantInt *RHSC; - - if (ICI->getPredicate() == (isEQ ? ICmpInst::ICMP_EQ:ICmpInst::ICMP_NE)) { - // (x & ~2^x) == y --> x == y || x == y|2^x - // This undoes a transformation done by instcombine to fuse 2 compares. - if (match(ICI->getOperand(0), - m_And(m_Value(RHSVal), m_ConstantInt(RHSC)))) { - APInt Not = ~RHSC->getValue(); - if (Not.isPowerOf2()) { - Vals.push_back(C); - Vals.push_back( - ConstantInt::get(C->getContext(), C->getValue() | Not)); - UsedICmps++; - return RHSVal; - } - } +namespace { + +/// Given a chain of or (||) or and (&&) comparison of a value against a +/// constant, this will try to recover the information required for a switch +/// structure. +/// It will depth-first traverse the chain of comparison, seeking for patterns +/// like %a == 12 or %a < 4 and combine them to produce a set of integer +/// representing the different cases for the switch. +/// Note that if the chain is composed of '||' it will build the set of elements +/// that matches the comparisons (i.e. any of this value validate the chain) +/// while for a chain of '&&' it will build the set elements that make the test +/// fail. +struct ConstantComparesGatherer { + + Value *CompValue; /// Value found for the switch comparison + Value *Extra; /// Extra clause to be checked before the switch + SmallVector<ConstantInt *, 8> Vals; /// Set of integers to match in switch + unsigned UsedICmps; /// Number of comparisons matched in the and/or chain + + /// Construct and compute the result for the comparison instruction Cond + ConstantComparesGatherer(Instruction *Cond, const DataLayout *DL) + : CompValue(nullptr), Extra(nullptr), UsedICmps(0) { + gather(Cond, DL); + } - UsedICmps++; - Vals.push_back(C); - return I->getOperand(0); + /// Prevent copy + ConstantComparesGatherer(const ConstantComparesGatherer &) + LLVM_DELETED_FUNCTION; + ConstantComparesGatherer & + operator=(const ConstantComparesGatherer &) LLVM_DELETED_FUNCTION; + +private: + + /// Try to set the current value used for the comparison, it succeeds only if + /// it wasn't set before or if the new value is the same as the old one + bool setValueOnce(Value *NewVal) { + if(CompValue && CompValue != NewVal) return false; + CompValue = NewVal; + return (CompValue != nullptr); + } + + /// Try to match Instruction "I" as a comparison against a constant and + /// populates the array Vals with the set of values that match (or do not + /// match depending on isEQ). + /// Return false on failure. On success, the Value the comparison matched + /// against is placed in CompValue. + /// If CompValue is already set, the function is expected to fail if a match + /// is found but the value compared to is different. + bool matchInstruction(Instruction *I, const DataLayout *DL, bool isEQ) { + // If this is an icmp against a constant, handle this as one of the cases. + ICmpInst *ICI; + ConstantInt *C; + if (!((ICI = dyn_cast<ICmpInst>(I)) && + (C = GetConstantInt(I->getOperand(1), DL)))) { + return false; + } + + Value *RHSVal; + ConstantInt *RHSC; + + // Pattern match a special case + // (x & ~2^x) == y --> x == y || x == y|2^x + // This undoes a transformation done by instcombine to fuse 2 compares. + if (ICI->getPredicate() == (isEQ ? ICmpInst::ICMP_EQ:ICmpInst::ICMP_NE)) { + if (match(ICI->getOperand(0), + m_And(m_Value(RHSVal), m_ConstantInt(RHSC)))) { + APInt Not = ~RHSC->getValue(); + if (Not.isPowerOf2()) { + // If we already have a value for the switch, it has to match! + if(!setValueOnce(RHSVal)) + return false; + + Vals.push_back(C); + Vals.push_back(ConstantInt::get(C->getContext(), + C->getValue() | Not)); + UsedICmps++; + return true; + } } - // If we have "x ult 3" comparison, for example, then we can add 0,1,2 to - // the set. - ConstantRange Span = - ConstantRange::makeICmpRegion(ICI->getPredicate(), C->getValue()); - - // Shift the range if the compare is fed by an add. This is the range - // compare idiom as emitted by instcombine. - bool hasAdd = - match(I->getOperand(0), m_Add(m_Value(RHSVal), m_ConstantInt(RHSC))); - if (hasAdd) - Span = Span.subtract(RHSC->getValue()); - - // If this is an and/!= check then we want to optimize "x ugt 2" into - // x != 0 && x != 1. - if (!isEQ) - Span = Span.inverse(); - - // If there are a ton of values, we don't want to make a ginormous switch. - if (Span.getSetSize().ugt(8) || Span.isEmptySet()) - return nullptr; - - for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp) - Vals.push_back(ConstantInt::get(V->getContext(), Tmp)); + // If we already have a value for the switch, it has to match! + if(!setValueOnce(ICI->getOperand(0))) + return false; + UsedICmps++; - return hasAdd ? RHSVal : I->getOperand(0); + Vals.push_back(C); + return ICI->getOperand(0); } - return nullptr; - } - // Otherwise, we can only handle an | or &, depending on isEQ. - if (I->getOpcode() != (isEQ ? Instruction::Or : Instruction::And)) - return nullptr; + // If we have "x ult 3", for example, then we can add 0,1,2 to the set. + ConstantRange Span = ConstantRange::makeICmpRegion(ICI->getPredicate(), + C->getValue()); - unsigned NumValsBeforeLHS = Vals.size(); - unsigned UsedICmpsBeforeLHS = UsedICmps; - if (Value *LHS = GatherConstantCompares(I->getOperand(0), Vals, Extra, DL, - isEQ, UsedICmps)) { - unsigned NumVals = Vals.size(); - unsigned UsedICmpsBeforeRHS = UsedICmps; - if (Value *RHS = GatherConstantCompares(I->getOperand(1), Vals, Extra, DL, - isEQ, UsedICmps)) { - if (LHS == RHS) - return LHS; - Vals.resize(NumVals); - UsedICmps = UsedICmpsBeforeRHS; + // Shift the range if the compare is fed by an add. This is the range + // compare idiom as emitted by instcombine. + Value *CandidateVal = I->getOperand(0); + if(match(I->getOperand(0), m_Add(m_Value(RHSVal), m_ConstantInt(RHSC)))) { + Span = Span.subtract(RHSC->getValue()); + CandidateVal = RHSVal; } - // The RHS of the or/and can't be folded in and we haven't used "Extra" yet, - // set it and return success. - if (Extra == nullptr || Extra == I->getOperand(1)) { - Extra = I->getOperand(1); - return LHS; + // If this is an and/!= check, then we are looking to build the set of + // value that *don't* pass the and chain. I.e. to turn "x ugt 2" into + // x != 0 && x != 1. + if (!isEQ) + Span = Span.inverse(); + + // If there are a ton of values, we don't want to make a ginormous switch. + if (Span.getSetSize().ugt(8) || Span.isEmptySet()) { + return false; } - Vals.resize(NumValsBeforeLHS); - UsedICmps = UsedICmpsBeforeLHS; - return nullptr; + // If we already have a value for the switch, it has to match! + if(!setValueOnce(CandidateVal)) + return false; + + // Add all values from the range to the set + for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp) + Vals.push_back(ConstantInt::get(I->getContext(), Tmp)); + + UsedICmps++; + return true; + } - // If the LHS can't be folded in, but Extra is available and RHS can, try to - // use LHS as Extra. - if (Extra == nullptr || Extra == I->getOperand(0)) { - Value *OldExtra = Extra; - Extra = I->getOperand(0); - if (Value *RHS = GatherConstantCompares(I->getOperand(1), Vals, Extra, DL, - isEQ, UsedICmps)) - return RHS; - assert(Vals.size() == NumValsBeforeLHS); - Extra = OldExtra; + /// gather - Given a potentially 'or'd or 'and'd together collection of icmp + /// eq/ne/lt/gt instructions that compare a value against a constant, extract + /// the value being compared, and stick the list constants into the Vals + /// vector. + /// One "Extra" case is allowed to differ from the other. + void gather(Value *V, const DataLayout *DL) { + Instruction *I = dyn_cast<Instruction>(V); + bool isEQ = (I->getOpcode() == Instruction::Or); + + // Keep a stack (SmallVector for efficiency) for depth-first traversal + SmallVector<Value *, 8> DFT; + + // Initialize + DFT.push_back(V); + + while(!DFT.empty()) { + V = DFT.pop_back_val(); + + if (Instruction *I = dyn_cast<Instruction>(V)) { + // If it is a || (or && depending on isEQ), process the operands. + if (I->getOpcode() == (isEQ ? Instruction::Or : Instruction::And)) { + DFT.push_back(I->getOperand(1)); + DFT.push_back(I->getOperand(0)); + continue; + } + + // Try to match the current instruction + if (matchInstruction(I, DL, isEQ)) + // Match succeed, continue the loop + continue; + } + + // One element of the sequence of || (or &&) could not be match as a + // comparison against the same value as the others. + // We allow only one "Extra" case to be checked before the switch + if (!Extra) { + Extra = V; + continue; + } + // Failed to parse a proper sequence, abort now + CompValue = nullptr; + break; + } } +}; - return nullptr; } static void EraseTerminatorInstAndDCECond(TerminatorInst *TI) { @@ -628,7 +707,7 @@ SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, // Collect branch weights into a vector. SmallVector<uint32_t, 8> Weights; - MDNode* MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); bool HasWeight = MD && (MD->getNumOperands() == 2 + SI->getNumCases()); if (HasWeight) for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; @@ -723,7 +802,7 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1, } static inline bool HasBranchWeights(const Instruction* I) { - MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof); + MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof); if (ProfMD && ProfMD->getOperand(0)) if (MDString* MDS = dyn_cast<MDString>(ProfMD->getOperand(0))) return MDS->getString().equals("branch_weights"); @@ -736,7 +815,7 @@ static inline bool HasBranchWeights(const Instruction* I) { /// metadata. static void GetBranchWeights(TerminatorInst *TI, SmallVectorImpl<uint64_t> &Weights) { - MDNode* MD = TI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = TI->getMetadata(LLVMContext::MD_prof); assert(MD); for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) { ConstantInt *CI = cast<ConstantInt>(MD->getOperand(i)); @@ -995,6 +1074,8 @@ static bool isSafeToHoistInvoke(BasicBlock *BB1, BasicBlock *BB2, return true; } +static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I); + /// HoistThenElseCodeToIf - Given a conditional branch that goes to BB1 and /// BB2, hoist any common code in the two blocks up into the branch block. The /// caller of this function guarantees that BI's block dominates BB1 and BB2. @@ -1040,6 +1121,14 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, const DataLayout *DL) { if (!I2->use_empty()) I2->replaceAllUsesWith(I1); I1->intersectOptionalDataWith(I2); + unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, + LLVMContext::MD_range, + LLVMContext::MD_fpmath, + LLVMContext::MD_invariant_load, + LLVMContext::MD_nonnull + }; + combineMetadata(I1, I2, KnownIDs); I2->eraseFromParent(); Changed = true; @@ -1072,6 +1161,12 @@ HoistTerminator: if (BB1V == BB2V) continue; + // Check for passingValueIsAlwaysUndefined here because we would rather + // eliminate undefined control flow then converting it to a select. + if (passingValueIsAlwaysUndefined(BB1V, PN) || + passingValueIsAlwaysUndefined(BB2V, PN)) + return Changed; + if (isa<ConstantExpr>(BB1V) && !isSafeToSpeculativelyExecute(BB1V, DL)) return Changed; if (isa<ConstantExpr>(BB2V) && !isSafeToSpeculativelyExecute(BB2V, DL)) @@ -1281,6 +1376,8 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { if (!I2->use_empty()) I2->replaceAllUsesWith(I1); I1->intersectOptionalDataWith(I2); + // TODO: Use combineMetadata here to preserve what metadata we can + // (analogous to the hoisting case above). I2->eraseFromParent(); if (UpdateRE1) @@ -1486,6 +1583,11 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, if (ThenV == OrigV) continue; + // Don't convert to selects if we could remove undefined behavior instead. + if (passingValueIsAlwaysUndefined(OrigV, PN) || + passingValueIsAlwaysUndefined(ThenV, PN)) + return false; + HaveRewritablePHIs = true; ConstantExpr *OrigCE = dyn_cast<ConstantExpr>(OrigV); ConstantExpr *ThenCE = dyn_cast<ConstantExpr>(ThenV); @@ -1963,7 +2065,8 @@ static bool checkCSEInPredecessor(Instruction *Inst, BasicBlock *PB) { /// FoldBranchToCommonDest - If this basic block is simple enough, and if a /// predecessor branches to us and one of our successors, fold the block into /// the predecessor and use logical operations to pick the right destination. -bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { +bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL, + unsigned BonusInstThreshold) { BasicBlock *BB = BI->getParent(); Instruction *Cond = nullptr; @@ -2000,33 +2103,6 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { Cond->getParent() != BB || !Cond->hasOneUse()) return false; - // Only allow this if the condition is a simple instruction that can be - // executed unconditionally. It must be in the same block as the branch, and - // must be at the front of the block. - BasicBlock::iterator FrontIt = BB->front(); - - // Ignore dbg intrinsics. - while (isa<DbgInfoIntrinsic>(FrontIt)) ++FrontIt; - - // Allow a single instruction to be hoisted in addition to the compare - // that feeds the branch. We later ensure that any values that _it_ uses - // were also live in the predecessor, so that we don't unnecessarily create - // register pressure or inhibit out-of-order execution. - Instruction *BonusInst = nullptr; - if (&*FrontIt != Cond && - FrontIt->hasOneUse() && FrontIt->user_back() == Cond && - isSafeToSpeculativelyExecute(FrontIt, DL)) { - BonusInst = &*FrontIt; - ++FrontIt; - - // Ignore dbg intrinsics. - while (isa<DbgInfoIntrinsic>(FrontIt)) ++FrontIt; - } - - // Only a single bonus inst is allowed. - if (&*FrontIt != Cond) - return false; - // Make sure the instruction after the condition is the cond branch. BasicBlock::iterator CondIt = Cond; ++CondIt; @@ -2036,6 +2112,31 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { if (&*CondIt != BI) return false; + // Only allow this transformation if computing the condition doesn't involve + // too many instructions and these involved instructions can be executed + // unconditionally. We denote all involved instructions except the condition + // as "bonus instructions", and only allow this transformation when the + // number of the bonus instructions does not exceed a certain threshold. + unsigned NumBonusInsts = 0; + for (auto I = BB->begin(); Cond != I; ++I) { + // Ignore dbg intrinsics. + if (isa<DbgInfoIntrinsic>(I)) + continue; + if (!I->hasOneUse() || !isSafeToSpeculativelyExecute(I, DL)) + return false; + // I has only one use and can be executed unconditionally. + Instruction *User = dyn_cast<Instruction>(I->user_back()); + if (User == nullptr || User->getParent() != BB) + return false; + // I is used in the same BB. Since BI uses Cond and doesn't have more slots + // to use any other instruction, User must be an instruction between next(I) + // and Cond. + ++NumBonusInsts; + // Early exits once we reach the limit. + if (NumBonusInsts > BonusInstThreshold) + return false; + } + // Cond is known to be a compare or binary operator. Check to make sure that // neither operand is a potentially-trapping constant expression. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Cond->getOperand(0))) @@ -2086,49 +2187,6 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { continue; } - // Ensure that any values used in the bonus instruction are also used - // by the terminator of the predecessor. This means that those values - // must already have been resolved, so we won't be inhibiting the - // out-of-order core by speculating them earlier. We also allow - // instructions that are used by the terminator's condition because it - // exposes more merging opportunities. - bool UsedByBranch = (BonusInst && BonusInst->hasOneUse() && - BonusInst->user_back() == Cond); - - if (BonusInst && !UsedByBranch) { - // Collect the values used by the bonus inst - SmallPtrSet<Value*, 4> UsedValues; - for (Instruction::op_iterator OI = BonusInst->op_begin(), - OE = BonusInst->op_end(); OI != OE; ++OI) { - Value *V = *OI; - if (!isa<Constant>(V) && !isa<Argument>(V)) - UsedValues.insert(V); - } - - SmallVector<std::pair<Value*, unsigned>, 4> Worklist; - Worklist.push_back(std::make_pair(PBI->getOperand(0), 0)); - - // Walk up to four levels back up the use-def chain of the predecessor's - // terminator to see if all those values were used. The choice of four - // levels is arbitrary, to provide a compile-time-cost bound. - while (!Worklist.empty()) { - std::pair<Value*, unsigned> Pair = Worklist.back(); - Worklist.pop_back(); - - if (Pair.second >= 4) continue; - UsedValues.erase(Pair.first); - if (UsedValues.empty()) break; - - if (Instruction *I = dyn_cast<Instruction>(Pair.first)) { - for (Instruction::op_iterator OI = I->op_begin(), OE = I->op_end(); - OI != OE; ++OI) - Worklist.push_back(std::make_pair(OI->get(), Pair.second+1)); - } - } - - if (!UsedValues.empty()) return false; - } - DEBUG(dbgs() << "FOLDING BRANCH TO COMMON DEST:\n" << *PBI << *BB); IRBuilder<> Builder(PBI); @@ -2148,30 +2206,41 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { PBI->swapSuccessors(); } - // If we have a bonus inst, clone it into the predecessor block. - Instruction *NewBonus = nullptr; - if (BonusInst) { - NewBonus = BonusInst->clone(); + // If we have bonus instructions, clone them into the predecessor block. + // Note that there may be mutliple predecessor blocks, so we cannot move + // bonus instructions to a predecessor block. + ValueToValueMapTy VMap; // maps original values to cloned values + // We already make sure Cond is the last instruction before BI. Therefore, + // every instructions before Cond other than DbgInfoIntrinsic are bonus + // instructions. + for (auto BonusInst = BB->begin(); Cond != BonusInst; ++BonusInst) { + if (isa<DbgInfoIntrinsic>(BonusInst)) + continue; + Instruction *NewBonusInst = BonusInst->clone(); + RemapInstruction(NewBonusInst, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); + VMap[BonusInst] = NewBonusInst; // If we moved a load, we cannot any longer claim any knowledge about // its potential value. The previous information might have been valid // only given the branch precondition. // For an analogous reason, we must also drop all the metadata whose // semantics we don't understand. - NewBonus->dropUnknownMetadata(LLVMContext::MD_dbg); + NewBonusInst->dropUnknownMetadata(LLVMContext::MD_dbg); - PredBlock->getInstList().insert(PBI, NewBonus); - NewBonus->takeName(BonusInst); - BonusInst->setName(BonusInst->getName()+".old"); + PredBlock->getInstList().insert(PBI, NewBonusInst); + NewBonusInst->takeName(BonusInst); + BonusInst->setName(BonusInst->getName() + ".old"); } // Clone Cond into the predecessor basic block, and or/and the // two conditions together. Instruction *New = Cond->clone(); - if (BonusInst) New->replaceUsesOfWith(BonusInst, NewBonus); + RemapInstruction(New, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); PredBlock->getInstList().insert(PBI, New); New->takeName(Cond); - Cond->setName(New->getName()+".old"); + Cond->setName(New->getName() + ".old"); if (BI->isConditional()) { Instruction *NewCond = @@ -2649,7 +2718,7 @@ static bool SimplifyIndirectBrOnSelect(IndirectBrInst *IBI, SelectInst *SI) { /// the PHI, merging the third icmp into the switch. static bool TryToSimplifyUncondBranchWithICmpInIt( ICmpInst *ICI, IRBuilder<> &Builder, const TargetTransformInfo &TTI, - const DataLayout *DL) { + unsigned BonusInstThreshold, const DataLayout *DL, AssumptionTracker *AT) { BasicBlock *BB = ICI->getParent(); // If the block has any PHIs in it or the icmp has multiple uses, it is too @@ -2682,7 +2751,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( ICI->eraseFromParent(); } // BB is now empty, so it is likely to simplify away. - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; } // Ok, the block is reachable from the default dest. If the constant we're @@ -2698,7 +2767,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( ICI->replaceAllUsesWith(V); ICI->eraseFromParent(); // BB is now empty, so it is likely to simplify away. - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; } // The use of the icmp has to be in the 'end' block, by the only PHI node in @@ -2759,24 +2828,17 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, const DataLayout *DL, Instruction *Cond = dyn_cast<Instruction>(BI->getCondition()); if (!Cond) return false; - // Change br (X == 0 | X == 1), T, F into a switch instruction. // If this is a bunch of seteq's or'd together, or if it's a bunch of // 'setne's and'ed together, collect them. - Value *CompVal = nullptr; - std::vector<ConstantInt*> Values; - bool TrueWhenEqual = true; - Value *ExtraCase = nullptr; - unsigned UsedICmps = 0; - - if (Cond->getOpcode() == Instruction::Or) { - CompVal = GatherConstantCompares(Cond, Values, ExtraCase, DL, true, - UsedICmps); - } else if (Cond->getOpcode() == Instruction::And) { - CompVal = GatherConstantCompares(Cond, Values, ExtraCase, DL, false, - UsedICmps); - TrueWhenEqual = false; - } + + // Try to gather values from a chain of and/or to be turned into a switch + ConstantComparesGatherer ConstantCompare(Cond, DL); + // Unpack the result + SmallVectorImpl<ConstantInt*> &Values = ConstantCompare.Vals; + Value *CompVal = ConstantCompare.CompValue; + unsigned UsedICmps = ConstantCompare.UsedICmps; + Value *ExtraCase = ConstantCompare.Extra; // If we didn't have a multiply compared value, fail. if (!CompVal) return false; @@ -2785,6 +2847,8 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, const DataLayout *DL, if (UsedICmps <= 1) return false; + bool TrueWhenEqual = (Cond->getOpcode() == Instruction::Or); + // There might be duplicate constants in the list, which the switch // instruction can't handle, remove them now. array_pod_sort(Values.begin(), Values.end(), ConstantIntSortPredicate); @@ -3208,11 +3272,12 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { /// EliminateDeadSwitchCases - Compute masked bits for the condition of a switch /// and use it to remove dead cases. -static bool EliminateDeadSwitchCases(SwitchInst *SI) { +static bool EliminateDeadSwitchCases(SwitchInst *SI, const DataLayout *DL, + AssumptionTracker *AT) { Value *Cond = SI->getCondition(); unsigned Bits = Cond->getType()->getIntegerBitWidth(); APInt KnownZero(Bits, 0), KnownOne(Bits, 0); - computeKnownBits(Cond, KnownZero, KnownOne); + computeKnownBits(Cond, KnownZero, KnownOne, DL, 0, AT, SI); // Gather dead cases. SmallVector<ConstantInt*, 8> DeadCases; @@ -3460,6 +3525,163 @@ GetCaseResults(SwitchInst *SI, return Res.size() > 0; } +// MapCaseToResult - Helper function used to +// add CaseVal to the list of cases that generate Result. +static void MapCaseToResult(ConstantInt *CaseVal, + SwitchCaseResultVectorTy &UniqueResults, + Constant *Result) { + for (auto &I : UniqueResults) { + if (I.first == Result) { + I.second.push_back(CaseVal); + return; + } + } + UniqueResults.push_back(std::make_pair(Result, + SmallVector<ConstantInt*, 4>(1, CaseVal))); +} + +// InitializeUniqueCases - Helper function that initializes a map containing +// results for the PHI node of the common destination block for a switch +// instruction. Returns false if multiple PHI nodes have been found or if +// there is not a common destination block for the switch. +static bool InitializeUniqueCases( + SwitchInst *SI, const DataLayout *DL, PHINode *&PHI, + BasicBlock *&CommonDest, + SwitchCaseResultVectorTy &UniqueResults, + Constant *&DefaultResult) { + for (auto &I : SI->cases()) { + ConstantInt *CaseVal = I.getCaseValue(); + + // Resulting value at phi nodes for this case value. + SwitchCaseResultsTy Results; + if (!GetCaseResults(SI, CaseVal, I.getCaseSuccessor(), &CommonDest, Results, + DL)) + return false; + + // Only one value per case is permitted + if (Results.size() > 1) + return false; + MapCaseToResult(CaseVal, UniqueResults, Results.begin()->second); + + // Check the PHI consistency. + if (!PHI) + PHI = Results[0].first; + else if (PHI != Results[0].first) + return false; + } + // Find the default result value. + SmallVector<std::pair<PHINode *, Constant *>, 1> DefaultResults; + BasicBlock *DefaultDest = SI->getDefaultDest(); + GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResults, + DL); + // If the default value is not found abort unless the default destination + // is unreachable. + DefaultResult = + DefaultResults.size() == 1 ? DefaultResults.begin()->second : nullptr; + if ((!DefaultResult && + !isa<UnreachableInst>(DefaultDest->getFirstNonPHIOrDbg()))) + return false; + + return true; +} + +// ConvertTwoCaseSwitch - Helper function that checks if it is possible to +// transform a switch with only two cases (or two cases + default) +// that produces a result into a value select. +// Example: +// switch (a) { +// case 10: %0 = icmp eq i32 %a, 10 +// return 10; %1 = select i1 %0, i32 10, i32 4 +// case 20: ----> %2 = icmp eq i32 %a, 20 +// return 2; %3 = select i1 %2, i32 2, i32 %1 +// default: +// return 4; +// } +static Value * +ConvertTwoCaseSwitch(const SwitchCaseResultVectorTy &ResultVector, + Constant *DefaultResult, Value *Condition, + IRBuilder<> &Builder) { + assert(ResultVector.size() == 2 && + "We should have exactly two unique results at this point"); + // If we are selecting between only two cases transform into a simple + // select or a two-way select if default is possible. + if (ResultVector[0].second.size() == 1 && + ResultVector[1].second.size() == 1) { + ConstantInt *const FirstCase = ResultVector[0].second[0]; + ConstantInt *const SecondCase = ResultVector[1].second[0]; + + bool DefaultCanTrigger = DefaultResult; + Value *SelectValue = ResultVector[1].first; + if (DefaultCanTrigger) { + Value *const ValueCompare = + Builder.CreateICmpEQ(Condition, SecondCase, "switch.selectcmp"); + SelectValue = Builder.CreateSelect(ValueCompare, ResultVector[1].first, + DefaultResult, "switch.select"); + } + Value *const ValueCompare = + Builder.CreateICmpEQ(Condition, FirstCase, "switch.selectcmp"); + return Builder.CreateSelect(ValueCompare, ResultVector[0].first, SelectValue, + "switch.select"); + } + + return nullptr; +} + +// RemoveSwitchAfterSelectConversion - Helper function to cleanup a switch +// instruction that has been converted into a select, fixing up PHI nodes and +// basic blocks. +static void RemoveSwitchAfterSelectConversion(SwitchInst *SI, PHINode *PHI, + Value *SelectValue, + IRBuilder<> &Builder) { + BasicBlock *SelectBB = SI->getParent(); + while (PHI->getBasicBlockIndex(SelectBB) >= 0) + PHI->removeIncomingValue(SelectBB); + PHI->addIncoming(SelectValue, SelectBB); + + Builder.CreateBr(PHI->getParent()); + + // Remove the switch. + for (unsigned i = 0, e = SI->getNumSuccessors(); i < e; ++i) { + BasicBlock *Succ = SI->getSuccessor(i); + + if (Succ == PHI->getParent()) + continue; + Succ->removePredecessor(SelectBB); + } + SI->eraseFromParent(); +} + +/// SwitchToSelect - If the switch is only used to initialize one or more +/// phi nodes in a common successor block with only two different +/// constant values, replace the switch with select. +static bool SwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder, + const DataLayout *DL, AssumptionTracker *AT) { + Value *const Cond = SI->getCondition(); + PHINode *PHI = nullptr; + BasicBlock *CommonDest = nullptr; + Constant *DefaultResult; + SwitchCaseResultVectorTy UniqueResults; + // Collect all the cases that will deliver the same value from the switch. + if (!InitializeUniqueCases(SI, DL, PHI, CommonDest, UniqueResults, + DefaultResult)) + return false; + // Selects choose between maximum two values. + if (UniqueResults.size() != 2) + return false; + assert(PHI != nullptr && "PHI for value select not found"); + + Builder.SetInsertPoint(SI); + Value *SelectValue = ConvertTwoCaseSwitch( + UniqueResults, + DefaultResult, Cond, Builder); + if (SelectValue) { + RemoveSwitchAfterSelectConversion(SI, PHI, SelectValue, Builder); + return true; + } + // The switch couldn't be converted into a select. + return false; +} + namespace { /// SwitchLookupTable - This class represents a lookup table that can be used /// to replace a switch. @@ -3493,6 +3715,11 @@ namespace { // store that single value and return it for each lookup. SingleValueKind, + // For tables where there is a linear relationship between table index + // and values. We calculate the result with a simple multiplication + // and addition instead of a table lookup. + LinearMapKind, + // For small tables with integer elements, we can pack them into a bitmap // that fits into a target-legal register. Values are retrieved by // shift and mask operations. @@ -3510,6 +3737,10 @@ namespace { ConstantInt *BitMap; IntegerType *BitMapElementTy; + // For LinearMapKind, these are the constants used to derive the value. + ConstantInt *LinearOffset; + ConstantInt *LinearMultiplier; + // For ArrayKind, this is the array. GlobalVariable *Array; }; @@ -3522,7 +3753,7 @@ SwitchLookupTable::SwitchLookupTable(Module &M, Constant *DefaultValue, const DataLayout *DL) : SingleValue(nullptr), BitMap(nullptr), BitMapElementTy(nullptr), - Array(nullptr) { + LinearOffset(nullptr), LinearMultiplier(nullptr), Array(nullptr) { assert(Values.size() && "Can't build lookup table without values!"); assert(TableSize >= Values.size() && "Can't fit values in table!"); @@ -3567,6 +3798,43 @@ SwitchLookupTable::SwitchLookupTable(Module &M, return; } + // Check if we can derive the value with a linear transformation from the + // table index. + if (isa<IntegerType>(ValueType)) { + bool LinearMappingPossible = true; + APInt PrevVal; + APInt DistToPrev; + assert(TableSize >= 2 && "Should be a SingleValue table."); + // Check if there is the same distance between two consecutive values. + for (uint64_t I = 0; I < TableSize; ++I) { + ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]); + if (!ConstVal) { + // This is an undef. We could deal with it, but undefs in lookup tables + // are very seldom. It's probably not worth the additional complexity. + LinearMappingPossible = false; + break; + } + APInt Val = ConstVal->getValue(); + if (I != 0) { + APInt Dist = Val - PrevVal; + if (I == 1) { + DistToPrev = Dist; + } else if (Dist != DistToPrev) { + LinearMappingPossible = false; + break; + } + } + PrevVal = Val; + } + if (LinearMappingPossible) { + LinearOffset = cast<ConstantInt>(TableContents[0]); + LinearMultiplier = ConstantInt::get(M.getContext(), DistToPrev); + Kind = LinearMapKind; + ++NumLinearMaps; + return; + } + } + // If the type is integer and the table fits in a register, build a bitmap. if (WouldFitInRegister(DL, TableSize, ValueType)) { IntegerType *IT = cast<IntegerType>(ValueType); @@ -3602,6 +3870,16 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) { switch (Kind) { case SingleValueKind: return SingleValue; + case LinearMapKind: { + // Derive the result value from the input value. + Value *Result = Builder.CreateIntCast(Index, LinearMultiplier->getType(), + false, "switch.idx.cast"); + if (!LinearMultiplier->isOne()) + Result = Builder.CreateMul(Result, LinearMultiplier, "switch.idx.mult"); + if (!LinearOffset->isZero()) + Result = Builder.CreateAdd(Result, LinearOffset, "switch.offset"); + return Result; + } case BitMapKind: { // Type of the bitmap (e.g. i59). IntegerType *MapTy = BitMap->getType(); @@ -3624,6 +3902,16 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) { "switch.masked"); } case ArrayKind: { + // Make sure the table index will not overflow when treated as signed. + IntegerType *IT = cast<IntegerType>(Index->getType()); + uint64_t TableSize = Array->getInitializer()->getType() + ->getArrayNumElements(); + if (TableSize > (1ULL << (IT->getBitWidth() - 1))) + Index = Builder.CreateZExt(Index, + IntegerType::get(IT->getContext(), + IT->getBitWidth() + 1), + "switch.tableidx.zext"); + Value *GEPIndices[] = { Builder.getInt32(0), Index }; Value *GEP = Builder.CreateInBoundsGEP(Array, GEPIndices, "switch.gep"); @@ -3663,9 +3951,8 @@ static bool ShouldBuildLookupTable(SwitchInst *SI, bool AllTablesFitInRegister = true; bool HasIllegalType = false; - for (SmallDenseMap<PHINode*, Type*>::const_iterator I = ResultTypes.begin(), - E = ResultTypes.end(); I != E; ++I) { - Type *Ty = I->second; + for (const auto &I : ResultTypes) { + Type *Ty = I.second; // Saturate this flag to true. HasIllegalType = HasIllegalType || !TTI.isTypeLegal(Ty); @@ -3749,16 +4036,17 @@ static bool SwitchToLookupTable(SwitchInst *SI, return false; // Append the result from this case to the list for each phi. - for (ResultsTy::iterator I = Results.begin(), E = Results.end(); I!=E; ++I) { - if (!ResultLists.count(I->first)) - PHIs.push_back(I->first); - ResultLists[I->first].push_back(std::make_pair(CaseVal, I->second)); + for (const auto &I : Results) { + PHINode *PHI = I.first; + Constant *Value = I.second; + if (!ResultLists.count(PHI)) + PHIs.push_back(PHI); + ResultLists[PHI].push_back(std::make_pair(CaseVal, Value)); } } // Keep track of the result types. - for (size_t I = 0, E = PHIs.size(); I != E; ++I) { - PHINode *PHI = PHIs[I]; + for (PHINode *PHI : PHIs) { ResultTypes[PHI] = ResultLists[PHI][0].second->getType(); } @@ -3775,6 +4063,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, HasDefaultResults = GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResultsList, DL); } + bool NeedMask = (TableHasHoles && !HasDefaultResults); if (NeedMask) { // As an extra penalty for the validity test we require more cases. @@ -3784,9 +4073,9 @@ static bool SwitchToLookupTable(SwitchInst *SI, return false; } - for (size_t I = 0, E = DefaultResultsList.size(); I != E; ++I) { - PHINode *PHI = DefaultResultsList[I].first; - Constant *Result = DefaultResultsList[I].second; + for (const auto &I : DefaultResultsList) { + PHINode *PHI = I.first; + Constant *Result = I.second; DefaultResults[PHI] = Result; } @@ -3820,10 +4109,13 @@ static bool SwitchToLookupTable(SwitchInst *SI, const bool GeneratingCoveredLookupTable = MaxTableSize == TableSize; if (GeneratingCoveredLookupTable) { Builder.CreateBr(LookupBB); - SI->getDefaultDest()->removePredecessor(SI->getParent()); + // We cached PHINodes in PHIs, to avoid accessing deleted PHINodes later, + // do not delete PHINodes here. + SI->getDefaultDest()->removePredecessor(SI->getParent(), + true/*DontDeleteUselessPHIs*/); } else { Value *Cmp = Builder.CreateICmpULT(TableIndex, ConstantInt::get( - MinCaseVal->getType(), TableSize)); + MinCaseVal->getType(), TableSize)); Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest()); } @@ -3841,9 +4133,12 @@ static bool SwitchToLookupTable(SwitchInst *SI, CommonDest->getParent(), CommonDest); + // Make the mask's bitwidth at least 8bit and a power-of-2 to avoid + // unnecessary illegal types. + uint64_t TableSizePowOf2 = NextPowerOf2(std::max(7ULL, TableSize - 1ULL)); + APInt MaskInt(TableSizePowOf2, 0); + APInt One(TableSizePowOf2, 1); // Build bitmask; fill in a 1 bit for every case. - APInt MaskInt(TableSize, 0); - APInt One(TableSize, 1); const ResultListTy &ResultList = ResultLists[PHIs[0]]; for (size_t I = 0, E = ResultList.size(); I != E; ++I) { uint64_t Idx = (ResultList[I].first->getValue() - @@ -3919,12 +4214,12 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { // see if that predecessor totally determines the outcome of this switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(SI, OnlyPred, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; Value *Cond = SI->getCondition(); if (SelectInst *Select = dyn_cast<SelectInst>(Cond)) if (SimplifySwitchOnSelect(SI, Select)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; // If the block only contains the switch, see if we can fold the block // away into any preds. @@ -3934,22 +4229,25 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { ++BBI; if (SI == &*BBI) if (FoldValueComparisonIntoPredecessors(SI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; } // Try to transform the switch into an icmp and a branch. if (TurnSwitchRangeIntoICmp(SI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; // Remove unreachable cases. - if (EliminateDeadSwitchCases(SI)) - return SimplifyCFG(BB, TTI, DL) | true; + if (EliminateDeadSwitchCases(SI, DL, AT)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; + + if (SwitchToSelect(SI, Builder, DL, AT)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; if (ForwardSwitchConditionToPHI(SI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; if (SwitchToLookupTable(SI, Builder, TTI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; return false; } @@ -3962,7 +4260,7 @@ bool SimplifyCFGOpt::SimplifyIndirectBr(IndirectBrInst *IBI) { SmallPtrSet<Value *, 8> Succs; for (unsigned i = 0, e = IBI->getNumDestinations(); i != e; ++i) { BasicBlock *Dest = IBI->getDestination(i); - if (!Dest->hasAddressTaken() || !Succs.insert(Dest)) { + if (!Dest->hasAddressTaken() || !Succs.insert(Dest).second) { Dest->removePredecessor(BB); IBI->removeDestination(i); --i; --e; @@ -3986,7 +4284,7 @@ bool SimplifyCFGOpt::SimplifyIndirectBr(IndirectBrInst *IBI) { if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) { if (SimplifyIndirectBrOnSelect(IBI, SI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; } return Changed; } @@ -3998,7 +4296,7 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ return true; // If the Terminator is the only non-phi instruction, simplify the block. - BasicBlock::iterator I = BB->getFirstNonPHIOrDbgOrLifetime(); + BasicBlock::iterator I = BB->getFirstNonPHIOrDbg(); if (I->isTerminator() && BB != &BB->getParent()->getEntryBlock() && TryToSimplifyUncondBranchFromEmptyBlock(BB)) return true; @@ -4010,7 +4308,8 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ for (++I; isa<DbgInfoIntrinsic>(I); ++I) ; if (I->isTerminator() && - TryToSimplifyUncondBranchWithICmpInIt(ICI, Builder, TTI, DL)) + TryToSimplifyUncondBranchWithICmpInIt(ICI, Builder, TTI, + BonusInstThreshold, DL, AT)) return true; } @@ -4018,8 +4317,8 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ // branches to us and our successor, fold the comparison into the // predecessor and use logical operations to update the incoming value // for PHI nodes in common successor. - if (FoldBranchToCommonDest(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + if (FoldBranchToCommonDest(BI, DL, BonusInstThreshold)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; return false; } @@ -4034,7 +4333,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(BI, OnlyPred, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; // This block must be empty, except for the setcond inst, if it exists. // Ignore dbg intrinsics. @@ -4044,14 +4343,14 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { ++I; if (&*I == BI) { if (FoldValueComparisonIntoPredecessors(BI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; } else if (&*I == cast<Instruction>(BI->getCondition())){ ++I; // Ignore dbg intrinsics. while (isa<DbgInfoIntrinsic>(I)) ++I; if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; } } @@ -4062,8 +4361,8 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // If this basic block is ONLY a compare and a branch, and if a predecessor // branches to us and one of our successors, fold the comparison into the // predecessor and use logical operations to pick the right destination. - if (FoldBranchToCommonDest(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + if (FoldBranchToCommonDest(BI, DL, BonusInstThreshold)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; // We have a conditional branch to two blocks that are only reachable // from BI. We know that the condbr dominates the two blocks, so see if @@ -4072,7 +4371,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (BI->getSuccessor(0)->getSinglePredecessor()) { if (BI->getSuccessor(1)->getSinglePredecessor()) { if (HoistThenElseCodeToIf(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; } else { // If Successor #1 has multiple preds, we may be able to conditionally // execute Successor #0 if it branches to Successor #1. @@ -4080,7 +4379,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (Succ0TI->getNumSuccessors() == 1 && Succ0TI->getSuccessor(0) == BI->getSuccessor(1)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(0), DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; } } else if (BI->getSuccessor(1)->getSinglePredecessor()) { // If Successor #0 has multiple preds, we may be able to conditionally @@ -4089,7 +4388,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (Succ1TI->getNumSuccessors() == 1 && Succ1TI->getSuccessor(0) == BI->getSuccessor(0)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(1), DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; } // If this is a branch on a phi node in the current block, thread control @@ -4097,14 +4396,14 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (PHINode *PN = dyn_cast<PHINode>(BI->getCondition())) if (PN->getParent() == BI->getParent()) if (FoldCondBranchOnPHI(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; // Scan predecessor blocks for conditional branches. for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) if (BranchInst *PBI = dyn_cast<BranchInst>((*PI)->getTerminator())) if (PBI != BI && PBI->isConditional()) if (SimplifyCondBranchToCondBranch(PBI, BI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AT) | true; return false; } @@ -4248,6 +4547,7 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { /// of the CFG. It returns true if a modification was made. /// bool llvm::SimplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, - const DataLayout *DL) { - return SimplifyCFGOpt(TTI, DL).run(BB); + unsigned BonusInstThreshold, + const DataLayout *DL, AssumptionTracker *AT) { + return SimplifyCFGOpt(TTI, BonusInstThreshold, DL, AT).run(BB); } diff --git a/lib/Transforms/Utils/SimplifyIndVar.cpp b/lib/Transforms/Utils/SimplifyIndVar.cpp index b284e6f..a4fdd55 100644 --- a/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -40,7 +40,7 @@ STATISTIC(NumElimRem , "Number of IV remainder operations eliminated"); STATISTIC(NumElimCmp , "Number of IV comparisons eliminated"); namespace { - /// SimplifyIndvar - This is a utility for simplifying induction variables + /// This is a utility for simplifying induction variables /// based on ScalarEvolution. It is the primary instrument of the /// IndvarSimplify pass, but it may also be directly invoked to cleanup after /// other loop passes that preserve SCEV. @@ -86,7 +86,7 @@ namespace { }; } -/// foldIVUser - Fold an IV operand into its use. This removes increments of an +/// Fold an IV operand into its use. This removes increments of an /// aligned IV when used by a instruction that ignores the low bits. /// /// IVOperand is guaranteed SCEVable, but UseInst may not be. @@ -152,7 +152,7 @@ Value *SimplifyIndvar::foldIVUser(Instruction *UseInst, Instruction *IVOperand) return IVSrc; } -/// eliminateIVComparison - SimplifyIVUsers helper for eliminating useless +/// SimplifyIVUsers helper for eliminating useless /// comparisons against an induction variable. void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { unsigned IVOperIdx = 0; @@ -188,7 +188,7 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { DeadInsts.push_back(ICmp); } -/// eliminateIVRemainder - SimplifyIVUsers helper for eliminating useless +/// SimplifyIVUsers helper for eliminating useless /// remainder operations operating on an induction variable. void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem, Value *IVOperand, @@ -239,7 +239,7 @@ void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem, DeadInsts.push_back(Rem); } -/// eliminateIVUser - Eliminate an operation that consumes a simple IV and has +/// Eliminate an operation that consumes a simple IV and has /// no observable side-effect given the range of IV values. /// IVOperand is guaranteed SCEVable, but UseInst may not be. bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst, @@ -334,8 +334,7 @@ Instruction *SimplifyIndvar::splitOverflowIntrinsic(Instruction *IVUser, return AddInst; } -/// pushIVUsers - Add all uses of Def to the current IV's worklist. -/// +/// Add all uses of Def to the current IV's worklist. static void pushIVUsers( Instruction *Def, SmallPtrSet<Instruction*,16> &Simplified, @@ -348,12 +347,12 @@ static void pushIVUsers( // Also ensure unique worklist users. // If Def is a LoopPhi, it may not be in the Simplified set, so check for // self edges first. - if (UI != Def && Simplified.insert(UI)) + if (UI != Def && Simplified.insert(UI).second) SimpleIVUsers.push_back(std::make_pair(UI, Def)); } } -/// isSimpleIVUser - Return true if this instruction generates a simple SCEV +/// Return true if this instruction generates a simple SCEV /// expression in terms of that IV. /// /// This is similar to IVUsers' isInteresting() but processes each instruction @@ -374,7 +373,7 @@ static bool isSimpleIVUser(Instruction *I, const Loop *L, ScalarEvolution *SE) { return false; } -/// simplifyUsers - Iteratively perform simplification on a worklist of users +/// Iteratively perform simplification on a worklist of users /// of the specified induction variable. Each successive simplification may push /// more users which may themselves be candidates for simplification. /// @@ -446,7 +445,7 @@ namespace llvm { void IVVisitor::anchor() { } -/// simplifyUsersOfIV - Simplify instructions that use this induction variable +/// Simplify instructions that use this induction variable /// by using ScalarEvolution to analyze the IV's recurrence. bool simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, LPPassManager *LPM, SmallVectorImpl<WeakVH> &Dead, IVVisitor *V) @@ -457,7 +456,7 @@ bool simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, LPPassManager *LPM, return SIV.hasChanged(); } -/// simplifyLoopIVs - Simplify users of induction variables within this +/// Simplify users of induction variables within this /// loop. This does not actually change or add IVs. bool simplifyLoopIVs(Loop *L, ScalarEvolution *SE, LPPassManager *LPM, SmallVectorImpl<WeakVH> &Dead) { diff --git a/lib/Transforms/Utils/SimplifyInstructions.cpp b/lib/Transforms/Utils/SimplifyInstructions.cpp index 33b3637..5632095 100644 --- a/lib/Transforms/Utils/SimplifyInstructions.cpp +++ b/lib/Transforms/Utils/SimplifyInstructions.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" @@ -41,6 +42,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired<AssumptionTracker>(); AU.addRequired<TargetLibraryInfo>(); } @@ -52,6 +54,7 @@ namespace { DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; const TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfo>(); + AssumptionTracker *AT = &getAnalysis<AssumptionTracker>(); SmallPtrSet<const Instruction*, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; bool Changed = false; @@ -68,7 +71,7 @@ namespace { continue; // Don't waste time simplifying unused instructions. if (!I->use_empty()) - if (Value *V = SimplifyInstruction(I, DL, TLI, DT)) { + if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AT)) { // Mark all uses for resimplification next time round the loop. for (User *U : I->users()) Next->insert(cast<Instruction>(U)); @@ -101,6 +104,7 @@ namespace { char InstSimplifier::ID = 0; INITIALIZE_PASS_BEGIN(InstSimplifier, "instsimplify", "Remove redundant instructions", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_END(InstSimplifier, "instsimplify", "Remove redundant instructions", false, false) diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 3b61bb5..a39f128 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -27,65 +27,43 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetLibraryInfo.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" using namespace llvm; +using namespace PatternMatch; static cl::opt<bool> -ColdErrorCalls("error-reporting-is-cold", cl::init(true), - cl::Hidden, cl::desc("Treat error-reporting calls as cold")); - -/// This class is the abstract base class for the set of optimizations that -/// corresponds to one library call. -namespace { -class LibCallOptimization { -protected: - Function *Caller; - const DataLayout *DL; - const TargetLibraryInfo *TLI; - const LibCallSimplifier *LCS; - LLVMContext* Context; -public: - LibCallOptimization() { } - virtual ~LibCallOptimization() {} - - /// callOptimizer - This pure virtual method is implemented by base classes to - /// do various optimizations. If this returns null then no transformation was - /// performed. If it returns CI, then it transformed the call and CI is to be - /// deleted. If it returns something else, replace CI with the new value and - /// delete CI. - virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) - =0; - - /// ignoreCallingConv - Returns false if this transformation could possibly - /// change the calling convention. - virtual bool ignoreCallingConv() { return false; } - - Value *optimizeCall(CallInst *CI, const DataLayout *DL, - const TargetLibraryInfo *TLI, - const LibCallSimplifier *LCS, IRBuilder<> &B) { - Caller = CI->getParent()->getParent(); - this->DL = DL; - this->TLI = TLI; - this->LCS = LCS; - if (CI->getCalledFunction()) - Context = &CI->getCalledFunction()->getContext(); + ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden, + cl::desc("Treat error-reporting calls as cold")); - // We never change the calling convention. - if (!ignoreCallingConv() && CI->getCallingConv() != llvm::CallingConv::C) - return nullptr; +static cl::opt<bool> + EnableUnsafeFPShrink("enable-double-float-shrink", cl::Hidden, + cl::init(false), + cl::desc("Enable unsafe double to float " + "shrinking for math lib calls")); - return callOptimizer(CI->getCalledFunction(), CI, B); - } -}; //===----------------------------------------------------------------------===// // Helper Functions //===----------------------------------------------------------------------===// +static bool ignoreCallingConv(LibFunc::Func Func) { + switch (Func) { + case LibFunc::abs: + case LibFunc::labs: + case LibFunc::llabs: + case LibFunc::strlen: + return true; + default: + return false; + } + llvm_unreachable("All cases should be covered in the switch."); +} + /// isOnlyUsedInZeroEqualityComparison - Return true if it only matters that the /// value is equal or not-equal to zero. static bool isOnlyUsedInZeroEqualityComparison(Value *V) { @@ -142,967 +120,912 @@ static bool hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, // Fortified Library Call Optimizations //===----------------------------------------------------------------------===// -struct FortifiedLibCallOptimization : public LibCallOptimization { -protected: - virtual bool isFoldable(unsigned SizeCIOp, unsigned SizeArgOp, - bool isString) const = 0; -}; - -struct InstFortifiedLibCallOptimization : public FortifiedLibCallOptimization { - CallInst *CI; - - bool isFoldable(unsigned SizeCIOp, unsigned SizeArgOp, - bool isString) const override { - if (CI->getArgOperand(SizeCIOp) == CI->getArgOperand(SizeArgOp)) +static bool isFortifiedCallFoldable(CallInst *CI, unsigned SizeCIOp, unsigned SizeArgOp, + bool isString) { + if (CI->getArgOperand(SizeCIOp) == CI->getArgOperand(SizeArgOp)) + return true; + if (ConstantInt *SizeCI = + dyn_cast<ConstantInt>(CI->getArgOperand(SizeCIOp))) { + if (SizeCI->isAllOnesValue()) return true; - if (ConstantInt *SizeCI = - dyn_cast<ConstantInt>(CI->getArgOperand(SizeCIOp))) { - if (SizeCI->isAllOnesValue()) - return true; - if (isString) { - uint64_t Len = GetStringLength(CI->getArgOperand(SizeArgOp)); - // If the length is 0 we don't know how long it is and so we can't - // remove the check. - if (Len == 0) return false; - return SizeCI->getZExtValue() >= Len; - } - if (ConstantInt *Arg = dyn_cast<ConstantInt>( - CI->getArgOperand(SizeArgOp))) - return SizeCI->getZExtValue() >= Arg->getZExtValue(); + if (isString) { + uint64_t Len = GetStringLength(CI->getArgOperand(SizeArgOp)); + // If the length is 0 we don't know how long it is and so we can't + // remove the check. + if (Len == 0) + return false; + return SizeCI->getZExtValue() >= Len; } - return false; + if (ConstantInt *Arg = dyn_cast<ConstantInt>(CI->getArgOperand(SizeArgOp))) + return SizeCI->getZExtValue() >= Arg->getZExtValue(); } -}; - -struct MemCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); - - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(Context) || - FT->getParamType(3) != DL->getIntPtrType(Context)) - return nullptr; + return false; +} - if (isFoldable(3, 2, false)) { - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } +Value *LibCallSimplifier::optimizeMemCpyChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + LLVMContext &Context = CI->getContext(); + + // Check if this has the right signature. + if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + FT->getParamType(2) != DL->getIntPtrType(Context) || + FT->getParamType(3) != DL->getIntPtrType(Context)) return nullptr; - } -}; - -struct MemMoveChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(Context) || - FT->getParamType(3) != DL->getIntPtrType(Context)) - return nullptr; - - if (isFoldable(3, 2, false)) { - B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } - return nullptr; + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); } -}; - -struct MemSetChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); - - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - FT->getParamType(2) != DL->getIntPtrType(Context) || - FT->getParamType(3) != DL->getIntPtrType(Context)) - return nullptr; + return nullptr; +} - if (isFoldable(3, 2, false)) { - Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), - false); - B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } +Value *LibCallSimplifier::optimizeMemMoveChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + LLVMContext &Context = CI->getContext(); + + // Check if this has the right signature. + if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + FT->getParamType(2) != DL->getIntPtrType(Context) || + FT->getParamType(3) != DL->getIntPtrType(Context)) return nullptr; + + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); } -}; + return nullptr; +} -struct StrCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - StringRef Name = Callee->getName(); - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); +Value *LibCallSimplifier::optimizeMemSetChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + LLVMContext &Context = CI->getContext(); + + // Check if this has the right signature. + if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isIntegerTy() || + FT->getParamType(2) != DL->getIntPtrType(Context) || + FT->getParamType(3) != DL->getIntPtrType(Context)) + return nullptr; - // Check if this has the right signature. - if (FT->getNumParams() != 3 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - FT->getParamType(2) != DL->getIntPtrType(Context)) - return nullptr; + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); + B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); + return CI->getArgOperand(0); + } + return nullptr; +} - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) // __strcpy_chk(x,x) -> x - return Src; - - // If a) we don't have any length information, or b) we know this will - // fit then just lower to a plain strcpy. Otherwise we'll keep our - // strcpy_chk call which may fail at runtime if the size is too long. - // TODO: It might be nice to get a maximum length out of the possible - // string lengths for varying. - if (isFoldable(2, 1, true)) { - Value *Ret = EmitStrCpy(Dst, Src, B, DL, TLI, Name.substr(2, 6)); - return Ret; - } else { - // Maybe we can stil fold __strcpy_chk to __memcpy_chk. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - - // This optimization require DataLayout. - if (!DL) return nullptr; - - Value *Ret = - EmitMemCpyChk(Dst, Src, - ConstantInt::get(DL->getIntPtrType(Context), Len), - CI->getArgOperand(2), B, DL, TLI); - return Ret; - } +Value *LibCallSimplifier::optimizeStrCpyChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + StringRef Name = Callee->getName(); + FunctionType *FT = Callee->getFunctionType(); + LLVMContext &Context = CI->getContext(); + + // Check if this has the right signature. + if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != Type::getInt8PtrTy(Context) || + FT->getParamType(2) != DL->getIntPtrType(Context)) return nullptr; - } -}; -struct StpCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - StringRef Name = Callee->getName(); - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); + Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); + if (Dst == Src) // __strcpy_chk(x,x) -> x + return Src; + + // If a) we don't have any length information, or b) we know this will + // fit then just lower to a plain strcpy. Otherwise we'll keep our + // strcpy_chk call which may fail at runtime if the size is too long. + // TODO: It might be nice to get a maximum length out of the possible + // string lengths for varying. + if (isFortifiedCallFoldable(CI, 2, 1, true)) { + Value *Ret = EmitStrCpy(Dst, Src, B, DL, TLI, Name.substr(2, 6)); + return Ret; + } else { + // Maybe we can stil fold __strcpy_chk to __memcpy_chk. + uint64_t Len = GetStringLength(Src); + if (Len == 0) + return nullptr; - // Check if this has the right signature. - if (FT->getNumParams() != 3 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - FT->getParamType(2) != DL->getIntPtrType(FT->getParamType(0))) + // This optimization require DataLayout. + if (!DL) return nullptr; - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) - Value *StrLen = EmitStrLen(Src, B, DL, TLI); - return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; - } + Value *Ret = EmitMemCpyChk( + Dst, Src, ConstantInt::get(DL->getIntPtrType(Context), Len), + CI->getArgOperand(2), B, DL, TLI); + return Ret; + } + return nullptr; +} - // If a) we don't have any length information, or b) we know this will - // fit then just lower to a plain stpcpy. Otherwise we'll keep our - // stpcpy_chk call which may fail at runtime if the size is too long. - // TODO: It might be nice to get a maximum length out of the possible - // string lengths for varying. - if (isFoldable(2, 1, true)) { - Value *Ret = EmitStrCpy(Dst, Src, B, DL, TLI, Name.substr(2, 6)); - return Ret; - } else { - // Maybe we can stil fold __stpcpy_chk to __memcpy_chk. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - - // This optimization require DataLayout. - if (!DL) return nullptr; - - Type *PT = FT->getParamType(0); - Value *LenV = ConstantInt::get(DL->getIntPtrType(PT), Len); - Value *DstEnd = B.CreateGEP(Dst, - ConstantInt::get(DL->getIntPtrType(PT), - Len - 1)); - if (!EmitMemCpyChk(Dst, Src, LenV, CI->getArgOperand(2), B, DL, TLI)) - return nullptr; - return DstEnd; - } +Value *LibCallSimplifier::optimizeStpCpyChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + StringRef Name = Callee->getName(); + FunctionType *FT = Callee->getFunctionType(); + LLVMContext &Context = CI->getContext(); + + // Check if this has the right signature. + if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != Type::getInt8PtrTy(Context) || + FT->getParamType(2) != DL->getIntPtrType(FT->getParamType(0))) return nullptr; + + Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); + if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) + Value *StrLen = EmitStrLen(Src, B, DL, TLI); + return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; } -}; -struct StrNCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - StringRef Name = Callee->getName(); - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); + // If a) we don't have any length information, or b) we know this will + // fit then just lower to a plain stpcpy. Otherwise we'll keep our + // stpcpy_chk call which may fail at runtime if the size is too long. + // TODO: It might be nice to get a maximum length out of the possible + // string lengths for varying. + if (isFortifiedCallFoldable(CI, 2, 1, true)) { + Value *Ret = EmitStrCpy(Dst, Src, B, DL, TLI, Name.substr(2, 6)); + return Ret; + } else { + // Maybe we can stil fold __stpcpy_chk to __memcpy_chk. + uint64_t Len = GetStringLength(Src); + if (Len == 0) + return nullptr; + + // This optimization require DataLayout. + if (!DL) + return nullptr; - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - !FT->getParamType(2)->isIntegerTy() || - FT->getParamType(3) != DL->getIntPtrType(Context)) + Type *PT = FT->getParamType(0); + Value *LenV = ConstantInt::get(DL->getIntPtrType(PT), Len); + Value *DstEnd = + B.CreateGEP(Dst, ConstantInt::get(DL->getIntPtrType(PT), Len - 1)); + if (!EmitMemCpyChk(Dst, Src, LenV, CI->getArgOperand(2), B, DL, TLI)) return nullptr; + return DstEnd; + } + return nullptr; +} - if (isFoldable(3, 2, false)) { - Value *Ret = EmitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), B, DL, TLI, - Name.substr(2, 7)); - return Ret; - } +Value *LibCallSimplifier::optimizeStrNCpyChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + StringRef Name = Callee->getName(); + FunctionType *FT = Callee->getFunctionType(); + LLVMContext &Context = CI->getContext(); + + // Check if this has the right signature. + if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != Type::getInt8PtrTy(Context) || + !FT->getParamType(2)->isIntegerTy() || + FT->getParamType(3) != DL->getIntPtrType(Context)) return nullptr; + + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + Value *Ret = + EmitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), B, DL, TLI, Name.substr(2, 7)); + return Ret; } -}; + return nullptr; +} //===----------------------------------------------------------------------===// // String and Memory Library Call Optimizations //===----------------------------------------------------------------------===// -struct StrCatOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strcat" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - FT->getParamType(1) != FT->getReturnType()) - return nullptr; +Value *LibCallSimplifier::optimizeStrCat(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strcat" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2|| + FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + FT->getParamType(1) != FT->getReturnType()) + return nullptr; - // Extract some information from the instruction - Value *Dst = CI->getArgOperand(0); - Value *Src = CI->getArgOperand(1); + // Extract some information from the instruction + Value *Dst = CI->getArgOperand(0); + Value *Src = CI->getArgOperand(1); - // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - --Len; // Unbias length. + // See if we can get the length of the input string. + uint64_t Len = GetStringLength(Src); + if (Len == 0) + return nullptr; + --Len; // Unbias length. - // Handle the simple, do-nothing case: strcat(x, "") -> x - if (Len == 0) - return Dst; + // Handle the simple, do-nothing case: strcat(x, "") -> x + if (Len == 0) + return Dst; - // These optimizations require DataLayout. - if (!DL) return nullptr; + // These optimizations require DataLayout. + if (!DL) + return nullptr; - return emitStrLenMemCpy(Src, Dst, Len, B); - } + return emitStrLenMemCpy(Src, Dst, Len, B); +} - Value *emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, - IRBuilder<> &B) { - // We need to find the end of the destination string. That's where the - // memory is to be moved to. We just generate a call to strlen. - Value *DstLen = EmitStrLen(Dst, B, DL, TLI); - if (!DstLen) - return nullptr; +Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, + IRBuilder<> &B) { + // We need to find the end of the destination string. That's where the + // memory is to be moved to. We just generate a call to strlen. + Value *DstLen = EmitStrLen(Dst, B, DL, TLI); + if (!DstLen) + return nullptr; - // Now that we have the destination's length, we must index into the - // destination's pointer to get the actual memcpy destination (end of - // the string .. we're concatenating). - Value *CpyDst = B.CreateGEP(Dst, DstLen, "endptr"); + // Now that we have the destination's length, we must index into the + // destination's pointer to get the actual memcpy destination (end of + // the string .. we're concatenating). + Value *CpyDst = B.CreateGEP(Dst, DstLen, "endptr"); + + // We have enough information to now generate the memcpy call to do the + // concatenation for us. Make a memcpy to copy the nul byte with align = 1. + B.CreateMemCpy( + CpyDst, Src, + ConstantInt::get(DL->getIntPtrType(Src->getContext()), Len + 1), 1); + return Dst; +} - // We have enough information to now generate the memcpy call to do the - // concatenation for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(CpyDst, Src, - ConstantInt::get(DL->getIntPtrType(*Context), Len + 1), 1); - return Dst; - } -}; - -struct StrNCatOpt : public StrCatOpt { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strncat" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - FT->getParamType(1) != FT->getReturnType() || - !FT->getParamType(2)->isIntegerTy()) - return nullptr; +Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strncat" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + FT->getParamType(1) != FT->getReturnType() || + !FT->getParamType(2)->isIntegerTy()) + return nullptr; - // Extract some information from the instruction - Value *Dst = CI->getArgOperand(0); - Value *Src = CI->getArgOperand(1); - uint64_t Len; + // Extract some information from the instruction + Value *Dst = CI->getArgOperand(0); + Value *Src = CI->getArgOperand(1); + uint64_t Len; - // We don't do anything if length is not constant - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) - Len = LengthArg->getZExtValue(); - else - return nullptr; + // We don't do anything if length is not constant + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) + Len = LengthArg->getZExtValue(); + else + return nullptr; - // See if we can get the length of the input string. - uint64_t SrcLen = GetStringLength(Src); - if (SrcLen == 0) return nullptr; - --SrcLen; // Unbias length. + // See if we can get the length of the input string. + uint64_t SrcLen = GetStringLength(Src); + if (SrcLen == 0) + return nullptr; + --SrcLen; // Unbias length. - // Handle the simple, do-nothing cases: - // strncat(x, "", c) -> x - // strncat(x, c, 0) -> x - if (SrcLen == 0 || Len == 0) return Dst; + // Handle the simple, do-nothing cases: + // strncat(x, "", c) -> x + // strncat(x, c, 0) -> x + if (SrcLen == 0 || Len == 0) + return Dst; - // These optimizations require DataLayout. - if (!DL) return nullptr; - - // We don't optimize this case - if (Len < SrcLen) return nullptr; - - // strncat(x, s, c) -> strcat(x, s) - // s is constant so the strcat can be optimized further - return emitStrLenMemCpy(Src, Dst, SrcLen, B); - } -}; - -struct StrChrOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strchr" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - !FT->getParamType(1)->isIntegerTy(32)) - return nullptr; + // These optimizations require DataLayout. + if (!DL) + return nullptr; - Value *SrcStr = CI->getArgOperand(0); + // We don't optimize this case + if (Len < SrcLen) + return nullptr; - // If the second operand is non-constant, see if we can compute the length - // of the input string and turn this into memchr. - ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); - if (!CharC) { - // These optimizations require DataLayout. - if (!DL) return nullptr; + // strncat(x, s, c) -> strcat(x, s) + // s is constant so the strcat can be optimized further + return emitStrLenMemCpy(Src, Dst, SrcLen, B); +} - uint64_t Len = GetStringLength(SrcStr); - if (Len == 0 || !FT->getParamType(1)->isIntegerTy(32))// memchr needs i32. - return nullptr; +Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strchr" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + !FT->getParamType(1)->isIntegerTy(32)) + return nullptr; - return EmitMemChr(SrcStr, CI->getArgOperand(1), // include nul. - ConstantInt::get(DL->getIntPtrType(*Context), Len), - B, DL, TLI); - } + Value *SrcStr = CI->getArgOperand(0); - // Otherwise, the character is a constant, see if the first argument is - // a string literal. If so, we can constant fold. - StringRef Str; - if (!getConstantStringInfo(SrcStr, Str)) { - if (DL && CharC->isZero()) // strchr(p, 0) -> p + strlen(p) - return B.CreateGEP(SrcStr, EmitStrLen(SrcStr, B, DL, TLI), "strchr"); + // If the second operand is non-constant, see if we can compute the length + // of the input string and turn this into memchr. + ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (!CharC) { + // These optimizations require DataLayout. + if (!DL) return nullptr; - } - // Compute the offset, make sure to handle the case when we're searching for - // zero (a weird way to spell strlen). - size_t I = (0xFF & CharC->getSExtValue()) == 0 ? - Str.size() : Str.find(CharC->getSExtValue()); - if (I == StringRef::npos) // Didn't find the char. strchr returns null. - return Constant::getNullValue(CI->getType()); + uint64_t Len = GetStringLength(SrcStr); + if (Len == 0 || !FT->getParamType(1)->isIntegerTy(32)) // memchr needs i32. + return nullptr; - // strchr(s+n,c) -> gep(s+n+i,c) - return B.CreateGEP(SrcStr, B.getInt64(I), "strchr"); + return EmitMemChr( + SrcStr, CI->getArgOperand(1), // include nul. + ConstantInt::get(DL->getIntPtrType(CI->getContext()), Len), B, DL, TLI); } -}; -struct StrRChrOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strrchr" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - !FT->getParamType(1)->isIntegerTy(32)) - return nullptr; + // Otherwise, the character is a constant, see if the first argument is + // a string literal. If so, we can constant fold. + StringRef Str; + if (!getConstantStringInfo(SrcStr, Str)) { + if (DL && CharC->isZero()) // strchr(p, 0) -> p + strlen(p) + return B.CreateGEP(SrcStr, EmitStrLen(SrcStr, B, DL, TLI), "strchr"); + return nullptr; + } - Value *SrcStr = CI->getArgOperand(0); - ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + // Compute the offset, make sure to handle the case when we're searching for + // zero (a weird way to spell strlen). + size_t I = (0xFF & CharC->getSExtValue()) == 0 + ? Str.size() + : Str.find(CharC->getSExtValue()); + if (I == StringRef::npos) // Didn't find the char. strchr returns null. + return Constant::getNullValue(CI->getType()); - // Cannot fold anything if we're not looking for a constant. - if (!CharC) - return nullptr; + // strchr(s+n,c) -> gep(s+n+i,c) + return B.CreateGEP(SrcStr, B.getInt64(I), "strchr"); +} - StringRef Str; - if (!getConstantStringInfo(SrcStr, Str)) { - // strrchr(s, 0) -> strchr(s, 0) - if (DL && CharC->isZero()) - return EmitStrChr(SrcStr, '\0', B, DL, TLI); - return nullptr; - } +Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strrchr" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + !FT->getParamType(1)->isIntegerTy(32)) + return nullptr; - // Compute the offset. - size_t I = (0xFF & CharC->getSExtValue()) == 0 ? - Str.size() : Str.rfind(CharC->getSExtValue()); - if (I == StringRef::npos) // Didn't find the char. Return null. - return Constant::getNullValue(CI->getType()); + Value *SrcStr = CI->getArgOperand(0); + ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); - // strrchr(s+n,c) -> gep(s+n+i,c) - return B.CreateGEP(SrcStr, B.getInt64(I), "strrchr"); + // Cannot fold anything if we're not looking for a constant. + if (!CharC) + return nullptr; + + StringRef Str; + if (!getConstantStringInfo(SrcStr, Str)) { + // strrchr(s, 0) -> strchr(s, 0) + if (DL && CharC->isZero()) + return EmitStrChr(SrcStr, '\0', B, DL, TLI); + return nullptr; } -}; -struct StrCmpOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strcmp" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - !FT->getReturnType()->isIntegerTy(32) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy()) - return nullptr; + // Compute the offset. + size_t I = (0xFF & CharC->getSExtValue()) == 0 + ? Str.size() + : Str.rfind(CharC->getSExtValue()); + if (I == StringRef::npos) // Didn't find the char. Return null. + return Constant::getNullValue(CI->getType()); - Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); - if (Str1P == Str2P) // strcmp(x,x) -> 0 - return ConstantInt::get(CI->getType(), 0); + // strrchr(s+n,c) -> gep(s+n+i,c) + return B.CreateGEP(SrcStr, B.getInt64(I), "strrchr"); +} - StringRef Str1, Str2; - bool HasStr1 = getConstantStringInfo(Str1P, Str1); - bool HasStr2 = getConstantStringInfo(Str2P, Str2); +Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strcmp" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getReturnType()->isIntegerTy(32) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != B.getInt8PtrTy()) + return nullptr; - // strcmp(x, y) -> cnst (if both x and y are constant strings) - if (HasStr1 && HasStr2) - return ConstantInt::get(CI->getType(), Str1.compare(Str2)); + Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); + if (Str1P == Str2P) // strcmp(x,x) -> 0 + return ConstantInt::get(CI->getType(), 0); - if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x - return B.CreateNeg(B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), - CI->getType())); + StringRef Str1, Str2; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + bool HasStr2 = getConstantStringInfo(Str2P, Str2); - if (HasStr2 && Str2.empty()) // strcmp(x,"") -> *x - return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + // strcmp(x, y) -> cnst (if both x and y are constant strings) + if (HasStr1 && HasStr2) + return ConstantInt::get(CI->getType(), Str1.compare(Str2)); - // strcmp(P, "x") -> memcmp(P, "x", 2) - uint64_t Len1 = GetStringLength(Str1P); - uint64_t Len2 = GetStringLength(Str2P); - if (Len1 && Len2) { - // These optimizations require DataLayout. - if (!DL) return nullptr; + if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x + return B.CreateNeg( + B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), CI->getType())); - return EmitMemCmp(Str1P, Str2P, - ConstantInt::get(DL->getIntPtrType(*Context), - std::min(Len1, Len2)), B, DL, TLI); - } + if (HasStr2 && Str2.empty()) // strcmp(x,"") -> *x + return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); - return nullptr; - } -}; - -struct StrNCmpOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strncmp" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || - !FT->getReturnType()->isIntegerTy(32) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy() || - !FT->getParamType(2)->isIntegerTy()) + // strcmp(P, "x") -> memcmp(P, "x", 2) + uint64_t Len1 = GetStringLength(Str1P); + uint64_t Len2 = GetStringLength(Str2P); + if (Len1 && Len2) { + // These optimizations require DataLayout. + if (!DL) return nullptr; - Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); - if (Str1P == Str2P) // strncmp(x,x,n) -> 0 - return ConstantInt::get(CI->getType(), 0); + return EmitMemCmp(Str1P, Str2P, + ConstantInt::get(DL->getIntPtrType(CI->getContext()), + std::min(Len1, Len2)), + B, DL, TLI); + } - // Get the length argument if it is constant. - uint64_t Length; - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) - Length = LengthArg->getZExtValue(); - else - return nullptr; + return nullptr; +} - if (Length == 0) // strncmp(x,y,0) -> 0 - return ConstantInt::get(CI->getType(), 0); +Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strncmp" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || !FT->getReturnType()->isIntegerTy(32) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != B.getInt8PtrTy() || + !FT->getParamType(2)->isIntegerTy()) + return nullptr; - if (DL && Length == 1) // strncmp(x,y,1) -> memcmp(x,y,1) - return EmitMemCmp(Str1P, Str2P, CI->getArgOperand(2), B, DL, TLI); + Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); + if (Str1P == Str2P) // strncmp(x,x,n) -> 0 + return ConstantInt::get(CI->getType(), 0); - StringRef Str1, Str2; - bool HasStr1 = getConstantStringInfo(Str1P, Str1); - bool HasStr2 = getConstantStringInfo(Str2P, Str2); + // Get the length argument if it is constant. + uint64_t Length; + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) + Length = LengthArg->getZExtValue(); + else + return nullptr; - // strncmp(x, y) -> cnst (if both x and y are constant strings) - if (HasStr1 && HasStr2) { - StringRef SubStr1 = Str1.substr(0, Length); - StringRef SubStr2 = Str2.substr(0, Length); - return ConstantInt::get(CI->getType(), SubStr1.compare(SubStr2)); - } + if (Length == 0) // strncmp(x,y,0) -> 0 + return ConstantInt::get(CI->getType(), 0); - if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x - return B.CreateNeg(B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), - CI->getType())); + if (DL && Length == 1) // strncmp(x,y,1) -> memcmp(x,y,1) + return EmitMemCmp(Str1P, Str2P, CI->getArgOperand(2), B, DL, TLI); - if (HasStr2 && Str2.empty()) // strncmp(x, "", n) -> *x - return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + StringRef Str1, Str2; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + bool HasStr2 = getConstantStringInfo(Str2P, Str2); - return nullptr; + // strncmp(x, y) -> cnst (if both x and y are constant strings) + if (HasStr1 && HasStr2) { + StringRef SubStr1 = Str1.substr(0, Length); + StringRef SubStr2 = Str2.substr(0, Length); + return ConstantInt::get(CI->getType(), SubStr1.compare(SubStr2)); } -}; -struct StrCpyOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strcpy" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy()) - return nullptr; + if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x + return B.CreateNeg( + B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), CI->getType())); - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) // strcpy(x,x) -> x - return Src; + if (HasStr2 && Str2.empty()) // strncmp(x, "", n) -> *x + return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); - // These optimizations require DataLayout. - if (!DL) return nullptr; - - // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - - // We have enough information to now generate the memcpy call to do the - // copy for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(Dst, Src, - ConstantInt::get(DL->getIntPtrType(*Context), Len), 1); - return Dst; - } -}; - -struct StpCpyOpt: public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "stpcpy" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy()) - return nullptr; + return nullptr; +} - // These optimizations require DataLayout. - if (!DL) return nullptr; +Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strcpy" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != B.getInt8PtrTy()) + return nullptr; - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) - Value *StrLen = EmitStrLen(Src, B, DL, TLI); - return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; - } + Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); + if (Dst == Src) // strcpy(x,x) -> x + return Src; - // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; + // These optimizations require DataLayout. + if (!DL) + return nullptr; - Type *PT = FT->getParamType(0); - Value *LenV = ConstantInt::get(DL->getIntPtrType(PT), Len); - Value *DstEnd = B.CreateGEP(Dst, - ConstantInt::get(DL->getIntPtrType(PT), - Len - 1)); + // See if we can get the length of the input string. + uint64_t Len = GetStringLength(Src); + if (Len == 0) + return nullptr; - // We have enough information to now generate the memcpy call to do the - // copy for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(Dst, Src, LenV, 1); - return DstEnd; - } -}; + // We have enough information to now generate the memcpy call to do the + // copy for us. Make a memcpy to copy the nul byte with align = 1. + B.CreateMemCpy(Dst, Src, + ConstantInt::get(DL->getIntPtrType(CI->getContext()), Len), 1); + return Dst; +} -struct StrNCpyOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy() || - !FT->getParamType(2)->isIntegerTy()) - return nullptr; +Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "stpcpy" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != B.getInt8PtrTy()) + return nullptr; - Value *Dst = CI->getArgOperand(0); - Value *Src = CI->getArgOperand(1); - Value *LenOp = CI->getArgOperand(2); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // See if we can get the length of the input string. - uint64_t SrcLen = GetStringLength(Src); - if (SrcLen == 0) return nullptr; - --SrcLen; + Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); + if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) + Value *StrLen = EmitStrLen(Src, B, DL, TLI); + return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; + } - if (SrcLen == 0) { - // strncpy(x, "", y) -> memset(x, '\0', y, 1) - B.CreateMemSet(Dst, B.getInt8('\0'), LenOp, 1); - return Dst; - } + // See if we can get the length of the input string. + uint64_t Len = GetStringLength(Src); + if (Len == 0) + return nullptr; - uint64_t Len; - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(LenOp)) - Len = LengthArg->getZExtValue(); - else - return nullptr; + Type *PT = FT->getParamType(0); + Value *LenV = ConstantInt::get(DL->getIntPtrType(PT), Len); + Value *DstEnd = + B.CreateGEP(Dst, ConstantInt::get(DL->getIntPtrType(PT), Len - 1)); - if (Len == 0) return Dst; // strncpy(x, y, 0) -> x + // We have enough information to now generate the memcpy call to do the + // copy for us. Make a memcpy to copy the nul byte with align = 1. + B.CreateMemCpy(Dst, Src, LenV, 1); + return DstEnd; +} - // These optimizations require DataLayout. - if (!DL) return nullptr; +Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != B.getInt8PtrTy() || + !FT->getParamType(2)->isIntegerTy()) + return nullptr; - // Let strncpy handle the zero padding - if (Len > SrcLen+1) return nullptr; + Value *Dst = CI->getArgOperand(0); + Value *Src = CI->getArgOperand(1); + Value *LenOp = CI->getArgOperand(2); - Type *PT = FT->getParamType(0); - // strncpy(x, s, c) -> memcpy(x, s, c, 1) [s and c are constant] - B.CreateMemCpy(Dst, Src, - ConstantInt::get(DL->getIntPtrType(PT), Len), 1); + // See if we can get the length of the input string. + uint64_t SrcLen = GetStringLength(Src); + if (SrcLen == 0) + return nullptr; + --SrcLen; + if (SrcLen == 0) { + // strncpy(x, "", y) -> memset(x, '\0', y, 1) + B.CreateMemSet(Dst, B.getInt8('\0'), LenOp, 1); return Dst; } -}; - -struct StrLenOpt : public LibCallOptimization { - bool ignoreCallingConv() override { return true; } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 1 || - FT->getParamType(0) != B.getInt8PtrTy() || - !FT->getReturnType()->isIntegerTy()) - return nullptr; - Value *Src = CI->getArgOperand(0); - - // Constant folding: strlen("xyz") -> 3 - if (uint64_t Len = GetStringLength(Src)) - return ConstantInt::get(CI->getType(), Len-1); - - // strlen(x?"foo":"bars") --> x ? 3 : 4 - if (SelectInst *SI = dyn_cast<SelectInst>(Src)) { - uint64_t LenTrue = GetStringLength(SI->getTrueValue()); - uint64_t LenFalse = GetStringLength(SI->getFalseValue()); - if (LenTrue && LenFalse) { - emitOptimizationRemark(*Context, "simplify-libcalls", *Caller, - SI->getDebugLoc(), - "folded strlen(select) to select of constants"); - return B.CreateSelect(SI->getCondition(), - ConstantInt::get(CI->getType(), LenTrue-1), - ConstantInt::get(CI->getType(), LenFalse-1)); - } - } + uint64_t Len; + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(LenOp)) + Len = LengthArg->getZExtValue(); + else + return nullptr; - // strlen(x) != 0 --> *x != 0 - // strlen(x) == 0 --> *x == 0 - if (isOnlyUsedInZeroEqualityComparison(CI)) - return B.CreateZExt(B.CreateLoad(Src, "strlenfirst"), CI->getType()); + if (Len == 0) + return Dst; // strncpy(x, y, 0) -> x + // These optimizations require DataLayout. + if (!DL) return nullptr; - } -}; -struct StrPBrkOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - FT->getReturnType() != FT->getParamType(0)) - return nullptr; + // Let strncpy handle the zero padding + if (Len > SrcLen + 1) + return nullptr; - StringRef S1, S2; - bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); - bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + Type *PT = FT->getParamType(0); + // strncpy(x, s, c) -> memcpy(x, s, c, 1) [s and c are constant] + B.CreateMemCpy(Dst, Src, ConstantInt::get(DL->getIntPtrType(PT), Len), 1); - // strpbrk(s, "") -> NULL - // strpbrk("", s) -> NULL - if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) - return Constant::getNullValue(CI->getType()); + return Dst; +} - // Constant folding. - if (HasS1 && HasS2) { - size_t I = S1.find_first_of(S2); - if (I == StringRef::npos) // No match. - return Constant::getNullValue(CI->getType()); +Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 1 || FT->getParamType(0) != B.getInt8PtrTy() || + !FT->getReturnType()->isIntegerTy()) + return nullptr; - return B.CreateGEP(CI->getArgOperand(0), B.getInt64(I), "strpbrk"); + Value *Src = CI->getArgOperand(0); + + // Constant folding: strlen("xyz") -> 3 + if (uint64_t Len = GetStringLength(Src)) + return ConstantInt::get(CI->getType(), Len - 1); + + // strlen(x?"foo":"bars") --> x ? 3 : 4 + if (SelectInst *SI = dyn_cast<SelectInst>(Src)) { + uint64_t LenTrue = GetStringLength(SI->getTrueValue()); + uint64_t LenFalse = GetStringLength(SI->getFalseValue()); + if (LenTrue && LenFalse) { + Function *Caller = CI->getParent()->getParent(); + emitOptimizationRemark(CI->getContext(), "simplify-libcalls", *Caller, + SI->getDebugLoc(), + "folded strlen(select) to select of constants"); + return B.CreateSelect(SI->getCondition(), + ConstantInt::get(CI->getType(), LenTrue - 1), + ConstantInt::get(CI->getType(), LenFalse - 1)); } - - // strpbrk(s, "a") -> strchr(s, 'a') - if (DL && HasS2 && S2.size() == 1) - return EmitStrChr(CI->getArgOperand(0), S2[0], B, DL, TLI); - - return nullptr; } -}; -struct StrToOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if ((FT->getNumParams() != 2 && FT->getNumParams() != 3) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy()) - return nullptr; + // strlen(x) != 0 --> *x != 0 + // strlen(x) == 0 --> *x == 0 + if (isOnlyUsedInZeroEqualityComparison(CI)) + return B.CreateZExt(B.CreateLoad(Src, "strlenfirst"), CI->getType()); - Value *EndPtr = CI->getArgOperand(1); - if (isa<ConstantPointerNull>(EndPtr)) { - // With a null EndPtr, this function won't capture the main argument. - // It would be readonly too, except that it still may write to errno. - CI->addAttribute(1, Attribute::NoCapture); - } + return nullptr; +} +Value *LibCallSimplifier::optimizeStrPBrk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || + FT->getParamType(1) != FT->getParamType(0) || + FT->getReturnType() != FT->getParamType(0)) return nullptr; - } -}; -struct StrSpnOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - !FT->getReturnType()->isIntegerTy()) - return nullptr; + StringRef S1, S2; + bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); + bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); - StringRef S1, S2; - bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); - bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + // strpbrk(s, "") -> nullptr + // strpbrk("", s) -> nullptr + if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) + return Constant::getNullValue(CI->getType()); - // strspn(s, "") -> 0 - // strspn("", s) -> 0 - if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) + // Constant folding. + if (HasS1 && HasS2) { + size_t I = S1.find_first_of(S2); + if (I == StringRef::npos) // No match. return Constant::getNullValue(CI->getType()); - // Constant folding. - if (HasS1 && HasS2) { - size_t Pos = S1.find_first_not_of(S2); - if (Pos == StringRef::npos) Pos = S1.size(); - return ConstantInt::get(CI->getType(), Pos); - } - - return nullptr; + return B.CreateGEP(CI->getArgOperand(0), B.getInt64(I), "strpbrk"); } -}; -struct StrCSpnOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - !FT->getReturnType()->isIntegerTy()) - return nullptr; + // strpbrk(s, "a") -> strchr(s, 'a') + if (DL && HasS2 && S2.size() == 1) + return EmitStrChr(CI->getArgOperand(0), S2[0], B, DL, TLI); - StringRef S1, S2; - bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); - bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + return nullptr; +} - // strcspn("", s) -> 0 - if (HasS1 && S1.empty()) - return Constant::getNullValue(CI->getType()); +Value *LibCallSimplifier::optimizeStrTo(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if ((FT->getNumParams() != 2 && FT->getNumParams() != 3) || + !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy()) + return nullptr; - // Constant folding. - if (HasS1 && HasS2) { - size_t Pos = S1.find_first_of(S2); - if (Pos == StringRef::npos) Pos = S1.size(); - return ConstantInt::get(CI->getType(), Pos); - } + Value *EndPtr = CI->getArgOperand(1); + if (isa<ConstantPointerNull>(EndPtr)) { + // With a null EndPtr, this function won't capture the main argument. + // It would be readonly too, except that it still may write to errno. + CI->addAttribute(1, Attribute::NoCapture); + } - // strcspn(s, "") -> strlen(s) - if (DL && HasS2 && S2.empty()) - return EmitStrLen(CI->getArgOperand(0), B, DL, TLI); + return nullptr; +} +Value *LibCallSimplifier::optimizeStrSpn(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || + FT->getParamType(1) != FT->getParamType(0) || + !FT->getReturnType()->isIntegerTy()) return nullptr; + + StringRef S1, S2; + bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); + bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + + // strspn(s, "") -> 0 + // strspn("", s) -> 0 + if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) + return Constant::getNullValue(CI->getType()); + + // Constant folding. + if (HasS1 && HasS2) { + size_t Pos = S1.find_first_not_of(S2); + if (Pos == StringRef::npos) + Pos = S1.size(); + return ConstantInt::get(CI->getType(), Pos); } -}; -struct StrStrOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isPointerTy()) - return nullptr; + return nullptr; +} - // fold strstr(x, x) -> x. - if (CI->getArgOperand(0) == CI->getArgOperand(1)) - return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); +Value *LibCallSimplifier::optimizeStrCSpn(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || + FT->getParamType(1) != FT->getParamType(0) || + !FT->getReturnType()->isIntegerTy()) + return nullptr; - // fold strstr(a, b) == a -> strncmp(a, b, strlen(b)) == 0 - if (DL && isOnlyUsedInEqualityComparison(CI, CI->getArgOperand(0))) { - Value *StrLen = EmitStrLen(CI->getArgOperand(1), B, DL, TLI); - if (!StrLen) - return nullptr; - Value *StrNCmp = EmitStrNCmp(CI->getArgOperand(0), CI->getArgOperand(1), - StrLen, B, DL, TLI); - if (!StrNCmp) - return nullptr; - for (auto UI = CI->user_begin(), UE = CI->user_end(); UI != UE;) { - ICmpInst *Old = cast<ICmpInst>(*UI++); - Value *Cmp = B.CreateICmp(Old->getPredicate(), StrNCmp, - ConstantInt::getNullValue(StrNCmp->getType()), - "cmp"); - LCS->replaceAllUsesWith(Old, Cmp); - } - return CI; - } + StringRef S1, S2; + bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); + bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); - // See if either input string is a constant string. - StringRef SearchStr, ToFindStr; - bool HasStr1 = getConstantStringInfo(CI->getArgOperand(0), SearchStr); - bool HasStr2 = getConstantStringInfo(CI->getArgOperand(1), ToFindStr); + // strcspn("", s) -> 0 + if (HasS1 && S1.empty()) + return Constant::getNullValue(CI->getType()); - // fold strstr(x, "") -> x. - if (HasStr2 && ToFindStr.empty()) - return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); + // Constant folding. + if (HasS1 && HasS2) { + size_t Pos = S1.find_first_of(S2); + if (Pos == StringRef::npos) + Pos = S1.size(); + return ConstantInt::get(CI->getType(), Pos); + } - // If both strings are known, constant fold it. - if (HasStr1 && HasStr2) { - size_t Offset = SearchStr.find(ToFindStr); + // strcspn(s, "") -> strlen(s) + if (DL && HasS2 && S2.empty()) + return EmitStrLen(CI->getArgOperand(0), B, DL, TLI); - if (Offset == StringRef::npos) // strstr("foo", "bar") -> null - return Constant::getNullValue(CI->getType()); + return nullptr; +} - // strstr("abcd", "bc") -> gep((char*)"abcd", 1) - Value *Result = CastToCStr(CI->getArgOperand(0), B); - Result = B.CreateConstInBoundsGEP1_64(Result, Offset, "strstr"); - return B.CreateBitCast(Result, CI->getType()); - } +Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isPointerTy()) + return nullptr; + + // fold strstr(x, x) -> x. + if (CI->getArgOperand(0) == CI->getArgOperand(1)) + return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); - // fold strstr(x, "y") -> strchr(x, 'y'). - if (HasStr2 && ToFindStr.size() == 1) { - Value *StrChr= EmitStrChr(CI->getArgOperand(0), ToFindStr[0], B, DL, TLI); - return StrChr ? B.CreateBitCast(StrChr, CI->getType()) : nullptr; + // fold strstr(a, b) == a -> strncmp(a, b, strlen(b)) == 0 + if (DL && isOnlyUsedInEqualityComparison(CI, CI->getArgOperand(0))) { + Value *StrLen = EmitStrLen(CI->getArgOperand(1), B, DL, TLI); + if (!StrLen) + return nullptr; + Value *StrNCmp = EmitStrNCmp(CI->getArgOperand(0), CI->getArgOperand(1), + StrLen, B, DL, TLI); + if (!StrNCmp) + return nullptr; + for (auto UI = CI->user_begin(), UE = CI->user_end(); UI != UE;) { + ICmpInst *Old = cast<ICmpInst>(*UI++); + Value *Cmp = + B.CreateICmp(Old->getPredicate(), StrNCmp, + ConstantInt::getNullValue(StrNCmp->getType()), "cmp"); + replaceAllUsesWith(Old, Cmp); } - return nullptr; + return CI; } -}; -struct MemCmpOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy(32)) - return nullptr; + // See if either input string is a constant string. + StringRef SearchStr, ToFindStr; + bool HasStr1 = getConstantStringInfo(CI->getArgOperand(0), SearchStr); + bool HasStr2 = getConstantStringInfo(CI->getArgOperand(1), ToFindStr); + + // fold strstr(x, "") -> x. + if (HasStr2 && ToFindStr.empty()) + return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); - Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1); + // If both strings are known, constant fold it. + if (HasStr1 && HasStr2) { + size_t Offset = SearchStr.find(ToFindStr); - if (LHS == RHS) // memcmp(s,s,x) -> 0 + if (Offset == StringRef::npos) // strstr("foo", "bar") -> null return Constant::getNullValue(CI->getType()); - // Make sure we have a constant length. - ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); - if (!LenC) return nullptr; - uint64_t Len = LenC->getZExtValue(); + // strstr("abcd", "bc") -> gep((char*)"abcd", 1) + Value *Result = CastToCStr(CI->getArgOperand(0), B); + Result = B.CreateConstInBoundsGEP1_64(Result, Offset, "strstr"); + return B.CreateBitCast(Result, CI->getType()); + } - if (Len == 0) // memcmp(s1,s2,0) -> 0 - return Constant::getNullValue(CI->getType()); + // fold strstr(x, "y") -> strchr(x, 'y'). + if (HasStr2 && ToFindStr.size() == 1) { + Value *StrChr = EmitStrChr(CI->getArgOperand(0), ToFindStr[0], B, DL, TLI); + return StrChr ? B.CreateBitCast(StrChr, CI->getType()) : nullptr; + } + return nullptr; +} - // memcmp(S1,S2,1) -> *(unsigned char*)LHS - *(unsigned char*)RHS - if (Len == 1) { - Value *LHSV = B.CreateZExt(B.CreateLoad(CastToCStr(LHS, B), "lhsc"), - CI->getType(), "lhsv"); - Value *RHSV = B.CreateZExt(B.CreateLoad(CastToCStr(RHS, B), "rhsc"), - CI->getType(), "rhsv"); - return B.CreateSub(LHSV, RHSV, "chardiff"); - } +Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isIntegerTy(32)) + return nullptr; - // Constant folding: memcmp(x, y, l) -> cnst (all arguments are constant) - StringRef LHSStr, RHSStr; - if (getConstantStringInfo(LHS, LHSStr) && - getConstantStringInfo(RHS, RHSStr)) { - // Make sure we're not reading out-of-bounds memory. - if (Len > LHSStr.size() || Len > RHSStr.size()) - return nullptr; - // Fold the memcmp and normalize the result. This way we get consistent - // results across multiple platforms. - uint64_t Ret = 0; - int Cmp = memcmp(LHSStr.data(), RHSStr.data(), Len); - if (Cmp < 0) - Ret = -1; - else if (Cmp > 0) - Ret = 1; - return ConstantInt::get(CI->getType(), Ret); - } + Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1); + + if (LHS == RHS) // memcmp(s,s,x) -> 0 + return Constant::getNullValue(CI->getType()); + // Make sure we have a constant length. + ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); + if (!LenC) return nullptr; + uint64_t Len = LenC->getZExtValue(); + + if (Len == 0) // memcmp(s1,s2,0) -> 0 + return Constant::getNullValue(CI->getType()); + + // memcmp(S1,S2,1) -> *(unsigned char*)LHS - *(unsigned char*)RHS + if (Len == 1) { + Value *LHSV = B.CreateZExt(B.CreateLoad(CastToCStr(LHS, B), "lhsc"), + CI->getType(), "lhsv"); + Value *RHSV = B.CreateZExt(B.CreateLoad(CastToCStr(RHS, B), "rhsc"), + CI->getType(), "rhsv"); + return B.CreateSub(LHSV, RHSV, "chardiff"); + } + + // Constant folding: memcmp(x, y, l) -> cnst (all arguments are constant) + StringRef LHSStr, RHSStr; + if (getConstantStringInfo(LHS, LHSStr) && + getConstantStringInfo(RHS, RHSStr)) { + // Make sure we're not reading out-of-bounds memory. + if (Len > LHSStr.size() || Len > RHSStr.size()) + return nullptr; + // Fold the memcmp and normalize the result. This way we get consistent + // results across multiple platforms. + uint64_t Ret = 0; + int Cmp = memcmp(LHSStr.data(), RHSStr.data(), Len); + if (Cmp < 0) + Ret = -1; + else if (Cmp > 0) + Ret = 1; + return ConstantInt::get(CI->getType(), Ret); } -}; -struct MemCpyOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // These optimizations require DataLayout. - if (!DL) return nullptr; + return nullptr; +} - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(*Context)) - return nullptr; +Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // memcpy(x, y, n) -> llvm.memcpy(x, y, n, 1) - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } -}; + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + FT->getParamType(2) != DL->getIntPtrType(CI->getContext())) + return nullptr; -struct MemMoveOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // These optimizations require DataLayout. - if (!DL) return nullptr; + // memcpy(x, y, n) -> llvm.memcpy(x, y, n, 1) + B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); +} - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(*Context)) - return nullptr; +Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // memmove(x, y, n) -> llvm.memmove(x, y, n, 1) - B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } -}; + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + FT->getParamType(2) != DL->getIntPtrType(CI->getContext())) + return nullptr; -struct MemSetOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // These optimizations require DataLayout. - if (!DL) return nullptr; + // memmove(x, y, n) -> llvm.memmove(x, y, n, 1) + B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); +} - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - FT->getParamType(2) != DL->getIntPtrType(FT->getParamType(0))) - return nullptr; +Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // memset(p, v, n) -> llvm.memset(p, v, n, 1) - Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); - B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } -}; + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isIntegerTy() || + FT->getParamType(2) != DL->getIntPtrType(FT->getParamType(0))) + return nullptr; + + // memset(p, v, n) -> llvm.memset(p, v, n, 1) + Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); + B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); + return CI->getArgOperand(0); +} //===----------------------------------------------------------------------===// // Math Library Optimizations @@ -1111,935 +1034,959 @@ struct MemSetOpt : public LibCallOptimization { //===----------------------------------------------------------------------===// // Double -> Float Shrinking Optimizations for Unary Functions like 'floor' -struct UnaryDoubleFPOpt : public LibCallOptimization { - bool CheckRetType; - UnaryDoubleFPOpt(bool CheckReturnType): CheckRetType(CheckReturnType) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 1 || !FT->getReturnType()->isDoubleTy() || - !FT->getParamType(0)->isDoubleTy()) - return nullptr; +Value *LibCallSimplifier::optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, + bool CheckRetType) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 1 || !FT->getReturnType()->isDoubleTy() || + !FT->getParamType(0)->isDoubleTy()) + return nullptr; - if (CheckRetType) { - // Check if all the uses for function like 'sin' are converted to float. - for (User *U : CI->users()) { - FPTruncInst *Cast = dyn_cast<FPTruncInst>(U); - if (!Cast || !Cast->getType()->isFloatTy()) - return nullptr; - } + if (CheckRetType) { + // Check if all the uses for function like 'sin' are converted to float. + for (User *U : CI->users()) { + FPTruncInst *Cast = dyn_cast<FPTruncInst>(U); + if (!Cast || !Cast->getType()->isFloatTy()) + return nullptr; } + } - // If this is something like 'floor((double)floatval)', convert to floorf. - FPExtInst *Cast = dyn_cast<FPExtInst>(CI->getArgOperand(0)); - if (!Cast || !Cast->getOperand(0)->getType()->isFloatTy()) - return nullptr; + // If this is something like 'floor((double)floatval)', convert to floorf. + FPExtInst *Cast = dyn_cast<FPExtInst>(CI->getArgOperand(0)); + if (!Cast || !Cast->getOperand(0)->getType()->isFloatTy()) + return nullptr; - // floor((double)floatval) -> (double)floorf(floatval) - Value *V = Cast->getOperand(0); + // floor((double)floatval) -> (double)floorf(floatval) + Value *V = Cast->getOperand(0); + if (Callee->isIntrinsic()) { + Module *M = CI->getParent()->getParent()->getParent(); + Intrinsic::ID IID = (Intrinsic::ID) Callee->getIntrinsicID(); + Function *F = Intrinsic::getDeclaration(M, IID, B.getFloatTy()); + V = B.CreateCall(F, V); + } else { + // The call is a library call rather than an intrinsic. V = EmitUnaryFloatFnCall(V, Callee->getName(), B, Callee->getAttributes()); - return B.CreateFPExt(V, B.getDoubleTy()); } -}; + + return B.CreateFPExt(V, B.getDoubleTy()); +} // Double -> Float Shrinking Optimizations for Binary Functions like 'fmin/fmax' -struct BinaryDoubleFPOpt : public LibCallOptimization { - bool CheckRetType; - BinaryDoubleFPOpt(bool CheckReturnType): CheckRetType(CheckReturnType) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 2 arguments of the same FP type, which match the - // result type. - if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - !FT->getParamType(0)->isFloatingPointTy()) - return nullptr; +Value *LibCallSimplifier::optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 2 arguments of the same FP type, which match the + // result type. + if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + !FT->getParamType(0)->isFloatingPointTy()) + return nullptr; - if (CheckRetType) { - // Check if all the uses for function like 'fmin/fmax' are converted to - // float. - for (User *U : CI->users()) { - FPTruncInst *Cast = dyn_cast<FPTruncInst>(U); - if (!Cast || !Cast->getType()->isFloatTy()) - return nullptr; - } - } + // If this is something like 'fmin((double)floatval1, (double)floatval2)', + // we convert it to fminf. + FPExtInst *Cast1 = dyn_cast<FPExtInst>(CI->getArgOperand(0)); + FPExtInst *Cast2 = dyn_cast<FPExtInst>(CI->getArgOperand(1)); + if (!Cast1 || !Cast1->getOperand(0)->getType()->isFloatTy() || !Cast2 || + !Cast2->getOperand(0)->getType()->isFloatTy()) + return nullptr; - // If this is something like 'fmin((double)floatval1, (double)floatval2)', - // we convert it to fminf. - FPExtInst *Cast1 = dyn_cast<FPExtInst>(CI->getArgOperand(0)); - FPExtInst *Cast2 = dyn_cast<FPExtInst>(CI->getArgOperand(1)); - if (!Cast1 || !Cast1->getOperand(0)->getType()->isFloatTy() || - !Cast2 || !Cast2->getOperand(0)->getType()->isFloatTy()) - return nullptr; + // fmin((double)floatval1, (double)floatval2) + // -> (double)fmin(floatval1, floatval2) + Value *V = nullptr; + Value *V1 = Cast1->getOperand(0); + Value *V2 = Cast2->getOperand(0); + // TODO: Handle intrinsics in the same way as in optimizeUnaryDoubleFP(). + V = EmitBinaryFloatFnCall(V1, V2, Callee->getName(), B, + Callee->getAttributes()); + return B.CreateFPExt(V, B.getDoubleTy()); +} - // fmin((double)floatval1, (double)floatval2) - // -> (double)fmin(floatval1, floatval2) - Value *V = nullptr; - Value *V1 = Cast1->getOperand(0); - Value *V2 = Cast2->getOperand(0); - V = EmitBinaryFloatFnCall(V1, V2, Callee->getName(), B, - Callee->getAttributes()); - return B.CreateFPExt(V, B.getDoubleTy()); - } -}; - -struct UnsafeFPLibCallOptimization : public LibCallOptimization { - bool UnsafeFPShrink; - UnsafeFPLibCallOptimization(bool UnsafeFPShrink) { - this->UnsafeFPShrink = UnsafeFPShrink; - } -}; - -struct CosOpt : public UnsafeFPLibCallOptimization { - CosOpt(bool UnsafeFPShrink) : UnsafeFPLibCallOptimization(UnsafeFPShrink) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - Value *Ret = nullptr; - if (UnsafeFPShrink && Callee->getName() == "cos" && - TLI->has(LibFunc::cosf)) { - UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); - Ret = UnsafeUnaryDoubleFP.callOptimizer(Callee, CI, B); - } +Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + Value *Ret = nullptr; + if (UnsafeFPShrink && Callee->getName() == "cos" && TLI->has(LibFunc::cosf)) { + Ret = optimizeUnaryDoubleFP(CI, B, true); + } - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 1 argument of FP type, which matches the - // result type. - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 1 argument of FP type, which matches the + // result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; - // cos(-x) -> cos(x) - Value *Op1 = CI->getArgOperand(0); - if (BinaryOperator::isFNeg(Op1)) { - BinaryOperator *BinExpr = cast<BinaryOperator>(Op1); - return B.CreateCall(Callee, BinExpr->getOperand(1), "cos"); - } + // cos(-x) -> cos(x) + Value *Op1 = CI->getArgOperand(0); + if (BinaryOperator::isFNeg(Op1)) { + BinaryOperator *BinExpr = cast<BinaryOperator>(Op1); + return B.CreateCall(Callee, BinExpr->getOperand(1), "cos"); + } + return Ret; +} + +Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + Value *Ret = nullptr; + if (UnsafeFPShrink && Callee->getName() == "pow" && TLI->has(LibFunc::powf)) { + Ret = optimizeUnaryDoubleFP(CI, B, true); + } + + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 2 arguments of the same FP type, which match the + // result type. + if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; + + Value *Op1 = CI->getArgOperand(0), *Op2 = CI->getArgOperand(1); + if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) { + // pow(1.0, x) -> 1.0 + if (Op1C->isExactlyValue(1.0)) + return Op1C; + // pow(2.0, x) -> exp2(x) + if (Op1C->isExactlyValue(2.0) && + hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp2, LibFunc::exp2f, + LibFunc::exp2l)) + return EmitUnaryFloatFnCall(Op2, "exp2", B, Callee->getAttributes()); + // pow(10.0, x) -> exp10(x) + if (Op1C->isExactlyValue(10.0) && + hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp10, LibFunc::exp10f, + LibFunc::exp10l)) + return EmitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp10), B, + Callee->getAttributes()); + } + + ConstantFP *Op2C = dyn_cast<ConstantFP>(Op2); + if (!Op2C) return Ret; + + if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 + return ConstantFP::get(CI->getType(), 1.0); + + if (Op2C->isExactlyValue(0.5) && + hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, + LibFunc::sqrtl) && + hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::fabs, LibFunc::fabsf, + LibFunc::fabsl)) { + // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). + // This is faster than calling pow, and still handles negative zero + // and negative infinity correctly. + // TODO: In fast-math mode, this could be just sqrt(x). + // TODO: In finite-only mode, this could be just fabs(sqrt(x)). + Value *Inf = ConstantFP::getInfinity(CI->getType()); + Value *NegInf = ConstantFP::getInfinity(CI->getType(), true); + Value *Sqrt = EmitUnaryFloatFnCall(Op1, "sqrt", B, Callee->getAttributes()); + Value *FAbs = + EmitUnaryFloatFnCall(Sqrt, "fabs", B, Callee->getAttributes()); + Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf); + Value *Sel = B.CreateSelect(FCmp, Inf, FAbs); + return Sel; + } + + if (Op2C->isExactlyValue(1.0)) // pow(x, 1.0) -> x + return Op1; + if (Op2C->isExactlyValue(2.0)) // pow(x, 2.0) -> x*x + return B.CreateFMul(Op1, Op1, "pow2"); + if (Op2C->isExactlyValue(-1.0)) // pow(x, -1.0) -> 1.0/x + return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Op1, "powrecip"); + return nullptr; +} + +Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + Function *Caller = CI->getParent()->getParent(); + + Value *Ret = nullptr; + if (UnsafeFPShrink && Callee->getName() == "exp2" && + TLI->has(LibFunc::exp2f)) { + Ret = optimizeUnaryDoubleFP(CI, B, true); } -}; - -struct PowOpt : public UnsafeFPLibCallOptimization { - PowOpt(bool UnsafeFPShrink) : UnsafeFPLibCallOptimization(UnsafeFPShrink) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - Value *Ret = nullptr; - if (UnsafeFPShrink && Callee->getName() == "pow" && - TLI->has(LibFunc::powf)) { - UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); - Ret = UnsafeUnaryDoubleFP.callOptimizer(Callee, CI, B); - } - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 2 arguments of the same FP type, which match the - // result type. - if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 1 argument of FP type, which matches the + // result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; - Value *Op1 = CI->getArgOperand(0), *Op2 = CI->getArgOperand(1); - if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) { - // pow(1.0, x) -> 1.0 - if (Op1C->isExactlyValue(1.0)) - return Op1C; - // pow(2.0, x) -> exp2(x) - if (Op1C->isExactlyValue(2.0) && - hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp2, LibFunc::exp2f, - LibFunc::exp2l)) - return EmitUnaryFloatFnCall(Op2, "exp2", B, Callee->getAttributes()); - // pow(10.0, x) -> exp10(x) - if (Op1C->isExactlyValue(10.0) && - hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp10, LibFunc::exp10f, - LibFunc::exp10l)) - return EmitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp10), B, - Callee->getAttributes()); + Value *Op = CI->getArgOperand(0); + // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32 + // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32 + LibFunc::Func LdExp = LibFunc::ldexpl; + if (Op->getType()->isFloatTy()) + LdExp = LibFunc::ldexpf; + else if (Op->getType()->isDoubleTy()) + LdExp = LibFunc::ldexp; + + if (TLI->has(LdExp)) { + Value *LdExpArg = nullptr; + if (SIToFPInst *OpC = dyn_cast<SIToFPInst>(Op)) { + if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() <= 32) + LdExpArg = B.CreateSExt(OpC->getOperand(0), B.getInt32Ty()); + } else if (UIToFPInst *OpC = dyn_cast<UIToFPInst>(Op)) { + if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() < 32) + LdExpArg = B.CreateZExt(OpC->getOperand(0), B.getInt32Ty()); } - ConstantFP *Op2C = dyn_cast<ConstantFP>(Op2); - if (!Op2C) return Ret; - - if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 - return ConstantFP::get(CI->getType(), 1.0); - - if (Op2C->isExactlyValue(0.5) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, - LibFunc::sqrtl) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::fabs, LibFunc::fabsf, - LibFunc::fabsl)) { - // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). - // This is faster than calling pow, and still handles negative zero - // and negative infinity correctly. - // TODO: In fast-math mode, this could be just sqrt(x). - // TODO: In finite-only mode, this could be just fabs(sqrt(x)). - Value *Inf = ConstantFP::getInfinity(CI->getType()); - Value *NegInf = ConstantFP::getInfinity(CI->getType(), true); - Value *Sqrt = EmitUnaryFloatFnCall(Op1, "sqrt", B, - Callee->getAttributes()); - Value *FAbs = EmitUnaryFloatFnCall(Sqrt, "fabs", B, - Callee->getAttributes()); - Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf); - Value *Sel = B.CreateSelect(FCmp, Inf, FAbs); - return Sel; - } + if (LdExpArg) { + Constant *One = ConstantFP::get(CI->getContext(), APFloat(1.0f)); + if (!Op->getType()->isFloatTy()) + One = ConstantExpr::getFPExtend(One, Op->getType()); + + Module *M = Caller->getParent(); + Value *Callee = + M->getOrInsertFunction(TLI->getName(LdExp), Op->getType(), + Op->getType(), B.getInt32Ty(), nullptr); + CallInst *CI = B.CreateCall2(Callee, One, LdExpArg); + if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); - if (Op2C->isExactlyValue(1.0)) // pow(x, 1.0) -> x - return Op1; - if (Op2C->isExactlyValue(2.0)) // pow(x, 2.0) -> x*x - return B.CreateFMul(Op1, Op1, "pow2"); - if (Op2C->isExactlyValue(-1.0)) // pow(x, -1.0) -> 1.0/x - return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), - Op1, "powrecip"); - return nullptr; - } -}; - -struct Exp2Opt : public UnsafeFPLibCallOptimization { - Exp2Opt(bool UnsafeFPShrink) : UnsafeFPLibCallOptimization(UnsafeFPShrink) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - Value *Ret = nullptr; - if (UnsafeFPShrink && Callee->getName() == "exp2" && - TLI->has(LibFunc::exp2f)) { - UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); - Ret = UnsafeUnaryDoubleFP.callOptimizer(Callee, CI, B); + return CI; } + } + return Ret; +} - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 1 argument of FP type, which matches the - // result type. - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; +Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); - Value *Op = CI->getArgOperand(0); - // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32 - // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32 - LibFunc::Func LdExp = LibFunc::ldexpl; - if (Op->getType()->isFloatTy()) - LdExp = LibFunc::ldexpf; - else if (Op->getType()->isDoubleTy()) - LdExp = LibFunc::ldexp; - - if (TLI->has(LdExp)) { - Value *LdExpArg = nullptr; - if (SIToFPInst *OpC = dyn_cast<SIToFPInst>(Op)) { - if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() <= 32) - LdExpArg = B.CreateSExt(OpC->getOperand(0), B.getInt32Ty()); - } else if (UIToFPInst *OpC = dyn_cast<UIToFPInst>(Op)) { - if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() < 32) - LdExpArg = B.CreateZExt(OpC->getOperand(0), B.getInt32Ty()); - } + Value *Ret = nullptr; + if (Callee->getName() == "fabs" && TLI->has(LibFunc::fabsf)) { + Ret = optimizeUnaryDoubleFP(CI, B, false); + } - if (LdExpArg) { - Constant *One = ConstantFP::get(*Context, APFloat(1.0f)); - if (!Op->getType()->isFloatTy()) - One = ConstantExpr::getFPExtend(One, Op->getType()); + FunctionType *FT = Callee->getFunctionType(); + // Make sure this has 1 argument of FP type which matches the result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; - Module *M = Caller->getParent(); - Value *Callee = - M->getOrInsertFunction(TLI->getName(LdExp), Op->getType(), - Op->getType(), B.getInt32Ty(), NULL); - CallInst *CI = B.CreateCall2(Callee, One, LdExpArg); - if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) - CI->setCallingConv(F->getCallingConv()); + Value *Op = CI->getArgOperand(0); + if (Instruction *I = dyn_cast<Instruction>(Op)) { + // Fold fabs(x * x) -> x * x; any squared FP value must already be positive. + if (I->getOpcode() == Instruction::FMul) + if (I->getOperand(0) == I->getOperand(1)) + return Op; + } + return Ret; +} - return CI; +Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + Value *Ret = nullptr; + if (TLI->has(LibFunc::sqrtf) && (Callee->getName() == "sqrt" || + Callee->getIntrinsicID() == Intrinsic::sqrt)) + Ret = optimizeUnaryDoubleFP(CI, B, true); + + // FIXME: For finer-grain optimization, we need intrinsics to have the same + // fast-math flag decorations that are applied to FP instructions. For now, + // we have to rely on the function-level unsafe-fp-math attribute to do this + // optimization because there's no other way to express that the sqrt can be + // reassociated. + Function *F = CI->getParent()->getParent(); + if (F->hasFnAttribute("unsafe-fp-math")) { + // Check for unsafe-fp-math = true. + Attribute Attr = F->getFnAttribute("unsafe-fp-math"); + if (Attr.getValueAsString() != "true") + return Ret; + } + Value *Op = CI->getArgOperand(0); + if (Instruction *I = dyn_cast<Instruction>(Op)) { + if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) { + // We're looking for a repeated factor in a multiplication tree, + // so we can do this fold: sqrt(x * x) -> fabs(x); + // or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y). + Value *Op0 = I->getOperand(0); + Value *Op1 = I->getOperand(1); + Value *RepeatOp = nullptr; + Value *OtherOp = nullptr; + if (Op0 == Op1) { + // Simple match: the operands of the multiply are identical. + RepeatOp = Op0; + } else { + // Look for a more complicated pattern: one of the operands is itself + // a multiply, so search for a common factor in that multiply. + // Note: We don't bother looking any deeper than this first level or for + // variations of this pattern because instcombine's visitFMUL and/or the + // reassociation pass should give us this form. + Value *OtherMul0, *OtherMul1; + if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) { + // Pattern: sqrt((x * y) * z) + if (OtherMul0 == OtherMul1) { + // Matched: sqrt((x * x) * z) + RepeatOp = OtherMul0; + OtherOp = Op1; + } + } + } + if (RepeatOp) { + // Fast math flags for any created instructions should match the sqrt + // and multiply. + // FIXME: We're not checking the sqrt because it doesn't have + // fast-math-flags (see earlier comment). + IRBuilder<true, ConstantFolder, + IRBuilderDefaultInserter<true> >::FastMathFlagGuard Guard(B); + B.SetFastMathFlags(I->getFastMathFlags()); + // If we found a repeated factor, hoist it out of the square root and + // replace it with the fabs of that factor. + Module *M = Callee->getParent(); + Type *ArgType = Op->getType(); + Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType); + Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs"); + if (OtherOp) { + // If we found a non-repeated factor, we still need to get its square + // root. We then multiply that by the value that was simplified out + // of the square root calculation. + Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType); + Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt"); + return B.CreateFMul(FabsCall, SqrtCall); + } + return FabsCall; } } - return Ret; } -}; + return Ret; +} -struct SinCosPiOpt : public LibCallOptimization { - SinCosPiOpt() {} +static bool isTrigLibCall(CallInst *CI); +static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, + bool UseFloat, Value *&Sin, Value *&Cos, + Value *&SinCos); - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Make sure the prototype is as expected, otherwise the rest of the - // function is probably invalid and likely to abort. - if (!isTrigLibCall(CI)) - return nullptr; +Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilder<> &B) { - Value *Arg = CI->getArgOperand(0); - SmallVector<CallInst *, 1> SinCalls; - SmallVector<CallInst *, 1> CosCalls; - SmallVector<CallInst *, 1> SinCosCalls; + // Make sure the prototype is as expected, otherwise the rest of the + // function is probably invalid and likely to abort. + if (!isTrigLibCall(CI)) + return nullptr; - bool IsFloat = Arg->getType()->isFloatTy(); + Value *Arg = CI->getArgOperand(0); + SmallVector<CallInst *, 1> SinCalls; + SmallVector<CallInst *, 1> CosCalls; + SmallVector<CallInst *, 1> SinCosCalls; - // Look for all compatible sinpi, cospi and sincospi calls with the same - // argument. If there are enough (in some sense) we can make the - // substitution. - for (User *U : Arg->users()) - classifyArgUse(U, CI->getParent(), IsFloat, SinCalls, CosCalls, - SinCosCalls); + bool IsFloat = Arg->getType()->isFloatTy(); - // It's only worthwhile if both sinpi and cospi are actually used. - if (SinCosCalls.empty() && (SinCalls.empty() || CosCalls.empty())) - return nullptr; + // Look for all compatible sinpi, cospi and sincospi calls with the same + // argument. If there are enough (in some sense) we can make the + // substitution. + for (User *U : Arg->users()) + classifyArgUse(U, CI->getParent(), IsFloat, SinCalls, CosCalls, + SinCosCalls); - Value *Sin, *Cos, *SinCos; - insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos, - SinCos); - - replaceTrigInsts(SinCalls, Sin); - replaceTrigInsts(CosCalls, Cos); - replaceTrigInsts(SinCosCalls, SinCos); - - return nullptr; - } - - bool isTrigLibCall(CallInst *CI) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - - // We can only hope to do anything useful if we can ignore things like errno - // and floating-point exceptions. - bool AttributesSafe = CI->hasFnAttr(Attribute::NoUnwind) && - CI->hasFnAttr(Attribute::ReadNone); - - // Other than that we need float(float) or double(double) - return AttributesSafe && FT->getNumParams() == 1 && - FT->getReturnType() == FT->getParamType(0) && - (FT->getParamType(0)->isFloatTy() || - FT->getParamType(0)->isDoubleTy()); - } - - void classifyArgUse(Value *Val, BasicBlock *BB, bool IsFloat, - SmallVectorImpl<CallInst *> &SinCalls, - SmallVectorImpl<CallInst *> &CosCalls, - SmallVectorImpl<CallInst *> &SinCosCalls) { - CallInst *CI = dyn_cast<CallInst>(Val); - - if (!CI) - return; - - Function *Callee = CI->getCalledFunction(); - StringRef FuncName = Callee->getName(); - LibFunc::Func Func; - if (!TLI->getLibFunc(FuncName, Func) || !TLI->has(Func) || - !isTrigLibCall(CI)) - return; - - if (IsFloat) { - if (Func == LibFunc::sinpif) - SinCalls.push_back(CI); - else if (Func == LibFunc::cospif) - CosCalls.push_back(CI); - else if (Func == LibFunc::sincospif_stret) - SinCosCalls.push_back(CI); - } else { - if (Func == LibFunc::sinpi) - SinCalls.push_back(CI); - else if (Func == LibFunc::cospi) - CosCalls.push_back(CI); - else if (Func == LibFunc::sincospi_stret) - SinCosCalls.push_back(CI); - } - } + // It's only worthwhile if both sinpi and cospi are actually used. + if (SinCosCalls.empty() && (SinCalls.empty() || CosCalls.empty())) + return nullptr; - void replaceTrigInsts(SmallVectorImpl<CallInst*> &Calls, Value *Res) { - for (SmallVectorImpl<CallInst*>::iterator I = Calls.begin(), - E = Calls.end(); - I != E; ++I) { - LCS->replaceAllUsesWith(*I, Res); - } - } + Value *Sin, *Cos, *SinCos; + insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos, SinCos); - void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, - bool UseFloat, Value *&Sin, Value *&Cos, - Value *&SinCos) { - Type *ArgTy = Arg->getType(); - Type *ResTy; - StringRef Name; - - Triple T(OrigCallee->getParent()->getTargetTriple()); - if (UseFloat) { - Name = "__sincospif_stret"; - - assert(T.getArch() != Triple::x86 && "x86 messy and unsupported for now"); - // x86_64 can't use {float, float} since that would be returned in both - // xmm0 and xmm1, which isn't what a real struct would do. - ResTy = T.getArch() == Triple::x86_64 - ? static_cast<Type *>(VectorType::get(ArgTy, 2)) - : static_cast<Type *>(StructType::get(ArgTy, ArgTy, NULL)); - } else { - Name = "__sincospi_stret"; - ResTy = StructType::get(ArgTy, ArgTy, NULL); - } + replaceTrigInsts(SinCalls, Sin); + replaceTrigInsts(CosCalls, Cos); + replaceTrigInsts(SinCosCalls, SinCos); - Module *M = OrigCallee->getParent(); - Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(), - ResTy, ArgTy, NULL); - - if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { - // If the argument is an instruction, it must dominate all uses so put our - // sincos call there. - BasicBlock::iterator Loc = ArgInst; - B.SetInsertPoint(ArgInst->getParent(), ++Loc); - } else { - // Otherwise (e.g. for a constant) the beginning of the function is as - // good a place as any. - BasicBlock &EntryBB = B.GetInsertBlock()->getParent()->getEntryBlock(); - B.SetInsertPoint(&EntryBB, EntryBB.begin()); - } + return nullptr; +} - SinCos = B.CreateCall(Callee, Arg, "sincospi"); +static bool isTrigLibCall(CallInst *CI) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + + // We can only hope to do anything useful if we can ignore things like errno + // and floating-point exceptions. + bool AttributesSafe = + CI->hasFnAttr(Attribute::NoUnwind) && CI->hasFnAttr(Attribute::ReadNone); + + // Other than that we need float(float) or double(double) + return AttributesSafe && FT->getNumParams() == 1 && + FT->getReturnType() == FT->getParamType(0) && + (FT->getParamType(0)->isFloatTy() || + FT->getParamType(0)->isDoubleTy()); +} - if (SinCos->getType()->isStructTy()) { - Sin = B.CreateExtractValue(SinCos, 0, "sinpi"); - Cos = B.CreateExtractValue(SinCos, 1, "cospi"); - } else { - Sin = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 0), - "sinpi"); - Cos = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 1), - "cospi"); - } +void +LibCallSimplifier::classifyArgUse(Value *Val, BasicBlock *BB, bool IsFloat, + SmallVectorImpl<CallInst *> &SinCalls, + SmallVectorImpl<CallInst *> &CosCalls, + SmallVectorImpl<CallInst *> &SinCosCalls) { + CallInst *CI = dyn_cast<CallInst>(Val); + + if (!CI) + return; + + Function *Callee = CI->getCalledFunction(); + StringRef FuncName = Callee->getName(); + LibFunc::Func Func; + if (!TLI->getLibFunc(FuncName, Func) || !TLI->has(Func) || !isTrigLibCall(CI)) + return; + + if (IsFloat) { + if (Func == LibFunc::sinpif) + SinCalls.push_back(CI); + else if (Func == LibFunc::cospif) + CosCalls.push_back(CI); + else if (Func == LibFunc::sincospif_stret) + SinCosCalls.push_back(CI); + } else { + if (Func == LibFunc::sinpi) + SinCalls.push_back(CI); + else if (Func == LibFunc::cospi) + CosCalls.push_back(CI); + else if (Func == LibFunc::sincospi_stret) + SinCosCalls.push_back(CI); } +} -}; +void LibCallSimplifier::replaceTrigInsts(SmallVectorImpl<CallInst *> &Calls, + Value *Res) { + for (SmallVectorImpl<CallInst *>::iterator I = Calls.begin(), E = Calls.end(); + I != E; ++I) { + replaceAllUsesWith(*I, Res); + } +} + +void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, + bool UseFloat, Value *&Sin, Value *&Cos, Value *&SinCos) { + Type *ArgTy = Arg->getType(); + Type *ResTy; + StringRef Name; + + Triple T(OrigCallee->getParent()->getTargetTriple()); + if (UseFloat) { + Name = "__sincospif_stret"; + + assert(T.getArch() != Triple::x86 && "x86 messy and unsupported for now"); + // x86_64 can't use {float, float} since that would be returned in both + // xmm0 and xmm1, which isn't what a real struct would do. + ResTy = T.getArch() == Triple::x86_64 + ? static_cast<Type *>(VectorType::get(ArgTy, 2)) + : static_cast<Type *>(StructType::get(ArgTy, ArgTy, nullptr)); + } else { + Name = "__sincospi_stret"; + ResTy = StructType::get(ArgTy, ArgTy, nullptr); + } + + Module *M = OrigCallee->getParent(); + Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(), + ResTy, ArgTy, nullptr); + + if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { + // If the argument is an instruction, it must dominate all uses so put our + // sincos call there. + BasicBlock::iterator Loc = ArgInst; + B.SetInsertPoint(ArgInst->getParent(), ++Loc); + } else { + // Otherwise (e.g. for a constant) the beginning of the function is as + // good a place as any. + BasicBlock &EntryBB = B.GetInsertBlock()->getParent()->getEntryBlock(); + B.SetInsertPoint(&EntryBB, EntryBB.begin()); + } + + SinCos = B.CreateCall(Callee, Arg, "sincospi"); + + if (SinCos->getType()->isStructTy()) { + Sin = B.CreateExtractValue(SinCos, 0, "sinpi"); + Cos = B.CreateExtractValue(SinCos, 1, "cospi"); + } else { + Sin = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 0), + "sinpi"); + Cos = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 1), + "cospi"); + } +} //===----------------------------------------------------------------------===// // Integer Library Call Optimizations //===----------------------------------------------------------------------===// -struct FFSOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 2 arguments of the same FP type, which match the - // result type. - if (FT->getNumParams() != 1 || - !FT->getReturnType()->isIntegerTy(32) || - !FT->getParamType(0)->isIntegerTy()) - return nullptr; +Value *LibCallSimplifier::optimizeFFS(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 2 arguments of the same FP type, which match the + // result type. + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy(32) || + !FT->getParamType(0)->isIntegerTy()) + return nullptr; - Value *Op = CI->getArgOperand(0); + Value *Op = CI->getArgOperand(0); - // Constant fold. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { - if (CI->isZero()) // ffs(0) -> 0. - return B.getInt32(0); - // ffs(c) -> cttz(c)+1 - return B.getInt32(CI->getValue().countTrailingZeros() + 1); - } + // Constant fold. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { + if (CI->isZero()) // ffs(0) -> 0. + return B.getInt32(0); + // ffs(c) -> cttz(c)+1 + return B.getInt32(CI->getValue().countTrailingZeros() + 1); + } - // ffs(x) -> x != 0 ? (i32)llvm.cttz(x)+1 : 0 - Type *ArgType = Op->getType(); - Value *F = Intrinsic::getDeclaration(Callee->getParent(), - Intrinsic::cttz, ArgType); - Value *V = B.CreateCall2(F, Op, B.getFalse(), "cttz"); - V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1)); - V = B.CreateIntCast(V, B.getInt32Ty(), false); - - Value *Cond = B.CreateICmpNE(Op, Constant::getNullValue(ArgType)); - return B.CreateSelect(Cond, V, B.getInt32(0)); - } -}; - -struct AbsOpt : public LibCallOptimization { - bool ignoreCallingConv() override { return true; } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require integer(integer) where the types agree. - if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || - FT->getParamType(0) != FT->getReturnType()) - return nullptr; + // ffs(x) -> x != 0 ? (i32)llvm.cttz(x)+1 : 0 + Type *ArgType = Op->getType(); + Value *F = + Intrinsic::getDeclaration(Callee->getParent(), Intrinsic::cttz, ArgType); + Value *V = B.CreateCall2(F, Op, B.getFalse(), "cttz"); + V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1)); + V = B.CreateIntCast(V, B.getInt32Ty(), false); - // abs(x) -> x >s -1 ? x : -x - Value *Op = CI->getArgOperand(0); - Value *Pos = B.CreateICmpSGT(Op, Constant::getAllOnesValue(Op->getType()), - "ispos"); - Value *Neg = B.CreateNeg(Op, "neg"); - return B.CreateSelect(Pos, Op, Neg); - } -}; - -struct IsDigitOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require integer(i32) - if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || - !FT->getParamType(0)->isIntegerTy(32)) - return nullptr; + Value *Cond = B.CreateICmpNE(Op, Constant::getNullValue(ArgType)); + return B.CreateSelect(Cond, V, B.getInt32(0)); +} - // isdigit(c) -> (c-'0') <u 10 - Value *Op = CI->getArgOperand(0); - Op = B.CreateSub(Op, B.getInt32('0'), "isdigittmp"); - Op = B.CreateICmpULT(Op, B.getInt32(10), "isdigit"); - return B.CreateZExt(Op, CI->getType()); - } -}; - -struct IsAsciiOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require integer(i32) - if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || - !FT->getParamType(0)->isIntegerTy(32)) - return nullptr; +Value *LibCallSimplifier::optimizeAbs(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require integer(integer) where the types agree. + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || + FT->getParamType(0) != FT->getReturnType()) + return nullptr; - // isascii(c) -> c <u 128 - Value *Op = CI->getArgOperand(0); - Op = B.CreateICmpULT(Op, B.getInt32(128), "isascii"); - return B.CreateZExt(Op, CI->getType()); - } -}; + // abs(x) -> x >s -1 ? x : -x + Value *Op = CI->getArgOperand(0); + Value *Pos = + B.CreateICmpSGT(Op, Constant::getAllOnesValue(Op->getType()), "ispos"); + Value *Neg = B.CreateNeg(Op, "neg"); + return B.CreateSelect(Pos, Op, Neg); +} -struct ToAsciiOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require i32(i32) - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isIntegerTy(32)) - return nullptr; +Value *LibCallSimplifier::optimizeIsDigit(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require integer(i32) + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || + !FT->getParamType(0)->isIntegerTy(32)) + return nullptr; - // toascii(c) -> c & 0x7f - return B.CreateAnd(CI->getArgOperand(0), - ConstantInt::get(CI->getType(),0x7F)); - } -}; + // isdigit(c) -> (c-'0') <u 10 + Value *Op = CI->getArgOperand(0); + Op = B.CreateSub(Op, B.getInt32('0'), "isdigittmp"); + Op = B.CreateICmpULT(Op, B.getInt32(10), "isdigit"); + return B.CreateZExt(Op, CI->getType()); +} + +Value *LibCallSimplifier::optimizeIsAscii(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require integer(i32) + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || + !FT->getParamType(0)->isIntegerTy(32)) + return nullptr; + + // isascii(c) -> c <u 128 + Value *Op = CI->getArgOperand(0); + Op = B.CreateICmpULT(Op, B.getInt32(128), "isascii"); + return B.CreateZExt(Op, CI->getType()); +} + +Value *LibCallSimplifier::optimizeToAscii(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require i32(i32) + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isIntegerTy(32)) + return nullptr; + + // toascii(c) -> c & 0x7f + return B.CreateAnd(CI->getArgOperand(0), + ConstantInt::get(CI->getType(), 0x7F)); +} //===----------------------------------------------------------------------===// // Formatting and IO Library Call Optimizations //===----------------------------------------------------------------------===// -struct ErrorReportingOpt : public LibCallOptimization { - ErrorReportingOpt(int S = -1) : StreamArg(S) {} +static bool isReportingError(Function *Callee, CallInst *CI, int StreamArg); - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &) override { - // Error reporting calls should be cold, mark them as such. - // This applies even to non-builtin calls: it is only a hint and applies to - // functions that the frontend might not understand as builtins. +Value *LibCallSimplifier::optimizeErrorReporting(CallInst *CI, IRBuilder<> &B, + int StreamArg) { + // Error reporting calls should be cold, mark them as such. + // This applies even to non-builtin calls: it is only a hint and applies to + // functions that the frontend might not understand as builtins. - // This heuristic was suggested in: - // Improving Static Branch Prediction in a Compiler - // Brian L. Deitrich, Ben-Chung Cheng, Wen-mei W. Hwu - // Proceedings of PACT'98, Oct. 1998, IEEE - - if (!CI->hasFnAttr(Attribute::Cold) && isReportingError(Callee, CI)) { - CI->addAttribute(AttributeSet::FunctionIndex, Attribute::Cold); - } + // This heuristic was suggested in: + // Improving Static Branch Prediction in a Compiler + // Brian L. Deitrich, Ben-Chung Cheng, Wen-mei W. Hwu + // Proceedings of PACT'98, Oct. 1998, IEEE + Function *Callee = CI->getCalledFunction(); - return nullptr; + if (!CI->hasFnAttr(Attribute::Cold) && + isReportingError(Callee, CI, StreamArg)) { + CI->addAttribute(AttributeSet::FunctionIndex, Attribute::Cold); } -protected: - bool isReportingError(Function *Callee, CallInst *CI) { - if (!ColdErrorCalls) - return false; - - if (!Callee || !Callee->isDeclaration()) - return false; - - if (StreamArg < 0) - return true; + return nullptr; +} - // These functions might be considered cold, but only if their stream - // argument is stderr. - - if (StreamArg >= (int) CI->getNumArgOperands()) - return false; - LoadInst *LI = dyn_cast<LoadInst>(CI->getArgOperand(StreamArg)); - if (!LI) - return false; - GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getPointerOperand()); - if (!GV || !GV->isDeclaration()) - return false; - return GV->getName() == "stderr"; - } - - int StreamArg; -}; - -struct PrintFOpt : public LibCallOptimization { - Value *optimizeFixedFormatString(Function *Callee, CallInst *CI, - IRBuilder<> &B) { - // Check for a fixed format string. - StringRef FormatStr; - if (!getConstantStringInfo(CI->getArgOperand(0), FormatStr)) - return nullptr; +static bool isReportingError(Function *Callee, CallInst *CI, int StreamArg) { + if (!ColdErrorCalls) + return false; - // Empty format string -> noop. - if (FormatStr.empty()) // Tolerate printf's declared void. - return CI->use_empty() ? (Value*)CI : - ConstantInt::get(CI->getType(), 0); + if (!Callee || !Callee->isDeclaration()) + return false; - // Do not do any of the following transformations if the printf return value - // is used, in general the printf return value is not compatible with either - // putchar() or puts(). - if (!CI->use_empty()) - return nullptr; + if (StreamArg < 0) + return true; - // printf("x") -> putchar('x'), even for '%'. - if (FormatStr.size() == 1) { - Value *Res = EmitPutChar(B.getInt32(FormatStr[0]), B, DL, TLI); - if (CI->use_empty() || !Res) return Res; - return B.CreateIntCast(Res, CI->getType(), true); - } + // These functions might be considered cold, but only if their stream + // argument is stderr. - // printf("foo\n") --> puts("foo") - if (FormatStr[FormatStr.size()-1] == '\n' && - FormatStr.find('%') == StringRef::npos) { // No format characters. - // Create a string literal with no \n on it. We expect the constant merge - // pass to be run after this pass, to merge duplicate strings. - FormatStr = FormatStr.drop_back(); - Value *GV = B.CreateGlobalString(FormatStr, "str"); - Value *NewCI = EmitPutS(GV, B, DL, TLI); - return (CI->use_empty() || !NewCI) ? - NewCI : - ConstantInt::get(CI->getType(), FormatStr.size()+1); - } + if (StreamArg >= (int)CI->getNumArgOperands()) + return false; + LoadInst *LI = dyn_cast<LoadInst>(CI->getArgOperand(StreamArg)); + if (!LI) + return false; + GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getPointerOperand()); + if (!GV || !GV->isDeclaration()) + return false; + return GV->getName() == "stderr"; +} - // Optimize specific format strings. - // printf("%c", chr) --> putchar(chr) - if (FormatStr == "%c" && CI->getNumArgOperands() > 1 && - CI->getArgOperand(1)->getType()->isIntegerTy()) { - Value *Res = EmitPutChar(CI->getArgOperand(1), B, DL, TLI); +Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilder<> &B) { + // Check for a fixed format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(0), FormatStr)) + return nullptr; - if (CI->use_empty() || !Res) return Res; - return B.CreateIntCast(Res, CI->getType(), true); - } + // Empty format string -> noop. + if (FormatStr.empty()) // Tolerate printf's declared void. + return CI->use_empty() ? (Value *)CI : ConstantInt::get(CI->getType(), 0); - // printf("%s\n", str) --> puts(str) - if (FormatStr == "%s\n" && CI->getNumArgOperands() > 1 && - CI->getArgOperand(1)->getType()->isPointerTy()) { - return EmitPutS(CI->getArgOperand(1), B, DL, TLI); - } + // Do not do any of the following transformations if the printf return value + // is used, in general the printf return value is not compatible with either + // putchar() or puts(). + if (!CI->use_empty()) return nullptr; - } - - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require one fixed pointer argument and an integer/void result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || - !(FT->getReturnType()->isIntegerTy() || - FT->getReturnType()->isVoidTy())) - return nullptr; - if (Value *V = optimizeFixedFormatString(Callee, CI, B)) { - return V; - } + // printf("x") -> putchar('x'), even for '%'. + if (FormatStr.size() == 1) { + Value *Res = EmitPutChar(B.getInt32(FormatStr[0]), B, DL, TLI); + if (CI->use_empty() || !Res) + return Res; + return B.CreateIntCast(Res, CI->getType(), true); + } - // printf(format, ...) -> iprintf(format, ...) if no floating point - // arguments. - if (TLI->has(LibFunc::iprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *IPrintFFn = - M->getOrInsertFunction("iprintf", FT, Callee->getAttributes()); - CallInst *New = cast<CallInst>(CI->clone()); - New->setCalledFunction(IPrintFFn); - B.Insert(New); - return New; - } - return nullptr; + // printf("foo\n") --> puts("foo") + if (FormatStr[FormatStr.size() - 1] == '\n' && + FormatStr.find('%') == StringRef::npos) { // No format characters. + // Create a string literal with no \n on it. We expect the constant merge + // pass to be run after this pass, to merge duplicate strings. + FormatStr = FormatStr.drop_back(); + Value *GV = B.CreateGlobalString(FormatStr, "str"); + Value *NewCI = EmitPutS(GV, B, DL, TLI); + return (CI->use_empty() || !NewCI) + ? NewCI + : ConstantInt::get(CI->getType(), FormatStr.size() + 1); } -}; -struct SPrintFOpt : public LibCallOptimization { - Value *OptimizeFixedFormatString(Function *Callee, CallInst *CI, - IRBuilder<> &B) { - // Check for a fixed format string. - StringRef FormatStr; - if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) - return nullptr; + // Optimize specific format strings. + // printf("%c", chr) --> putchar(chr) + if (FormatStr == "%c" && CI->getNumArgOperands() > 1 && + CI->getArgOperand(1)->getType()->isIntegerTy()) { + Value *Res = EmitPutChar(CI->getArgOperand(1), B, DL, TLI); - // If we just have a format string (nothing else crazy) transform it. - if (CI->getNumArgOperands() == 2) { - // Make sure there's no % in the constant array. We could try to handle - // %% -> % in the future if we cared. - for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) - if (FormatStr[i] == '%') - return nullptr; // we found a format specifier, bail out. - - // These optimizations require DataLayout. - if (!DL) return nullptr; - - // sprintf(str, fmt) -> llvm.memcpy(str, fmt, strlen(fmt)+1, 1) - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - ConstantInt::get(DL->getIntPtrType(*Context), // Copy the - FormatStr.size() + 1), 1); // nul byte. - return ConstantInt::get(CI->getType(), FormatStr.size()); - } + if (CI->use_empty() || !Res) + return Res; + return B.CreateIntCast(Res, CI->getType(), true); + } - // The remaining optimizations require the format string to be "%s" or "%c" - // and have an extra operand. - if (FormatStr.size() != 2 || FormatStr[0] != '%' || - CI->getNumArgOperands() < 3) - return nullptr; + // printf("%s\n", str) --> puts(str) + if (FormatStr == "%s\n" && CI->getNumArgOperands() > 1 && + CI->getArgOperand(1)->getType()->isPointerTy()) { + return EmitPutS(CI->getArgOperand(1), B, DL, TLI); + } + return nullptr; +} - // Decode the second character of the format string. - if (FormatStr[1] == 'c') { - // sprintf(dst, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 - if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; - Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); - Value *Ptr = CastToCStr(CI->getArgOperand(0), B); - B.CreateStore(V, Ptr); - Ptr = B.CreateGEP(Ptr, B.getInt32(1), "nul"); - B.CreateStore(B.getInt8(0), Ptr); - - return ConstantInt::get(CI->getType(), 1); - } +Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilder<> &B) { - if (FormatStr[1] == 's') { - // These optimizations require DataLayout. - if (!DL) return nullptr; + Function *Callee = CI->getCalledFunction(); + // Require one fixed pointer argument and an integer/void result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || + !(FT->getReturnType()->isIntegerTy() || FT->getReturnType()->isVoidTy())) + return nullptr; - // sprintf(dest, "%s", str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) - if (!CI->getArgOperand(2)->getType()->isPointerTy()) return nullptr; + if (Value *V = optimizePrintFString(CI, B)) { + return V; + } - Value *Len = EmitStrLen(CI->getArgOperand(2), B, DL, TLI); - if (!Len) - return nullptr; - Value *IncLen = B.CreateAdd(Len, - ConstantInt::get(Len->getType(), 1), - "leninc"); - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(2), IncLen, 1); + // printf(format, ...) -> iprintf(format, ...) if no floating point + // arguments. + if (TLI->has(LibFunc::iprintf) && !callHasFloatingPointArgument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + Constant *IPrintFFn = + M->getOrInsertFunction("iprintf", FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(IPrintFFn); + B.Insert(New); + return New; + } + return nullptr; +} - // The sprintf result is the unincremented number of bytes in the string. - return B.CreateIntCast(Len, CI->getType(), false); - } +Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, IRBuilder<> &B) { + // Check for a fixed format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) return nullptr; - } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require two fixed pointer arguments and an integer result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) + // If we just have a format string (nothing else crazy) transform it. + if (CI->getNumArgOperands() == 2) { + // Make sure there's no % in the constant array. We could try to handle + // %% -> % in the future if we cared. + for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) + if (FormatStr[i] == '%') + return nullptr; // we found a format specifier, bail out. + + // These optimizations require DataLayout. + if (!DL) return nullptr; - if (Value *V = OptimizeFixedFormatString(Callee, CI, B)) { - return V; - } + // sprintf(str, fmt) -> llvm.memcpy(str, fmt, strlen(fmt)+1, 1) + B.CreateMemCpy( + CI->getArgOperand(0), CI->getArgOperand(1), + ConstantInt::get(DL->getIntPtrType(CI->getContext()), + FormatStr.size() + 1), + 1); // Copy the null byte. + return ConstantInt::get(CI->getType(), FormatStr.size()); + } - // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating - // point arguments. - if (TLI->has(LibFunc::siprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *SIPrintFFn = - M->getOrInsertFunction("siprintf", FT, Callee->getAttributes()); - CallInst *New = cast<CallInst>(CI->clone()); - New->setCalledFunction(SIPrintFFn); - B.Insert(New); - return New; - } + // The remaining optimizations require the format string to be "%s" or "%c" + // and have an extra operand. + if (FormatStr.size() != 2 || FormatStr[0] != '%' || + CI->getNumArgOperands() < 3) return nullptr; + + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + // sprintf(dst, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 + if (!CI->getArgOperand(2)->getType()->isIntegerTy()) + return nullptr; + Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); + Value *Ptr = CastToCStr(CI->getArgOperand(0), B); + B.CreateStore(V, Ptr); + Ptr = B.CreateGEP(Ptr, B.getInt32(1), "nul"); + B.CreateStore(B.getInt8(0), Ptr); + + return ConstantInt::get(CI->getType(), 1); } -}; -struct FPrintFOpt : public LibCallOptimization { - Value *optimizeFixedFormatString(Function *Callee, CallInst *CI, - IRBuilder<> &B) { - ErrorReportingOpt ER(/* StreamArg = */ 0); - (void) ER.callOptimizer(Callee, CI, B); + if (FormatStr[1] == 's') { + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // All the optimizations depend on the format string. - StringRef FormatStr; - if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) + // sprintf(dest, "%s", str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) + if (!CI->getArgOperand(2)->getType()->isPointerTy()) return nullptr; - // Do not do any of the following transformations if the fprintf return - // value is used, in general the fprintf return value is not compatible - // with fwrite(), fputc() or fputs(). - if (!CI->use_empty()) + Value *Len = EmitStrLen(CI->getArgOperand(2), B, DL, TLI); + if (!Len) return nullptr; + Value *IncLen = + B.CreateAdd(Len, ConstantInt::get(Len->getType(), 1), "leninc"); + B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(2), IncLen, 1); - // fprintf(F, "foo") --> fwrite("foo", 3, 1, F) - if (CI->getNumArgOperands() == 2) { - for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) - if (FormatStr[i] == '%') // Could handle %% -> % if we cared. - return nullptr; // We found a format specifier. + // The sprintf result is the unincremented number of bytes in the string. + return B.CreateIntCast(Len, CI->getType(), false); + } + return nullptr; +} - // These optimizations require DataLayout. - if (!DL) return nullptr; +Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Require two fixed pointer arguments and an integer result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isIntegerTy()) + return nullptr; - return EmitFWrite(CI->getArgOperand(1), - ConstantInt::get(DL->getIntPtrType(*Context), - FormatStr.size()), - CI->getArgOperand(0), B, DL, TLI); - } + if (Value *V = optimizeSPrintFString(CI, B)) { + return V; + } - // The remaining optimizations require the format string to be "%s" or "%c" - // and have an extra operand. - if (FormatStr.size() != 2 || FormatStr[0] != '%' || - CI->getNumArgOperands() < 3) - return nullptr; + // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating + // point arguments. + if (TLI->has(LibFunc::siprintf) && !callHasFloatingPointArgument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + Constant *SIPrintFFn = + M->getOrInsertFunction("siprintf", FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(SIPrintFFn); + B.Insert(New); + return New; + } + return nullptr; +} - // Decode the second character of the format string. - if (FormatStr[1] == 'c') { - // fprintf(F, "%c", chr) --> fputc(chr, F) - if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; - return EmitFPutC(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); - } +Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, IRBuilder<> &B) { + optimizeErrorReporting(CI, B, 0); - if (FormatStr[1] == 's') { - // fprintf(F, "%s", str) --> fputs(str, F) - if (!CI->getArgOperand(2)->getType()->isPointerTy()) - return nullptr; - return EmitFPutS(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); - } + // All the optimizations depend on the format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) return nullptr; + + // Do not do any of the following transformations if the fprintf return + // value is used, in general the fprintf return value is not compatible + // with fwrite(), fputc() or fputs(). + if (!CI->use_empty()) + return nullptr; + + // fprintf(F, "foo") --> fwrite("foo", 3, 1, F) + if (CI->getNumArgOperands() == 2) { + for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) + if (FormatStr[i] == '%') // Could handle %% -> % if we cared. + return nullptr; // We found a format specifier. + + // These optimizations require DataLayout. + if (!DL) + return nullptr; + + return EmitFWrite( + CI->getArgOperand(1), + ConstantInt::get(DL->getIntPtrType(CI->getContext()), FormatStr.size()), + CI->getArgOperand(0), B, DL, TLI); } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require two fixed paramters as pointers and integer result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) + // The remaining optimizations require the format string to be "%s" or "%c" + // and have an extra operand. + if (FormatStr.size() != 2 || FormatStr[0] != '%' || + CI->getNumArgOperands() < 3) + return nullptr; + + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + // fprintf(F, "%c", chr) --> fputc(chr, F) + if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; + return EmitFPutC(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); + } - if (Value *V = optimizeFixedFormatString(Callee, CI, B)) { - return V; - } + if (FormatStr[1] == 's') { + // fprintf(F, "%s", str) --> fputs(str, F) + if (!CI->getArgOperand(2)->getType()->isPointerTy()) + return nullptr; + return EmitFPutS(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); + } + return nullptr; +} - // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no - // floating point arguments. - if (TLI->has(LibFunc::fiprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *FIPrintFFn = - M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes()); - CallInst *New = cast<CallInst>(CI->clone()); - New->setCalledFunction(FIPrintFFn); - B.Insert(New); - return New; - } +Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Require two fixed paramters as pointers and integer result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isIntegerTy()) return nullptr; + + if (Value *V = optimizeFPrintFString(CI, B)) { + return V; } -}; -struct FWriteOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - ErrorReportingOpt ER(/* StreamArg = */ 3); - (void) ER.callOptimizer(Callee, CI, B); + // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no + // floating point arguments. + if (TLI->has(LibFunc::fiprintf) && !callHasFloatingPointArgument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + Constant *FIPrintFFn = + M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(FIPrintFFn); + B.Insert(New); + return New; + } + return nullptr; +} - // Require a pointer, an integer, an integer, a pointer, returning integer. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 4 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - !FT->getParamType(2)->isIntegerTy() || - !FT->getParamType(3)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) - return nullptr; +Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilder<> &B) { + optimizeErrorReporting(CI, B, 3); - // Get the element size and count. - ConstantInt *SizeC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); - ConstantInt *CountC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); - if (!SizeC || !CountC) return nullptr; - uint64_t Bytes = SizeC->getZExtValue()*CountC->getZExtValue(); - - // If this is writing zero records, remove the call (it's a noop). - if (Bytes == 0) - return ConstantInt::get(CI->getType(), 0); - - // If this is writing one byte, turn it into fputc. - // This optimisation is only valid, if the return value is unused. - if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) - Value *Char = B.CreateLoad(CastToCStr(CI->getArgOperand(0), B), "char"); - Value *NewCI = EmitFPutC(Char, CI->getArgOperand(3), B, DL, TLI); - return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr; - } + Function *Callee = CI->getCalledFunction(); + // Require a pointer, an integer, an integer, a pointer, returning integer. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 4 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isIntegerTy() || + !FT->getParamType(2)->isIntegerTy() || + !FT->getParamType(3)->isPointerTy() || + !FT->getReturnType()->isIntegerTy()) + return nullptr; + // Get the element size and count. + ConstantInt *SizeC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + ConstantInt *CountC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); + if (!SizeC || !CountC) return nullptr; - } -}; + uint64_t Bytes = SizeC->getZExtValue() * CountC->getZExtValue(); -struct FPutsOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - ErrorReportingOpt ER(/* StreamArg = */ 1); - (void) ER.callOptimizer(Callee, CI, B); + // If this is writing zero records, remove the call (it's a noop). + if (Bytes == 0) + return ConstantInt::get(CI->getType(), 0); - // These optimizations require DataLayout. - if (!DL) return nullptr; + // If this is writing one byte, turn it into fputc. + // This optimisation is only valid, if the return value is unused. + if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) + Value *Char = B.CreateLoad(CastToCStr(CI->getArgOperand(0), B), "char"); + Value *NewCI = EmitFPutC(Char, CI->getArgOperand(3), B, DL, TLI); + return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr; + } - // Require two pointers. Also, we can't optimize if return value is used. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !CI->use_empty()) - return nullptr; + return nullptr; +} - // fputs(s,F) --> fwrite(s,1,strlen(s),F) - uint64_t Len = GetStringLength(CI->getArgOperand(0)); - if (!Len) return nullptr; - // Known to have no uses (see above). - return EmitFWrite(CI->getArgOperand(0), - ConstantInt::get(DL->getIntPtrType(*Context), Len-1), - CI->getArgOperand(1), B, DL, TLI); - } -}; - -struct PutsOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require one fixed pointer argument and an integer/void result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || - !(FT->getReturnType()->isIntegerTy() || - FT->getReturnType()->isVoidTy())) - return nullptr; +Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilder<> &B) { + optimizeErrorReporting(CI, B, 1); - // Check for a constant string. - StringRef Str; - if (!getConstantStringInfo(CI->getArgOperand(0), Str)) - return nullptr; + Function *Callee = CI->getCalledFunction(); - if (Str.empty() && CI->use_empty()) { - // puts("") -> putchar('\n') - Value *Res = EmitPutChar(B.getInt32('\n'), B, DL, TLI); - if (CI->use_empty() || !Res) return Res; - return B.CreateIntCast(Res, CI->getType(), true); - } + // These optimizations require DataLayout. + if (!DL) + return nullptr; + // Require two pointers. Also, we can't optimize if return value is used. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || !CI->use_empty()) return nullptr; - } -}; -} // End anonymous namespace. + // fputs(s,F) --> fwrite(s,1,strlen(s),F) + uint64_t Len = GetStringLength(CI->getArgOperand(0)); + if (!Len) + return nullptr; + + // Known to have no uses (see above). + return EmitFWrite( + CI->getArgOperand(0), + ConstantInt::get(DL->getIntPtrType(CI->getContext()), Len - 1), + CI->getArgOperand(1), B, DL, TLI); +} -namespace llvm { +Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Require one fixed pointer argument and an integer/void result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || + !(FT->getReturnType()->isIntegerTy() || FT->getReturnType()->isVoidTy())) + return nullptr; -class LibCallSimplifierImpl { - const DataLayout *DL; - const TargetLibraryInfo *TLI; - const LibCallSimplifier *LCS; - bool UnsafeFPShrink; + // Check for a constant string. + StringRef Str; + if (!getConstantStringInfo(CI->getArgOperand(0), Str)) + return nullptr; - // Math library call optimizations. - CosOpt Cos; - PowOpt Pow; - Exp2Opt Exp2; -public: - LibCallSimplifierImpl(const DataLayout *DL, const TargetLibraryInfo *TLI, - const LibCallSimplifier *LCS, - bool UnsafeFPShrink = false) - : Cos(UnsafeFPShrink), Pow(UnsafeFPShrink), Exp2(UnsafeFPShrink) { - this->DL = DL; - this->TLI = TLI; - this->LCS = LCS; - this->UnsafeFPShrink = UnsafeFPShrink; + if (Str.empty() && CI->use_empty()) { + // puts("") -> putchar('\n') + Value *Res = EmitPutChar(B.getInt32('\n'), B, DL, TLI); + if (CI->use_empty() || !Res) + return Res; + return B.CreateIntCast(Res, CI->getType(), true); } - Value *optimizeCall(CallInst *CI); - LibCallOptimization *lookupOptimization(CallInst *CI); - bool hasFloatVersion(StringRef FuncName); -}; + return nullptr; +} -bool LibCallSimplifierImpl::hasFloatVersion(StringRef FuncName) { +bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) { LibFunc::Func Func; SmallString<20> FloatFuncName = FuncName; FloatFuncName += 'f'; @@ -2048,263 +1995,219 @@ bool LibCallSimplifierImpl::hasFloatVersion(StringRef FuncName) { return false; } -// Fortified library call optimizations. -static MemCpyChkOpt MemCpyChk; -static MemMoveChkOpt MemMoveChk; -static MemSetChkOpt MemSetChk; -static StrCpyChkOpt StrCpyChk; -static StpCpyChkOpt StpCpyChk; -static StrNCpyChkOpt StrNCpyChk; - -// String library call optimizations. -static StrCatOpt StrCat; -static StrNCatOpt StrNCat; -static StrChrOpt StrChr; -static StrRChrOpt StrRChr; -static StrCmpOpt StrCmp; -static StrNCmpOpt StrNCmp; -static StrCpyOpt StrCpy; -static StpCpyOpt StpCpy; -static StrNCpyOpt StrNCpy; -static StrLenOpt StrLen; -static StrPBrkOpt StrPBrk; -static StrToOpt StrTo; -static StrSpnOpt StrSpn; -static StrCSpnOpt StrCSpn; -static StrStrOpt StrStr; - -// Memory library call optimizations. -static MemCmpOpt MemCmp; -static MemCpyOpt MemCpy; -static MemMoveOpt MemMove; -static MemSetOpt MemSet; - -// Math library call optimizations. -static UnaryDoubleFPOpt UnaryDoubleFP(false); -static BinaryDoubleFPOpt BinaryDoubleFP(false); -static UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); -static SinCosPiOpt SinCosPi; - - // Integer library call optimizations. -static FFSOpt FFS; -static AbsOpt Abs; -static IsDigitOpt IsDigit; -static IsAsciiOpt IsAscii; -static ToAsciiOpt ToAscii; - -// Formatting and IO library call optimizations. -static ErrorReportingOpt ErrorReporting; -static ErrorReportingOpt ErrorReporting0(0); -static ErrorReportingOpt ErrorReporting1(1); -static PrintFOpt PrintF; -static SPrintFOpt SPrintF; -static FPrintFOpt FPrintF; -static FWriteOpt FWrite; -static FPutsOpt FPuts; -static PutsOpt Puts; - -LibCallOptimization *LibCallSimplifierImpl::lookupOptimization(CallInst *CI) { +Value *LibCallSimplifier::optimizeCall(CallInst *CI) { + if (CI->isNoBuiltin()) + return nullptr; + LibFunc::Func Func; Function *Callee = CI->getCalledFunction(); StringRef FuncName = Callee->getName(); + IRBuilder<> Builder(CI); + bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C; - // Next check for intrinsics. + // Command-line parameter overrides function attribute. + if (EnableUnsafeFPShrink.getNumOccurrences() > 0) + UnsafeFPShrink = EnableUnsafeFPShrink; + else if (Callee->hasFnAttribute("unsafe-fp-math")) { + // FIXME: This is the same problem as described in optimizeSqrt(). + // If calls gain access to IR-level FMF, then use that instead of a + // function attribute. + + // Check for unsafe-fp-math = true. + Attribute Attr = Callee->getFnAttribute("unsafe-fp-math"); + if (Attr.getValueAsString() == "true") + UnsafeFPShrink = true; + } + + // First, check for intrinsics. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { + if (!isCallingConvC) + return nullptr; switch (II->getIntrinsicID()) { case Intrinsic::pow: - return &Pow; + return optimizePow(CI, Builder); case Intrinsic::exp2: - return &Exp2; + return optimizeExp2(CI, Builder); + case Intrinsic::fabs: + return optimizeFabs(CI, Builder); + case Intrinsic::sqrt: + return optimizeSqrt(CI, Builder); default: - return nullptr; + return nullptr; } } // Then check for known library functions. if (TLI->getLibFunc(FuncName, Func) && TLI->has(Func)) { + // We never change the calling convention. + if (!ignoreCallingConv(Func) && !isCallingConvC) + return nullptr; switch (Func) { - case LibFunc::strcat: - return &StrCat; - case LibFunc::strncat: - return &StrNCat; - case LibFunc::strchr: - return &StrChr; - case LibFunc::strrchr: - return &StrRChr; - case LibFunc::strcmp: - return &StrCmp; - case LibFunc::strncmp: - return &StrNCmp; - case LibFunc::strcpy: - return &StrCpy; - case LibFunc::stpcpy: - return &StpCpy; - case LibFunc::strncpy: - return &StrNCpy; - case LibFunc::strlen: - return &StrLen; - case LibFunc::strpbrk: - return &StrPBrk; - case LibFunc::strtol: - case LibFunc::strtod: - case LibFunc::strtof: - case LibFunc::strtoul: - case LibFunc::strtoll: - case LibFunc::strtold: - case LibFunc::strtoull: - return &StrTo; - case LibFunc::strspn: - return &StrSpn; - case LibFunc::strcspn: - return &StrCSpn; - case LibFunc::strstr: - return &StrStr; - case LibFunc::memcmp: - return &MemCmp; - case LibFunc::memcpy: - return &MemCpy; - case LibFunc::memmove: - return &MemMove; - case LibFunc::memset: - return &MemSet; - case LibFunc::cosf: - case LibFunc::cos: - case LibFunc::cosl: - return &Cos; - case LibFunc::sinpif: - case LibFunc::sinpi: - case LibFunc::cospif: - case LibFunc::cospi: - return &SinCosPi; - case LibFunc::powf: - case LibFunc::pow: - case LibFunc::powl: - return &Pow; - case LibFunc::exp2l: - case LibFunc::exp2: - case LibFunc::exp2f: - return &Exp2; - case LibFunc::ffs: - case LibFunc::ffsl: - case LibFunc::ffsll: - return &FFS; - case LibFunc::abs: - case LibFunc::labs: - case LibFunc::llabs: - return &Abs; - case LibFunc::isdigit: - return &IsDigit; - case LibFunc::isascii: - return &IsAscii; - case LibFunc::toascii: - return &ToAscii; - case LibFunc::printf: - return &PrintF; - case LibFunc::sprintf: - return &SPrintF; - case LibFunc::fprintf: - return &FPrintF; - case LibFunc::fwrite: - return &FWrite; - case LibFunc::fputs: - return &FPuts; - case LibFunc::puts: - return &Puts; - case LibFunc::perror: - return &ErrorReporting; - case LibFunc::vfprintf: - case LibFunc::fiprintf: - return &ErrorReporting0; - case LibFunc::fputc: - return &ErrorReporting1; - case LibFunc::ceil: - case LibFunc::fabs: - case LibFunc::floor: - case LibFunc::rint: - case LibFunc::round: - case LibFunc::nearbyint: - case LibFunc::trunc: - if (hasFloatVersion(FuncName)) - return &UnaryDoubleFP; - return nullptr; - case LibFunc::acos: - case LibFunc::acosh: - case LibFunc::asin: - case LibFunc::asinh: - case LibFunc::atan: - case LibFunc::atanh: - case LibFunc::cbrt: - case LibFunc::cosh: - case LibFunc::exp: - case LibFunc::exp10: - case LibFunc::expm1: - case LibFunc::log: - case LibFunc::log10: - case LibFunc::log1p: - case LibFunc::log2: - case LibFunc::logb: - case LibFunc::sin: - case LibFunc::sinh: - case LibFunc::sqrt: - case LibFunc::tan: - case LibFunc::tanh: - if (UnsafeFPShrink && hasFloatVersion(FuncName)) - return &UnsafeUnaryDoubleFP; - return nullptr; - case LibFunc::fmin: - case LibFunc::fmax: - if (hasFloatVersion(FuncName)) - return &BinaryDoubleFP; - return nullptr; - case LibFunc::memcpy_chk: - return &MemCpyChk; - default: - return nullptr; - } - } - - // Finally check for fortified library calls. - if (FuncName.endswith("_chk")) { - if (FuncName == "__memmove_chk") - return &MemMoveChk; - else if (FuncName == "__memset_chk") - return &MemSetChk; - else if (FuncName == "__strcpy_chk") - return &StrCpyChk; - else if (FuncName == "__stpcpy_chk") - return &StpCpyChk; - else if (FuncName == "__strncpy_chk") - return &StrNCpyChk; - else if (FuncName == "__stpncpy_chk") - return &StrNCpyChk; + case LibFunc::strcat: + return optimizeStrCat(CI, Builder); + case LibFunc::strncat: + return optimizeStrNCat(CI, Builder); + case LibFunc::strchr: + return optimizeStrChr(CI, Builder); + case LibFunc::strrchr: + return optimizeStrRChr(CI, Builder); + case LibFunc::strcmp: + return optimizeStrCmp(CI, Builder); + case LibFunc::strncmp: + return optimizeStrNCmp(CI, Builder); + case LibFunc::strcpy: + return optimizeStrCpy(CI, Builder); + case LibFunc::stpcpy: + return optimizeStpCpy(CI, Builder); + case LibFunc::strncpy: + return optimizeStrNCpy(CI, Builder); + case LibFunc::strlen: + return optimizeStrLen(CI, Builder); + case LibFunc::strpbrk: + return optimizeStrPBrk(CI, Builder); + case LibFunc::strtol: + case LibFunc::strtod: + case LibFunc::strtof: + case LibFunc::strtoul: + case LibFunc::strtoll: + case LibFunc::strtold: + case LibFunc::strtoull: + return optimizeStrTo(CI, Builder); + case LibFunc::strspn: + return optimizeStrSpn(CI, Builder); + case LibFunc::strcspn: + return optimizeStrCSpn(CI, Builder); + case LibFunc::strstr: + return optimizeStrStr(CI, Builder); + case LibFunc::memcmp: + return optimizeMemCmp(CI, Builder); + case LibFunc::memcpy: + return optimizeMemCpy(CI, Builder); + case LibFunc::memmove: + return optimizeMemMove(CI, Builder); + case LibFunc::memset: + return optimizeMemSet(CI, Builder); + case LibFunc::cosf: + case LibFunc::cos: + case LibFunc::cosl: + return optimizeCos(CI, Builder); + case LibFunc::sinpif: + case LibFunc::sinpi: + case LibFunc::cospif: + case LibFunc::cospi: + return optimizeSinCosPi(CI, Builder); + case LibFunc::powf: + case LibFunc::pow: + case LibFunc::powl: + return optimizePow(CI, Builder); + case LibFunc::exp2l: + case LibFunc::exp2: + case LibFunc::exp2f: + return optimizeExp2(CI, Builder); + case LibFunc::fabsf: + case LibFunc::fabs: + case LibFunc::fabsl: + return optimizeFabs(CI, Builder); + case LibFunc::sqrtf: + case LibFunc::sqrt: + case LibFunc::sqrtl: + return optimizeSqrt(CI, Builder); + case LibFunc::ffs: + case LibFunc::ffsl: + case LibFunc::ffsll: + return optimizeFFS(CI, Builder); + case LibFunc::abs: + case LibFunc::labs: + case LibFunc::llabs: + return optimizeAbs(CI, Builder); + case LibFunc::isdigit: + return optimizeIsDigit(CI, Builder); + case LibFunc::isascii: + return optimizeIsAscii(CI, Builder); + case LibFunc::toascii: + return optimizeToAscii(CI, Builder); + case LibFunc::printf: + return optimizePrintF(CI, Builder); + case LibFunc::sprintf: + return optimizeSPrintF(CI, Builder); + case LibFunc::fprintf: + return optimizeFPrintF(CI, Builder); + case LibFunc::fwrite: + return optimizeFWrite(CI, Builder); + case LibFunc::fputs: + return optimizeFPuts(CI, Builder); + case LibFunc::puts: + return optimizePuts(CI, Builder); + case LibFunc::perror: + return optimizeErrorReporting(CI, Builder); + case LibFunc::vfprintf: + case LibFunc::fiprintf: + return optimizeErrorReporting(CI, Builder, 0); + case LibFunc::fputc: + return optimizeErrorReporting(CI, Builder, 1); + case LibFunc::ceil: + case LibFunc::floor: + case LibFunc::rint: + case LibFunc::round: + case LibFunc::nearbyint: + case LibFunc::trunc: + if (hasFloatVersion(FuncName)) + return optimizeUnaryDoubleFP(CI, Builder, false); + return nullptr; + case LibFunc::acos: + case LibFunc::acosh: + case LibFunc::asin: + case LibFunc::asinh: + case LibFunc::atan: + case LibFunc::atanh: + case LibFunc::cbrt: + case LibFunc::cosh: + case LibFunc::exp: + case LibFunc::exp10: + case LibFunc::expm1: + case LibFunc::log: + case LibFunc::log10: + case LibFunc::log1p: + case LibFunc::log2: + case LibFunc::logb: + case LibFunc::sin: + case LibFunc::sinh: + case LibFunc::tan: + case LibFunc::tanh: + if (UnsafeFPShrink && hasFloatVersion(FuncName)) + return optimizeUnaryDoubleFP(CI, Builder, true); + return nullptr; + case LibFunc::fmin: + case LibFunc::fmax: + if (hasFloatVersion(FuncName)) + return optimizeBinaryDoubleFP(CI, Builder); + return nullptr; + case LibFunc::memcpy_chk: + return optimizeMemCpyChk(CI, Builder); + case LibFunc::memmove_chk: + return optimizeMemMoveChk(CI, Builder); + case LibFunc::memset_chk: + return optimizeMemSetChk(CI, Builder); + case LibFunc::strcpy_chk: + return optimizeStrCpyChk(CI, Builder); + case LibFunc::stpcpy_chk: + return optimizeStpCpyChk(CI, Builder); + case LibFunc::stpncpy_chk: + case LibFunc::strncpy_chk: + return optimizeStrNCpyChk(CI, Builder); + default: + return nullptr; + } } return nullptr; - -} - -Value *LibCallSimplifierImpl::optimizeCall(CallInst *CI) { - LibCallOptimization *LCO = lookupOptimization(CI); - if (LCO) { - IRBuilder<> Builder(CI); - return LCO->optimizeCall(CI, DL, TLI, LCS, Builder); - } - return nullptr; } LibCallSimplifier::LibCallSimplifier(const DataLayout *DL, - const TargetLibraryInfo *TLI, - bool UnsafeFPShrink) { - Impl = new LibCallSimplifierImpl(DL, TLI, this, UnsafeFPShrink); -} - -LibCallSimplifier::~LibCallSimplifier() { - delete Impl; -} - -Value *LibCallSimplifier::optimizeCall(CallInst *CI) { - if (CI->isNoBuiltin()) return nullptr; - return Impl->optimizeCall(CI); + const TargetLibraryInfo *TLI) : + DL(DL), + TLI(TLI), + UnsafeFPShrink(false) { } void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) const { @@ -2312,8 +2215,6 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) const { I->eraseFromParent(); } -} - // TODO: // Additional cases that we need to add to this file: // diff --git a/lib/Transforms/Utils/SymbolRewriter.cpp b/lib/Transforms/Utils/SymbolRewriter.cpp new file mode 100644 index 0000000..aacc945 --- /dev/null +++ b/lib/Transforms/Utils/SymbolRewriter.cpp @@ -0,0 +1,525 @@ +//===- SymbolRewriter.cpp - Symbol Rewriter ---------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// SymbolRewriter is a LLVM pass which can rewrite symbols transparently within +// existing code. It is implemented as a compiler pass and is configured via a +// YAML configuration file. +// +// The YAML configuration file format is as follows: +// +// RewriteMapFile := RewriteDescriptors +// RewriteDescriptors := RewriteDescriptor | RewriteDescriptors +// RewriteDescriptor := RewriteDescriptorType ':' '{' RewriteDescriptorFields '}' +// RewriteDescriptorFields := RewriteDescriptorField | RewriteDescriptorFields +// RewriteDescriptorField := FieldIdentifier ':' FieldValue ',' +// RewriteDescriptorType := Identifier +// FieldIdentifier := Identifier +// FieldValue := Identifier +// Identifier := [0-9a-zA-Z]+ +// +// Currently, the following descriptor types are supported: +// +// - function: (function rewriting) +// + Source (original name of the function) +// + Target (explicit transformation) +// + Transform (pattern transformation) +// + Naked (boolean, whether the function is undecorated) +// - global variable: (external linkage global variable rewriting) +// + Source (original name of externally visible variable) +// + Target (explicit transformation) +// + Transform (pattern transformation) +// - global alias: (global alias rewriting) +// + Source (original name of the aliased name) +// + Target (explicit transformation) +// + Transform (pattern transformation) +// +// Note that source and exactly one of [Target, Transform] must be provided +// +// New rewrite descriptors can be created. Addding a new rewrite descriptor +// involves: +// +// a) extended the rewrite descriptor kind enumeration +// (<anonymous>::RewriteDescriptor::RewriteDescriptorType) +// b) implementing the new descriptor +// (c.f. <anonymous>::ExplicitRewriteFunctionDescriptor) +// c) extending the rewrite map parser +// (<anonymous>::RewriteMapParser::parseEntry) +// +// Specify to rewrite the symbols using the `-rewrite-symbols` option, and +// specify the map file to use for the rewriting via the `-rewrite-map-file` +// option. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "symbol-rewriter" +#include "llvm/CodeGen/Passes.h" +#include "llvm/Pass.h" +#include "llvm/PassManager.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/YAMLParser.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "llvm/Transforms/Utils/SymbolRewriter.h" + +using namespace llvm; + +static cl::list<std::string> RewriteMapFiles("rewrite-map-file", + cl::desc("Symbol Rewrite Map"), + cl::value_desc("filename")); + +namespace llvm { +namespace SymbolRewriter { +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const> +class ExplicitRewriteDescriptor : public RewriteDescriptor { +public: + const std::string Source; + const std::string Target; + + ExplicitRewriteDescriptor(StringRef S, StringRef T, const bool Naked) + : RewriteDescriptor(DT), Source(Naked ? StringRef("\01" + S.str()) : S), + Target(T) {} + + bool performOnModule(Module &M) override; + + static bool classof(const RewriteDescriptor *RD) { + return RD->getType() == DT; + } +}; + +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const> +bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) { + bool Changed = false; + if (ValueType *S = (M.*Get)(Source)) { + if (Value *T = (M.*Get)(Target)) + S->setValueName(T->getValueName()); + else + S->setName(Target); + Changed = true; + } + return Changed; +} + +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const, + iterator_range<typename iplist<ValueType>::iterator> (llvm::Module::*Iterator)()> +class PatternRewriteDescriptor : public RewriteDescriptor { +public: + const std::string Pattern; + const std::string Transform; + + PatternRewriteDescriptor(StringRef P, StringRef T) + : RewriteDescriptor(DT), Pattern(P), Transform(T) { } + + bool performOnModule(Module &M) override; + + static bool classof(const RewriteDescriptor *RD) { + return RD->getType() == DT; + } +}; + +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const, + iterator_range<typename iplist<ValueType>::iterator> (llvm::Module::*Iterator)()> +bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>:: +performOnModule(Module &M) { + bool Changed = false; + for (auto &C : (M.*Iterator)()) { + std::string Error; + + std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error); + if (!Error.empty()) + report_fatal_error("unable to transforn " + C.getName() + " in " + + M.getModuleIdentifier() + ": " + Error); + + if (Value *V = (M.*Get)(Name)) + C.setValueName(V->getValueName()); + else + C.setName(Name); + + Changed = true; + } + return Changed; +} + +/// Represents a rewrite for an explicitly named (function) symbol. Both the +/// source function name and target function name of the transformation are +/// explicitly spelt out. +typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function, + llvm::Function, &llvm::Module::getFunction> + ExplicitRewriteFunctionDescriptor; + +/// Represents a rewrite for an explicitly named (global variable) symbol. Both +/// the source variable name and target variable name are spelt out. This +/// applies only to module level variables. +typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable, + llvm::GlobalVariable, + &llvm::Module::getGlobalVariable> + ExplicitRewriteGlobalVariableDescriptor; + +/// Represents a rewrite for an explicitly named global alias. Both the source +/// and target name are explicitly spelt out. +typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, + llvm::GlobalAlias, + &llvm::Module::getNamedAlias> + ExplicitRewriteNamedAliasDescriptor; + +/// Represents a rewrite for a regular expression based pattern for functions. +/// A pattern for the function name is provided and a transformation for that +/// pattern to determine the target function name create the rewrite rule. +typedef PatternRewriteDescriptor<RewriteDescriptor::Type::Function, + llvm::Function, &llvm::Module::getFunction, + &llvm::Module::functions> + PatternRewriteFunctionDescriptor; + +/// Represents a rewrite for a global variable based upon a matching pattern. +/// Each global variable matching the provided pattern will be transformed as +/// described in the transformation pattern for the target. Applies only to +/// module level variables. +typedef PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable, + llvm::GlobalVariable, + &llvm::Module::getGlobalVariable, + &llvm::Module::globals> + PatternRewriteGlobalVariableDescriptor; + +/// PatternRewriteNamedAliasDescriptor - represents a rewrite for global +/// aliases which match a given pattern. The provided transformation will be +/// applied to each of the matching names. +typedef PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, + llvm::GlobalAlias, + &llvm::Module::getNamedAlias, + &llvm::Module::aliases> + PatternRewriteNamedAliasDescriptor; + +bool RewriteMapParser::parse(const std::string &MapFile, + RewriteDescriptorList *DL) { + ErrorOr<std::unique_ptr<MemoryBuffer>> Mapping = + MemoryBuffer::getFile(MapFile); + + if (!Mapping) + report_fatal_error("unable to read rewrite map '" + MapFile + "': " + + Mapping.getError().message()); + + if (!parse(*Mapping, DL)) + report_fatal_error("unable to parse rewrite map '" + MapFile + "'"); + + return true; +} + +bool RewriteMapParser::parse(std::unique_ptr<MemoryBuffer> &MapFile, + RewriteDescriptorList *DL) { + SourceMgr SM; + yaml::Stream YS(MapFile->getBuffer(), SM); + + for (auto &Document : YS) { + yaml::MappingNode *DescriptorList; + + // ignore empty documents + if (isa<yaml::NullNode>(Document.getRoot())) + continue; + + DescriptorList = dyn_cast<yaml::MappingNode>(Document.getRoot()); + if (!DescriptorList) { + YS.printError(Document.getRoot(), "DescriptorList node must be a map"); + return false; + } + + for (auto &Descriptor : *DescriptorList) + if (!parseEntry(YS, Descriptor, DL)) + return false; + } + + return true; +} + +bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry, + RewriteDescriptorList *DL) { + yaml::ScalarNode *Key; + yaml::MappingNode *Value; + SmallString<32> KeyStorage; + StringRef RewriteType; + + Key = dyn_cast<yaml::ScalarNode>(Entry.getKey()); + if (!Key) { + YS.printError(Entry.getKey(), "rewrite type must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::MappingNode>(Entry.getValue()); + if (!Value) { + YS.printError(Entry.getValue(), "rewrite descriptor must be a map"); + return false; + } + + RewriteType = Key->getValue(KeyStorage); + if (RewriteType.equals("function")) + return parseRewriteFunctionDescriptor(YS, Key, Value, DL); + else if (RewriteType.equals("global variable")) + return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL); + else if (RewriteType.equals("global alias")) + return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL); + + YS.printError(Entry.getKey(), "unknown rewrite type"); + return false; +} + +bool RewriteMapParser:: +parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, + yaml::MappingNode *Descriptor, + RewriteDescriptorList *DL) { + bool Naked = false; + std::string Source; + std::string Target; + std::string Transform; + + for (auto &Field : *Descriptor) { + yaml::ScalarNode *Key; + yaml::ScalarNode *Value; + SmallString<32> KeyStorage; + SmallString<32> ValueStorage; + StringRef KeyValue; + + Key = dyn_cast<yaml::ScalarNode>(Field.getKey()); + if (!Key) { + YS.printError(Field.getKey(), "descriptor key must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::ScalarNode>(Field.getValue()); + if (!Value) { + YS.printError(Field.getValue(), "descriptor value must be a scalar"); + return false; + } + + KeyValue = Key->getValue(KeyStorage); + if (KeyValue.equals("source")) { + std::string Error; + + Source = Value->getValue(ValueStorage); + if (!Regex(Source).isValid(Error)) { + YS.printError(Field.getKey(), "invalid regex: " + Error); + return false; + } + } else if (KeyValue.equals("target")) { + Target = Value->getValue(ValueStorage); + } else if (KeyValue.equals("transform")) { + Transform = Value->getValue(ValueStorage); + } else if (KeyValue.equals("naked")) { + std::string Undecorated; + + Undecorated = Value->getValue(ValueStorage); + Naked = StringRef(Undecorated).lower() == "true" || Undecorated == "1"; + } else { + YS.printError(Field.getKey(), "unknown key for function"); + return false; + } + } + + if (Transform.empty() == Target.empty()) { + YS.printError(Descriptor, + "exactly one of transform or target must be specified"); + return false; + } + + // TODO see if there is a more elegant solution to selecting the rewrite + // descriptor type + if (!Target.empty()) + DL->push_back(new ExplicitRewriteFunctionDescriptor(Source, Target, Naked)); + else + DL->push_back(new PatternRewriteFunctionDescriptor(Source, Transform)); + + return true; +} + +bool RewriteMapParser:: +parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, + yaml::MappingNode *Descriptor, + RewriteDescriptorList *DL) { + std::string Source; + std::string Target; + std::string Transform; + + for (auto &Field : *Descriptor) { + yaml::ScalarNode *Key; + yaml::ScalarNode *Value; + SmallString<32> KeyStorage; + SmallString<32> ValueStorage; + StringRef KeyValue; + + Key = dyn_cast<yaml::ScalarNode>(Field.getKey()); + if (!Key) { + YS.printError(Field.getKey(), "descriptor Key must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::ScalarNode>(Field.getValue()); + if (!Value) { + YS.printError(Field.getValue(), "descriptor value must be a scalar"); + return false; + } + + KeyValue = Key->getValue(KeyStorage); + if (KeyValue.equals("source")) { + std::string Error; + + Source = Value->getValue(ValueStorage); + if (!Regex(Source).isValid(Error)) { + YS.printError(Field.getKey(), "invalid regex: " + Error); + return false; + } + } else if (KeyValue.equals("target")) { + Target = Value->getValue(ValueStorage); + } else if (KeyValue.equals("transform")) { + Transform = Value->getValue(ValueStorage); + } else { + YS.printError(Field.getKey(), "unknown Key for Global Variable"); + return false; + } + } + + if (Transform.empty() == Target.empty()) { + YS.printError(Descriptor, + "exactly one of transform or target must be specified"); + return false; + } + + if (!Target.empty()) + DL->push_back(new ExplicitRewriteGlobalVariableDescriptor(Source, Target, + /*Naked*/false)); + else + DL->push_back(new PatternRewriteGlobalVariableDescriptor(Source, + Transform)); + + return true; +} + +bool RewriteMapParser:: +parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, + yaml::MappingNode *Descriptor, + RewriteDescriptorList *DL) { + std::string Source; + std::string Target; + std::string Transform; + + for (auto &Field : *Descriptor) { + yaml::ScalarNode *Key; + yaml::ScalarNode *Value; + SmallString<32> KeyStorage; + SmallString<32> ValueStorage; + StringRef KeyValue; + + Key = dyn_cast<yaml::ScalarNode>(Field.getKey()); + if (!Key) { + YS.printError(Field.getKey(), "descriptor key must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::ScalarNode>(Field.getValue()); + if (!Value) { + YS.printError(Field.getValue(), "descriptor value must be a scalar"); + return false; + } + + KeyValue = Key->getValue(KeyStorage); + if (KeyValue.equals("source")) { + std::string Error; + + Source = Value->getValue(ValueStorage); + if (!Regex(Source).isValid(Error)) { + YS.printError(Field.getKey(), "invalid regex: " + Error); + return false; + } + } else if (KeyValue.equals("target")) { + Target = Value->getValue(ValueStorage); + } else if (KeyValue.equals("transform")) { + Transform = Value->getValue(ValueStorage); + } else { + YS.printError(Field.getKey(), "unknown key for Global Alias"); + return false; + } + } + + if (Transform.empty() == Target.empty()) { + YS.printError(Descriptor, + "exactly one of transform or target must be specified"); + return false; + } + + if (!Target.empty()) + DL->push_back(new ExplicitRewriteNamedAliasDescriptor(Source, Target, + /*Naked*/false)); + else + DL->push_back(new PatternRewriteNamedAliasDescriptor(Source, Transform)); + + return true; +} +} +} + +namespace { +class RewriteSymbols : public ModulePass { +public: + static char ID; // Pass identification, replacement for typeid + + RewriteSymbols(); + RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL); + + bool runOnModule(Module &M) override; + +private: + void loadAndParseMapFiles(); + + SymbolRewriter::RewriteDescriptorList Descriptors; +}; + +char RewriteSymbols::ID = 0; + +RewriteSymbols::RewriteSymbols() : ModulePass(ID) { + initializeRewriteSymbolsPass(*PassRegistry::getPassRegistry()); + loadAndParseMapFiles(); +} + +RewriteSymbols::RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL) + : ModulePass(ID) { + std::swap(Descriptors, DL); +} + +bool RewriteSymbols::runOnModule(Module &M) { + bool Changed; + + Changed = false; + for (auto &Descriptor : Descriptors) + Changed |= Descriptor.performOnModule(M); + + return Changed; +} + +void RewriteSymbols::loadAndParseMapFiles() { + const std::vector<std::string> MapFiles(RewriteMapFiles); + SymbolRewriter::RewriteMapParser parser; + + for (const auto &MapFile : MapFiles) + parser.parse(MapFile, &Descriptors); +} +} + +INITIALIZE_PASS(RewriteSymbols, "rewrite-symbols", "Rewrite Symbols", false, + false) + +ModulePass *llvm::createRewriteSymbolsPass() { return new RewriteSymbols(); } + +ModulePass * +llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) { + return new RewriteSymbols(DL); +} diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp index 0f20e6d..a2f69d1 100644 --- a/lib/Transforms/Utils/ValueMapper.cpp +++ b/lib/Transforms/Utils/ValueMapper.cpp @@ -210,8 +210,10 @@ void llvm::RemapInstruction(Instruction *I, ValueToValueMapTy &VMap, // Remap attached metadata. SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; I->getAllMetadata(MDs); - for (SmallVectorImpl<std::pair<unsigned, MDNode *> >::iterator - MI = MDs.begin(), ME = MDs.end(); MI != ME; ++MI) { + for (SmallVectorImpl<std::pair<unsigned, MDNode *>>::iterator + MI = MDs.begin(), + ME = MDs.end(); + MI != ME; ++MI) { MDNode *Old = MI->second; MDNode *New = MapValue(Old, VMap, Flags, TypeMapper, Materializer); if (New != Old) |