From f17a25c88b892d30c2b41ba7ecdfbdfb2b4be9cc Mon Sep 17 00:00:00 2001 From: Dan Gohman Date: Wed, 18 Jul 2007 16:29:46 +0000 Subject: It's not necessary to do rounding for alloca operations when the requested alignment is equal to the stack alignment. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@40004 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/Hello/Hello.cpp | 64 + lib/Transforms/Hello/Makefile | 16 + lib/Transforms/IPO/ArgumentPromotion.cpp | 559 + lib/Transforms/IPO/ConstantMerge.cpp | 116 + lib/Transforms/IPO/DeadArgumentElimination.cpp | 703 ++ lib/Transforms/IPO/DeadTypeElimination.cpp | 106 + lib/Transforms/IPO/ExtractFunction.cpp | 144 + lib/Transforms/IPO/GlobalDCE.cpp | 203 + lib/Transforms/IPO/GlobalOpt.cpp | 1988 ++++ lib/Transforms/IPO/IPConstantPropagation.cpp | 197 + lib/Transforms/IPO/IndMemRemoval.cpp | 89 + lib/Transforms/IPO/InlineSimple.cpp | 323 + lib/Transforms/IPO/Inliner.cpp | 217 + lib/Transforms/IPO/Internalize.cpp | 154 + lib/Transforms/IPO/LoopExtractor.cpp | 201 + lib/Transforms/IPO/LowerSetJmp.cpp | 534 + lib/Transforms/IPO/Makefile | 15 + lib/Transforms/IPO/PruneEH.cpp | 233 + lib/Transforms/IPO/RaiseAllocations.cpp | 249 + lib/Transforms/IPO/SimplifyLibCalls.cpp | 2021 ++++ lib/Transforms/IPO/StripDeadPrototypes.cpp | 70 + lib/Transforms/IPO/StripSymbols.cpp | 206 + lib/Transforms/Instrumentation/BlockProfiling.cpp | 126 + lib/Transforms/Instrumentation/EdgeProfiling.cpp | 101 + lib/Transforms/Instrumentation/Makefile | 15 + lib/Transforms/Instrumentation/ProfilingUtils.cpp | 119 + lib/Transforms/Instrumentation/ProfilingUtils.h | 31 + lib/Transforms/Instrumentation/RSProfiling.cpp | 650 ++ lib/Transforms/Instrumentation/RSProfiling.h | 31 + lib/Transforms/Makefile | 14 + lib/Transforms/Scalar/ADCE.cpp | 497 + lib/Transforms/Scalar/BasicBlockPlacement.cpp | 148 + lib/Transforms/Scalar/CodeGenPrepare.cpp | 988 ++ lib/Transforms/Scalar/CondPropagate.cpp | 219 + lib/Transforms/Scalar/ConstantProp.cpp | 90 + lib/Transforms/Scalar/CorrelatedExprs.cpp | 1487 +++ lib/Transforms/Scalar/DCE.cpp | 130 + lib/Transforms/Scalar/DeadStoreElimination.cpp | 179 + lib/Transforms/Scalar/FastDSE.cpp | 387 + lib/Transforms/Scalar/GCSE.cpp | 201 + lib/Transforms/Scalar/GVNPRE.cpp | 1819 ++++ lib/Transforms/Scalar/IndVarSimplify.cpp | 604 ++ lib/Transforms/Scalar/InstructionCombining.cpp | 10090 +++++++++++++++++++ lib/Transforms/Scalar/LICM.cpp | 797 ++ lib/Transforms/Scalar/LoopRotation.cpp | 579 ++ lib/Transforms/Scalar/LoopStrengthReduce.cpp | 1504 +++ lib/Transforms/Scalar/LoopUnroll.cpp | 500 + lib/Transforms/Scalar/LoopUnswitch.cpp | 1074 ++ lib/Transforms/Scalar/LowerGC.cpp | 330 + lib/Transforms/Scalar/LowerPacked.cpp | 462 + lib/Transforms/Scalar/Makefile | 15 + lib/Transforms/Scalar/PredicateSimplifier.cpp | 2640 +++++ lib/Transforms/Scalar/Reassociate.cpp | 868 ++ lib/Transforms/Scalar/Reg2Mem.cpp | 91 + lib/Transforms/Scalar/SCCP.cpp | 1691 ++++ lib/Transforms/Scalar/ScalarReplAggregates.cpp | 1335 +++ lib/Transforms/Scalar/SimplifyCFG.cpp | 145 + lib/Transforms/Scalar/TailDuplication.cpp | 364 + lib/Transforms/Scalar/TailRecursionElimination.cpp | 462 + lib/Transforms/Utils/BasicBlockUtils.cpp | 175 + lib/Transforms/Utils/BreakCriticalEdges.cpp | 269 + lib/Transforms/Utils/CloneFunction.cpp | 485 + lib/Transforms/Utils/CloneModule.cpp | 124 + lib/Transforms/Utils/CloneTrace.cpp | 120 + lib/Transforms/Utils/CodeExtractor.cpp | 737 ++ lib/Transforms/Utils/DemoteRegToStack.cpp | 133 + lib/Transforms/Utils/InlineFunction.cpp | 496 + lib/Transforms/Utils/LCSSA.cpp | 269 + lib/Transforms/Utils/Local.cpp | 200 + lib/Transforms/Utils/LoopSimplify.cpp | 692 ++ lib/Transforms/Utils/LowerAllocations.cpp | 176 + lib/Transforms/Utils/LowerInvoke.cpp | 585 ++ lib/Transforms/Utils/LowerSelect.cpp | 105 + lib/Transforms/Utils/LowerSwitch.cpp | 324 + lib/Transforms/Utils/Makefile | 15 + lib/Transforms/Utils/Mem2Reg.cpp | 93 + lib/Transforms/Utils/PromoteMemoryToRegister.cpp | 835 ++ lib/Transforms/Utils/SimplifyCFG.cpp | 1905 ++++ lib/Transforms/Utils/UnifyFunctionExitNodes.cpp | 138 + lib/Transforms/Utils/ValueMapper.cpp | 118 + lib/Transforms/Utils/ValueMapper.h | 29 + 81 files changed, 47214 insertions(+) create mode 100644 lib/Transforms/Hello/Hello.cpp create mode 100644 lib/Transforms/Hello/Makefile create mode 100644 lib/Transforms/IPO/ArgumentPromotion.cpp create mode 100644 lib/Transforms/IPO/ConstantMerge.cpp create mode 100644 lib/Transforms/IPO/DeadArgumentElimination.cpp create mode 100644 lib/Transforms/IPO/DeadTypeElimination.cpp create mode 100644 lib/Transforms/IPO/ExtractFunction.cpp create mode 100644 lib/Transforms/IPO/GlobalDCE.cpp create mode 100644 lib/Transforms/IPO/GlobalOpt.cpp create mode 100644 lib/Transforms/IPO/IPConstantPropagation.cpp create mode 100644 lib/Transforms/IPO/IndMemRemoval.cpp create mode 100644 lib/Transforms/IPO/InlineSimple.cpp create mode 100644 lib/Transforms/IPO/Inliner.cpp create mode 100644 lib/Transforms/IPO/Internalize.cpp create mode 100644 lib/Transforms/IPO/LoopExtractor.cpp create mode 100644 lib/Transforms/IPO/LowerSetJmp.cpp create mode 100644 lib/Transforms/IPO/Makefile create mode 100644 lib/Transforms/IPO/PruneEH.cpp create mode 100644 lib/Transforms/IPO/RaiseAllocations.cpp create mode 100644 lib/Transforms/IPO/SimplifyLibCalls.cpp create mode 100644 lib/Transforms/IPO/StripDeadPrototypes.cpp create mode 100644 lib/Transforms/IPO/StripSymbols.cpp create mode 100644 lib/Transforms/Instrumentation/BlockProfiling.cpp create mode 100644 lib/Transforms/Instrumentation/EdgeProfiling.cpp create mode 100644 lib/Transforms/Instrumentation/Makefile create mode 100644 lib/Transforms/Instrumentation/ProfilingUtils.cpp create mode 100644 lib/Transforms/Instrumentation/ProfilingUtils.h create mode 100644 lib/Transforms/Instrumentation/RSProfiling.cpp create mode 100644 lib/Transforms/Instrumentation/RSProfiling.h create mode 100644 lib/Transforms/Makefile create mode 100644 lib/Transforms/Scalar/ADCE.cpp create mode 100644 lib/Transforms/Scalar/BasicBlockPlacement.cpp create mode 100644 lib/Transforms/Scalar/CodeGenPrepare.cpp create mode 100644 lib/Transforms/Scalar/CondPropagate.cpp create mode 100644 lib/Transforms/Scalar/ConstantProp.cpp create mode 100644 lib/Transforms/Scalar/CorrelatedExprs.cpp create mode 100644 lib/Transforms/Scalar/DCE.cpp create mode 100644 lib/Transforms/Scalar/DeadStoreElimination.cpp create mode 100644 lib/Transforms/Scalar/FastDSE.cpp create mode 100644 lib/Transforms/Scalar/GCSE.cpp create mode 100644 lib/Transforms/Scalar/GVNPRE.cpp create mode 100644 lib/Transforms/Scalar/IndVarSimplify.cpp create mode 100644 lib/Transforms/Scalar/InstructionCombining.cpp create mode 100644 lib/Transforms/Scalar/LICM.cpp create mode 100644 lib/Transforms/Scalar/LoopRotation.cpp create mode 100644 lib/Transforms/Scalar/LoopStrengthReduce.cpp create mode 100644 lib/Transforms/Scalar/LoopUnroll.cpp create mode 100644 lib/Transforms/Scalar/LoopUnswitch.cpp create mode 100644 lib/Transforms/Scalar/LowerGC.cpp create mode 100644 lib/Transforms/Scalar/LowerPacked.cpp create mode 100644 lib/Transforms/Scalar/Makefile create mode 100644 lib/Transforms/Scalar/PredicateSimplifier.cpp create mode 100644 lib/Transforms/Scalar/Reassociate.cpp create mode 100644 lib/Transforms/Scalar/Reg2Mem.cpp create mode 100644 lib/Transforms/Scalar/SCCP.cpp create mode 100644 lib/Transforms/Scalar/ScalarReplAggregates.cpp create mode 100644 lib/Transforms/Scalar/SimplifyCFG.cpp create mode 100644 lib/Transforms/Scalar/TailDuplication.cpp create mode 100644 lib/Transforms/Scalar/TailRecursionElimination.cpp create mode 100644 lib/Transforms/Utils/BasicBlockUtils.cpp create mode 100644 lib/Transforms/Utils/BreakCriticalEdges.cpp create mode 100644 lib/Transforms/Utils/CloneFunction.cpp create mode 100644 lib/Transforms/Utils/CloneModule.cpp create mode 100644 lib/Transforms/Utils/CloneTrace.cpp create mode 100644 lib/Transforms/Utils/CodeExtractor.cpp create mode 100644 lib/Transforms/Utils/DemoteRegToStack.cpp create mode 100644 lib/Transforms/Utils/InlineFunction.cpp create mode 100644 lib/Transforms/Utils/LCSSA.cpp create mode 100644 lib/Transforms/Utils/Local.cpp create mode 100644 lib/Transforms/Utils/LoopSimplify.cpp create mode 100644 lib/Transforms/Utils/LowerAllocations.cpp create mode 100644 lib/Transforms/Utils/LowerInvoke.cpp create mode 100644 lib/Transforms/Utils/LowerSelect.cpp create mode 100644 lib/Transforms/Utils/LowerSwitch.cpp create mode 100644 lib/Transforms/Utils/Makefile create mode 100644 lib/Transforms/Utils/Mem2Reg.cpp create mode 100644 lib/Transforms/Utils/PromoteMemoryToRegister.cpp create mode 100644 lib/Transforms/Utils/SimplifyCFG.cpp create mode 100644 lib/Transforms/Utils/UnifyFunctionExitNodes.cpp create mode 100644 lib/Transforms/Utils/ValueMapper.cpp create mode 100644 lib/Transforms/Utils/ValueMapper.h (limited to 'lib/Transforms') diff --git a/lib/Transforms/Hello/Hello.cpp b/lib/Transforms/Hello/Hello.cpp new file mode 100644 index 0000000..a437215 --- /dev/null +++ b/lib/Transforms/Hello/Hello.cpp @@ -0,0 +1,64 @@ +//===- Hello.cpp - Example code from "Writing an LLVM Pass" ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements two versions of the LLVM "Hello World" pass described +// in docs/WritingAnLLVMPass.html +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "hello" +#include "llvm/Pass.h" +#include "llvm/Function.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Streams.h" +#include "llvm/ADT/Statistic.h" +using namespace llvm; + +STATISTIC(HelloCounter, "Counts number of functions greeted"); + +namespace { + // Hello - The first implementation, without getAnalysisUsage. + struct Hello : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + Hello() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F) { + HelloCounter++; + std::string fname = F.getName(); + EscapeString(fname); + cerr << "Hello: " << fname << "\n"; + return false; + } + }; + + char Hello::ID = 0; + RegisterPass X("hello", "Hello World Pass"); + + // Hello2 - The second implementation with getAnalysisUsage implemented. + struct Hello2 : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + Hello2() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F) { + HelloCounter++; + std::string fname = F.getName(); + EscapeString(fname); + cerr << "Hello: " << fname << "\n"; + return false; + } + + // We don't modify the program, so we preserve all analyses + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + }; + }; + char Hello2::ID = 0; + RegisterPass Y("hello2", + "Hello World Pass (with getAnalysisUsage implemented)"); +} diff --git a/lib/Transforms/Hello/Makefile b/lib/Transforms/Hello/Makefile new file mode 100644 index 0000000..0a02fe9 --- /dev/null +++ b/lib/Transforms/Hello/Makefile @@ -0,0 +1,16 @@ +##===- lib/Transforms/Hello/Makefile -----------------------*- Makefile -*-===## +# +# The LLVM Compiler Infrastructure +# +# This file was developed by the LLVM research group and is distributed under +# the University of Illinois Open Source License. See LICENSE.TXT for details. +# +##===----------------------------------------------------------------------===## + +LEVEL = ../../.. +LIBRARYNAME = LLVMHello +LOADABLE_MODULE = 1 +USEDLIBS = + +include $(LEVEL)/Makefile.common + diff --git a/lib/Transforms/IPO/ArgumentPromotion.cpp b/lib/Transforms/IPO/ArgumentPromotion.cpp new file mode 100644 index 0000000..9a7bcc7 --- /dev/null +++ b/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -0,0 +1,559 @@ +//===-- ArgumentPromotion.cpp - Promote by-reference arguments ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass promotes "by reference" arguments to be "by value" arguments. In +// practice, this means looking for internal functions that have pointer +// arguments. If we can prove, through the use of alias analysis, that an +// argument is *only* loaded, then we can pass the value into the function +// instead of the address of the value. This can cause recursive simplification +// of code and lead to the elimination of allocas (especially in C++ template +// code like the STL). +// +// This pass also handles aggregate arguments that are passed into a function, +// scalarizing them if the elements of the aggregate are only loaded. Note that +// we refuse to scalarize aggregates which would require passing in more than +// three operands to the function, because we don't want to pass thousands of +// operands for a large array or structure! +// +// Note that this transformation could also be done for arguments that are only +// stored to (returning the value instead), but we do not currently handle that +// case. This case would be best handled when and if we start supporting +// multiple return values from functions. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "argpromotion" +#include "llvm/Transforms/IPO.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Module.h" +#include "llvm/CallGraphSCCPass.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Compiler.h" +#include +using namespace llvm; + +STATISTIC(NumArgumentsPromoted , "Number of pointer arguments promoted"); +STATISTIC(NumAggregatesPromoted, "Number of aggregate arguments promoted"); +STATISTIC(NumArgumentsDead , "Number of dead pointer args eliminated"); + +namespace { + /// ArgPromotion - The 'by reference' to 'by value' argument promotion pass. + /// + struct VISIBILITY_HIDDEN ArgPromotion : public CallGraphSCCPass { + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.addRequired(); + CallGraphSCCPass::getAnalysisUsage(AU); + } + + virtual bool runOnSCC(const std::vector &SCC); + static char ID; // Pass identification, replacement for typeid + ArgPromotion() : CallGraphSCCPass((intptr_t)&ID) {} + + private: + bool PromoteArguments(CallGraphNode *CGN); + bool isSafeToPromoteArgument(Argument *Arg) const; + Function *DoPromotion(Function *F, std::vector &ArgsToPromote); + }; + + char ArgPromotion::ID = 0; + RegisterPass X("argpromotion", + "Promote 'by reference' arguments to scalars"); +} + +Pass *llvm::createArgumentPromotionPass() { + return new ArgPromotion(); +} + +bool ArgPromotion::runOnSCC(const std::vector &SCC) { + bool Changed = false, LocalChange; + + do { // Iterate until we stop promoting from this SCC. + LocalChange = false; + // Attempt to promote arguments from all functions in this SCC. + for (unsigned i = 0, e = SCC.size(); i != e; ++i) + LocalChange |= PromoteArguments(SCC[i]); + Changed |= LocalChange; // Remember that we changed something. + } while (LocalChange); + + return Changed; +} + +/// PromoteArguments - This method checks the specified function to see if there +/// are any promotable arguments and if it is safe to promote the function (for +/// example, all callers are direct). If safe to promote some arguments, it +/// calls the DoPromotion method. +/// +bool ArgPromotion::PromoteArguments(CallGraphNode *CGN) { + Function *F = CGN->getFunction(); + + // Make sure that it is local to this module. + if (!F || !F->hasInternalLinkage()) return false; + + // First check: see if there are any pointer arguments! If not, quick exit. + std::vector PointerArgs; + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; ++I) + if (isa(I->getType())) + PointerArgs.push_back(I); + if (PointerArgs.empty()) return false; + + // Second check: make sure that all callers are direct callers. We can't + // transform functions that have indirect callers. + for (Value::use_iterator UI = F->use_begin(), E = F->use_end(); + UI != E; ++UI) { + CallSite CS = CallSite::get(*UI); + if (!CS.getInstruction()) // "Taking the address" of the function + return false; + + // Ensure that this call site is CALLING the function, not passing it as + // an argument. + for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end(); + AI != E; ++AI) + if (*AI == F) return false; // Passing the function address in! + } + + // Check to see which arguments are promotable. If an argument is not + // promotable, remove it from the PointerArgs vector. + for (unsigned i = 0; i != PointerArgs.size(); ++i) + if (!isSafeToPromoteArgument(PointerArgs[i])) { + std::swap(PointerArgs[i--], PointerArgs.back()); + PointerArgs.pop_back(); + } + + // No promotable pointer arguments. + if (PointerArgs.empty()) return false; + + // Okay, promote all of the arguments are rewrite the callees! + Function *NewF = DoPromotion(F, PointerArgs); + + // Update the call graph to know that the old function is gone. + getAnalysis().changeFunction(F, NewF); + return true; +} + +/// IsAlwaysValidPointer - Return true if the specified pointer is always legal +/// to load. +static bool IsAlwaysValidPointer(Value *V) { + if (isa(V) || isa(V)) return true; + if (GetElementPtrInst *GEP = dyn_cast(V)) + return IsAlwaysValidPointer(GEP->getOperand(0)); + if (ConstantExpr *CE = dyn_cast(V)) + if (CE->getOpcode() == Instruction::GetElementPtr) + return IsAlwaysValidPointer(CE->getOperand(0)); + + return false; +} + +/// AllCalleesPassInValidPointerForArgument - Return true if we can prove that +/// all callees pass in a valid pointer for the specified function argument. +static bool AllCalleesPassInValidPointerForArgument(Argument *Arg) { + Function *Callee = Arg->getParent(); + + unsigned ArgNo = std::distance(Callee->arg_begin(), + Function::arg_iterator(Arg)); + + // Look at all call sites of the function. At this pointer we know we only + // have direct callees. + for (Value::use_iterator UI = Callee->use_begin(), E = Callee->use_end(); + UI != E; ++UI) { + CallSite CS = CallSite::get(*UI); + assert(CS.getInstruction() && "Should only have direct calls!"); + + if (!IsAlwaysValidPointer(CS.getArgument(ArgNo))) + return false; + } + return true; +} + + +/// isSafeToPromoteArgument - As you might guess from the name of this method, +/// it checks to see if it is both safe and useful to promote the argument. +/// This method limits promotion of aggregates to only promote up to three +/// elements of the aggregate in order to avoid exploding the number of +/// arguments passed in. +bool ArgPromotion::isSafeToPromoteArgument(Argument *Arg) const { + // We can only promote this argument if all of the uses are loads, or are GEP + // instructions (with constant indices) that are subsequently loaded. + bool HasLoadInEntryBlock = false; + BasicBlock *EntryBlock = Arg->getParent()->begin(); + std::vector Loads; + std::vector > GEPIndices; + for (Value::use_iterator UI = Arg->use_begin(), E = Arg->use_end(); + UI != E; ++UI) + if (LoadInst *LI = dyn_cast(*UI)) { + if (LI->isVolatile()) return false; // Don't hack volatile loads + Loads.push_back(LI); + HasLoadInEntryBlock |= LI->getParent() == EntryBlock; + } else if (GetElementPtrInst *GEP = dyn_cast(*UI)) { + if (GEP->use_empty()) { + // Dead GEP's cause trouble later. Just remove them if we run into + // them. + getAnalysis().deleteValue(GEP); + GEP->getParent()->getInstList().erase(GEP); + return isSafeToPromoteArgument(Arg); + } + // Ensure that all of the indices are constants. + std::vector Operands; + for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i) + if (ConstantInt *C = dyn_cast(GEP->getOperand(i))) + Operands.push_back(C); + else + return false; // Not a constant operand GEP! + + // Ensure that the only users of the GEP are load instructions. + for (Value::use_iterator UI = GEP->use_begin(), E = GEP->use_end(); + UI != E; ++UI) + if (LoadInst *LI = dyn_cast(*UI)) { + if (LI->isVolatile()) return false; // Don't hack volatile loads + Loads.push_back(LI); + HasLoadInEntryBlock |= LI->getParent() == EntryBlock; + } else { + return false; + } + + // See if there is already a GEP with these indices. If not, check to + // make sure that we aren't promoting too many elements. If so, nothing + // to do. + if (std::find(GEPIndices.begin(), GEPIndices.end(), Operands) == + GEPIndices.end()) { + if (GEPIndices.size() == 3) { + DOUT << "argpromotion disable promoting argument '" + << Arg->getName() << "' because it would require adding more " + << "than 3 arguments to the function.\n"; + // We limit aggregate promotion to only promoting up to three elements + // of the aggregate. + return false; + } + GEPIndices.push_back(Operands); + } + } else { + return false; // Not a load or a GEP. + } + + if (Loads.empty()) return true; // No users, this is a dead argument. + + // If we decide that we want to promote this argument, the value is going to + // be unconditionally loaded in all callees. This is only safe to do if the + // pointer was going to be unconditionally loaded anyway (i.e. there is a load + // of the pointer in the entry block of the function) or if we can prove that + // all pointers passed in are always to legal locations (for example, no null + // pointers are passed in, no pointers to free'd memory, etc). + if (!HasLoadInEntryBlock && !AllCalleesPassInValidPointerForArgument(Arg)) + return false; // Cannot prove that this is safe!! + + // Okay, now we know that the argument is only used by load instructions and + // it is safe to unconditionally load the pointer. Use alias analysis to + // check to see if the pointer is guaranteed to not be modified from entry of + // the function to each of the load instructions. + + // Because there could be several/many load instructions, remember which + // blocks we know to be transparent to the load. + std::set TranspBlocks; + + AliasAnalysis &AA = getAnalysis(); + TargetData &TD = getAnalysis(); + + for (unsigned i = 0, e = Loads.size(); i != e; ++i) { + // Check to see if the load is invalidated from the start of the block to + // the load itself. + LoadInst *Load = Loads[i]; + BasicBlock *BB = Load->getParent(); + + const PointerType *LoadTy = + cast(Load->getOperand(0)->getType()); + unsigned LoadSize = (unsigned)TD.getTypeSize(LoadTy->getElementType()); + + if (AA.canInstructionRangeModify(BB->front(), *Load, Arg, LoadSize)) + return false; // Pointer is invalidated! + + // Now check every path from the entry block to the load for transparency. + // To do this, we perform a depth first search on the inverse CFG from the + // loading block. + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) + for (idf_ext_iterator I = idf_ext_begin(*PI, TranspBlocks), + E = idf_ext_end(*PI, TranspBlocks); I != E; ++I) + if (AA.canBasicBlockModify(**I, Arg, LoadSize)) + return false; + } + + // If the path from the entry of the function to each load is free of + // instructions that potentially invalidate the load, we can make the + // transformation! + return true; +} + +namespace { + /// GEPIdxComparator - Provide a strong ordering for GEP indices. All Value* + /// elements are instances of ConstantInt. + /// + struct GEPIdxComparator { + bool operator()(const std::vector &LHS, + const std::vector &RHS) const { + unsigned idx = 0; + for (; idx < LHS.size() && idx < RHS.size(); ++idx) { + if (LHS[idx] != RHS[idx]) { + return cast(LHS[idx])->getZExtValue() < + cast(RHS[idx])->getZExtValue(); + } + } + + // Return less than if we ran out of stuff in LHS and we didn't run out of + // stuff in RHS. + return idx == LHS.size() && idx != RHS.size(); + } + }; +} + + +/// DoPromotion - This method actually performs the promotion of the specified +/// arguments, and returns the new function. At this point, we know that it's +/// safe to do so. +Function *ArgPromotion::DoPromotion(Function *F, + std::vector &Args2Prom) { + std::set ArgsToPromote(Args2Prom.begin(), Args2Prom.end()); + + // Start by computing a new prototype for the function, which is the same as + // the old function, but has modified arguments. + const FunctionType *FTy = F->getFunctionType(); + std::vector Params; + + typedef std::set, GEPIdxComparator> ScalarizeTable; + + // ScalarizedElements - If we are promoting a pointer that has elements + // accessed out of it, keep track of which elements are accessed so that we + // can add one argument for each. + // + // Arguments that are directly loaded will have a zero element value here, to + // handle cases where there are both a direct load and GEP accesses. + // + std::map ScalarizedElements; + + // OriginalLoads - Keep track of a representative load instruction from the + // original function so that we can tell the alias analysis implementation + // what the new GEP/Load instructions we are inserting look like. + std::map, LoadInst*> OriginalLoads; + + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; ++I) + if (!ArgsToPromote.count(I)) { + Params.push_back(I->getType()); + } else if (I->use_empty()) { + ++NumArgumentsDead; + } else { + // Okay, this is being promoted. Check to see if there are any GEP uses + // of the argument. + ScalarizeTable &ArgIndices = ScalarizedElements[I]; + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); UI != E; + ++UI) { + Instruction *User = cast(*UI); + assert(isa(User) || isa(User)); + std::vector Indices(User->op_begin()+1, User->op_end()); + ArgIndices.insert(Indices); + LoadInst *OrigLoad; + if (LoadInst *L = dyn_cast(User)) + OrigLoad = L; + else + OrigLoad = cast(User->use_back()); + OriginalLoads[Indices] = OrigLoad; + } + + // Add a parameter to the function for each element passed in. + for (ScalarizeTable::iterator SI = ArgIndices.begin(), + E = ArgIndices.end(); SI != E; ++SI) + Params.push_back(GetElementPtrInst::getIndexedType(I->getType(), + &(*SI)[0], + SI->size())); + + if (ArgIndices.size() == 1 && ArgIndices.begin()->empty()) + ++NumArgumentsPromoted; + else + ++NumAggregatesPromoted; + } + + const Type *RetTy = FTy->getReturnType(); + + // Work around LLVM bug PR56: the CWriter cannot emit varargs functions which + // have zero fixed arguments. + bool ExtraArgHack = false; + if (Params.empty() && FTy->isVarArg()) { + ExtraArgHack = true; + Params.push_back(Type::Int32Ty); + } + FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg()); + + // Create the new function body and insert it into the module... + Function *NF = new Function(NFTy, F->getLinkage(), F->getName()); + NF->setCallingConv(F->getCallingConv()); + F->getParent()->getFunctionList().insert(F, NF); + + // Get the alias analysis information that we need to update to reflect our + // changes. + AliasAnalysis &AA = getAnalysis(); + + // Loop over all of the callers of the function, transforming the call sites + // to pass in the loaded pointers. + // + std::vector Args; + while (!F->use_empty()) { + CallSite CS = CallSite::get(F->use_back()); + Instruction *Call = CS.getInstruction(); + + // Loop over the operands, inserting GEP and loads in the caller as + // appropriate. + CallSite::arg_iterator AI = CS.arg_begin(); + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); + I != E; ++I, ++AI) + if (!ArgsToPromote.count(I)) + Args.push_back(*AI); // Unmodified argument + else if (!I->use_empty()) { + // Non-dead argument: insert GEPs and loads as appropriate. + ScalarizeTable &ArgIndices = ScalarizedElements[I]; + for (ScalarizeTable::iterator SI = ArgIndices.begin(), + E = ArgIndices.end(); SI != E; ++SI) { + Value *V = *AI; + LoadInst *OrigLoad = OriginalLoads[*SI]; + if (!SI->empty()) { + V = new GetElementPtrInst(V, &(*SI)[0], SI->size(), + V->getName()+".idx", Call); + AA.copyValue(OrigLoad->getOperand(0), V); + } + Args.push_back(new LoadInst(V, V->getName()+".val", Call)); + AA.copyValue(OrigLoad, Args.back()); + } + } + + if (ExtraArgHack) + Args.push_back(Constant::getNullValue(Type::Int32Ty)); + + // Push any varargs arguments on the list + for (; AI != CS.arg_end(); ++AI) + Args.push_back(*AI); + + Instruction *New; + if (InvokeInst *II = dyn_cast(Call)) { + New = new InvokeInst(NF, II->getNormalDest(), II->getUnwindDest(), + &Args[0], Args.size(), "", Call); + cast(New)->setCallingConv(CS.getCallingConv()); + } else { + New = new CallInst(NF, &Args[0], Args.size(), "", Call); + cast(New)->setCallingConv(CS.getCallingConv()); + if (cast(Call)->isTailCall()) + cast(New)->setTailCall(); + } + Args.clear(); + + // Update the alias analysis implementation to know that we are replacing + // the old call with a new one. + AA.replaceWithNewValue(Call, New); + + if (!Call->use_empty()) { + Call->replaceAllUsesWith(New); + New->takeName(Call); + } + + // Finally, remove the old call from the program, reducing the use-count of + // F. + Call->getParent()->getInstList().erase(Call); + } + + // Since we have now created the new function, splice the body of the old + // function right into the new function, leaving the old rotting hulk of the + // function empty. + NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); + + // Loop over the argument list, transfering uses of the old arguments over to + // the new arguments, also transfering over the names as well. + // + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(), + I2 = NF->arg_begin(); I != E; ++I) + if (!ArgsToPromote.count(I)) { + // If this is an unmodified argument, move the name and users over to the + // new version. + I->replaceAllUsesWith(I2); + I2->takeName(I); + AA.replaceWithNewValue(I, I2); + ++I2; + } else if (I->use_empty()) { + AA.deleteValue(I); + } else { + // Otherwise, if we promoted this argument, then all users are load + // instructions, and all loads should be using the new argument that we + // added. + ScalarizeTable &ArgIndices = ScalarizedElements[I]; + + while (!I->use_empty()) { + if (LoadInst *LI = dyn_cast(I->use_back())) { + assert(ArgIndices.begin()->empty() && + "Load element should sort to front!"); + I2->setName(I->getName()+".val"); + LI->replaceAllUsesWith(I2); + AA.replaceWithNewValue(LI, I2); + LI->getParent()->getInstList().erase(LI); + DOUT << "*** Promoted load of argument '" << I->getName() + << "' in function '" << F->getName() << "'\n"; + } else { + GetElementPtrInst *GEP = cast(I->use_back()); + std::vector Operands(GEP->op_begin()+1, GEP->op_end()); + + Function::arg_iterator TheArg = I2; + for (ScalarizeTable::iterator It = ArgIndices.begin(); + *It != Operands; ++It, ++TheArg) { + assert(It != ArgIndices.end() && "GEP not handled??"); + } + + std::string NewName = I->getName(); + for (unsigned i = 0, e = Operands.size(); i != e; ++i) + if (ConstantInt *CI = dyn_cast(Operands[i])) + NewName += "." + CI->getValue().toString(10); + else + NewName += ".x"; + TheArg->setName(NewName+".val"); + + DOUT << "*** Promoted agg argument '" << TheArg->getName() + << "' of function '" << F->getName() << "'\n"; + + // All of the uses must be load instructions. Replace them all with + // the argument specified by ArgNo. + while (!GEP->use_empty()) { + LoadInst *L = cast(GEP->use_back()); + L->replaceAllUsesWith(TheArg); + AA.replaceWithNewValue(L, TheArg); + L->getParent()->getInstList().erase(L); + } + AA.deleteValue(GEP); + GEP->getParent()->getInstList().erase(GEP); + } + } + + // Increment I2 past all of the arguments added for this promoted pointer. + for (unsigned i = 0, e = ArgIndices.size(); i != e; ++i) + ++I2; + } + + // Notify the alias analysis implementation that we inserted a new argument. + if (ExtraArgHack) + AA.copyValue(Constant::getNullValue(Type::Int32Ty), NF->arg_begin()); + + + // Tell the alias analysis that the old function is about to disappear. + AA.replaceWithNewValue(F, NF); + + // Now that the old function is dead, delete it. + F->getParent()->getFunctionList().erase(F); + return NF; +} diff --git a/lib/Transforms/IPO/ConstantMerge.cpp b/lib/Transforms/IPO/ConstantMerge.cpp new file mode 100644 index 0000000..0c7ee59 --- /dev/null +++ b/lib/Transforms/IPO/ConstantMerge.cpp @@ -0,0 +1,116 @@ +//===- ConstantMerge.cpp - Merge duplicate global constants ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the interface to a pass that merges duplicate global +// constants together into a single constant that is shared. This is useful +// because some passes (ie TraceValues) insert a lot of string constants into +// the program, regardless of whether or not an existing string is available. +// +// Algorithm: ConstantMerge is designed to build up a map of available constants +// and eliminate duplicates when it is initialized. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "constmerge" +#include "llvm/Transforms/IPO.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumMerged, "Number of global constants merged"); + +namespace { + struct VISIBILITY_HIDDEN ConstantMerge : public ModulePass { + static char ID; // Pass identification, replacement for typeid + ConstantMerge() : ModulePass((intptr_t)&ID) {} + + // run - For this pass, process all of the globals in the module, + // eliminating duplicate constants. + // + bool runOnModule(Module &M); + }; + + char ConstantMerge::ID = 0; + RegisterPassX("constmerge","Merge Duplicate Global Constants"); +} + +ModulePass *llvm::createConstantMergePass() { return new ConstantMerge(); } + +bool ConstantMerge::runOnModule(Module &M) { + // Map unique constant/section pairs to globals. We don't want to merge + // globals in different sections. + std::map, GlobalVariable*> CMap; + + // Replacements - This vector contains a list of replacements to perform. + std::vector > Replacements; + + bool MadeChange = false; + + // Iterate constant merging while we are still making progress. Merging two + // constants together may allow us to merge other constants together if the + // second level constants have initializers which point to the globals that + // were just merged. + while (1) { + // First pass: identify all globals that can be merged together, filling in + // the Replacements vector. We cannot do the replacement in this pass + // because doing so may cause initializers of other globals to be rewritten, + // invalidating the Constant* pointers in CMap. + // + for (Module::global_iterator GVI = M.global_begin(), E = M.global_end(); + GVI != E; ) { + GlobalVariable *GV = GVI++; + + // If this GV is dead, remove it. + GV->removeDeadConstantUsers(); + if (GV->use_empty() && GV->hasInternalLinkage()) { + GV->eraseFromParent(); + continue; + } + + // Only process constants with initializers. + if (GV->isConstant() && GV->hasInitializer()) { + Constant *Init = GV->getInitializer(); + + // Check to see if the initializer is already known. + GlobalVariable *&Slot = CMap[std::make_pair(Init, GV->getSection())]; + + if (Slot == 0) { // Nope, add it to the map. + Slot = GV; + } else if (GV->hasInternalLinkage()) { // Yup, this is a duplicate! + // Make all uses of the duplicate constant use the canonical version. + Replacements.push_back(std::make_pair(GV, Slot)); + } else if (GV->hasInternalLinkage()) { + // Make all uses of the duplicate constant use the canonical version. + Replacements.push_back(std::make_pair(Slot, GV)); + Slot = GV; + } + } + } + + if (Replacements.empty()) + return MadeChange; + CMap.clear(); + + // Now that we have figured out which replacements must be made, do them all + // now. This avoid invalidating the pointers in CMap, which are unneeded + // now. + for (unsigned i = 0, e = Replacements.size(); i != e; ++i) { + // Eliminate any uses of the dead global... + Replacements[i].first->replaceAllUsesWith(Replacements[i].second); + + // Delete the global value from the module... + M.getGlobalList().erase(Replacements[i].first); + } + + NumMerged += Replacements.size(); + Replacements.clear(); + } +} diff --git a/lib/Transforms/IPO/DeadArgumentElimination.cpp b/lib/Transforms/IPO/DeadArgumentElimination.cpp new file mode 100644 index 0000000..943ea30 --- /dev/null +++ b/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -0,0 +1,703 @@ +//===-- DeadArgumentElimination.cpp - Eliminate dead arguments ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass deletes dead arguments from internal functions. Dead argument +// elimination removes arguments which are directly dead, as well as arguments +// only passed into function calls as dead arguments of other functions. This +// pass also deletes dead arguments in a similar way. +// +// This pass is often useful as a cleanup pass to run after aggressive +// interprocedural passes, which add possibly-dead arguments. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "deadargelim" +#include "llvm/Transforms/IPO.h" +#include "llvm/CallingConv.h" +#include "llvm/Constant.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +#include +using namespace llvm; + +STATISTIC(NumArgumentsEliminated, "Number of unread args removed"); +STATISTIC(NumRetValsEliminated , "Number of unused return values removed"); + +namespace { + /// DAE - The dead argument elimination pass. + /// + class VISIBILITY_HIDDEN DAE : public ModulePass { + /// Liveness enum - During our initial pass over the program, we determine + /// that things are either definately alive, definately dead, or in need of + /// interprocedural analysis (MaybeLive). + /// + enum Liveness { Live, MaybeLive, Dead }; + + /// LiveArguments, MaybeLiveArguments, DeadArguments - These sets contain + /// all of the arguments in the program. The Dead set contains arguments + /// which are completely dead (never used in the function). The MaybeLive + /// set contains arguments which are only passed into other function calls, + /// thus may be live and may be dead. The Live set contains arguments which + /// are known to be alive. + /// + std::set DeadArguments, MaybeLiveArguments, LiveArguments; + + /// DeadRetVal, MaybeLiveRetVal, LifeRetVal - These sets contain all of the + /// functions in the program. The Dead set contains functions whose return + /// value is known to be dead. The MaybeLive set contains functions whose + /// return values are only used by return instructions, and the Live set + /// contains functions whose return values are used, functions that are + /// external, and functions that already return void. + /// + std::set DeadRetVal, MaybeLiveRetVal, LiveRetVal; + + /// InstructionsToInspect - As we mark arguments and return values + /// MaybeLive, we keep track of which instructions could make the values + /// live here. Once the entire program has had the return value and + /// arguments analyzed, this set is scanned to promote the MaybeLive objects + /// to be Live if they really are used. + std::vector InstructionsToInspect; + + /// CallSites - Keep track of the call sites of functions that have + /// MaybeLive arguments or return values. + std::multimap CallSites; + + public: + static char ID; // Pass identification, replacement for typeid + DAE() : ModulePass((intptr_t)&ID) {} + bool runOnModule(Module &M); + + virtual bool ShouldHackArguments() const { return false; } + + private: + Liveness getArgumentLiveness(const Argument &A); + bool isMaybeLiveArgumentNowLive(Argument *Arg); + + bool DeleteDeadVarargs(Function &Fn); + void SurveyFunction(Function &Fn); + + void MarkArgumentLive(Argument *Arg); + void MarkRetValLive(Function *F); + void MarkReturnInstArgumentLive(ReturnInst *RI); + + void RemoveDeadArgumentsFromFunction(Function *F); + }; + char DAE::ID = 0; + RegisterPass X("deadargelim", "Dead Argument Elimination"); + + /// DAH - DeadArgumentHacking pass - Same as dead argument elimination, but + /// deletes arguments to functions which are external. This is only for use + /// by bugpoint. + struct DAH : public DAE { + static char ID; + virtual bool ShouldHackArguments() const { return true; } + }; + char DAH::ID = 0; + RegisterPass Y("deadarghaX0r", + "Dead Argument Hacking (BUGPOINT USE ONLY; DO NOT USE)"); +} + +/// createDeadArgEliminationPass - This pass removes arguments from functions +/// which are not used by the body of the function. +/// +ModulePass *llvm::createDeadArgEliminationPass() { return new DAE(); } +ModulePass *llvm::createDeadArgHackingPass() { return new DAH(); } + +/// DeleteDeadVarargs - If this is an function that takes a ... list, and if +/// llvm.vastart is never called, the varargs list is dead for the function. +bool DAE::DeleteDeadVarargs(Function &Fn) { + assert(Fn.getFunctionType()->isVarArg() && "Function isn't varargs!"); + if (Fn.isDeclaration() || !Fn.hasInternalLinkage()) return false; + + // Ensure that the function is only directly called. + for (Value::use_iterator I = Fn.use_begin(), E = Fn.use_end(); I != E; ++I) { + // If this use is anything other than a call site, give up. + CallSite CS = CallSite::get(*I); + Instruction *TheCall = CS.getInstruction(); + if (!TheCall) return false; // Not a direct call site? + + // The addr of this function is passed to the call. + if (I.getOperandNo() != 0) return false; + } + + // Okay, we know we can transform this function if safe. Scan its body + // looking for calls to llvm.vastart. + for (Function::iterator BB = Fn.begin(), E = Fn.end(); BB != E; ++BB) { + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + if (IntrinsicInst *II = dyn_cast(I)) { + if (II->getIntrinsicID() == Intrinsic::vastart) + return false; + } + } + } + + // If we get here, there are no calls to llvm.vastart in the function body, + // remove the "..." and adjust all the calls. + + // Start by computing a new prototype for the function, which is the same as + // the old function, but has fewer arguments. + const FunctionType *FTy = Fn.getFunctionType(); + std::vector Params(FTy->param_begin(), FTy->param_end()); + FunctionType *NFTy = FunctionType::get(FTy->getReturnType(), Params, false); + unsigned NumArgs = Params.size(); + + // Create the new function body and insert it into the module... + Function *NF = new Function(NFTy, Fn.getLinkage()); + NF->setCallingConv(Fn.getCallingConv()); + Fn.getParent()->getFunctionList().insert(&Fn, NF); + NF->takeName(&Fn); + + // Loop over all of the callers of the function, transforming the call sites + // to pass in a smaller number of arguments into the new function. + // + std::vector Args; + while (!Fn.use_empty()) { + CallSite CS = CallSite::get(Fn.use_back()); + Instruction *Call = CS.getInstruction(); + + // Loop over the operands, dropping extraneous ones at the end of the list. + Args.assign(CS.arg_begin(), CS.arg_begin()+NumArgs); + + Instruction *New; + if (InvokeInst *II = dyn_cast(Call)) { + New = new InvokeInst(NF, II->getNormalDest(), II->getUnwindDest(), + &Args[0], Args.size(), "", Call); + cast(New)->setCallingConv(CS.getCallingConv()); + } else { + New = new CallInst(NF, &Args[0], Args.size(), "", Call); + cast(New)->setCallingConv(CS.getCallingConv()); + if (cast(Call)->isTailCall()) + cast(New)->setTailCall(); + } + Args.clear(); + + if (!Call->use_empty()) + Call->replaceAllUsesWith(Constant::getNullValue(Call->getType())); + + New->takeName(Call); + + // Finally, remove the old call from the program, reducing the use-count of + // F. + Call->getParent()->getInstList().erase(Call); + } + + // Since we have now created the new function, splice the body of the old + // function right into the new function, leaving the old rotting hulk of the + // function empty. + NF->getBasicBlockList().splice(NF->begin(), Fn.getBasicBlockList()); + + // Loop over the argument list, transfering uses of the old arguments over to + // the new arguments, also transfering over the names as well. While we're at + // it, remove the dead arguments from the DeadArguments list. + // + for (Function::arg_iterator I = Fn.arg_begin(), E = Fn.arg_end(), + I2 = NF->arg_begin(); I != E; ++I, ++I2) { + // Move the name and users over to the new version. + I->replaceAllUsesWith(I2); + I2->takeName(I); + } + + // Finally, nuke the old function. + Fn.eraseFromParent(); + return true; +} + + +static inline bool CallPassesValueThoughVararg(Instruction *Call, + const Value *Arg) { + CallSite CS = CallSite::get(Call); + const Type *CalledValueTy = CS.getCalledValue()->getType(); + const Type *FTy = cast(CalledValueTy)->getElementType(); + unsigned NumFixedArgs = cast(FTy)->getNumParams(); + for (CallSite::arg_iterator AI = CS.arg_begin()+NumFixedArgs; + AI != CS.arg_end(); ++AI) + if (AI->get() == Arg) + return true; + return false; +} + +// getArgumentLiveness - Inspect an argument, determining if is known Live +// (used in a computation), MaybeLive (only passed as an argument to a call), or +// Dead (not used). +DAE::Liveness DAE::getArgumentLiveness(const Argument &A) { + const FunctionType *FTy = A.getParent()->getFunctionType(); + + // If this is the return value of a struct function, it's not really dead. + if (FTy->isStructReturn() && &*A.getParent()->arg_begin() == &A) + return Live; + + if (A.use_empty()) // First check, directly dead? + return Dead; + + // Scan through all of the uses, looking for non-argument passing uses. + for (Value::use_const_iterator I = A.use_begin(), E = A.use_end(); I!=E;++I) { + // Return instructions do not immediately effect liveness. + if (isa(*I)) + continue; + + CallSite CS = CallSite::get(const_cast(*I)); + if (!CS.getInstruction()) { + // If its used by something that is not a call or invoke, it's alive! + return Live; + } + // If it's an indirect call, mark it alive... + Function *Callee = CS.getCalledFunction(); + if (!Callee) return Live; + + // Check to see if it's passed through a va_arg area: if so, we cannot + // remove it. + if (CallPassesValueThoughVararg(CS.getInstruction(), &A)) + return Live; // If passed through va_arg area, we cannot remove it + } + + return MaybeLive; // It must be used, but only as argument to a function +} + + +// SurveyFunction - This performs the initial survey of the specified function, +// checking out whether or not it uses any of its incoming arguments or whether +// any callers use the return value. This fills in the +// (Dead|MaybeLive|Live)(Arguments|RetVal) sets. +// +// We consider arguments of non-internal functions to be intrinsically alive as +// well as arguments to functions which have their "address taken". +// +void DAE::SurveyFunction(Function &F) { + bool FunctionIntrinsicallyLive = false; + Liveness RetValLiveness = F.getReturnType() == Type::VoidTy ? Live : Dead; + + if (!F.hasInternalLinkage() && + (!ShouldHackArguments() || F.getIntrinsicID())) + FunctionIntrinsicallyLive = true; + else + for (Value::use_iterator I = F.use_begin(), E = F.use_end(); I != E; ++I) { + // If this use is anything other than a call site, the function is alive. + CallSite CS = CallSite::get(*I); + Instruction *TheCall = CS.getInstruction(); + if (!TheCall) { // Not a direct call site? + FunctionIntrinsicallyLive = true; + break; + } + + // Check to see if the return value is used... + if (RetValLiveness != Live) + for (Value::use_iterator I = TheCall->use_begin(), + E = TheCall->use_end(); I != E; ++I) + if (isa(cast(*I))) { + RetValLiveness = MaybeLive; + } else if (isa(cast(*I)) || + isa(cast(*I))) { + if (CallPassesValueThoughVararg(cast(*I), TheCall) || + !CallSite::get(cast(*I)).getCalledFunction()) { + RetValLiveness = Live; + break; + } else { + RetValLiveness = MaybeLive; + } + } else { + RetValLiveness = Live; + break; + } + + // If the function is PASSED IN as an argument, its address has been taken + for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end(); + AI != E; ++AI) + if (AI->get() == &F) { + FunctionIntrinsicallyLive = true; + break; + } + if (FunctionIntrinsicallyLive) break; + } + + if (FunctionIntrinsicallyLive) { + DOUT << " Intrinsically live fn: " << F.getName() << "\n"; + for (Function::arg_iterator AI = F.arg_begin(), E = F.arg_end(); + AI != E; ++AI) + LiveArguments.insert(AI); + LiveRetVal.insert(&F); + return; + } + + switch (RetValLiveness) { + case Live: LiveRetVal.insert(&F); break; + case MaybeLive: MaybeLiveRetVal.insert(&F); break; + case Dead: DeadRetVal.insert(&F); break; + } + + DOUT << " Inspecting args for fn: " << F.getName() << "\n"; + + // If it is not intrinsically alive, we know that all users of the + // function are call sites. Mark all of the arguments live which are + // directly used, and keep track of all of the call sites of this function + // if there are any arguments we assume that are dead. + // + bool AnyMaybeLiveArgs = false; + for (Function::arg_iterator AI = F.arg_begin(), E = F.arg_end(); + AI != E; ++AI) + switch (getArgumentLiveness(*AI)) { + case Live: + DOUT << " Arg live by use: " << AI->getName() << "\n"; + LiveArguments.insert(AI); + break; + case Dead: + DOUT << " Arg definitely dead: " << AI->getName() <<"\n"; + DeadArguments.insert(AI); + break; + case MaybeLive: + DOUT << " Arg only passed to calls: " << AI->getName() << "\n"; + AnyMaybeLiveArgs = true; + MaybeLiveArguments.insert(AI); + break; + } + + // If there are any "MaybeLive" arguments, we need to check callees of + // this function when/if they become alive. Record which functions are + // callees... + if (AnyMaybeLiveArgs || RetValLiveness == MaybeLive) + for (Value::use_iterator I = F.use_begin(), E = F.use_end(); + I != E; ++I) { + if (AnyMaybeLiveArgs) + CallSites.insert(std::make_pair(&F, CallSite::get(*I))); + + if (RetValLiveness == MaybeLive) + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) + InstructionsToInspect.push_back(cast(*UI)); + } +} + +// isMaybeLiveArgumentNowLive - Check to see if Arg is alive. At this point, we +// know that the only uses of Arg are to be passed in as an argument to a +// function call or return. Check to see if the formal argument passed in is in +// the LiveArguments set. If so, return true. +// +bool DAE::isMaybeLiveArgumentNowLive(Argument *Arg) { + for (Value::use_iterator I = Arg->use_begin(), E = Arg->use_end(); I!=E; ++I){ + if (isa(*I)) { + if (LiveRetVal.count(Arg->getParent())) return true; + continue; + } + + CallSite CS = CallSite::get(*I); + + // We know that this can only be used for direct calls... + Function *Callee = CS.getCalledFunction(); + + // Loop over all of the arguments (because Arg may be passed into the call + // multiple times) and check to see if any are now alive... + CallSite::arg_iterator CSAI = CS.arg_begin(); + for (Function::arg_iterator AI = Callee->arg_begin(), E = Callee->arg_end(); + AI != E; ++AI, ++CSAI) + // If this is the argument we are looking for, check to see if it's alive + if (*CSAI == Arg && LiveArguments.count(AI)) + return true; + } + return false; +} + +/// MarkArgumentLive - The MaybeLive argument 'Arg' is now known to be alive. +/// Mark it live in the specified sets and recursively mark arguments in callers +/// live that are needed to pass in a value. +/// +void DAE::MarkArgumentLive(Argument *Arg) { + std::set::iterator It = MaybeLiveArguments.lower_bound(Arg); + if (It == MaybeLiveArguments.end() || *It != Arg) return; + + DOUT << " MaybeLive argument now live: " << Arg->getName() <<"\n"; + MaybeLiveArguments.erase(It); + LiveArguments.insert(Arg); + + // Loop over all of the call sites of the function, making any arguments + // passed in to provide a value for this argument live as necessary. + // + Function *Fn = Arg->getParent(); + unsigned ArgNo = std::distance(Fn->arg_begin(), Function::arg_iterator(Arg)); + + std::multimap::iterator I = CallSites.lower_bound(Fn); + for (; I != CallSites.end() && I->first == Fn; ++I) { + CallSite CS = I->second; + Value *ArgVal = *(CS.arg_begin()+ArgNo); + if (Argument *ActualArg = dyn_cast(ArgVal)) { + MarkArgumentLive(ActualArg); + } else { + // If the value passed in at this call site is a return value computed by + // some other call site, make sure to mark the return value at the other + // call site as being needed. + CallSite ArgCS = CallSite::get(ArgVal); + if (ArgCS.getInstruction()) + if (Function *Fn = ArgCS.getCalledFunction()) + MarkRetValLive(Fn); + } + } +} + +/// MarkArgumentLive - The MaybeLive return value for the specified function is +/// now known to be alive. Propagate this fact to the return instructions which +/// produce it. +void DAE::MarkRetValLive(Function *F) { + assert(F && "Shame shame, we can't have null pointers here!"); + + // Check to see if we already knew it was live + std::set::iterator I = MaybeLiveRetVal.lower_bound(F); + if (I == MaybeLiveRetVal.end() || *I != F) return; // It's already alive! + + DOUT << " MaybeLive retval now live: " << F->getName() << "\n"; + + MaybeLiveRetVal.erase(I); + LiveRetVal.insert(F); // It is now known to be live! + + // Loop over all of the functions, noticing that the return value is now live. + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + if (ReturnInst *RI = dyn_cast(BB->getTerminator())) + MarkReturnInstArgumentLive(RI); +} + +void DAE::MarkReturnInstArgumentLive(ReturnInst *RI) { + Value *Op = RI->getOperand(0); + if (Argument *A = dyn_cast(Op)) { + MarkArgumentLive(A); + } else if (CallInst *CI = dyn_cast(Op)) { + if (Function *F = CI->getCalledFunction()) + MarkRetValLive(F); + } else if (InvokeInst *II = dyn_cast(Op)) { + if (Function *F = II->getCalledFunction()) + MarkRetValLive(F); + } +} + +// RemoveDeadArgumentsFromFunction - We know that F has dead arguments, as +// specified by the DeadArguments list. Transform the function and all of the +// callees of the function to not have these arguments. +// +void DAE::RemoveDeadArgumentsFromFunction(Function *F) { + // Start by computing a new prototype for the function, which is the same as + // the old function, but has fewer arguments. + const FunctionType *FTy = F->getFunctionType(); + std::vector Params; + + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; ++I) + if (!DeadArguments.count(I)) + Params.push_back(I->getType()); + + const Type *RetTy = FTy->getReturnType(); + if (DeadRetVal.count(F)) { + RetTy = Type::VoidTy; + DeadRetVal.erase(F); + } + + // Work around LLVM bug PR56: the CWriter cannot emit varargs functions which + // have zero fixed arguments. + // + bool ExtraArgHack = false; + if (Params.empty() && FTy->isVarArg()) { + ExtraArgHack = true; + Params.push_back(Type::Int32Ty); + } + + FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg()); + + // Create the new function body and insert it into the module... + Function *NF = new Function(NFTy, F->getLinkage()); + NF->setCallingConv(F->getCallingConv()); + F->getParent()->getFunctionList().insert(F, NF); + NF->takeName(F); + + // Loop over all of the callers of the function, transforming the call sites + // to pass in a smaller number of arguments into the new function. + // + std::vector Args; + while (!F->use_empty()) { + CallSite CS = CallSite::get(F->use_back()); + Instruction *Call = CS.getInstruction(); + + // Loop over the operands, deleting dead ones... + CallSite::arg_iterator AI = CS.arg_begin(); + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); + I != E; ++I, ++AI) + if (!DeadArguments.count(I)) // Remove operands for dead arguments + Args.push_back(*AI); + + if (ExtraArgHack) + Args.push_back(UndefValue::get(Type::Int32Ty)); + + // Push any varargs arguments on the list + for (; AI != CS.arg_end(); ++AI) + Args.push_back(*AI); + + Instruction *New; + if (InvokeInst *II = dyn_cast(Call)) { + New = new InvokeInst(NF, II->getNormalDest(), II->getUnwindDest(), + &Args[0], Args.size(), "", Call); + cast(New)->setCallingConv(CS.getCallingConv()); + } else { + New = new CallInst(NF, &Args[0], Args.size(), "", Call); + cast(New)->setCallingConv(CS.getCallingConv()); + if (cast(Call)->isTailCall()) + cast(New)->setTailCall(); + } + Args.clear(); + + if (!Call->use_empty()) { + if (New->getType() == Type::VoidTy) + Call->replaceAllUsesWith(Constant::getNullValue(Call->getType())); + else { + Call->replaceAllUsesWith(New); + New->takeName(Call); + } + } + + // Finally, remove the old call from the program, reducing the use-count of + // F. + Call->getParent()->getInstList().erase(Call); + } + + // Since we have now created the new function, splice the body of the old + // function right into the new function, leaving the old rotting hulk of the + // function empty. + NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); + + // Loop over the argument list, transfering uses of the old arguments over to + // the new arguments, also transfering over the names as well. While we're at + // it, remove the dead arguments from the DeadArguments list. + // + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(), + I2 = NF->arg_begin(); + I != E; ++I) + if (!DeadArguments.count(I)) { + // If this is a live argument, move the name and users over to the new + // version. + I->replaceAllUsesWith(I2); + I2->takeName(I); + ++I2; + } else { + // If this argument is dead, replace any uses of it with null constants + // (these are guaranteed to only be operands to call instructions which + // will later be simplified). + I->replaceAllUsesWith(Constant::getNullValue(I->getType())); + DeadArguments.erase(I); + } + + // If we change the return value of the function we must rewrite any return + // instructions. Check this now. + if (F->getReturnType() != NF->getReturnType()) + for (Function::iterator BB = NF->begin(), E = NF->end(); BB != E; ++BB) + if (ReturnInst *RI = dyn_cast(BB->getTerminator())) { + new ReturnInst(0, RI); + BB->getInstList().erase(RI); + } + + // Now that the old function is dead, delete it. + F->getParent()->getFunctionList().erase(F); +} + +bool DAE::runOnModule(Module &M) { + // First phase: loop through the module, determining which arguments are live. + // We assume all arguments are dead unless proven otherwise (allowing us to + // determine that dead arguments passed into recursive functions are dead). + // + DOUT << "DAE - Determining liveness\n"; + for (Module::iterator I = M.begin(), E = M.end(); I != E; ) { + Function &F = *I++; + if (F.getFunctionType()->isVarArg()) + if (DeleteDeadVarargs(F)) + continue; + + SurveyFunction(F); + } + + // Loop over the instructions to inspect, propagating liveness among arguments + // and return values which are MaybeLive. + + while (!InstructionsToInspect.empty()) { + Instruction *I = InstructionsToInspect.back(); + InstructionsToInspect.pop_back(); + + if (ReturnInst *RI = dyn_cast(I)) { + // For return instructions, we just have to check to see if the return + // value for the current function is known now to be alive. If so, any + // arguments used by it are now alive, and any call instruction return + // value is alive as well. + if (LiveRetVal.count(RI->getParent()->getParent())) + MarkReturnInstArgumentLive(RI); + + } else { + CallSite CS = CallSite::get(I); + assert(CS.getInstruction() && "Unknown instruction for the I2I list!"); + + Function *Callee = CS.getCalledFunction(); + + // If we found a call or invoke instruction on this list, that means that + // an argument of the function is a call instruction. If the argument is + // live, then the return value of the called instruction is now live. + // + CallSite::arg_iterator AI = CS.arg_begin(); // ActualIterator + for (Function::arg_iterator FI = Callee->arg_begin(), + E = Callee->arg_end(); FI != E; ++AI, ++FI) { + // If this argument is another call... + CallSite ArgCS = CallSite::get(*AI); + if (ArgCS.getInstruction() && LiveArguments.count(FI)) + if (Function *Callee = ArgCS.getCalledFunction()) + MarkRetValLive(Callee); + } + } + } + + // Now we loop over all of the MaybeLive arguments, promoting them to be live + // arguments if one of the calls that uses the arguments to the calls they are + // passed into requires them to be live. Of course this could make other + // arguments live, so process callers recursively. + // + // Because elements can be removed from the MaybeLiveArguments set, copy it to + // a temporary vector. + // + std::vector TmpArgList(MaybeLiveArguments.begin(), + MaybeLiveArguments.end()); + for (unsigned i = 0, e = TmpArgList.size(); i != e; ++i) { + Argument *MLA = TmpArgList[i]; + if (MaybeLiveArguments.count(MLA) && + isMaybeLiveArgumentNowLive(MLA)) + MarkArgumentLive(MLA); + } + + // Recover memory early... + CallSites.clear(); + + // At this point, we know that all arguments in DeadArguments and + // MaybeLiveArguments are dead. If the two sets are empty, there is nothing + // to do. + if (MaybeLiveArguments.empty() && DeadArguments.empty() && + MaybeLiveRetVal.empty() && DeadRetVal.empty()) + return false; + + // Otherwise, compact into one set, and start eliminating the arguments from + // the functions. + DeadArguments.insert(MaybeLiveArguments.begin(), MaybeLiveArguments.end()); + MaybeLiveArguments.clear(); + DeadRetVal.insert(MaybeLiveRetVal.begin(), MaybeLiveRetVal.end()); + MaybeLiveRetVal.clear(); + + LiveArguments.clear(); + LiveRetVal.clear(); + + NumArgumentsEliminated += DeadArguments.size(); + NumRetValsEliminated += DeadRetVal.size(); + while (!DeadArguments.empty()) + RemoveDeadArgumentsFromFunction((*DeadArguments.begin())->getParent()); + + while (!DeadRetVal.empty()) + RemoveDeadArgumentsFromFunction(*DeadRetVal.begin()); + return true; +} diff --git a/lib/Transforms/IPO/DeadTypeElimination.cpp b/lib/Transforms/IPO/DeadTypeElimination.cpp new file mode 100644 index 0000000..87b725a --- /dev/null +++ b/lib/Transforms/IPO/DeadTypeElimination.cpp @@ -0,0 +1,106 @@ +//===- DeadTypeElimination.cpp - Eliminate unused types for symbol table --===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass is used to cleanup the output of GCC. It eliminate names for types +// that are unused in the entire translation unit, using the FindUsedTypes pass. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "deadtypeelim" +#include "llvm/Transforms/IPO.h" +#include "llvm/Analysis/FindUsedTypes.h" +#include "llvm/Module.h" +#include "llvm/TypeSymbolTable.h" +#include "llvm/DerivedTypes.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumKilled, "Number of unused typenames removed from symtab"); + +namespace { + struct VISIBILITY_HIDDEN DTE : public ModulePass { + static char ID; // Pass identification, replacement for typeid + DTE() : ModulePass((intptr_t)&ID) {} + + // doPassInitialization - For this pass, it removes global symbol table + // entries for primitive types. These are never used for linking in GCC and + // they make the output uglier to look at, so we nuke them. + // + // Also, initialize instance variables. + // + bool runOnModule(Module &M); + + // getAnalysisUsage - This function needs FindUsedTypes to do its job... + // + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + } + }; + char DTE::ID = 0; + RegisterPass X("deadtypeelim", "Dead Type Elimination"); +} + +ModulePass *llvm::createDeadTypeEliminationPass() { + return new DTE(); +} + + +// ShouldNukeSymtabEntry - Return true if this module level symbol table entry +// should be eliminated. +// +static inline bool ShouldNukeSymtabEntry(const Type *Ty){ + // Nuke all names for primitive types! + if (Ty->isPrimitiveType() || Ty->isInteger()) + return true; + + // Nuke all pointers to primitive types as well... + if (const PointerType *PT = dyn_cast(Ty)) + if (PT->getElementType()->isPrimitiveType() || + PT->getElementType()->isInteger()) + return true; + + return false; +} + +// run - For this pass, it removes global symbol table entries for primitive +// types. These are never used for linking in GCC and they make the output +// uglier to look at, so we nuke them. Also eliminate types that are never used +// in the entire program as indicated by FindUsedTypes. +// +bool DTE::runOnModule(Module &M) { + bool Changed = false; + + TypeSymbolTable &ST = M.getTypeSymbolTable(); + std::set UsedTypes = getAnalysis().getTypes(); + + // Check the symbol table for superfluous type entries... + // + // Grab the 'type' plane of the module symbol... + TypeSymbolTable::iterator TI = ST.begin(); + TypeSymbolTable::iterator TE = ST.end(); + while ( TI != TE ) { + // If this entry should be unconditionally removed, or if we detect that + // the type is not used, remove it. + const Type *RHS = TI->second; + if (ShouldNukeSymtabEntry(RHS) || !UsedTypes.count(RHS)) { + ST.remove(TI++); + ++NumKilled; + Changed = true; + } else { + ++TI; + // We only need to leave one name for each type. + UsedTypes.erase(RHS); + } + } + + return Changed; +} + +// vim: sw=2 diff --git a/lib/Transforms/IPO/ExtractFunction.cpp b/lib/Transforms/IPO/ExtractFunction.cpp new file mode 100644 index 0000000..8d6af41 --- /dev/null +++ b/lib/Transforms/IPO/ExtractFunction.cpp @@ -0,0 +1,144 @@ +//===-- ExtractFunction.cpp - Function extraction pass --------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass extracts +// +//===----------------------------------------------------------------------===// + +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +namespace { + /// @brief A pass to extract specific functions and their dependencies. + class VISIBILITY_HIDDEN FunctionExtractorPass : public ModulePass { + Function *Named; + bool deleteFunc; + bool reLink; + public: + static char ID; // Pass identification, replacement for typeid + + /// FunctionExtractorPass - If deleteFn is true, this pass deletes as the + /// specified function. Otherwise, it deletes as much of the module as + /// possible, except for the function specified. + /// + FunctionExtractorPass(Function *F = 0, bool deleteFn = true, + bool relinkCallees = false) + : ModulePass((intptr_t)&ID), Named(F), deleteFunc(deleteFn), + reLink(relinkCallees) {} + + bool runOnModule(Module &M) { + if (Named == 0) { + Named = M.getFunction("main"); + if (Named == 0) return false; // No function to extract + } + + if (deleteFunc) + return deleteFunction(); + M.setModuleInlineAsm(""); + return isolateFunction(M); + } + + bool deleteFunction() { + // If we're in relinking mode, set linkage of all internal callees to + // external. This will allow us extract function, and then - link + // everything together + if (reLink) { + for (Function::iterator B = Named->begin(), BE = Named->end(); + B != BE; ++B) { + for (BasicBlock::iterator I = B->begin(), E = B->end(); + I != E; ++I) { + if (CallInst* callInst = dyn_cast(&*I)) { + Function* Callee = callInst->getCalledFunction(); + if (Callee && Callee->hasInternalLinkage()) + Callee->setLinkage(GlobalValue::ExternalLinkage); + } + } + } + } + + Named->setLinkage(GlobalValue::ExternalLinkage); + Named->deleteBody(); + assert(Named->isDeclaration() && "This didn't make the function external!"); + return true; + } + + bool isolateFunction(Module &M) { + // Make sure our result is globally accessible... + Named->setLinkage(GlobalValue::ExternalLinkage); + + // Mark all global variables internal + for (Module::global_iterator I = M.global_begin(), E = M.global_end(); I != E; ++I) + if (!I->isDeclaration()) { + I->setInitializer(0); // Make all variables external + I->setLinkage(GlobalValue::ExternalLinkage); + } + + // All of the functions may be used by global variables or the named + // function. Loop through them and create a new, external functions that + // can be "used", instead of ones with bodies. + std::vector NewFunctions; + + Function *Last = --M.end(); // Figure out where the last real fn is. + + for (Module::iterator I = M.begin(); ; ++I) { + if (&*I != Named) { + Function *New = new Function(I->getFunctionType(), + GlobalValue::ExternalLinkage); + New->setCallingConv(I->getCallingConv()); + + // If it's not the named function, delete the body of the function + I->dropAllReferences(); + + M.getFunctionList().push_back(New); + NewFunctions.push_back(New); + New->takeName(I); + } + + if (&*I == Last) break; // Stop after processing the last function + } + + // Now that we have replacements all set up, loop through the module, + // deleting the old functions, replacing them with the newly created + // functions. + if (!NewFunctions.empty()) { + unsigned FuncNum = 0; + Module::iterator I = M.begin(); + do { + if (&*I != Named) { + // Make everything that uses the old function use the new dummy fn + I->replaceAllUsesWith(NewFunctions[FuncNum++]); + + Function *Old = I; + ++I; // Move the iterator to the new function + + // Delete the old function! + M.getFunctionList().erase(Old); + + } else { + ++I; // Skip the function we are extracting + } + } while (&*I != NewFunctions[0]); + } + + return true; + } + }; + + char FunctionExtractorPass::ID = 0; + RegisterPass X("extract", "Function Extractor"); +} + +ModulePass *llvm::createFunctionExtractionPass(Function *F, bool deleteFn, + bool relinkCallees) { + return new FunctionExtractorPass(F, deleteFn, relinkCallees); +} diff --git a/lib/Transforms/IPO/GlobalDCE.cpp b/lib/Transforms/IPO/GlobalDCE.cpp new file mode 100644 index 0000000..09cfa21 --- /dev/null +++ b/lib/Transforms/IPO/GlobalDCE.cpp @@ -0,0 +1,203 @@ +//===-- GlobalDCE.cpp - DCE unreachable internal functions ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This transform is designed to eliminate unreachable internal globals from the +// program. It uses an aggressive algorithm, searching out globals that are +// known to be alive. After it finds all of the globals which are needed, it +// deletes whatever is left over. This allows it to delete recursive chunks of +// the program which are unreachable. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "globaldce" +#include "llvm/Transforms/IPO.h" +#include "llvm/Constants.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +#include +using namespace llvm; + +STATISTIC(NumFunctions, "Number of functions removed"); +STATISTIC(NumVariables, "Number of global variables removed"); + +namespace { + struct VISIBILITY_HIDDEN GlobalDCE : public ModulePass { + static char ID; // Pass identification, replacement for typeid + GlobalDCE() : ModulePass((intptr_t)&ID) {} + + // run - Do the GlobalDCE pass on the specified module, optionally updating + // the specified callgraph to reflect the changes. + // + bool runOnModule(Module &M); + + private: + std::set AliveGlobals; + + /// MarkGlobalIsNeeded - the specific global value as needed, and + /// recursively mark anything that it uses as also needed. + void GlobalIsNeeded(GlobalValue *GV); + void MarkUsedGlobalsAsNeeded(Constant *C); + + bool SafeToDestroyConstant(Constant* C); + bool RemoveUnusedGlobalValue(GlobalValue &GV); + }; + char GlobalDCE::ID = 0; + RegisterPass X("globaldce", "Dead Global Elimination"); +} + +ModulePass *llvm::createGlobalDCEPass() { return new GlobalDCE(); } + +bool GlobalDCE::runOnModule(Module &M) { + bool Changed = false; + // Loop over the module, adding globals which are obviously necessary. + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) { + Changed |= RemoveUnusedGlobalValue(*I); + // Functions with external linkage are needed if they have a body + if ((!I->hasInternalLinkage() && !I->hasLinkOnceLinkage()) && + !I->isDeclaration()) + GlobalIsNeeded(I); + } + + for (Module::global_iterator I = M.global_begin(), E = M.global_end(); + I != E; ++I) { + Changed |= RemoveUnusedGlobalValue(*I); + // Externally visible & appending globals are needed, if they have an + // initializer. + if ((!I->hasInternalLinkage() && !I->hasLinkOnceLinkage()) && + !I->isDeclaration()) + GlobalIsNeeded(I); + } + + + for (Module::alias_iterator I = M.alias_begin(), E = M.alias_end(); + I != E; ++I) { + // Aliases are always needed even if they are not used. + MarkUsedGlobalsAsNeeded(I->getAliasee()); + } + + // Now that all globals which are needed are in the AliveGlobals set, we loop + // through the program, deleting those which are not alive. + // + + // The first pass is to drop initializers of global variables which are dead. + std::vector DeadGlobalVars; // Keep track of dead globals + for (Module::global_iterator I = M.global_begin(), E = M.global_end(); I != E; ++I) + if (!AliveGlobals.count(I)) { + DeadGlobalVars.push_back(I); // Keep track of dead globals + I->setInitializer(0); + } + + + // The second pass drops the bodies of functions which are dead... + std::vector DeadFunctions; + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) + if (!AliveGlobals.count(I)) { + DeadFunctions.push_back(I); // Keep track of dead globals + if (!I->isDeclaration()) + I->deleteBody(); + } + + if (!DeadFunctions.empty()) { + // Now that all interreferences have been dropped, delete the actual objects + // themselves. + for (unsigned i = 0, e = DeadFunctions.size(); i != e; ++i) { + RemoveUnusedGlobalValue(*DeadFunctions[i]); + M.getFunctionList().erase(DeadFunctions[i]); + } + NumFunctions += DeadFunctions.size(); + Changed = true; + } + + if (!DeadGlobalVars.empty()) { + for (unsigned i = 0, e = DeadGlobalVars.size(); i != e; ++i) { + RemoveUnusedGlobalValue(*DeadGlobalVars[i]); + M.getGlobalList().erase(DeadGlobalVars[i]); + } + NumVariables += DeadGlobalVars.size(); + Changed = true; + } + + // Make sure that all memory is released + AliveGlobals.clear(); + return Changed; +} + +/// MarkGlobalIsNeeded - the specific global value as needed, and +/// recursively mark anything that it uses as also needed. +void GlobalDCE::GlobalIsNeeded(GlobalValue *G) { + std::set::iterator I = AliveGlobals.lower_bound(G); + + // If the global is already in the set, no need to reprocess it. + if (I != AliveGlobals.end() && *I == G) return; + + // Otherwise insert it now, so we do not infinitely recurse + AliveGlobals.insert(I, G); + + if (GlobalVariable *GV = dyn_cast(G)) { + // If this is a global variable, we must make sure to add any global values + // referenced by the initializer to the alive set. + if (GV->hasInitializer()) + MarkUsedGlobalsAsNeeded(GV->getInitializer()); + } else if (!isa(G)) { + // Otherwise this must be a function object. We have to scan the body of + // the function looking for constants and global values which are used as + // operands. Any operands of these types must be processed to ensure that + // any globals used will be marked as needed. + Function *F = cast(G); + // For all basic blocks... + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + // For all instructions... + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + // For all operands... + for (User::op_iterator U = I->op_begin(), E = I->op_end(); U != E; ++U) + if (GlobalValue *GV = dyn_cast(*U)) + GlobalIsNeeded(GV); + else if (Constant *C = dyn_cast(*U)) + MarkUsedGlobalsAsNeeded(C); + } +} + +void GlobalDCE::MarkUsedGlobalsAsNeeded(Constant *C) { + if (GlobalValue *GV = dyn_cast(C)) + GlobalIsNeeded(GV); + else { + // Loop over all of the operands of the constant, adding any globals they + // use to the list of needed globals. + for (User::op_iterator I = C->op_begin(), E = C->op_end(); I != E; ++I) + MarkUsedGlobalsAsNeeded(cast(*I)); + } +} + +// RemoveUnusedGlobalValue - Loop over all of the uses of the specified +// GlobalValue, looking for the constant pointer ref that may be pointing to it. +// If found, check to see if the constant pointer ref is safe to destroy, and if +// so, nuke it. This will reduce the reference count on the global value, which +// might make it deader. +// +bool GlobalDCE::RemoveUnusedGlobalValue(GlobalValue &GV) { + if (GV.use_empty()) return false; + GV.removeDeadConstantUsers(); + return GV.use_empty(); +} + +// SafeToDestroyConstant - It is safe to destroy a constant iff it is only used +// by constants itself. Note that constants cannot be cyclic, so this test is +// pretty easy to implement recursively. +// +bool GlobalDCE::SafeToDestroyConstant(Constant *C) { + for (Value::use_iterator I = C->use_begin(), E = C->use_end(); I != E; ++I) + if (Constant *User = dyn_cast(*I)) { + if (!SafeToDestroyConstant(User)) return false; + } else { + return false; + } + return true; +} diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp new file mode 100644 index 0000000..520af87 --- /dev/null +++ b/lib/Transforms/IPO/GlobalOpt.cpp @@ -0,0 +1,1988 @@ +//===- GlobalOpt.cpp - Optimize Global Variables --------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass transforms simple global variables that never have their address +// taken. If obviously true, it marks read/write globals as constant, deletes +// variables only stored to, etc. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "globalopt" +#include "llvm/Transforms/IPO.h" +#include "llvm/CallingConv.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +#include +#include +using namespace llvm; + +STATISTIC(NumMarked , "Number of globals marked constant"); +STATISTIC(NumSRA , "Number of aggregate globals broken into scalars"); +STATISTIC(NumHeapSRA , "Number of heap objects SRA'd"); +STATISTIC(NumSubstitute,"Number of globals with initializers stored into them"); +STATISTIC(NumDeleted , "Number of globals deleted"); +STATISTIC(NumFnDeleted , "Number of functions deleted"); +STATISTIC(NumGlobUses , "Number of global uses devirtualized"); +STATISTIC(NumLocalized , "Number of globals localized"); +STATISTIC(NumShrunkToBool , "Number of global vars shrunk to booleans"); +STATISTIC(NumFastCallFns , "Number of functions converted to fastcc"); +STATISTIC(NumCtorsEvaluated, "Number of static ctors evaluated"); + +namespace { + struct VISIBILITY_HIDDEN GlobalOpt : public ModulePass { + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + } + static char ID; // Pass identification, replacement for typeid + GlobalOpt() : ModulePass((intptr_t)&ID) {} + + bool runOnModule(Module &M); + + private: + GlobalVariable *FindGlobalCtors(Module &M); + bool OptimizeFunctions(Module &M); + bool OptimizeGlobalVars(Module &M); + bool OptimizeGlobalCtorsList(GlobalVariable *&GCL); + bool ProcessInternalGlobal(GlobalVariable *GV,Module::global_iterator &GVI); + }; + + char GlobalOpt::ID = 0; + RegisterPass X("globalopt", "Global Variable Optimizer"); +} + +ModulePass *llvm::createGlobalOptimizerPass() { return new GlobalOpt(); } + +/// GlobalStatus - As we analyze each global, keep track of some information +/// about it. If we find out that the address of the global is taken, none of +/// this info will be accurate. +struct VISIBILITY_HIDDEN GlobalStatus { + /// isLoaded - True if the global is ever loaded. If the global isn't ever + /// loaded it can be deleted. + bool isLoaded; + + /// StoredType - Keep track of what stores to the global look like. + /// + enum StoredType { + /// NotStored - There is no store to this global. It can thus be marked + /// constant. + NotStored, + + /// isInitializerStored - This global is stored to, but the only thing + /// stored is the constant it was initialized with. This is only tracked + /// for scalar globals. + isInitializerStored, + + /// isStoredOnce - This global is stored to, but only its initializer and + /// one other value is ever stored to it. If this global isStoredOnce, we + /// track the value stored to it in StoredOnceValue below. This is only + /// tracked for scalar globals. + isStoredOnce, + + /// isStored - This global is stored to by multiple values or something else + /// that we cannot track. + isStored + } StoredType; + + /// StoredOnceValue - If only one value (besides the initializer constant) is + /// ever stored to this global, keep track of what value it is. + Value *StoredOnceValue; + + /// AccessingFunction/HasMultipleAccessingFunctions - These start out + /// null/false. When the first accessing function is noticed, it is recorded. + /// When a second different accessing function is noticed, + /// HasMultipleAccessingFunctions is set to true. + Function *AccessingFunction; + bool HasMultipleAccessingFunctions; + + /// HasNonInstructionUser - Set to true if this global has a user that is not + /// an instruction (e.g. a constant expr or GV initializer). + bool HasNonInstructionUser; + + /// HasPHIUser - Set to true if this global has a user that is a PHI node. + bool HasPHIUser; + + /// isNotSuitableForSRA - Keep track of whether any SRA preventing users of + /// the global exist. Such users include GEP instruction with variable + /// indexes, and non-gep/load/store users like constant expr casts. + bool isNotSuitableForSRA; + + GlobalStatus() : isLoaded(false), StoredType(NotStored), StoredOnceValue(0), + AccessingFunction(0), HasMultipleAccessingFunctions(false), + HasNonInstructionUser(false), HasPHIUser(false), + isNotSuitableForSRA(false) {} +}; + + + +/// ConstantIsDead - Return true if the specified constant is (transitively) +/// dead. The constant may be used by other constants (e.g. constant arrays and +/// constant exprs) as long as they are dead, but it cannot be used by anything +/// else. +static bool ConstantIsDead(Constant *C) { + if (isa(C)) return false; + + for (Value::use_iterator UI = C->use_begin(), E = C->use_end(); UI != E; ++UI) + if (Constant *CU = dyn_cast(*UI)) { + if (!ConstantIsDead(CU)) return false; + } else + return false; + return true; +} + + +/// AnalyzeGlobal - Look at all uses of the global and fill in the GlobalStatus +/// structure. If the global has its address taken, return true to indicate we +/// can't do anything with it. +/// +static bool AnalyzeGlobal(Value *V, GlobalStatus &GS, + std::set &PHIUsers) { + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E; ++UI) + if (ConstantExpr *CE = dyn_cast(*UI)) { + GS.HasNonInstructionUser = true; + + if (AnalyzeGlobal(CE, GS, PHIUsers)) return true; + if (CE->getOpcode() != Instruction::GetElementPtr) + GS.isNotSuitableForSRA = true; + else if (!GS.isNotSuitableForSRA) { + // Check to see if this ConstantExpr GEP is SRA'able. In particular, we + // don't like < 3 operand CE's, and we don't like non-constant integer + // indices. + if (CE->getNumOperands() < 3 || !CE->getOperand(1)->isNullValue()) + GS.isNotSuitableForSRA = true; + else { + for (unsigned i = 1, e = CE->getNumOperands(); i != e; ++i) + if (!isa(CE->getOperand(i))) { + GS.isNotSuitableForSRA = true; + break; + } + } + } + + } else if (Instruction *I = dyn_cast(*UI)) { + if (!GS.HasMultipleAccessingFunctions) { + Function *F = I->getParent()->getParent(); + if (GS.AccessingFunction == 0) + GS.AccessingFunction = F; + else if (GS.AccessingFunction != F) + GS.HasMultipleAccessingFunctions = true; + } + if (isa(I)) { + GS.isLoaded = true; + } else if (StoreInst *SI = dyn_cast(I)) { + // Don't allow a store OF the address, only stores TO the address. + if (SI->getOperand(0) == V) return true; + + // If this is a direct store to the global (i.e., the global is a scalar + // value, not an aggregate), keep more specific information about + // stores. + if (GS.StoredType != GlobalStatus::isStored) + if (GlobalVariable *GV = dyn_cast(SI->getOperand(1))){ + Value *StoredVal = SI->getOperand(0); + if (StoredVal == GV->getInitializer()) { + if (GS.StoredType < GlobalStatus::isInitializerStored) + GS.StoredType = GlobalStatus::isInitializerStored; + } else if (isa(StoredVal) && + cast(StoredVal)->getOperand(0) == GV) { + // G = G + if (GS.StoredType < GlobalStatus::isInitializerStored) + GS.StoredType = GlobalStatus::isInitializerStored; + } else if (GS.StoredType < GlobalStatus::isStoredOnce) { + GS.StoredType = GlobalStatus::isStoredOnce; + GS.StoredOnceValue = StoredVal; + } else if (GS.StoredType == GlobalStatus::isStoredOnce && + GS.StoredOnceValue == StoredVal) { + // noop. + } else { + GS.StoredType = GlobalStatus::isStored; + } + } else { + GS.StoredType = GlobalStatus::isStored; + } + } else if (isa(I)) { + if (AnalyzeGlobal(I, GS, PHIUsers)) return true; + + // If the first two indices are constants, this can be SRA'd. + if (isa(I->getOperand(0))) { + if (I->getNumOperands() < 3 || !isa(I->getOperand(1)) || + !cast(I->getOperand(1))->isNullValue() || + !isa(I->getOperand(2))) + GS.isNotSuitableForSRA = true; + } else if (ConstantExpr *CE = dyn_cast(I->getOperand(0))){ + if (CE->getOpcode() != Instruction::GetElementPtr || + CE->getNumOperands() < 3 || I->getNumOperands() < 2 || + !isa(I->getOperand(0)) || + !cast(I->getOperand(0))->isNullValue()) + GS.isNotSuitableForSRA = true; + } else { + GS.isNotSuitableForSRA = true; + } + } else if (isa(I)) { + if (AnalyzeGlobal(I, GS, PHIUsers)) return true; + GS.isNotSuitableForSRA = true; + } else if (PHINode *PN = dyn_cast(I)) { + // PHI nodes we can check just like select or GEP instructions, but we + // have to be careful about infinite recursion. + if (PHIUsers.insert(PN).second) // Not already visited. + if (AnalyzeGlobal(I, GS, PHIUsers)) return true; + GS.isNotSuitableForSRA = true; + GS.HasPHIUser = true; + } else if (isa(I)) { + GS.isNotSuitableForSRA = true; + } else if (isa(I) || isa(I)) { + if (I->getOperand(1) == V) + GS.StoredType = GlobalStatus::isStored; + if (I->getOperand(2) == V) + GS.isLoaded = true; + GS.isNotSuitableForSRA = true; + } else if (isa(I)) { + assert(I->getOperand(1) == V && "Memset only takes one pointer!"); + GS.StoredType = GlobalStatus::isStored; + GS.isNotSuitableForSRA = true; + } else { + return true; // Any other non-load instruction might take address! + } + } else if (Constant *C = dyn_cast(*UI)) { + GS.HasNonInstructionUser = true; + // We might have a dead and dangling constant hanging off of here. + if (!ConstantIsDead(C)) + return true; + } else { + GS.HasNonInstructionUser = true; + // Otherwise must be some other user. + return true; + } + + return false; +} + +static Constant *getAggregateConstantElement(Constant *Agg, Constant *Idx) { + ConstantInt *CI = dyn_cast(Idx); + if (!CI) return 0; + unsigned IdxV = CI->getZExtValue(); + + if (ConstantStruct *CS = dyn_cast(Agg)) { + if (IdxV < CS->getNumOperands()) return CS->getOperand(IdxV); + } else if (ConstantArray *CA = dyn_cast(Agg)) { + if (IdxV < CA->getNumOperands()) return CA->getOperand(IdxV); + } else if (ConstantVector *CP = dyn_cast(Agg)) { + if (IdxV < CP->getNumOperands()) return CP->getOperand(IdxV); + } else if (isa(Agg)) { + if (const StructType *STy = dyn_cast(Agg->getType())) { + if (IdxV < STy->getNumElements()) + return Constant::getNullValue(STy->getElementType(IdxV)); + } else if (const SequentialType *STy = + dyn_cast(Agg->getType())) { + return Constant::getNullValue(STy->getElementType()); + } + } else if (isa(Agg)) { + if (const StructType *STy = dyn_cast(Agg->getType())) { + if (IdxV < STy->getNumElements()) + return UndefValue::get(STy->getElementType(IdxV)); + } else if (const SequentialType *STy = + dyn_cast(Agg->getType())) { + return UndefValue::get(STy->getElementType()); + } + } + return 0; +} + + +/// CleanupConstantGlobalUsers - We just marked GV constant. Loop over all +/// users of the global, cleaning up the obvious ones. This is largely just a +/// quick scan over the use list to clean up the easy and obvious cruft. This +/// returns true if it made a change. +static bool CleanupConstantGlobalUsers(Value *V, Constant *Init) { + bool Changed = false; + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E;) { + User *U = *UI++; + + if (LoadInst *LI = dyn_cast(U)) { + if (Init) { + // Replace the load with the initializer. + LI->replaceAllUsesWith(Init); + LI->eraseFromParent(); + Changed = true; + } + } else if (StoreInst *SI = dyn_cast(U)) { + // Store must be unreachable or storing Init into the global. + SI->eraseFromParent(); + Changed = true; + } else if (ConstantExpr *CE = dyn_cast(U)) { + if (CE->getOpcode() == Instruction::GetElementPtr) { + Constant *SubInit = 0; + if (Init) + SubInit = ConstantFoldLoadThroughGEPConstantExpr(Init, CE); + Changed |= CleanupConstantGlobalUsers(CE, SubInit); + } else if (CE->getOpcode() == Instruction::BitCast && + isa(CE->getType())) { + // Pointer cast, delete any stores and memsets to the global. + Changed |= CleanupConstantGlobalUsers(CE, 0); + } + + if (CE->use_empty()) { + CE->destroyConstant(); + Changed = true; + } + } else if (GetElementPtrInst *GEP = dyn_cast(U)) { + Constant *SubInit = 0; + ConstantExpr *CE = + dyn_cast_or_null(ConstantFoldInstruction(GEP)); + if (Init && CE && CE->getOpcode() == Instruction::GetElementPtr) + SubInit = ConstantFoldLoadThroughGEPConstantExpr(Init, CE); + Changed |= CleanupConstantGlobalUsers(GEP, SubInit); + + if (GEP->use_empty()) { + GEP->eraseFromParent(); + Changed = true; + } + } else if (MemIntrinsic *MI = dyn_cast(U)) { // memset/cpy/mv + if (MI->getRawDest() == V) { + MI->eraseFromParent(); + Changed = true; + } + + } else if (Constant *C = dyn_cast(U)) { + // If we have a chain of dead constantexprs or other things dangling from + // us, and if they are all dead, nuke them without remorse. + if (ConstantIsDead(C)) { + C->destroyConstant(); + // This could have invalidated UI, start over from scratch. + CleanupConstantGlobalUsers(V, Init); + return true; + } + } + } + return Changed; +} + +/// SRAGlobal - Perform scalar replacement of aggregates on the specified global +/// variable. This opens the door for other optimizations by exposing the +/// behavior of the program in a more fine-grained way. We have determined that +/// this transformation is safe already. We return the first global variable we +/// insert so that the caller can reprocess it. +static GlobalVariable *SRAGlobal(GlobalVariable *GV) { + assert(GV->hasInternalLinkage() && !GV->isConstant()); + Constant *Init = GV->getInitializer(); + const Type *Ty = Init->getType(); + + std::vector NewGlobals; + Module::GlobalListType &Globals = GV->getParent()->getGlobalList(); + + if (const StructType *STy = dyn_cast(Ty)) { + NewGlobals.reserve(STy->getNumElements()); + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + Constant *In = getAggregateConstantElement(Init, + ConstantInt::get(Type::Int32Ty, i)); + assert(In && "Couldn't get element of initializer?"); + GlobalVariable *NGV = new GlobalVariable(STy->getElementType(i), false, + GlobalVariable::InternalLinkage, + In, GV->getName()+"."+utostr(i), + (Module *)NULL, + GV->isThreadLocal()); + Globals.insert(GV, NGV); + NewGlobals.push_back(NGV); + } + } else if (const SequentialType *STy = dyn_cast(Ty)) { + unsigned NumElements = 0; + if (const ArrayType *ATy = dyn_cast(STy)) + NumElements = ATy->getNumElements(); + else if (const VectorType *PTy = dyn_cast(STy)) + NumElements = PTy->getNumElements(); + else + assert(0 && "Unknown aggregate sequential type!"); + + if (NumElements > 16 && GV->hasNUsesOrMore(16)) + return 0; // It's not worth it. + NewGlobals.reserve(NumElements); + for (unsigned i = 0, e = NumElements; i != e; ++i) { + Constant *In = getAggregateConstantElement(Init, + ConstantInt::get(Type::Int32Ty, i)); + assert(In && "Couldn't get element of initializer?"); + + GlobalVariable *NGV = new GlobalVariable(STy->getElementType(), false, + GlobalVariable::InternalLinkage, + In, GV->getName()+"."+utostr(i), + (Module *)NULL, + GV->isThreadLocal()); + Globals.insert(GV, NGV); + NewGlobals.push_back(NGV); + } + } + + if (NewGlobals.empty()) + return 0; + + DOUT << "PERFORMING GLOBAL SRA ON: " << *GV; + + Constant *NullInt = Constant::getNullValue(Type::Int32Ty); + + // Loop over all of the uses of the global, replacing the constantexpr geps, + // with smaller constantexpr geps or direct references. + while (!GV->use_empty()) { + User *GEP = GV->use_back(); + assert(((isa(GEP) && + cast(GEP)->getOpcode()==Instruction::GetElementPtr)|| + isa(GEP)) && "NonGEP CE's are not SRAable!"); + + // Ignore the 1th operand, which has to be zero or else the program is quite + // broken (undefined). Get the 2nd operand, which is the structure or array + // index. + unsigned Val = cast(GEP->getOperand(2))->getZExtValue(); + if (Val >= NewGlobals.size()) Val = 0; // Out of bound array access. + + Value *NewPtr = NewGlobals[Val]; + + // Form a shorter GEP if needed. + if (GEP->getNumOperands() > 3) + if (ConstantExpr *CE = dyn_cast(GEP)) { + SmallVector Idxs; + Idxs.push_back(NullInt); + for (unsigned i = 3, e = CE->getNumOperands(); i != e; ++i) + Idxs.push_back(CE->getOperand(i)); + NewPtr = ConstantExpr::getGetElementPtr(cast(NewPtr), + &Idxs[0], Idxs.size()); + } else { + GetElementPtrInst *GEPI = cast(GEP); + SmallVector Idxs; + Idxs.push_back(NullInt); + for (unsigned i = 3, e = GEPI->getNumOperands(); i != e; ++i) + Idxs.push_back(GEPI->getOperand(i)); + NewPtr = new GetElementPtrInst(NewPtr, &Idxs[0], Idxs.size(), + GEPI->getName()+"."+utostr(Val), GEPI); + } + GEP->replaceAllUsesWith(NewPtr); + + if (GetElementPtrInst *GEPI = dyn_cast(GEP)) + GEPI->eraseFromParent(); + else + cast(GEP)->destroyConstant(); + } + + // Delete the old global, now that it is dead. + Globals.erase(GV); + ++NumSRA; + + // Loop over the new globals array deleting any globals that are obviously + // dead. This can arise due to scalarization of a structure or an array that + // has elements that are dead. + unsigned FirstGlobal = 0; + for (unsigned i = 0, e = NewGlobals.size(); i != e; ++i) + if (NewGlobals[i]->use_empty()) { + Globals.erase(NewGlobals[i]); + if (FirstGlobal == i) ++FirstGlobal; + } + + return FirstGlobal != NewGlobals.size() ? NewGlobals[FirstGlobal] : 0; +} + +/// AllUsesOfValueWillTrapIfNull - Return true if all users of the specified +/// value will trap if the value is dynamically null. +static bool AllUsesOfValueWillTrapIfNull(Value *V) { + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E; ++UI) + if (isa(*UI)) { + // Will trap. + } else if (StoreInst *SI = dyn_cast(*UI)) { + if (SI->getOperand(0) == V) { + //cerr << "NONTRAPPING USE: " << **UI; + return false; // Storing the value. + } + } else if (CallInst *CI = dyn_cast(*UI)) { + if (CI->getOperand(0) != V) { + //cerr << "NONTRAPPING USE: " << **UI; + return false; // Not calling the ptr + } + } else if (InvokeInst *II = dyn_cast(*UI)) { + if (II->getOperand(0) != V) { + //cerr << "NONTRAPPING USE: " << **UI; + return false; // Not calling the ptr + } + } else if (CastInst *CI = dyn_cast(*UI)) { + if (!AllUsesOfValueWillTrapIfNull(CI)) return false; + } else if (GetElementPtrInst *GEPI = dyn_cast(*UI)) { + if (!AllUsesOfValueWillTrapIfNull(GEPI)) return false; + } else if (isa(*UI) && + isa(UI->getOperand(1))) { + // Ignore setcc X, null + } else { + //cerr << "NONTRAPPING USE: " << **UI; + return false; + } + return true; +} + +/// AllUsesOfLoadedValueWillTrapIfNull - Return true if all uses of any loads +/// from GV will trap if the loaded value is null. Note that this also permits +/// comparisons of the loaded value against null, as a special case. +static bool AllUsesOfLoadedValueWillTrapIfNull(GlobalVariable *GV) { + for (Value::use_iterator UI = GV->use_begin(), E = GV->use_end(); UI!=E; ++UI) + if (LoadInst *LI = dyn_cast(*UI)) { + if (!AllUsesOfValueWillTrapIfNull(LI)) + return false; + } else if (isa(*UI)) { + // Ignore stores to the global. + } else { + // We don't know or understand this user, bail out. + //cerr << "UNKNOWN USER OF GLOBAL!: " << **UI; + return false; + } + + return true; +} + +static bool OptimizeAwayTrappingUsesOfValue(Value *V, Constant *NewV) { + bool Changed = false; + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E; ) { + Instruction *I = cast(*UI++); + if (LoadInst *LI = dyn_cast(I)) { + LI->setOperand(0, NewV); + Changed = true; + } else if (StoreInst *SI = dyn_cast(I)) { + if (SI->getOperand(1) == V) { + SI->setOperand(1, NewV); + Changed = true; + } + } else if (isa(I) || isa(I)) { + if (I->getOperand(0) == V) { + // Calling through the pointer! Turn into a direct call, but be careful + // that the pointer is not also being passed as an argument. + I->setOperand(0, NewV); + Changed = true; + bool PassedAsArg = false; + for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i) + if (I->getOperand(i) == V) { + PassedAsArg = true; + I->setOperand(i, NewV); + } + + if (PassedAsArg) { + // Being passed as an argument also. Be careful to not invalidate UI! + UI = V->use_begin(); + } + } + } else if (CastInst *CI = dyn_cast(I)) { + Changed |= OptimizeAwayTrappingUsesOfValue(CI, + ConstantExpr::getCast(CI->getOpcode(), + NewV, CI->getType())); + if (CI->use_empty()) { + Changed = true; + CI->eraseFromParent(); + } + } else if (GetElementPtrInst *GEPI = dyn_cast(I)) { + // Should handle GEP here. + SmallVector Idxs; + Idxs.reserve(GEPI->getNumOperands()-1); + for (unsigned i = 1, e = GEPI->getNumOperands(); i != e; ++i) + if (Constant *C = dyn_cast(GEPI->getOperand(i))) + Idxs.push_back(C); + else + break; + if (Idxs.size() == GEPI->getNumOperands()-1) + Changed |= OptimizeAwayTrappingUsesOfValue(GEPI, + ConstantExpr::getGetElementPtr(NewV, &Idxs[0], + Idxs.size())); + if (GEPI->use_empty()) { + Changed = true; + GEPI->eraseFromParent(); + } + } + } + + return Changed; +} + + +/// OptimizeAwayTrappingUsesOfLoads - The specified global has only one non-null +/// value stored into it. If there are uses of the loaded value that would trap +/// if the loaded value is dynamically null, then we know that they cannot be +/// reachable with a null optimize away the load. +static bool OptimizeAwayTrappingUsesOfLoads(GlobalVariable *GV, Constant *LV) { + std::vector Loads; + bool Changed = false; + + // Replace all uses of loads with uses of uses of the stored value. + for (Value::use_iterator GUI = GV->use_begin(), E = GV->use_end(); + GUI != E; ++GUI) + if (LoadInst *LI = dyn_cast(*GUI)) { + Loads.push_back(LI); + Changed |= OptimizeAwayTrappingUsesOfValue(LI, LV); + } else { + // If we get here we could have stores, selects, or phi nodes whose values + // are loaded. + assert((isa(*GUI) || isa(*GUI) || + isa(*GUI)) && + "Only expect load and stores!"); + } + + if (Changed) { + DOUT << "OPTIMIZED LOADS FROM STORED ONCE POINTER: " << *GV; + ++NumGlobUses; + } + + // Delete all of the loads we can, keeping track of whether we nuked them all! + bool AllLoadsGone = true; + while (!Loads.empty()) { + LoadInst *L = Loads.back(); + if (L->use_empty()) { + L->eraseFromParent(); + Changed = true; + } else { + AllLoadsGone = false; + } + Loads.pop_back(); + } + + // If we nuked all of the loads, then none of the stores are needed either, + // nor is the global. + if (AllLoadsGone) { + DOUT << " *** GLOBAL NOW DEAD!\n"; + CleanupConstantGlobalUsers(GV, 0); + if (GV->use_empty()) { + GV->eraseFromParent(); + ++NumDeleted; + } + Changed = true; + } + return Changed; +} + +/// ConstantPropUsersOf - Walk the use list of V, constant folding all of the +/// instructions that are foldable. +static void ConstantPropUsersOf(Value *V) { + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E; ) + if (Instruction *I = dyn_cast(*UI++)) + if (Constant *NewC = ConstantFoldInstruction(I)) { + I->replaceAllUsesWith(NewC); + + // Advance UI to the next non-I use to avoid invalidating it! + // Instructions could multiply use V. + while (UI != E && *UI == I) + ++UI; + I->eraseFromParent(); + } +} + +/// OptimizeGlobalAddressOfMalloc - This function takes the specified global +/// variable, and transforms the program as if it always contained the result of +/// the specified malloc. Because it is always the result of the specified +/// malloc, there is no reason to actually DO the malloc. Instead, turn the +/// malloc into a global, and any loads of GV as uses of the new global. +static GlobalVariable *OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, + MallocInst *MI) { + DOUT << "PROMOTING MALLOC GLOBAL: " << *GV << " MALLOC = " << *MI; + ConstantInt *NElements = cast(MI->getArraySize()); + + if (NElements->getZExtValue() != 1) { + // If we have an array allocation, transform it to a single element + // allocation to make the code below simpler. + Type *NewTy = ArrayType::get(MI->getAllocatedType(), + NElements->getZExtValue()); + MallocInst *NewMI = + new MallocInst(NewTy, Constant::getNullValue(Type::Int32Ty), + MI->getAlignment(), MI->getName(), MI); + Value* Indices[2]; + Indices[0] = Indices[1] = Constant::getNullValue(Type::Int32Ty); + Value *NewGEP = new GetElementPtrInst(NewMI, Indices, 2, + NewMI->getName()+".el0", MI); + MI->replaceAllUsesWith(NewGEP); + MI->eraseFromParent(); + MI = NewMI; + } + + // Create the new global variable. The contents of the malloc'd memory is + // undefined, so initialize with an undef value. + Constant *Init = UndefValue::get(MI->getAllocatedType()); + GlobalVariable *NewGV = new GlobalVariable(MI->getAllocatedType(), false, + GlobalValue::InternalLinkage, Init, + GV->getName()+".body", + (Module *)NULL, + GV->isThreadLocal()); + GV->getParent()->getGlobalList().insert(GV, NewGV); + + // Anything that used the malloc now uses the global directly. + MI->replaceAllUsesWith(NewGV); + + Constant *RepValue = NewGV; + if (NewGV->getType() != GV->getType()->getElementType()) + RepValue = ConstantExpr::getBitCast(RepValue, + GV->getType()->getElementType()); + + // If there is a comparison against null, we will insert a global bool to + // keep track of whether the global was initialized yet or not. + GlobalVariable *InitBool = + new GlobalVariable(Type::Int1Ty, false, GlobalValue::InternalLinkage, + ConstantInt::getFalse(), GV->getName()+".init", + (Module *)NULL, GV->isThreadLocal()); + bool InitBoolUsed = false; + + // Loop over all uses of GV, processing them in turn. + std::vector Stores; + while (!GV->use_empty()) + if (LoadInst *LI = dyn_cast(GV->use_back())) { + while (!LI->use_empty()) { + Use &LoadUse = LI->use_begin().getUse(); + if (!isa(LoadUse.getUser())) + LoadUse = RepValue; + else { + ICmpInst *CI = cast(LoadUse.getUser()); + // Replace the cmp X, 0 with a use of the bool value. + Value *LV = new LoadInst(InitBool, InitBool->getName()+".val", CI); + InitBoolUsed = true; + switch (CI->getPredicate()) { + default: assert(0 && "Unknown ICmp Predicate!"); + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + LV = ConstantInt::getFalse(); // X < null -> always false + break; + case ICmpInst::ICMP_ULE: + case ICmpInst::ICMP_SLE: + case ICmpInst::ICMP_EQ: + LV = BinaryOperator::createNot(LV, "notinit", CI); + break; + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_UGE: + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + break; // no change. + } + CI->replaceAllUsesWith(LV); + CI->eraseFromParent(); + } + } + LI->eraseFromParent(); + } else { + StoreInst *SI = cast(GV->use_back()); + // The global is initialized when the store to it occurs. + new StoreInst(ConstantInt::getTrue(), InitBool, SI); + SI->eraseFromParent(); + } + + // If the initialization boolean was used, insert it, otherwise delete it. + if (!InitBoolUsed) { + while (!InitBool->use_empty()) // Delete initializations + cast(InitBool->use_back())->eraseFromParent(); + delete InitBool; + } else + GV->getParent()->getGlobalList().insert(GV, InitBool); + + + // Now the GV is dead, nuke it and the malloc. + GV->eraseFromParent(); + MI->eraseFromParent(); + + // To further other optimizations, loop over all users of NewGV and try to + // constant prop them. This will promote GEP instructions with constant + // indices into GEP constant-exprs, which will allow global-opt to hack on it. + ConstantPropUsersOf(NewGV); + if (RepValue != NewGV) + ConstantPropUsersOf(RepValue); + + return NewGV; +} + +/// ValueIsOnlyUsedLocallyOrStoredToOneGlobal - Scan the use-list of V checking +/// to make sure that there are no complex uses of V. We permit simple things +/// like dereferencing the pointer, but not storing through the address, unless +/// it is to the specified global. +static bool ValueIsOnlyUsedLocallyOrStoredToOneGlobal(Instruction *V, + GlobalVariable *GV) { + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E;++UI) + if (isa(*UI) || isa(*UI)) { + // Fine, ignore. + } else if (StoreInst *SI = dyn_cast(*UI)) { + if (SI->getOperand(0) == V && SI->getOperand(1) != GV) + return false; // Storing the pointer itself... bad. + // Otherwise, storing through it, or storing into GV... fine. + } else if (isa(*UI) || isa(*UI)) { + if (!ValueIsOnlyUsedLocallyOrStoredToOneGlobal(cast(*UI),GV)) + return false; + } else { + return false; + } + return true; +} + +/// ReplaceUsesOfMallocWithGlobal - The Alloc pointer is stored into GV +/// somewhere. Transform all uses of the allocation into loads from the +/// global and uses of the resultant pointer. Further, delete the store into +/// GV. This assumes that these value pass the +/// 'ValueIsOnlyUsedLocallyOrStoredToOneGlobal' predicate. +static void ReplaceUsesOfMallocWithGlobal(Instruction *Alloc, + GlobalVariable *GV) { + while (!Alloc->use_empty()) { + Instruction *U = Alloc->use_back(); + if (StoreInst *SI = dyn_cast(U)) { + // If this is the store of the allocation into the global, remove it. + if (SI->getOperand(1) == GV) { + SI->eraseFromParent(); + continue; + } + } + + // Insert a load from the global, and use it instead of the malloc. + Value *NL = new LoadInst(GV, GV->getName()+".val", U); + U->replaceUsesOfWith(Alloc, NL); + } +} + +/// GlobalLoadUsesSimpleEnoughForHeapSRA - If all users of values loaded from +/// GV are simple enough to perform HeapSRA, return true. +static bool GlobalLoadUsesSimpleEnoughForHeapSRA(GlobalVariable *GV) { + for (Value::use_iterator UI = GV->use_begin(), E = GV->use_end(); UI != E; + ++UI) + if (LoadInst *LI = dyn_cast(*UI)) { + // We permit two users of the load: setcc comparing against the null + // pointer, and a getelementptr of a specific form. + for (Value::use_iterator UI = LI->use_begin(), E = LI->use_end(); UI != E; + ++UI) { + // Comparison against null is ok. + if (ICmpInst *ICI = dyn_cast(*UI)) { + if (!isa(ICI->getOperand(1))) + return false; + continue; + } + + // getelementptr is also ok, but only a simple form. + GetElementPtrInst *GEPI = dyn_cast(*UI); + if (!GEPI) return false; + + // Must index into the array and into the struct. + if (GEPI->getNumOperands() < 3) + return false; + + // Otherwise the GEP is ok. + continue; + } + } + return true; +} + +/// RewriteUsesOfLoadForHeapSRoA - We are performing Heap SRoA on a global. Ptr +/// is a value loaded from the global. Eliminate all uses of Ptr, making them +/// use FieldGlobals instead. All uses of loaded values satisfy +/// GlobalLoadUsesSimpleEnoughForHeapSRA. +static void RewriteUsesOfLoadForHeapSRoA(LoadInst *Ptr, + const std::vector &FieldGlobals) { + std::vector InsertedLoadsForPtr; + //InsertedLoadsForPtr.resize(FieldGlobals.size()); + while (!Ptr->use_empty()) { + Instruction *User = Ptr->use_back(); + + // If this is a comparison against null, handle it. + if (ICmpInst *SCI = dyn_cast(User)) { + assert(isa(SCI->getOperand(1))); + // If we have a setcc of the loaded pointer, we can use a setcc of any + // field. + Value *NPtr; + if (InsertedLoadsForPtr.empty()) { + NPtr = new LoadInst(FieldGlobals[0], Ptr->getName()+".f0", Ptr); + InsertedLoadsForPtr.push_back(Ptr); + } else { + NPtr = InsertedLoadsForPtr.back(); + } + + Value *New = new ICmpInst(SCI->getPredicate(), NPtr, + Constant::getNullValue(NPtr->getType()), + SCI->getName(), SCI); + SCI->replaceAllUsesWith(New); + SCI->eraseFromParent(); + continue; + } + + // Otherwise, this should be: 'getelementptr Ptr, Idx, uint FieldNo ...' + GetElementPtrInst *GEPI = cast(User); + assert(GEPI->getNumOperands() >= 3 && isa(GEPI->getOperand(2)) + && "Unexpected GEPI!"); + + // Load the pointer for this field. + unsigned FieldNo = cast(GEPI->getOperand(2))->getZExtValue(); + if (InsertedLoadsForPtr.size() <= FieldNo) + InsertedLoadsForPtr.resize(FieldNo+1); + if (InsertedLoadsForPtr[FieldNo] == 0) + InsertedLoadsForPtr[FieldNo] = new LoadInst(FieldGlobals[FieldNo], + Ptr->getName()+".f" + + utostr(FieldNo), Ptr); + Value *NewPtr = InsertedLoadsForPtr[FieldNo]; + + // Create the new GEP idx vector. + SmallVector GEPIdx; + GEPIdx.push_back(GEPI->getOperand(1)); + GEPIdx.append(GEPI->op_begin()+3, GEPI->op_end()); + + Value *NGEPI = new GetElementPtrInst(NewPtr, &GEPIdx[0], GEPIdx.size(), + GEPI->getName(), GEPI); + GEPI->replaceAllUsesWith(NGEPI); + GEPI->eraseFromParent(); + } +} + +/// PerformHeapAllocSRoA - MI is an allocation of an array of structures. Break +/// it up into multiple allocations of arrays of the fields. +static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, MallocInst *MI){ + DOUT << "SROA HEAP ALLOC: " << *GV << " MALLOC = " << *MI; + const StructType *STy = cast(MI->getAllocatedType()); + + // There is guaranteed to be at least one use of the malloc (storing + // it into GV). If there are other uses, change them to be uses of + // the global to simplify later code. This also deletes the store + // into GV. + ReplaceUsesOfMallocWithGlobal(MI, GV); + + // Okay, at this point, there are no users of the malloc. Insert N + // new mallocs at the same place as MI, and N globals. + std::vector FieldGlobals; + std::vector FieldMallocs; + + for (unsigned FieldNo = 0, e = STy->getNumElements(); FieldNo != e;++FieldNo){ + const Type *FieldTy = STy->getElementType(FieldNo); + const Type *PFieldTy = PointerType::get(FieldTy); + + GlobalVariable *NGV = + new GlobalVariable(PFieldTy, false, GlobalValue::InternalLinkage, + Constant::getNullValue(PFieldTy), + GV->getName() + ".f" + utostr(FieldNo), GV, + GV->isThreadLocal()); + FieldGlobals.push_back(NGV); + + MallocInst *NMI = new MallocInst(FieldTy, MI->getArraySize(), + MI->getName() + ".f" + utostr(FieldNo),MI); + FieldMallocs.push_back(NMI); + new StoreInst(NMI, NGV, MI); + } + + // The tricky aspect of this transformation is handling the case when malloc + // fails. In the original code, malloc failing would set the result pointer + // of malloc to null. In this case, some mallocs could succeed and others + // could fail. As such, we emit code that looks like this: + // F0 = malloc(field0) + // F1 = malloc(field1) + // F2 = malloc(field2) + // if (F0 == 0 || F1 == 0 || F2 == 0) { + // if (F0) { free(F0); F0 = 0; } + // if (F1) { free(F1); F1 = 0; } + // if (F2) { free(F2); F2 = 0; } + // } + Value *RunningOr = 0; + for (unsigned i = 0, e = FieldMallocs.size(); i != e; ++i) { + Value *Cond = new ICmpInst(ICmpInst::ICMP_EQ, FieldMallocs[i], + Constant::getNullValue(FieldMallocs[i]->getType()), + "isnull", MI); + if (!RunningOr) + RunningOr = Cond; // First seteq + else + RunningOr = BinaryOperator::createOr(RunningOr, Cond, "tmp", MI); + } + + // Split the basic block at the old malloc. + BasicBlock *OrigBB = MI->getParent(); + BasicBlock *ContBB = OrigBB->splitBasicBlock(MI, "malloc_cont"); + + // Create the block to check the first condition. Put all these blocks at the + // end of the function as they are unlikely to be executed. + BasicBlock *NullPtrBlock = new BasicBlock("malloc_ret_null", + OrigBB->getParent()); + + // Remove the uncond branch from OrigBB to ContBB, turning it into a cond + // branch on RunningOr. + OrigBB->getTerminator()->eraseFromParent(); + new BranchInst(NullPtrBlock, ContBB, RunningOr, OrigBB); + + // Within the NullPtrBlock, we need to emit a comparison and branch for each + // pointer, because some may be null while others are not. + for (unsigned i = 0, e = FieldGlobals.size(); i != e; ++i) { + Value *GVVal = new LoadInst(FieldGlobals[i], "tmp", NullPtrBlock); + Value *Cmp = new ICmpInst(ICmpInst::ICMP_NE, GVVal, + Constant::getNullValue(GVVal->getType()), + "tmp", NullPtrBlock); + BasicBlock *FreeBlock = new BasicBlock("free_it", OrigBB->getParent()); + BasicBlock *NextBlock = new BasicBlock("next", OrigBB->getParent()); + new BranchInst(FreeBlock, NextBlock, Cmp, NullPtrBlock); + + // Fill in FreeBlock. + new FreeInst(GVVal, FreeBlock); + new StoreInst(Constant::getNullValue(GVVal->getType()), FieldGlobals[i], + FreeBlock); + new BranchInst(NextBlock, FreeBlock); + + NullPtrBlock = NextBlock; + } + + new BranchInst(ContBB, NullPtrBlock); + + + // MI is no longer needed, remove it. + MI->eraseFromParent(); + + + // Okay, the malloc site is completely handled. All of the uses of GV are now + // loads, and all uses of those loads are simple. Rewrite them to use loads + // of the per-field globals instead. + while (!GV->use_empty()) { + if (LoadInst *LI = dyn_cast(GV->use_back())) { + RewriteUsesOfLoadForHeapSRoA(LI, FieldGlobals); + LI->eraseFromParent(); + } else { + // Must be a store of null. + StoreInst *SI = cast(GV->use_back()); + assert(isa(SI->getOperand(0)) && + cast(SI->getOperand(0))->isNullValue() && + "Unexpected heap-sra user!"); + + // Insert a store of null into each global. + for (unsigned i = 0, e = FieldGlobals.size(); i != e; ++i) { + Constant *Null = + Constant::getNullValue(FieldGlobals[i]->getType()->getElementType()); + new StoreInst(Null, FieldGlobals[i], SI); + } + // Erase the original store. + SI->eraseFromParent(); + } + } + + // The old global is now dead, remove it. + GV->eraseFromParent(); + + ++NumHeapSRA; + return FieldGlobals[0]; +} + + +// OptimizeOnceStoredGlobal - Try to optimize globals based on the knowledge +// that only one value (besides its initializer) is ever stored to the global. +static bool OptimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal, + Module::global_iterator &GVI, + TargetData &TD) { + if (CastInst *CI = dyn_cast(StoredOnceVal)) + StoredOnceVal = CI->getOperand(0); + else if (GetElementPtrInst *GEPI =dyn_cast(StoredOnceVal)){ + // "getelementptr Ptr, 0, 0, 0" is really just a cast. + bool IsJustACast = true; + for (unsigned i = 1, e = GEPI->getNumOperands(); i != e; ++i) + if (!isa(GEPI->getOperand(i)) || + !cast(GEPI->getOperand(i))->isNullValue()) { + IsJustACast = false; + break; + } + if (IsJustACast) + StoredOnceVal = GEPI->getOperand(0); + } + + // If we are dealing with a pointer global that is initialized to null and + // only has one (non-null) value stored into it, then we can optimize any + // users of the loaded value (often calls and loads) that would trap if the + // value was null. + if (isa(GV->getInitializer()->getType()) && + GV->getInitializer()->isNullValue()) { + if (Constant *SOVC = dyn_cast(StoredOnceVal)) { + if (GV->getInitializer()->getType() != SOVC->getType()) + SOVC = ConstantExpr::getBitCast(SOVC, GV->getInitializer()->getType()); + + // Optimize away any trapping uses of the loaded value. + if (OptimizeAwayTrappingUsesOfLoads(GV, SOVC)) + return true; + } else if (MallocInst *MI = dyn_cast(StoredOnceVal)) { + // If this is a malloc of an abstract type, don't touch it. + if (!MI->getAllocatedType()->isSized()) + return false; + + // We can't optimize this global unless all uses of it are *known* to be + // of the malloc value, not of the null initializer value (consider a use + // that compares the global's value against zero to see if the malloc has + // been reached). To do this, we check to see if all uses of the global + // would trap if the global were null: this proves that they must all + // happen after the malloc. + if (!AllUsesOfLoadedValueWillTrapIfNull(GV)) + return false; + + // We can't optimize this if the malloc itself is used in a complex way, + // for example, being stored into multiple globals. This allows the + // malloc to be stored into the specified global, loaded setcc'd, and + // GEP'd. These are all things we could transform to using the global + // for. + if (!ValueIsOnlyUsedLocallyOrStoredToOneGlobal(MI, GV)) + return false; + + + // If we have a global that is only initialized with a fixed size malloc, + // transform the program to use global memory instead of malloc'd memory. + // This eliminates dynamic allocation, avoids an indirection accessing the + // data, and exposes the resultant global to further GlobalOpt. + if (ConstantInt *NElements = dyn_cast(MI->getArraySize())) { + // Restrict this transformation to only working on small allocations + // (2048 bytes currently), as we don't want to introduce a 16M global or + // something. + if (NElements->getZExtValue()* + TD.getTypeSize(MI->getAllocatedType()) < 2048) { + GVI = OptimizeGlobalAddressOfMalloc(GV, MI); + return true; + } + } + + // If the allocation is an array of structures, consider transforming this + // into multiple malloc'd arrays, one for each field. This is basically + // SRoA for malloc'd memory. + if (const StructType *AllocTy = + dyn_cast(MI->getAllocatedType())) { + // This the structure has an unreasonable number of fields, leave it + // alone. + if (AllocTy->getNumElements() <= 16 && AllocTy->getNumElements() > 0 && + GlobalLoadUsesSimpleEnoughForHeapSRA(GV)) { + GVI = PerformHeapAllocSRoA(GV, MI); + return true; + } + } + } + } + + return false; +} + +/// ShrinkGlobalToBoolean - At this point, we have learned that the only two +/// values ever stored into GV are its initializer and OtherVal. +static void ShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { + // Create the new global, initializing it to false. + GlobalVariable *NewGV = new GlobalVariable(Type::Int1Ty, false, + GlobalValue::InternalLinkage, ConstantInt::getFalse(), + GV->getName()+".b", + (Module *)NULL, + GV->isThreadLocal()); + GV->getParent()->getGlobalList().insert(GV, NewGV); + + Constant *InitVal = GV->getInitializer(); + assert(InitVal->getType() != Type::Int1Ty && "No reason to shrink to bool!"); + + // If initialized to zero and storing one into the global, we can use a cast + // instead of a select to synthesize the desired value. + bool IsOneZero = false; + if (ConstantInt *CI = dyn_cast(OtherVal)) + IsOneZero = InitVal->isNullValue() && CI->isOne(); + + while (!GV->use_empty()) { + Instruction *UI = cast(GV->use_back()); + if (StoreInst *SI = dyn_cast(UI)) { + // Change the store into a boolean store. + bool StoringOther = SI->getOperand(0) == OtherVal; + // Only do this if we weren't storing a loaded value. + Value *StoreVal; + if (StoringOther || SI->getOperand(0) == InitVal) + StoreVal = ConstantInt::get(Type::Int1Ty, StoringOther); + else { + // Otherwise, we are storing a previously loaded copy. To do this, + // change the copy from copying the original value to just copying the + // bool. + Instruction *StoredVal = cast(SI->getOperand(0)); + + // If we're already replaced the input, StoredVal will be a cast or + // select instruction. If not, it will be a load of the original + // global. + if (LoadInst *LI = dyn_cast(StoredVal)) { + assert(LI->getOperand(0) == GV && "Not a copy!"); + // Insert a new load, to preserve the saved value. + StoreVal = new LoadInst(NewGV, LI->getName()+".b", LI); + } else { + assert((isa(StoredVal) || isa(StoredVal)) && + "This is not a form that we understand!"); + StoreVal = StoredVal->getOperand(0); + assert(isa(StoreVal) && "Not a load of NewGV!"); + } + } + new StoreInst(StoreVal, NewGV, SI); + } else if (!UI->use_empty()) { + // Change the load into a load of bool then a select. + LoadInst *LI = cast(UI); + LoadInst *NLI = new LoadInst(NewGV, LI->getName()+".b", LI); + Value *NSI; + if (IsOneZero) + NSI = new ZExtInst(NLI, LI->getType(), "", LI); + else + NSI = new SelectInst(NLI, OtherVal, InitVal, "", LI); + NSI->takeName(LI); + LI->replaceAllUsesWith(NSI); + } + UI->eraseFromParent(); + } + + GV->eraseFromParent(); +} + + +/// ProcessInternalGlobal - Analyze the specified global variable and optimize +/// it if possible. If we make a change, return true. +bool GlobalOpt::ProcessInternalGlobal(GlobalVariable *GV, + Module::global_iterator &GVI) { + std::set PHIUsers; + GlobalStatus GS; + GV->removeDeadConstantUsers(); + + if (GV->use_empty()) { + DOUT << "GLOBAL DEAD: " << *GV; + GV->eraseFromParent(); + ++NumDeleted; + return true; + } + + if (!AnalyzeGlobal(GV, GS, PHIUsers)) { +#if 0 + cerr << "Global: " << *GV; + cerr << " isLoaded = " << GS.isLoaded << "\n"; + cerr << " StoredType = "; + switch (GS.StoredType) { + case GlobalStatus::NotStored: cerr << "NEVER STORED\n"; break; + case GlobalStatus::isInitializerStored: cerr << "INIT STORED\n"; break; + case GlobalStatus::isStoredOnce: cerr << "STORED ONCE\n"; break; + case GlobalStatus::isStored: cerr << "stored\n"; break; + } + if (GS.StoredType == GlobalStatus::isStoredOnce && GS.StoredOnceValue) + cerr << " StoredOnceValue = " << *GS.StoredOnceValue << "\n"; + if (GS.AccessingFunction && !GS.HasMultipleAccessingFunctions) + cerr << " AccessingFunction = " << GS.AccessingFunction->getName() + << "\n"; + cerr << " HasMultipleAccessingFunctions = " + << GS.HasMultipleAccessingFunctions << "\n"; + cerr << " HasNonInstructionUser = " << GS.HasNonInstructionUser<<"\n"; + cerr << " isNotSuitableForSRA = " << GS.isNotSuitableForSRA << "\n"; + cerr << "\n"; +#endif + + // If this is a first class global and has only one accessing function + // and this function is main (which we know is not recursive we can make + // this global a local variable) we replace the global with a local alloca + // in this function. + // + // NOTE: It doesn't make sense to promote non first class types since we + // are just replacing static memory to stack memory. + if (!GS.HasMultipleAccessingFunctions && + GS.AccessingFunction && !GS.HasNonInstructionUser && + GV->getType()->getElementType()->isFirstClassType() && + GS.AccessingFunction->getName() == "main" && + GS.AccessingFunction->hasExternalLinkage()) { + DOUT << "LOCALIZING GLOBAL: " << *GV; + Instruction* FirstI = GS.AccessingFunction->getEntryBlock().begin(); + const Type* ElemTy = GV->getType()->getElementType(); + // FIXME: Pass Global's alignment when globals have alignment + AllocaInst* Alloca = new AllocaInst(ElemTy, NULL, GV->getName(), FirstI); + if (!isa(GV->getInitializer())) + new StoreInst(GV->getInitializer(), Alloca, FirstI); + + GV->replaceAllUsesWith(Alloca); + GV->eraseFromParent(); + ++NumLocalized; + return true; + } + + // If the global is never loaded (but may be stored to), it is dead. + // Delete it now. + if (!GS.isLoaded) { + DOUT << "GLOBAL NEVER LOADED: " << *GV; + + // Delete any stores we can find to the global. We may not be able to + // make it completely dead though. + bool Changed = CleanupConstantGlobalUsers(GV, GV->getInitializer()); + + // If the global is dead now, delete it. + if (GV->use_empty()) { + GV->eraseFromParent(); + ++NumDeleted; + Changed = true; + } + return Changed; + + } else if (GS.StoredType <= GlobalStatus::isInitializerStored) { + DOUT << "MARKING CONSTANT: " << *GV; + GV->setConstant(true); + + // Clean up any obviously simplifiable users now. + CleanupConstantGlobalUsers(GV, GV->getInitializer()); + + // If the global is dead now, just nuke it. + if (GV->use_empty()) { + DOUT << " *** Marking constant allowed us to simplify " + << "all users and delete global!\n"; + GV->eraseFromParent(); + ++NumDeleted; + } + + ++NumMarked; + return true; + } else if (!GS.isNotSuitableForSRA && + !GV->getInitializer()->getType()->isFirstClassType()) { + if (GlobalVariable *FirstNewGV = SRAGlobal(GV)) { + GVI = FirstNewGV; // Don't skip the newly produced globals! + return true; + } + } else if (GS.StoredType == GlobalStatus::isStoredOnce) { + // If the initial value for the global was an undef value, and if only + // one other value was stored into it, we can just change the + // initializer to be an undef value, then delete all stores to the + // global. This allows us to mark it constant. + if (Constant *SOVConstant = dyn_cast(GS.StoredOnceValue)) + if (isa(GV->getInitializer())) { + // Change the initial value here. + GV->setInitializer(SOVConstant); + + // Clean up any obviously simplifiable users now. + CleanupConstantGlobalUsers(GV, GV->getInitializer()); + + if (GV->use_empty()) { + DOUT << " *** Substituting initializer allowed us to " + << "simplify all users and delete global!\n"; + GV->eraseFromParent(); + ++NumDeleted; + } else { + GVI = GV; + } + ++NumSubstitute; + return true; + } + + // Try to optimize globals based on the knowledge that only one value + // (besides its initializer) is ever stored to the global. + if (OptimizeOnceStoredGlobal(GV, GS.StoredOnceValue, GVI, + getAnalysis())) + return true; + + // Otherwise, if the global was not a boolean, we can shrink it to be a + // boolean. + if (Constant *SOVConstant = dyn_cast(GS.StoredOnceValue)) + if (GV->getType()->getElementType() != Type::Int1Ty && + !GV->getType()->getElementType()->isFloatingPoint() && + !isa(GV->getType()->getElementType()) && + !GS.HasPHIUser && !GS.isNotSuitableForSRA) { + DOUT << " *** SHRINKING TO BOOL: " << *GV; + ShrinkGlobalToBoolean(GV, SOVConstant); + ++NumShrunkToBool; + return true; + } + } + } + return false; +} + +/// OnlyCalledDirectly - Return true if the specified function is only called +/// directly. In other words, its address is never taken. +static bool OnlyCalledDirectly(Function *F) { + for (Value::use_iterator UI = F->use_begin(), E = F->use_end(); UI != E;++UI){ + Instruction *User = dyn_cast(*UI); + if (!User) return false; + if (!isa(User) && !isa(User)) return false; + + // See if the function address is passed as an argument. + for (unsigned i = 1, e = User->getNumOperands(); i != e; ++i) + if (User->getOperand(i) == F) return false; + } + return true; +} + +/// ChangeCalleesToFastCall - Walk all of the direct calls of the specified +/// function, changing them to FastCC. +static void ChangeCalleesToFastCall(Function *F) { + for (Value::use_iterator UI = F->use_begin(), E = F->use_end(); UI != E;++UI){ + Instruction *User = cast(*UI); + if (CallInst *CI = dyn_cast(User)) + CI->setCallingConv(CallingConv::Fast); + else + cast(User)->setCallingConv(CallingConv::Fast); + } +} + +bool GlobalOpt::OptimizeFunctions(Module &M) { + bool Changed = false; + // Optimize functions. + for (Module::iterator FI = M.begin(), E = M.end(); FI != E; ) { + Function *F = FI++; + F->removeDeadConstantUsers(); + if (F->use_empty() && (F->hasInternalLinkage() || + F->hasLinkOnceLinkage())) { + M.getFunctionList().erase(F); + Changed = true; + ++NumFnDeleted; + } else if (F->hasInternalLinkage() && + F->getCallingConv() == CallingConv::C && !F->isVarArg() && + OnlyCalledDirectly(F)) { + // If this function has C calling conventions, is not a varargs + // function, and is only called directly, promote it to use the Fast + // calling convention. + F->setCallingConv(CallingConv::Fast); + ChangeCalleesToFastCall(F); + ++NumFastCallFns; + Changed = true; + } + } + return Changed; +} + +bool GlobalOpt::OptimizeGlobalVars(Module &M) { + bool Changed = false; + for (Module::global_iterator GVI = M.global_begin(), E = M.global_end(); + GVI != E; ) { + GlobalVariable *GV = GVI++; + if (!GV->isConstant() && GV->hasInternalLinkage() && + GV->hasInitializer()) + Changed |= ProcessInternalGlobal(GV, GVI); + } + return Changed; +} + +/// FindGlobalCtors - Find the llvm.globalctors list, verifying that all +/// initializers have an init priority of 65535. +GlobalVariable *GlobalOpt::FindGlobalCtors(Module &M) { + for (Module::global_iterator I = M.global_begin(), E = M.global_end(); + I != E; ++I) + if (I->getName() == "llvm.global_ctors") { + // Found it, verify it's an array of { int, void()* }. + const ArrayType *ATy =dyn_cast(I->getType()->getElementType()); + if (!ATy) return 0; + const StructType *STy = dyn_cast(ATy->getElementType()); + if (!STy || STy->getNumElements() != 2 || + STy->getElementType(0) != Type::Int32Ty) return 0; + const PointerType *PFTy = dyn_cast(STy->getElementType(1)); + if (!PFTy) return 0; + const FunctionType *FTy = dyn_cast(PFTy->getElementType()); + if (!FTy || FTy->getReturnType() != Type::VoidTy || FTy->isVarArg() || + FTy->getNumParams() != 0) + return 0; + + // Verify that the initializer is simple enough for us to handle. + if (!I->hasInitializer()) return 0; + ConstantArray *CA = dyn_cast(I->getInitializer()); + if (!CA) return 0; + for (unsigned i = 0, e = CA->getNumOperands(); i != e; ++i) + if (ConstantStruct *CS = dyn_cast(CA->getOperand(i))) { + if (isa(CS->getOperand(1))) + continue; + + // Must have a function or null ptr. + if (!isa(CS->getOperand(1))) + return 0; + + // Init priority must be standard. + ConstantInt *CI = dyn_cast(CS->getOperand(0)); + if (!CI || CI->getZExtValue() != 65535) + return 0; + } else { + return 0; + } + + return I; + } + return 0; +} + +/// ParseGlobalCtors - Given a llvm.global_ctors list that we can understand, +/// return a list of the functions and null terminator as a vector. +static std::vector ParseGlobalCtors(GlobalVariable *GV) { + ConstantArray *CA = cast(GV->getInitializer()); + std::vector Result; + Result.reserve(CA->getNumOperands()); + for (unsigned i = 0, e = CA->getNumOperands(); i != e; ++i) { + ConstantStruct *CS = cast(CA->getOperand(i)); + Result.push_back(dyn_cast(CS->getOperand(1))); + } + return Result; +} + +/// InstallGlobalCtors - Given a specified llvm.global_ctors list, install the +/// specified array, returning the new global to use. +static GlobalVariable *InstallGlobalCtors(GlobalVariable *GCL, + const std::vector &Ctors) { + // If we made a change, reassemble the initializer list. + std::vector CSVals; + CSVals.push_back(ConstantInt::get(Type::Int32Ty, 65535)); + CSVals.push_back(0); + + // Create the new init list. + std::vector CAList; + for (unsigned i = 0, e = Ctors.size(); i != e; ++i) { + if (Ctors[i]) { + CSVals[1] = Ctors[i]; + } else { + const Type *FTy = FunctionType::get(Type::VoidTy, + std::vector(), false); + const PointerType *PFTy = PointerType::get(FTy); + CSVals[1] = Constant::getNullValue(PFTy); + CSVals[0] = ConstantInt::get(Type::Int32Ty, 2147483647); + } + CAList.push_back(ConstantStruct::get(CSVals)); + } + + // Create the array initializer. + const Type *StructTy = + cast(GCL->getType()->getElementType())->getElementType(); + Constant *CA = ConstantArray::get(ArrayType::get(StructTy, CAList.size()), + CAList); + + // If we didn't change the number of elements, don't create a new GV. + if (CA->getType() == GCL->getInitializer()->getType()) { + GCL->setInitializer(CA); + return GCL; + } + + // Create the new global and insert it next to the existing list. + GlobalVariable *NGV = new GlobalVariable(CA->getType(), GCL->isConstant(), + GCL->getLinkage(), CA, "", + (Module *)NULL, + GCL->isThreadLocal()); + GCL->getParent()->getGlobalList().insert(GCL, NGV); + NGV->takeName(GCL); + + // Nuke the old list, replacing any uses with the new one. + if (!GCL->use_empty()) { + Constant *V = NGV; + if (V->getType() != GCL->getType()) + V = ConstantExpr::getBitCast(V, GCL->getType()); + GCL->replaceAllUsesWith(V); + } + GCL->eraseFromParent(); + + if (Ctors.size()) + return NGV; + else + return 0; +} + + +static Constant *getVal(std::map &ComputedValues, + Value *V) { + if (Constant *CV = dyn_cast(V)) return CV; + Constant *R = ComputedValues[V]; + assert(R && "Reference to an uncomputed value!"); + return R; +} + +/// isSimpleEnoughPointerToCommit - Return true if this constant is simple +/// enough for us to understand. In particular, if it is a cast of something, +/// we punt. We basically just support direct accesses to globals and GEP's of +/// globals. This should be kept up to date with CommitValueTo. +static bool isSimpleEnoughPointerToCommit(Constant *C) { + if (GlobalVariable *GV = dyn_cast(C)) { + if (!GV->hasExternalLinkage() && !GV->hasInternalLinkage()) + return false; // do not allow weak/linkonce/dllimport/dllexport linkage. + return !GV->isDeclaration(); // reject external globals. + } + if (ConstantExpr *CE = dyn_cast(C)) + // Handle a constantexpr gep. + if (CE->getOpcode() == Instruction::GetElementPtr && + isa(CE->getOperand(0))) { + GlobalVariable *GV = cast(CE->getOperand(0)); + if (!GV->hasExternalLinkage() && !GV->hasInternalLinkage()) + return false; // do not allow weak/linkonce/dllimport/dllexport linkage. + return GV->hasInitializer() && + ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE); + } + return false; +} + +/// EvaluateStoreInto - Evaluate a piece of a constantexpr store into a global +/// initializer. This returns 'Init' modified to reflect 'Val' stored into it. +/// At this point, the GEP operands of Addr [0, OpNo) have been stepped into. +static Constant *EvaluateStoreInto(Constant *Init, Constant *Val, + ConstantExpr *Addr, unsigned OpNo) { + // Base case of the recursion. + if (OpNo == Addr->getNumOperands()) { + assert(Val->getType() == Init->getType() && "Type mismatch!"); + return Val; + } + + if (const StructType *STy = dyn_cast(Init->getType())) { + std::vector Elts; + + // Break up the constant into its elements. + if (ConstantStruct *CS = dyn_cast(Init)) { + for (unsigned i = 0, e = CS->getNumOperands(); i != e; ++i) + Elts.push_back(CS->getOperand(i)); + } else if (isa(Init)) { + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) + Elts.push_back(Constant::getNullValue(STy->getElementType(i))); + } else if (isa(Init)) { + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) + Elts.push_back(UndefValue::get(STy->getElementType(i))); + } else { + assert(0 && "This code is out of sync with " + " ConstantFoldLoadThroughGEPConstantExpr"); + } + + // Replace the element that we are supposed to. + ConstantInt *CU = cast(Addr->getOperand(OpNo)); + unsigned Idx = CU->getZExtValue(); + assert(Idx < STy->getNumElements() && "Struct index out of range!"); + Elts[Idx] = EvaluateStoreInto(Elts[Idx], Val, Addr, OpNo+1); + + // Return the modified struct. + return ConstantStruct::get(&Elts[0], Elts.size(), STy->isPacked()); + } else { + ConstantInt *CI = cast(Addr->getOperand(OpNo)); + const ArrayType *ATy = cast(Init->getType()); + + // Break up the array into elements. + std::vector Elts; + if (ConstantArray *CA = dyn_cast(Init)) { + for (unsigned i = 0, e = CA->getNumOperands(); i != e; ++i) + Elts.push_back(CA->getOperand(i)); + } else if (isa(Init)) { + Constant *Elt = Constant::getNullValue(ATy->getElementType()); + Elts.assign(ATy->getNumElements(), Elt); + } else if (isa(Init)) { + Constant *Elt = UndefValue::get(ATy->getElementType()); + Elts.assign(ATy->getNumElements(), Elt); + } else { + assert(0 && "This code is out of sync with " + " ConstantFoldLoadThroughGEPConstantExpr"); + } + + assert(CI->getZExtValue() < ATy->getNumElements()); + Elts[CI->getZExtValue()] = + EvaluateStoreInto(Elts[CI->getZExtValue()], Val, Addr, OpNo+1); + return ConstantArray::get(ATy, Elts); + } +} + +/// CommitValueTo - We have decided that Addr (which satisfies the predicate +/// isSimpleEnoughPointerToCommit) should get Val as its value. Make it happen. +static void CommitValueTo(Constant *Val, Constant *Addr) { + if (GlobalVariable *GV = dyn_cast(Addr)) { + assert(GV->hasInitializer()); + GV->setInitializer(Val); + return; + } + + ConstantExpr *CE = cast(Addr); + GlobalVariable *GV = cast(CE->getOperand(0)); + + Constant *Init = GV->getInitializer(); + Init = EvaluateStoreInto(Init, Val, CE, 2); + GV->setInitializer(Init); +} + +/// ComputeLoadResult - Return the value that would be computed by a load from +/// P after the stores reflected by 'memory' have been performed. If we can't +/// decide, return null. +static Constant *ComputeLoadResult(Constant *P, + const std::map &Memory) { + // If this memory location has been recently stored, use the stored value: it + // is the most up-to-date. + std::map::const_iterator I = Memory.find(P); + if (I != Memory.end()) return I->second; + + // Access it. + if (GlobalVariable *GV = dyn_cast(P)) { + if (GV->hasInitializer()) + return GV->getInitializer(); + return 0; + } + + // Handle a constantexpr getelementptr. + if (ConstantExpr *CE = dyn_cast(P)) + if (CE->getOpcode() == Instruction::GetElementPtr && + isa(CE->getOperand(0))) { + GlobalVariable *GV = cast(CE->getOperand(0)); + if (GV->hasInitializer()) + return ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE); + } + + return 0; // don't know how to evaluate. +} + +/// EvaluateFunction - Evaluate a call to function F, returning true if +/// successful, false if we can't evaluate it. ActualArgs contains the formal +/// arguments for the function. +static bool EvaluateFunction(Function *F, Constant *&RetVal, + const std::vector &ActualArgs, + std::vector &CallStack, + std::map &MutatedMemory, + std::vector &AllocaTmps) { + // Check to see if this function is already executing (recursion). If so, + // bail out. TODO: we might want to accept limited recursion. + if (std::find(CallStack.begin(), CallStack.end(), F) != CallStack.end()) + return false; + + CallStack.push_back(F); + + /// Values - As we compute SSA register values, we store their contents here. + std::map Values; + + // Initialize arguments to the incoming values specified. + unsigned ArgNo = 0; + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); AI != E; + ++AI, ++ArgNo) + Values[AI] = ActualArgs[ArgNo]; + + /// ExecutedBlocks - We only handle non-looping, non-recursive code. As such, + /// we can only evaluate any one basic block at most once. This set keeps + /// track of what we have executed so we can detect recursive cases etc. + std::set ExecutedBlocks; + + // CurInst - The current instruction we're evaluating. + BasicBlock::iterator CurInst = F->begin()->begin(); + + // This is the main evaluation loop. + while (1) { + Constant *InstResult = 0; + + if (StoreInst *SI = dyn_cast(CurInst)) { + if (SI->isVolatile()) return false; // no volatile accesses. + Constant *Ptr = getVal(Values, SI->getOperand(1)); + if (!isSimpleEnoughPointerToCommit(Ptr)) + // If this is too complex for us to commit, reject it. + return false; + Constant *Val = getVal(Values, SI->getOperand(0)); + MutatedMemory[Ptr] = Val; + } else if (BinaryOperator *BO = dyn_cast(CurInst)) { + InstResult = ConstantExpr::get(BO->getOpcode(), + getVal(Values, BO->getOperand(0)), + getVal(Values, BO->getOperand(1))); + } else if (CmpInst *CI = dyn_cast(CurInst)) { + InstResult = ConstantExpr::getCompare(CI->getPredicate(), + getVal(Values, CI->getOperand(0)), + getVal(Values, CI->getOperand(1))); + } else if (CastInst *CI = dyn_cast(CurInst)) { + InstResult = ConstantExpr::getCast(CI->getOpcode(), + getVal(Values, CI->getOperand(0)), + CI->getType()); + } else if (SelectInst *SI = dyn_cast(CurInst)) { + InstResult = ConstantExpr::getSelect(getVal(Values, SI->getOperand(0)), + getVal(Values, SI->getOperand(1)), + getVal(Values, SI->getOperand(2))); + } else if (GetElementPtrInst *GEP = dyn_cast(CurInst)) { + Constant *P = getVal(Values, GEP->getOperand(0)); + SmallVector GEPOps; + for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i) + GEPOps.push_back(getVal(Values, GEP->getOperand(i))); + InstResult = ConstantExpr::getGetElementPtr(P, &GEPOps[0], GEPOps.size()); + } else if (LoadInst *LI = dyn_cast(CurInst)) { + if (LI->isVolatile()) return false; // no volatile accesses. + InstResult = ComputeLoadResult(getVal(Values, LI->getOperand(0)), + MutatedMemory); + if (InstResult == 0) return false; // Could not evaluate load. + } else if (AllocaInst *AI = dyn_cast(CurInst)) { + if (AI->isArrayAllocation()) return false; // Cannot handle array allocs. + const Type *Ty = AI->getType()->getElementType(); + AllocaTmps.push_back(new GlobalVariable(Ty, false, + GlobalValue::InternalLinkage, + UndefValue::get(Ty), + AI->getName())); + InstResult = AllocaTmps.back(); + } else if (CallInst *CI = dyn_cast(CurInst)) { + // Cannot handle inline asm. + if (isa(CI->getOperand(0))) return false; + + // Resolve function pointers. + Function *Callee = dyn_cast(getVal(Values, CI->getOperand(0))); + if (!Callee) return false; // Cannot resolve. + + std::vector Formals; + for (unsigned i = 1, e = CI->getNumOperands(); i != e; ++i) + Formals.push_back(getVal(Values, CI->getOperand(i))); + + if (Callee->isDeclaration()) { + // If this is a function we can constant fold, do it. + if (Constant *C = ConstantFoldCall(Callee, &Formals[0], + Formals.size())) { + InstResult = C; + } else { + return false; + } + } else { + if (Callee->getFunctionType()->isVarArg()) + return false; + + Constant *RetVal; + + // Execute the call, if successful, use the return value. + if (!EvaluateFunction(Callee, RetVal, Formals, CallStack, + MutatedMemory, AllocaTmps)) + return false; + InstResult = RetVal; + } + } else if (isa(CurInst)) { + BasicBlock *NewBB = 0; + if (BranchInst *BI = dyn_cast(CurInst)) { + if (BI->isUnconditional()) { + NewBB = BI->getSuccessor(0); + } else { + ConstantInt *Cond = + dyn_cast(getVal(Values, BI->getCondition())); + if (!Cond) return false; // Cannot determine. + + NewBB = BI->getSuccessor(!Cond->getZExtValue()); + } + } else if (SwitchInst *SI = dyn_cast(CurInst)) { + ConstantInt *Val = + dyn_cast(getVal(Values, SI->getCondition())); + if (!Val) return false; // Cannot determine. + NewBB = SI->getSuccessor(SI->findCaseValue(Val)); + } else if (ReturnInst *RI = dyn_cast(CurInst)) { + if (RI->getNumOperands()) + RetVal = getVal(Values, RI->getOperand(0)); + + CallStack.pop_back(); // return from fn. + return true; // We succeeded at evaluating this ctor! + } else { + // invoke, unwind, unreachable. + return false; // Cannot handle this terminator. + } + + // Okay, we succeeded in evaluating this control flow. See if we have + // executed the new block before. If so, we have a looping function, + // which we cannot evaluate in reasonable time. + if (!ExecutedBlocks.insert(NewBB).second) + return false; // looped! + + // Okay, we have never been in this block before. Check to see if there + // are any PHI nodes. If so, evaluate them with information about where + // we came from. + BasicBlock *OldBB = CurInst->getParent(); + CurInst = NewBB->begin(); + PHINode *PN; + for (; (PN = dyn_cast(CurInst)); ++CurInst) + Values[PN] = getVal(Values, PN->getIncomingValueForBlock(OldBB)); + + // Do NOT increment CurInst. We know that the terminator had no value. + continue; + } else { + // Did not know how to evaluate this! + return false; + } + + if (!CurInst->use_empty()) + Values[CurInst] = InstResult; + + // Advance program counter. + ++CurInst; + } +} + +/// EvaluateStaticConstructor - Evaluate static constructors in the function, if +/// we can. Return true if we can, false otherwise. +static bool EvaluateStaticConstructor(Function *F) { + /// MutatedMemory - For each store we execute, we update this map. Loads + /// check this to get the most up-to-date value. If evaluation is successful, + /// this state is committed to the process. + std::map MutatedMemory; + + /// AllocaTmps - To 'execute' an alloca, we create a temporary global variable + /// to represent its body. This vector is needed so we can delete the + /// temporary globals when we are done. + std::vector AllocaTmps; + + /// CallStack - This is used to detect recursion. In pathological situations + /// we could hit exponential behavior, but at least there is nothing + /// unbounded. + std::vector CallStack; + + // Call the function. + Constant *RetValDummy; + bool EvalSuccess = EvaluateFunction(F, RetValDummy, std::vector(), + CallStack, MutatedMemory, AllocaTmps); + if (EvalSuccess) { + // We succeeded at evaluation: commit the result. + DOUT << "FULLY EVALUATED GLOBAL CTOR FUNCTION '" + << F->getName() << "' to " << MutatedMemory.size() + << " stores.\n"; + for (std::map::iterator I = MutatedMemory.begin(), + E = MutatedMemory.end(); I != E; ++I) + CommitValueTo(I->second, I->first); + } + + // At this point, we are done interpreting. If we created any 'alloca' + // temporaries, release them now. + while (!AllocaTmps.empty()) { + GlobalVariable *Tmp = AllocaTmps.back(); + AllocaTmps.pop_back(); + + // If there are still users of the alloca, the program is doing something + // silly, e.g. storing the address of the alloca somewhere and using it + // later. Since this is undefined, we'll just make it be null. + if (!Tmp->use_empty()) + Tmp->replaceAllUsesWith(Constant::getNullValue(Tmp->getType())); + delete Tmp; + } + + return EvalSuccess; +} + + + +/// OptimizeGlobalCtorsList - Simplify and evaluation global ctors if possible. +/// Return true if anything changed. +bool GlobalOpt::OptimizeGlobalCtorsList(GlobalVariable *&GCL) { + std::vector Ctors = ParseGlobalCtors(GCL); + bool MadeChange = false; + if (Ctors.empty()) return false; + + // Loop over global ctors, optimizing them when we can. + for (unsigned i = 0; i != Ctors.size(); ++i) { + Function *F = Ctors[i]; + // Found a null terminator in the middle of the list, prune off the rest of + // the list. + if (F == 0) { + if (i != Ctors.size()-1) { + Ctors.resize(i+1); + MadeChange = true; + } + break; + } + + // We cannot simplify external ctor functions. + if (F->empty()) continue; + + // If we can evaluate the ctor at compile time, do. + if (EvaluateStaticConstructor(F)) { + Ctors.erase(Ctors.begin()+i); + MadeChange = true; + --i; + ++NumCtorsEvaluated; + continue; + } + } + + if (!MadeChange) return false; + + GCL = InstallGlobalCtors(GCL, Ctors); + return true; +} + + +bool GlobalOpt::runOnModule(Module &M) { + bool Changed = false; + + // Try to find the llvm.globalctors list. + GlobalVariable *GlobalCtors = FindGlobalCtors(M); + + bool LocalChange = true; + while (LocalChange) { + LocalChange = false; + + // Delete functions that are trivially dead, ccc -> fastcc + LocalChange |= OptimizeFunctions(M); + + // Optimize global_ctors list. + if (GlobalCtors) + LocalChange |= OptimizeGlobalCtorsList(GlobalCtors); + + // Optimize non-address-taken globals. + LocalChange |= OptimizeGlobalVars(M); + Changed |= LocalChange; + } + + // TODO: Move all global ctors functions to the end of the module for code + // layout. + + return Changed; +} diff --git a/lib/Transforms/IPO/IPConstantPropagation.cpp b/lib/Transforms/IPO/IPConstantPropagation.cpp new file mode 100644 index 0000000..b55e538 --- /dev/null +++ b/lib/Transforms/IPO/IPConstantPropagation.cpp @@ -0,0 +1,197 @@ +//===-- IPConstantPropagation.cpp - Propagate constants through calls -----===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass implements an _extremely_ simple interprocedural constant +// propagation pass. It could certainly be improved in many different ways, +// like using a worklist. This pass makes arguments dead, but does not remove +// them. The existing dead argument elimination pass should be run after this +// to clean up the mess. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "ipconstprop" +#include "llvm/Transforms/IPO.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/Compiler.h" +#include "llvm/ADT/Statistic.h" +using namespace llvm; + +STATISTIC(NumArgumentsProped, "Number of args turned into constants"); +STATISTIC(NumReturnValProped, "Number of return values turned into constants"); + +namespace { + /// IPCP - The interprocedural constant propagation pass + /// + struct VISIBILITY_HIDDEN IPCP : public ModulePass { + static char ID; // Pass identification, replacement for typeid + IPCP() : ModulePass((intptr_t)&ID) {} + + bool runOnModule(Module &M); + private: + bool PropagateConstantsIntoArguments(Function &F); + bool PropagateConstantReturn(Function &F); + }; + char IPCP::ID = 0; + RegisterPass X("ipconstprop", "Interprocedural constant propagation"); +} + +ModulePass *llvm::createIPConstantPropagationPass() { return new IPCP(); } + +bool IPCP::runOnModule(Module &M) { + bool Changed = false; + bool LocalChange = true; + + // FIXME: instead of using smart algorithms, we just iterate until we stop + // making changes. + while (LocalChange) { + LocalChange = false; + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) + if (!I->isDeclaration()) { + // Delete any klingons. + I->removeDeadConstantUsers(); + if (I->hasInternalLinkage()) + LocalChange |= PropagateConstantsIntoArguments(*I); + Changed |= PropagateConstantReturn(*I); + } + Changed |= LocalChange; + } + return Changed; +} + +/// PropagateConstantsIntoArguments - Look at all uses of the specified +/// function. If all uses are direct call sites, and all pass a particular +/// constant in for an argument, propagate that constant in as the argument. +/// +bool IPCP::PropagateConstantsIntoArguments(Function &F) { + if (F.arg_empty() || F.use_empty()) return false; // No arguments? Early exit. + + std::vector > ArgumentConstants; + ArgumentConstants.resize(F.arg_size()); + + unsigned NumNonconstant = 0; + + for (Value::use_iterator I = F.use_begin(), E = F.use_end(); I != E; ++I) + if (!isa(*I)) + return false; // Used by a non-instruction, do not transform + else { + CallSite CS = CallSite::get(cast(*I)); + if (CS.getInstruction() == 0 || + CS.getCalledFunction() != &F) + return false; // Not a direct call site? + + // Check out all of the potentially constant arguments + CallSite::arg_iterator AI = CS.arg_begin(); + Function::arg_iterator Arg = F.arg_begin(); + for (unsigned i = 0, e = ArgumentConstants.size(); i != e; + ++i, ++AI, ++Arg) { + if (*AI == &F) return false; // Passes the function into itself + + if (!ArgumentConstants[i].second) { + if (Constant *C = dyn_cast(*AI)) { + if (!ArgumentConstants[i].first) + ArgumentConstants[i].first = C; + else if (ArgumentConstants[i].first != C) { + // Became non-constant + ArgumentConstants[i].second = true; + ++NumNonconstant; + if (NumNonconstant == ArgumentConstants.size()) return false; + } + } else if (*AI != &*Arg) { // Ignore recursive calls with same arg + // This is not a constant argument. Mark the argument as + // non-constant. + ArgumentConstants[i].second = true; + ++NumNonconstant; + if (NumNonconstant == ArgumentConstants.size()) return false; + } + } + } + } + + // If we got to this point, there is a constant argument! + assert(NumNonconstant != ArgumentConstants.size()); + Function::arg_iterator AI = F.arg_begin(); + bool MadeChange = false; + for (unsigned i = 0, e = ArgumentConstants.size(); i != e; ++i, ++AI) + // Do we have a constant argument!? + if (!ArgumentConstants[i].second && !AI->use_empty()) { + Value *V = ArgumentConstants[i].first; + if (V == 0) V = UndefValue::get(AI->getType()); + AI->replaceAllUsesWith(V); + ++NumArgumentsProped; + MadeChange = true; + } + return MadeChange; +} + + +// Check to see if this function returns a constant. If so, replace all callers +// that user the return value with the returned valued. If we can replace ALL +// callers, +bool IPCP::PropagateConstantReturn(Function &F) { + if (F.getReturnType() == Type::VoidTy) + return false; // No return value. + + // Check to see if this function returns a constant. + Value *RetVal = 0; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (ReturnInst *RI = dyn_cast(BB->getTerminator())) + if (isa(RI->getOperand(0))) { + // Ignore. + } else if (Constant *C = dyn_cast(RI->getOperand(0))) { + if (RetVal == 0) + RetVal = C; + else if (RetVal != C) + return false; // Does not return the same constant. + } else { + return false; // Does not return a constant. + } + + if (RetVal == 0) RetVal = UndefValue::get(F.getReturnType()); + + // If we got here, the function returns a constant value. Loop over all + // users, replacing any uses of the return value with the returned constant. + bool ReplacedAllUsers = true; + bool MadeChange = false; + for (Value::use_iterator I = F.use_begin(), E = F.use_end(); I != E; ++I) + if (!isa(*I)) + ReplacedAllUsers = false; + else { + CallSite CS = CallSite::get(cast(*I)); + if (CS.getInstruction() == 0 || + CS.getCalledFunction() != &F) { + ReplacedAllUsers = false; + } else { + if (!CS.getInstruction()->use_empty()) { + CS.getInstruction()->replaceAllUsesWith(RetVal); + MadeChange = true; + } + } + } + + // If we replace all users with the returned constant, and there can be no + // other callers of the function, replace the constant being returned in the + // function with an undef value. + if (ReplacedAllUsers && F.hasInternalLinkage() && !isa(RetVal)) { + Value *RV = UndefValue::get(RetVal->getType()); + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (ReturnInst *RI = dyn_cast(BB->getTerminator())) { + if (RI->getOperand(0) != RV) { + RI->setOperand(0, RV); + MadeChange = true; + } + } + } + + if (MadeChange) ++NumReturnValProped; + return MadeChange; +} diff --git a/lib/Transforms/IPO/IndMemRemoval.cpp b/lib/Transforms/IPO/IndMemRemoval.cpp new file mode 100644 index 0000000..6b06469 --- /dev/null +++ b/lib/Transforms/IPO/IndMemRemoval.cpp @@ -0,0 +1,89 @@ +//===-- IndMemRemoval.cpp - Remove indirect allocations and frees ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass finds places where memory allocation functions may escape into +// indirect land. Some transforms are much easier (aka possible) only if free +// or malloc are not called indirectly. +// Thus find places where the address of memory functions are taken and construct +// bounce functions with direct calls of those functions. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "indmemrem" +#include "llvm/Transforms/IPO.h" +#include "llvm/Pass.h" +#include "llvm/Module.h" +#include "llvm/Instructions.h" +#include "llvm/Type.h" +#include "llvm/DerivedTypes.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumBounceSites, "Number of sites modified"); +STATISTIC(NumBounce , "Number of bounce functions created"); + +namespace { + class VISIBILITY_HIDDEN IndMemRemPass : public ModulePass { + public: + static char ID; // Pass identification, replacement for typeid + IndMemRemPass() : ModulePass((intptr_t)&ID) {} + + virtual bool runOnModule(Module &M); + }; + char IndMemRemPass::ID = 0; + RegisterPass X("indmemrem","Indirect Malloc and Free Removal"); +} // end anonymous namespace + + +bool IndMemRemPass::runOnModule(Module &M) { + //in Theory, all direct calls of malloc and free should be promoted + //to intrinsics. Therefor, this goes through and finds where the + //address of free or malloc are taken and replaces those with bounce + //functions, ensuring that all malloc and free that might happen + //happen through intrinsics. + bool changed = false; + if (Function* F = M.getFunction("free")) { + assert(F->isDeclaration() && "free not external?"); + if (!F->use_empty()) { + Function* FN = new Function(F->getFunctionType(), + GlobalValue::LinkOnceLinkage, + "free_llvm_bounce", &M); + BasicBlock* bb = new BasicBlock("entry",FN); + Instruction* R = new ReturnInst(bb); + new FreeInst(FN->arg_begin(), R); + ++NumBounce; + NumBounceSites += F->getNumUses(); + F->replaceAllUsesWith(FN); + changed = true; + } + } + if (Function* F = M.getFunction("malloc")) { + assert(F->isDeclaration() && "malloc not external?"); + if (!F->use_empty()) { + Function* FN = new Function(F->getFunctionType(), + GlobalValue::LinkOnceLinkage, + "malloc_llvm_bounce", &M); + BasicBlock* bb = new BasicBlock("entry",FN); + Instruction* c = CastInst::createIntegerCast( + FN->arg_begin(), Type::Int32Ty, false, "c", bb); + Instruction* a = new MallocInst(Type::Int8Ty, c, "m", bb); + new ReturnInst(a, bb); + ++NumBounce; + NumBounceSites += F->getNumUses(); + F->replaceAllUsesWith(FN); + changed = true; + } + } + return changed; +} + +ModulePass *llvm::createIndMemRemPass() { + return new IndMemRemPass(); +} diff --git a/lib/Transforms/IPO/InlineSimple.cpp b/lib/Transforms/IPO/InlineSimple.cpp new file mode 100644 index 0000000..2157dcd --- /dev/null +++ b/lib/Transforms/IPO/InlineSimple.cpp @@ -0,0 +1,323 @@ +//===- InlineSimple.cpp - Code to perform simple function inlining --------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements bottom-up inlining of functions into callees. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "inline" +#include "llvm/CallingConv.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Module.h" +#include "llvm/Type.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/InlinerPass.h" +#include + +using namespace llvm; + +namespace { + struct VISIBILITY_HIDDEN ArgInfo { + unsigned ConstantWeight; + unsigned AllocaWeight; + + ArgInfo(unsigned CWeight, unsigned AWeight) + : ConstantWeight(CWeight), AllocaWeight(AWeight) {} + }; + + // FunctionInfo - For each function, calculate the size of it in blocks and + // instructions. + struct VISIBILITY_HIDDEN FunctionInfo { + // NumInsts, NumBlocks - Keep track of how large each function is, which is + // used to estimate the code size cost of inlining it. + unsigned NumInsts, NumBlocks; + + // ArgumentWeights - Each formal argument of the function is inspected to + // see if it is used in any contexts where making it a constant or alloca + // would reduce the code size. If so, we add some value to the argument + // entry here. + std::vector ArgumentWeights; + + FunctionInfo() : NumInsts(0), NumBlocks(0) {} + + /// analyzeFunction - Fill in the current structure with information gleaned + /// from the specified function. + void analyzeFunction(Function *F); + }; + + class VISIBILITY_HIDDEN SimpleInliner : public Inliner { + std::map CachedFunctionInfo; + std::set NeverInline; // Functions that are never inlined + public: + SimpleInliner() : Inliner(&ID) {} + static char ID; // Pass identification, replacement for typeid + int getInlineCost(CallSite CS); + virtual bool doInitialization(CallGraph &CG); + }; + char SimpleInliner::ID = 0; + RegisterPass X("inline", "Function Integration/Inlining"); +} + +Pass *llvm::createFunctionInliningPass() { return new SimpleInliner(); } + +// CountCodeReductionForConstant - Figure out an approximation for how many +// instructions will be constant folded if the specified value is constant. +// +static unsigned CountCodeReductionForConstant(Value *V) { + unsigned Reduction = 0; + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E; ++UI) + if (isa(*UI)) + Reduction += 40; // Eliminating a conditional branch is a big win + else if (SwitchInst *SI = dyn_cast(*UI)) + // Eliminating a switch is a big win, proportional to the number of edges + // deleted. + Reduction += (SI->getNumSuccessors()-1) * 40; + else if (CallInst *CI = dyn_cast(*UI)) { + // Turning an indirect call into a direct call is a BIG win + Reduction += CI->getCalledValue() == V ? 500 : 0; + } else if (InvokeInst *II = dyn_cast(*UI)) { + // Turning an indirect call into a direct call is a BIG win + Reduction += II->getCalledValue() == V ? 500 : 0; + } else { + // Figure out if this instruction will be removed due to simple constant + // propagation. + Instruction &Inst = cast(**UI); + bool AllOperandsConstant = true; + for (unsigned i = 0, e = Inst.getNumOperands(); i != e; ++i) + if (!isa(Inst.getOperand(i)) && Inst.getOperand(i) != V) { + AllOperandsConstant = false; + break; + } + + if (AllOperandsConstant) { + // We will get to remove this instruction... + Reduction += 7; + + // And any other instructions that use it which become constants + // themselves. + Reduction += CountCodeReductionForConstant(&Inst); + } + } + + return Reduction; +} + +// CountCodeReductionForAlloca - Figure out an approximation of how much smaller +// the function will be if it is inlined into a context where an argument +// becomes an alloca. +// +static unsigned CountCodeReductionForAlloca(Value *V) { + if (!isa(V->getType())) return 0; // Not a pointer + unsigned Reduction = 0; + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E;++UI){ + Instruction *I = cast(*UI); + if (isa(I) || isa(I)) + Reduction += 10; + else if (GetElementPtrInst *GEP = dyn_cast(I)) { + // If the GEP has variable indices, we won't be able to do much with it. + for (Instruction::op_iterator I = GEP->op_begin()+1, E = GEP->op_end(); + I != E; ++I) + if (!isa(*I)) return 0; + Reduction += CountCodeReductionForAlloca(GEP)+15; + } else { + // If there is some other strange instruction, we're not going to be able + // to do much if we inline this. + return 0; + } + } + + return Reduction; +} + +/// analyzeFunction - Fill in the current structure with information gleaned +/// from the specified function. +void FunctionInfo::analyzeFunction(Function *F) { + unsigned NumInsts = 0, NumBlocks = 0; + + // Look at the size of the callee. Each basic block counts as 20 units, and + // each instruction counts as 10. + for (Function::const_iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { + for (BasicBlock::const_iterator II = BB->begin(), E = BB->end(); + II != E; ++II) { + if (isa(II)) continue; // Debug intrinsics don't count. + + // Noop casts, including ptr <-> int, don't count. + if (const CastInst *CI = dyn_cast(II)) { + if (CI->isLosslessCast() || isa(CI) || + isa(CI)) + continue; + } else if (const GetElementPtrInst *GEPI = + dyn_cast(II)) { + // If a GEP has all constant indices, it will probably be folded with + // a load/store. + bool AllConstant = true; + for (unsigned i = 1, e = GEPI->getNumOperands(); i != e; ++i) + if (!isa(GEPI->getOperand(i))) { + AllConstant = false; + break; + } + if (AllConstant) continue; + } + + ++NumInsts; + } + + ++NumBlocks; + } + + this->NumBlocks = NumBlocks; + this->NumInsts = NumInsts; + + // Check out all of the arguments to the function, figuring out how much + // code can be eliminated if one of the arguments is a constant. + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; ++I) + ArgumentWeights.push_back(ArgInfo(CountCodeReductionForConstant(I), + CountCodeReductionForAlloca(I))); +} + + +// getInlineCost - The heuristic used to determine if we should inline the +// function call or not. +// +int SimpleInliner::getInlineCost(CallSite CS) { + Instruction *TheCall = CS.getInstruction(); + Function *Callee = CS.getCalledFunction(); + const Function *Caller = TheCall->getParent()->getParent(); + + // Don't inline a directly recursive call. + if (Caller == Callee || + // Don't inline functions which can be redefined at link-time to mean + // something else. link-once linkage is ok though. + Callee->hasWeakLinkage() || + + // Don't inline functions marked noinline. + NeverInline.count(Callee)) + return 2000000000; + + // InlineCost - This value measures how good of an inline candidate this call + // site is to inline. A lower inline cost make is more likely for the call to + // be inlined. This value may go negative. + // + int InlineCost = 0; + + // If there is only one call of the function, and it has internal linkage, + // make it almost guaranteed to be inlined. + // + if (Callee->hasInternalLinkage() && Callee->hasOneUse()) + InlineCost -= 30000; + + // If this function uses the coldcc calling convention, prefer not to inline + // it. + if (Callee->getCallingConv() == CallingConv::Cold) + InlineCost += 2000; + + // If the instruction after the call, or if the normal destination of the + // invoke is an unreachable instruction, the function is noreturn. As such, + // there is little point in inlining this. + if (InvokeInst *II = dyn_cast(TheCall)) { + if (isa(II->getNormalDest()->begin())) + InlineCost += 10000; + } else if (isa(++BasicBlock::iterator(TheCall))) + InlineCost += 10000; + + // Get information about the callee... + FunctionInfo &CalleeFI = CachedFunctionInfo[Callee]; + + // If we haven't calculated this information yet, do so now. + if (CalleeFI.NumBlocks == 0) + CalleeFI.analyzeFunction(Callee); + + // Add to the inline quality for properties that make the call valuable to + // inline. This includes factors that indicate that the result of inlining + // the function will be optimizable. Currently this just looks at arguments + // passed into the function. + // + unsigned ArgNo = 0; + for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); + I != E; ++I, ++ArgNo) { + // Each argument passed in has a cost at both the caller and the callee + // sides. This favors functions that take many arguments over functions + // that take few arguments. + InlineCost -= 20; + + // If this is a function being passed in, it is very likely that we will be + // able to turn an indirect function call into a direct function call. + if (isa(I)) + InlineCost -= 100; + + // If an alloca is passed in, inlining this function is likely to allow + // significant future optimization possibilities (like scalar promotion, and + // scalarization), so encourage the inlining of the function. + // + else if (isa(I)) { + if (ArgNo < CalleeFI.ArgumentWeights.size()) + InlineCost -= CalleeFI.ArgumentWeights[ArgNo].AllocaWeight; + + // If this is a constant being passed into the function, use the argument + // weights calculated for the callee to determine how much will be folded + // away with this information. + } else if (isa(I)) { + if (ArgNo < CalleeFI.ArgumentWeights.size()) + InlineCost -= CalleeFI.ArgumentWeights[ArgNo].ConstantWeight; + } + } + + // Now that we have considered all of the factors that make the call site more + // likely to be inlined, look at factors that make us not want to inline it. + + // Don't inline into something too big, which would make it bigger. Here, we + // count each basic block as a single unit. + // + InlineCost += Caller->size()/20; + + + // Look at the size of the callee. Each basic block counts as 20 units, and + // each instruction counts as 5. + InlineCost += CalleeFI.NumInsts*5 + CalleeFI.NumBlocks*20; + return InlineCost; +} + +// doInitialization - Initializes the vector of functions that have been +// annotated with the noinline attribute. +bool SimpleInliner::doInitialization(CallGraph &CG) { + + Module &M = CG.getModule(); + + // Get llvm.noinline + GlobalVariable *GV = M.getNamedGlobal("llvm.noinline"); + + if (GV == 0) + return false; + + const ConstantArray *InitList = dyn_cast(GV->getInitializer()); + + if (InitList == 0) + return false; + + // Iterate over each element and add to the NeverInline set + for (unsigned i = 0, e = InitList->getNumOperands(); i != e; ++i) { + + // Get Source + const Constant *Elt = InitList->getOperand(i); + + if (const ConstantExpr *CE = dyn_cast(Elt)) + if (CE->getOpcode() == Instruction::BitCast) + Elt = CE->getOperand(0); + + // Insert into set of functions to never inline + if (const Function *F = dyn_cast(Elt)) + NeverInline.insert(F); + } + + return false; +} diff --git a/lib/Transforms/IPO/Inliner.cpp b/lib/Transforms/IPO/Inliner.cpp new file mode 100644 index 0000000..85893d7 --- /dev/null +++ b/lib/Transforms/IPO/Inliner.cpp @@ -0,0 +1,217 @@ +//===- Inliner.cpp - Code common to all inliners --------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the mechanics required to implement inlining without +// missing any calls and updating the call graph. The decisions of which calls +// are profitable to inline are implemented elsewhere. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "inline" +#include "llvm/Module.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/IPO/InlinerPass.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" +#include +using namespace llvm; + +STATISTIC(NumInlined, "Number of functions inlined"); +STATISTIC(NumDeleted, "Number of functions deleted because all callers found"); + +namespace { + cl::opt // FIXME: 200 is VERY conservative + InlineLimit("inline-threshold", cl::Hidden, cl::init(200), + cl::desc("Control the amount of inlining to perform (default = 200)")); +} + +Inliner::Inliner(const void *ID) + : CallGraphSCCPass((intptr_t)ID), InlineThreshold(InlineLimit) {} + +/// getAnalysisUsage - For this class, we declare that we require and preserve +/// the call graph. If the derived class implements this method, it should +/// always explicitly call the implementation here. +void Inliner::getAnalysisUsage(AnalysisUsage &Info) const { + Info.addRequired(); + CallGraphSCCPass::getAnalysisUsage(Info); +} + +// InlineCallIfPossible - If it is possible to inline the specified call site, +// do so and update the CallGraph for this operation. +static bool InlineCallIfPossible(CallSite CS, CallGraph &CG, + const std::set &SCCFunctions, + const TargetData &TD) { + Function *Callee = CS.getCalledFunction(); + if (!InlineFunction(CS, &CG, &TD)) return false; + + // If we inlined the last possible call site to the function, delete the + // function body now. + if (Callee->use_empty() && Callee->hasInternalLinkage() && + !SCCFunctions.count(Callee)) { + DOUT << " -> Deleting dead function: " << Callee->getName() << "\n"; + + // Remove any call graph edges from the callee to its callees. + CallGraphNode *CalleeNode = CG[Callee]; + while (CalleeNode->begin() != CalleeNode->end()) + CalleeNode->removeCallEdgeTo((CalleeNode->end()-1)->second); + + // Removing the node for callee from the call graph and delete it. + delete CG.removeFunctionFromModule(CalleeNode); + ++NumDeleted; + } + return true; +} + +bool Inliner::runOnSCC(const std::vector &SCC) { + CallGraph &CG = getAnalysis(); + + std::set SCCFunctions; + DOUT << "Inliner visiting SCC:"; + for (unsigned i = 0, e = SCC.size(); i != e; ++i) { + Function *F = SCC[i]->getFunction(); + if (F) SCCFunctions.insert(F); + DOUT << " " << (F ? F->getName() : "INDIRECTNODE"); + } + + // Scan through and identify all call sites ahead of time so that we only + // inline call sites in the original functions, not call sites that result + // from inlining other functions. + std::vector CallSites; + + for (unsigned i = 0, e = SCC.size(); i != e; ++i) + if (Function *F = SCC[i]->getFunction()) + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + for (BasicBlock::iterator I = BB->begin(); I != BB->end(); ++I) { + CallSite CS = CallSite::get(I); + if (CS.getInstruction() && (!CS.getCalledFunction() || + !CS.getCalledFunction()->isDeclaration())) + CallSites.push_back(CS); + } + + DOUT << ": " << CallSites.size() << " call sites.\n"; + + // Now that we have all of the call sites, move the ones to functions in the + // current SCC to the end of the list. + unsigned FirstCallInSCC = CallSites.size(); + for (unsigned i = 0; i < FirstCallInSCC; ++i) + if (Function *F = CallSites[i].getCalledFunction()) + if (SCCFunctions.count(F)) + std::swap(CallSites[i--], CallSites[--FirstCallInSCC]); + + // Now that we have all of the call sites, loop over them and inline them if + // it looks profitable to do so. + bool Changed = false; + bool LocalChange; + do { + LocalChange = false; + // Iterate over the outer loop because inlining functions can cause indirect + // calls to become direct calls. + for (unsigned CSi = 0; CSi != CallSites.size(); ++CSi) + if (Function *Callee = CallSites[CSi].getCalledFunction()) { + // Calls to external functions are never inlinable. + if (Callee->isDeclaration() || + CallSites[CSi].getInstruction()->getParent()->getParent() ==Callee){ + if (SCC.size() == 1) { + std::swap(CallSites[CSi], CallSites.back()); + CallSites.pop_back(); + } else { + // Keep the 'in SCC / not in SCC' boundary correct. + CallSites.erase(CallSites.begin()+CSi); + } + --CSi; + continue; + } + + // If the policy determines that we should inline this function, + // try to do so. + CallSite CS = CallSites[CSi]; + int InlineCost = getInlineCost(CS); + if (InlineCost >= (int)InlineThreshold) { + DOUT << " NOT Inlining: cost=" << InlineCost + << ", Call: " << *CS.getInstruction(); + } else { + DOUT << " Inlining: cost=" << InlineCost + << ", Call: " << *CS.getInstruction(); + + // Attempt to inline the function... + if (InlineCallIfPossible(CS, CG, SCCFunctions, + getAnalysis())) { + // Remove this call site from the list. If possible, use + // swap/pop_back for efficiency, but do not use it if doing so would + // move a call site to a function in this SCC before the + // 'FirstCallInSCC' barrier. + if (SCC.size() == 1) { + std::swap(CallSites[CSi], CallSites.back()); + CallSites.pop_back(); + } else { + CallSites.erase(CallSites.begin()+CSi); + } + --CSi; + + ++NumInlined; + Changed = true; + LocalChange = true; + } + } + } + } while (LocalChange); + + return Changed; +} + +// doFinalization - Remove now-dead linkonce functions at the end of +// processing to avoid breaking the SCC traversal. +bool Inliner::doFinalization(CallGraph &CG) { + std::set FunctionsToRemove; + + // Scan for all of the functions, looking for ones that should now be removed + // from the program. Insert the dead ones in the FunctionsToRemove set. + for (CallGraph::iterator I = CG.begin(), E = CG.end(); I != E; ++I) { + CallGraphNode *CGN = I->second; + if (Function *F = CGN ? CGN->getFunction() : 0) { + // If the only remaining users of the function are dead constants, remove + // them. + F->removeDeadConstantUsers(); + + if ((F->hasLinkOnceLinkage() || F->hasInternalLinkage()) && + F->use_empty()) { + + // Remove any call graph edges from the function to its callees. + while (CGN->begin() != CGN->end()) + CGN->removeCallEdgeTo((CGN->end()-1)->second); + + // Remove any edges from the external node to the function's call graph + // node. These edges might have been made irrelegant due to + // optimization of the program. + CG.getExternalCallingNode()->removeAnyCallEdgeTo(CGN); + + // Removing the node for callee from the call graph and delete it. + FunctionsToRemove.insert(CGN); + } + } + } + + // Now that we know which functions to delete, do so. We didn't want to do + // this inline, because that would invalidate our CallGraph::iterator + // objects. :( + bool Changed = false; + for (std::set::iterator I = FunctionsToRemove.begin(), + E = FunctionsToRemove.end(); I != E; ++I) { + delete CG.removeFunctionFromModule(*I); + ++NumDeleted; + Changed = true; + } + + return Changed; +} diff --git a/lib/Transforms/IPO/Internalize.cpp b/lib/Transforms/IPO/Internalize.cpp new file mode 100644 index 0000000..7b5392c --- /dev/null +++ b/lib/Transforms/IPO/Internalize.cpp @@ -0,0 +1,154 @@ +//===-- Internalize.cpp - Mark functions internal -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass loops over all of the functions in the input module, looking for a +// main function. If a main function is found, all other functions and all +// global variables with initializers are marked as internal. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "internalize" +#include "llvm/Transforms/IPO.h" +#include "llvm/Pass.h" +#include "llvm/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" +#include +#include +using namespace llvm; + +STATISTIC(NumFunctions, "Number of functions internalized"); +STATISTIC(NumGlobals , "Number of global vars internalized"); + +namespace { + + // APIFile - A file which contains a list of symbols that should not be marked + // external. + cl::opt + APIFile("internalize-public-api-file", cl::value_desc("filename"), + cl::desc("A file containing list of symbol names to preserve")); + + // APIList - A list of symbols that should not be marked internal. + cl::list + APIList("internalize-public-api-list", cl::value_desc("list"), + cl::desc("A list of symbol names to preserve"), + cl::CommaSeparated); + + class VISIBILITY_HIDDEN InternalizePass : public ModulePass { + std::set ExternalNames; + bool DontInternalize; + public: + static char ID; // Pass identification, replacement for typeid + InternalizePass(bool InternalizeEverything = true); + InternalizePass(const std::vector & exportList); + void LoadFile(const char *Filename); + virtual bool runOnModule(Module &M); + }; + char InternalizePass::ID = 0; + RegisterPass X("internalize", "Internalize Global Symbols"); +} // end anonymous namespace + +InternalizePass::InternalizePass(bool InternalizeEverything) + : ModulePass((intptr_t)&ID), DontInternalize(false){ + if (!APIFile.empty()) // If a filename is specified, use it + LoadFile(APIFile.c_str()); + else if (!APIList.empty()) // Else, if a list is specified, use it. + ExternalNames.insert(APIList.begin(), APIList.end()); + else if (!InternalizeEverything) + // Finally, if we're allowed to, internalize all but main. + DontInternalize = true; +} + +InternalizePass::InternalizePass(const std::vector&exportList) + : ModulePass((intptr_t)&ID), DontInternalize(false){ + for(std::vector::const_iterator itr = exportList.begin(); + itr != exportList.end(); itr++) { + ExternalNames.insert(*itr); + } +} + +void InternalizePass::LoadFile(const char *Filename) { + // Load the APIFile... + std::ifstream In(Filename); + if (!In.good()) { + cerr << "WARNING: Internalize couldn't load file '" << Filename << "'!\n"; + return; // Do not internalize anything... + } + while (In) { + std::string Symbol; + In >> Symbol; + if (!Symbol.empty()) + ExternalNames.insert(Symbol); + } +} + +bool InternalizePass::runOnModule(Module &M) { + if (DontInternalize) return false; + + // If no list or file of symbols was specified, check to see if there is a + // "main" symbol defined in the module. If so, use it, otherwise do not + // internalize the module, it must be a library or something. + // + if (ExternalNames.empty()) { + Function *MainFunc = M.getFunction("main"); + if (MainFunc == 0 || MainFunc->isDeclaration()) + return false; // No main found, must be a library... + + // Preserve main, internalize all else. + ExternalNames.insert(MainFunc->getName()); + } + + bool Changed = false; + + // Found a main function, mark all functions not named main as internal. + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) + if (!I->isDeclaration() && // Function must be defined here + !I->hasInternalLinkage() && // Can't already have internal linkage + !ExternalNames.count(I->getName())) {// Not marked to keep external? + I->setLinkage(GlobalValue::InternalLinkage); + Changed = true; + ++NumFunctions; + DOUT << "Internalizing func " << I->getName() << "\n"; + } + + // Never internalize the llvm.used symbol. It is used to implement + // attribute((used)). + ExternalNames.insert("llvm.used"); + + // Never internalize anchors used by the machine module info, else the info + // won't find them. (see MachineModuleInfo.) + ExternalNames.insert("llvm.dbg.compile_units"); + ExternalNames.insert("llvm.dbg.global_variables"); + ExternalNames.insert("llvm.dbg.subprograms"); + ExternalNames.insert("llvm.global_ctors"); + ExternalNames.insert("llvm.global_dtors"); + + // Mark all global variables with initializers as internal as well. + for (Module::global_iterator I = M.global_begin(), E = M.global_end(); + I != E; ++I) + if (!I->isDeclaration() && !I->hasInternalLinkage() && + !ExternalNames.count(I->getName())) { + I->setLinkage(GlobalValue::InternalLinkage); + Changed = true; + ++NumGlobals; + DOUT << "Internalized gvar " << I->getName() << "\n"; + } + + return Changed; +} + +ModulePass *llvm::createInternalizePass(bool InternalizeEverything) { + return new InternalizePass(InternalizeEverything); +} + +ModulePass *llvm::createInternalizePass(const std::vector &el) { + return new InternalizePass(el); +} diff --git a/lib/Transforms/IPO/LoopExtractor.cpp b/lib/Transforms/IPO/LoopExtractor.cpp new file mode 100644 index 0000000..7b14ce0 --- /dev/null +++ b/lib/Transforms/IPO/LoopExtractor.cpp @@ -0,0 +1,201 @@ +//===- LoopExtractor.cpp - Extract each loop into a new function ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// A pass wrapper around the ExtractLoop() scalar transformation to extract each +// top-level loop into its own new function. If the loop is the ONLY loop in a +// given function, it is not touched. This is a pass most useful for debugging +// via bugpoint. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-extract" +#include "llvm/Transforms/IPO.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/FunctionUtils.h" +#include "llvm/ADT/Statistic.h" +using namespace llvm; + +STATISTIC(NumExtracted, "Number of loops extracted"); + +namespace { + // FIXME: This is not a function pass, but the PassManager doesn't allow + // Module passes to require FunctionPasses, so we can't get loop info if we're + // not a function pass. + struct VISIBILITY_HIDDEN LoopExtractor : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + unsigned NumLoops; + + LoopExtractor(unsigned numLoops = ~0) + : FunctionPass((intptr_t)&ID), NumLoops(numLoops) {} + + virtual bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(BreakCriticalEdgesID); + AU.addRequiredID(LoopSimplifyID); + AU.addRequired(); + AU.addRequired(); + } + }; + + char LoopExtractor::ID = 0; + RegisterPass + X("loop-extract", "Extract loops into new functions"); + + /// SingleLoopExtractor - For bugpoint. + struct SingleLoopExtractor : public LoopExtractor { + static char ID; // Pass identification, replacement for typeid + SingleLoopExtractor() : LoopExtractor(1) {} + }; + + char SingleLoopExtractor::ID = 0; + RegisterPass + Y("loop-extract-single", "Extract at most one loop into a new function"); +} // End anonymous namespace + +// createLoopExtractorPass - This pass extracts all natural loops from the +// program into a function if it can. +// +FunctionPass *llvm::createLoopExtractorPass() { return new LoopExtractor(); } + +bool LoopExtractor::runOnFunction(Function &F) { + LoopInfo &LI = getAnalysis(); + + // If this function has no loops, there is nothing to do. + if (LI.begin() == LI.end()) + return false; + + DominatorTree &DT = getAnalysis(); + + // If there is more than one top-level loop in this function, extract all of + // the loops. + bool Changed = false; + if (LI.end()-LI.begin() > 1) { + for (LoopInfo::iterator i = LI.begin(), e = LI.end(); i != e; ++i) { + if (NumLoops == 0) return Changed; + --NumLoops; + Changed |= ExtractLoop(DT, *i) != 0; + ++NumExtracted; + } + } else { + // Otherwise there is exactly one top-level loop. If this function is more + // than a minimal wrapper around the loop, extract the loop. + Loop *TLL = *LI.begin(); + bool ShouldExtractLoop = false; + + // Extract the loop if the entry block doesn't branch to the loop header. + TerminatorInst *EntryTI = F.getEntryBlock().getTerminator(); + if (!isa(EntryTI) || + !cast(EntryTI)->isUnconditional() || + EntryTI->getSuccessor(0) != TLL->getHeader()) + ShouldExtractLoop = true; + else { + // Check to see if any exits from the loop are more than just return + // blocks. + std::vector ExitBlocks; + TLL->getExitBlocks(ExitBlocks); + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) + if (!isa(ExitBlocks[i]->getTerminator())) { + ShouldExtractLoop = true; + break; + } + } + + if (ShouldExtractLoop) { + if (NumLoops == 0) return Changed; + --NumLoops; + Changed |= ExtractLoop(DT, TLL) != 0; + ++NumExtracted; + } else { + // Okay, this function is a minimal container around the specified loop. + // If we extract the loop, we will continue to just keep extracting it + // infinitely... so don't extract it. However, if the loop contains any + // subloops, extract them. + for (Loop::iterator i = TLL->begin(), e = TLL->end(); i != e; ++i) { + if (NumLoops == 0) return Changed; + --NumLoops; + Changed |= ExtractLoop(DT, *i) != 0; + ++NumExtracted; + } + } + } + + return Changed; +} + +// createSingleLoopExtractorPass - This pass extracts one natural loop from the +// program into a function if it can. This is used by bugpoint. +// +FunctionPass *llvm::createSingleLoopExtractorPass() { + return new SingleLoopExtractor(); +} + + +namespace { + /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks + /// from the module into their own functions except for those specified by the + /// BlocksToNotExtract list. + class BlockExtractorPass : public ModulePass { + std::vector BlocksToNotExtract; + public: + static char ID; // Pass identification, replacement for typeid + BlockExtractorPass(std::vector &B) + : ModulePass((intptr_t)&ID), BlocksToNotExtract(B) {} + BlockExtractorPass() : ModulePass((intptr_t)&ID) {} + + bool runOnModule(Module &M); + }; + + char BlockExtractorPass::ID = 0; + RegisterPass + XX("extract-blocks", "Extract Basic Blocks From Module (for bugpoint use)"); +} + +// createBlockExtractorPass - This pass extracts all blocks (except those +// specified in the argument list) from the functions in the module. +// +ModulePass *llvm::createBlockExtractorPass(std::vector &BTNE) { + return new BlockExtractorPass(BTNE); +} + +bool BlockExtractorPass::runOnModule(Module &M) { + std::set TranslatedBlocksToNotExtract; + for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) { + BasicBlock *BB = BlocksToNotExtract[i]; + Function *F = BB->getParent(); + + // Map the corresponding function in this module. + Function *MF = M.getFunction(F->getName()); + assert(MF->getFunctionType() == F->getFunctionType() && "Wrong function?"); + + // Figure out which index the basic block is in its function. + Function::iterator BBI = MF->begin(); + std::advance(BBI, std::distance(F->begin(), Function::iterator(BB))); + TranslatedBlocksToNotExtract.insert(BBI); + } + + // Now that we know which blocks to not extract, figure out which ones we WANT + // to extract. + std::vector BlocksToExtract; + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + if (!TranslatedBlocksToNotExtract.count(BB)) + BlocksToExtract.push_back(BB); + + for (unsigned i = 0, e = BlocksToExtract.size(); i != e; ++i) + ExtractBasicBlock(BlocksToExtract[i]); + + return !BlocksToExtract.empty(); +} diff --git a/lib/Transforms/IPO/LowerSetJmp.cpp b/lib/Transforms/IPO/LowerSetJmp.cpp new file mode 100644 index 0000000..0243980 --- /dev/null +++ b/lib/Transforms/IPO/LowerSetJmp.cpp @@ -0,0 +1,534 @@ +//===- LowerSetJmp.cpp - Code pertaining to lowering set/long jumps -------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the lowering of setjmp and longjmp to use the +// LLVM invoke and unwind instructions as necessary. +// +// Lowering of longjmp is fairly trivial. We replace the call with a +// call to the LLVM library function "__llvm_sjljeh_throw_longjmp()". +// This unwinds the stack for us calling all of the destructors for +// objects allocated on the stack. +// +// At a setjmp call, the basic block is split and the setjmp removed. +// The calls in a function that have a setjmp are converted to invoke +// where the except part checks to see if it's a longjmp exception and, +// if so, if it's handled in the function. If it is, then it gets the +// value returned by the longjmp and goes to where the basic block was +// split. Invoke instructions are handled in a similar fashion with the +// original except block being executed if it isn't a longjmp except +// that is handled by that function. +// +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// FIXME: This pass doesn't deal with PHI statements just yet. That is, +// we expect this to occur before SSAification is done. This would seem +// to make sense, but in general, it might be a good idea to make this +// pass invokable via the "opt" command at will. +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "lowersetjmp" +#include "llvm/Transforms/IPO.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Intrinsics.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/InstVisitor.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/VectorExtras.h" +using namespace llvm; + +STATISTIC(LongJmpsTransformed, "Number of longjmps transformed"); +STATISTIC(SetJmpsTransformed , "Number of setjmps transformed"); +STATISTIC(CallsTransformed , "Number of calls invokified"); +STATISTIC(InvokesTransformed , "Number of invokes modified"); + +namespace { + //===--------------------------------------------------------------------===// + // LowerSetJmp pass implementation. + class VISIBILITY_HIDDEN LowerSetJmp : public ModulePass, + public InstVisitor { + // LLVM library functions... + Constant *InitSJMap; // __llvm_sjljeh_init_setjmpmap + Constant *DestroySJMap; // __llvm_sjljeh_destroy_setjmpmap + Constant *AddSJToMap; // __llvm_sjljeh_add_setjmp_to_map + Constant *ThrowLongJmp; // __llvm_sjljeh_throw_longjmp + Constant *TryCatchLJ; // __llvm_sjljeh_try_catching_longjmp_exception + Constant *IsLJException; // __llvm_sjljeh_is_longjmp_exception + Constant *GetLJValue; // __llvm_sjljeh_get_longjmp_value + + typedef std::pair SwitchValuePair; + + // Keep track of those basic blocks reachable via a depth-first search of + // the CFG from a setjmp call. We only need to transform those "call" and + // "invoke" instructions that are reachable from the setjmp call site. + std::set DFSBlocks; + + // The setjmp map is going to hold information about which setjmps + // were called (each setjmp gets its own number) and with which + // buffer it was called. + std::map SJMap; + + // The rethrow basic block map holds the basic block to branch to if + // the exception isn't handled in the current function and needs to + // be rethrown. + std::map RethrowBBMap; + + // The preliminary basic block map holds a basic block that grabs the + // exception and determines if it's handled by the current function. + std::map PrelimBBMap; + + // The switch/value map holds a switch inst/call inst pair. The + // switch inst controls which handler (if any) gets called and the + // value is the value returned to that handler by the call to + // __llvm_sjljeh_get_longjmp_value. + std::map SwitchValMap; + + // A map of which setjmps we've seen so far in a function. + std::map SetJmpIDMap; + + AllocaInst* GetSetJmpMap(Function* Func); + BasicBlock* GetRethrowBB(Function* Func); + SwitchValuePair GetSJSwitch(Function* Func, BasicBlock* Rethrow); + + void TransformLongJmpCall(CallInst* Inst); + void TransformSetJmpCall(CallInst* Inst); + + bool IsTransformableFunction(const std::string& Name); + public: + static char ID; // Pass identification, replacement for typeid + LowerSetJmp() : ModulePass((intptr_t)&ID) {} + + void visitCallInst(CallInst& CI); + void visitInvokeInst(InvokeInst& II); + void visitReturnInst(ReturnInst& RI); + void visitUnwindInst(UnwindInst& UI); + + bool runOnModule(Module& M); + bool doInitialization(Module& M); + }; + + char LowerSetJmp::ID = 0; + RegisterPass X("lowersetjmp", "Lower Set Jump"); +} // end anonymous namespace + +// run - Run the transformation on the program. We grab the function +// prototypes for longjmp and setjmp. If they are used in the program, +// then we can go directly to the places they're at and transform them. +bool LowerSetJmp::runOnModule(Module& M) { + bool Changed = false; + + // These are what the functions are called. + Function* SetJmp = M.getFunction("llvm.setjmp"); + Function* LongJmp = M.getFunction("llvm.longjmp"); + + // This program doesn't have longjmp and setjmp calls. + if ((!LongJmp || LongJmp->use_empty()) && + (!SetJmp || SetJmp->use_empty())) return false; + + // Initialize some values and functions we'll need to transform the + // setjmp/longjmp functions. + doInitialization(M); + + if (SetJmp) { + for (Value::use_iterator B = SetJmp->use_begin(), E = SetJmp->use_end(); + B != E; ++B) { + BasicBlock* BB = cast(*B)->getParent(); + for (df_ext_iterator I = df_ext_begin(BB, DFSBlocks), + E = df_ext_end(BB, DFSBlocks); I != E; ++I) + /* empty */; + } + + while (!SetJmp->use_empty()) { + assert(isa(SetJmp->use_back()) && + "User of setjmp intrinsic not a call?"); + TransformSetJmpCall(cast(SetJmp->use_back())); + Changed = true; + } + } + + if (LongJmp) + while (!LongJmp->use_empty()) { + assert(isa(LongJmp->use_back()) && + "User of longjmp intrinsic not a call?"); + TransformLongJmpCall(cast(LongJmp->use_back())); + Changed = true; + } + + // Now go through the affected functions and convert calls and invokes + // to new invokes... + for (std::map::iterator + B = SJMap.begin(), E = SJMap.end(); B != E; ++B) { + Function* F = B->first; + for (Function::iterator BB = F->begin(), BE = F->end(); BB != BE; ++BB) + for (BasicBlock::iterator IB = BB->begin(), IE = BB->end(); IB != IE; ) { + visit(*IB++); + if (IB != BB->end() && IB->getParent() != BB) + break; // The next instruction got moved to a different block! + } + } + + DFSBlocks.clear(); + SJMap.clear(); + RethrowBBMap.clear(); + PrelimBBMap.clear(); + SwitchValMap.clear(); + SetJmpIDMap.clear(); + + return Changed; +} + +// doInitialization - For the lower long/setjmp pass, this ensures that a +// module contains a declaration for the intrisic functions we are going +// to call to convert longjmp and setjmp calls. +// +// This function is always successful, unless it isn't. +bool LowerSetJmp::doInitialization(Module& M) +{ + const Type *SBPTy = PointerType::get(Type::Int8Ty); + const Type *SBPPTy = PointerType::get(SBPTy); + + // N.B. See llvm/runtime/GCCLibraries/libexception/SJLJ-Exception.h for + // a description of the following library functions. + + // void __llvm_sjljeh_init_setjmpmap(void**) + InitSJMap = M.getOrInsertFunction("__llvm_sjljeh_init_setjmpmap", + Type::VoidTy, SBPPTy, (Type *)0); + // void __llvm_sjljeh_destroy_setjmpmap(void**) + DestroySJMap = M.getOrInsertFunction("__llvm_sjljeh_destroy_setjmpmap", + Type::VoidTy, SBPPTy, (Type *)0); + + // void __llvm_sjljeh_add_setjmp_to_map(void**, void*, unsigned) + AddSJToMap = M.getOrInsertFunction("__llvm_sjljeh_add_setjmp_to_map", + Type::VoidTy, SBPPTy, SBPTy, + Type::Int32Ty, (Type *)0); + + // void __llvm_sjljeh_throw_longjmp(int*, int) + ThrowLongJmp = M.getOrInsertFunction("__llvm_sjljeh_throw_longjmp", + Type::VoidTy, SBPTy, Type::Int32Ty, + (Type *)0); + + // unsigned __llvm_sjljeh_try_catching_longjmp_exception(void **) + TryCatchLJ = + M.getOrInsertFunction("__llvm_sjljeh_try_catching_longjmp_exception", + Type::Int32Ty, SBPPTy, (Type *)0); + + // bool __llvm_sjljeh_is_longjmp_exception() + IsLJException = M.getOrInsertFunction("__llvm_sjljeh_is_longjmp_exception", + Type::Int1Ty, (Type *)0); + + // int __llvm_sjljeh_get_longjmp_value() + GetLJValue = M.getOrInsertFunction("__llvm_sjljeh_get_longjmp_value", + Type::Int32Ty, (Type *)0); + return true; +} + +// IsTransformableFunction - Return true if the function name isn't one +// of the ones we don't want transformed. Currently, don't transform any +// "llvm.{setjmp,longjmp}" functions and none of the setjmp/longjmp error +// handling functions (beginning with __llvm_sjljeh_...they don't throw +// exceptions). +bool LowerSetJmp::IsTransformableFunction(const std::string& Name) { + std::string SJLJEh("__llvm_sjljeh"); + + if (Name.size() > SJLJEh.size()) + return std::string(Name.begin(), Name.begin() + SJLJEh.size()) != SJLJEh; + + return true; +} + +// TransformLongJmpCall - Transform a longjmp call into a call to the +// internal __llvm_sjljeh_throw_longjmp function. It then takes care of +// throwing the exception for us. +void LowerSetJmp::TransformLongJmpCall(CallInst* Inst) +{ + const Type* SBPTy = PointerType::get(Type::Int8Ty); + + // Create the call to "__llvm_sjljeh_throw_longjmp". This takes the + // same parameters as "longjmp", except that the buffer is cast to a + // char*. It returns "void", so it doesn't need to replace any of + // Inst's uses and doesn't get a name. + CastInst* CI = + new BitCastInst(Inst->getOperand(1), SBPTy, "LJBuf", Inst); + new CallInst(ThrowLongJmp, CI, Inst->getOperand(2), "", Inst); + + SwitchValuePair& SVP = SwitchValMap[Inst->getParent()->getParent()]; + + // If the function has a setjmp call in it (they are transformed first) + // we should branch to the basic block that determines if this longjmp + // is applicable here. Otherwise, issue an unwind. + if (SVP.first) + new BranchInst(SVP.first->getParent(), Inst); + else + new UnwindInst(Inst); + + // Remove all insts after the branch/unwind inst. Go from back to front to + // avoid replaceAllUsesWith if possible. + BasicBlock *BB = Inst->getParent(); + Instruction *Removed; + do { + Removed = &BB->back(); + // If the removed instructions have any users, replace them now. + if (!Removed->use_empty()) + Removed->replaceAllUsesWith(UndefValue::get(Removed->getType())); + Removed->eraseFromParent(); + } while (Removed != Inst); + + ++LongJmpsTransformed; +} + +// GetSetJmpMap - Retrieve (create and initialize, if necessary) the +// setjmp map. This map is going to hold information about which setjmps +// were called (each setjmp gets its own number) and with which buffer it +// was called. There can be only one! +AllocaInst* LowerSetJmp::GetSetJmpMap(Function* Func) +{ + if (SJMap[Func]) return SJMap[Func]; + + // Insert the setjmp map initialization before the first instruction in + // the function. + Instruction* Inst = Func->getEntryBlock().begin(); + assert(Inst && "Couldn't find even ONE instruction in entry block!"); + + // Fill in the alloca and call to initialize the SJ map. + const Type *SBPTy = PointerType::get(Type::Int8Ty); + AllocaInst* Map = new AllocaInst(SBPTy, 0, "SJMap", Inst); + new CallInst(InitSJMap, Map, "", Inst); + return SJMap[Func] = Map; +} + +// GetRethrowBB - Only one rethrow basic block is needed per function. +// If this is a longjmp exception but not handled in this block, this BB +// performs the rethrow. +BasicBlock* LowerSetJmp::GetRethrowBB(Function* Func) +{ + if (RethrowBBMap[Func]) return RethrowBBMap[Func]; + + // The basic block we're going to jump to if we need to rethrow the + // exception. + BasicBlock* Rethrow = new BasicBlock("RethrowExcept", Func); + + // Fill in the "Rethrow" BB with a call to rethrow the exception. This + // is the last instruction in the BB since at this point the runtime + // should exit this function and go to the next function. + new UnwindInst(Rethrow); + return RethrowBBMap[Func] = Rethrow; +} + +// GetSJSwitch - Return the switch statement that controls which handler +// (if any) gets called and the value returned to that handler. +LowerSetJmp::SwitchValuePair LowerSetJmp::GetSJSwitch(Function* Func, + BasicBlock* Rethrow) +{ + if (SwitchValMap[Func].first) return SwitchValMap[Func]; + + BasicBlock* LongJmpPre = new BasicBlock("LongJmpBlkPre", Func); + BasicBlock::InstListType& LongJmpPreIL = LongJmpPre->getInstList(); + + // Keep track of the preliminary basic block for some of the other + // transformations. + PrelimBBMap[Func] = LongJmpPre; + + // Grab the exception. + CallInst* Cond = new CallInst(IsLJException, "IsLJExcept"); + LongJmpPreIL.push_back(Cond); + + // The "decision basic block" gets the number associated with the + // setjmp call returning to switch on and the value returned by + // longjmp. + BasicBlock* DecisionBB = new BasicBlock("LJDecisionBB", Func); + BasicBlock::InstListType& DecisionBBIL = DecisionBB->getInstList(); + + new BranchInst(DecisionBB, Rethrow, Cond, LongJmpPre); + + // Fill in the "decision" basic block. + CallInst* LJVal = new CallInst(GetLJValue, "LJVal"); + DecisionBBIL.push_back(LJVal); + CallInst* SJNum = new CallInst(TryCatchLJ, GetSetJmpMap(Func), "SJNum"); + DecisionBBIL.push_back(SJNum); + + SwitchInst* SI = new SwitchInst(SJNum, Rethrow, 0, DecisionBB); + return SwitchValMap[Func] = SwitchValuePair(SI, LJVal); +} + +// TransformSetJmpCall - The setjmp call is a bit trickier to transform. +// We're going to convert all setjmp calls to nops. Then all "call" and +// "invoke" instructions in the function are converted to "invoke" where +// the "except" branch is used when returning from a longjmp call. +void LowerSetJmp::TransformSetJmpCall(CallInst* Inst) +{ + BasicBlock* ABlock = Inst->getParent(); + Function* Func = ABlock->getParent(); + + // Add this setjmp to the setjmp map. + const Type* SBPTy = PointerType::get(Type::Int8Ty); + CastInst* BufPtr = + new BitCastInst(Inst->getOperand(1), SBPTy, "SBJmpBuf", Inst); + std::vector Args = + make_vector(GetSetJmpMap(Func), BufPtr, + ConstantInt::get(Type::Int32Ty, + SetJmpIDMap[Func]++), 0); + new CallInst(AddSJToMap, &Args[0], Args.size(), "", Inst); + + // We are guaranteed that there are no values live across basic blocks + // (because we are "not in SSA form" yet), but there can still be values live + // in basic blocks. Because of this, splitting the setjmp block can cause + // values above the setjmp to not dominate uses which are after the setjmp + // call. For all of these occasions, we must spill the value to the stack. + // + std::set InstrsAfterCall; + + // The call is probably very close to the end of the basic block, for the + // common usage pattern of: 'if (setjmp(...))', so keep track of the + // instructions after the call. + for (BasicBlock::iterator I = ++BasicBlock::iterator(Inst), E = ABlock->end(); + I != E; ++I) + InstrsAfterCall.insert(I); + + for (BasicBlock::iterator II = ABlock->begin(); + II != BasicBlock::iterator(Inst); ++II) + // Loop over all of the uses of instruction. If any of them are after the + // call, "spill" the value to the stack. + for (Value::use_iterator UI = II->use_begin(), E = II->use_end(); + UI != E; ++UI) + if (cast(*UI)->getParent() != ABlock || + InstrsAfterCall.count(cast(*UI))) { + DemoteRegToStack(*II); + break; + } + InstrsAfterCall.clear(); + + // Change the setjmp call into a branch statement. We'll remove the + // setjmp call in a little bit. No worries. + BasicBlock* SetJmpContBlock = ABlock->splitBasicBlock(Inst); + assert(SetJmpContBlock && "Couldn't split setjmp BB!!"); + + SetJmpContBlock->setName(ABlock->getName()+"SetJmpCont"); + + // Add the SetJmpContBlock to the set of blocks reachable from a setjmp. + DFSBlocks.insert(SetJmpContBlock); + + // This PHI node will be in the new block created from the + // splitBasicBlock call. + PHINode* PHI = new PHINode(Type::Int32Ty, "SetJmpReturn", Inst); + + // Coming from a call to setjmp, the return is 0. + PHI->addIncoming(ConstantInt::getNullValue(Type::Int32Ty), ABlock); + + // Add the case for this setjmp's number... + SwitchValuePair SVP = GetSJSwitch(Func, GetRethrowBB(Func)); + SVP.first->addCase(ConstantInt::get(Type::Int32Ty, SetJmpIDMap[Func] - 1), + SetJmpContBlock); + + // Value coming from the handling of the exception. + PHI->addIncoming(SVP.second, SVP.second->getParent()); + + // Replace all uses of this instruction with the PHI node created by + // the eradication of setjmp. + Inst->replaceAllUsesWith(PHI); + Inst->getParent()->getInstList().erase(Inst); + + ++SetJmpsTransformed; +} + +// visitCallInst - This converts all LLVM call instructions into invoke +// instructions. The except part of the invoke goes to the "LongJmpBlkPre" +// that grabs the exception and proceeds to determine if it's a longjmp +// exception or not. +void LowerSetJmp::visitCallInst(CallInst& CI) +{ + if (CI.getCalledFunction()) + if (!IsTransformableFunction(CI.getCalledFunction()->getName()) || + CI.getCalledFunction()->isIntrinsic()) return; + + BasicBlock* OldBB = CI.getParent(); + + // If not reachable from a setjmp call, don't transform. + if (!DFSBlocks.count(OldBB)) return; + + BasicBlock* NewBB = OldBB->splitBasicBlock(CI); + assert(NewBB && "Couldn't split BB of \"call\" instruction!!"); + DFSBlocks.insert(NewBB); + NewBB->setName("Call2Invoke"); + + Function* Func = OldBB->getParent(); + + // Construct the new "invoke" instruction. + TerminatorInst* Term = OldBB->getTerminator(); + std::vector Params(CI.op_begin() + 1, CI.op_end()); + InvokeInst* II = new + InvokeInst(CI.getCalledValue(), NewBB, PrelimBBMap[Func], + &Params[0], Params.size(), CI.getName(), Term); + + // Replace the old call inst with the invoke inst and remove the call. + CI.replaceAllUsesWith(II); + CI.getParent()->getInstList().erase(&CI); + + // The old terminator is useless now that we have the invoke inst. + Term->getParent()->getInstList().erase(Term); + ++CallsTransformed; +} + +// visitInvokeInst - Converting the "invoke" instruction is fairly +// straight-forward. The old exception part is replaced by a query asking +// if this is a longjmp exception. If it is, then it goes to the longjmp +// exception blocks. Otherwise, control is passed the old exception. +void LowerSetJmp::visitInvokeInst(InvokeInst& II) +{ + if (II.getCalledFunction()) + if (!IsTransformableFunction(II.getCalledFunction()->getName()) || + II.getCalledFunction()->isIntrinsic()) return; + + BasicBlock* BB = II.getParent(); + + // If not reachable from a setjmp call, don't transform. + if (!DFSBlocks.count(BB)) return; + + BasicBlock* ExceptBB = II.getUnwindDest(); + + Function* Func = BB->getParent(); + BasicBlock* NewExceptBB = new BasicBlock("InvokeExcept", Func); + BasicBlock::InstListType& InstList = NewExceptBB->getInstList(); + + // If this is a longjmp exception, then branch to the preliminary BB of + // the longjmp exception handling. Otherwise, go to the old exception. + CallInst* IsLJExcept = new CallInst(IsLJException, "IsLJExcept"); + InstList.push_back(IsLJExcept); + + new BranchInst(PrelimBBMap[Func], ExceptBB, IsLJExcept, NewExceptBB); + + II.setUnwindDest(NewExceptBB); + ++InvokesTransformed; +} + +// visitReturnInst - We want to destroy the setjmp map upon exit from the +// function. +void LowerSetJmp::visitReturnInst(ReturnInst &RI) { + Function* Func = RI.getParent()->getParent(); + new CallInst(DestroySJMap, GetSetJmpMap(Func), "", &RI); +} + +// visitUnwindInst - We want to destroy the setjmp map upon exit from the +// function. +void LowerSetJmp::visitUnwindInst(UnwindInst &UI) { + Function* Func = UI.getParent()->getParent(); + new CallInst(DestroySJMap, GetSetJmpMap(Func), "", &UI); +} + +ModulePass *llvm::createLowerSetJmpPass() { + return new LowerSetJmp(); +} + diff --git a/lib/Transforms/IPO/Makefile b/lib/Transforms/IPO/Makefile new file mode 100644 index 0000000..22a76d3 --- /dev/null +++ b/lib/Transforms/IPO/Makefile @@ -0,0 +1,15 @@ +##===- lib/Transforms/IPO/Makefile -------------------------*- Makefile -*-===## +# +# The LLVM Compiler Infrastructure +# +# This file was developed by the LLVM research group and is distributed under +# the University of Illinois Open Source License. See LICENSE.TXT for details. +# +##===----------------------------------------------------------------------===## + +LEVEL = ../../.. +LIBRARYNAME = LLVMipo +BUILD_ARCHIVE = 1 + +include $(LEVEL)/Makefile.common + diff --git a/lib/Transforms/IPO/PruneEH.cpp b/lib/Transforms/IPO/PruneEH.cpp new file mode 100644 index 0000000..a783272 --- /dev/null +++ b/lib/Transforms/IPO/PruneEH.cpp @@ -0,0 +1,233 @@ +//===- PruneEH.cpp - Pass which deletes unused exception handlers ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple interprocedural pass which walks the +// call-graph, turning invoke instructions into calls, iff the callee cannot +// throw an exception. It implements this as a bottom-up traversal of the +// call-graph. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "prune-eh" +#include "llvm/Transforms/IPO.h" +#include "llvm/CallGraphSCCPass.h" +#include "llvm/Constants.h" +#include "llvm/Function.h" +#include "llvm/Intrinsics.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include +#include +using namespace llvm; + +STATISTIC(NumRemoved, "Number of invokes removed"); +STATISTIC(NumUnreach, "Number of noreturn calls optimized"); + +namespace { + struct VISIBILITY_HIDDEN PruneEH : public CallGraphSCCPass { + static char ID; // Pass identification, replacement for typeid + PruneEH() : CallGraphSCCPass((intptr_t)&ID) {} + + /// DoesNotUnwind - This set contains all of the functions which we have + /// determined cannot unwind. + std::set DoesNotUnwind; + + /// DoesNotReturn - This set contains all of the functions which we have + /// determined cannot return normally (but might unwind). + std::set DoesNotReturn; + + // runOnSCC - Analyze the SCC, performing the transformation if possible. + bool runOnSCC(const std::vector &SCC); + + bool SimplifyFunction(Function *F); + void DeleteBasicBlock(BasicBlock *BB); + }; + + char PruneEH::ID = 0; + RegisterPass X("prune-eh", "Remove unused exception handling info"); +} + +Pass *llvm::createPruneEHPass() { return new PruneEH(); } + + +bool PruneEH::runOnSCC(const std::vector &SCC) { + CallGraph &CG = getAnalysis(); + bool MadeChange = false; + + // First pass, scan all of the functions in the SCC, simplifying them + // according to what we know. + for (unsigned i = 0, e = SCC.size(); i != e; ++i) + if (Function *F = SCC[i]->getFunction()) + MadeChange |= SimplifyFunction(F); + + // Next, check to see if any callees might throw or if there are any external + // functions in this SCC: if so, we cannot prune any functions in this SCC. + // If this SCC includes the unwind instruction, we KNOW it throws, so + // obviously the SCC might throw. + // + bool SCCMightUnwind = false, SCCMightReturn = false; + for (unsigned i = 0, e = SCC.size(); + (!SCCMightUnwind || !SCCMightReturn) && i != e; ++i) { + Function *F = SCC[i]->getFunction(); + if (F == 0 || (F->isDeclaration() && !F->getIntrinsicID())) { + SCCMightUnwind = true; + SCCMightReturn = true; + } else { + if (F->isDeclaration()) + SCCMightReturn = true; + + // Check to see if this function performs an unwind or calls an + // unwinding function. + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { + if (isa(BB->getTerminator())) { // Uses unwind! + SCCMightUnwind = true; + } else if (isa(BB->getTerminator())) { + SCCMightReturn = true; + } + + // Invoke instructions don't allow unwinding to continue, so we are + // only interested in call instructions. + if (!SCCMightUnwind) + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (CallInst *CI = dyn_cast(I)) { + if (Function *Callee = CI->getCalledFunction()) { + CallGraphNode *CalleeNode = CG[Callee]; + // If the callee is outside our current SCC, or if it is not + // known to throw, then we might throw also. + if (std::find(SCC.begin(), SCC.end(), CalleeNode) == SCC.end()&& + !DoesNotUnwind.count(CalleeNode)) { + SCCMightUnwind = true; + break; + } + } else { + // Indirect call, it might throw. + SCCMightUnwind = true; + break; + } + } + if (SCCMightUnwind && SCCMightReturn) break; + } + } + } + + // If the SCC doesn't unwind or doesn't throw, note this fact. + if (!SCCMightUnwind) + for (unsigned i = 0, e = SCC.size(); i != e; ++i) + DoesNotUnwind.insert(SCC[i]); + if (!SCCMightReturn) + for (unsigned i = 0, e = SCC.size(); i != e; ++i) + DoesNotReturn.insert(SCC[i]); + + for (unsigned i = 0, e = SCC.size(); i != e; ++i) { + // Convert any invoke instructions to non-throwing functions in this node + // into call instructions with a branch. This makes the exception blocks + // dead. + if (Function *F = SCC[i]->getFunction()) + MadeChange |= SimplifyFunction(F); + } + + return MadeChange; +} + + +// SimplifyFunction - Given information about callees, simplify the specified +// function if we have invokes to non-unwinding functions or code after calls to +// no-return functions. +bool PruneEH::SimplifyFunction(Function *F) { + CallGraph &CG = getAnalysis(); + bool MadeChange = false; + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { + if (InvokeInst *II = dyn_cast(BB->getTerminator())) + if (Function *F = II->getCalledFunction()) + if (DoesNotUnwind.count(CG[F])) { + SmallVector Args(II->op_begin()+3, II->op_end()); + // Insert a call instruction before the invoke. + CallInst *Call = new CallInst(II->getCalledValue(), + &Args[0], Args.size(), "", II); + Call->takeName(II); + Call->setCallingConv(II->getCallingConv()); + + // Anything that used the value produced by the invoke instruction + // now uses the value produced by the call instruction. + II->replaceAllUsesWith(Call); + BasicBlock *UnwindBlock = II->getUnwindDest(); + UnwindBlock->removePredecessor(II->getParent()); + + // Insert a branch to the normal destination right before the + // invoke. + new BranchInst(II->getNormalDest(), II); + + // Finally, delete the invoke instruction! + BB->getInstList().pop_back(); + + // If the unwind block is now dead, nuke it. + if (pred_begin(UnwindBlock) == pred_end(UnwindBlock)) + DeleteBasicBlock(UnwindBlock); // Delete the new BB. + + ++NumRemoved; + MadeChange = true; + } + + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) + if (CallInst *CI = dyn_cast(I++)) + if (Function *Callee = CI->getCalledFunction()) + if (DoesNotReturn.count(CG[Callee]) && !isa(I)) { + // This call calls a function that cannot return. Insert an + // unreachable instruction after it and simplify the code. Do this + // by splitting the BB, adding the unreachable, then deleting the + // new BB. + BasicBlock *New = BB->splitBasicBlock(I); + + // Remove the uncond branch and add an unreachable. + BB->getInstList().pop_back(); + new UnreachableInst(BB); + + DeleteBasicBlock(New); // Delete the new BB. + MadeChange = true; + ++NumUnreach; + break; + } + + } + return MadeChange; +} + +/// DeleteBasicBlock - remove the specified basic block from the program, +/// updating the callgraph to reflect any now-obsolete edges due to calls that +/// exist in the BB. +void PruneEH::DeleteBasicBlock(BasicBlock *BB) { + assert(pred_begin(BB) == pred_end(BB) && "BB is not dead!"); + CallGraph &CG = getAnalysis(); + + CallGraphNode *CGN = CG[BB->getParent()]; + for (BasicBlock::iterator I = BB->end(), E = BB->begin(); I != E; ) { + --I; + if (CallInst *CI = dyn_cast(I)) { + if (Function *Callee = CI->getCalledFunction()) + CGN->removeCallEdgeTo(CG[Callee]); + } else if (InvokeInst *II = dyn_cast(I)) { + if (Function *Callee = II->getCalledFunction()) + CGN->removeCallEdgeTo(CG[Callee]); + } + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + } + + // Get the list of successors of this block. + std::vector Succs(succ_begin(BB), succ_end(BB)); + + for (unsigned i = 0, e = Succs.size(); i != e; ++i) + Succs[i]->removePredecessor(BB); + + BB->eraseFromParent(); +} diff --git a/lib/Transforms/IPO/RaiseAllocations.cpp b/lib/Transforms/IPO/RaiseAllocations.cpp new file mode 100644 index 0000000..5d2d9dd --- /dev/null +++ b/lib/Transforms/IPO/RaiseAllocations.cpp @@ -0,0 +1,249 @@ +//===- RaiseAllocations.cpp - Convert %malloc & %free calls to insts ------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the RaiseAllocations pass which convert malloc and free +// calls to malloc and free instructions. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "raiseallocs" +#include "llvm/Transforms/IPO.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Module.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/Compiler.h" +#include "llvm/ADT/Statistic.h" +#include +using namespace llvm; + +STATISTIC(NumRaised, "Number of allocations raised"); + +namespace { + // RaiseAllocations - Turn %malloc and %free calls into the appropriate + // instruction. + // + class VISIBILITY_HIDDEN RaiseAllocations : public ModulePass { + Function *MallocFunc; // Functions in the module we are processing + Function *FreeFunc; // Initialized by doPassInitializationVirt + public: + static char ID; // Pass identification, replacement for typeid + RaiseAllocations() + : ModulePass((intptr_t)&ID), MallocFunc(0), FreeFunc(0) {} + + // doPassInitialization - For the raise allocations pass, this finds a + // declaration for malloc and free if they exist. + // + void doInitialization(Module &M); + + // run - This method does the actual work of converting instructions over. + // + bool runOnModule(Module &M); + }; + + char RaiseAllocations::ID = 0; + RegisterPass + X("raiseallocs", "Raise allocations from calls to instructions"); +} // end anonymous namespace + + +// createRaiseAllocationsPass - The interface to this file... +ModulePass *llvm::createRaiseAllocationsPass() { + return new RaiseAllocations(); +} + + +// If the module has a symbol table, they might be referring to the malloc and +// free functions. If this is the case, grab the method pointers that the +// module is using. +// +// Lookup %malloc and %free in the symbol table, for later use. If they don't +// exist, or are not external, we do not worry about converting calls to that +// function into the appropriate instruction. +// +void RaiseAllocations::doInitialization(Module &M) { + + // Get Malloc and free prototypes if they exist! + MallocFunc = M.getFunction("malloc"); + if (MallocFunc) { + const FunctionType* TyWeHave = MallocFunc->getFunctionType(); + + // Get the expected prototype for malloc + const FunctionType *Malloc1Type = + FunctionType::get(PointerType::get(Type::Int8Ty), + std::vector(1, Type::Int64Ty), false); + + // Chck to see if we got the expected malloc + if (TyWeHave != Malloc1Type) { + // Check to see if the prototype is wrong, giving us sbyte*(uint) * malloc + // This handles the common declaration of: 'void *malloc(unsigned);' + const FunctionType *Malloc2Type = + FunctionType::get(PointerType::get(Type::Int8Ty), + std::vector(1, Type::Int32Ty), false); + if (TyWeHave != Malloc2Type) { + // Check to see if the prototype is missing, giving us + // sbyte*(...) * malloc + // This handles the common declaration of: 'void *malloc();' + const FunctionType *Malloc3Type = + FunctionType::get(PointerType::get(Type::Int8Ty), + std::vector(), true); + if (TyWeHave != Malloc3Type) + // Give up + MallocFunc = 0; + } + } + } + + FreeFunc = M.getFunction("free"); + if (FreeFunc) { + const FunctionType* TyWeHave = FreeFunc->getFunctionType(); + + // Get the expected prototype for void free(i8*) + const FunctionType *Free1Type = FunctionType::get(Type::VoidTy, + std::vector(1, PointerType::get(Type::Int8Ty)), false); + + if (TyWeHave != Free1Type) { + // Check to see if the prototype was forgotten, giving us + // void (...) * free + // This handles the common forward declaration of: 'void free();' + const FunctionType* Free2Type = FunctionType::get(Type::VoidTy, + std::vector(),true); + + if (TyWeHave != Free2Type) { + // One last try, check to see if we can find free as + // int (...)* free. This handles the case where NOTHING was declared. + const FunctionType* Free3Type = FunctionType::get(Type::Int32Ty, + std::vector(),true); + + if (TyWeHave != Free3Type) { + // Give up. + FreeFunc = 0; + } + } + } + } + + // Don't mess with locally defined versions of these functions... + if (MallocFunc && !MallocFunc->isDeclaration()) MallocFunc = 0; + if (FreeFunc && !FreeFunc->isDeclaration()) FreeFunc = 0; +} + +// run - Transform calls into instructions... +// +bool RaiseAllocations::runOnModule(Module &M) { + // Find the malloc/free prototypes... + doInitialization(M); + + bool Changed = false; + + // First, process all of the malloc calls... + if (MallocFunc) { + std::vector Users(MallocFunc->use_begin(), MallocFunc->use_end()); + std::vector EqPointers; // Values equal to MallocFunc + while (!Users.empty()) { + User *U = Users.back(); + Users.pop_back(); + + if (Instruction *I = dyn_cast(U)) { + CallSite CS = CallSite::get(I); + if (CS.getInstruction() && CS.arg_begin() != CS.arg_end() && + (CS.getCalledFunction() == MallocFunc || + std::find(EqPointers.begin(), EqPointers.end(), + CS.getCalledValue()) != EqPointers.end())) { + + Value *Source = *CS.arg_begin(); + + // If no prototype was provided for malloc, we may need to cast the + // source size. + if (Source->getType() != Type::Int32Ty) + Source = + CastInst::createIntegerCast(Source, Type::Int32Ty, false/*ZExt*/, + "MallocAmtCast", I); + + MallocInst *MI = new MallocInst(Type::Int8Ty, Source, "", I); + MI->takeName(I); + I->replaceAllUsesWith(MI); + + // If the old instruction was an invoke, add an unconditional branch + // before the invoke, which will become the new terminator. + if (InvokeInst *II = dyn_cast(I)) + new BranchInst(II->getNormalDest(), I); + + // Delete the old call site + MI->getParent()->getInstList().erase(I); + Changed = true; + ++NumRaised; + } + } else if (GlobalValue *GV = dyn_cast(U)) { + Users.insert(Users.end(), GV->use_begin(), GV->use_end()); + EqPointers.push_back(GV); + } else if (ConstantExpr *CE = dyn_cast(U)) { + if (CE->isCast()) { + Users.insert(Users.end(), CE->use_begin(), CE->use_end()); + EqPointers.push_back(CE); + } + } + } + } + + // Next, process all free calls... + if (FreeFunc) { + std::vector Users(FreeFunc->use_begin(), FreeFunc->use_end()); + std::vector EqPointers; // Values equal to FreeFunc + + while (!Users.empty()) { + User *U = Users.back(); + Users.pop_back(); + + if (Instruction *I = dyn_cast(U)) { + CallSite CS = CallSite::get(I); + if (CS.getInstruction() && CS.arg_begin() != CS.arg_end() && + (CS.getCalledFunction() == FreeFunc || + std::find(EqPointers.begin(), EqPointers.end(), + CS.getCalledValue()) != EqPointers.end())) { + + // If no prototype was provided for free, we may need to cast the + // source pointer. This should be really uncommon, but it's necessary + // just in case we are dealing with weird code like this: + // free((long)ptr); + // + Value *Source = *CS.arg_begin(); + if (!isa(Source->getType())) + Source = new IntToPtrInst(Source, PointerType::get(Type::Int8Ty), + "FreePtrCast", I); + new FreeInst(Source, I); + + // If the old instruction was an invoke, add an unconditional branch + // before the invoke, which will become the new terminator. + if (InvokeInst *II = dyn_cast(I)) + new BranchInst(II->getNormalDest(), I); + + // Delete the old call site + if (I->getType() != Type::VoidTy) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->eraseFromParent(); + Changed = true; + ++NumRaised; + } + } else if (GlobalValue *GV = dyn_cast(U)) { + Users.insert(Users.end(), GV->use_begin(), GV->use_end()); + EqPointers.push_back(GV); + } else if (ConstantExpr *CE = dyn_cast(U)) { + if (CE->isCast()) { + Users.insert(Users.end(), CE->use_begin(), CE->use_end()); + EqPointers.push_back(CE); + } + } + } + } + + return Changed; +} diff --git a/lib/Transforms/IPO/SimplifyLibCalls.cpp b/lib/Transforms/IPO/SimplifyLibCalls.cpp new file mode 100644 index 0000000..b0f9128 --- /dev/null +++ b/lib/Transforms/IPO/SimplifyLibCalls.cpp @@ -0,0 +1,2021 @@ +//===- SimplifyLibCalls.cpp - Optimize specific well-known library calls --===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by Reid Spencer and is distributed under the +// University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a module pass that applies a variety of small +// optimizations for calls to specific well-known function calls (e.g. runtime +// library functions). For example, a call to the function "exit(3)" that +// occurs within the main() function can be transformed into a simple "return 3" +// instruction. Any optimization that takes this form (replace call to library +// function with simpler code that provides the same result) belongs in this +// file. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "simplify-libcalls" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/ADT/hash_map" +#include "llvm/ADT/Statistic.h" +#include "llvm/Config/config.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/IPO.h" +using namespace llvm; + +/// This statistic keeps track of the total number of library calls that have +/// been simplified regardless of which call it is. +STATISTIC(SimplifiedLibCalls, "Number of library calls simplified"); + +namespace { + // Forward declarations + class LibCallOptimization; + class SimplifyLibCalls; + +/// This list is populated by the constructor for LibCallOptimization class. +/// Therefore all subclasses are registered here at static initialization time +/// and this list is what the SimplifyLibCalls pass uses to apply the individual +/// optimizations to the call sites. +/// @brief The list of optimizations deriving from LibCallOptimization +static LibCallOptimization *OptList = 0; + +/// This class is the abstract base class for the set of optimizations that +/// corresponds to one library call. The SimplifyLibCalls pass will call the +/// ValidateCalledFunction method to ask the optimization if a given Function +/// is the kind that the optimization can handle. If the subclass returns true, +/// then SImplifyLibCalls will also call the OptimizeCall method to perform, +/// or attempt to perform, the optimization(s) for the library call. Otherwise, +/// OptimizeCall won't be called. Subclasses are responsible for providing the +/// name of the library call (strlen, strcpy, etc.) to the LibCallOptimization +/// constructor. This is used to efficiently select which call instructions to +/// optimize. The criteria for a "lib call" is "anything with well known +/// semantics", typically a library function that is defined by an international +/// standard. Because the semantics are well known, the optimizations can +/// generally short-circuit actually calling the function if there's a simpler +/// way (e.g. strlen(X) can be reduced to a constant if X is a constant global). +/// @brief Base class for library call optimizations +class VISIBILITY_HIDDEN LibCallOptimization { + LibCallOptimization **Prev, *Next; + const char *FunctionName; ///< Name of the library call we optimize +#ifndef NDEBUG + Statistic occurrences; ///< debug statistic (-debug-only=simplify-libcalls) +#endif +public: + /// The \p fname argument must be the name of the library function being + /// optimized by the subclass. + /// @brief Constructor that registers the optimization. + LibCallOptimization(const char *FName, const char *Description) + : FunctionName(FName) { + +#ifndef NDEBUG + occurrences.construct("simplify-libcalls", Description); +#endif + // Register this optimizer in the list of optimizations. + Next = OptList; + OptList = this; + Prev = &OptList; + if (Next) Next->Prev = &Next; + } + + /// getNext - All libcall optimizations are chained together into a list, + /// return the next one in the list. + LibCallOptimization *getNext() { return Next; } + + /// @brief Deregister from the optlist + virtual ~LibCallOptimization() { + *Prev = Next; + if (Next) Next->Prev = Prev; + } + + /// The implementation of this function in subclasses should determine if + /// \p F is suitable for the optimization. This method is called by + /// SimplifyLibCalls::runOnModule to short circuit visiting all the call + /// sites of such a function if that function is not suitable in the first + /// place. If the called function is suitabe, this method should return true; + /// false, otherwise. This function should also perform any lazy + /// initialization that the LibCallOptimization needs to do, if its to return + /// true. This avoids doing initialization until the optimizer is actually + /// going to be called upon to do some optimization. + /// @brief Determine if the function is suitable for optimization + virtual bool ValidateCalledFunction( + const Function* F, ///< The function that is the target of call sites + SimplifyLibCalls& SLC ///< The pass object invoking us + ) = 0; + + /// The implementations of this function in subclasses is the heart of the + /// SimplifyLibCalls algorithm. Sublcasses of this class implement + /// OptimizeCall to determine if (a) the conditions are right for optimizing + /// the call and (b) to perform the optimization. If an action is taken + /// against ci, the subclass is responsible for returning true and ensuring + /// that ci is erased from its parent. + /// @brief Optimize a call, if possible. + virtual bool OptimizeCall( + CallInst* ci, ///< The call instruction that should be optimized. + SimplifyLibCalls& SLC ///< The pass object invoking us + ) = 0; + + /// @brief Get the name of the library call being optimized + const char *getFunctionName() const { return FunctionName; } + + bool ReplaceCallWith(CallInst *CI, Value *V) { + if (!CI->use_empty()) + CI->replaceAllUsesWith(V); + CI->eraseFromParent(); + return true; + } + + /// @brief Called by SimplifyLibCalls to update the occurrences statistic. + void succeeded() { +#ifndef NDEBUG + DEBUG(++occurrences); +#endif + } +}; + +/// This class is an LLVM Pass that applies each of the LibCallOptimization +/// instances to all the call sites in a module, relatively efficiently. The +/// purpose of this pass is to provide optimizations for calls to well-known +/// functions with well-known semantics, such as those in the c library. The +/// class provides the basic infrastructure for handling runOnModule. Whenever +/// this pass finds a function call, it asks the appropriate optimizer to +/// validate the call (ValidateLibraryCall). If it is validated, then +/// the OptimizeCall method is also called. +/// @brief A ModulePass for optimizing well-known function calls. +class VISIBILITY_HIDDEN SimplifyLibCalls : public ModulePass { +public: + static char ID; // Pass identification, replacement for typeid + SimplifyLibCalls() : ModulePass((intptr_t)&ID) {} + + /// We need some target data for accurate signature details that are + /// target dependent. So we require target data in our AnalysisUsage. + /// @brief Require TargetData from AnalysisUsage. + virtual void getAnalysisUsage(AnalysisUsage& Info) const { + // Ask that the TargetData analysis be performed before us so we can use + // the target data. + Info.addRequired(); + } + + /// For this pass, process all of the function calls in the module, calling + /// ValidateLibraryCall and OptimizeCall as appropriate. + /// @brief Run all the lib call optimizations on a Module. + virtual bool runOnModule(Module &M) { + reset(M); + + bool result = false; + hash_map OptznMap; + for (LibCallOptimization *Optzn = OptList; Optzn; Optzn = Optzn->getNext()) + OptznMap[Optzn->getFunctionName()] = Optzn; + + // The call optimizations can be recursive. That is, the optimization might + // generate a call to another function which can also be optimized. This way + // we make the LibCallOptimization instances very specific to the case they + // handle. It also means we need to keep running over the function calls in + // the module until we don't get any more optimizations possible. + bool found_optimization = false; + do { + found_optimization = false; + for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) { + // All the "well-known" functions are external and have external linkage + // because they live in a runtime library somewhere and were (probably) + // not compiled by LLVM. So, we only act on external functions that + // have external or dllimport linkage and non-empty uses. + if (!FI->isDeclaration() || + !(FI->hasExternalLinkage() || FI->hasDLLImportLinkage()) || + FI->use_empty()) + continue; + + // Get the optimization class that pertains to this function + hash_map::iterator OMI = + OptznMap.find(FI->getName()); + if (OMI == OptznMap.end()) continue; + + LibCallOptimization *CO = OMI->second; + + // Make sure the called function is suitable for the optimization + if (!CO->ValidateCalledFunction(FI, *this)) + continue; + + // Loop over each of the uses of the function + for (Value::use_iterator UI = FI->use_begin(), UE = FI->use_end(); + UI != UE ; ) { + // If the use of the function is a call instruction + if (CallInst* CI = dyn_cast(*UI++)) { + // Do the optimization on the LibCallOptimization. + if (CO->OptimizeCall(CI, *this)) { + ++SimplifiedLibCalls; + found_optimization = result = true; + CO->succeeded(); + } + } + } + } + } while (found_optimization); + + return result; + } + + /// @brief Return the *current* module we're working on. + Module* getModule() const { return M; } + + /// @brief Return the *current* target data for the module we're working on. + TargetData* getTargetData() const { return TD; } + + /// @brief Return the size_t type -- syntactic shortcut + const Type* getIntPtrType() const { return TD->getIntPtrType(); } + + /// @brief Return a Function* for the putchar libcall + Constant *get_putchar() { + if (!putchar_func) + putchar_func = + M->getOrInsertFunction("putchar", Type::Int32Ty, Type::Int32Ty, NULL); + return putchar_func; + } + + /// @brief Return a Function* for the puts libcall + Constant *get_puts() { + if (!puts_func) + puts_func = M->getOrInsertFunction("puts", Type::Int32Ty, + PointerType::get(Type::Int8Ty), + NULL); + return puts_func; + } + + /// @brief Return a Function* for the fputc libcall + Constant *get_fputc(const Type* FILEptr_type) { + if (!fputc_func) + fputc_func = M->getOrInsertFunction("fputc", Type::Int32Ty, Type::Int32Ty, + FILEptr_type, NULL); + return fputc_func; + } + + /// @brief Return a Function* for the fputs libcall + Constant *get_fputs(const Type* FILEptr_type) { + if (!fputs_func) + fputs_func = M->getOrInsertFunction("fputs", Type::Int32Ty, + PointerType::get(Type::Int8Ty), + FILEptr_type, NULL); + return fputs_func; + } + + /// @brief Return a Function* for the fwrite libcall + Constant *get_fwrite(const Type* FILEptr_type) { + if (!fwrite_func) + fwrite_func = M->getOrInsertFunction("fwrite", TD->getIntPtrType(), + PointerType::get(Type::Int8Ty), + TD->getIntPtrType(), + TD->getIntPtrType(), + FILEptr_type, NULL); + return fwrite_func; + } + + /// @brief Return a Function* for the sqrt libcall + Constant *get_sqrt() { + if (!sqrt_func) + sqrt_func = M->getOrInsertFunction("sqrt", Type::DoubleTy, + Type::DoubleTy, NULL); + return sqrt_func; + } + + /// @brief Return a Function* for the strcpy libcall + Constant *get_strcpy() { + if (!strcpy_func) + strcpy_func = M->getOrInsertFunction("strcpy", + PointerType::get(Type::Int8Ty), + PointerType::get(Type::Int8Ty), + PointerType::get(Type::Int8Ty), + NULL); + return strcpy_func; + } + + /// @brief Return a Function* for the strlen libcall + Constant *get_strlen() { + if (!strlen_func) + strlen_func = M->getOrInsertFunction("strlen", TD->getIntPtrType(), + PointerType::get(Type::Int8Ty), + NULL); + return strlen_func; + } + + /// @brief Return a Function* for the memchr libcall + Constant *get_memchr() { + if (!memchr_func) + memchr_func = M->getOrInsertFunction("memchr", + PointerType::get(Type::Int8Ty), + PointerType::get(Type::Int8Ty), + Type::Int32Ty, TD->getIntPtrType(), + NULL); + return memchr_func; + } + + /// @brief Return a Function* for the memcpy libcall + Constant *get_memcpy() { + if (!memcpy_func) { + const Type *SBP = PointerType::get(Type::Int8Ty); + const char *N = TD->getIntPtrType() == Type::Int32Ty ? + "llvm.memcpy.i32" : "llvm.memcpy.i64"; + memcpy_func = M->getOrInsertFunction(N, Type::VoidTy, SBP, SBP, + TD->getIntPtrType(), Type::Int32Ty, + NULL); + } + return memcpy_func; + } + + Constant *getUnaryFloatFunction(const char *Name, Constant *&Cache) { + if (!Cache) + Cache = M->getOrInsertFunction(Name, Type::FloatTy, Type::FloatTy, NULL); + return Cache; + } + + Constant *get_floorf() { return getUnaryFloatFunction("floorf", floorf_func);} + Constant *get_ceilf() { return getUnaryFloatFunction( "ceilf", ceilf_func);} + Constant *get_roundf() { return getUnaryFloatFunction("roundf", roundf_func);} + Constant *get_rintf() { return getUnaryFloatFunction( "rintf", rintf_func);} + Constant *get_nearbyintf() { return getUnaryFloatFunction("nearbyintf", + nearbyintf_func); } +private: + /// @brief Reset our cached data for a new Module + void reset(Module& mod) { + M = &mod; + TD = &getAnalysis(); + putchar_func = 0; + puts_func = 0; + fputc_func = 0; + fputs_func = 0; + fwrite_func = 0; + memcpy_func = 0; + memchr_func = 0; + sqrt_func = 0; + strcpy_func = 0; + strlen_func = 0; + floorf_func = 0; + ceilf_func = 0; + roundf_func = 0; + rintf_func = 0; + nearbyintf_func = 0; + } + +private: + /// Caches for function pointers. + Constant *putchar_func, *puts_func; + Constant *fputc_func, *fputs_func, *fwrite_func; + Constant *memcpy_func, *memchr_func; + Constant *sqrt_func; + Constant *strcpy_func, *strlen_func; + Constant *floorf_func, *ceilf_func, *roundf_func; + Constant *rintf_func, *nearbyintf_func; + Module *M; ///< Cached Module + TargetData *TD; ///< Cached TargetData +}; + +char SimplifyLibCalls::ID = 0; +// Register the pass +RegisterPass +X("simplify-libcalls", "Simplify well-known library calls"); + +} // anonymous namespace + +// The only public symbol in this file which just instantiates the pass object +ModulePass *llvm::createSimplifyLibCallsPass() { + return new SimplifyLibCalls(); +} + +// Classes below here, in the anonymous namespace, are all subclasses of the +// LibCallOptimization class, each implementing all optimizations possible for a +// single well-known library call. Each has a static singleton instance that +// auto registers it into the "optlist" global above. +namespace { + +// Forward declare utility functions. +static bool GetConstantStringInfo(Value *V, std::string &Str); +static Value *CastToCStr(Value *V, Instruction *IP); + +/// This LibCallOptimization will find instances of a call to "exit" that occurs +/// within the "main" function and change it to a simple "ret" instruction with +/// the same value passed to the exit function. When this is done, it splits the +/// basic block at the exit(3) call and deletes the call instruction. +/// @brief Replace calls to exit in main with a simple return +struct VISIBILITY_HIDDEN ExitInMainOptimization : public LibCallOptimization { + ExitInMainOptimization() : LibCallOptimization("exit", + "Number of 'exit' calls simplified") {} + + // Make sure the called function looks like exit (int argument, int return + // type, external linkage, not varargs). + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + return F->arg_size() >= 1 && F->arg_begin()->getType()->isInteger(); + } + + virtual bool OptimizeCall(CallInst* ci, SimplifyLibCalls& SLC) { + // To be careful, we check that the call to exit is coming from "main", that + // main has external linkage, and the return type of main and the argument + // to exit have the same type. + Function *from = ci->getParent()->getParent(); + if (from->hasExternalLinkage()) + if (from->getReturnType() == ci->getOperand(1)->getType()) + if (from->getName() == "main") { + // Okay, time to actually do the optimization. First, get the basic + // block of the call instruction + BasicBlock* bb = ci->getParent(); + + // Create a return instruction that we'll replace the call with. + // Note that the argument of the return is the argument of the call + // instruction. + new ReturnInst(ci->getOperand(1), ci); + + // Split the block at the call instruction which places it in a new + // basic block. + bb->splitBasicBlock(ci); + + // The block split caused a branch instruction to be inserted into + // the end of the original block, right after the return instruction + // that we put there. That's not a valid block, so delete the branch + // instruction. + bb->getInstList().pop_back(); + + // Now we can finally get rid of the call instruction which now lives + // in the new basic block. + ci->eraseFromParent(); + + // Optimization succeeded, return true. + return true; + } + // We didn't pass the criteria for this optimization so return false + return false; + } +} ExitInMainOptimizer; + +/// This LibCallOptimization will simplify a call to the strcat library +/// function. The simplification is possible only if the string being +/// concatenated is a constant array or a constant expression that results in +/// a constant string. In this case we can replace it with strlen + llvm.memcpy +/// of the constant string. Both of these calls are further reduced, if possible +/// on subsequent passes. +/// @brief Simplify the strcat library function. +struct VISIBILITY_HIDDEN StrCatOptimization : public LibCallOptimization { +public: + /// @brief Default constructor + StrCatOptimization() : LibCallOptimization("strcat", + "Number of 'strcat' calls simplified") {} + +public: + + /// @brief Make sure that the "strcat" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getNumParams() == 2 && + FT->getReturnType() == PointerType::get(Type::Int8Ty) && + FT->getParamType(0) == FT->getReturnType() && + FT->getParamType(1) == FT->getReturnType(); + } + + /// @brief Optimize the strcat library function + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // Extract some information from the instruction + Value *Dst = CI->getOperand(1); + Value *Src = CI->getOperand(2); + + // Extract the initializer (while making numerous checks) from the + // source operand of the call to strcat. + std::string SrcStr; + if (!GetConstantStringInfo(Src, SrcStr)) + return false; + + // Handle the simple, do-nothing case + if (SrcStr.empty()) + return ReplaceCallWith(CI, Dst); + + // We need to find the end of the destination string. That's where the + // memory is to be moved to. We just generate a call to strlen. + CallInst *DstLen = new CallInst(SLC.get_strlen(), Dst, + Dst->getName()+".len", CI); + + // Now that we have the destination's length, we must index into the + // destination's pointer to get the actual memcpy destination (end of + // the string .. we're concatenating). + Dst = new GetElementPtrInst(Dst, DstLen, Dst->getName()+".indexed", CI); + + // We have enough information to now generate the memcpy call to + // do the concatenation for us. + Value *Vals[] = { + Dst, Src, + ConstantInt::get(SLC.getIntPtrType(), SrcStr.size()+1), // copy nul byte. + ConstantInt::get(Type::Int32Ty, 1) // alignment + }; + new CallInst(SLC.get_memcpy(), Vals, 4, "", CI); + + return ReplaceCallWith(CI, Dst); + } +} StrCatOptimizer; + +/// This LibCallOptimization will simplify a call to the strchr library +/// function. It optimizes out cases where the arguments are both constant +/// and the result can be determined statically. +/// @brief Simplify the strcmp library function. +struct VISIBILITY_HIDDEN StrChrOptimization : public LibCallOptimization { +public: + StrChrOptimization() : LibCallOptimization("strchr", + "Number of 'strchr' calls simplified") {} + + /// @brief Make sure that the "strchr" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getNumParams() == 2 && + FT->getReturnType() == PointerType::get(Type::Int8Ty) && + FT->getParamType(0) == FT->getReturnType() && + isa(FT->getParamType(1)); + } + + /// @brief Perform the strchr optimizations + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // Check that the first argument to strchr is a constant array of sbyte. + std::string Str; + if (!GetConstantStringInfo(CI->getOperand(1), Str)) + return false; + + // If the second operand is not constant, just lower this to memchr since we + // know the length of the input string. + ConstantInt *CSI = dyn_cast(CI->getOperand(2)); + if (!CSI) { + Value *Args[3] = { + CI->getOperand(1), + CI->getOperand(2), + ConstantInt::get(SLC.getIntPtrType(), Str.size()+1) + }; + return ReplaceCallWith(CI, new CallInst(SLC.get_memchr(), Args, 3, + CI->getName(), CI)); + } + + // strchr can find the nul character. + Str += '\0'; + + // Get the character we're looking for + char CharValue = CSI->getSExtValue(); + + // Compute the offset + uint64_t i = 0; + while (1) { + if (i == Str.size()) // Didn't find the char. strchr returns null. + return ReplaceCallWith(CI, Constant::getNullValue(CI->getType())); + // Did we find our match? + if (Str[i] == CharValue) + break; + ++i; + } + + // strchr(s+n,c) -> gep(s+n+i,c) + // (if c is a constant integer and s is a constant string) + Value *Idx = ConstantInt::get(Type::Int64Ty, i); + Value *GEP = new GetElementPtrInst(CI->getOperand(1), Idx, + CI->getOperand(1)->getName() + + ".strchr", CI); + return ReplaceCallWith(CI, GEP); + } +} StrChrOptimizer; + +/// This LibCallOptimization will simplify a call to the strcmp library +/// function. It optimizes out cases where one or both arguments are constant +/// and the result can be determined statically. +/// @brief Simplify the strcmp library function. +struct VISIBILITY_HIDDEN StrCmpOptimization : public LibCallOptimization { +public: + StrCmpOptimization() : LibCallOptimization("strcmp", + "Number of 'strcmp' calls simplified") {} + + /// @brief Make sure that the "strcmp" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getReturnType() == Type::Int32Ty && FT->getNumParams() == 2 && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == PointerType::get(Type::Int8Ty); + } + + /// @brief Perform the strcmp optimization + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // First, check to see if src and destination are the same. If they are, + // then the optimization is to replace the CallInst with a constant 0 + // because the call is a no-op. + Value *Str1P = CI->getOperand(1); + Value *Str2P = CI->getOperand(2); + if (Str1P == Str2P) // strcmp(x,x) -> 0 + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 0)); + + std::string Str1; + if (!GetConstantStringInfo(Str1P, Str1)) + return false; + if (Str1.empty()) { + // strcmp("", x) -> *x + Value *V = new LoadInst(Str2P, CI->getName()+".load", CI); + V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI); + return ReplaceCallWith(CI, V); + } + + std::string Str2; + if (!GetConstantStringInfo(Str2P, Str2)) + return false; + if (Str2.empty()) { + // strcmp(x,"") -> *x + Value *V = new LoadInst(Str1P, CI->getName()+".load", CI); + V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI); + return ReplaceCallWith(CI, V); + } + + // strcmp(x, y) -> cnst (if both x and y are constant strings) + int R = strcmp(Str1.c_str(), Str2.c_str()); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), R)); + } +} StrCmpOptimizer; + +/// This LibCallOptimization will simplify a call to the strncmp library +/// function. It optimizes out cases where one or both arguments are constant +/// and the result can be determined statically. +/// @brief Simplify the strncmp library function. +struct VISIBILITY_HIDDEN StrNCmpOptimization : public LibCallOptimization { +public: + StrNCmpOptimization() : LibCallOptimization("strncmp", + "Number of 'strncmp' calls simplified") {} + + /// @brief Make sure that the "strncmp" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getReturnType() == Type::Int32Ty && FT->getNumParams() == 3 && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == PointerType::get(Type::Int8Ty) && + isa(FT->getParamType(2)); + return false; + } + + /// @brief Perform the strncmp optimization + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // First, check to see if src and destination are the same. If they are, + // then the optimization is to replace the CallInst with a constant 0 + // because the call is a no-op. + Value *Str1P = CI->getOperand(1); + Value *Str2P = CI->getOperand(2); + if (Str1P == Str2P) // strncmp(x,x, n) -> 0 + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 0)); + + // Check the length argument, if it is Constant zero then the strings are + // considered equal. + uint64_t Length; + if (ConstantInt *LengthArg = dyn_cast(CI->getOperand(3))) + Length = LengthArg->getZExtValue(); + else + return false; + + if (Length == 0) // strncmp(x,y,0) -> 0 + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 0)); + + std::string Str1; + if (!GetConstantStringInfo(Str1P, Str1)) + return false; + if (Str1.empty()) { + // strncmp("", x, n) -> *x + Value *V = new LoadInst(Str2P, CI->getName()+".load", CI); + V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI); + return ReplaceCallWith(CI, V); + } + + std::string Str2; + if (!GetConstantStringInfo(Str2P, Str2)) + return false; + if (Str2.empty()) { + // strncmp(x, "", n) -> *x + Value *V = new LoadInst(Str1P, CI->getName()+".load", CI); + V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI); + return ReplaceCallWith(CI, V); + } + + // strncmp(x, y, n) -> cnst (if both x and y are constant strings) + int R = strncmp(Str1.c_str(), Str2.c_str(), Length); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), R)); + } +} StrNCmpOptimizer; + +/// This LibCallOptimization will simplify a call to the strcpy library +/// function. Two optimizations are possible: +/// (1) If src and dest are the same and not volatile, just return dest +/// (2) If the src is a constant then we can convert to llvm.memmove +/// @brief Simplify the strcpy library function. +struct VISIBILITY_HIDDEN StrCpyOptimization : public LibCallOptimization { +public: + StrCpyOptimization() : LibCallOptimization("strcpy", + "Number of 'strcpy' calls simplified") {} + + /// @brief Make sure that the "strcpy" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getNumParams() == 2 && + FT->getParamType(0) == FT->getParamType(1) && + FT->getReturnType() == FT->getParamType(0) && + FT->getParamType(0) == PointerType::get(Type::Int8Ty); + } + + /// @brief Perform the strcpy optimization + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // First, check to see if src and destination are the same. If they are, + // then the optimization is to replace the CallInst with the destination + // because the call is a no-op. Note that this corresponds to the + // degenerate strcpy(X,X) case which should have "undefined" results + // according to the C specification. However, it occurs sometimes and + // we optimize it as a no-op. + Value *Dst = CI->getOperand(1); + Value *Src = CI->getOperand(2); + if (Dst == Src) { + // strcpy(x, x) -> x + return ReplaceCallWith(CI, Dst); + } + + // Get the length of the constant string referenced by the Src operand. + std::string SrcStr; + if (!GetConstantStringInfo(Src, SrcStr)) + return false; + + // If the constant string's length is zero we can optimize this by just + // doing a store of 0 at the first byte of the destination + if (SrcStr.size() == 0) { + new StoreInst(ConstantInt::get(Type::Int8Ty, 0), Dst, CI); + return ReplaceCallWith(CI, Dst); + } + + // We have enough information to now generate the memcpy call to + // do the concatenation for us. + Value *MemcpyOps[] = { + Dst, Src, // Pass length including nul byte. + ConstantInt::get(SLC.getIntPtrType(), SrcStr.size()+1), + ConstantInt::get(Type::Int32Ty, 1) // alignment + }; + new CallInst(SLC.get_memcpy(), MemcpyOps, 4, "", CI); + + return ReplaceCallWith(CI, Dst); + } +} StrCpyOptimizer; + +/// This LibCallOptimization will simplify a call to the strlen library +/// function by replacing it with a constant value if the string provided to +/// it is a constant array. +/// @brief Simplify the strlen library function. +struct VISIBILITY_HIDDEN StrLenOptimization : public LibCallOptimization { + StrLenOptimization() : LibCallOptimization("strlen", + "Number of 'strlen' calls simplified") {} + + /// @brief Make sure that the "strlen" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getNumParams() == 1 && + FT->getParamType(0) == PointerType::get(Type::Int8Ty) && + isa(FT->getReturnType()); + } + + /// @brief Perform the strlen optimization + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // Make sure we're dealing with an sbyte* here. + Value *Src = CI->getOperand(1); + + // Does the call to strlen have exactly one use? + if (CI->hasOneUse()) { + // Is that single use a icmp operator? + if (ICmpInst *Cmp = dyn_cast(CI->use_back())) + // Is it compared against a constant integer? + if (ConstantInt *Cst = dyn_cast(Cmp->getOperand(1))) { + // If its compared against length 0 with == or != + if (Cst->getZExtValue() == 0 && Cmp->isEquality()) { + // strlen(x) != 0 -> *x != 0 + // strlen(x) == 0 -> *x == 0 + Value *V = new LoadInst(Src, Src->getName()+".first", CI); + V = new ICmpInst(Cmp->getPredicate(), V, + ConstantInt::get(Type::Int8Ty, 0), + Cmp->getName()+".strlen", CI); + Cmp->replaceAllUsesWith(V); + Cmp->eraseFromParent(); + return ReplaceCallWith(CI, 0); // no uses. + } + } + } + + // Get the length of the constant string operand + std::string Str; + if (!GetConstantStringInfo(Src, Str)) + return false; + + // strlen("xyz") -> 3 (for example) + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), Str.size())); + } +} StrLenOptimizer; + +/// IsOnlyUsedInEqualsComparison - Return true if it only matters that the value +/// is equal or not-equal to zero. +static bool IsOnlyUsedInEqualsZeroComparison(Instruction *I) { + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) { + if (ICmpInst *IC = dyn_cast(*UI)) + if (IC->isEquality()) + if (Constant *C = dyn_cast(IC->getOperand(1))) + if (C->isNullValue()) + continue; + // Unknown instruction. + return false; + } + return true; +} + +/// This memcmpOptimization will simplify a call to the memcmp library +/// function. +struct VISIBILITY_HIDDEN memcmpOptimization : public LibCallOptimization { + /// @brief Default Constructor + memcmpOptimization() + : LibCallOptimization("memcmp", "Number of 'memcmp' calls simplified") {} + + /// @brief Make sure that the "memcmp" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &TD) { + Function::const_arg_iterator AI = F->arg_begin(); + if (F->arg_size() != 3 || !isa(AI->getType())) return false; + if (!isa((++AI)->getType())) return false; + if (!(++AI)->getType()->isInteger()) return false; + if (!F->getReturnType()->isInteger()) return false; + return true; + } + + /// Because of alignment and instruction information that we don't have, we + /// leave the bulk of this to the code generators. + /// + /// Note that we could do much more if we could force alignment on otherwise + /// small aligned allocas, or if we could indicate that loads have a small + /// alignment. + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &TD) { + Value *LHS = CI->getOperand(1), *RHS = CI->getOperand(2); + + // If the two operands are the same, return zero. + if (LHS == RHS) { + // memcmp(s,s,x) -> 0 + return ReplaceCallWith(CI, Constant::getNullValue(CI->getType())); + } + + // Make sure we have a constant length. + ConstantInt *LenC = dyn_cast(CI->getOperand(3)); + if (!LenC) return false; + uint64_t Len = LenC->getZExtValue(); + + // If the length is zero, this returns 0. + switch (Len) { + case 0: + // memcmp(s1,s2,0) -> 0 + return ReplaceCallWith(CI, Constant::getNullValue(CI->getType())); + case 1: { + // memcmp(S1,S2,1) -> *(ubyte*)S1 - *(ubyte*)S2 + const Type *UCharPtr = PointerType::get(Type::Int8Ty); + CastInst *Op1Cast = CastInst::create( + Instruction::BitCast, LHS, UCharPtr, LHS->getName(), CI); + CastInst *Op2Cast = CastInst::create( + Instruction::BitCast, RHS, UCharPtr, RHS->getName(), CI); + Value *S1V = new LoadInst(Op1Cast, LHS->getName()+".val", CI); + Value *S2V = new LoadInst(Op2Cast, RHS->getName()+".val", CI); + Value *RV = BinaryOperator::createSub(S1V, S2V, CI->getName()+".diff",CI); + if (RV->getType() != CI->getType()) + RV = CastInst::createIntegerCast(RV, CI->getType(), false, + RV->getName(), CI); + return ReplaceCallWith(CI, RV); + } + case 2: + if (IsOnlyUsedInEqualsZeroComparison(CI)) { + // TODO: IF both are aligned, use a short load/compare. + + // memcmp(S1,S2,2) -> S1[0]-S2[0] | S1[1]-S2[1] iff only ==/!= 0 matters + const Type *UCharPtr = PointerType::get(Type::Int8Ty); + CastInst *Op1Cast = CastInst::create( + Instruction::BitCast, LHS, UCharPtr, LHS->getName(), CI); + CastInst *Op2Cast = CastInst::create( + Instruction::BitCast, RHS, UCharPtr, RHS->getName(), CI); + Value *S1V1 = new LoadInst(Op1Cast, LHS->getName()+".val1", CI); + Value *S2V1 = new LoadInst(Op2Cast, RHS->getName()+".val1", CI); + Value *D1 = BinaryOperator::createSub(S1V1, S2V1, + CI->getName()+".d1", CI); + Constant *One = ConstantInt::get(Type::Int32Ty, 1); + Value *G1 = new GetElementPtrInst(Op1Cast, One, "next1v", CI); + Value *G2 = new GetElementPtrInst(Op2Cast, One, "next2v", CI); + Value *S1V2 = new LoadInst(G1, LHS->getName()+".val2", CI); + Value *S2V2 = new LoadInst(G2, RHS->getName()+".val2", CI); + Value *D2 = BinaryOperator::createSub(S1V2, S2V2, + CI->getName()+".d1", CI); + Value *Or = BinaryOperator::createOr(D1, D2, CI->getName()+".res", CI); + if (Or->getType() != CI->getType()) + Or = CastInst::createIntegerCast(Or, CI->getType(), false /*ZExt*/, + Or->getName(), CI); + return ReplaceCallWith(CI, Or); + } + break; + default: + break; + } + + return false; + } +} memcmpOptimizer; + + +/// This LibCallOptimization will simplify a call to the memcpy library +/// function by expanding it out to a single store of size 0, 1, 2, 4, or 8 +/// bytes depending on the length of the string and the alignment. Additional +/// optimizations are possible in code generation (sequence of immediate store) +/// @brief Simplify the memcpy library function. +struct VISIBILITY_HIDDEN LLVMMemCpyMoveOptzn : public LibCallOptimization { + LLVMMemCpyMoveOptzn(const char* fname, const char* desc) + : LibCallOptimization(fname, desc) {} + + /// @brief Make sure that the "memcpy" function has the right prototype + virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& TD) { + // Just make sure this has 4 arguments per LLVM spec. + return (f->arg_size() == 4); + } + + /// Because of alignment and instruction information that we don't have, we + /// leave the bulk of this to the code generators. The optimization here just + /// deals with a few degenerate cases where the length of the string and the + /// alignment match the sizes of our intrinsic types so we can do a load and + /// store instead of the memcpy call. + /// @brief Perform the memcpy optimization. + virtual bool OptimizeCall(CallInst* ci, SimplifyLibCalls& TD) { + // Make sure we have constant int values to work with + ConstantInt* LEN = dyn_cast(ci->getOperand(3)); + if (!LEN) + return false; + ConstantInt* ALIGN = dyn_cast(ci->getOperand(4)); + if (!ALIGN) + return false; + + // If the length is larger than the alignment, we can't optimize + uint64_t len = LEN->getZExtValue(); + uint64_t alignment = ALIGN->getZExtValue(); + if (alignment == 0) + alignment = 1; // Alignment 0 is identity for alignment 1 + if (len > alignment) + return false; + + // Get the type we will cast to, based on size of the string + Value* dest = ci->getOperand(1); + Value* src = ci->getOperand(2); + const Type* castType = 0; + switch (len) { + case 0: + // memcpy(d,s,0,a) -> d + return ReplaceCallWith(ci, 0); + case 1: castType = Type::Int8Ty; break; + case 2: castType = Type::Int16Ty; break; + case 4: castType = Type::Int32Ty; break; + case 8: castType = Type::Int64Ty; break; + default: + return false; + } + + // Cast source and dest to the right sized primitive and then load/store + CastInst* SrcCast = CastInst::create(Instruction::BitCast, + src, PointerType::get(castType), src->getName()+".cast", ci); + CastInst* DestCast = CastInst::create(Instruction::BitCast, + dest, PointerType::get(castType),dest->getName()+".cast", ci); + LoadInst* LI = new LoadInst(SrcCast,SrcCast->getName()+".val",ci); + new StoreInst(LI, DestCast, ci); + return ReplaceCallWith(ci, 0); + } +}; + +/// This LibCallOptimization will simplify a call to the memcpy/memmove library +/// functions. +LLVMMemCpyMoveOptzn LLVMMemCpyOptimizer32("llvm.memcpy.i32", + "Number of 'llvm.memcpy' calls simplified"); +LLVMMemCpyMoveOptzn LLVMMemCpyOptimizer64("llvm.memcpy.i64", + "Number of 'llvm.memcpy' calls simplified"); +LLVMMemCpyMoveOptzn LLVMMemMoveOptimizer32("llvm.memmove.i32", + "Number of 'llvm.memmove' calls simplified"); +LLVMMemCpyMoveOptzn LLVMMemMoveOptimizer64("llvm.memmove.i64", + "Number of 'llvm.memmove' calls simplified"); + +/// This LibCallOptimization will simplify a call to the memset library +/// function by expanding it out to a single store of size 0, 1, 2, 4, or 8 +/// bytes depending on the length argument. +struct VISIBILITY_HIDDEN LLVMMemSetOptimization : public LibCallOptimization { + /// @brief Default Constructor + LLVMMemSetOptimization(const char *Name) : LibCallOptimization(Name, + "Number of 'llvm.memset' calls simplified") {} + + /// @brief Make sure that the "memset" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &TD) { + // Just make sure this has 3 arguments per LLVM spec. + return F->arg_size() == 4; + } + + /// Because of alignment and instruction information that we don't have, we + /// leave the bulk of this to the code generators. The optimization here just + /// deals with a few degenerate cases where the length parameter is constant + /// and the alignment matches the sizes of our intrinsic types so we can do + /// store instead of the memcpy call. Other calls are transformed into the + /// llvm.memset intrinsic. + /// @brief Perform the memset optimization. + virtual bool OptimizeCall(CallInst *ci, SimplifyLibCalls &TD) { + // Make sure we have constant int values to work with + ConstantInt* LEN = dyn_cast(ci->getOperand(3)); + if (!LEN) + return false; + ConstantInt* ALIGN = dyn_cast(ci->getOperand(4)); + if (!ALIGN) + return false; + + // Extract the length and alignment + uint64_t len = LEN->getZExtValue(); + uint64_t alignment = ALIGN->getZExtValue(); + + // Alignment 0 is identity for alignment 1 + if (alignment == 0) + alignment = 1; + + // If the length is zero, this is a no-op + if (len == 0) { + // memset(d,c,0,a) -> noop + return ReplaceCallWith(ci, 0); + } + + // If the length is larger than the alignment, we can't optimize + if (len > alignment) + return false; + + // Make sure we have a constant ubyte to work with so we can extract + // the value to be filled. + ConstantInt* FILL = dyn_cast(ci->getOperand(2)); + if (!FILL) + return false; + if (FILL->getType() != Type::Int8Ty) + return false; + + // memset(s,c,n) -> store s, c (for n=1,2,4,8) + + // Extract the fill character + uint64_t fill_char = FILL->getZExtValue(); + uint64_t fill_value = fill_char; + + // Get the type we will cast to, based on size of memory area to fill, and + // and the value we will store there. + Value* dest = ci->getOperand(1); + const Type* castType = 0; + switch (len) { + case 1: + castType = Type::Int8Ty; + break; + case 2: + castType = Type::Int16Ty; + fill_value |= fill_char << 8; + break; + case 4: + castType = Type::Int32Ty; + fill_value |= fill_char << 8 | fill_char << 16 | fill_char << 24; + break; + case 8: + castType = Type::Int64Ty; + fill_value |= fill_char << 8 | fill_char << 16 | fill_char << 24; + fill_value |= fill_char << 32 | fill_char << 40 | fill_char << 48; + fill_value |= fill_char << 56; + break; + default: + return false; + } + + // Cast dest to the right sized primitive and then load/store + CastInst* DestCast = new BitCastInst(dest, PointerType::get(castType), + dest->getName()+".cast", ci); + new StoreInst(ConstantInt::get(castType,fill_value),DestCast, ci); + return ReplaceCallWith(ci, 0); + } +}; + +LLVMMemSetOptimization MemSet32Optimizer("llvm.memset.i32"); +LLVMMemSetOptimization MemSet64Optimizer("llvm.memset.i64"); + + +/// This LibCallOptimization will simplify calls to the "pow" library +/// function. It looks for cases where the result of pow is well known and +/// substitutes the appropriate value. +/// @brief Simplify the pow library function. +struct VISIBILITY_HIDDEN PowOptimization : public LibCallOptimization { +public: + /// @brief Default Constructor + PowOptimization() : LibCallOptimization("pow", + "Number of 'pow' calls simplified") {} + + /// @brief Make sure that the "pow" function has the right prototype + virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC){ + // Just make sure this has 2 arguments + return (f->arg_size() == 2); + } + + /// @brief Perform the pow optimization. + virtual bool OptimizeCall(CallInst *ci, SimplifyLibCalls &SLC) { + const Type *Ty = cast(ci->getOperand(0))->getReturnType(); + Value* base = ci->getOperand(1); + Value* expn = ci->getOperand(2); + if (ConstantFP *Op1 = dyn_cast(base)) { + double Op1V = Op1->getValue(); + if (Op1V == 1.0) // pow(1.0,x) -> 1.0 + return ReplaceCallWith(ci, ConstantFP::get(Ty, 1.0)); + } else if (ConstantFP* Op2 = dyn_cast(expn)) { + double Op2V = Op2->getValue(); + if (Op2V == 0.0) { + // pow(x,0.0) -> 1.0 + return ReplaceCallWith(ci, ConstantFP::get(Ty,1.0)); + } else if (Op2V == 0.5) { + // pow(x,0.5) -> sqrt(x) + CallInst* sqrt_inst = new CallInst(SLC.get_sqrt(), base, + ci->getName()+".pow",ci); + return ReplaceCallWith(ci, sqrt_inst); + } else if (Op2V == 1.0) { + // pow(x,1.0) -> x + return ReplaceCallWith(ci, base); + } else if (Op2V == -1.0) { + // pow(x,-1.0) -> 1.0/x + Value *div_inst = + BinaryOperator::createFDiv(ConstantFP::get(Ty, 1.0), base, + ci->getName()+".pow", ci); + return ReplaceCallWith(ci, div_inst); + } + } + return false; // opt failed + } +} PowOptimizer; + +/// This LibCallOptimization will simplify calls to the "printf" library +/// function. It looks for cases where the result of printf is not used and the +/// operation can be reduced to something simpler. +/// @brief Simplify the printf library function. +struct VISIBILITY_HIDDEN PrintfOptimization : public LibCallOptimization { +public: + /// @brief Default Constructor + PrintfOptimization() : LibCallOptimization("printf", + "Number of 'printf' calls simplified") {} + + /// @brief Make sure that the "printf" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + // Just make sure this has at least 1 argument and returns an integer or + // void type. + const FunctionType *FT = F->getFunctionType(); + return FT->getNumParams() >= 1 && + (isa(FT->getReturnType()) || + FT->getReturnType() == Type::VoidTy); + } + + /// @brief Perform the printf optimization. + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // All the optimizations depend on the length of the first argument and the + // fact that it is a constant string array. Check that now + std::string FormatStr; + if (!GetConstantStringInfo(CI->getOperand(1), FormatStr)) + return false; + + // If this is a simple constant string with no format specifiers that ends + // with a \n, turn it into a puts call. + if (FormatStr.empty()) { + // Tolerate printf's declared void. + if (CI->use_empty()) return ReplaceCallWith(CI, 0); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 0)); + } + + if (FormatStr.size() == 1) { + // Turn this into a putchar call, even if it is a %. + Value *V = ConstantInt::get(Type::Int32Ty, FormatStr[0]); + new CallInst(SLC.get_putchar(), V, "", CI); + if (CI->use_empty()) return ReplaceCallWith(CI, 0); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 1)); + } + + // Check to see if the format str is something like "foo\n", in which case + // we convert it to a puts call. We don't allow it to contain any format + // characters. + if (FormatStr[FormatStr.size()-1] == '\n' && + FormatStr.find('%') == std::string::npos) { + // Create a string literal with no \n on it. We expect the constant merge + // pass to be run after this pass, to merge duplicate strings. + FormatStr.erase(FormatStr.end()-1); + Constant *Init = ConstantArray::get(FormatStr, true); + Constant *GV = new GlobalVariable(Init->getType(), true, + GlobalVariable::InternalLinkage, + Init, "str", + CI->getParent()->getParent()->getParent()); + // Cast GV to be a pointer to char. + GV = ConstantExpr::getBitCast(GV, PointerType::get(Type::Int8Ty)); + new CallInst(SLC.get_puts(), GV, "", CI); + + if (CI->use_empty()) return ReplaceCallWith(CI, 0); + return ReplaceCallWith(CI, + ConstantInt::get(CI->getType(), FormatStr.size())); + } + + + // Only support %c or "%s\n" for now. + if (FormatStr.size() < 2 || FormatStr[0] != '%') + return false; + + // Get the second character and switch on its value + switch (FormatStr[1]) { + default: return false; + case 's': + if (FormatStr != "%s\n" || CI->getNumOperands() < 3 || + // TODO: could insert strlen call to compute string length. + !CI->use_empty()) + return false; + + // printf("%s\n",str) -> puts(str) + new CallInst(SLC.get_puts(), CastToCStr(CI->getOperand(2), CI), + CI->getName(), CI); + return ReplaceCallWith(CI, 0); + case 'c': { + // printf("%c",c) -> putchar(c) + if (FormatStr.size() != 2 || CI->getNumOperands() < 3) + return false; + + Value *V = CI->getOperand(2); + if (!isa(V->getType()) || + cast(V->getType())->getBitWidth() > 32) + return false; + + V = CastInst::createZExtOrBitCast(V, Type::Int32Ty, CI->getName()+".int", + CI); + new CallInst(SLC.get_putchar(), V, "", CI); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 1)); + } + } + } +} PrintfOptimizer; + +/// This LibCallOptimization will simplify calls to the "fprintf" library +/// function. It looks for cases where the result of fprintf is not used and the +/// operation can be reduced to something simpler. +/// @brief Simplify the fprintf library function. +struct VISIBILITY_HIDDEN FPrintFOptimization : public LibCallOptimization { +public: + /// @brief Default Constructor + FPrintFOptimization() : LibCallOptimization("fprintf", + "Number of 'fprintf' calls simplified") {} + + /// @brief Make sure that the "fprintf" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getNumParams() == 2 && // two fixed arguments. + FT->getParamType(1) == PointerType::get(Type::Int8Ty) && + isa(FT->getParamType(0)) && + isa(FT->getReturnType()); + } + + /// @brief Perform the fprintf optimization. + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // If the call has more than 3 operands, we can't optimize it + if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4) + return false; + + // All the optimizations depend on the format string. + std::string FormatStr; + if (!GetConstantStringInfo(CI->getOperand(2), FormatStr)) + return false; + + // If this is just a format string, turn it into fwrite. + if (CI->getNumOperands() == 3) { + for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) + if (FormatStr[i] == '%') + return false; // we found a format specifier + + // fprintf(file,fmt) -> fwrite(fmt,strlen(fmt),file) + const Type *FILEty = CI->getOperand(1)->getType(); + + Value *FWriteArgs[] = { + CI->getOperand(2), + ConstantInt::get(SLC.getIntPtrType(), FormatStr.size()), + ConstantInt::get(SLC.getIntPtrType(), 1), + CI->getOperand(1) + }; + new CallInst(SLC.get_fwrite(FILEty), FWriteArgs, 4, CI->getName(), CI); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), + FormatStr.size())); + } + + // The remaining optimizations require the format string to be length 2: + // "%s" or "%c". + if (FormatStr.size() != 2 || FormatStr[0] != '%') + return false; + + // Get the second character and switch on its value + switch (FormatStr[1]) { + case 'c': { + // fprintf(file,"%c",c) -> fputc(c,file) + const Type *FILETy = CI->getOperand(1)->getType(); + Value *C = CastInst::createZExtOrBitCast(CI->getOperand(3), Type::Int32Ty, + CI->getName()+".int", CI); + new CallInst(SLC.get_fputc(FILETy), C, CI->getOperand(1), "", CI); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 1)); + } + case 's': { + const Type *FILETy = CI->getOperand(1)->getType(); + + // If the result of the fprintf call is used, we can't do this. + // TODO: we should insert a strlen call. + if (!CI->use_empty()) + return false; + + // fprintf(file,"%s",str) -> fputs(str,file) + new CallInst(SLC.get_fputs(FILETy), CastToCStr(CI->getOperand(3), CI), + CI->getOperand(1), CI->getName(), CI); + return ReplaceCallWith(CI, 0); + } + default: + return false; + } + } +} FPrintFOptimizer; + +/// This LibCallOptimization will simplify calls to the "sprintf" library +/// function. It looks for cases where the result of sprintf is not used and the +/// operation can be reduced to something simpler. +/// @brief Simplify the sprintf library function. +struct VISIBILITY_HIDDEN SPrintFOptimization : public LibCallOptimization { +public: + /// @brief Default Constructor + SPrintFOptimization() : LibCallOptimization("sprintf", + "Number of 'sprintf' calls simplified") {} + + /// @brief Make sure that the "sprintf" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getNumParams() == 2 && // two fixed arguments. + FT->getParamType(1) == PointerType::get(Type::Int8Ty) && + FT->getParamType(0) == FT->getParamType(1) && + isa(FT->getReturnType()); + } + + /// @brief Perform the sprintf optimization. + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // If the call has more than 3 operands, we can't optimize it + if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4) + return false; + + std::string FormatStr; + if (!GetConstantStringInfo(CI->getOperand(2), FormatStr)) + return false; + + if (CI->getNumOperands() == 3) { + // Make sure there's no % in the constant array + for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) + if (FormatStr[i] == '%') + return false; // we found a format specifier + + // sprintf(str,fmt) -> llvm.memcpy(str,fmt,strlen(fmt),1) + Value *MemCpyArgs[] = { + CI->getOperand(1), CI->getOperand(2), + ConstantInt::get(SLC.getIntPtrType(), + FormatStr.size()+1), // Copy the nul byte. + ConstantInt::get(Type::Int32Ty, 1) + }; + new CallInst(SLC.get_memcpy(), MemCpyArgs, 4, "", CI); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), + FormatStr.size())); + } + + // The remaining optimizations require the format string to be "%s" or "%c". + if (FormatStr.size() != 2 || FormatStr[0] != '%') + return false; + + // Get the second character and switch on its value + switch (FormatStr[1]) { + case 'c': { + // sprintf(dest,"%c",chr) -> store chr, dest + Value *V = CastInst::createTruncOrBitCast(CI->getOperand(3), + Type::Int8Ty, "char", CI); + new StoreInst(V, CI->getOperand(1), CI); + Value *Ptr = new GetElementPtrInst(CI->getOperand(1), + ConstantInt::get(Type::Int32Ty, 1), + CI->getOperand(1)->getName()+".end", + CI); + new StoreInst(ConstantInt::get(Type::Int8Ty,0), Ptr, CI); + return ReplaceCallWith(CI, ConstantInt::get(Type::Int32Ty, 1)); + } + case 's': { + // sprintf(dest,"%s",str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) + Value *Len = new CallInst(SLC.get_strlen(), + CastToCStr(CI->getOperand(3), CI), + CI->getOperand(3)->getName()+".len", CI); + Value *UnincLen = Len; + Len = BinaryOperator::createAdd(Len, ConstantInt::get(Len->getType(), 1), + Len->getName()+"1", CI); + Value *MemcpyArgs[4] = { + CI->getOperand(1), + CastToCStr(CI->getOperand(3), CI), + Len, + ConstantInt::get(Type::Int32Ty, 1) + }; + new CallInst(SLC.get_memcpy(), MemcpyArgs, 4, "", CI); + + // The strlen result is the unincremented number of bytes in the string. + if (!CI->use_empty()) { + if (UnincLen->getType() != CI->getType()) + UnincLen = CastInst::createIntegerCast(UnincLen, CI->getType(), false, + Len->getName(), CI); + CI->replaceAllUsesWith(UnincLen); + } + return ReplaceCallWith(CI, 0); + } + } + return false; + } +} SPrintFOptimizer; + +/// This LibCallOptimization will simplify calls to the "fputs" library +/// function. It looks for cases where the result of fputs is not used and the +/// operation can be reduced to something simpler. +/// @brief Simplify the fputs library function. +struct VISIBILITY_HIDDEN FPutsOptimization : public LibCallOptimization { +public: + /// @brief Default Constructor + FPutsOptimization() : LibCallOptimization("fputs", + "Number of 'fputs' calls simplified") {} + + /// @brief Make sure that the "fputs" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + // Just make sure this has 2 arguments + return F->arg_size() == 2; + } + + /// @brief Perform the fputs optimization. + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // If the result is used, none of these optimizations work. + if (!CI->use_empty()) + return false; + + // All the optimizations depend on the length of the first argument and the + // fact that it is a constant string array. Check that now + std::string Str; + if (!GetConstantStringInfo(CI->getOperand(1), Str)) + return false; + + const Type *FILETy = CI->getOperand(2)->getType(); + // fputs(s,F) -> fwrite(s,1,len,F) (if s is constant and strlen(s) > 1) + Value *FWriteParms[4] = { + CI->getOperand(1), + ConstantInt::get(SLC.getIntPtrType(), Str.size()), + ConstantInt::get(SLC.getIntPtrType(), 1), + CI->getOperand(2) + }; + new CallInst(SLC.get_fwrite(FILETy), FWriteParms, 4, "", CI); + return ReplaceCallWith(CI, 0); // Known to have no uses (see above). + } +} FPutsOptimizer; + +/// This LibCallOptimization will simplify calls to the "fwrite" function. +struct VISIBILITY_HIDDEN FWriteOptimization : public LibCallOptimization { +public: + /// @brief Default Constructor + FWriteOptimization() : LibCallOptimization("fwrite", + "Number of 'fwrite' calls simplified") {} + + /// @brief Make sure that the "fputs" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getNumParams() == 4 && + FT->getParamType(0) == PointerType::get(Type::Int8Ty) && + FT->getParamType(1) == FT->getParamType(2) && + isa(FT->getParamType(1)) && + isa(FT->getParamType(3)) && + isa(FT->getReturnType()); + } + + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // Get the element size and count. + uint64_t EltSize, EltCount; + if (ConstantInt *C = dyn_cast(CI->getOperand(2))) + EltSize = C->getZExtValue(); + else + return false; + if (ConstantInt *C = dyn_cast(CI->getOperand(3))) + EltCount = C->getZExtValue(); + else + return false; + + // If this is writing zero records, remove the call (it's a noop). + if (EltSize * EltCount == 0) + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 0)); + + // If this is writing one byte, turn it into fputc. + if (EltSize == 1 && EltCount == 1) { + // fwrite(s,1,1,F) -> fputc(s[0],F) + Value *Ptr = CI->getOperand(1); + Value *Val = new LoadInst(Ptr, Ptr->getName()+".byte", CI); + Val = new ZExtInst(Val, Type::Int32Ty, Val->getName()+".int", CI); + const Type *FILETy = CI->getOperand(4)->getType(); + new CallInst(SLC.get_fputc(FILETy), Val, CI->getOperand(4), "", CI); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 1)); + } + return false; + } +} FWriteOptimizer; + +/// This LibCallOptimization will simplify calls to the "isdigit" library +/// function. It simply does range checks the parameter explicitly. +/// @brief Simplify the isdigit library function. +struct VISIBILITY_HIDDEN isdigitOptimization : public LibCallOptimization { +public: + isdigitOptimization() : LibCallOptimization("isdigit", + "Number of 'isdigit' calls simplified") {} + + /// @brief Make sure that the "isdigit" function has the right prototype + virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC){ + // Just make sure this has 1 argument + return (f->arg_size() == 1); + } + + /// @brief Perform the toascii optimization. + virtual bool OptimizeCall(CallInst *ci, SimplifyLibCalls &SLC) { + if (ConstantInt* CI = dyn_cast(ci->getOperand(1))) { + // isdigit(c) -> 0 or 1, if 'c' is constant + uint64_t val = CI->getZExtValue(); + if (val >= '0' && val <= '9') + return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty, 1)); + else + return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty, 0)); + } + + // isdigit(c) -> (unsigned)c - '0' <= 9 + CastInst* cast = CastInst::createIntegerCast(ci->getOperand(1), + Type::Int32Ty, false/*ZExt*/, ci->getOperand(1)->getName()+".uint", ci); + BinaryOperator* sub_inst = BinaryOperator::createSub(cast, + ConstantInt::get(Type::Int32Ty,0x30), + ci->getOperand(1)->getName()+".sub",ci); + ICmpInst* setcond_inst = new ICmpInst(ICmpInst::ICMP_ULE,sub_inst, + ConstantInt::get(Type::Int32Ty,9), + ci->getOperand(1)->getName()+".cmp",ci); + CastInst* c2 = new ZExtInst(setcond_inst, Type::Int32Ty, + ci->getOperand(1)->getName()+".isdigit", ci); + return ReplaceCallWith(ci, c2); + } +} isdigitOptimizer; + +struct VISIBILITY_HIDDEN isasciiOptimization : public LibCallOptimization { +public: + isasciiOptimization() + : LibCallOptimization("isascii", "Number of 'isascii' calls simplified") {} + + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + return F->arg_size() == 1 && F->arg_begin()->getType()->isInteger() && + F->getReturnType()->isInteger(); + } + + /// @brief Perform the isascii optimization. + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { + // isascii(c) -> (unsigned)c < 128 + Value *V = CI->getOperand(1); + Value *Cmp = new ICmpInst(ICmpInst::ICMP_ULT, V, + ConstantInt::get(V->getType(), 128), + V->getName()+".isascii", CI); + if (Cmp->getType() != CI->getType()) + Cmp = new ZExtInst(Cmp, CI->getType(), Cmp->getName(), CI); + return ReplaceCallWith(CI, Cmp); + } +} isasciiOptimizer; + + +/// This LibCallOptimization will simplify calls to the "toascii" library +/// function. It simply does the corresponding and operation to restrict the +/// range of values to the ASCII character set (0-127). +/// @brief Simplify the toascii library function. +struct VISIBILITY_HIDDEN ToAsciiOptimization : public LibCallOptimization { +public: + /// @brief Default Constructor + ToAsciiOptimization() : LibCallOptimization("toascii", + "Number of 'toascii' calls simplified") {} + + /// @brief Make sure that the "fputs" function has the right prototype + virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC){ + // Just make sure this has 2 arguments + return (f->arg_size() == 1); + } + + /// @brief Perform the toascii optimization. + virtual bool OptimizeCall(CallInst *ci, SimplifyLibCalls &SLC) { + // toascii(c) -> (c & 0x7f) + Value *chr = ci->getOperand(1); + Value *and_inst = BinaryOperator::createAnd(chr, + ConstantInt::get(chr->getType(),0x7F),ci->getName()+".toascii",ci); + return ReplaceCallWith(ci, and_inst); + } +} ToAsciiOptimizer; + +/// This LibCallOptimization will simplify calls to the "ffs" library +/// calls which find the first set bit in an int, long, or long long. The +/// optimization is to compute the result at compile time if the argument is +/// a constant. +/// @brief Simplify the ffs library function. +struct VISIBILITY_HIDDEN FFSOptimization : public LibCallOptimization { +protected: + /// @brief Subclass Constructor + FFSOptimization(const char* funcName, const char* description) + : LibCallOptimization(funcName, description) {} + +public: + /// @brief Default Constructor + FFSOptimization() : LibCallOptimization("ffs", + "Number of 'ffs' calls simplified") {} + + /// @brief Make sure that the "ffs" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + // Just make sure this has 2 arguments + return F->arg_size() == 1 && F->getReturnType() == Type::Int32Ty; + } + + /// @brief Perform the ffs optimization. + virtual bool OptimizeCall(CallInst *TheCall, SimplifyLibCalls &SLC) { + if (ConstantInt *CI = dyn_cast(TheCall->getOperand(1))) { + // ffs(cnst) -> bit# + // ffsl(cnst) -> bit# + // ffsll(cnst) -> bit# + uint64_t val = CI->getZExtValue(); + int result = 0; + if (val) { + ++result; + while ((val & 1) == 0) { + ++result; + val >>= 1; + } + } + return ReplaceCallWith(TheCall, ConstantInt::get(Type::Int32Ty, result)); + } + + // ffs(x) -> x == 0 ? 0 : llvm.cttz(x)+1 + // ffsl(x) -> x == 0 ? 0 : llvm.cttz(x)+1 + // ffsll(x) -> x == 0 ? 0 : llvm.cttz(x)+1 + const Type *ArgType = TheCall->getOperand(1)->getType(); + const char *CTTZName; + assert(ArgType->getTypeID() == Type::IntegerTyID && + "llvm.cttz argument is not an integer?"); + unsigned BitWidth = cast(ArgType)->getBitWidth(); + if (BitWidth == 8) + CTTZName = "llvm.cttz.i8"; + else if (BitWidth == 16) + CTTZName = "llvm.cttz.i16"; + else if (BitWidth == 32) + CTTZName = "llvm.cttz.i32"; + else { + assert(BitWidth == 64 && "Unknown bitwidth"); + CTTZName = "llvm.cttz.i64"; + } + + Constant *F = SLC.getModule()->getOrInsertFunction(CTTZName, ArgType, + ArgType, NULL); + Value *V = CastInst::createIntegerCast(TheCall->getOperand(1), ArgType, + false/*ZExt*/, "tmp", TheCall); + Value *V2 = new CallInst(F, V, "tmp", TheCall); + V2 = CastInst::createIntegerCast(V2, Type::Int32Ty, false/*ZExt*/, + "tmp", TheCall); + V2 = BinaryOperator::createAdd(V2, ConstantInt::get(Type::Int32Ty, 1), + "tmp", TheCall); + Value *Cond = new ICmpInst(ICmpInst::ICMP_EQ, V, + Constant::getNullValue(V->getType()), "tmp", + TheCall); + V2 = new SelectInst(Cond, ConstantInt::get(Type::Int32Ty, 0), V2, + TheCall->getName(), TheCall); + return ReplaceCallWith(TheCall, V2); + } +} FFSOptimizer; + +/// This LibCallOptimization will simplify calls to the "ffsl" library +/// calls. It simply uses FFSOptimization for which the transformation is +/// identical. +/// @brief Simplify the ffsl library function. +struct VISIBILITY_HIDDEN FFSLOptimization : public FFSOptimization { +public: + /// @brief Default Constructor + FFSLOptimization() : FFSOptimization("ffsl", + "Number of 'ffsl' calls simplified") {} + +} FFSLOptimizer; + +/// This LibCallOptimization will simplify calls to the "ffsll" library +/// calls. It simply uses FFSOptimization for which the transformation is +/// identical. +/// @brief Simplify the ffsl library function. +struct VISIBILITY_HIDDEN FFSLLOptimization : public FFSOptimization { +public: + /// @brief Default Constructor + FFSLLOptimization() : FFSOptimization("ffsll", + "Number of 'ffsll' calls simplified") {} + +} FFSLLOptimizer; + +/// This optimizes unary functions that take and return doubles. +struct UnaryDoubleFPOptimizer : public LibCallOptimization { + UnaryDoubleFPOptimizer(const char *Fn, const char *Desc) + : LibCallOptimization(Fn, Desc) {} + + // Make sure that this function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + return F->arg_size() == 1 && F->arg_begin()->getType() == Type::DoubleTy && + F->getReturnType() == Type::DoubleTy; + } + + /// ShrinkFunctionToFloatVersion - If the input to this function is really a + /// float, strength reduce this to a float version of the function, + /// e.g. floor((double)FLT) -> (double)floorf(FLT). This can only be called + /// when the target supports the destination function and where there can be + /// no precision loss. + static bool ShrinkFunctionToFloatVersion(CallInst *CI, SimplifyLibCalls &SLC, + Constant *(SimplifyLibCalls::*FP)()){ + if (FPExtInst *Cast = dyn_cast(CI->getOperand(1))) + if (Cast->getOperand(0)->getType() == Type::FloatTy) { + Value *New = new CallInst((SLC.*FP)(), Cast->getOperand(0), + CI->getName(), CI); + New = new FPExtInst(New, Type::DoubleTy, CI->getName(), CI); + CI->replaceAllUsesWith(New); + CI->eraseFromParent(); + if (Cast->use_empty()) + Cast->eraseFromParent(); + return true; + } + return false; + } +}; + + +struct VISIBILITY_HIDDEN FloorOptimization : public UnaryDoubleFPOptimizer { + FloorOptimization() + : UnaryDoubleFPOptimizer("floor", "Number of 'floor' calls simplified") {} + + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { +#ifdef HAVE_FLOORF + // If this is a float argument passed in, convert to floorf. + if (ShrinkFunctionToFloatVersion(CI, SLC, &SimplifyLibCalls::get_floorf)) + return true; +#endif + return false; // opt failed + } +} FloorOptimizer; + +struct VISIBILITY_HIDDEN CeilOptimization : public UnaryDoubleFPOptimizer { + CeilOptimization() + : UnaryDoubleFPOptimizer("ceil", "Number of 'ceil' calls simplified") {} + + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { +#ifdef HAVE_CEILF + // If this is a float argument passed in, convert to ceilf. + if (ShrinkFunctionToFloatVersion(CI, SLC, &SimplifyLibCalls::get_ceilf)) + return true; +#endif + return false; // opt failed + } +} CeilOptimizer; + +struct VISIBILITY_HIDDEN RoundOptimization : public UnaryDoubleFPOptimizer { + RoundOptimization() + : UnaryDoubleFPOptimizer("round", "Number of 'round' calls simplified") {} + + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { +#ifdef HAVE_ROUNDF + // If this is a float argument passed in, convert to roundf. + if (ShrinkFunctionToFloatVersion(CI, SLC, &SimplifyLibCalls::get_roundf)) + return true; +#endif + return false; // opt failed + } +} RoundOptimizer; + +struct VISIBILITY_HIDDEN RintOptimization : public UnaryDoubleFPOptimizer { + RintOptimization() + : UnaryDoubleFPOptimizer("rint", "Number of 'rint' calls simplified") {} + + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { +#ifdef HAVE_RINTF + // If this is a float argument passed in, convert to rintf. + if (ShrinkFunctionToFloatVersion(CI, SLC, &SimplifyLibCalls::get_rintf)) + return true; +#endif + return false; // opt failed + } +} RintOptimizer; + +struct VISIBILITY_HIDDEN NearByIntOptimization : public UnaryDoubleFPOptimizer { + NearByIntOptimization() + : UnaryDoubleFPOptimizer("nearbyint", + "Number of 'nearbyint' calls simplified") {} + + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { +#ifdef HAVE_NEARBYINTF + // If this is a float argument passed in, convert to nearbyintf. + if (ShrinkFunctionToFloatVersion(CI, SLC,&SimplifyLibCalls::get_nearbyintf)) + return true; +#endif + return false; // opt failed + } +} NearByIntOptimizer; + +/// GetConstantStringInfo - This function computes the length of a +/// null-terminated constant array of integers. This function can't rely on the +/// size of the constant array because there could be a null terminator in the +/// middle of the array. +/// +/// We also have to bail out if we find a non-integer constant initializer +/// of one of the elements or if there is no null-terminator. The logic +/// below checks each of these conditions and will return true only if all +/// conditions are met. If the conditions aren't met, this returns false. +/// +/// If successful, the \p Array param is set to the constant array being +/// indexed, the \p Length parameter is set to the length of the null-terminated +/// string pointed to by V, the \p StartIdx value is set to the first +/// element of the Array that V points to, and true is returned. +static bool GetConstantStringInfo(Value *V, std::string &Str) { + // Look through noop bitcast instructions. + if (BitCastInst *BCI = dyn_cast(V)) { + if (BCI->getType() == BCI->getOperand(0)->getType()) + return GetConstantStringInfo(BCI->getOperand(0), Str); + return false; + } + + // If the value is not a GEP instruction nor a constant expression with a + // GEP instruction, then return false because ConstantArray can't occur + // any other way + User *GEP = 0; + if (GetElementPtrInst *GEPI = dyn_cast(V)) { + GEP = GEPI; + } else if (ConstantExpr *CE = dyn_cast(V)) { + if (CE->getOpcode() != Instruction::GetElementPtr) + return false; + GEP = CE; + } else { + return false; + } + + // Make sure the GEP has exactly three arguments. + if (GEP->getNumOperands() != 3) + return false; + + // Check to make sure that the first operand of the GEP is an integer and + // has value 0 so that we are sure we're indexing into the initializer. + if (ConstantInt *Idx = dyn_cast(GEP->getOperand(1))) { + if (!Idx->isZero()) + return false; + } else + return false; + + // If the second index isn't a ConstantInt, then this is a variable index + // into the array. If this occurs, we can't say anything meaningful about + // the string. + uint64_t StartIdx = 0; + if (ConstantInt *CI = dyn_cast(GEP->getOperand(2))) + StartIdx = CI->getZExtValue(); + else + return false; + + // The GEP instruction, constant or instruction, must reference a global + // variable that is a constant and is initialized. The referenced constant + // initializer is the array that we'll use for optimization. + GlobalVariable* GV = dyn_cast(GEP->getOperand(0)); + if (!GV || !GV->isConstant() || !GV->hasInitializer()) + return false; + Constant *GlobalInit = GV->getInitializer(); + + // Handle the ConstantAggregateZero case + if (isa(GlobalInit)) { + // This is a degenerate case. The initializer is constant zero so the + // length of the string must be zero. + Str.clear(); + return true; + } + + // Must be a Constant Array + ConstantArray *Array = dyn_cast(GlobalInit); + if (!Array) return false; + + // Get the number of elements in the array + uint64_t NumElts = Array->getType()->getNumElements(); + + // Traverse the constant array from StartIdx (derived above) which is + // the place the GEP refers to in the array. + for (unsigned i = StartIdx; i < NumElts; ++i) { + Constant *Elt = Array->getOperand(i); + ConstantInt *CI = dyn_cast(Elt); + if (!CI) // This array isn't suitable, non-int initializer. + return false; + if (CI->isZero()) + return true; // we found end of string, success! + Str += (char)CI->getZExtValue(); + } + + return false; // The array isn't null terminated. +} + +/// CastToCStr - Return V if it is an sbyte*, otherwise cast it to sbyte*, +/// inserting the cast before IP, and return the cast. +/// @brief Cast a value to a "C" string. +static Value *CastToCStr(Value *V, Instruction *IP) { + assert(isa(V->getType()) && + "Can't cast non-pointer type to C string type"); + const Type *SBPTy = PointerType::get(Type::Int8Ty); + if (V->getType() != SBPTy) + return new BitCastInst(V, SBPTy, V->getName(), IP); + return V; +} + +// TODO: +// Additional cases that we need to add to this file: +// +// cbrt: +// * cbrt(expN(X)) -> expN(x/3) +// * cbrt(sqrt(x)) -> pow(x,1/6) +// * cbrt(sqrt(x)) -> pow(x,1/9) +// +// cos, cosf, cosl: +// * cos(-x) -> cos(x) +// +// exp, expf, expl: +// * exp(log(x)) -> x +// +// log, logf, logl: +// * log(exp(x)) -> x +// * log(x**y) -> y*log(x) +// * log(exp(y)) -> y*log(e) +// * log(exp2(y)) -> y*log(2) +// * log(exp10(y)) -> y*log(10) +// * log(sqrt(x)) -> 0.5*log(x) +// * log(pow(x,y)) -> y*log(x) +// +// lround, lroundf, lroundl: +// * lround(cnst) -> cnst' +// +// memcmp: +// * memcmp(x,y,l) -> cnst +// (if all arguments are constant and strlen(x) <= l and strlen(y) <= l) +// +// memmove: +// * memmove(d,s,l,a) -> memcpy(d,s,l,a) +// (if s is a global constant array) +// +// pow, powf, powl: +// * pow(exp(x),y) -> exp(x*y) +// * pow(sqrt(x),y) -> pow(x,y*0.5) +// * pow(pow(x,y),z)-> pow(x,y*z) +// +// puts: +// * puts("") -> putchar("\n") +// +// round, roundf, roundl: +// * round(cnst) -> cnst' +// +// signbit: +// * signbit(cnst) -> cnst' +// * signbit(nncst) -> 0 (if pstv is a non-negative constant) +// +// sqrt, sqrtf, sqrtl: +// * sqrt(expN(x)) -> expN(x*0.5) +// * sqrt(Nroot(x)) -> pow(x,1/(2*N)) +// * sqrt(pow(x,y)) -> pow(|x|,y*0.5) +// +// stpcpy: +// * stpcpy(str, "literal") -> +// llvm.memcpy(str,"literal",strlen("literal")+1,1) +// strrchr: +// * strrchr(s,c) -> reverse_offset_of_in(c,s) +// (if c is a constant integer and s is a constant string) +// * strrchr(s1,0) -> strchr(s1,0) +// +// strncat: +// * strncat(x,y,0) -> x +// * strncat(x,y,0) -> x (if strlen(y) = 0) +// * strncat(x,y,l) -> strcat(x,y) (if y and l are constants an l > strlen(y)) +// +// strncpy: +// * strncpy(d,s,0) -> d +// * strncpy(d,s,l) -> memcpy(d,s,l,1) +// (if s and l are constants) +// +// strpbrk: +// * strpbrk(s,a) -> offset_in_for(s,a) +// (if s and a are both constant strings) +// * strpbrk(s,"") -> 0 +// * strpbrk(s,a) -> strchr(s,a[0]) (if a is constant string of length 1) +// +// strspn, strcspn: +// * strspn(s,a) -> const_int (if both args are constant) +// * strspn("",a) -> 0 +// * strspn(s,"") -> 0 +// * strcspn(s,a) -> const_int (if both args are constant) +// * strcspn("",a) -> 0 +// * strcspn(s,"") -> strlen(a) +// +// strstr: +// * strstr(x,x) -> x +// * strstr(s1,s2) -> offset_of_s2_in(s1) +// (if s1 and s2 are constant strings) +// +// tan, tanf, tanl: +// * tan(atan(x)) -> x +// +// trunc, truncf, truncl: +// * trunc(cnst) -> cnst' +// +// +} diff --git a/lib/Transforms/IPO/StripDeadPrototypes.cpp b/lib/Transforms/IPO/StripDeadPrototypes.cpp new file mode 100644 index 0000000..9851b26 --- /dev/null +++ b/lib/Transforms/IPO/StripDeadPrototypes.cpp @@ -0,0 +1,70 @@ +//===-- StripDeadPrototypes.cpp - Removed unused function declarations ----===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by Reid Spencer and is distributed under the +// University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass loops over all of the functions in the input module, looking for +// dead declarations and removes them. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "strip-dead-prototypes" +#include "llvm/Transforms/IPO.h" +#include "llvm/Pass.h" +#include "llvm/Module.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumDeadPrototypes, "Number of dead prototypes removed"); + +namespace { + +/// @brief Pass to remove unused function declarations. +class VISIBILITY_HIDDEN StripDeadPrototypesPass : public ModulePass { +public: + static char ID; // Pass identification, replacement for typeid + StripDeadPrototypesPass() : ModulePass((intptr_t)&ID) { } + virtual bool runOnModule(Module &M); +}; + +char StripDeadPrototypesPass::ID = 0; +RegisterPass X("strip-dead-prototypes", + "Strip Unused Function Prototypes"); + +} // end anonymous namespace + +bool StripDeadPrototypesPass::runOnModule(Module &M) { + bool MadeChange = false; + + // Erase dead function prototypes. + for (Module::iterator I = M.begin(), E = M.end(); I != E; ) { + Function *F = I++; + // Function must be a prototype and unused. + if (F->isDeclaration() && F->use_empty()) { + F->eraseFromParent(); + ++NumDeadPrototypes; + MadeChange = true; + } + } + + // Erase dead global var prototypes. + for (Module::global_iterator I = M.global_begin(), E = M.global_end(); + I != E; ) { + GlobalVariable *GV = I++; + // Global must be a prototype and unused. + if (GV->isDeclaration() && GV->use_empty()) + GV->eraseFromParent(); + } + + // Return an indication of whether we changed anything or not. + return MadeChange; +} + +ModulePass *llvm::createStripDeadPrototypesPass() { + return new StripDeadPrototypesPass(); +} diff --git a/lib/Transforms/IPO/StripSymbols.cpp b/lib/Transforms/IPO/StripSymbols.cpp new file mode 100644 index 0000000..c8f8926 --- /dev/null +++ b/lib/Transforms/IPO/StripSymbols.cpp @@ -0,0 +1,206 @@ +//===- StripSymbols.cpp - Strip symbols and debug info from a module ------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements stripping symbols out of symbol tables. +// +// Specifically, this allows you to strip all of the symbols out of: +// * All functions in a module +// * All non-essential symbols in a module (all function symbols + all module +// scope symbols) +// * Debug information. +// +// Notice that: +// * This pass makes code much less readable, so it should only be used in +// situations where the 'strip' utility would be used (such as reducing +// code size, and making it harder to reverse engineer code). +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/ValueSymbolTable.h" +#include "llvm/TypeSymbolTable.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +namespace { + class VISIBILITY_HIDDEN StripSymbols : public ModulePass { + bool OnlyDebugInfo; + public: + static char ID; // Pass identification, replacement for typeid + StripSymbols(bool ODI = false) + : ModulePass((intptr_t)&ID), OnlyDebugInfo(ODI) {} + + virtual bool runOnModule(Module &M); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + } + }; + + char StripSymbols::ID = 0; + RegisterPass X("strip", "Strip all symbols from a module"); +} + +ModulePass *llvm::createStripSymbolsPass(bool OnlyDebugInfo) { + return new StripSymbols(OnlyDebugInfo); +} + +static void RemoveDeadConstant(Constant *C) { + assert(C->use_empty() && "Constant is not dead!"); + std::vector Operands; + for (unsigned i = 0, e = C->getNumOperands(); i != e; ++i) + if (isa(C->getOperand(i)->getType()) && + C->getOperand(i)->hasOneUse()) + Operands.push_back(C->getOperand(i)); + if (GlobalVariable *GV = dyn_cast(C)) { + if (!GV->hasInternalLinkage()) return; // Don't delete non static globals. + GV->eraseFromParent(); + } + else if (!isa(C)) + C->destroyConstant(); + + // If the constant referenced anything, see if we can delete it as well. + while (!Operands.empty()) { + RemoveDeadConstant(Operands.back()); + Operands.pop_back(); + } +} + +// Strip the symbol table of its names. +// +static void StripSymtab(ValueSymbolTable &ST) { + for (ValueSymbolTable::iterator VI = ST.begin(), VE = ST.end(); VI != VE; ) { + Value *V = VI->getValue(); + ++VI; + if (!isa(V) || cast(V)->hasInternalLinkage()) { + // Set name to "", removing from symbol table! + V->setName(""); + } + } +} + +// Strip the symbol table of its names. +static void StripTypeSymtab(TypeSymbolTable &ST) { + for (TypeSymbolTable::iterator TI = ST.begin(), E = ST.end(); TI != E; ) + ST.remove(TI++); +} + + + +bool StripSymbols::runOnModule(Module &M) { + // If we're not just stripping debug info, strip all symbols from the + // functions and the names from any internal globals. + if (!OnlyDebugInfo) { + for (Module::global_iterator I = M.global_begin(), E = M.global_end(); + I != E; ++I) + if (I->hasInternalLinkage()) + I->setName(""); // Internal symbols can't participate in linkage + + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) { + if (I->hasInternalLinkage()) + I->setName(""); // Internal symbols can't participate in linkage + StripSymtab(I->getValueSymbolTable()); + } + + // Remove all names from types. + StripTypeSymtab(M.getTypeSymbolTable()); + } + + // Strip debug info in the module if it exists. To do this, we remove + // llvm.dbg.func.start, llvm.dbg.stoppoint, and llvm.dbg.region.end calls, and + // any globals they point to if now dead. + Function *FuncStart = M.getFunction("llvm.dbg.func.start"); + Function *StopPoint = M.getFunction("llvm.dbg.stoppoint"); + Function *RegionStart = M.getFunction("llvm.dbg.region.start"); + Function *RegionEnd = M.getFunction("llvm.dbg.region.end"); + Function *Declare = M.getFunction("llvm.dbg.declare"); + if (!FuncStart && !StopPoint && !RegionStart && !RegionEnd && !Declare) + return true; + + std::vector DeadGlobals; + + // Remove all of the calls to the debugger intrinsics, and remove them from + // the module. + if (FuncStart) { + while (!FuncStart->use_empty()) { + CallInst *CI = cast(FuncStart->use_back()); + Value *Arg = CI->getOperand(1); + assert(CI->use_empty() && "llvm.dbg intrinsic should have void result"); + CI->eraseFromParent(); + if (Arg->use_empty()) + if (GlobalVariable *GV = dyn_cast(Arg)) + DeadGlobals.push_back(GV); + } + FuncStart->eraseFromParent(); + } + if (StopPoint) { + while (!StopPoint->use_empty()) { + CallInst *CI = cast(StopPoint->use_back()); + Value *Arg = CI->getOperand(3); + assert(CI->use_empty() && "llvm.dbg intrinsic should have void result"); + CI->eraseFromParent(); + if (Arg->use_empty()) + if (GlobalVariable *GV = dyn_cast(Arg)) + DeadGlobals.push_back(GV); + } + StopPoint->eraseFromParent(); + } + if (RegionStart) { + while (!RegionStart->use_empty()) { + CallInst *CI = cast(RegionStart->use_back()); + Value *Arg = CI->getOperand(1); + assert(CI->use_empty() && "llvm.dbg intrinsic should have void result"); + CI->eraseFromParent(); + if (Arg->use_empty()) + if (GlobalVariable *GV = dyn_cast(Arg)) + DeadGlobals.push_back(GV); + } + RegionStart->eraseFromParent(); + } + if (RegionEnd) { + while (!RegionEnd->use_empty()) { + CallInst *CI = cast(RegionEnd->use_back()); + Value *Arg = CI->getOperand(1); + assert(CI->use_empty() && "llvm.dbg intrinsic should have void result"); + CI->eraseFromParent(); + if (Arg->use_empty()) + if (GlobalVariable *GV = dyn_cast(Arg)) + DeadGlobals.push_back(GV); + } + RegionEnd->eraseFromParent(); + } + if (Declare) { + while (!Declare->use_empty()) { + CallInst *CI = cast(Declare->use_back()); + Value *Arg = CI->getOperand(2); + assert(CI->use_empty() && "llvm.dbg intrinsic should have void result"); + CI->eraseFromParent(); + if (Arg->use_empty()) + if (GlobalVariable *GV = dyn_cast(Arg)) + DeadGlobals.push_back(GV); + } + Declare->eraseFromParent(); + } + + // Finally, delete any internal globals that were only used by the debugger + // intrinsics. + while (!DeadGlobals.empty()) { + GlobalVariable *GV = DeadGlobals.back(); + DeadGlobals.pop_back(); + if (GV->hasInternalLinkage()) + RemoveDeadConstant(GV); + } + + return true; +} diff --git a/lib/Transforms/Instrumentation/BlockProfiling.cpp b/lib/Transforms/Instrumentation/BlockProfiling.cpp new file mode 100644 index 0000000..f772dd4 --- /dev/null +++ b/lib/Transforms/Instrumentation/BlockProfiling.cpp @@ -0,0 +1,126 @@ +//===- BlockProfiling.cpp - Insert counters for block profiling -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass instruments the specified program with counters for basic block or +// function profiling. This is the most basic form of profiling, which can tell +// which blocks are hot, but cannot reliably detect hot paths through the CFG. +// Block profiling counts the number of times each basic block executes, and +// function profiling counts the number of times each function is called. +// +// Note that this implementation is very naive. Control equivalent regions of +// the CFG should not require duplicate counters, but we do put duplicate +// counters in. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Streams.h" +#include "llvm/Transforms/Instrumentation.h" +#include "RSProfiling.h" +#include "ProfilingUtils.h" +using namespace llvm; + +namespace { + class VISIBILITY_HIDDEN FunctionProfiler : public RSProfilers_std { + public: + static char ID; + bool runOnModule(Module &M); + }; + + char FunctionProfiler::ID = 0; + + RegisterPass X("insert-function-profiling", + "Insert instrumentation for function profiling"); + RegisterAnalysisGroup XG(X); + +} + +ModulePass *llvm::createFunctionProfilerPass() { + return new FunctionProfiler(); +} + +bool FunctionProfiler::runOnModule(Module &M) { + Function *Main = M.getFunction("main"); + if (Main == 0) { + cerr << "WARNING: cannot insert function profiling into a module" + << " with no main function!\n"; + return false; // No main, no instrumentation! + } + + unsigned NumFunctions = 0; + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) + if (!I->isDeclaration()) + ++NumFunctions; + + const Type *ATy = ArrayType::get(Type::Int32Ty, NumFunctions); + GlobalVariable *Counters = + new GlobalVariable(ATy, false, GlobalValue::InternalLinkage, + Constant::getNullValue(ATy), "FuncProfCounters", &M); + + // Instrument all of the functions... + unsigned i = 0; + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) + if (!I->isDeclaration()) + // Insert counter at the start of the function + IncrementCounterInBlock(I->begin(), i++, Counters); + + // Add the initialization call to main. + InsertProfilingInitCall(Main, "llvm_start_func_profiling", Counters); + return true; +} + + +namespace { + class BlockProfiler : public RSProfilers_std { + bool runOnModule(Module &M); + public: + static char ID; + }; + + char BlockProfiler::ID = 0; + RegisterPass Y("insert-block-profiling", + "Insert instrumentation for block profiling"); + RegisterAnalysisGroup YG(Y); +} + +ModulePass *llvm::createBlockProfilerPass() { return new BlockProfiler(); } + +bool BlockProfiler::runOnModule(Module &M) { + Function *Main = M.getFunction("main"); + if (Main == 0) { + cerr << "WARNING: cannot insert block profiling into a module" + << " with no main function!\n"; + return false; // No main, no instrumentation! + } + + unsigned NumBlocks = 0; + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) + NumBlocks += I->size(); + + const Type *ATy = ArrayType::get(Type::Int32Ty, NumBlocks); + GlobalVariable *Counters = + new GlobalVariable(ATy, false, GlobalValue::InternalLinkage, + Constant::getNullValue(ATy), "BlockProfCounters", &M); + + // Instrument all of the blocks... + unsigned i = 0; + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) + for (Function::iterator BB = I->begin(), E = I->end(); BB != E; ++BB) + // Insert counter at the start of the block + IncrementCounterInBlock(BB, i++, Counters); + + // Add the initialization call to main. + InsertProfilingInitCall(Main, "llvm_start_block_profiling", Counters); + return true; +} + diff --git a/lib/Transforms/Instrumentation/EdgeProfiling.cpp b/lib/Transforms/Instrumentation/EdgeProfiling.cpp new file mode 100644 index 0000000..360d2b7 --- /dev/null +++ b/lib/Transforms/Instrumentation/EdgeProfiling.cpp @@ -0,0 +1,101 @@ +//===- EdgeProfiling.cpp - Insert counters for edge profiling -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass instruments the specified program with counters for edge profiling. +// Edge profiling can give a reasonable approximation of the hot paths through a +// program, and is used for a wide variety of program transformations. +// +// Note that this implementation is very naive. We insert a counter for *every* +// edge in the program, instead of using control flow information to prune the +// number of counters inserted. +// +//===----------------------------------------------------------------------===// + +#include "ProfilingUtils.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Streams.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Instrumentation.h" +#include +using namespace llvm; + +namespace { + class VISIBILITY_HIDDEN EdgeProfiler : public ModulePass { + bool runOnModule(Module &M); + public: + static char ID; // Pass identification, replacement for typeid + EdgeProfiler() : ModulePass((intptr_t)&ID) {} + }; + + char EdgeProfiler::ID = 0; + RegisterPass X("insert-edge-profiling", + "Insert instrumentation for edge profiling"); +} + +ModulePass *llvm::createEdgeProfilerPass() { return new EdgeProfiler(); } + +bool EdgeProfiler::runOnModule(Module &M) { + Function *Main = M.getFunction("main"); + if (Main == 0) { + cerr << "WARNING: cannot insert edge profiling into a module" + << " with no main function!\n"; + return false; // No main, no instrumentation! + } + + std::set BlocksToInstrument; + unsigned NumEdges = 0; + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { + // Keep track of which blocks need to be instrumented. We don't want to + // instrument blocks that are added as the result of breaking critical + // edges! + BlocksToInstrument.insert(BB); + NumEdges += BB->getTerminator()->getNumSuccessors(); + } + + const Type *ATy = ArrayType::get(Type::Int32Ty, NumEdges); + GlobalVariable *Counters = + new GlobalVariable(ATy, false, GlobalValue::InternalLinkage, + Constant::getNullValue(ATy), "EdgeProfCounters", &M); + + // Instrument all of the edges... + unsigned i = 0; + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + if (BlocksToInstrument.count(BB)) { // Don't instrument inserted blocks + // Okay, we have to add a counter of each outgoing edge. If the + // outgoing edge is not critical don't split it, just insert the counter + // in the source or destination of the edge. + TerminatorInst *TI = BB->getTerminator(); + for (unsigned s = 0, e = TI->getNumSuccessors(); s != e; ++s) { + // If the edge is critical, split it. + SplitCriticalEdge(TI, s, this); + + // Okay, we are guaranteed that the edge is no longer critical. If we + // only have a single successor, insert the counter in this block, + // otherwise insert it in the successor block. + if (TI->getNumSuccessors() == 0) { + // Insert counter at the start of the block + IncrementCounterInBlock(BB, i++, Counters); + } else { + // Insert counter at the start of the block + IncrementCounterInBlock(TI->getSuccessor(s), i++, Counters); + } + } + } + + // Add the initialization call to main. + InsertProfilingInitCall(Main, "llvm_start_edge_profiling", Counters); + return true; +} + diff --git a/lib/Transforms/Instrumentation/Makefile b/lib/Transforms/Instrumentation/Makefile new file mode 100644 index 0000000..bf5c3d3 --- /dev/null +++ b/lib/Transforms/Instrumentation/Makefile @@ -0,0 +1,15 @@ +##===- lib/Transforms/Instrumentation/Makefile -------------*- Makefile -*-===## +# +# The LLVM Compiler Infrastructure +# +# This file was developed by the LLVM research group and is distributed under +# the University of Illinois Open Source License. See LICENSE.TXT for details. +# +##===----------------------------------------------------------------------===## + +LEVEL = ../../.. +LIBRARYNAME = LLVMInstrumentation +BUILD_ARCHIVE = 1 + +include $(LEVEL)/Makefile.common + diff --git a/lib/Transforms/Instrumentation/ProfilingUtils.cpp b/lib/Transforms/Instrumentation/ProfilingUtils.cpp new file mode 100644 index 0000000..54ea803 --- /dev/null +++ b/lib/Transforms/Instrumentation/ProfilingUtils.cpp @@ -0,0 +1,119 @@ +//===- ProfilingUtils.cpp - Helper functions shared by profilers ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This files implements a few helper functions which are used by profile +// instrumentation code to instrument the code. This allows the profiler pass +// to worry about *what* to insert, and these functions take care of *how* to do +// it. +// +//===----------------------------------------------------------------------===// + +#include "ProfilingUtils.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" + +void llvm::InsertProfilingInitCall(Function *MainFn, const char *FnName, + GlobalValue *Array) { + const Type *ArgVTy = PointerType::get(PointerType::get(Type::Int8Ty)); + const PointerType *UIntPtr = PointerType::get(Type::Int32Ty); + Module &M = *MainFn->getParent(); + Constant *InitFn = M.getOrInsertFunction(FnName, Type::Int32Ty, Type::Int32Ty, + ArgVTy, UIntPtr, Type::Int32Ty, + (Type *)0); + + // This could force argc and argv into programs that wouldn't otherwise have + // them, but instead we just pass null values in. + std::vector Args(4); + Args[0] = Constant::getNullValue(Type::Int32Ty); + Args[1] = Constant::getNullValue(ArgVTy); + + // Skip over any allocas in the entry block. + BasicBlock *Entry = MainFn->begin(); + BasicBlock::iterator InsertPos = Entry->begin(); + while (isa(InsertPos)) ++InsertPos; + + std::vector GEPIndices(2, Constant::getNullValue(Type::Int32Ty)); + unsigned NumElements = 0; + if (Array) { + Args[2] = ConstantExpr::getGetElementPtr(Array, &GEPIndices[0], + GEPIndices.size()); + NumElements = + cast(Array->getType()->getElementType())->getNumElements(); + } else { + // If this profiling instrumentation doesn't have a constant array, just + // pass null. + Args[2] = ConstantPointerNull::get(UIntPtr); + } + Args[3] = ConstantInt::get(Type::Int32Ty, NumElements); + + Instruction *InitCall = new CallInst(InitFn, &Args[0], Args.size(), + "newargc", InsertPos); + + // If argc or argv are not available in main, just pass null values in. + Function::arg_iterator AI; + switch (MainFn->arg_size()) { + default: + case 2: + AI = MainFn->arg_begin(); ++AI; + if (AI->getType() != ArgVTy) { + Instruction::CastOps opcode = CastInst::getCastOpcode(AI, false, ArgVTy, + false); + InitCall->setOperand(2, + CastInst::create(opcode, AI, ArgVTy, "argv.cast", InitCall)); + } else { + InitCall->setOperand(2, AI); + } + /* FALL THROUGH */ + + case 1: + AI = MainFn->arg_begin(); + // If the program looked at argc, have it look at the return value of the + // init call instead. + if (AI->getType() != Type::Int32Ty) { + Instruction::CastOps opcode; + if (!AI->use_empty()) { + opcode = CastInst::getCastOpcode(InitCall, true, AI->getType(), true); + AI->replaceAllUsesWith( + CastInst::create(opcode, InitCall, AI->getType(), "", InsertPos)); + } + opcode = CastInst::getCastOpcode(AI, true, Type::Int32Ty, true); + InitCall->setOperand(1, + CastInst::create(opcode, AI, Type::Int32Ty, "argc.cast", InitCall)); + } else { + AI->replaceAllUsesWith(InitCall); + InitCall->setOperand(1, AI); + } + + case 0: break; + } +} + +void llvm::IncrementCounterInBlock(BasicBlock *BB, unsigned CounterNum, + GlobalValue *CounterArray) { + // Insert the increment after any alloca or PHI instructions... + BasicBlock::iterator InsertPos = BB->begin(); + while (isa(InsertPos) || isa(InsertPos)) + ++InsertPos; + + // Create the getelementptr constant expression + std::vector Indices(2); + Indices[0] = Constant::getNullValue(Type::Int32Ty); + Indices[1] = ConstantInt::get(Type::Int32Ty, CounterNum); + Constant *ElementPtr = + ConstantExpr::getGetElementPtr(CounterArray, &Indices[0], Indices.size()); + + // Load, increment and store the value back. + Value *OldVal = new LoadInst(ElementPtr, "OldFuncCounter", InsertPos); + Value *NewVal = BinaryOperator::create(Instruction::Add, OldVal, + ConstantInt::get(Type::Int32Ty, 1), + "NewFuncCounter", InsertPos); + new StoreInst(NewVal, ElementPtr, InsertPos); +} diff --git a/lib/Transforms/Instrumentation/ProfilingUtils.h b/lib/Transforms/Instrumentation/ProfilingUtils.h new file mode 100644 index 0000000..52c6e04 --- /dev/null +++ b/lib/Transforms/Instrumentation/ProfilingUtils.h @@ -0,0 +1,31 @@ +//===- ProfilingUtils.h - Helper functions shared by profilers --*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This files defines a few helper functions which are used by profile +// instrumentation code to instrument the code. This allows the profiler pass +// to worry about *what* to insert, and these functions take care of *how* to do +// it. +// +//===----------------------------------------------------------------------===// + +#ifndef PROFILINGUTILS_H +#define PROFILINGUTILS_H + +namespace llvm { + class Function; + class GlobalValue; + class BasicBlock; + + void InsertProfilingInitCall(Function *MainFn, const char *FnName, + GlobalValue *Arr = 0); + void IncrementCounterInBlock(BasicBlock *BB, unsigned CounterNum, + GlobalValue *CounterArray); +} + +#endif diff --git a/lib/Transforms/Instrumentation/RSProfiling.cpp b/lib/Transforms/Instrumentation/RSProfiling.cpp new file mode 100644 index 0000000..3c7efb1 --- /dev/null +++ b/lib/Transforms/Instrumentation/RSProfiling.cpp @@ -0,0 +1,650 @@ +//===- RSProfiling.cpp - Various profiling using random sampling ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// These passes implement a random sampling based profiling. Different methods +// of choosing when to sample are supported, as well as different types of +// profiling. This is done as two passes. The first is a sequence of profiling +// passes which insert profiling into the program, and remember what they +// inserted. +// +// The second stage duplicates all instructions in a function, ignoring the +// profiling code, then connects the two versions togeather at the entry and at +// backedges. At each connection point a choice is made as to whether to jump +// to the profiled code (take a sample) or execute the unprofiled code. +// +// It is highly recommeneded that after this pass one runs mem2reg and adce +// (instcombine load-vn gdce dse also are good to run afterwards) +// +// This design is intended to make the profiling passes independent of the RS +// framework, but any profiling pass that implements the RSProfiling interface +// is compatible with the rs framework (and thus can be sampled) +// +// TODO: obviously the block and function profiling are almost identical to the +// existing ones, so they can be unified (esp since these passes are valid +// without the rs framework). +// TODO: Fix choice code so that frequency is not hard coded +// +//===----------------------------------------------------------------------===// + +#include "llvm/Pass.h" +#include "llvm/Module.h" +#include "llvm/Instructions.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Instrumentation.h" +#include "RSProfiling.h" +#include +#include +#include +#include +using namespace llvm; + +namespace { + enum RandomMeth { + GBV, GBVO, HOSTCC + }; + + cl::opt RandomMethod("profile-randomness", + cl::desc("How to randomly choose to profile:"), + cl::values( + clEnumValN(GBV, "global", "global counter"), + clEnumValN(GBVO, "ra_global", + "register allocated global counter"), + clEnumValN(HOSTCC, "rdcc", "cycle counter"), + clEnumValEnd)); + + /// NullProfilerRS - The basic profiler that does nothing. It is the default + /// profiler and thus terminates RSProfiler chains. It is useful for + /// measuring framework overhead + class VISIBILITY_HIDDEN NullProfilerRS : public RSProfilers { + public: + static char ID; // Pass identification, replacement for typeid + bool isProfiling(Value* v) { + return false; + } + bool runOnModule(Module &M) { + return false; + } + void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + } + }; + + static RegisterAnalysisGroup A("Profiling passes"); + static RegisterPass NP("insert-null-profiling-rs", + "Measure profiling framework overhead"); + static RegisterAnalysisGroup NPT(NP); + + /// Chooser - Something that chooses when to make a sample of the profiled code + class VISIBILITY_HIDDEN Chooser { + public: + /// ProcessChoicePoint - is called for each basic block inserted to choose + /// between normal and sample code + virtual void ProcessChoicePoint(BasicBlock*) = 0; + /// PrepFunction - is called once per function before other work is done. + /// This gives the opertunity to insert new allocas and such. + virtual void PrepFunction(Function*) = 0; + virtual ~Chooser() {} + }; + + //Things that implement sampling policies + //A global value that is read-mod-stored to choose when to sample. + //A sample is taken when the global counter hits 0 + class VISIBILITY_HIDDEN GlobalRandomCounter : public Chooser { + GlobalVariable* Counter; + Value* ResetValue; + const Type* T; + public: + GlobalRandomCounter(Module& M, const Type* t, uint64_t resetval); + virtual ~GlobalRandomCounter(); + virtual void PrepFunction(Function* F); + virtual void ProcessChoicePoint(BasicBlock* bb); + }; + + //Same is GRC, but allow register allocation of the global counter + class VISIBILITY_HIDDEN GlobalRandomCounterOpt : public Chooser { + GlobalVariable* Counter; + Value* ResetValue; + AllocaInst* AI; + const Type* T; + public: + GlobalRandomCounterOpt(Module& M, const Type* t, uint64_t resetval); + virtual ~GlobalRandomCounterOpt(); + virtual void PrepFunction(Function* F); + virtual void ProcessChoicePoint(BasicBlock* bb); + }; + + //Use the cycle counter intrinsic as a source of pseudo randomness when + //deciding when to sample. + class VISIBILITY_HIDDEN CycleCounter : public Chooser { + uint64_t rm; + Constant *F; + public: + CycleCounter(Module& m, uint64_t resetmask); + virtual ~CycleCounter(); + virtual void PrepFunction(Function* F); + virtual void ProcessChoicePoint(BasicBlock* bb); + }; + + /// ProfilerRS - Insert the random sampling framework + struct VISIBILITY_HIDDEN ProfilerRS : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + ProfilerRS() : FunctionPass((intptr_t)&ID) {} + + std::map TransCache; + std::set ChoicePoints; + Chooser* c; + + //Translate and duplicate values for the new profile free version of stuff + Value* Translate(Value* v); + //Duplicate an entire function (with out profiling) + void Duplicate(Function& F, RSProfilers& LI); + //Called once for each backedge, handle the insertion of choice points and + //the interconection of the two versions of the code + void ProcessBackEdge(BasicBlock* src, BasicBlock* dst, Function& F); + bool runOnFunction(Function& F); + bool doInitialization(Module &M); + virtual void getAnalysisUsage(AnalysisUsage &AU) const; + }; + + RegisterPass X("insert-rs-profiling-framework", + "Insert random sampling instrumentation framework"); +} + +char RSProfilers::ID = 0; +char NullProfilerRS::ID = 0; +char ProfilerRS::ID = 0; + +//Local utilities +static void ReplacePhiPred(BasicBlock* btarget, + BasicBlock* bold, BasicBlock* bnew); + +static void CollapsePhi(BasicBlock* btarget, BasicBlock* bsrc); + +template +static void recBackEdge(BasicBlock* bb, T& BackEdges, + std::map& color, + std::map& depth, + std::map& finish, + int& time); + +//find the back edges and where they go to +template +static void getBackEdges(Function& F, T& BackEdges); + + +/////////////////////////////////////// +// Methods of choosing when to profile +/////////////////////////////////////// + +GlobalRandomCounter::GlobalRandomCounter(Module& M, const Type* t, + uint64_t resetval) : T(t) { + ConstantInt* Init = ConstantInt::get(T, resetval); + ResetValue = Init; + Counter = new GlobalVariable(T, false, GlobalValue::InternalLinkage, + Init, "RandomSteeringCounter", &M); +} + +GlobalRandomCounter::~GlobalRandomCounter() {} + +void GlobalRandomCounter::PrepFunction(Function* F) {} + +void GlobalRandomCounter::ProcessChoicePoint(BasicBlock* bb) { + BranchInst* t = cast(bb->getTerminator()); + + //decrement counter + LoadInst* l = new LoadInst(Counter, "counter", t); + + ICmpInst* s = new ICmpInst(ICmpInst::ICMP_EQ, l, ConstantInt::get(T, 0), + "countercc", t); + + Value* nv = BinaryOperator::createSub(l, ConstantInt::get(T, 1), + "counternew", t); + new StoreInst(nv, Counter, t); + t->setCondition(s); + + //reset counter + BasicBlock* oldnext = t->getSuccessor(0); + BasicBlock* resetblock = new BasicBlock("reset", oldnext->getParent(), + oldnext); + TerminatorInst* t2 = new BranchInst(oldnext, resetblock); + t->setSuccessor(0, resetblock); + new StoreInst(ResetValue, Counter, t2); + ReplacePhiPred(oldnext, bb, resetblock); +} + +GlobalRandomCounterOpt::GlobalRandomCounterOpt(Module& M, const Type* t, + uint64_t resetval) + : AI(0), T(t) { + ConstantInt* Init = ConstantInt::get(T, resetval); + ResetValue = Init; + Counter = new GlobalVariable(T, false, GlobalValue::InternalLinkage, + Init, "RandomSteeringCounter", &M); +} + +GlobalRandomCounterOpt::~GlobalRandomCounterOpt() {} + +void GlobalRandomCounterOpt::PrepFunction(Function* F) { + //make a local temporary to cache the global + BasicBlock& bb = F->getEntryBlock(); + BasicBlock::iterator InsertPt = bb.begin(); + AI = new AllocaInst(T, 0, "localcounter", InsertPt); + LoadInst* l = new LoadInst(Counter, "counterload", InsertPt); + new StoreInst(l, AI, InsertPt); + + //modify all functions and return values to restore the local variable to/from + //the global variable + for(Function::iterator fib = F->begin(), fie = F->end(); + fib != fie; ++fib) + for(BasicBlock::iterator bib = fib->begin(), bie = fib->end(); + bib != bie; ++bib) + if (isa(bib)) { + LoadInst* l = new LoadInst(AI, "counter", bib); + new StoreInst(l, Counter, bib); + l = new LoadInst(Counter, "counter", ++bib); + new StoreInst(l, AI, bib--); + } else if (isa(bib)) { + LoadInst* l = new LoadInst(AI, "counter", bib); + new StoreInst(l, Counter, bib); + + BasicBlock* bb = cast(bib)->getNormalDest(); + BasicBlock::iterator i = bb->begin(); + while (isa(i)) + ++i; + l = new LoadInst(Counter, "counter", i); + + bb = cast(bib)->getUnwindDest(); + i = bb->begin(); + while (isa(i)) ++i; + l = new LoadInst(Counter, "counter", i); + new StoreInst(l, AI, i); + } else if (isa(&*bib) || isa(&*bib)) { + LoadInst* l = new LoadInst(AI, "counter", bib); + new StoreInst(l, Counter, bib); + } +} + +void GlobalRandomCounterOpt::ProcessChoicePoint(BasicBlock* bb) { + BranchInst* t = cast(bb->getTerminator()); + + //decrement counter + LoadInst* l = new LoadInst(AI, "counter", t); + + ICmpInst* s = new ICmpInst(ICmpInst::ICMP_EQ, l, ConstantInt::get(T, 0), + "countercc", t); + + Value* nv = BinaryOperator::createSub(l, ConstantInt::get(T, 1), + "counternew", t); + new StoreInst(nv, AI, t); + t->setCondition(s); + + //reset counter + BasicBlock* oldnext = t->getSuccessor(0); + BasicBlock* resetblock = new BasicBlock("reset", oldnext->getParent(), + oldnext); + TerminatorInst* t2 = new BranchInst(oldnext, resetblock); + t->setSuccessor(0, resetblock); + new StoreInst(ResetValue, AI, t2); + ReplacePhiPred(oldnext, bb, resetblock); +} + + +CycleCounter::CycleCounter(Module& m, uint64_t resetmask) : rm(resetmask) { + F = m.getOrInsertFunction("llvm.readcyclecounter", Type::Int64Ty, NULL); +} + +CycleCounter::~CycleCounter() {} + +void CycleCounter::PrepFunction(Function* F) {} + +void CycleCounter::ProcessChoicePoint(BasicBlock* bb) { + BranchInst* t = cast(bb->getTerminator()); + + CallInst* c = new CallInst(F, "rdcc", t); + BinaryOperator* b = + BinaryOperator::createAnd(c, ConstantInt::get(Type::Int64Ty, rm), + "mrdcc", t); + + ICmpInst *s = new ICmpInst(ICmpInst::ICMP_EQ, b, + ConstantInt::get(Type::Int64Ty, 0), + "mrdccc", t); + + t->setCondition(s); +} + +/////////////////////////////////////// +// Profiling: +/////////////////////////////////////// +bool RSProfilers_std::isProfiling(Value* v) { + if (profcode.find(v) != profcode.end()) + return true; + //else + RSProfilers& LI = getAnalysis(); + return LI.isProfiling(v); +} + +void RSProfilers_std::IncrementCounterInBlock(BasicBlock *BB, unsigned CounterNum, + GlobalValue *CounterArray) { + // Insert the increment after any alloca or PHI instructions... + BasicBlock::iterator InsertPos = BB->begin(); + while (isa(InsertPos) || isa(InsertPos)) + ++InsertPos; + + // Create the getelementptr constant expression + std::vector Indices(2); + Indices[0] = Constant::getNullValue(Type::Int32Ty); + Indices[1] = ConstantInt::get(Type::Int32Ty, CounterNum); + Constant *ElementPtr = ConstantExpr::getGetElementPtr(CounterArray, + &Indices[0], 2); + + // Load, increment and store the value back. + Value *OldVal = new LoadInst(ElementPtr, "OldCounter", InsertPos); + profcode.insert(OldVal); + Value *NewVal = BinaryOperator::createAdd(OldVal, + ConstantInt::get(Type::Int32Ty, 1), + "NewCounter", InsertPos); + profcode.insert(NewVal); + profcode.insert(new StoreInst(NewVal, ElementPtr, InsertPos)); +} + +void RSProfilers_std::getAnalysisUsage(AnalysisUsage &AU) const { + //grab any outstanding profiler, or get the null one + AU.addRequired(); +} + +/////////////////////////////////////// +// RS Framework +/////////////////////////////////////// + +Value* ProfilerRS::Translate(Value* v) { + if(TransCache[v]) + return TransCache[v]; + + if (BasicBlock* bb = dyn_cast(v)) { + if (bb == &bb->getParent()->getEntryBlock()) + TransCache[bb] = bb; //don't translate entry block + else + TransCache[bb] = new BasicBlock("dup_" + bb->getName(), bb->getParent(), + NULL); + return TransCache[bb]; + } else if (Instruction* i = dyn_cast(v)) { + //we have already translated this + //do not translate entry block allocas + if(&i->getParent()->getParent()->getEntryBlock() == i->getParent()) { + TransCache[i] = i; + return i; + } else { + //translate this + Instruction* i2 = i->clone(); + if (i->hasName()) + i2->setName("dup_" + i->getName()); + TransCache[i] = i2; + //NumNewInst++; + for (unsigned x = 0; x < i2->getNumOperands(); ++x) + i2->setOperand(x, Translate(i2->getOperand(x))); + return i2; + } + } else if (isa(v) || isa(v) || isa(v)) { + TransCache[v] = v; + return v; + } + assert(0 && "Value not handled"); + return 0; +} + +void ProfilerRS::Duplicate(Function& F, RSProfilers& LI) +{ + //perform a breadth first search, building up a duplicate of the code + std::queue worklist; + std::set seen; + + //This loop ensures proper BB order, to help performance + for (Function::iterator fib = F.begin(), fie = F.end(); fib != fie; ++fib) + worklist.push(fib); + while (!worklist.empty()) { + Translate(worklist.front()); + worklist.pop(); + } + + //remember than reg2mem created a new entry block we don't want to duplicate + worklist.push(F.getEntryBlock().getTerminator()->getSuccessor(0)); + seen.insert(&F.getEntryBlock()); + + while (!worklist.empty()) { + BasicBlock* bb = worklist.front(); + worklist.pop(); + if(seen.find(bb) == seen.end()) { + BasicBlock* bbtarget = cast(Translate(bb)); + BasicBlock::InstListType& instlist = bbtarget->getInstList(); + for (BasicBlock::iterator iib = bb->begin(), iie = bb->end(); + iib != iie; ++iib) { + //NumOldInst++; + if (!LI.isProfiling(&*iib)) { + Instruction* i = cast(Translate(iib)); + instlist.insert(bbtarget->end(), i); + } + } + //updated search state; + seen.insert(bb); + TerminatorInst* ti = bb->getTerminator(); + for (unsigned x = 0; x < ti->getNumSuccessors(); ++x) { + BasicBlock* bbs = ti->getSuccessor(x); + if (seen.find(bbs) == seen.end()) { + worklist.push(bbs); + } + } + } + } +} + +void ProfilerRS::ProcessBackEdge(BasicBlock* src, BasicBlock* dst, Function& F) { + //given a backedge from B -> A, and translations A' and B', + //a: insert C and C' + //b: add branches in C to A and A' and in C' to A and A' + //c: mod terminators@B, replace A with C + //d: mod terminators@B', replace A' with C' + //e: mod phis@A for pred B to be pred C + // if multiple entries, simplify to one + //f: mod phis@A' for pred B' to be pred C' + // if multiple entries, simplify to one + //g: for all phis@A with pred C using x + // add in edge from C' using x' + // add in edge from C using x in A' + + //a: + Function::iterator BBN = src; ++BBN; + BasicBlock* bbC = new BasicBlock("choice", &F, BBN); + //ChoicePoints.insert(bbC); + BBN = cast(Translate(src)); + BasicBlock* bbCp = new BasicBlock("choice", &F, ++BBN); + ChoicePoints.insert(bbCp); + + //b: + new BranchInst(cast(Translate(dst)), bbC); + new BranchInst(dst, cast(Translate(dst)), + ConstantInt::get(Type::Int1Ty, true), bbCp); + //c: + { + TerminatorInst* iB = src->getTerminator(); + for (unsigned x = 0; x < iB->getNumSuccessors(); ++x) + if (iB->getSuccessor(x) == dst) + iB->setSuccessor(x, bbC); + } + //d: + { + TerminatorInst* iBp = cast(Translate(src->getTerminator())); + for (unsigned x = 0; x < iBp->getNumSuccessors(); ++x) + if (iBp->getSuccessor(x) == cast(Translate(dst))) + iBp->setSuccessor(x, bbCp); + } + //e: + ReplacePhiPred(dst, src, bbC); + //src could be a switch, in which case we are replacing several edges with one + //thus collapse those edges int the Phi + CollapsePhi(dst, bbC); + //f: + ReplacePhiPred(cast(Translate(dst)), + cast(Translate(src)),bbCp); + CollapsePhi(cast(Translate(dst)), bbCp); + //g: + for(BasicBlock::iterator ib = dst->begin(), ie = dst->end(); ib != ie; + ++ib) + if (PHINode* phi = dyn_cast(&*ib)) { + for(unsigned x = 0; x < phi->getNumIncomingValues(); ++x) + if(bbC == phi->getIncomingBlock(x)) { + phi->addIncoming(Translate(phi->getIncomingValue(x)), bbCp); + cast(Translate(phi))->addIncoming(phi->getIncomingValue(x), + bbC); + } + phi->removeIncomingValue(bbC); + } +} + +bool ProfilerRS::runOnFunction(Function& F) { + if (!F.isDeclaration()) { + std::set > BackEdges; + RSProfilers& LI = getAnalysis(); + + getBackEdges(F, BackEdges); + Duplicate(F, LI); + //assume that stuff worked. now connect the duplicated basic blocks + //with the originals in such a way as to preserve ssa. yuk! + for (std::set >::iterator + ib = BackEdges.begin(), ie = BackEdges.end(); ib != ie; ++ib) + ProcessBackEdge(ib->first, ib->second, F); + + //oh, and add the edge from the reg2mem created entry node to the + //duplicated second node + TerminatorInst* T = F.getEntryBlock().getTerminator(); + ReplaceInstWithInst(T, new BranchInst(T->getSuccessor(0), + cast( + Translate(T->getSuccessor(0))), + ConstantInt::get(Type::Int1Ty, true))); + + //do whatever is needed now that the function is duplicated + c->PrepFunction(&F); + + //add entry node to choice points + ChoicePoints.insert(&F.getEntryBlock()); + + for (std::set::iterator + ii = ChoicePoints.begin(), ie = ChoicePoints.end(); ii != ie; ++ii) + c->ProcessChoicePoint(*ii); + + ChoicePoints.clear(); + TransCache.clear(); + + return true; + } + return false; +} + +bool ProfilerRS::doInitialization(Module &M) { + switch (RandomMethod) { + case GBV: + c = new GlobalRandomCounter(M, Type::Int32Ty, (1 << 14) - 1); + break; + case GBVO: + c = new GlobalRandomCounterOpt(M, Type::Int32Ty, (1 << 14) - 1); + break; + case HOSTCC: + c = new CycleCounter(M, (1 << 14) - 1); + break; + }; + return true; +} + +void ProfilerRS::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.addRequiredID(DemoteRegisterToMemoryID); +} + +/////////////////////////////////////// +// Utilities: +/////////////////////////////////////// +static void ReplacePhiPred(BasicBlock* btarget, + BasicBlock* bold, BasicBlock* bnew) { + for(BasicBlock::iterator ib = btarget->begin(), ie = btarget->end(); + ib != ie; ++ib) + if (PHINode* phi = dyn_cast(&*ib)) { + for(unsigned x = 0; x < phi->getNumIncomingValues(); ++x) + if(bold == phi->getIncomingBlock(x)) + phi->setIncomingBlock(x, bnew); + } +} + +static void CollapsePhi(BasicBlock* btarget, BasicBlock* bsrc) { + for(BasicBlock::iterator ib = btarget->begin(), ie = btarget->end(); + ib != ie; ++ib) + if (PHINode* phi = dyn_cast(&*ib)) { + std::map counter; + for(unsigned i = 0; i < phi->getNumIncomingValues(); ) { + if (counter[phi->getIncomingBlock(i)]) { + assert(phi->getIncomingValue(i) == counter[phi->getIncomingBlock(i)]); + phi->removeIncomingValue(i, false); + } else { + counter[phi->getIncomingBlock(i)] = phi->getIncomingValue(i); + ++i; + } + } + } +} + +template +static void recBackEdge(BasicBlock* bb, T& BackEdges, + std::map& color, + std::map& depth, + std::map& finish, + int& time) +{ + color[bb] = 1; + ++time; + depth[bb] = time; + TerminatorInst* t= bb->getTerminator(); + for(unsigned i = 0; i < t->getNumSuccessors(); ++i) { + BasicBlock* bbnew = t->getSuccessor(i); + if (color[bbnew] == 0) + recBackEdge(bbnew, BackEdges, color, depth, finish, time); + else if (color[bbnew] == 1) { + BackEdges.insert(std::make_pair(bb, bbnew)); + //NumBackEdges++; + } + } + color[bb] = 2; + ++time; + finish[bb] = time; +} + + + +//find the back edges and where they go to +template +static void getBackEdges(Function& F, T& BackEdges) { + std::map color; + std::map depth; + std::map finish; + int time = 0; + recBackEdge(&F.getEntryBlock(), BackEdges, color, depth, finish, time); + DOUT << F.getName() << " " << BackEdges.size() << "\n"; +} + + +//Creation functions +ModulePass* llvm::createNullProfilerRSPass() { + return new NullProfilerRS(); +} + +FunctionPass* llvm::createRSProfilingPass() { + return new ProfilerRS(); +} diff --git a/lib/Transforms/Instrumentation/RSProfiling.h b/lib/Transforms/Instrumentation/RSProfiling.h new file mode 100644 index 0000000..b7c31f2 --- /dev/null +++ b/lib/Transforms/Instrumentation/RSProfiling.h @@ -0,0 +1,31 @@ +//===- RSProfiling.h - Various profiling using random sampling ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// See notes in RSProfiling.cpp +// +//===----------------------------------------------------------------------===// +#include "llvm/Transforms/RSProfiling.h" +#include + +namespace llvm { + /// RSProfilers_std - a simple support class for profilers that handles most + /// of the work of chaining and tracking inserted code. + struct RSProfilers_std : public RSProfilers { + static char ID; + std::set profcode; + // Lookup up values in profcode + virtual bool isProfiling(Value* v); + // handles required chaining + virtual void getAnalysisUsage(AnalysisUsage &AU) const; + // places counter updates in basic blocks and recordes added instructions in + // profcode + void IncrementCounterInBlock(BasicBlock *BB, unsigned CounterNum, + GlobalValue *CounterArray); + }; +} diff --git a/lib/Transforms/Makefile b/lib/Transforms/Makefile new file mode 100644 index 0000000..bc6cc99 --- /dev/null +++ b/lib/Transforms/Makefile @@ -0,0 +1,14 @@ +##===- lib/Transforms/Makefile -----------------------------*- Makefile -*-===## +# +# The LLVM Compiler Infrastructure +# +# This file was developed by the LLVM research group and is distributed under +# the University of Illinois Open Source License. See LICENSE.TXT for details. +# +##===----------------------------------------------------------------------===## + +LEVEL = ../.. +PARALLEL_DIRS = Utils Instrumentation Scalar IPO Hello + +include $(LEVEL)/Makefile.common + diff --git a/lib/Transforms/Scalar/ADCE.cpp b/lib/Transforms/Scalar/ADCE.cpp new file mode 100644 index 0000000..4968fc9 --- /dev/null +++ b/lib/Transforms/Scalar/ADCE.cpp @@ -0,0 +1,497 @@ +//===- ADCE.cpp - Code to perform aggressive dead code elimination --------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements "aggressive" dead code elimination. ADCE is DCe where +// values are assumed to be dead until proven otherwise. This is similar to +// SCCP, except applied to the liveness of values. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "adce" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Support/CFG.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Compiler.h" +#include +using namespace llvm; + +STATISTIC(NumBlockRemoved, "Number of basic blocks removed"); +STATISTIC(NumInstRemoved , "Number of instructions removed"); +STATISTIC(NumCallRemoved , "Number of calls and invokes removed"); + +namespace { +//===----------------------------------------------------------------------===// +// ADCE Class +// +// This class does all of the work of Aggressive Dead Code Elimination. +// It's public interface consists of a constructor and a doADCE() method. +// +class VISIBILITY_HIDDEN ADCE : public FunctionPass { + Function *Func; // The function that we are working on + std::vector WorkList; // Instructions that just became live + std::set LiveSet; // The set of live instructions + + //===--------------------------------------------------------------------===// + // The public interface for this class + // +public: + static char ID; // Pass identification, replacement for typeid + ADCE() : FunctionPass((intptr_t)&ID) {} + + // Execute the Aggressive Dead Code Elimination Algorithm + // + virtual bool runOnFunction(Function &F) { + Func = &F; + bool Changed = doADCE(); + assert(WorkList.empty()); + LiveSet.clear(); + return Changed; + } + // getAnalysisUsage - We require post dominance frontiers (aka Control + // Dependence Graph) + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + // We require that all function nodes are unified, because otherwise code + // can be marked live that wouldn't necessarily be otherwise. + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + } + + + //===--------------------------------------------------------------------===// + // The implementation of this class + // +private: + // doADCE() - Run the Aggressive Dead Code Elimination algorithm, returning + // true if the function was modified. + // + bool doADCE(); + + void markBlockAlive(BasicBlock *BB); + + + // deleteDeadInstructionsInLiveBlock - Loop over all of the instructions in + // the specified basic block, deleting ones that are dead according to + // LiveSet. + bool deleteDeadInstructionsInLiveBlock(BasicBlock *BB); + + TerminatorInst *convertToUnconditionalBranch(TerminatorInst *TI); + + inline void markInstructionLive(Instruction *I) { + if (!LiveSet.insert(I).second) return; + DOUT << "Insn Live: " << *I; + WorkList.push_back(I); + } + + inline void markTerminatorLive(const BasicBlock *BB) { + DOUT << "Terminator Live: " << *BB->getTerminator(); + markInstructionLive(const_cast(BB->getTerminator())); + } +}; + + char ADCE::ID = 0; + RegisterPass X("adce", "Aggressive Dead Code Elimination"); +} // End of anonymous namespace + +FunctionPass *llvm::createAggressiveDCEPass() { return new ADCE(); } + +void ADCE::markBlockAlive(BasicBlock *BB) { + // Mark the basic block as being newly ALIVE... and mark all branches that + // this block is control dependent on as being alive also... + // + PostDominanceFrontier &CDG = getAnalysis(); + + PostDominanceFrontier::const_iterator It = CDG.find(BB); + if (It != CDG.end()) { + // Get the blocks that this node is control dependent on... + const PostDominanceFrontier::DomSetType &CDB = It->second; + for (PostDominanceFrontier::DomSetType::const_iterator I = + CDB.begin(), E = CDB.end(); I != E; ++I) + markTerminatorLive(*I); // Mark all their terminators as live + } + + // If this basic block is live, and it ends in an unconditional branch, then + // the branch is alive as well... + if (BranchInst *BI = dyn_cast(BB->getTerminator())) + if (BI->isUnconditional()) + markTerminatorLive(BB); +} + +// deleteDeadInstructionsInLiveBlock - Loop over all of the instructions in the +// specified basic block, deleting ones that are dead according to LiveSet. +bool ADCE::deleteDeadInstructionsInLiveBlock(BasicBlock *BB) { + bool Changed = false; + for (BasicBlock::iterator II = BB->begin(), E = --BB->end(); II != E; ) { + Instruction *I = II++; + if (!LiveSet.count(I)) { // Is this instruction alive? + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + + // Nope... remove the instruction from it's basic block... + if (isa(I)) + ++NumCallRemoved; + else + ++NumInstRemoved; + BB->getInstList().erase(I); + Changed = true; + } + } + return Changed; +} + + +/// convertToUnconditionalBranch - Transform this conditional terminator +/// instruction into an unconditional branch because we don't care which of the +/// successors it goes to. This eliminate a use of the condition as well. +/// +TerminatorInst *ADCE::convertToUnconditionalBranch(TerminatorInst *TI) { + BranchInst *NB = new BranchInst(TI->getSuccessor(0), TI); + BasicBlock *BB = TI->getParent(); + + // Remove entries from PHI nodes to avoid confusing ourself later... + for (unsigned i = 1, e = TI->getNumSuccessors(); i != e; ++i) + TI->getSuccessor(i)->removePredecessor(BB); + + // Delete the old branch itself... + BB->getInstList().erase(TI); + return NB; +} + + +// doADCE() - Run the Aggressive Dead Code Elimination algorithm, returning +// true if the function was modified. +// +bool ADCE::doADCE() { + bool MadeChanges = false; + + AliasAnalysis &AA = getAnalysis(); + + + // Iterate over all invokes in the function, turning invokes into calls if + // they cannot throw. + for (Function::iterator BB = Func->begin(), E = Func->end(); BB != E; ++BB) + if (InvokeInst *II = dyn_cast(BB->getTerminator())) + if (Function *F = II->getCalledFunction()) + if (AA.onlyReadsMemory(F)) { + // The function cannot unwind. Convert it to a call with a branch + // after it to the normal destination. + SmallVector Args(II->op_begin()+3, II->op_end()); + CallInst *NewCall = new CallInst(F, &Args[0], Args.size(), "", II); + NewCall->takeName(II); + NewCall->setCallingConv(II->getCallingConv()); + II->replaceAllUsesWith(NewCall); + new BranchInst(II->getNormalDest(), II); + + // Update PHI nodes in the unwind destination + II->getUnwindDest()->removePredecessor(BB); + BB->getInstList().erase(II); + + if (NewCall->use_empty()) { + BB->getInstList().erase(NewCall); + ++NumCallRemoved; + } + } + + // Iterate over all of the instructions in the function, eliminating trivially + // dead instructions, and marking instructions live that are known to be + // needed. Perform the walk in depth first order so that we avoid marking any + // instructions live in basic blocks that are unreachable. These blocks will + // be eliminated later, along with the instructions inside. + // + std::set ReachableBBs; + for (df_ext_iterator + BBI = df_ext_begin(&Func->front(), ReachableBBs), + BBE = df_ext_end(&Func->front(), ReachableBBs); BBI != BBE; ++BBI) { + BasicBlock *BB = *BBI; + for (BasicBlock::iterator II = BB->begin(), EI = BB->end(); II != EI; ) { + Instruction *I = II++; + if (CallInst *CI = dyn_cast(I)) { + Function *F = CI->getCalledFunction(); + if (F && AA.onlyReadsMemory(F)) { + if (CI->use_empty()) { + BB->getInstList().erase(CI); + ++NumCallRemoved; + } + } else { + markInstructionLive(I); + } + } else if (I->mayWriteToMemory() || isa(I) || + isa(I) || isa(I)) { + // FIXME: Unreachable instructions should not be marked intrinsically + // live here. + markInstructionLive(I); + } else if (isInstructionTriviallyDead(I)) { + // Remove the instruction from it's basic block... + BB->getInstList().erase(I); + ++NumInstRemoved; + } + } + } + + // Check to ensure we have an exit node for this CFG. If we don't, we won't + // have any post-dominance information, thus we cannot perform our + // transformations safely. + // + PostDominatorTree &DT = getAnalysis(); + if (DT[&Func->getEntryBlock()] == 0) { + WorkList.clear(); + return MadeChanges; + } + + // Scan the function marking blocks without post-dominance information as + // live. Blocks without post-dominance information occur when there is an + // infinite loop in the program. Because the infinite loop could contain a + // function which unwinds, exits or has side-effects, we don't want to delete + // the infinite loop or those blocks leading up to it. + for (Function::iterator I = Func->begin(), E = Func->end(); I != E; ++I) + if (DT[I] == 0 && ReachableBBs.count(I)) + for (pred_iterator PI = pred_begin(I), E = pred_end(I); PI != E; ++PI) + markInstructionLive((*PI)->getTerminator()); + + DOUT << "Processing work list\n"; + + // AliveBlocks - Set of basic blocks that we know have instructions that are + // alive in them... + // + std::set AliveBlocks; + + // Process the work list of instructions that just became live... if they + // became live, then that means that all of their operands are necessary as + // well... make them live as well. + // + while (!WorkList.empty()) { + Instruction *I = WorkList.back(); // Get an instruction that became live... + WorkList.pop_back(); + + BasicBlock *BB = I->getParent(); + if (!ReachableBBs.count(BB)) continue; + if (AliveBlocks.insert(BB).second) // Basic block not alive yet. + markBlockAlive(BB); // Make it so now! + + // PHI nodes are a special case, because the incoming values are actually + // defined in the predecessor nodes of this block, meaning that the PHI + // makes the predecessors alive. + // + if (PHINode *PN = dyn_cast(I)) { + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + // If the incoming edge is clearly dead, it won't have control + // dependence information. Do not mark it live. + BasicBlock *PredBB = PN->getIncomingBlock(i); + if (ReachableBBs.count(PredBB)) { + // FIXME: This should mark the control dependent edge as live, not + // necessarily the predecessor itself! + if (AliveBlocks.insert(PredBB).second) + markBlockAlive(PN->getIncomingBlock(i)); // Block is newly ALIVE! + if (Instruction *Op = dyn_cast(PN->getIncomingValue(i))) + markInstructionLive(Op); + } + } + } else { + // Loop over all of the operands of the live instruction, making sure that + // they are known to be alive as well. + // + for (unsigned op = 0, End = I->getNumOperands(); op != End; ++op) + if (Instruction *Operand = dyn_cast(I->getOperand(op))) + markInstructionLive(Operand); + } + } + + DEBUG( + DOUT << "Current Function: X = Live\n"; + for (Function::iterator I = Func->begin(), E = Func->end(); I != E; ++I){ + DOUT << I->getName() << ":\t" + << (AliveBlocks.count(I) ? "LIVE\n" : "DEAD\n"); + for (BasicBlock::iterator BI = I->begin(), BE = I->end(); BI != BE; ++BI){ + if (LiveSet.count(BI)) DOUT << "X "; + DOUT << *BI; + } + }); + + // All blocks being live is a common case, handle it specially. + if (AliveBlocks.size() == Func->size()) { // No dead blocks? + for (Function::iterator I = Func->begin(), E = Func->end(); I != E; ++I) { + // Loop over all of the instructions in the function deleting instructions + // to drop their references. + deleteDeadInstructionsInLiveBlock(I); + + // Check to make sure the terminator instruction is live. If it isn't, + // this means that the condition that it branches on (we know it is not an + // unconditional branch), is not needed to make the decision of where to + // go to, because all outgoing edges go to the same place. We must remove + // the use of the condition (because it's probably dead), so we convert + // the terminator to an unconditional branch. + // + TerminatorInst *TI = I->getTerminator(); + if (!LiveSet.count(TI)) + convertToUnconditionalBranch(TI); + } + + return MadeChanges; + } + + + // If the entry node is dead, insert a new entry node to eliminate the entry + // node as a special case. + // + if (!AliveBlocks.count(&Func->front())) { + BasicBlock *NewEntry = new BasicBlock(); + new BranchInst(&Func->front(), NewEntry); + Func->getBasicBlockList().push_front(NewEntry); + AliveBlocks.insert(NewEntry); // This block is always alive! + LiveSet.insert(NewEntry->getTerminator()); // The branch is live + } + + // Loop over all of the alive blocks in the function. If any successor + // blocks are not alive, we adjust the outgoing branches to branch to the + // first live postdominator of the live block, adjusting any PHI nodes in + // the block to reflect this. + // + for (Function::iterator I = Func->begin(), E = Func->end(); I != E; ++I) + if (AliveBlocks.count(I)) { + BasicBlock *BB = I; + TerminatorInst *TI = BB->getTerminator(); + + // If the terminator instruction is alive, but the block it is contained + // in IS alive, this means that this terminator is a conditional branch on + // a condition that doesn't matter. Make it an unconditional branch to + // ONE of the successors. This has the side effect of dropping a use of + // the conditional value, which may also be dead. + if (!LiveSet.count(TI)) + TI = convertToUnconditionalBranch(TI); + + // Loop over all of the successors, looking for ones that are not alive. + // We cannot save the number of successors in the terminator instruction + // here because we may remove them if we don't have a postdominator. + // + for (unsigned i = 0; i != TI->getNumSuccessors(); ++i) + if (!AliveBlocks.count(TI->getSuccessor(i))) { + // Scan up the postdominator tree, looking for the first + // postdominator that is alive, and the last postdominator that is + // dead... + // + DomTreeNode *LastNode = DT[TI->getSuccessor(i)]; + DomTreeNode *NextNode = 0; + + if (LastNode) { + NextNode = LastNode->getIDom(); + while (!AliveBlocks.count(NextNode->getBlock())) { + LastNode = NextNode; + NextNode = NextNode->getIDom(); + if (NextNode == 0) { + LastNode = 0; + break; + } + } + } + + // There is a special case here... if there IS no post-dominator for + // the block we have nowhere to point our branch to. Instead, convert + // it to a return. This can only happen if the code branched into an + // infinite loop. Note that this may not be desirable, because we + // _are_ altering the behavior of the code. This is a well known + // drawback of ADCE, so in the future if we choose to revisit the + // decision, this is where it should be. + // + if (LastNode == 0) { // No postdominator! + if (!isa(TI)) { + // Call RemoveSuccessor to transmogrify the terminator instruction + // to not contain the outgoing branch, or to create a new + // terminator if the form fundamentally changes (i.e., + // unconditional branch to return). Note that this will change a + // branch into an infinite loop into a return instruction! + // + RemoveSuccessor(TI, i); + + // RemoveSuccessor may replace TI... make sure we have a fresh + // pointer. + // + TI = BB->getTerminator(); + + // Rescan this successor... + --i; + } else { + + } + } else { + // Get the basic blocks that we need... + BasicBlock *LastDead = LastNode->getBlock(); + BasicBlock *NextAlive = NextNode->getBlock(); + + // Make the conditional branch now go to the next alive block... + TI->getSuccessor(i)->removePredecessor(BB); + TI->setSuccessor(i, NextAlive); + + // If there are PHI nodes in NextAlive, we need to add entries to + // the PHI nodes for the new incoming edge. The incoming values + // should be identical to the incoming values for LastDead. + // + for (BasicBlock::iterator II = NextAlive->begin(); + isa(II); ++II) { + PHINode *PN = cast(II); + if (LiveSet.count(PN)) { // Only modify live phi nodes + // Get the incoming value for LastDead... + int OldIdx = PN->getBasicBlockIndex(LastDead); + assert(OldIdx != -1 &&"LastDead is not a pred of NextAlive!"); + Value *InVal = PN->getIncomingValue(OldIdx); + + // Add an incoming value for BB now... + PN->addIncoming(InVal, BB); + } + } + } + } + + // Now loop over all of the instructions in the basic block, deleting + // dead instructions. This is so that the next sweep over the program + // can safely delete dead instructions without other dead instructions + // still referring to them. + // + deleteDeadInstructionsInLiveBlock(BB); + } + + // Loop over all of the basic blocks in the function, dropping references of + // the dead basic blocks. We must do this after the previous step to avoid + // dropping references to PHIs which still have entries... + // + std::vector DeadBlocks; + for (Function::iterator BB = Func->begin(), E = Func->end(); BB != E; ++BB) + if (!AliveBlocks.count(BB)) { + // Remove PHI node entries for this block in live successor blocks. + for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI) + if (!SI->empty() && isa(SI->front()) && AliveBlocks.count(*SI)) + (*SI)->removePredecessor(BB); + + BB->dropAllReferences(); + MadeChanges = true; + DeadBlocks.push_back(BB); + } + + NumBlockRemoved += DeadBlocks.size(); + + // Now loop through all of the blocks and delete the dead ones. We can safely + // do this now because we know that there are no references to dead blocks + // (because they have dropped all of their references). + for (std::vector::iterator I = DeadBlocks.begin(), + E = DeadBlocks.end(); I != E; ++I) + Func->getBasicBlockList().erase(*I); + + return MadeChanges; +} diff --git a/lib/Transforms/Scalar/BasicBlockPlacement.cpp b/lib/Transforms/Scalar/BasicBlockPlacement.cpp new file mode 100644 index 0000000..7521ea3 --- /dev/null +++ b/lib/Transforms/Scalar/BasicBlockPlacement.cpp @@ -0,0 +1,148 @@ +//===-- BasicBlockPlacement.cpp - Basic Block Code Layout optimization ----===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a very simple profile guided basic block placement +// algorithm. The idea is to put frequently executed blocks together at the +// start of the function, and hopefully increase the number of fall-through +// conditional branches. If there is no profile information for a particular +// function, this pass basically orders blocks in depth-first order +// +// The algorithm implemented here is basically "Algo1" from "Profile Guided Code +// Positioning" by Pettis and Hansen, except that it uses basic block counts +// instead of edge counts. This should be improved in many ways, but is very +// simple for now. +// +// Basically we "place" the entry block, then loop over all successors in a DFO, +// placing the most frequently executed successor until we run out of blocks. I +// told you this was _extremely_ simplistic. :) This is also much slower than it +// could be. When it becomes important, this pass will be rewritten to use a +// better algorithm, and then we can worry about efficiency. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "block-placement" +#include "llvm/Analysis/ProfileInfo.h" +#include "llvm/Function.h" +#include "llvm/Pass.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Transforms/Scalar.h" +#include +using namespace llvm; + +STATISTIC(NumMoved, "Number of basic blocks moved"); + +namespace { + struct VISIBILITY_HIDDEN BlockPlacement : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + BlockPlacement() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequired(); + //AU.addPreserved(); // Does this work? + } + private: + /// PI - The profile information that is guiding us. + /// + ProfileInfo *PI; + + /// NumMovedBlocks - Every time we move a block, increment this counter. + /// + unsigned NumMovedBlocks; + + /// PlacedBlocks - Every time we place a block, remember it so we don't get + /// into infinite loops. + std::set PlacedBlocks; + + /// InsertPos - This an iterator to the next place we want to insert a + /// block. + Function::iterator InsertPos; + + /// PlaceBlocks - Recursively place the specified blocks and any unplaced + /// successors. + void PlaceBlocks(BasicBlock *BB); + }; + + char BlockPlacement::ID = 0; + RegisterPass X("block-placement", + "Profile Guided Basic Block Placement"); +} + +FunctionPass *llvm::createBlockPlacementPass() { return new BlockPlacement(); } + +bool BlockPlacement::runOnFunction(Function &F) { + PI = &getAnalysis(); + + NumMovedBlocks = 0; + InsertPos = F.begin(); + + // Recursively place all blocks. + PlaceBlocks(F.begin()); + + PlacedBlocks.clear(); + NumMoved += NumMovedBlocks; + return NumMovedBlocks != 0; +} + + +/// PlaceBlocks - Recursively place the specified blocks and any unplaced +/// successors. +void BlockPlacement::PlaceBlocks(BasicBlock *BB) { + assert(!PlacedBlocks.count(BB) && "Already placed this block!"); + PlacedBlocks.insert(BB); + + // Place the specified block. + if (&*InsertPos != BB) { + // Use splice to move the block into the right place. This avoids having to + // remove the block from the function then readd it, which causes a bunch of + // symbol table traffic that is entirely pointless. + Function::BasicBlockListType &Blocks = BB->getParent()->getBasicBlockList(); + Blocks.splice(InsertPos, Blocks, BB); + + ++NumMovedBlocks; + } else { + // This block is already in the right place, we don't have to do anything. + ++InsertPos; + } + + // Keep placing successors until we run out of ones to place. Note that this + // loop is very inefficient (N^2) for blocks with many successors, like switch + // statements. FIXME! + while (1) { + // Okay, now place any unplaced successors. + succ_iterator SI = succ_begin(BB), E = succ_end(BB); + + // Scan for the first unplaced successor. + for (; SI != E && PlacedBlocks.count(*SI); ++SI) + /*empty*/; + if (SI == E) return; // No more successors to place. + + unsigned MaxExecutionCount = PI->getExecutionCount(*SI); + BasicBlock *MaxSuccessor = *SI; + + // Scan for more frequently executed successors + for (; SI != E; ++SI) + if (!PlacedBlocks.count(*SI)) { + unsigned Count = PI->getExecutionCount(*SI); + if (Count > MaxExecutionCount || + // Prefer to not disturb the code. + (Count == MaxExecutionCount && *SI == &*InsertPos)) { + MaxExecutionCount = Count; + MaxSuccessor = *SI; + } + } + + // Now that we picked the maximally executed successor, place it. + PlaceBlocks(MaxSuccessor); + } +} diff --git a/lib/Transforms/Scalar/CodeGenPrepare.cpp b/lib/Transforms/Scalar/CodeGenPrepare.cpp new file mode 100644 index 0000000..2969df3 --- /dev/null +++ b/lib/Transforms/Scalar/CodeGenPrepare.cpp @@ -0,0 +1,988 @@ +//===- CodeGenPrepare.cpp - Prepare a function for code generation --------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by Chris Lattner and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass munges the code in the input function to better prepare it for +// SelectionDAG-based code generation. This works around limitations in it's +// basic-block-at-a-time approach. It should eventually be removed. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "codegenprepare" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Target/TargetAsmInfo.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Target/TargetLowering.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +using namespace llvm; + +namespace { + class VISIBILITY_HIDDEN CodeGenPrepare : public FunctionPass { + /// TLI - Keep a pointer of a TargetLowering to consult for determining + /// transformation profitability. + const TargetLowering *TLI; + public: + static char ID; // Pass identification, replacement for typeid + CodeGenPrepare(const TargetLowering *tli = 0) : FunctionPass((intptr_t)&ID), + TLI(tli) {} + bool runOnFunction(Function &F); + + private: + bool EliminateMostlyEmptyBlocks(Function &F); + bool CanMergeBlocks(const BasicBlock *BB, const BasicBlock *DestBB) const; + void EliminateMostlyEmptyBlock(BasicBlock *BB); + bool OptimizeBlock(BasicBlock &BB); + bool OptimizeLoadStoreInst(Instruction *I, Value *Addr, + const Type *AccessTy, + DenseMap &SunkAddrs); + }; +} + +char CodeGenPrepare::ID = 0; +static RegisterPass X("codegenprepare", + "Optimize for code generation"); + +FunctionPass *llvm::createCodeGenPreparePass(const TargetLowering *TLI) { + return new CodeGenPrepare(TLI); +} + + +bool CodeGenPrepare::runOnFunction(Function &F) { + bool EverMadeChange = false; + + // First pass, eliminate blocks that contain only PHI nodes and an + // unconditional branch. + EverMadeChange |= EliminateMostlyEmptyBlocks(F); + + bool MadeChange = true; + while (MadeChange) { + MadeChange = false; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + MadeChange |= OptimizeBlock(*BB); + EverMadeChange |= MadeChange; + } + return EverMadeChange; +} + +/// EliminateMostlyEmptyBlocks - eliminate blocks that contain only PHI nodes +/// and an unconditional branch. Passes before isel (e.g. LSR/loopsimplify) +/// often split edges in ways that are non-optimal for isel. Start by +/// eliminating these blocks so we can split them the way we want them. +bool CodeGenPrepare::EliminateMostlyEmptyBlocks(Function &F) { + bool MadeChange = false; + // Note that this intentionally skips the entry block. + for (Function::iterator I = ++F.begin(), E = F.end(); I != E; ) { + BasicBlock *BB = I++; + + // If this block doesn't end with an uncond branch, ignore it. + BranchInst *BI = dyn_cast(BB->getTerminator()); + if (!BI || !BI->isUnconditional()) + continue; + + // If the instruction before the branch isn't a phi node, then other stuff + // is happening here. + BasicBlock::iterator BBI = BI; + if (BBI != BB->begin()) { + --BBI; + if (!isa(BBI)) continue; + } + + // Do not break infinite loops. + BasicBlock *DestBB = BI->getSuccessor(0); + if (DestBB == BB) + continue; + + if (!CanMergeBlocks(BB, DestBB)) + continue; + + EliminateMostlyEmptyBlock(BB); + MadeChange = true; + } + return MadeChange; +} + +/// CanMergeBlocks - Return true if we can merge BB into DestBB if there is a +/// single uncond branch between them, and BB contains no other non-phi +/// instructions. +bool CodeGenPrepare::CanMergeBlocks(const BasicBlock *BB, + const BasicBlock *DestBB) const { + // We only want to eliminate blocks whose phi nodes are used by phi nodes in + // the successor. If there are more complex condition (e.g. preheaders), + // don't mess around with them. + BasicBlock::const_iterator BBI = BB->begin(); + while (const PHINode *PN = dyn_cast(BBI++)) { + for (Value::use_const_iterator UI = PN->use_begin(), E = PN->use_end(); + UI != E; ++UI) { + const Instruction *User = cast(*UI); + if (User->getParent() != DestBB || !isa(User)) + return false; + // If User is inside DestBB block and it is a PHINode then check + // incoming value. If incoming value is not from BB then this is + // a complex condition (e.g. preheaders) we want to avoid here. + if (User->getParent() == DestBB) { + if (const PHINode *UPN = dyn_cast(User)) + for (unsigned I = 0, E = UPN->getNumIncomingValues(); I != E; ++I) { + Instruction *Insn = dyn_cast(UPN->getIncomingValue(I)); + if (Insn && Insn->getParent() == BB && + Insn->getParent() != UPN->getIncomingBlock(I)) + return false; + } + } + } + } + + // If BB and DestBB contain any common predecessors, then the phi nodes in BB + // and DestBB may have conflicting incoming values for the block. If so, we + // can't merge the block. + const PHINode *DestBBPN = dyn_cast(DestBB->begin()); + if (!DestBBPN) return true; // no conflict. + + // Collect the preds of BB. + SmallPtrSet BBPreds; + if (const PHINode *BBPN = dyn_cast(BB->begin())) { + // It is faster to get preds from a PHI than with pred_iterator. + for (unsigned i = 0, e = BBPN->getNumIncomingValues(); i != e; ++i) + BBPreds.insert(BBPN->getIncomingBlock(i)); + } else { + BBPreds.insert(pred_begin(BB), pred_end(BB)); + } + + // Walk the preds of DestBB. + for (unsigned i = 0, e = DestBBPN->getNumIncomingValues(); i != e; ++i) { + BasicBlock *Pred = DestBBPN->getIncomingBlock(i); + if (BBPreds.count(Pred)) { // Common predecessor? + BBI = DestBB->begin(); + while (const PHINode *PN = dyn_cast(BBI++)) { + const Value *V1 = PN->getIncomingValueForBlock(Pred); + const Value *V2 = PN->getIncomingValueForBlock(BB); + + // If V2 is a phi node in BB, look up what the mapped value will be. + if (const PHINode *V2PN = dyn_cast(V2)) + if (V2PN->getParent() == BB) + V2 = V2PN->getIncomingValueForBlock(Pred); + + // If there is a conflict, bail out. + if (V1 != V2) return false; + } + } + } + + return true; +} + + +/// EliminateMostlyEmptyBlock - Eliminate a basic block that have only phi's and +/// an unconditional branch in it. +void CodeGenPrepare::EliminateMostlyEmptyBlock(BasicBlock *BB) { + BranchInst *BI = cast(BB->getTerminator()); + BasicBlock *DestBB = BI->getSuccessor(0); + + DOUT << "MERGING MOSTLY EMPTY BLOCKS - BEFORE:\n" << *BB << *DestBB; + + // If the destination block has a single pred, then this is a trivial edge, + // just collapse it. + if (DestBB->getSinglePredecessor()) { + // If DestBB has single-entry PHI nodes, fold them. + while (PHINode *PN = dyn_cast(DestBB->begin())) { + PN->replaceAllUsesWith(PN->getIncomingValue(0)); + PN->eraseFromParent(); + } + + // Splice all the PHI nodes from BB over to DestBB. + DestBB->getInstList().splice(DestBB->begin(), BB->getInstList(), + BB->begin(), BI); + + // Anything that branched to BB now branches to DestBB. + BB->replaceAllUsesWith(DestBB); + + // Nuke BB. + BB->eraseFromParent(); + + DOUT << "AFTER:\n" << *DestBB << "\n\n\n"; + return; + } + + // Otherwise, we have multiple predecessors of BB. Update the PHIs in DestBB + // to handle the new incoming edges it is about to have. + PHINode *PN; + for (BasicBlock::iterator BBI = DestBB->begin(); + (PN = dyn_cast(BBI)); ++BBI) { + // Remove the incoming value for BB, and remember it. + Value *InVal = PN->removeIncomingValue(BB, false); + + // Two options: either the InVal is a phi node defined in BB or it is some + // value that dominates BB. + PHINode *InValPhi = dyn_cast(InVal); + if (InValPhi && InValPhi->getParent() == BB) { + // Add all of the input values of the input PHI as inputs of this phi. + for (unsigned i = 0, e = InValPhi->getNumIncomingValues(); i != e; ++i) + PN->addIncoming(InValPhi->getIncomingValue(i), + InValPhi->getIncomingBlock(i)); + } else { + // Otherwise, add one instance of the dominating value for each edge that + // we will be adding. + if (PHINode *BBPN = dyn_cast(BB->begin())) { + for (unsigned i = 0, e = BBPN->getNumIncomingValues(); i != e; ++i) + PN->addIncoming(InVal, BBPN->getIncomingBlock(i)); + } else { + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) + PN->addIncoming(InVal, *PI); + } + } + } + + // The PHIs are now updated, change everything that refers to BB to use + // DestBB and remove BB. + BB->replaceAllUsesWith(DestBB); + BB->eraseFromParent(); + + DOUT << "AFTER:\n" << *DestBB << "\n\n\n"; +} + + +/// SplitEdgeNicely - Split the critical edge from TI to it's specified +/// successor if it will improve codegen. We only do this if the successor has +/// phi nodes (otherwise critical edges are ok). If there is already another +/// predecessor of the succ that is empty (and thus has no phi nodes), use it +/// instead of introducing a new block. +static void SplitEdgeNicely(TerminatorInst *TI, unsigned SuccNum, Pass *P) { + BasicBlock *TIBB = TI->getParent(); + BasicBlock *Dest = TI->getSuccessor(SuccNum); + assert(isa(Dest->begin()) && + "This should only be called if Dest has a PHI!"); + + /// TIPHIValues - This array is lazily computed to determine the values of + /// PHIs in Dest that TI would provide. + std::vector TIPHIValues; + + // Check to see if Dest has any blocks that can be used as a split edge for + // this terminator. + for (pred_iterator PI = pred_begin(Dest), E = pred_end(Dest); PI != E; ++PI) { + BasicBlock *Pred = *PI; + // To be usable, the pred has to end with an uncond branch to the dest. + BranchInst *PredBr = dyn_cast(Pred->getTerminator()); + if (!PredBr || !PredBr->isUnconditional() || + // Must be empty other than the branch. + &Pred->front() != PredBr || + // Cannot be the entry block; its label does not get emitted. + Pred == &(Dest->getParent()->getEntryBlock())) + continue; + + // Finally, since we know that Dest has phi nodes in it, we have to make + // sure that jumping to Pred will have the same affect as going to Dest in + // terms of PHI values. + PHINode *PN; + unsigned PHINo = 0; + bool FoundMatch = true; + for (BasicBlock::iterator I = Dest->begin(); + (PN = dyn_cast(I)); ++I, ++PHINo) { + if (PHINo == TIPHIValues.size()) + TIPHIValues.push_back(PN->getIncomingValueForBlock(TIBB)); + + // If the PHI entry doesn't work, we can't use this pred. + if (TIPHIValues[PHINo] != PN->getIncomingValueForBlock(Pred)) { + FoundMatch = false; + break; + } + } + + // If we found a workable predecessor, change TI to branch to Succ. + if (FoundMatch) { + Dest->removePredecessor(TIBB); + TI->setSuccessor(SuccNum, Pred); + return; + } + } + + SplitCriticalEdge(TI, SuccNum, P, true); +} + +/// OptimizeNoopCopyExpression - If the specified cast instruction is a noop +/// copy (e.g. it's casting from one pointer type to another, int->uint, or +/// int->sbyte on PPC), sink it into user blocks to reduce the number of virtual +/// registers that must be created and coalesced. +/// +/// Return true if any changes are made. +static bool OptimizeNoopCopyExpression(CastInst *CI, const TargetLowering &TLI){ + // If this is a noop copy, + MVT::ValueType SrcVT = TLI.getValueType(CI->getOperand(0)->getType()); + MVT::ValueType DstVT = TLI.getValueType(CI->getType()); + + // This is an fp<->int conversion? + if (MVT::isInteger(SrcVT) != MVT::isInteger(DstVT)) + return false; + + // If this is an extension, it will be a zero or sign extension, which + // isn't a noop. + if (SrcVT < DstVT) return false; + + // If these values will be promoted, find out what they will be promoted + // to. This helps us consider truncates on PPC as noop copies when they + // are. + if (TLI.getTypeAction(SrcVT) == TargetLowering::Promote) + SrcVT = TLI.getTypeToTransformTo(SrcVT); + if (TLI.getTypeAction(DstVT) == TargetLowering::Promote) + DstVT = TLI.getTypeToTransformTo(DstVT); + + // If, after promotion, these are the same types, this is a noop copy. + if (SrcVT != DstVT) + return false; + + BasicBlock *DefBB = CI->getParent(); + + /// InsertedCasts - Only insert a cast in each block once. + DenseMap InsertedCasts; + + bool MadeChange = false; + for (Value::use_iterator UI = CI->use_begin(), E = CI->use_end(); + UI != E; ) { + Use &TheUse = UI.getUse(); + Instruction *User = cast(*UI); + + // Figure out which BB this cast is used in. For PHI's this is the + // appropriate predecessor block. + BasicBlock *UserBB = User->getParent(); + if (PHINode *PN = dyn_cast(User)) { + unsigned OpVal = UI.getOperandNo()/2; + UserBB = PN->getIncomingBlock(OpVal); + } + + // Preincrement use iterator so we don't invalidate it. + ++UI; + + // If this user is in the same block as the cast, don't change the cast. + if (UserBB == DefBB) continue; + + // If we have already inserted a cast into this block, use it. + CastInst *&InsertedCast = InsertedCasts[UserBB]; + + if (!InsertedCast) { + BasicBlock::iterator InsertPt = UserBB->begin(); + while (isa(InsertPt)) ++InsertPt; + + InsertedCast = + CastInst::create(CI->getOpcode(), CI->getOperand(0), CI->getType(), "", + InsertPt); + MadeChange = true; + } + + // Replace a use of the cast with a use of the new cast. + TheUse = InsertedCast; + } + + // If we removed all uses, nuke the cast. + if (CI->use_empty()) + CI->eraseFromParent(); + + return MadeChange; +} + +/// OptimizeCmpExpression - sink the given CmpInst into user blocks to reduce +/// the number of virtual registers that must be created and coalesced. This is +/// a clear win except on targets with multiple condition code registers (powerPC), +/// where it might lose; some adjustment may be wanted there. +/// +/// Return true if any changes are made. +static bool OptimizeCmpExpression(CmpInst *CI){ + + BasicBlock *DefBB = CI->getParent(); + + /// InsertedCmp - Only insert a cmp in each block once. + DenseMap InsertedCmps; + + bool MadeChange = false; + for (Value::use_iterator UI = CI->use_begin(), E = CI->use_end(); + UI != E; ) { + Use &TheUse = UI.getUse(); + Instruction *User = cast(*UI); + + // Preincrement use iterator so we don't invalidate it. + ++UI; + + // Don't bother for PHI nodes. + if (isa(User)) + continue; + + // Figure out which BB this cmp is used in. + BasicBlock *UserBB = User->getParent(); + + // If this user is in the same block as the cmp, don't change the cmp. + if (UserBB == DefBB) continue; + + // If we have already inserted a cmp into this block, use it. + CmpInst *&InsertedCmp = InsertedCmps[UserBB]; + + if (!InsertedCmp) { + BasicBlock::iterator InsertPt = UserBB->begin(); + while (isa(InsertPt)) ++InsertPt; + + InsertedCmp = + CmpInst::create(CI->getOpcode(), CI->getPredicate(), CI->getOperand(0), + CI->getOperand(1), "", InsertPt); + MadeChange = true; + } + + // Replace a use of the cmp with a use of the new cmp. + TheUse = InsertedCmp; + } + + // If we removed all uses, nuke the cmp. + if (CI->use_empty()) + CI->eraseFromParent(); + + return MadeChange; +} + +/// EraseDeadInstructions - Erase any dead instructions +static void EraseDeadInstructions(Value *V) { + Instruction *I = dyn_cast(V); + if (!I || !I->use_empty()) return; + + SmallPtrSet Insts; + Insts.insert(I); + + while (!Insts.empty()) { + I = *Insts.begin(); + Insts.erase(I); + if (isInstructionTriviallyDead(I)) { + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Instruction *U = dyn_cast(I->getOperand(i))) + Insts.insert(U); + I->eraseFromParent(); + } + } +} + + +/// ExtAddrMode - This is an extended version of TargetLowering::AddrMode which +/// holds actual Value*'s for register values. +struct ExtAddrMode : public TargetLowering::AddrMode { + Value *BaseReg; + Value *ScaledReg; + ExtAddrMode() : BaseReg(0), ScaledReg(0) {} + void dump() const; +}; + +static std::ostream &operator<<(std::ostream &OS, const ExtAddrMode &AM) { + bool NeedPlus = false; + OS << "["; + if (AM.BaseGV) + OS << (NeedPlus ? " + " : "") + << "GV:%" << AM.BaseGV->getName(), NeedPlus = true; + + if (AM.BaseOffs) + OS << (NeedPlus ? " + " : "") << AM.BaseOffs, NeedPlus = true; + + if (AM.BaseReg) + OS << (NeedPlus ? " + " : "") + << "Base:%" << AM.BaseReg->getName(), NeedPlus = true; + if (AM.Scale) + OS << (NeedPlus ? " + " : "") + << AM.Scale << "*%" << AM.ScaledReg->getName(), NeedPlus = true; + + return OS << "]"; +} + +void ExtAddrMode::dump() const { + cerr << *this << "\n"; +} + +static bool TryMatchingScaledValue(Value *ScaleReg, int64_t Scale, + const Type *AccessTy, ExtAddrMode &AddrMode, + SmallVector &AddrModeInsts, + const TargetLowering &TLI, unsigned Depth); + +/// FindMaximalLegalAddressingMode - If we can, try to merge the computation of +/// Addr into the specified addressing mode. If Addr can't be added to AddrMode +/// this returns false. This assumes that Addr is either a pointer type or +/// intptr_t for the target. +static bool FindMaximalLegalAddressingMode(Value *Addr, const Type *AccessTy, + ExtAddrMode &AddrMode, + SmallVector &AddrModeInsts, + const TargetLowering &TLI, + unsigned Depth) { + + // If this is a global variable, fold it into the addressing mode if possible. + if (GlobalValue *GV = dyn_cast(Addr)) { + if (AddrMode.BaseGV == 0) { + AddrMode.BaseGV = GV; + if (TLI.isLegalAddressingMode(AddrMode, AccessTy)) + return true; + AddrMode.BaseGV = 0; + } + } else if (ConstantInt *CI = dyn_cast(Addr)) { + AddrMode.BaseOffs += CI->getSExtValue(); + if (TLI.isLegalAddressingMode(AddrMode, AccessTy)) + return true; + AddrMode.BaseOffs -= CI->getSExtValue(); + } else if (isa(Addr)) { + return true; + } + + // Look through constant exprs and instructions. + unsigned Opcode = ~0U; + User *AddrInst = 0; + if (Instruction *I = dyn_cast(Addr)) { + Opcode = I->getOpcode(); + AddrInst = I; + } else if (ConstantExpr *CE = dyn_cast(Addr)) { + Opcode = CE->getOpcode(); + AddrInst = CE; + } + + // Limit recursion to avoid exponential behavior. + if (Depth == 5) { AddrInst = 0; Opcode = ~0U; } + + // If this is really an instruction, add it to our list of related + // instructions. + if (Instruction *I = dyn_cast_or_null(AddrInst)) + AddrModeInsts.push_back(I); + + switch (Opcode) { + case Instruction::PtrToInt: + // PtrToInt is always a noop, as we know that the int type is pointer sized. + if (FindMaximalLegalAddressingMode(AddrInst->getOperand(0), AccessTy, + AddrMode, AddrModeInsts, TLI, Depth)) + return true; + break; + case Instruction::IntToPtr: + // This inttoptr is a no-op if the integer type is pointer sized. + if (TLI.getValueType(AddrInst->getOperand(0)->getType()) == + TLI.getPointerTy()) { + if (FindMaximalLegalAddressingMode(AddrInst->getOperand(0), AccessTy, + AddrMode, AddrModeInsts, TLI, Depth)) + return true; + } + break; + case Instruction::Add: { + // Check to see if we can merge in the RHS then the LHS. If so, we win. + ExtAddrMode BackupAddrMode = AddrMode; + unsigned OldSize = AddrModeInsts.size(); + if (FindMaximalLegalAddressingMode(AddrInst->getOperand(1), AccessTy, + AddrMode, AddrModeInsts, TLI, Depth+1) && + FindMaximalLegalAddressingMode(AddrInst->getOperand(0), AccessTy, + AddrMode, AddrModeInsts, TLI, Depth+1)) + return true; + + // Restore the old addr mode info. + AddrMode = BackupAddrMode; + AddrModeInsts.resize(OldSize); + + // Otherwise this was over-aggressive. Try merging in the LHS then the RHS. + if (FindMaximalLegalAddressingMode(AddrInst->getOperand(0), AccessTy, + AddrMode, AddrModeInsts, TLI, Depth+1) && + FindMaximalLegalAddressingMode(AddrInst->getOperand(1), AccessTy, + AddrMode, AddrModeInsts, TLI, Depth+1)) + return true; + + // Otherwise we definitely can't merge the ADD in. + AddrMode = BackupAddrMode; + AddrModeInsts.resize(OldSize); + break; + } + case Instruction::Or: { + ConstantInt *RHS = dyn_cast(AddrInst->getOperand(1)); + if (!RHS) break; + // TODO: We can handle "Or Val, Imm" iff this OR is equivalent to an ADD. + break; + } + case Instruction::Mul: + case Instruction::Shl: { + // Can only handle X*C and X << C, and can only handle this when the scale + // field is available. + ConstantInt *RHS = dyn_cast(AddrInst->getOperand(1)); + if (!RHS) break; + int64_t Scale = RHS->getSExtValue(); + if (Opcode == Instruction::Shl) + Scale = 1 << Scale; + + if (TryMatchingScaledValue(AddrInst->getOperand(0), Scale, AccessTy, + AddrMode, AddrModeInsts, TLI, Depth)) + return true; + break; + } + case Instruction::GetElementPtr: { + // Scan the GEP. We check it if it contains constant offsets and at most + // one variable offset. + int VariableOperand = -1; + unsigned VariableScale = 0; + + int64_t ConstantOffset = 0; + const TargetData *TD = TLI.getTargetData(); + gep_type_iterator GTI = gep_type_begin(AddrInst); + for (unsigned i = 1, e = AddrInst->getNumOperands(); i != e; ++i, ++GTI) { + if (const StructType *STy = dyn_cast(*GTI)) { + const StructLayout *SL = TD->getStructLayout(STy); + unsigned Idx = + cast(AddrInst->getOperand(i))->getZExtValue(); + ConstantOffset += SL->getElementOffset(Idx); + } else { + uint64_t TypeSize = TD->getTypeSize(GTI.getIndexedType()); + if (ConstantInt *CI = dyn_cast(AddrInst->getOperand(i))) { + ConstantOffset += CI->getSExtValue()*TypeSize; + } else if (TypeSize) { // Scales of zero don't do anything. + // We only allow one variable index at the moment. + if (VariableOperand != -1) { + VariableOperand = -2; + break; + } + + // Remember the variable index. + VariableOperand = i; + VariableScale = TypeSize; + } + } + } + + // If the GEP had multiple variable indices, punt. + if (VariableOperand == -2) + break; + + // A common case is for the GEP to only do a constant offset. In this case, + // just add it to the disp field and check validity. + if (VariableOperand == -1) { + AddrMode.BaseOffs += ConstantOffset; + if (ConstantOffset == 0 || TLI.isLegalAddressingMode(AddrMode, AccessTy)){ + // Check to see if we can fold the base pointer in too. + if (FindMaximalLegalAddressingMode(AddrInst->getOperand(0), AccessTy, + AddrMode, AddrModeInsts, TLI, + Depth+1)) + return true; + } + AddrMode.BaseOffs -= ConstantOffset; + } else { + // Check that this has no base reg yet. If so, we won't have a place to + // put the base of the GEP (assuming it is not a null ptr). + bool SetBaseReg = false; + if (AddrMode.HasBaseReg) { + if (!isa(AddrInst->getOperand(0))) + break; + } else { + AddrMode.HasBaseReg = true; + AddrMode.BaseReg = AddrInst->getOperand(0); + SetBaseReg = true; + } + + // See if the scale amount is valid for this target. + AddrMode.BaseOffs += ConstantOffset; + if (TryMatchingScaledValue(AddrInst->getOperand(VariableOperand), + VariableScale, AccessTy, AddrMode, + AddrModeInsts, TLI, Depth)) { + if (!SetBaseReg) return true; + + // If this match succeeded, we know that we can form an address with the + // GepBase as the basereg. See if we can match *more*. + AddrMode.HasBaseReg = false; + AddrMode.BaseReg = 0; + if (FindMaximalLegalAddressingMode(AddrInst->getOperand(0), AccessTy, + AddrMode, AddrModeInsts, TLI, + Depth+1)) + return true; + // Strange, shouldn't happen. Restore the base reg and succeed the easy + // way. + AddrMode.HasBaseReg = true; + AddrMode.BaseReg = AddrInst->getOperand(0); + return true; + } + + AddrMode.BaseOffs -= ConstantOffset; + if (SetBaseReg) { + AddrMode.HasBaseReg = false; + AddrMode.BaseReg = 0; + } + } + break; + } + } + + if (Instruction *I = dyn_cast_or_null(AddrInst)) { + assert(AddrModeInsts.back() == I && "Stack imbalance"); + AddrModeInsts.pop_back(); + } + + // Worse case, the target should support [reg] addressing modes. :) + if (!AddrMode.HasBaseReg) { + AddrMode.HasBaseReg = true; + // Still check for legality in case the target supports [imm] but not [i+r]. + if (TLI.isLegalAddressingMode(AddrMode, AccessTy)) { + AddrMode.BaseReg = Addr; + return true; + } + AddrMode.HasBaseReg = false; + } + + // If the base register is already taken, see if we can do [r+r]. + if (AddrMode.Scale == 0) { + AddrMode.Scale = 1; + if (TLI.isLegalAddressingMode(AddrMode, AccessTy)) { + AddrMode.ScaledReg = Addr; + return true; + } + AddrMode.Scale = 0; + } + // Couldn't match. + return false; +} + +/// TryMatchingScaledValue - Try adding ScaleReg*Scale to the specified +/// addressing mode. Return true if this addr mode is legal for the target, +/// false if not. +static bool TryMatchingScaledValue(Value *ScaleReg, int64_t Scale, + const Type *AccessTy, ExtAddrMode &AddrMode, + SmallVector &AddrModeInsts, + const TargetLowering &TLI, unsigned Depth) { + // If we already have a scale of this value, we can add to it, otherwise, we + // need an available scale field. + if (AddrMode.Scale != 0 && AddrMode.ScaledReg != ScaleReg) + return false; + + ExtAddrMode InputAddrMode = AddrMode; + + // Add scale to turn X*4+X*3 -> X*7. This could also do things like + // [A+B + A*7] -> [B+A*8]. + AddrMode.Scale += Scale; + AddrMode.ScaledReg = ScaleReg; + + if (TLI.isLegalAddressingMode(AddrMode, AccessTy)) { + // Okay, we decided that we can add ScaleReg+Scale to AddrMode. Check now + // to see if ScaleReg is actually X+C. If so, we can turn this into adding + // X*Scale + C*Scale to addr mode. + BinaryOperator *BinOp = dyn_cast(ScaleReg); + if (BinOp && BinOp->getOpcode() == Instruction::Add && + isa(BinOp->getOperand(1)) && InputAddrMode.ScaledReg ==0) { + + InputAddrMode.Scale = Scale; + InputAddrMode.ScaledReg = BinOp->getOperand(0); + InputAddrMode.BaseOffs += + cast(BinOp->getOperand(1))->getSExtValue()*Scale; + if (TLI.isLegalAddressingMode(InputAddrMode, AccessTy)) { + AddrModeInsts.push_back(BinOp); + AddrMode = InputAddrMode; + return true; + } + } + + // Otherwise, not (x+c)*scale, just return what we have. + return true; + } + + // Otherwise, back this attempt out. + AddrMode.Scale -= Scale; + if (AddrMode.Scale == 0) AddrMode.ScaledReg = 0; + + return false; +} + + +/// IsNonLocalValue - Return true if the specified values are defined in a +/// different basic block than BB. +static bool IsNonLocalValue(Value *V, BasicBlock *BB) { + if (Instruction *I = dyn_cast(V)) + return I->getParent() != BB; + return false; +} + +/// OptimizeLoadStoreInst - Load and Store Instructions have often have +/// addressing modes that can do significant amounts of computation. As such, +/// instruction selection will try to get the load or store to do as much +/// computation as possible for the program. The problem is that isel can only +/// see within a single block. As such, we sink as much legal addressing mode +/// stuff into the block as possible. +bool CodeGenPrepare::OptimizeLoadStoreInst(Instruction *LdStInst, Value *Addr, + const Type *AccessTy, + DenseMap &SunkAddrs) { + // Figure out what addressing mode will be built up for this operation. + SmallVector AddrModeInsts; + ExtAddrMode AddrMode; + bool Success = FindMaximalLegalAddressingMode(Addr, AccessTy, AddrMode, + AddrModeInsts, *TLI, 0); + Success = Success; assert(Success && "Couldn't select *anything*?"); + + // Check to see if any of the instructions supersumed by this addr mode are + // non-local to I's BB. + bool AnyNonLocal = false; + for (unsigned i = 0, e = AddrModeInsts.size(); i != e; ++i) { + if (IsNonLocalValue(AddrModeInsts[i], LdStInst->getParent())) { + AnyNonLocal = true; + break; + } + } + + // If all the instructions matched are already in this BB, don't do anything. + if (!AnyNonLocal) { + DEBUG(cerr << "CGP: Found local addrmode: " << AddrMode << "\n"); + return false; + } + + // Insert this computation right after this user. Since our caller is + // scanning from the top of the BB to the bottom, reuse of the expr are + // guaranteed to happen later. + BasicBlock::iterator InsertPt = LdStInst; + + // Now that we determined the addressing expression we want to use and know + // that we have to sink it into this block. Check to see if we have already + // done this for some other load/store instr in this block. If so, reuse the + // computation. + Value *&SunkAddr = SunkAddrs[Addr]; + if (SunkAddr) { + DEBUG(cerr << "CGP: Reusing nonlocal addrmode: " << AddrMode << "\n"); + if (SunkAddr->getType() != Addr->getType()) + SunkAddr = new BitCastInst(SunkAddr, Addr->getType(), "tmp", InsertPt); + } else { + DEBUG(cerr << "CGP: SINKING nonlocal addrmode: " << AddrMode << "\n"); + const Type *IntPtrTy = TLI->getTargetData()->getIntPtrType(); + + Value *Result = 0; + // Start with the scale value. + if (AddrMode.Scale) { + Value *V = AddrMode.ScaledReg; + if (V->getType() == IntPtrTy) { + // done. + } else if (isa(V->getType())) { + V = new PtrToIntInst(V, IntPtrTy, "sunkaddr", InsertPt); + } else if (cast(IntPtrTy)->getBitWidth() < + cast(V->getType())->getBitWidth()) { + V = new TruncInst(V, IntPtrTy, "sunkaddr", InsertPt); + } else { + V = new SExtInst(V, IntPtrTy, "sunkaddr", InsertPt); + } + if (AddrMode.Scale != 1) + V = BinaryOperator::createMul(V, ConstantInt::get(IntPtrTy, + AddrMode.Scale), + "sunkaddr", InsertPt); + Result = V; + } + + // Add in the base register. + if (AddrMode.BaseReg) { + Value *V = AddrMode.BaseReg; + if (V->getType() != IntPtrTy) + V = new PtrToIntInst(V, IntPtrTy, "sunkaddr", InsertPt); + if (Result) + Result = BinaryOperator::createAdd(Result, V, "sunkaddr", InsertPt); + else + Result = V; + } + + // Add in the BaseGV if present. + if (AddrMode.BaseGV) { + Value *V = new PtrToIntInst(AddrMode.BaseGV, IntPtrTy, "sunkaddr", + InsertPt); + if (Result) + Result = BinaryOperator::createAdd(Result, V, "sunkaddr", InsertPt); + else + Result = V; + } + + // Add in the Base Offset if present. + if (AddrMode.BaseOffs) { + Value *V = ConstantInt::get(IntPtrTy, AddrMode.BaseOffs); + if (Result) + Result = BinaryOperator::createAdd(Result, V, "sunkaddr", InsertPt); + else + Result = V; + } + + if (Result == 0) + SunkAddr = Constant::getNullValue(Addr->getType()); + else + SunkAddr = new IntToPtrInst(Result, Addr->getType(), "sunkaddr",InsertPt); + } + + LdStInst->replaceUsesOfWith(Addr, SunkAddr); + + if (Addr->use_empty()) + EraseDeadInstructions(Addr); + return true; +} + +// In this pass we look for GEP and cast instructions that are used +// across basic blocks and rewrite them to improve basic-block-at-a-time +// selection. +bool CodeGenPrepare::OptimizeBlock(BasicBlock &BB) { + bool MadeChange = false; + + // Split all critical edges where the dest block has a PHI and where the phi + // has shared immediate operands. + TerminatorInst *BBTI = BB.getTerminator(); + if (BBTI->getNumSuccessors() > 1) { + for (unsigned i = 0, e = BBTI->getNumSuccessors(); i != e; ++i) + if (isa(BBTI->getSuccessor(i)->begin()) && + isCriticalEdge(BBTI, i, true)) + SplitEdgeNicely(BBTI, i, this); + } + + + // Keep track of non-local addresses that have been sunk into this block. + // This allows us to avoid inserting duplicate code for blocks with multiple + // load/stores of the same address. + DenseMap SunkAddrs; + + for (BasicBlock::iterator BBI = BB.begin(), E = BB.end(); BBI != E; ) { + Instruction *I = BBI++; + + if (CastInst *CI = dyn_cast(I)) { + // If the source of the cast is a constant, then this should have + // already been constant folded. The only reason NOT to constant fold + // it is if something (e.g. LSR) was careful to place the constant + // evaluation in a block other than then one that uses it (e.g. to hoist + // the address of globals out of a loop). If this is the case, we don't + // want to forward-subst the cast. + if (isa(CI->getOperand(0))) + continue; + + if (TLI) + MadeChange |= OptimizeNoopCopyExpression(CI, *TLI); + } else if (CmpInst *CI = dyn_cast(I)) { + MadeChange |= OptimizeCmpExpression(CI); + } else if (LoadInst *LI = dyn_cast(I)) { + if (TLI) + MadeChange |= OptimizeLoadStoreInst(I, I->getOperand(0), LI->getType(), + SunkAddrs); + } else if (StoreInst *SI = dyn_cast(I)) { + if (TLI) + MadeChange |= OptimizeLoadStoreInst(I, SI->getOperand(1), + SI->getOperand(0)->getType(), + SunkAddrs); + } else if (GetElementPtrInst *GEPI = dyn_cast(I)) { + if (GEPI->hasAllZeroIndices()) { + /// The GEP operand must be a pointer, so must its result -> BitCast + Instruction *NC = new BitCastInst(GEPI->getOperand(0), GEPI->getType(), + GEPI->getName(), GEPI); + GEPI->replaceAllUsesWith(NC); + GEPI->eraseFromParent(); + MadeChange = true; + BBI = NC; + } + } else if (CallInst *CI = dyn_cast(I)) { + // If we found an inline asm expession, and if the target knows how to + // lower it to normal LLVM code, do so now. + if (TLI && isa(CI->getCalledValue())) + if (const TargetAsmInfo *TAI = + TLI->getTargetMachine().getTargetAsmInfo()) { + if (TAI->ExpandInlineAsm(CI)) + BBI = BB.begin(); + } + } + } + + return MadeChange; +} + diff --git a/lib/Transforms/Scalar/CondPropagate.cpp b/lib/Transforms/Scalar/CondPropagate.cpp new file mode 100644 index 0000000..d4c583f --- /dev/null +++ b/lib/Transforms/Scalar/CondPropagate.cpp @@ -0,0 +1,219 @@ +//===-- CondPropagate.cpp - Propagate Conditional Expressions -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass propagates information about conditional expressions through the +// program, allowing it to eliminate conditional branches in some cases. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "condprop" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Constants.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Type.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Streams.h" +using namespace llvm; + +STATISTIC(NumBrThread, "Number of CFG edges threaded through branches"); +STATISTIC(NumSwThread, "Number of CFG edges threaded through switches"); + +namespace { + struct VISIBILITY_HIDDEN CondProp : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + CondProp() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(BreakCriticalEdgesID); + //AU.addRequired(); + } + + private: + bool MadeChange; + void SimplifyBlock(BasicBlock *BB); + void SimplifyPredecessors(BranchInst *BI); + void SimplifyPredecessors(SwitchInst *SI); + void RevectorBlockTo(BasicBlock *FromBB, BasicBlock *ToBB); + }; + + char CondProp::ID = 0; + RegisterPass X("condprop", "Conditional Propagation"); +} + +FunctionPass *llvm::createCondPropagationPass() { + return new CondProp(); +} + +bool CondProp::runOnFunction(Function &F) { + bool EverMadeChange = false; + + // While we are simplifying blocks, keep iterating. + do { + MadeChange = false; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + SimplifyBlock(BB); + EverMadeChange = MadeChange; + } while (MadeChange); + return EverMadeChange; +} + +void CondProp::SimplifyBlock(BasicBlock *BB) { + if (BranchInst *BI = dyn_cast(BB->getTerminator())) { + // If this is a conditional branch based on a phi node that is defined in + // this block, see if we can simplify predecessors of this block. + if (BI->isConditional() && isa(BI->getCondition()) && + cast(BI->getCondition())->getParent() == BB) + SimplifyPredecessors(BI); + + } else if (SwitchInst *SI = dyn_cast(BB->getTerminator())) { + if (isa(SI->getCondition()) && + cast(SI->getCondition())->getParent() == BB) + SimplifyPredecessors(SI); + } + + // If possible, simplify the terminator of this block. + if (ConstantFoldTerminator(BB)) + MadeChange = true; + + // If this block ends with an unconditional branch and the only successor has + // only this block as a predecessor, merge the two blocks together. + if (BranchInst *BI = dyn_cast(BB->getTerminator())) + if (BI->isUnconditional() && BI->getSuccessor(0)->getSinglePredecessor() && + BB != BI->getSuccessor(0)) { + BasicBlock *Succ = BI->getSuccessor(0); + + // If Succ has any PHI nodes, they are all single-entry PHI's. + while (PHINode *PN = dyn_cast(Succ->begin())) { + assert(PN->getNumIncomingValues() == 1 && + "PHI doesn't match parent block"); + PN->replaceAllUsesWith(PN->getIncomingValue(0)); + PN->eraseFromParent(); + } + + // Remove BI. + BI->eraseFromParent(); + + // Move over all of the instructions. + BB->getInstList().splice(BB->end(), Succ->getInstList()); + + // Any phi nodes that had entries for Succ now have entries from BB. + Succ->replaceAllUsesWith(BB); + + // Succ is now dead, but we cannot delete it without potentially + // invalidating iterators elsewhere. Just insert an unreachable + // instruction in it. + new UnreachableInst(Succ); + MadeChange = true; + } +} + +// SimplifyPredecessors(branches) - We know that BI is a conditional branch +// based on a PHI node defined in this block. If the phi node contains constant +// operands, then the blocks corresponding to those operands can be modified to +// jump directly to the destination instead of going through this block. +void CondProp::SimplifyPredecessors(BranchInst *BI) { + // TODO: We currently only handle the most trival case, where the PHI node has + // one use (the branch), and is the only instruction besides the branch in the + // block. + PHINode *PN = cast(BI->getCondition()); + if (!PN->hasOneUse()) return; + + BasicBlock *BB = BI->getParent(); + if (&*BB->begin() != PN || &*next(BB->begin()) != BI) + return; + + // Ok, we have this really simple case, walk the PHI operands, looking for + // constants. Walk from the end to remove operands from the end when + // possible, and to avoid invalidating "i". + for (unsigned i = PN->getNumIncomingValues(); i != 0; --i) + if (ConstantInt *CB = dyn_cast(PN->getIncomingValue(i-1))) { + // If we have a constant, forward the edge from its current to its + // ultimate destination. + bool PHIGone = PN->getNumIncomingValues() == 2; + RevectorBlockTo(PN->getIncomingBlock(i-1), + BI->getSuccessor(CB->isZero())); + ++NumBrThread; + + // If there were two predecessors before this simplification, the PHI node + // will be deleted. Don't iterate through it the last time. + if (PHIGone) return; + } +} + +// SimplifyPredecessors(switch) - We know that SI is switch based on a PHI node +// defined in this block. If the phi node contains constant operands, then the +// blocks corresponding to those operands can be modified to jump directly to +// the destination instead of going through this block. +void CondProp::SimplifyPredecessors(SwitchInst *SI) { + // TODO: We currently only handle the most trival case, where the PHI node has + // one use (the branch), and is the only instruction besides the branch in the + // block. + PHINode *PN = cast(SI->getCondition()); + if (!PN->hasOneUse()) return; + + BasicBlock *BB = SI->getParent(); + if (&*BB->begin() != PN || &*next(BB->begin()) != SI) + return; + + bool RemovedPreds = false; + + // Ok, we have this really simple case, walk the PHI operands, looking for + // constants. Walk from the end to remove operands from the end when + // possible, and to avoid invalidating "i". + for (unsigned i = PN->getNumIncomingValues(); i != 0; --i) + if (ConstantInt *CI = dyn_cast(PN->getIncomingValue(i-1))) { + // If we have a constant, forward the edge from its current to its + // ultimate destination. + bool PHIGone = PN->getNumIncomingValues() == 2; + unsigned DestCase = SI->findCaseValue(CI); + RevectorBlockTo(PN->getIncomingBlock(i-1), + SI->getSuccessor(DestCase)); + ++NumSwThread; + RemovedPreds = true; + + // If there were two predecessors before this simplification, the PHI node + // will be deleted. Don't iterate through it the last time. + if (PHIGone) return; + } +} + + +// RevectorBlockTo - Revector the unconditional branch at the end of FromBB to +// the ToBB block, which is one of the successors of its current successor. +void CondProp::RevectorBlockTo(BasicBlock *FromBB, BasicBlock *ToBB) { + BranchInst *FromBr = cast(FromBB->getTerminator()); + assert(FromBr->isUnconditional() && "FromBB should end with uncond br!"); + + // Get the old block we are threading through. + BasicBlock *OldSucc = FromBr->getSuccessor(0); + + // OldSucc had multiple successors. If ToBB has multiple predecessors, then + // the edge between them would be critical, which we already took care of. + // If ToBB has single operand PHI node then take care of it here. + while (PHINode *PN = dyn_cast(ToBB->begin())) { + assert(PN->getNumIncomingValues() == 1 && "Critical Edge Found!"); + PN->replaceAllUsesWith(PN->getIncomingValue(0)); + PN->eraseFromParent(); + } + + // Update PHI nodes in OldSucc to know that FromBB no longer branches to it. + OldSucc->removePredecessor(FromBB); + + // Change FromBr to branch to the new destination. + FromBr->setSuccessor(0, ToBB); + + MadeChange = true; +} diff --git a/lib/Transforms/Scalar/ConstantProp.cpp b/lib/Transforms/Scalar/ConstantProp.cpp new file mode 100644 index 0000000..3308e33 --- /dev/null +++ b/lib/Transforms/Scalar/ConstantProp.cpp @@ -0,0 +1,90 @@ +//===- ConstantProp.cpp - Code to perform Simple Constant Propagation -----===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements constant propagation and merging: +// +// Specifically, this: +// * Converts instructions like "add int 1, 2" into 3 +// +// Notice that: +// * This pass has a habit of making definitions be dead. It is a good idea +// to to run a DIE pass sometime after running this pass. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "constprop" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Constant.h" +#include "llvm/Instruction.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/ADT/Statistic.h" +#include +using namespace llvm; + +STATISTIC(NumInstKilled, "Number of instructions killed"); + +namespace { + struct VISIBILITY_HIDDEN ConstantPropagation : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + ConstantPropagation() : FunctionPass((intptr_t)&ID) {} + + bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + } + }; + + char ConstantPropagation::ID = 0; + RegisterPass X("constprop", + "Simple constant propagation"); +} + +FunctionPass *llvm::createConstantPropagationPass() { + return new ConstantPropagation(); +} + + +bool ConstantPropagation::runOnFunction(Function &F) { + // Initialize the worklist to all of the instructions ready to process... + std::set WorkList; + for(inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) { + WorkList.insert(&*i); + } + bool Changed = false; + + while (!WorkList.empty()) { + Instruction *I = *WorkList.begin(); + WorkList.erase(WorkList.begin()); // Get an element from the worklist... + + if (!I->use_empty()) // Don't muck with dead instructions... + if (Constant *C = ConstantFoldInstruction(I)) { + // Add all of the users of this instruction to the worklist, they might + // be constant propagatable now... + for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); + UI != UE; ++UI) + WorkList.insert(cast(*UI)); + + // Replace all of the uses of a variable with uses of the constant. + I->replaceAllUsesWith(C); + + // Remove the dead instruction. + WorkList.erase(I); + I->getParent()->getInstList().erase(I); + + // We made a change to the function... + Changed = true; + ++NumInstKilled; + } + } + return Changed; +} diff --git a/lib/Transforms/Scalar/CorrelatedExprs.cpp b/lib/Transforms/Scalar/CorrelatedExprs.cpp new file mode 100644 index 0000000..655f9eb --- /dev/null +++ b/lib/Transforms/Scalar/CorrelatedExprs.cpp @@ -0,0 +1,1487 @@ +//===- CorrelatedExprs.cpp - Pass to detect and eliminated c.e.'s ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Correlated Expression Elimination propagates information from conditional +// branches to blocks dominated by destinations of the branch. It propagates +// information from the condition check itself into the body of the branch, +// allowing transformations like these for example: +// +// if (i == 7) +// ... 4*i; // constant propagation +// +// M = i+1; N = j+1; +// if (i == j) +// X = M-N; // = M-M == 0; +// +// This is called Correlated Expression Elimination because we eliminate or +// simplify expressions that are correlated with the direction of a branch. In +// this way we use static information to give us some information about the +// dynamic value of a variable. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "cee" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/Pass.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Type.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/ConstantRange.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/Statistic.h" +#include +using namespace llvm; + +STATISTIC(NumCmpRemoved, "Number of cmp instruction eliminated"); +STATISTIC(NumOperandsCann, "Number of operands canonicalized"); +STATISTIC(BranchRevectors, "Number of branches revectored"); + +namespace { + class ValueInfo; + class VISIBILITY_HIDDEN Relation { + Value *Val; // Relation to what value? + unsigned Rel; // SetCC or ICmp relation, or Add if no information + public: + Relation(Value *V) : Val(V), Rel(Instruction::Add) {} + bool operator<(const Relation &R) const { return Val < R.Val; } + Value *getValue() const { return Val; } + unsigned getRelation() const { return Rel; } + + // contradicts - Return true if the relationship specified by the operand + // contradicts already known information. + // + bool contradicts(unsigned Rel, const ValueInfo &VI) const; + + // incorporate - Incorporate information in the argument into this relation + // entry. This assumes that the information doesn't contradict itself. If + // any new information is gained, true is returned, otherwise false is + // returned to indicate that nothing was updated. + // + bool incorporate(unsigned Rel, ValueInfo &VI); + + // KnownResult - Whether or not this condition determines the result of a + // setcc or icmp in the program. False & True are intentionally 0 & 1 + // so we can convert to bool by casting after checking for unknown. + // + enum KnownResult { KnownFalse = 0, KnownTrue = 1, Unknown = 2 }; + + // getImpliedResult - If this relationship between two values implies that + // the specified relationship is true or false, return that. If we cannot + // determine the result required, return Unknown. + // + KnownResult getImpliedResult(unsigned Rel) const; + + // print - Output this relation to the specified stream + void print(std::ostream &OS) const; + void dump() const; + }; + + + // ValueInfo - One instance of this record exists for every value with + // relationships between other values. It keeps track of all of the + // relationships to other values in the program (specified with Relation) that + // are known to be valid in a region. + // + class VISIBILITY_HIDDEN ValueInfo { + // RelationShips - this value is know to have the specified relationships to + // other values. There can only be one entry per value, and this list is + // kept sorted by the Val field. + std::vector Relationships; + + // If information about this value is known or propagated from constant + // expressions, this range contains the possible values this value may hold. + ConstantRange Bounds; + + // If we find that this value is equal to another value that has a lower + // rank, this value is used as it's replacement. + // + Value *Replacement; + public: + ValueInfo(const Type *Ty) + : Bounds(Ty->isInteger() ? cast(Ty)->getBitWidth() : 32), + Replacement(0) {} + + // getBounds() - Return the constant bounds of the value... + const ConstantRange &getBounds() const { return Bounds; } + ConstantRange &getBounds() { return Bounds; } + + const std::vector &getRelationships() { return Relationships; } + + // getReplacement - Return the value this value is to be replaced with if it + // exists, otherwise return null. + // + Value *getReplacement() const { return Replacement; } + + // setReplacement - Used by the replacement calculation pass to figure out + // what to replace this value with, if anything. + // + void setReplacement(Value *Repl) { Replacement = Repl; } + + // getRelation - return the relationship entry for the specified value. + // This can invalidate references to other Relations, so use it carefully. + // + Relation &getRelation(Value *V) { + // Binary search for V's entry... + std::vector::iterator I = + std::lower_bound(Relationships.begin(), Relationships.end(), + Relation(V)); + + // If we found the entry, return it... + if (I != Relationships.end() && I->getValue() == V) + return *I; + + // Insert and return the new relationship... + return *Relationships.insert(I, V); + } + + const Relation *requestRelation(Value *V) const { + // Binary search for V's entry... + std::vector::const_iterator I = + std::lower_bound(Relationships.begin(), Relationships.end(), + Relation(V)); + if (I != Relationships.end() && I->getValue() == V) + return &*I; + return 0; + } + + // print - Output information about this value relation... + void print(std::ostream &OS, Value *V) const; + void dump() const; + }; + + // RegionInfo - Keeps track of all of the value relationships for a region. A + // region is the are dominated by a basic block. RegionInfo's keep track of + // the RegionInfo for their dominator, because anything known in a dominator + // is known to be true in a dominated block as well. + // + class VISIBILITY_HIDDEN RegionInfo { + BasicBlock *BB; + + // ValueMap - Tracks the ValueInformation known for this region + typedef std::map ValueMapTy; + ValueMapTy ValueMap; + public: + RegionInfo(BasicBlock *bb) : BB(bb) {} + + // getEntryBlock - Return the block that dominates all of the members of + // this region. + BasicBlock *getEntryBlock() const { return BB; } + + // empty - return true if this region has no information known about it. + bool empty() const { return ValueMap.empty(); } + + const RegionInfo &operator=(const RegionInfo &RI) { + ValueMap = RI.ValueMap; + return *this; + } + + // print - Output information about this region... + void print(std::ostream &OS) const; + void dump() const; + + // Allow external access. + typedef ValueMapTy::iterator iterator; + iterator begin() { return ValueMap.begin(); } + iterator end() { return ValueMap.end(); } + + ValueInfo &getValueInfo(Value *V) { + ValueMapTy::iterator I = ValueMap.lower_bound(V); + if (I != ValueMap.end() && I->first == V) return I->second; + return ValueMap.insert(I, std::make_pair(V, V->getType()))->second; + } + + const ValueInfo *requestValueInfo(Value *V) const { + ValueMapTy::const_iterator I = ValueMap.find(V); + if (I != ValueMap.end()) return &I->second; + return 0; + } + + /// removeValueInfo - Remove anything known about V from our records. This + /// works whether or not we know anything about V. + /// + void removeValueInfo(Value *V) { + ValueMap.erase(V); + } + }; + + /// CEE - Correlated Expression Elimination + class VISIBILITY_HIDDEN CEE : public FunctionPass { + std::map RankMap; + std::map RegionInfoMap; + DominatorTree *DT; + public: + static char ID; // Pass identification, replacement for typeid + CEE() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F); + + // We don't modify the program, so we preserve all analyses + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.addRequiredID(BreakCriticalEdgesID); + }; + + // print - Implement the standard print form to print out analysis + // information. + virtual void print(std::ostream &O, const Module *M) const; + + private: + RegionInfo &getRegionInfo(BasicBlock *BB) { + std::map::iterator I + = RegionInfoMap.lower_bound(BB); + if (I != RegionInfoMap.end() && I->first == BB) return I->second; + return RegionInfoMap.insert(I, std::make_pair(BB, BB))->second; + } + + void BuildRankMap(Function &F); + unsigned getRank(Value *V) const { + if (isa(V)) return 0; + std::map::const_iterator I = RankMap.find(V); + if (I != RankMap.end()) return I->second; + return 0; // Must be some other global thing + } + + bool TransformRegion(BasicBlock *BB, std::set &VisitedBlocks); + + bool ForwardCorrelatedEdgeDestination(TerminatorInst *TI, unsigned SuccNo, + RegionInfo &RI); + + void ForwardSuccessorTo(TerminatorInst *TI, unsigned Succ, BasicBlock *D, + RegionInfo &RI); + void ReplaceUsesOfValueInRegion(Value *Orig, Value *New, + BasicBlock *RegionDominator); + void CalculateRegionExitBlocks(BasicBlock *BB, BasicBlock *OldSucc, + std::vector &RegionExitBlocks); + void InsertRegionExitMerges(PHINode *NewPHI, Instruction *OldVal, + const std::vector &RegionExitBlocks); + + void PropagateBranchInfo(BranchInst *BI); + void PropagateSwitchInfo(SwitchInst *SI); + void PropagateEquality(Value *Op0, Value *Op1, RegionInfo &RI); + void PropagateRelation(unsigned Opcode, Value *Op0, + Value *Op1, RegionInfo &RI); + void UpdateUsersOfValue(Value *V, RegionInfo &RI); + void IncorporateInstruction(Instruction *Inst, RegionInfo &RI); + void ComputeReplacements(RegionInfo &RI); + + // getCmpResult - Given a icmp instruction, determine if the result is + // determined by facts we already know about the region under analysis. + // Return KnownTrue, KnownFalse, or UnKnown based on what we can determine. + Relation::KnownResult getCmpResult(CmpInst *ICI, const RegionInfo &RI); + + bool SimplifyBasicBlock(BasicBlock &BB, const RegionInfo &RI); + bool SimplifyInstruction(Instruction *Inst, const RegionInfo &RI); + }; + + char CEE::ID = 0; + RegisterPass X("cee", "Correlated Expression Elimination"); +} + +FunctionPass *llvm::createCorrelatedExpressionEliminationPass() { + return new CEE(); +} + + +bool CEE::runOnFunction(Function &F) { + // Build a rank map for the function... + BuildRankMap(F); + + // Traverse the dominator tree, computing information for each node in the + // tree. Note that our traversal will not even touch unreachable basic + // blocks. + DT = &getAnalysis(); + + std::set VisitedBlocks; + bool Changed = TransformRegion(&F.getEntryBlock(), VisitedBlocks); + + RegionInfoMap.clear(); + RankMap.clear(); + return Changed; +} + +// TransformRegion - Transform the region starting with BB according to the +// calculated region information for the block. Transforming the region +// involves analyzing any information this block provides to successors, +// propagating the information to successors, and finally transforming +// successors. +// +// This method processes the function in depth first order, which guarantees +// that we process the immediate dominator of a block before the block itself. +// Because we are passing information from immediate dominators down to +// dominatees, we obviously have to process the information source before the +// information consumer. +// +bool CEE::TransformRegion(BasicBlock *BB, std::set &VisitedBlocks){ + // Prevent infinite recursion... + if (VisitedBlocks.count(BB)) return false; + VisitedBlocks.insert(BB); + + // Get the computed region information for this block... + RegionInfo &RI = getRegionInfo(BB); + + // Compute the replacement information for this block... + ComputeReplacements(RI); + + // If debugging, print computed region information... + DEBUG(RI.print(*cerr.stream())); + + // Simplify the contents of this block... + bool Changed = SimplifyBasicBlock(*BB, RI); + + // Get the terminator of this basic block... + TerminatorInst *TI = BB->getTerminator(); + + // Loop over all of the blocks that this block is the immediate dominator for. + // Because all information known in this region is also known in all of the + // blocks that are dominated by this one, we can safely propagate the + // information down now. + // + DomTreeNode *BBDom = DT->getNode(BB); + if (!RI.empty()) { // Time opt: only propagate if we can change something + for (std::vector::iterator DI = BBDom->begin(), + E = BBDom->end(); DI != E; ++DI) { + BasicBlock *ChildBB = (*DI)->getBlock(); + assert(RegionInfoMap.find(ChildBB) == RegionInfoMap.end() && + "RegionInfo should be calculated in dominanace order!"); + getRegionInfo(ChildBB) = RI; + } + } + + // Now that all of our successors have information if they deserve it, + // propagate any information our terminator instruction finds to our + // successors. + if (BranchInst *BI = dyn_cast(TI)) { + if (BI->isConditional()) + PropagateBranchInfo(BI); + } else if (SwitchInst *SI = dyn_cast(TI)) { + PropagateSwitchInfo(SI); + } + + // If this is a branch to a block outside our region that simply performs + // another conditional branch, one whose outcome is known inside of this + // region, then vector this outgoing edge directly to the known destination. + // + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + while (ForwardCorrelatedEdgeDestination(TI, i, RI)) { + ++BranchRevectors; + Changed = true; + } + + // Now that all of our successors have information, recursively process them. + for (std::vector::iterator DI = BBDom->begin(), + E = BBDom->end(); DI != E; ++DI) { + BasicBlock *ChildBB = (*DI)->getBlock(); + Changed |= TransformRegion(ChildBB, VisitedBlocks); + } + + return Changed; +} + +// isBlockSimpleEnoughForCheck to see if the block is simple enough for us to +// revector the conditional branch in the bottom of the block, do so now. +// +static bool isBlockSimpleEnough(BasicBlock *BB) { + assert(isa(BB->getTerminator())); + BranchInst *BI = cast(BB->getTerminator()); + assert(BI->isConditional()); + + // Check the common case first: empty block, or block with just a setcc. + if (BB->size() == 1 || + (BB->size() == 2 && &BB->front() == BI->getCondition() && + BI->getCondition()->hasOneUse())) + return true; + + // Check the more complex case now... + BasicBlock::iterator I = BB->begin(); + + // FIXME: This should be reenabled once the regression with SIM is fixed! +#if 0 + // PHI Nodes are ok, just skip over them... + while (isa(*I)) ++I; +#endif + + // Accept the setcc instruction... + if (&*I == BI->getCondition()) + ++I; + + // Nothing else is acceptable here yet. We must not revector... unless we are + // at the terminator instruction. + if (&*I == BI) + return true; + + return false; +} + + +bool CEE::ForwardCorrelatedEdgeDestination(TerminatorInst *TI, unsigned SuccNo, + RegionInfo &RI) { + // If this successor is a simple block not in the current region, which + // contains only a conditional branch, we decide if the outcome of the branch + // can be determined from information inside of the region. Instead of going + // to this block, we can instead go to the destination we know is the right + // target. + // + + // Check to see if we dominate the block. If so, this block will get the + // condition turned to a constant anyway. + // + //if (EF->dominates(RI.getEntryBlock(), BB)) + // return 0; + + BasicBlock *BB = TI->getParent(); + + // Get the destination block of this edge... + BasicBlock *OldSucc = TI->getSuccessor(SuccNo); + + // Make sure that the block ends with a conditional branch and is simple + // enough for use to be able to revector over. + BranchInst *BI = dyn_cast(OldSucc->getTerminator()); + if (BI == 0 || !BI->isConditional() || !isBlockSimpleEnough(OldSucc)) + return false; + + // We can only forward the branch over the block if the block ends with a + // cmp we can determine the outcome for. + // + // FIXME: we can make this more generic. Code below already handles more + // generic case. + if (!isa(BI->getCondition())) + return false; + + // Make a new RegionInfo structure so that we can simulate the effect of the + // PHI nodes in the block we are skipping over... + // + RegionInfo NewRI(RI); + + // Remove value information for all of the values we are simulating... to make + // sure we don't have any stale information. + for (BasicBlock::iterator I = OldSucc->begin(), E = OldSucc->end(); I!=E; ++I) + if (I->getType() != Type::VoidTy) + NewRI.removeValueInfo(I); + + // Put the newly discovered information into the RegionInfo... + for (BasicBlock::iterator I = OldSucc->begin(), E = OldSucc->end(); I!=E; ++I) + if (PHINode *PN = dyn_cast(I)) { + int OpNum = PN->getBasicBlockIndex(BB); + assert(OpNum != -1 && "PHI doesn't have incoming edge for predecessor!?"); + PropagateEquality(PN, PN->getIncomingValue(OpNum), NewRI); + } else if (CmpInst *CI = dyn_cast(I)) { + Relation::KnownResult Res = getCmpResult(CI, NewRI); + if (Res == Relation::Unknown) return false; + PropagateEquality(CI, ConstantInt::get(Type::Int1Ty, Res), NewRI); + } else { + assert(isa(*I) && "Unexpected instruction type!"); + } + + // Compute the facts implied by what we have discovered... + ComputeReplacements(NewRI); + + ValueInfo &PredicateVI = NewRI.getValueInfo(BI->getCondition()); + if (PredicateVI.getReplacement() && + isa(PredicateVI.getReplacement()) && + !isa(PredicateVI.getReplacement())) { + ConstantInt *CB = cast(PredicateVI.getReplacement()); + + // Forward to the successor that corresponds to the branch we will take. + ForwardSuccessorTo(TI, SuccNo, + BI->getSuccessor(!CB->getZExtValue()), NewRI); + return true; + } + + return false; +} + +static Value *getReplacementOrValue(Value *V, RegionInfo &RI) { + if (const ValueInfo *VI = RI.requestValueInfo(V)) + if (Value *Repl = VI->getReplacement()) + return Repl; + return V; +} + +/// ForwardSuccessorTo - We have found that we can forward successor # 'SuccNo' +/// of Terminator 'TI' to the 'Dest' BasicBlock. This method performs the +/// mechanics of updating SSA information and revectoring the branch. +/// +void CEE::ForwardSuccessorTo(TerminatorInst *TI, unsigned SuccNo, + BasicBlock *Dest, RegionInfo &RI) { + // If there are any PHI nodes in the Dest BB, we must duplicate the entry + // in the PHI node for the old successor to now include an entry from the + // current basic block. + // + BasicBlock *OldSucc = TI->getSuccessor(SuccNo); + BasicBlock *BB = TI->getParent(); + + DOUT << "Forwarding branch in basic block %" << BB->getName() + << " from block %" << OldSucc->getName() << " to block %" + << Dest->getName() << "\n" + << "Before forwarding: " << *BB->getParent(); + + // Because we know that there cannot be critical edges in the flow graph, and + // that OldSucc has multiple outgoing edges, this means that Dest cannot have + // multiple incoming edges. + // +#ifndef NDEBUG + pred_iterator DPI = pred_begin(Dest); ++DPI; + assert(DPI == pred_end(Dest) && "Critical edge found!!"); +#endif + + // Loop over any PHI nodes in the destination, eliminating them, because they + // may only have one input. + // + while (PHINode *PN = dyn_cast(&Dest->front())) { + assert(PN->getNumIncomingValues() == 1 && "Crit edge found!"); + // Eliminate the PHI node + PN->replaceAllUsesWith(PN->getIncomingValue(0)); + Dest->getInstList().erase(PN); + } + + // If there are values defined in the "OldSucc" basic block, we need to insert + // PHI nodes in the regions we are dealing with to emulate them. This can + // insert dead phi nodes, but it is more trouble to see if they are used than + // to just blindly insert them. + // + if (DT->dominates(OldSucc, Dest)) { + // RegionExitBlocks - Find all of the blocks that are not dominated by Dest, + // but have predecessors that are. Additionally, prune down the set to only + // include blocks that are dominated by OldSucc as well. + // + std::vector RegionExitBlocks; + CalculateRegionExitBlocks(Dest, OldSucc, RegionExitBlocks); + + for (BasicBlock::iterator I = OldSucc->begin(), E = OldSucc->end(); + I != E; ++I) + if (I->getType() != Type::VoidTy) { + // Create and insert the PHI node into the top of Dest. + PHINode *NewPN = new PHINode(I->getType(), I->getName()+".fw_merge", + Dest->begin()); + // There is definitely an edge from OldSucc... add the edge now + NewPN->addIncoming(I, OldSucc); + + // There is also an edge from BB now, add the edge with the calculated + // value from the RI. + NewPN->addIncoming(getReplacementOrValue(I, RI), BB); + + // Make everything in the Dest region use the new PHI node now... + ReplaceUsesOfValueInRegion(I, NewPN, Dest); + + // Make sure that exits out of the region dominated by NewPN get PHI + // nodes that merge the values as appropriate. + InsertRegionExitMerges(NewPN, I, RegionExitBlocks); + } + } + + // If there were PHI nodes in OldSucc, we need to remove the entry for this + // edge from the PHI node, and we need to replace any references to the PHI + // node with a new value. + // + for (BasicBlock::iterator I = OldSucc->begin(); isa(I); ) { + PHINode *PN = cast(I); + + // Get the value flowing across the old edge and remove the PHI node entry + // for this edge: we are about to remove the edge! Don't remove the PHI + // node yet though if this is the last edge into it. + Value *EdgeValue = PN->removeIncomingValue(BB, false); + + // Make sure that anything that used to use PN now refers to EdgeValue + ReplaceUsesOfValueInRegion(PN, EdgeValue, Dest); + + // If there is only one value left coming into the PHI node, replace the PHI + // node itself with the one incoming value left. + // + if (PN->getNumIncomingValues() == 1) { + assert(PN->getNumIncomingValues() == 1); + PN->replaceAllUsesWith(PN->getIncomingValue(0)); + PN->getParent()->getInstList().erase(PN); + I = OldSucc->begin(); + } else if (PN->getNumIncomingValues() == 0) { // Nuke the PHI + // If we removed the last incoming value to this PHI, nuke the PHI node + // now. + PN->replaceAllUsesWith(Constant::getNullValue(PN->getType())); + PN->getParent()->getInstList().erase(PN); + I = OldSucc->begin(); + } else { + ++I; // Otherwise, move on to the next PHI node + } + } + + // Actually revector the branch now... + TI->setSuccessor(SuccNo, Dest); + + // If we just introduced a critical edge in the flow graph, make sure to break + // it right away... + SplitCriticalEdge(TI, SuccNo, this); + + // Make sure that we don't introduce critical edges from oldsucc now! + for (unsigned i = 0, e = OldSucc->getTerminator()->getNumSuccessors(); + i != e; ++i) + SplitCriticalEdge(OldSucc->getTerminator(), i, this); + + // Since we invalidated the CFG, recalculate the dominator set so that it is + // useful for later processing! + // FIXME: This is much worse than it really should be! + //EF->recalculate(); + + DOUT << "After forwarding: " << *BB->getParent(); +} + +/// ReplaceUsesOfValueInRegion - This method replaces all uses of Orig with uses +/// of New. It only affects instructions that are defined in basic blocks that +/// are dominated by Head. +/// +void CEE::ReplaceUsesOfValueInRegion(Value *Orig, Value *New, + BasicBlock *RegionDominator) { + assert(Orig != New && "Cannot replace value with itself"); + std::vector InstsToChange; + std::vector PHIsToChange; + InstsToChange.reserve(Orig->getNumUses()); + + // Loop over instructions adding them to InstsToChange vector, this allows us + // an easy way to avoid invalidating the use_iterator at a bad time. + for (Value::use_iterator I = Orig->use_begin(), E = Orig->use_end(); + I != E; ++I) + if (Instruction *User = dyn_cast(*I)) + if (DT->dominates(RegionDominator, User->getParent())) + InstsToChange.push_back(User); + else if (PHINode *PN = dyn_cast(User)) { + PHIsToChange.push_back(PN); + } + + // PHIsToChange contains PHI nodes that use Orig that do not live in blocks + // dominated by orig. If the block the value flows in from is dominated by + // RegionDominator, then we rewrite the PHI + for (unsigned i = 0, e = PHIsToChange.size(); i != e; ++i) { + PHINode *PN = PHIsToChange[i]; + for (unsigned j = 0, e = PN->getNumIncomingValues(); j != e; ++j) + if (PN->getIncomingValue(j) == Orig && + DT->dominates(RegionDominator, PN->getIncomingBlock(j))) + PN->setIncomingValue(j, New); + } + + // Loop over the InstsToChange list, replacing all uses of Orig with uses of + // New. This list contains all of the instructions in our region that use + // Orig. + for (unsigned i = 0, e = InstsToChange.size(); i != e; ++i) + if (PHINode *PN = dyn_cast(InstsToChange[i])) { + // PHINodes must be handled carefully. If the PHI node itself is in the + // region, we have to make sure to only do the replacement for incoming + // values that correspond to basic blocks in the region. + for (unsigned j = 0, e = PN->getNumIncomingValues(); j != e; ++j) + if (PN->getIncomingValue(j) == Orig && + DT->dominates(RegionDominator, PN->getIncomingBlock(j))) + PN->setIncomingValue(j, New); + + } else { + InstsToChange[i]->replaceUsesOfWith(Orig, New); + } +} + +static void CalcRegionExitBlocks(BasicBlock *Header, BasicBlock *BB, + std::set &Visited, + DominatorTree &DT, + std::vector &RegionExitBlocks) { + if (Visited.count(BB)) return; + Visited.insert(BB); + + if (DT.dominates(Header, BB)) { // Block in the region, recursively traverse + for (succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) + CalcRegionExitBlocks(Header, *I, Visited, DT, RegionExitBlocks); + } else { + // Header does not dominate this block, but we have a predecessor that does + // dominate us. Add ourself to the list. + RegionExitBlocks.push_back(BB); + } +} + +/// CalculateRegionExitBlocks - Find all of the blocks that are not dominated by +/// BB, but have predecessors that are. Additionally, prune down the set to +/// only include blocks that are dominated by OldSucc as well. +/// +void CEE::CalculateRegionExitBlocks(BasicBlock *BB, BasicBlock *OldSucc, + std::vector &RegionExitBlocks){ + std::set Visited; // Don't infinite loop + + // Recursively calculate blocks we are interested in... + CalcRegionExitBlocks(BB, BB, Visited, *DT, RegionExitBlocks); + + // Filter out blocks that are not dominated by OldSucc... + for (unsigned i = 0; i != RegionExitBlocks.size(); ) { + if (DT->dominates(OldSucc, RegionExitBlocks[i])) + ++i; // Block is ok, keep it. + else { + // Move to end of list... + std::swap(RegionExitBlocks[i], RegionExitBlocks.back()); + RegionExitBlocks.pop_back(); // Nuke the end + } + } +} + +void CEE::InsertRegionExitMerges(PHINode *BBVal, Instruction *OldVal, + const std::vector &RegionExitBlocks) { + assert(BBVal->getType() == OldVal->getType() && "Should be derived values!"); + BasicBlock *BB = BBVal->getParent(); + + // Loop over all of the blocks we have to place PHIs in, doing it. + for (unsigned i = 0, e = RegionExitBlocks.size(); i != e; ++i) { + BasicBlock *FBlock = RegionExitBlocks[i]; // Block on the frontier + + // Create the new PHI node + PHINode *NewPN = new PHINode(BBVal->getType(), + OldVal->getName()+".fw_frontier", + FBlock->begin()); + + // Add an incoming value for every predecessor of the block... + for (pred_iterator PI = pred_begin(FBlock), PE = pred_end(FBlock); + PI != PE; ++PI) { + // If the incoming edge is from the region dominated by BB, use BBVal, + // otherwise use OldVal. + NewPN->addIncoming(DT->dominates(BB, *PI) ? BBVal : OldVal, *PI); + } + + // Now make everyone dominated by this block use this new value! + ReplaceUsesOfValueInRegion(OldVal, NewPN, FBlock); + } +} + + + +// BuildRankMap - This method builds the rank map data structure which gives +// each instruction/value in the function a value based on how early it appears +// in the function. We give constants and globals rank 0, arguments are +// numbered starting at one, and instructions are numbered in reverse post-order +// from where the arguments leave off. This gives instructions in loops higher +// values than instructions not in loops. +// +void CEE::BuildRankMap(Function &F) { + unsigned Rank = 1; // Skip rank zero. + + // Number the arguments... + for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) + RankMap[I] = Rank++; + + // Number the instructions in reverse post order... + ReversePostOrderTraversal RPOT(&F); + for (ReversePostOrderTraversal::rpo_iterator I = RPOT.begin(), + E = RPOT.end(); I != E; ++I) + for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end(); + BBI != E; ++BBI) + if (BBI->getType() != Type::VoidTy) + RankMap[BBI] = Rank++; +} + + +// PropagateBranchInfo - When this method is invoked, we need to propagate +// information derived from the branch condition into the true and false +// branches of BI. Since we know that there aren't any critical edges in the +// flow graph, this can proceed unconditionally. +// +void CEE::PropagateBranchInfo(BranchInst *BI) { + assert(BI->isConditional() && "Must be a conditional branch!"); + + // Propagate information into the true block... + // + PropagateEquality(BI->getCondition(), ConstantInt::getTrue(), + getRegionInfo(BI->getSuccessor(0))); + + // Propagate information into the false block... + // + PropagateEquality(BI->getCondition(), ConstantInt::getFalse(), + getRegionInfo(BI->getSuccessor(1))); +} + + +// PropagateSwitchInfo - We need to propagate the value tested by the +// switch statement through each case block. +// +void CEE::PropagateSwitchInfo(SwitchInst *SI) { + // Propagate information down each of our non-default case labels. We + // don't yet propagate information down the default label, because a + // potentially large number of inequality constraints provide less + // benefit per unit work than a single equality constraint. + // + Value *cond = SI->getCondition(); + for (unsigned i = 1; i < SI->getNumSuccessors(); ++i) + PropagateEquality(cond, SI->getSuccessorValue(i), + getRegionInfo(SI->getSuccessor(i))); +} + + +// PropagateEquality - If we discover that two values are equal to each other in +// a specified region, propagate this knowledge recursively. +// +void CEE::PropagateEquality(Value *Op0, Value *Op1, RegionInfo &RI) { + if (Op0 == Op1) return; // Gee whiz. Are these really equal each other? + + if (isa(Op0)) // Make sure the constant is always Op1 + std::swap(Op0, Op1); + + // Make sure we don't already know these are equal, to avoid infinite loops... + ValueInfo &VI = RI.getValueInfo(Op0); + + // Get information about the known relationship between Op0 & Op1 + Relation &KnownRelation = VI.getRelation(Op1); + + // If we already know they're equal, don't reprocess... + if (KnownRelation.getRelation() == FCmpInst::FCMP_OEQ || + KnownRelation.getRelation() == ICmpInst::ICMP_EQ) + return; + + // If this is boolean, check to see if one of the operands is a constant. If + // it's a constant, then see if the other one is one of a setcc instruction, + // an AND, OR, or XOR instruction. + // + ConstantInt *CB = dyn_cast(Op1); + if (CB && Op1->getType() == Type::Int1Ty) { + if (Instruction *Inst = dyn_cast(Op0)) { + // If we know that this instruction is an AND instruction, and the + // result is true, this means that both operands to the OR are known + // to be true as well. + // + if (CB->getZExtValue() && Inst->getOpcode() == Instruction::And) { + PropagateEquality(Inst->getOperand(0), CB, RI); + PropagateEquality(Inst->getOperand(1), CB, RI); + } + + // If we know that this instruction is an OR instruction, and the result + // is false, this means that both operands to the OR are know to be + // false as well. + // + if (!CB->getZExtValue() && Inst->getOpcode() == Instruction::Or) { + PropagateEquality(Inst->getOperand(0), CB, RI); + PropagateEquality(Inst->getOperand(1), CB, RI); + } + + // If we know that this instruction is a NOT instruction, we know that + // the operand is known to be the inverse of whatever the current + // value is. + // + if (BinaryOperator *BOp = dyn_cast(Inst)) + if (BinaryOperator::isNot(BOp)) + PropagateEquality(BinaryOperator::getNotArgument(BOp), + ConstantInt::get(Type::Int1Ty, + !CB->getZExtValue()), RI); + + // If we know the value of a FCmp instruction, propagate the information + // about the relation into this region as well. + // + if (FCmpInst *FCI = dyn_cast(Inst)) { + if (CB->getZExtValue()) { // If we know the condition is true... + // Propagate info about the LHS to the RHS & RHS to LHS + PropagateRelation(FCI->getPredicate(), FCI->getOperand(0), + FCI->getOperand(1), RI); + PropagateRelation(FCI->getSwappedPredicate(), + FCI->getOperand(1), FCI->getOperand(0), RI); + + } else { // If we know the condition is false... + // We know the opposite of the condition is true... + FCmpInst::Predicate C = FCI->getInversePredicate(); + + PropagateRelation(C, FCI->getOperand(0), FCI->getOperand(1), RI); + PropagateRelation(FCmpInst::getSwappedPredicate(C), + FCI->getOperand(1), FCI->getOperand(0), RI); + } + } + + // If we know the value of a ICmp instruction, propagate the information + // about the relation into this region as well. + // + if (ICmpInst *ICI = dyn_cast(Inst)) { + if (CB->getZExtValue()) { // If we know the condition is true... + // Propagate info about the LHS to the RHS & RHS to LHS + PropagateRelation(ICI->getPredicate(), ICI->getOperand(0), + ICI->getOperand(1), RI); + PropagateRelation(ICI->getSwappedPredicate(), ICI->getOperand(1), + ICI->getOperand(1), RI); + + } else { // If we know the condition is false ... + // We know the opposite of the condition is true... + ICmpInst::Predicate C = ICI->getInversePredicate(); + + PropagateRelation(C, ICI->getOperand(0), ICI->getOperand(1), RI); + PropagateRelation(ICmpInst::getSwappedPredicate(C), + ICI->getOperand(1), ICI->getOperand(0), RI); + } + } + } + } + + // Propagate information about Op0 to Op1 & visa versa + PropagateRelation(ICmpInst::ICMP_EQ, Op0, Op1, RI); + PropagateRelation(ICmpInst::ICMP_EQ, Op1, Op0, RI); + PropagateRelation(FCmpInst::FCMP_OEQ, Op0, Op1, RI); + PropagateRelation(FCmpInst::FCMP_OEQ, Op1, Op0, RI); +} + + +// PropagateRelation - We know that the specified relation is true in all of the +// blocks in the specified region. Propagate the information about Op0 and +// anything derived from it into this region. +// +void CEE::PropagateRelation(unsigned Opcode, Value *Op0, + Value *Op1, RegionInfo &RI) { + assert(Op0->getType() == Op1->getType() && "Equal types expected!"); + + // Constants are already pretty well understood. We will apply information + // about the constant to Op1 in another call to PropagateRelation. + // + if (isa(Op0)) return; + + // Get the region information for this block to update... + ValueInfo &VI = RI.getValueInfo(Op0); + + // Get information about the known relationship between Op0 & Op1 + Relation &Op1R = VI.getRelation(Op1); + + // Quick bailout for common case if we are reprocessing an instruction... + if (Op1R.getRelation() == Opcode) + return; + + // If we already have information that contradicts the current information we + // are propagating, ignore this info. Something bad must have happened! + // + if (Op1R.contradicts(Opcode, VI)) { + Op1R.contradicts(Opcode, VI); + cerr << "Contradiction found for opcode: " + << ((isa(Op0)||isa(Op1)) ? + Instruction::getOpcodeName(Instruction::ICmp) : + Instruction::getOpcodeName(Opcode)) + << "\n"; + Op1R.print(*cerr.stream()); + return; + } + + // If the information propagated is new, then we want process the uses of this + // instruction to propagate the information down to them. + // + if (Op1R.incorporate(Opcode, VI)) + UpdateUsersOfValue(Op0, RI); +} + + +// UpdateUsersOfValue - The information about V in this region has been updated. +// Propagate this to all consumers of the value. +// +void CEE::UpdateUsersOfValue(Value *V, RegionInfo &RI) { + for (Value::use_iterator I = V->use_begin(), E = V->use_end(); + I != E; ++I) + if (Instruction *Inst = dyn_cast(*I)) { + // If this is an instruction using a value that we know something about, + // try to propagate information to the value produced by the + // instruction. We can only do this if it is an instruction we can + // propagate information for (a setcc for example), and we only WANT to + // do this if the instruction dominates this region. + // + // If the instruction doesn't dominate this region, then it cannot be + // used in this region and we don't care about it. If the instruction + // is IN this region, then we will simplify the instruction before we + // get to uses of it anyway, so there is no reason to bother with it + // here. This check is also effectively checking to make sure that Inst + // is in the same function as our region (in case V is a global f.e.). + // + if (DT->properlyDominates(Inst->getParent(), RI.getEntryBlock())) + IncorporateInstruction(Inst, RI); + } +} + +// IncorporateInstruction - We just updated the information about one of the +// operands to the specified instruction. Update the information about the +// value produced by this instruction +// +void CEE::IncorporateInstruction(Instruction *Inst, RegionInfo &RI) { + if (CmpInst *CI = dyn_cast(Inst)) { + // See if we can figure out a result for this instruction... + Relation::KnownResult Result = getCmpResult(CI, RI); + if (Result != Relation::Unknown) { + PropagateEquality(CI, ConstantInt::get(Type::Int1Ty, Result != 0), RI); + } + } +} + + +// ComputeReplacements - Some values are known to be equal to other values in a +// region. For example if there is a comparison of equality between a variable +// X and a constant C, we can replace all uses of X with C in the region we are +// interested in. We generalize this replacement to replace variables with +// other variables if they are equal and there is a variable with lower rank +// than the current one. This offers a canonicalizing property that exposes +// more redundancies for later transformations to take advantage of. +// +void CEE::ComputeReplacements(RegionInfo &RI) { + // Loop over all of the values in the region info map... + for (RegionInfo::iterator I = RI.begin(), E = RI.end(); I != E; ++I) { + ValueInfo &VI = I->second; + + // If we know that this value is a particular constant, set Replacement to + // the constant... + Value *Replacement = 0; + const APInt * Rplcmnt = VI.getBounds().getSingleElement(); + if (Rplcmnt) + Replacement = ConstantInt::get(*Rplcmnt); + + // If this value is not known to be some constant, figure out the lowest + // rank value that it is known to be equal to (if anything). + // + if (Replacement == 0) { + // Find out if there are any equality relationships with values of lower + // rank than VI itself... + unsigned MinRank = getRank(I->first); + + // Loop over the relationships known about Op0. + const std::vector &Relationships = VI.getRelationships(); + for (unsigned i = 0, e = Relationships.size(); i != e; ++i) + if (Relationships[i].getRelation() == FCmpInst::FCMP_OEQ) { + unsigned R = getRank(Relationships[i].getValue()); + if (R < MinRank) { + MinRank = R; + Replacement = Relationships[i].getValue(); + } + } + else if (Relationships[i].getRelation() == ICmpInst::ICMP_EQ) { + unsigned R = getRank(Relationships[i].getValue()); + if (R < MinRank) { + MinRank = R; + Replacement = Relationships[i].getValue(); + } + } + } + + // If we found something to replace this value with, keep track of it. + if (Replacement) + VI.setReplacement(Replacement); + } +} + +// SimplifyBasicBlock - Given information about values in region RI, simplify +// the instructions in the specified basic block. +// +bool CEE::SimplifyBasicBlock(BasicBlock &BB, const RegionInfo &RI) { + bool Changed = false; + for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ) { + Instruction *Inst = I++; + + // Convert instruction arguments to canonical forms... + Changed |= SimplifyInstruction(Inst, RI); + + if (CmpInst *CI = dyn_cast(Inst)) { + // Try to simplify a setcc instruction based on inherited information + Relation::KnownResult Result = getCmpResult(CI, RI); + if (Result != Relation::Unknown) { + DEBUG(cerr << "Replacing icmp with " << Result + << " constant: " << *CI); + + CI->replaceAllUsesWith(ConstantInt::get(Type::Int1Ty, (bool)Result)); + // The instruction is now dead, remove it from the program. + CI->getParent()->getInstList().erase(CI); + ++NumCmpRemoved; + Changed = true; + } + } + } + + return Changed; +} + +// SimplifyInstruction - Inspect the operands of the instruction, converting +// them to their canonical form if possible. This takes care of, for example, +// replacing a value 'X' with a constant 'C' if the instruction in question is +// dominated by a true seteq 'X', 'C'. +// +bool CEE::SimplifyInstruction(Instruction *I, const RegionInfo &RI) { + bool Changed = false; + + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (const ValueInfo *VI = RI.requestValueInfo(I->getOperand(i))) + if (Value *Repl = VI->getReplacement()) { + // If we know if a replacement with lower rank than Op0, make the + // replacement now. + DOUT << "In Inst: " << *I << " Replacing operand #" << i + << " with " << *Repl << "\n"; + I->setOperand(i, Repl); + Changed = true; + ++NumOperandsCann; + } + + return Changed; +} + +// getCmpResult - Try to simplify a cmp instruction based on information +// inherited from a dominating icmp instruction. V is one of the operands to +// the icmp instruction, and VI is the set of information known about it. We +// take two cases into consideration here. If the comparison is against a +// constant value, we can use the constant range to see if the comparison is +// possible to succeed. If it is not a comparison against a constant, we check +// to see if there is a known relationship between the two values. If so, we +// may be able to eliminate the check. +// +Relation::KnownResult CEE::getCmpResult(CmpInst *CI, + const RegionInfo &RI) { + Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1); + unsigned short predicate = CI->getPredicate(); + + if (isa(Op0)) { + if (isa(Op1)) { + if (Constant *Result = ConstantFoldInstruction(CI)) { + // Wow, this is easy, directly eliminate the ICmpInst. + DEBUG(cerr << "Replacing cmp with constant fold: " << *CI); + return cast(Result)->getZExtValue() + ? Relation::KnownTrue : Relation::KnownFalse; + } + } else { + // We want to swap this instruction so that operand #0 is the constant. + std::swap(Op0, Op1); + if (isa(CI)) + predicate = cast(CI)->getSwappedPredicate(); + else + predicate = cast(CI)->getSwappedPredicate(); + } + } + + // Try to figure out what the result of this comparison will be... + Relation::KnownResult Result = Relation::Unknown; + + // We have to know something about the relationship to prove anything... + if (const ValueInfo *Op0VI = RI.requestValueInfo(Op0)) { + + // At this point, we know that if we have a constant argument that it is in + // Op1. Check to see if we know anything about comparing value with a + // constant, and if we can use this info to fold the icmp. + // + if (ConstantInt *C = dyn_cast(Op1)) { + // Check to see if we already know the result of this comparison... + ICmpInst::Predicate ipred = ICmpInst::Predicate(predicate); + ConstantRange R = ICmpInst::makeConstantRange(ipred, C->getValue()); + ConstantRange Int = R.intersectWith(Op0VI->getBounds()); + + // If the intersection of the two ranges is empty, then the condition + // could never be true! + // + if (Int.isEmptySet()) { + Result = Relation::KnownFalse; + + // Otherwise, if VI.getBounds() (the possible values) is a subset of R + // (the allowed values) then we know that the condition must always be + // true! + // + } else if (Int == Op0VI->getBounds()) { + Result = Relation::KnownTrue; + } + } else { + // If we are here, we know that the second argument is not a constant + // integral. See if we know anything about Op0 & Op1 that allows us to + // fold this anyway. + // + // Do we have value information about Op0 and a relation to Op1? + if (const Relation *Op2R = Op0VI->requestRelation(Op1)) + Result = Op2R->getImpliedResult(predicate); + } + } + return Result; +} + +//===----------------------------------------------------------------------===// +// Relation Implementation +//===----------------------------------------------------------------------===// + +// contradicts - Return true if the relationship specified by the operand +// contradicts already known information. +// +bool Relation::contradicts(unsigned Op, + const ValueInfo &VI) const { + assert (Op != Instruction::Add && "Invalid relation argument!"); + + // If this is a relationship with a constant, make sure that this relationship + // does not contradict properties known about the bounds of the constant. + // + if (ConstantInt *C = dyn_cast(Val)) + if (Op >= ICmpInst::FIRST_ICMP_PREDICATE && + Op <= ICmpInst::LAST_ICMP_PREDICATE) { + ICmpInst::Predicate ipred = ICmpInst::Predicate(Op); + if (ICmpInst::makeConstantRange(ipred, C->getValue()) + .intersectWith(VI.getBounds()).isEmptySet()) + return true; + } + + switch (Rel) { + default: assert(0 && "Unknown Relationship code!"); + case Instruction::Add: return false; // Nothing known, nothing contradicts + case ICmpInst::ICMP_EQ: + return Op == ICmpInst::ICMP_ULT || Op == ICmpInst::ICMP_SLT || + Op == ICmpInst::ICMP_UGT || Op == ICmpInst::ICMP_SGT || + Op == ICmpInst::ICMP_NE; + case ICmpInst::ICMP_NE: return Op == ICmpInst::ICMP_EQ; + case ICmpInst::ICMP_ULE: + case ICmpInst::ICMP_SLE: return Op == ICmpInst::ICMP_UGT || + Op == ICmpInst::ICMP_SGT; + case ICmpInst::ICMP_UGE: + case ICmpInst::ICMP_SGE: return Op == ICmpInst::ICMP_ULT || + Op == ICmpInst::ICMP_SLT; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + return Op == ICmpInst::ICMP_EQ || Op == ICmpInst::ICMP_UGT || + Op == ICmpInst::ICMP_SGT || Op == ICmpInst::ICMP_UGE || + Op == ICmpInst::ICMP_SGE; + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + return Op == ICmpInst::ICMP_EQ || Op == ICmpInst::ICMP_ULT || + Op == ICmpInst::ICMP_SLT || Op == ICmpInst::ICMP_ULE || + Op == ICmpInst::ICMP_SLE; + case FCmpInst::FCMP_OEQ: + return Op == FCmpInst::FCMP_OLT || Op == FCmpInst::FCMP_OGT || + Op == FCmpInst::FCMP_ONE; + case FCmpInst::FCMP_ONE: return Op == FCmpInst::FCMP_OEQ; + case FCmpInst::FCMP_OLE: return Op == FCmpInst::FCMP_OGT; + case FCmpInst::FCMP_OGE: return Op == FCmpInst::FCMP_OLT; + case FCmpInst::FCMP_OLT: + return Op == FCmpInst::FCMP_OEQ || Op == FCmpInst::FCMP_OGT || + Op == FCmpInst::FCMP_OGE; + case FCmpInst::FCMP_OGT: + return Op == FCmpInst::FCMP_OEQ || Op == FCmpInst::FCMP_OLT || + Op == FCmpInst::FCMP_OLE; + } +} + +// incorporate - Incorporate information in the argument into this relation +// entry. This assumes that the information doesn't contradict itself. If any +// new information is gained, true is returned, otherwise false is returned to +// indicate that nothing was updated. +// +bool Relation::incorporate(unsigned Op, ValueInfo &VI) { + assert(!contradicts(Op, VI) && + "Cannot incorporate contradictory information!"); + + // If this is a relationship with a constant, make sure that we update the + // range that is possible for the value to have... + // + if (ConstantInt *C = dyn_cast(Val)) + if (Op >= ICmpInst::FIRST_ICMP_PREDICATE && + Op <= ICmpInst::LAST_ICMP_PREDICATE) { + ICmpInst::Predicate ipred = ICmpInst::Predicate(Op); + VI.getBounds() = + ICmpInst::makeConstantRange(ipred, C->getValue()) + .intersectWith(VI.getBounds()); + } + + switch (Rel) { + default: assert(0 && "Unknown prior value!"); + case Instruction::Add: Rel = Op; return true; + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: return false; // Nothing is more precise + case ICmpInst::ICMP_ULE: + case ICmpInst::ICMP_SLE: + if (Op == ICmpInst::ICMP_EQ || Op == ICmpInst::ICMP_ULT || + Op == ICmpInst::ICMP_SLT) { + Rel = Op; + return true; + } else if (Op == ICmpInst::ICMP_NE) { + Rel = Rel == ICmpInst::ICMP_ULE ? ICmpInst::ICMP_ULT : + ICmpInst::ICMP_SLT; + return true; + } + return false; + case ICmpInst::ICMP_UGE: + case ICmpInst::ICMP_SGE: + if (Op == ICmpInst::ICMP_EQ || ICmpInst::ICMP_UGT || + Op == ICmpInst::ICMP_SGT) { + Rel = Op; + return true; + } else if (Op == ICmpInst::ICMP_NE) { + Rel = Rel == ICmpInst::ICMP_UGE ? ICmpInst::ICMP_UGT : + ICmpInst::ICMP_SGT; + return true; + } + return false; + case FCmpInst::FCMP_OEQ: return false; // Nothing is more precise + case FCmpInst::FCMP_ONE: return false; // Nothing is more precise + case FCmpInst::FCMP_OLT: return false; // Nothing is more precise + case FCmpInst::FCMP_OGT: return false; // Nothing is more precise + case FCmpInst::FCMP_OLE: + if (Op == FCmpInst::FCMP_OEQ || Op == FCmpInst::FCMP_OLT) { + Rel = Op; + return true; + } else if (Op == FCmpInst::FCMP_ONE) { + Rel = FCmpInst::FCMP_OLT; + return true; + } + return false; + case FCmpInst::FCMP_OGE: + return Op == FCmpInst::FCMP_OLT; + if (Op == FCmpInst::FCMP_OEQ || Op == FCmpInst::FCMP_OGT) { + Rel = Op; + return true; + } else if (Op == FCmpInst::FCMP_ONE) { + Rel = FCmpInst::FCMP_OGT; + return true; + } + return false; + } +} + +// getImpliedResult - If this relationship between two values implies that +// the specified relationship is true or false, return that. If we cannot +// determine the result required, return Unknown. +// +Relation::KnownResult +Relation::getImpliedResult(unsigned Op) const { + if (Rel == Op) return KnownTrue; + if (Op >= ICmpInst::FIRST_ICMP_PREDICATE && + Op <= ICmpInst::LAST_ICMP_PREDICATE) { + if (Rel == unsigned(ICmpInst::getInversePredicate(ICmpInst::Predicate(Op)))) + return KnownFalse; + } else if (Op <= FCmpInst::LAST_FCMP_PREDICATE) { + if (Rel == unsigned(FCmpInst::getInversePredicate(FCmpInst::Predicate(Op)))) + return KnownFalse; + } + + switch (Rel) { + default: assert(0 && "Unknown prior value!"); + case ICmpInst::ICMP_EQ: + if (Op == ICmpInst::ICMP_ULE || Op == ICmpInst::ICMP_SLE || + Op == ICmpInst::ICMP_UGE || Op == ICmpInst::ICMP_SGE) return KnownTrue; + if (Op == ICmpInst::ICMP_ULT || Op == ICmpInst::ICMP_SLT || + Op == ICmpInst::ICMP_UGT || Op == ICmpInst::ICMP_SGT) return KnownFalse; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + if (Op == ICmpInst::ICMP_ULE || Op == ICmpInst::ICMP_SLE || + Op == ICmpInst::ICMP_NE) return KnownTrue; + if (Op == ICmpInst::ICMP_EQ) return KnownFalse; + break; + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + if (Op == ICmpInst::ICMP_UGE || Op == ICmpInst::ICMP_SGE || + Op == ICmpInst::ICMP_NE) return KnownTrue; + if (Op == ICmpInst::ICMP_EQ) return KnownFalse; + break; + case FCmpInst::FCMP_OEQ: + if (Op == FCmpInst::FCMP_OLE || Op == FCmpInst::FCMP_OGE) return KnownTrue; + if (Op == FCmpInst::FCMP_OLT || Op == FCmpInst::FCMP_OGT) return KnownFalse; + break; + case FCmpInst::FCMP_OLT: + if (Op == FCmpInst::FCMP_ONE || Op == FCmpInst::FCMP_OLE) return KnownTrue; + if (Op == FCmpInst::FCMP_OEQ) return KnownFalse; + break; + case FCmpInst::FCMP_OGT: + if (Op == FCmpInst::FCMP_ONE || Op == FCmpInst::FCMP_OGE) return KnownTrue; + if (Op == FCmpInst::FCMP_OEQ) return KnownFalse; + break; + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_SLE: + case ICmpInst::ICMP_ULE: + case ICmpInst::ICMP_UGE: + case ICmpInst::ICMP_SGE: + case FCmpInst::FCMP_ONE: + case FCmpInst::FCMP_OLE: + case FCmpInst::FCMP_OGE: + case FCmpInst::FCMP_FALSE: + case FCmpInst::FCMP_ORD: + case FCmpInst::FCMP_UNO: + case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_ULE: + case FCmpInst::FCMP_UNE: + case FCmpInst::FCMP_TRUE: + break; + } + return Unknown; +} + + +//===----------------------------------------------------------------------===// +// Printing Support... +//===----------------------------------------------------------------------===// + +// print - Implement the standard print form to print out analysis information. +void CEE::print(std::ostream &O, const Module *M) const { + O << "\nPrinting Correlated Expression Info:\n"; + for (std::map::const_iterator I = + RegionInfoMap.begin(), E = RegionInfoMap.end(); I != E; ++I) + I->second.print(O); +} + +// print - Output information about this region... +void RegionInfo::print(std::ostream &OS) const { + if (ValueMap.empty()) return; + + OS << " RegionInfo for basic block: " << BB->getName() << "\n"; + for (std::map::const_iterator + I = ValueMap.begin(), E = ValueMap.end(); I != E; ++I) + I->second.print(OS, I->first); + OS << "\n"; +} + +// print - Output information about this value relation... +void ValueInfo::print(std::ostream &OS, Value *V) const { + if (Relationships.empty()) return; + + if (V) { + OS << " ValueInfo for: "; + WriteAsOperand(OS, V); + } + OS << "\n Bounds = " << Bounds << "\n"; + if (Replacement) { + OS << " Replacement = "; + WriteAsOperand(OS, Replacement); + OS << "\n"; + } + for (unsigned i = 0, e = Relationships.size(); i != e; ++i) + Relationships[i].print(OS); +} + +// print - Output this relation to the specified stream +void Relation::print(std::ostream &OS) const { + OS << " is "; + switch (Rel) { + default: OS << "*UNKNOWN*"; break; + case ICmpInst::ICMP_EQ: + case FCmpInst::FCMP_ORD: + case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_OEQ: OS << "== "; break; + case ICmpInst::ICMP_NE: + case FCmpInst::FCMP_UNO: + case FCmpInst::FCMP_UNE: + case FCmpInst::FCMP_ONE: OS << "!= "; break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_OLT: OS << "< "; break; + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_OGT: OS << "> "; break; + case ICmpInst::ICMP_ULE: + case ICmpInst::ICMP_SLE: + case FCmpInst::FCMP_ULE: + case FCmpInst::FCMP_OLE: OS << "<= "; break; + case ICmpInst::ICMP_UGE: + case ICmpInst::ICMP_SGE: + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OGE: OS << ">= "; break; + } + + WriteAsOperand(OS, Val); + OS << "\n"; +} + +// Don't inline these methods or else we won't be able to call them from GDB! +void Relation::dump() const { print(*cerr.stream()); } +void ValueInfo::dump() const { print(*cerr.stream(), 0); } +void RegionInfo::dump() const { print(*cerr.stream()); } diff --git a/lib/Transforms/Scalar/DCE.cpp b/lib/Transforms/Scalar/DCE.cpp new file mode 100644 index 0000000..163c2b0 --- /dev/null +++ b/lib/Transforms/Scalar/DCE.cpp @@ -0,0 +1,130 @@ +//===- DCE.cpp - Code to perform dead code elimination --------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements dead inst elimination and dead code elimination. +// +// Dead Inst Elimination performs a single pass over the function removing +// instructions that are obviously dead. Dead Code Elimination is similar, but +// it rechecks instructions that were used by removed instructions to see if +// they are newly dead. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "dce" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Instruction.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/ADT/Statistic.h" +#include +using namespace llvm; + +STATISTIC(DIEEliminated, "Number of insts removed by DIE pass"); +STATISTIC(DCEEliminated, "Number of insts removed"); + +namespace { + //===--------------------------------------------------------------------===// + // DeadInstElimination pass implementation + // + struct VISIBILITY_HIDDEN DeadInstElimination : public BasicBlockPass { + static char ID; // Pass identification, replacement for typeid + DeadInstElimination() : BasicBlockPass(intptr_t(&ID)) {} + virtual bool runOnBasicBlock(BasicBlock &BB) { + bool Changed = false; + for (BasicBlock::iterator DI = BB.begin(); DI != BB.end(); ) + if (dceInstruction(DI)) { + Changed = true; + ++DIEEliminated; + } else + ++DI; + return Changed; + } + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + } + }; + + char DeadInstElimination::ID = 0; + RegisterPass X("die", "Dead Instruction Elimination"); +} + +Pass *llvm::createDeadInstEliminationPass() { + return new DeadInstElimination(); +} + + +namespace { + //===--------------------------------------------------------------------===// + // DeadCodeElimination pass implementation + // + struct DCE : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + DCE() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + } + }; + + char DCE::ID = 0; + RegisterPass Y("dce", "Dead Code Elimination"); +} + +bool DCE::runOnFunction(Function &F) { + // Start out with all of the instructions in the worklist... + std::vector WorkList; + for (inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) + WorkList.push_back(&*i); + + // Loop over the worklist finding instructions that are dead. If they are + // dead make them drop all of their uses, making other instructions + // potentially dead, and work until the worklist is empty. + // + bool MadeChange = false; + while (!WorkList.empty()) { + Instruction *I = WorkList.back(); + WorkList.pop_back(); + + if (isInstructionTriviallyDead(I)) { // If the instruction is dead. + // Loop over all of the values that the instruction uses, if there are + // instructions being used, add them to the worklist, because they might + // go dead after this one is removed. + // + for (User::op_iterator OI = I->op_begin(), E = I->op_end(); OI != E; ++OI) + if (Instruction *Used = dyn_cast(*OI)) + WorkList.push_back(Used); + + // Remove the instruction. + I->eraseFromParent(); + + // Remove the instruction from the worklist if it still exists in it. + for (std::vector::iterator WI = WorkList.begin(), + E = WorkList.end(); WI != E; ++WI) + if (*WI == I) { + WorkList.erase(WI); + --E; + --WI; + } + + MadeChange = true; + ++DCEEliminated; + } + } + return MadeChange; +} + +FunctionPass *llvm::createDeadCodeEliminationPass() { + return new DCE(); +} + diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp new file mode 100644 index 0000000..665d538 --- /dev/null +++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -0,0 +1,179 @@ +//===- DeadStoreElimination.cpp - Dead Store Elimination ------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a trivial dead store elimination that only considers +// basic-block local redundant stores. +// +// FIXME: This should eventually be extended to be a post-dominator tree +// traversal. Doing so would be pretty trivial. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "dse" +#include "llvm/Transforms/Scalar.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumStores, "Number of stores deleted"); +STATISTIC(NumOther , "Number of other instrs removed"); + +namespace { + struct VISIBILITY_HIDDEN DSE : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + DSE() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F) { + bool Changed = false; + for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) + Changed |= runOnBasicBlock(*I); + return Changed; + } + + bool runOnBasicBlock(BasicBlock &BB); + + void DeleteDeadInstructionChains(Instruction *I, + SetVector &DeadInsts); + + // getAnalysisUsage - We require post dominance frontiers (aka Control + // Dependence Graph) + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); + } + }; + char DSE::ID = 0; + RegisterPass X("dse", "Dead Store Elimination"); +} + +FunctionPass *llvm::createDeadStoreEliminationPass() { return new DSE(); } + +bool DSE::runOnBasicBlock(BasicBlock &BB) { + TargetData &TD = getAnalysis(); + AliasAnalysis &AA = getAnalysis(); + AliasSetTracker KillLocs(AA); + + // If this block ends in a return, unwind, unreachable, and eventually + // tailcall, then all allocas are dead at its end. + if (BB.getTerminator()->getNumSuccessors() == 0) { + BasicBlock *Entry = BB.getParent()->begin(); + for (BasicBlock::iterator I = Entry->begin(), E = Entry->end(); I != E; ++I) + if (AllocaInst *AI = dyn_cast(I)) { + unsigned Size = ~0U; + if (!AI->isArrayAllocation() && + AI->getType()->getElementType()->isSized()) + Size = (unsigned)TD.getTypeSize(AI->getType()->getElementType()); + KillLocs.add(AI, Size); + } + } + + // PotentiallyDeadInsts - Deleting dead stores from the program can make other + // instructions die if they were only used as operands to stores. Keep track + // of the operands to stores so that we can try deleting them at the end of + // the traversal. + SetVector PotentiallyDeadInsts; + + bool MadeChange = false; + for (BasicBlock::iterator BBI = BB.end(); BBI != BB.begin(); ) { + Instruction *I = --BBI; // Keep moving iterator backwards + + // If this is a free instruction, it makes the free'd location dead! + if (FreeInst *FI = dyn_cast(I)) { + // Free instructions make any stores to the free'd location dead. + KillLocs.add(FI); + continue; + } + + if (!isa(I) || cast(I)->isVolatile()) { + // If this is a vaarg instruction, it reads its operand. We don't model + // it correctly, so just conservatively remove all entries. + if (isa(I)) { + KillLocs.clear(); + continue; + } + + // If this is a non-store instruction, it makes everything referenced no + // longer killed. Remove anything aliased from the alias set tracker. + KillLocs.remove(I); + continue; + } + + // If this is a non-volatile store instruction, and if it is already in + // the stored location is already in the tracker, then this is a dead + // store. We can just delete it here, but while we're at it, we also + // delete any trivially dead expression chains. + unsigned ValSize = (unsigned)TD.getTypeSize(I->getOperand(0)->getType()); + Value *Ptr = I->getOperand(1); + + if (AliasSet *AS = KillLocs.getAliasSetForPointerIfExists(Ptr, ValSize)) + for (AliasSet::iterator ASI = AS->begin(), E = AS->end(); ASI != E; ++ASI) + if (ASI.getSize() >= ValSize && // Overwriting all of this store. + AA.alias(ASI.getPointer(), ASI.getSize(), Ptr, ValSize) + == AliasAnalysis::MustAlias) { + // If we found a must alias in the killed set, then this store really + // is dead. Remember that the various operands of the store now have + // fewer users. At the end we will see if we can delete any values + // that are dead as part of the store becoming dead. + if (Instruction *Op = dyn_cast(I->getOperand(0))) + PotentiallyDeadInsts.insert(Op); + if (Instruction *Op = dyn_cast(Ptr)) + PotentiallyDeadInsts.insert(Op); + + // Delete it now. + ++BBI; // Don't invalidate iterator. + BB.getInstList().erase(I); // Nuke the store! + ++NumStores; + MadeChange = true; + goto BigContinue; + } + + // Otherwise, this is a non-dead store just add it to the set of dead + // locations. + KillLocs.add(cast(I)); + BigContinue:; + } + + while (!PotentiallyDeadInsts.empty()) { + Instruction *I = PotentiallyDeadInsts.back(); + PotentiallyDeadInsts.pop_back(); + DeleteDeadInstructionChains(I, PotentiallyDeadInsts); + } + return MadeChange; +} + +void DSE::DeleteDeadInstructionChains(Instruction *I, + SetVector &DeadInsts) { + // Instruction must be dead. + if (!I->use_empty() || !isInstructionTriviallyDead(I)) return; + + // Let the alias analysis know that we have nuked a value. + getAnalysis().deleteValue(I); + + // See if this made any operands dead. We do it this way in case the + // instruction uses the same operand twice. We don't want to delete a + // value then reference it. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + if (Instruction *Op = dyn_cast(I->getOperand(i))) + DeadInsts.insert(Op); // Attempt to nuke it later. + I->setOperand(i, 0); // Drop from the operand list. + } + + I->eraseFromParent(); + ++NumOther; +} diff --git a/lib/Transforms/Scalar/FastDSE.cpp b/lib/Transforms/Scalar/FastDSE.cpp new file mode 100644 index 0000000..72857b9 --- /dev/null +++ b/lib/Transforms/Scalar/FastDSE.cpp @@ -0,0 +1,387 @@ +//===- DeadStoreElimination.cpp - Dead Store Elimination ------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by Owen Anderson and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a trivial dead store elimination that only considers +// basic-block local redundant stores. +// +// FIXME: This should eventually be extended to be a post-dominator tree +// traversal. Doing so would be pretty trivial. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "fdse" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumFastStores, "Number of stores deleted"); +STATISTIC(NumFastOther , "Number of other instrs removed"); + +namespace { + struct VISIBILITY_HIDDEN FDSE : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + FDSE() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F) { + bool Changed = false; + for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) + Changed |= runOnBasicBlock(*I); + return Changed; + } + + bool runOnBasicBlock(BasicBlock &BB); + bool handleFreeWithNonTrivialDependency(FreeInst* F, Instruction* dependency, + SetVector& possiblyDead); + bool handleEndBlock(BasicBlock& BB, SetVector& possiblyDead); + bool RemoveUndeadPointers(Value* pointer, unsigned pointerSize, + BasicBlock::iterator& BBI, + SmallPtrSet& deadPointers, + SetVector& possiblyDead); + void DeleteDeadInstructionChains(Instruction *I, + SetVector &DeadInsts); + void TranslatePointerBitCasts(Value*& v) { + assert(isa(v->getType()) && "Translating a non-pointer type?"); + + // See through pointer-to-pointer bitcasts + while (isa(v) || isa(v)) + if (BitCastInst* C = dyn_cast(v)) + v = C->getOperand(0); + else if (GetElementPtrInst* G = dyn_cast(v)) + v = G->getOperand(0); + } + + // getAnalysisUsage - We require post dominance frontiers (aka Control + // Dependence Graph) + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); + AU.addPreserved(); + } + }; + char FDSE::ID = 0; + RegisterPass X("fdse", "Fast Dead Store Elimination"); +} + +FunctionPass *llvm::createFastDeadStoreEliminationPass() { return new FDSE(); } + +bool FDSE::runOnBasicBlock(BasicBlock &BB) { + MemoryDependenceAnalysis& MD = getAnalysis(); + + // Record the last-seen store to this pointer + DenseMap lastStore; + // Record instructions possibly made dead by deleting a store + SetVector possiblyDead; + + bool MadeChange = false; + + // Do a top-down walk on the BB + for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ++BBI) { + // If we find a store or a free... + if (isa(BBI) || isa(BBI)) { + Value* pointer = 0; + if (StoreInst* S = dyn_cast(BBI)) + pointer = S->getPointerOperand(); + else if (FreeInst* F = dyn_cast(BBI)) + pointer = F->getPointerOperand(); + + assert(pointer && "Not a free or a store?"); + + StoreInst*& last = lastStore[pointer]; + bool deletedStore = false; + + // ... to a pointer that has been stored to before... + if (last) { + + Instruction* dep = MD.getDependency(BBI); + + // ... and no other memory dependencies are between them.... + while (dep != MemoryDependenceAnalysis::None && + dep != MemoryDependenceAnalysis::NonLocal && + isa(dep)) { + if (dep == last) { + + // Remove it! + MD.removeInstruction(last); + + // DCE instructions only used to calculate that store + if (Instruction* D = dyn_cast(last->getOperand(0))) + possiblyDead.insert(D); + if (Instruction* D = dyn_cast(last->getOperand(1))) + possiblyDead.insert(D); + + last->eraseFromParent(); + NumFastStores++; + deletedStore = true; + MadeChange = true; + + break; + } else { + dep = MD.getDependency(BBI, dep); + } + } + } + + // Handle frees whose dependencies are non-trivial + if (FreeInst* F = dyn_cast(BBI)) + if (!deletedStore) + MadeChange |= handleFreeWithNonTrivialDependency(F, MD.getDependency(F), + possiblyDead); + + // Update our most-recent-store map + if (StoreInst* S = dyn_cast(BBI)) + last = S; + else + last = 0; + } + } + + // If this block ends in a return, unwind, unreachable, and eventually + // tailcall, then all allocas are dead at its end. + if (BB.getTerminator()->getNumSuccessors() == 0) + MadeChange |= handleEndBlock(BB, possiblyDead); + + // Do a trivial DCE + while (!possiblyDead.empty()) { + Instruction *I = possiblyDead.back(); + possiblyDead.pop_back(); + DeleteDeadInstructionChains(I, possiblyDead); + } + + return MadeChange; +} + +/// handleFreeWithNonTrivialDependency - Handle frees of entire structures whose +/// dependency is a store to a field of that structure +bool FDSE::handleFreeWithNonTrivialDependency(FreeInst* F, Instruction* dep, + SetVector& possiblyDead) { + TargetData &TD = getAnalysis(); + AliasAnalysis &AA = getAnalysis(); + MemoryDependenceAnalysis& MD = getAnalysis(); + + if (dep == MemoryDependenceAnalysis::None || + dep == MemoryDependenceAnalysis::NonLocal) + return false; + + StoreInst* dependency = dyn_cast(dep); + if (!dependency) + return false; + + Value* depPointer = dependency->getPointerOperand(); + unsigned depPointerSize = TD.getTypeSize(dependency->getOperand(0)->getType()); + + // Check for aliasing + AliasAnalysis::AliasResult A = AA.alias(F->getPointerOperand(), ~0UL, + depPointer, depPointerSize); + + if (A == AliasAnalysis::MustAlias) { + // Remove it! + MD.removeInstruction(dependency); + + // DCE instructions only used to calculate that store + if (Instruction* D = dyn_cast(dependency->getOperand(0))) + possiblyDead.insert(D); + if (Instruction* D = dyn_cast(dependency->getOperand(1))) + possiblyDead.insert(D); + + dependency->eraseFromParent(); + NumFastStores++; + return true; + } + + return false; +} + +/// handleEndBlock - Remove dead stores to stack-allocated locations in the function +/// end block +bool FDSE::handleEndBlock(BasicBlock& BB, SetVector& possiblyDead) { + TargetData &TD = getAnalysis(); + AliasAnalysis &AA = getAnalysis(); + MemoryDependenceAnalysis& MD = getAnalysis(); + + bool MadeChange = false; + + // Pointers alloca'd in this function are dead in the end block + SmallPtrSet deadPointers; + + // Find all of the alloca'd pointers in the entry block + BasicBlock *Entry = BB.getParent()->begin(); + for (BasicBlock::iterator I = Entry->begin(), E = Entry->end(); I != E; ++I) + if (AllocaInst *AI = dyn_cast(I)) + deadPointers.insert(AI); + + // Scan the basic block backwards + for (BasicBlock::iterator BBI = BB.end(); BBI != BB.begin(); ){ + --BBI; + + if (deadPointers.empty()) + break; + + Value* killPointer = 0; + unsigned killPointerSize = 0; + + // If we find a store whose pointer is dead... + if (StoreInst* S = dyn_cast(BBI)) { + Value* pointerOperand = S->getPointerOperand(); + // See through pointer-to-pointer bitcasts + TranslatePointerBitCasts(pointerOperand); + + if (deadPointers.count(pointerOperand)){ + // Remove it! + MD.removeInstruction(S); + + // DCE instructions only used to calculate that store + if (Instruction* D = dyn_cast(S->getOperand(0))) + possiblyDead.insert(D); + if (Instruction* D = dyn_cast(S->getOperand(1))) + possiblyDead.insert(D); + + BBI++; + S->eraseFromParent(); + NumFastStores++; + MadeChange = true; + } + + // If we encounter a use of the pointer, it is no longer considered dead + } else if (LoadInst* L = dyn_cast(BBI)) { + killPointer = L->getPointerOperand(); + killPointerSize = TD.getTypeSize(L->getType()); + } else if (VAArgInst* V = dyn_cast(BBI)) { + killPointer = V->getOperand(0); + killPointerSize = TD.getTypeSize(V->getType()); + } else if (FreeInst* F = dyn_cast(BBI)) { + killPointer = F->getPointerOperand(); + killPointerSize = ~0UL; + } else if (AllocaInst* A = dyn_cast(BBI)) { + deadPointers.erase(A); + continue; + } else if (CallSite::get(BBI).getInstruction() != 0) { + // Remove any pointers made undead by the call from the dead set + std::vector dead; + for (SmallPtrSet::iterator I = deadPointers.begin(), + E = deadPointers.end(); I != E; ++I) { + // Get size information for the alloca + unsigned pointerSize = ~0UL; + if (ConstantInt* C = dyn_cast((*I)->getArraySize())) + pointerSize = C->getZExtValue() * TD.getTypeSize((*I)->getAllocatedType()); + + // See if the call site touches it + AliasAnalysis::ModRefResult A = AA.getModRefInfo(CallSite::get(BBI), + *I, pointerSize); + if (A == AliasAnalysis::ModRef || A == AliasAnalysis::Ref) + dead.push_back(*I); + } + + for (std::vector::iterator I = dead.begin(), E = dead.end(); + I != E; ++I) + deadPointers.erase(*I); + + continue; + } + + if (!killPointer) + continue; + + // Deal with undead pointers + MadeChange |= RemoveUndeadPointers(killPointer, killPointerSize, BBI, + deadPointers, possiblyDead); + } + + return MadeChange; +} + +bool FDSE::RemoveUndeadPointers(Value* killPointer, unsigned killPointerSize, + BasicBlock::iterator& BBI, + SmallPtrSet& deadPointers, + SetVector& possiblyDead) { + TargetData &TD = getAnalysis(); + AliasAnalysis &AA = getAnalysis(); + MemoryDependenceAnalysis& MD = getAnalysis(); + + bool MadeChange = false; + + std::vector undead; + + for (SmallPtrSet::iterator I = deadPointers.begin(), + E = deadPointers.end(); I != E; ++I) { + // Get size information for the alloca + unsigned pointerSize = ~0UL; + if (ConstantInt* C = dyn_cast((*I)->getArraySize())) + pointerSize = C->getZExtValue() * TD.getTypeSize((*I)->getAllocatedType()); + + // See if this pointer could alias it + AliasAnalysis::AliasResult A = AA.alias(*I, pointerSize, killPointer, killPointerSize); + + // If it must-alias and a store, we can delete it + if (isa(BBI) && A == AliasAnalysis::MustAlias) { + StoreInst* S = cast(BBI); + + // Remove it! + MD.removeInstruction(S); + + // DCE instructions only used to calculate that store + if (Instruction* D = dyn_cast(S->getOperand(0))) + possiblyDead.insert(D); + if (Instruction* D = dyn_cast(S->getOperand(1))) + possiblyDead.insert(D); + + BBI++; + S->eraseFromParent(); + NumFastStores++; + MadeChange = true; + + continue; + + // Otherwise, it is undead + } else if (A != AliasAnalysis::NoAlias) + undead.push_back(*I); + } + + for (std::vector::iterator I = undead.begin(), E = undead.end(); + I != E; ++I) + deadPointers.erase(*I); + + return MadeChange; +} + +void FDSE::DeleteDeadInstructionChains(Instruction *I, + SetVector &DeadInsts) { + // Instruction must be dead. + if (!I->use_empty() || !isInstructionTriviallyDead(I)) return; + + // Let the memory dependence know + getAnalysis().removeInstruction(I); + + // See if this made any operands dead. We do it this way in case the + // instruction uses the same operand twice. We don't want to delete a + // value then reference it. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + if (I->getOperand(i)->hasOneUse()) + if (Instruction* Op = dyn_cast(I->getOperand(i))) + DeadInsts.insert(Op); // Attempt to nuke it later. + + I->setOperand(i, 0); // Drop from the operand list. + } + + I->eraseFromParent(); + ++NumFastOther; +} diff --git a/lib/Transforms/Scalar/GCSE.cpp b/lib/Transforms/Scalar/GCSE.cpp new file mode 100644 index 0000000..93ed8c4 --- /dev/null +++ b/lib/Transforms/Scalar/GCSE.cpp @@ -0,0 +1,201 @@ +//===-- GCSE.cpp - SSA-based Global Common Subexpression Elimination ------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass is designed to be a very quick global transformation that +// eliminates global common subexpressions from a function. It does this by +// using an existing value numbering implementation to identify the common +// subexpressions, eliminating them when possible. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "gcse" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Instructions.h" +#include "llvm/Function.h" +#include "llvm/Type.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/ValueNumbering.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +#include +using namespace llvm; + +STATISTIC(NumInstRemoved, "Number of instructions removed"); +STATISTIC(NumLoadRemoved, "Number of loads removed"); +STATISTIC(NumCallRemoved, "Number of calls removed"); +STATISTIC(NumNonInsts , "Number of instructions removed due " + "to non-instruction values"); +STATISTIC(NumArgsRepl , "Number of function arguments replaced " + "with constant values"); +namespace { + struct VISIBILITY_HIDDEN GCSE : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + GCSE() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F); + + private: + void ReplaceInstructionWith(Instruction *I, Value *V); + + // This transformation requires dominator and immediate dominator info + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequired(); + AU.addRequired(); + } + }; + + char GCSE::ID = 0; + RegisterPass X("gcse", "Global Common Subexpression Elimination"); +} + +// createGCSEPass - The public interface to this file... +FunctionPass *llvm::createGCSEPass() { return new GCSE(); } + +// GCSE::runOnFunction - This is the main transformation entry point for a +// function. +// +bool GCSE::runOnFunction(Function &F) { + bool Changed = false; + + // Get pointers to the analysis results that we will be using... + DominatorTree &DT = getAnalysis(); + ValueNumbering &VN = getAnalysis(); + + std::vector EqualValues; + + // Check for value numbers of arguments. If the value numbering + // implementation can prove that an incoming argument is a constant or global + // value address, substitute it, making the argument dead. + for (Function::arg_iterator AI = F.arg_begin(), E = F.arg_end(); AI != E;++AI) + if (!AI->use_empty()) { + VN.getEqualNumberNodes(AI, EqualValues); + if (!EqualValues.empty()) { + for (unsigned i = 0, e = EqualValues.size(); i != e; ++i) + if (isa(EqualValues[i])) { + AI->replaceAllUsesWith(EqualValues[i]); + ++NumArgsRepl; + Changed = true; + break; + } + EqualValues.clear(); + } + } + + // Traverse the CFG of the function in dominator order, so that we see each + // instruction after we see its operands. + for (df_iterator DI = df_begin(DT.getRootNode()), + E = df_end(DT.getRootNode()); DI != E; ++DI) { + BasicBlock *BB = DI->getBlock(); + + // Remember which instructions we've seen in this basic block as we scan. + std::set BlockInsts; + + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) { + Instruction *Inst = I++; + + if (Constant *C = ConstantFoldInstruction(Inst)) { + ReplaceInstructionWith(Inst, C); + } else if (Inst->getType() != Type::VoidTy) { + // If this instruction computes a value, try to fold together common + // instructions that compute it. + // + VN.getEqualNumberNodes(Inst, EqualValues); + + // If this instruction computes a value that is already computed + // elsewhere, try to recycle the old value. + if (!EqualValues.empty()) { + if (Inst == &*BB->begin()) + I = BB->end(); + else { + I = Inst; --I; + } + + // First check to see if we were able to value number this instruction + // to a non-instruction value. If so, prefer that value over other + // instructions which may compute the same thing. + for (unsigned i = 0, e = EqualValues.size(); i != e; ++i) + if (!isa(EqualValues[i])) { + ++NumNonInsts; // Keep track of # of insts repl with values + + // Change all users of Inst to use the replacement and remove it + // from the program. + ReplaceInstructionWith(Inst, EqualValues[i]); + Inst = 0; + EqualValues.clear(); // don't enter the next loop + break; + } + + // If there were no non-instruction values that this instruction + // produces, find a dominating instruction that produces the same + // value. If we find one, use it's value instead of ours. + for (unsigned i = 0, e = EqualValues.size(); i != e; ++i) { + Instruction *OtherI = cast(EqualValues[i]); + bool Dominates = false; + if (OtherI->getParent() == BB) + Dominates = BlockInsts.count(OtherI); + else + Dominates = DT.dominates(OtherI->getParent(), BB); + + if (Dominates) { + // Okay, we found an instruction with the same value as this one + // and that dominates this one. Replace this instruction with the + // specified one. + ReplaceInstructionWith(Inst, OtherI); + Inst = 0; + break; + } + } + + EqualValues.clear(); + + if (Inst) { + I = Inst; ++I; // Deleted no instructions + } else if (I == BB->end()) { // Deleted first instruction + I = BB->begin(); + } else { // Deleted inst in middle of block. + ++I; + } + } + + if (Inst) + BlockInsts.insert(Inst); + } + } + } + + // When the worklist is empty, return whether or not we changed anything... + return Changed; +} + + +void GCSE::ReplaceInstructionWith(Instruction *I, Value *V) { + if (isa(I)) + ++NumLoadRemoved; // Keep track of loads eliminated + if (isa(I)) + ++NumCallRemoved; // Keep track of calls eliminated + ++NumInstRemoved; // Keep track of number of insts eliminated + + // Update value numbering + getAnalysis().deleteValue(I); + + I->replaceAllUsesWith(V); + + if (InvokeInst *II = dyn_cast(I)) { + // Removing an invoke instruction requires adding a branch to the normal + // destination and removing PHI node entries in the exception destination. + new BranchInst(II->getNormalDest(), II); + II->getUnwindDest()->removePredecessor(II->getParent()); + } + + // Erase the instruction from the program. + I->getParent()->getInstList().erase(I); +} diff --git a/lib/Transforms/Scalar/GVNPRE.cpp b/lib/Transforms/Scalar/GVNPRE.cpp new file mode 100644 index 0000000..e625fc2 --- /dev/null +++ b/lib/Transforms/Scalar/GVNPRE.cpp @@ -0,0 +1,1819 @@ +//===- GVNPRE.cpp - Eliminate redundant values and expressions ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the Owen Anderson and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs a hybrid of global value numbering and partial redundancy +// elimination, known as GVN-PRE. It performs partial redundancy elimination on +// values, rather than lexical expressions, allowing a more comprehensive view +// the optimization. It replaces redundant values with uses of earlier +// occurences of the same value. While this is beneficial in that it eliminates +// unneeded computation, it also increases register pressure by creating large +// live ranges, and should be used with caution on platforms that are very +// sensitive to register pressure. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "gvnpre" +#include "llvm/Value.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Instructions.h" +#include "llvm/Function.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include +#include +#include +#include +#include +using namespace llvm; + +//===----------------------------------------------------------------------===// +// ValueTable Class +//===----------------------------------------------------------------------===// + +/// This class holds the mapping between values and value numbers. It is used +/// as an efficient mechanism to determine the expression-wise equivalence of +/// two values. + +namespace { + class VISIBILITY_HIDDEN ValueTable { + public: + struct Expression { + enum ExpressionOpcode { ADD, SUB, MUL, UDIV, SDIV, FDIV, UREM, SREM, + FREM, SHL, LSHR, ASHR, AND, OR, XOR, ICMPEQ, + ICMPNE, ICMPUGT, ICMPUGE, ICMPULT, ICMPULE, + ICMPSGT, ICMPSGE, ICMPSLT, ICMPSLE, FCMPOEQ, + FCMPOGT, FCMPOGE, FCMPOLT, FCMPOLE, FCMPONE, + FCMPORD, FCMPUNO, FCMPUEQ, FCMPUGT, FCMPUGE, + FCMPULT, FCMPULE, FCMPUNE, EXTRACT, INSERT, + SHUFFLE, SELECT, TRUNC, ZEXT, SEXT, FPTOUI, + FPTOSI, UITOFP, SITOFP, FPTRUNC, FPEXT, + PTRTOINT, INTTOPTR, BITCAST, GEP}; + + ExpressionOpcode opcode; + const Type* type; + uint32_t firstVN; + uint32_t secondVN; + uint32_t thirdVN; + std::vector varargs; + + bool operator< (const Expression& other) const { + if (opcode < other.opcode) + return true; + else if (opcode > other.opcode) + return false; + else if (type < other.type) + return true; + else if (type > other.type) + return false; + else if (firstVN < other.firstVN) + return true; + else if (firstVN > other.firstVN) + return false; + else if (secondVN < other.secondVN) + return true; + else if (secondVN > other.secondVN) + return false; + else if (thirdVN < other.thirdVN) + return true; + else if (thirdVN > other.thirdVN) + return false; + else { + if (varargs.size() < other.varargs.size()) + return true; + else if (varargs.size() > other.varargs.size()) + return false; + + for (size_t i = 0; i < varargs.size(); ++i) + if (varargs[i] < other.varargs[i]) + return true; + else if (varargs[i] > other.varargs[i]) + return false; + + return false; + } + } + }; + + private: + DenseMap valueNumbering; + std::map expressionNumbering; + + uint32_t nextValueNumber; + + Expression::ExpressionOpcode getOpcode(BinaryOperator* BO); + Expression::ExpressionOpcode getOpcode(CmpInst* C); + Expression::ExpressionOpcode getOpcode(CastInst* C); + Expression create_expression(BinaryOperator* BO); + Expression create_expression(CmpInst* C); + Expression create_expression(ShuffleVectorInst* V); + Expression create_expression(ExtractElementInst* C); + Expression create_expression(InsertElementInst* V); + Expression create_expression(SelectInst* V); + Expression create_expression(CastInst* C); + Expression create_expression(GetElementPtrInst* G); + public: + ValueTable() { nextValueNumber = 1; } + uint32_t lookup_or_add(Value* V); + uint32_t lookup(Value* V) const; + void add(Value* V, uint32_t num); + void clear(); + void erase(Value* v); + unsigned size(); + }; +} + +//===----------------------------------------------------------------------===// +// ValueTable Internal Functions +//===----------------------------------------------------------------------===// +ValueTable::Expression::ExpressionOpcode + ValueTable::getOpcode(BinaryOperator* BO) { + switch(BO->getOpcode()) { + case Instruction::Add: + return Expression::ADD; + case Instruction::Sub: + return Expression::SUB; + case Instruction::Mul: + return Expression::MUL; + case Instruction::UDiv: + return Expression::UDIV; + case Instruction::SDiv: + return Expression::SDIV; + case Instruction::FDiv: + return Expression::FDIV; + case Instruction::URem: + return Expression::UREM; + case Instruction::SRem: + return Expression::SREM; + case Instruction::FRem: + return Expression::FREM; + case Instruction::Shl: + return Expression::SHL; + case Instruction::LShr: + return Expression::LSHR; + case Instruction::AShr: + return Expression::ASHR; + case Instruction::And: + return Expression::AND; + case Instruction::Or: + return Expression::OR; + case Instruction::Xor: + return Expression::XOR; + + // THIS SHOULD NEVER HAPPEN + default: + assert(0 && "Binary operator with unknown opcode?"); + return Expression::ADD; + } +} + +ValueTable::Expression::ExpressionOpcode ValueTable::getOpcode(CmpInst* C) { + if (C->getOpcode() == Instruction::ICmp) { + switch (C->getPredicate()) { + case ICmpInst::ICMP_EQ: + return Expression::ICMPEQ; + case ICmpInst::ICMP_NE: + return Expression::ICMPNE; + case ICmpInst::ICMP_UGT: + return Expression::ICMPUGT; + case ICmpInst::ICMP_UGE: + return Expression::ICMPUGE; + case ICmpInst::ICMP_ULT: + return Expression::ICMPULT; + case ICmpInst::ICMP_ULE: + return Expression::ICMPULE; + case ICmpInst::ICMP_SGT: + return Expression::ICMPSGT; + case ICmpInst::ICMP_SGE: + return Expression::ICMPSGE; + case ICmpInst::ICMP_SLT: + return Expression::ICMPSLT; + case ICmpInst::ICMP_SLE: + return Expression::ICMPSLE; + + // THIS SHOULD NEVER HAPPEN + default: + assert(0 && "Comparison with unknown predicate?"); + return Expression::ICMPEQ; + } + } else { + switch (C->getPredicate()) { + case FCmpInst::FCMP_OEQ: + return Expression::FCMPOEQ; + case FCmpInst::FCMP_OGT: + return Expression::FCMPOGT; + case FCmpInst::FCMP_OGE: + return Expression::FCMPOGE; + case FCmpInst::FCMP_OLT: + return Expression::FCMPOLT; + case FCmpInst::FCMP_OLE: + return Expression::FCMPOLE; + case FCmpInst::FCMP_ONE: + return Expression::FCMPONE; + case FCmpInst::FCMP_ORD: + return Expression::FCMPORD; + case FCmpInst::FCMP_UNO: + return Expression::FCMPUNO; + case FCmpInst::FCMP_UEQ: + return Expression::FCMPUEQ; + case FCmpInst::FCMP_UGT: + return Expression::FCMPUGT; + case FCmpInst::FCMP_UGE: + return Expression::FCMPUGE; + case FCmpInst::FCMP_ULT: + return Expression::FCMPULT; + case FCmpInst::FCMP_ULE: + return Expression::FCMPULE; + case FCmpInst::FCMP_UNE: + return Expression::FCMPUNE; + + // THIS SHOULD NEVER HAPPEN + default: + assert(0 && "Comparison with unknown predicate?"); + return Expression::FCMPOEQ; + } + } +} + +ValueTable::Expression::ExpressionOpcode + ValueTable::getOpcode(CastInst* C) { + switch(C->getOpcode()) { + case Instruction::Trunc: + return Expression::TRUNC; + case Instruction::ZExt: + return Expression::ZEXT; + case Instruction::SExt: + return Expression::SEXT; + case Instruction::FPToUI: + return Expression::FPTOUI; + case Instruction::FPToSI: + return Expression::FPTOSI; + case Instruction::UIToFP: + return Expression::UITOFP; + case Instruction::SIToFP: + return Expression::SITOFP; + case Instruction::FPTrunc: + return Expression::FPTRUNC; + case Instruction::FPExt: + return Expression::FPEXT; + case Instruction::PtrToInt: + return Expression::PTRTOINT; + case Instruction::IntToPtr: + return Expression::INTTOPTR; + case Instruction::BitCast: + return Expression::BITCAST; + + // THIS SHOULD NEVER HAPPEN + default: + assert(0 && "Cast operator with unknown opcode?"); + return Expression::BITCAST; + } +} + +ValueTable::Expression ValueTable::create_expression(BinaryOperator* BO) { + Expression e; + + e.firstVN = lookup_or_add(BO->getOperand(0)); + e.secondVN = lookup_or_add(BO->getOperand(1)); + e.thirdVN = 0; + e.type = BO->getType(); + e.opcode = getOpcode(BO); + + return e; +} + +ValueTable::Expression ValueTable::create_expression(CmpInst* C) { + Expression e; + + e.firstVN = lookup_or_add(C->getOperand(0)); + e.secondVN = lookup_or_add(C->getOperand(1)); + e.thirdVN = 0; + e.type = C->getType(); + e.opcode = getOpcode(C); + + return e; +} + +ValueTable::Expression ValueTable::create_expression(CastInst* C) { + Expression e; + + e.firstVN = lookup_or_add(C->getOperand(0)); + e.secondVN = 0; + e.thirdVN = 0; + e.type = C->getType(); + e.opcode = getOpcode(C); + + return e; +} + +ValueTable::Expression ValueTable::create_expression(ShuffleVectorInst* S) { + Expression e; + + e.firstVN = lookup_or_add(S->getOperand(0)); + e.secondVN = lookup_or_add(S->getOperand(1)); + e.thirdVN = lookup_or_add(S->getOperand(2)); + e.type = S->getType(); + e.opcode = Expression::SHUFFLE; + + return e; +} + +ValueTable::Expression ValueTable::create_expression(ExtractElementInst* E) { + Expression e; + + e.firstVN = lookup_or_add(E->getOperand(0)); + e.secondVN = lookup_or_add(E->getOperand(1)); + e.thirdVN = 0; + e.type = E->getType(); + e.opcode = Expression::EXTRACT; + + return e; +} + +ValueTable::Expression ValueTable::create_expression(InsertElementInst* I) { + Expression e; + + e.firstVN = lookup_or_add(I->getOperand(0)); + e.secondVN = lookup_or_add(I->getOperand(1)); + e.thirdVN = lookup_or_add(I->getOperand(2)); + e.type = I->getType(); + e.opcode = Expression::INSERT; + + return e; +} + +ValueTable::Expression ValueTable::create_expression(SelectInst* I) { + Expression e; + + e.firstVN = lookup_or_add(I->getCondition()); + e.secondVN = lookup_or_add(I->getTrueValue()); + e.thirdVN = lookup_or_add(I->getFalseValue()); + e.type = I->getType(); + e.opcode = Expression::SELECT; + + return e; +} + +ValueTable::Expression ValueTable::create_expression(GetElementPtrInst* G) { + Expression e; + + e.firstVN = lookup_or_add(G->getPointerOperand()); + e.secondVN = 0; + e.thirdVN = 0; + e.type = G->getType(); + e.opcode = Expression::SELECT; + + for (GetElementPtrInst::op_iterator I = G->idx_begin(), E = G->idx_end(); + I != E; ++I) + e.varargs.push_back(lookup_or_add(*I)); + + return e; +} + +//===----------------------------------------------------------------------===// +// ValueTable External Functions +//===----------------------------------------------------------------------===// + +/// lookup_or_add - Returns the value number for the specified value, assigning +/// it a new number if it did not have one before. +uint32_t ValueTable::lookup_or_add(Value* V) { + DenseMap::iterator VI = valueNumbering.find(V); + if (VI != valueNumbering.end()) + return VI->second; + + + if (BinaryOperator* BO = dyn_cast(V)) { + Expression e = create_expression(BO); + + std::map::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (CmpInst* C = dyn_cast(V)) { + Expression e = create_expression(C); + + std::map::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (ShuffleVectorInst* U = dyn_cast(V)) { + Expression e = create_expression(U); + + std::map::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (ExtractElementInst* U = dyn_cast(V)) { + Expression e = create_expression(U); + + std::map::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (InsertElementInst* U = dyn_cast(V)) { + Expression e = create_expression(U); + + std::map::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (SelectInst* U = dyn_cast(V)) { + Expression e = create_expression(U); + + std::map::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (CastInst* U = dyn_cast(V)) { + Expression e = create_expression(U); + + std::map::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (GetElementPtrInst* U = dyn_cast(V)) { + Expression e = create_expression(U); + + std::map::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else { + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + return nextValueNumber++; + } +} + +/// lookup - Returns the value number of the specified value. Fails if +/// the value has not yet been numbered. +uint32_t ValueTable::lookup(Value* V) const { + DenseMap::iterator VI = valueNumbering.find(V); + if (VI != valueNumbering.end()) + return VI->second; + else + assert(0 && "Value not numbered?"); + + return 0; +} + +/// add - Add the specified value with the given value number, removing +/// its old number, if any +void ValueTable::add(Value* V, uint32_t num) { + DenseMap::iterator VI = valueNumbering.find(V); + if (VI != valueNumbering.end()) + valueNumbering.erase(VI); + valueNumbering.insert(std::make_pair(V, num)); +} + +/// clear - Remove all entries from the ValueTable +void ValueTable::clear() { + valueNumbering.clear(); + expressionNumbering.clear(); + nextValueNumber = 1; +} + +/// erase - Remove a value from the value numbering +void ValueTable::erase(Value* V) { + valueNumbering.erase(V); +} + +/// size - Return the number of assigned value numbers +unsigned ValueTable::size() { + // NOTE: zero is never assigned + return nextValueNumber; +} + +//===----------------------------------------------------------------------===// +// ValueNumberedSet Class +//===----------------------------------------------------------------------===// + +class ValueNumberedSet { + private: + SmallPtrSet contents; + BitVector numbers; + public: + ValueNumberedSet() { numbers.resize(1); } + ValueNumberedSet(const ValueNumberedSet& other) { + numbers = other.numbers; + contents = other.contents; + } + + typedef SmallPtrSet::iterator iterator; + + iterator begin() { return contents.begin(); } + iterator end() { return contents.end(); } + + bool insert(Value* v) { return contents.insert(v); } + void insert(iterator I, iterator E) { contents.insert(I, E); } + void erase(Value* v) { contents.erase(v); } + unsigned count(Value* v) { return contents.count(v); } + size_t size() { return contents.size(); } + + void set(unsigned i) { + if (i >= numbers.size()) + numbers.resize(i+1); + + numbers.set(i); + } + + void operator=(const ValueNumberedSet& other) { + contents = other.contents; + numbers = other.numbers; + } + + void reset(unsigned i) { + if (i < numbers.size()) + numbers.reset(i); + } + + bool test(unsigned i) { + if (i >= numbers.size()) + return false; + + return numbers.test(i); + } + + void clear() { + contents.clear(); + numbers.clear(); + } +}; + +//===----------------------------------------------------------------------===// +// GVNPRE Pass +//===----------------------------------------------------------------------===// + +namespace { + + class VISIBILITY_HIDDEN GVNPRE : public FunctionPass { + bool runOnFunction(Function &F); + public: + static char ID; // Pass identification, replacement for typeid + GVNPRE() : FunctionPass((intptr_t)&ID) { } + + private: + ValueTable VN; + std::vector createdExpressions; + + DenseMap availableOut; + DenseMap anticipatedIn; + DenseMap generatedPhis; + + // This transformation requires dominator postdominator info + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequiredID(BreakCriticalEdgesID); + AU.addRequired(); + AU.addRequired(); + } + + // Helper fuctions + // FIXME: eliminate or document these better + void dump(ValueNumberedSet& s) const ; + void clean(ValueNumberedSet& set) ; + Value* find_leader(ValueNumberedSet& vals, uint32_t v) ; + Value* phi_translate(Value* V, BasicBlock* pred, BasicBlock* succ) ; + void phi_translate_set(ValueNumberedSet& anticIn, BasicBlock* pred, + BasicBlock* succ, ValueNumberedSet& out) ; + + void topo_sort(ValueNumberedSet& set, + std::vector& vec) ; + + void cleanup() ; + bool elimination() ; + + void val_insert(ValueNumberedSet& s, Value* v) ; + void val_replace(ValueNumberedSet& s, Value* v) ; + bool dependsOnInvoke(Value* V) ; + void buildsets_availout(BasicBlock::iterator I, + ValueNumberedSet& currAvail, + ValueNumberedSet& currPhis, + ValueNumberedSet& currExps, + SmallPtrSet& currTemps) ; + bool buildsets_anticout(BasicBlock* BB, + ValueNumberedSet& anticOut, + std::set& visited) ; + unsigned buildsets_anticin(BasicBlock* BB, + ValueNumberedSet& anticOut, + ValueNumberedSet& currExps, + SmallPtrSet& currTemps, + std::set& visited) ; + void buildsets(Function& F) ; + + void insertion_pre(Value* e, BasicBlock* BB, + std::map& avail, + std::map& new_set) ; + unsigned insertion_mergepoint(std::vector& workList, + df_iterator& D, + std::map& new_set) ; + bool insertion(Function& F) ; + + }; + + char GVNPRE::ID = 0; + +} + +// createGVNPREPass - The public interface to this file... +FunctionPass *llvm::createGVNPREPass() { return new GVNPRE(); } + +static RegisterPass X("gvnpre", + "Global Value Numbering/Partial Redundancy Elimination"); + + +STATISTIC(NumInsertedVals, "Number of values inserted"); +STATISTIC(NumInsertedPhis, "Number of PHI nodes inserted"); +STATISTIC(NumEliminated, "Number of redundant instructions eliminated"); + +/// find_leader - Given a set and a value number, return the first +/// element of the set with that value number, or 0 if no such element +/// is present +Value* GVNPRE::find_leader(ValueNumberedSet& vals, uint32_t v) { + if (!vals.test(v)) + return 0; + + for (ValueNumberedSet::iterator I = vals.begin(), E = vals.end(); + I != E; ++I) + if (v == VN.lookup(*I)) + return *I; + + assert(0 && "No leader found, but present bit is set?"); + return 0; +} + +/// val_insert - Insert a value into a set only if there is not a value +/// with the same value number already in the set +void GVNPRE::val_insert(ValueNumberedSet& s, Value* v) { + uint32_t num = VN.lookup(v); + if (!s.test(num)) + s.insert(v); +} + +/// val_replace - Insert a value into a set, replacing any values already in +/// the set that have the same value number +void GVNPRE::val_replace(ValueNumberedSet& s, Value* v) { + uint32_t num = VN.lookup(v); + Value* leader = find_leader(s, num); + if (leader != 0) + s.erase(leader); + s.insert(v); + s.set(num); +} + +/// phi_translate - Given a value, its parent block, and a predecessor of its +/// parent, translate the value into legal for the predecessor block. This +/// means translating its operands (and recursively, their operands) through +/// any phi nodes in the parent into values available in the predecessor +Value* GVNPRE::phi_translate(Value* V, BasicBlock* pred, BasicBlock* succ) { + if (V == 0) + return 0; + + // Unary Operations + if (CastInst* U = dyn_cast(V)) { + Value* newOp1 = 0; + if (isa(U->getOperand(0))) + newOp1 = phi_translate(U->getOperand(0), pred, succ); + else + newOp1 = U->getOperand(0); + + if (newOp1 == 0) + return 0; + + if (newOp1 != U->getOperand(0)) { + Instruction* newVal = 0; + if (CastInst* C = dyn_cast(U)) + newVal = CastInst::create(C->getOpcode(), + newOp1, C->getType(), + C->getName()+".expr"); + + uint32_t v = VN.lookup_or_add(newVal); + + Value* leader = find_leader(availableOut[pred], v); + if (leader == 0) { + createdExpressions.push_back(newVal); + return newVal; + } else { + VN.erase(newVal); + delete newVal; + return leader; + } + } + + // Binary Operations + } if (isa(V) || isa(V) || + isa(V)) { + User* U = cast(V); + + Value* newOp1 = 0; + if (isa(U->getOperand(0))) + newOp1 = phi_translate(U->getOperand(0), pred, succ); + else + newOp1 = U->getOperand(0); + + if (newOp1 == 0) + return 0; + + Value* newOp2 = 0; + if (isa(U->getOperand(1))) + newOp2 = phi_translate(U->getOperand(1), pred, succ); + else + newOp2 = U->getOperand(1); + + if (newOp2 == 0) + return 0; + + if (newOp1 != U->getOperand(0) || newOp2 != U->getOperand(1)) { + Instruction* newVal = 0; + if (BinaryOperator* BO = dyn_cast(U)) + newVal = BinaryOperator::create(BO->getOpcode(), + newOp1, newOp2, + BO->getName()+".expr"); + else if (CmpInst* C = dyn_cast(U)) + newVal = CmpInst::create(C->getOpcode(), + C->getPredicate(), + newOp1, newOp2, + C->getName()+".expr"); + else if (ExtractElementInst* E = dyn_cast(U)) + newVal = new ExtractElementInst(newOp1, newOp2, E->getName()+".expr"); + + uint32_t v = VN.lookup_or_add(newVal); + + Value* leader = find_leader(availableOut[pred], v); + if (leader == 0) { + createdExpressions.push_back(newVal); + return newVal; + } else { + VN.erase(newVal); + delete newVal; + return leader; + } + } + + // Ternary Operations + } else if (isa(V) || isa(V) || + isa(V)) { + User* U = cast(V); + + Value* newOp1 = 0; + if (isa(U->getOperand(0))) + newOp1 = phi_translate(U->getOperand(0), pred, succ); + else + newOp1 = U->getOperand(0); + + if (newOp1 == 0) + return 0; + + Value* newOp2 = 0; + if (isa(U->getOperand(1))) + newOp2 = phi_translate(U->getOperand(1), pred, succ); + else + newOp2 = U->getOperand(1); + + if (newOp2 == 0) + return 0; + + Value* newOp3 = 0; + if (isa(U->getOperand(2))) + newOp3 = phi_translate(U->getOperand(2), pred, succ); + else + newOp3 = U->getOperand(2); + + if (newOp3 == 0) + return 0; + + if (newOp1 != U->getOperand(0) || + newOp2 != U->getOperand(1) || + newOp3 != U->getOperand(2)) { + Instruction* newVal = 0; + if (ShuffleVectorInst* S = dyn_cast(U)) + newVal = new ShuffleVectorInst(newOp1, newOp2, newOp3, + S->getName()+".expr"); + else if (InsertElementInst* I = dyn_cast(U)) + newVal = new InsertElementInst(newOp1, newOp2, newOp3, + I->getName()+".expr"); + else if (SelectInst* I = dyn_cast(U)) + newVal = new SelectInst(newOp1, newOp2, newOp3, I->getName()+".expr"); + + uint32_t v = VN.lookup_or_add(newVal); + + Value* leader = find_leader(availableOut[pred], v); + if (leader == 0) { + createdExpressions.push_back(newVal); + return newVal; + } else { + VN.erase(newVal); + delete newVal; + return leader; + } + } + + // Varargs operators + } else if (GetElementPtrInst* U = dyn_cast(V)) { + Value* newOp1 = 0; + if (isa(U->getPointerOperand())) + newOp1 = phi_translate(U->getPointerOperand(), pred, succ); + else + newOp1 = U->getPointerOperand(); + + if (newOp1 == 0) + return 0; + + bool changed_idx = false; + std::vector newIdx; + for (GetElementPtrInst::op_iterator I = U->idx_begin(), E = U->idx_end(); + I != E; ++I) + if (isa(*I)) { + Value* newVal = phi_translate(*I, pred, succ); + newIdx.push_back(newVal); + if (newVal != *I) + changed_idx = true; + } else { + newIdx.push_back(*I); + } + + if (newOp1 != U->getPointerOperand() || changed_idx) { + Instruction* newVal = new GetElementPtrInst(newOp1, + &newIdx[0], newIdx.size(), + U->getName()+".expr"); + + uint32_t v = VN.lookup_or_add(newVal); + + Value* leader = find_leader(availableOut[pred], v); + if (leader == 0) { + createdExpressions.push_back(newVal); + return newVal; + } else { + VN.erase(newVal); + delete newVal; + return leader; + } + } + + // PHI Nodes + } else if (PHINode* P = dyn_cast(V)) { + if (P->getParent() == succ) + return P->getIncomingValueForBlock(pred); + } + + return V; +} + +/// phi_translate_set - Perform phi translation on every element of a set +void GVNPRE::phi_translate_set(ValueNumberedSet& anticIn, + BasicBlock* pred, BasicBlock* succ, + ValueNumberedSet& out) { + for (ValueNumberedSet::iterator I = anticIn.begin(), + E = anticIn.end(); I != E; ++I) { + Value* V = phi_translate(*I, pred, succ); + if (V != 0 && !out.test(VN.lookup_or_add(V))) { + out.insert(V); + out.set(VN.lookup(V)); + } + } +} + +/// dependsOnInvoke - Test if a value has an phi node as an operand, any of +/// whose inputs is an invoke instruction. If this is true, we cannot safely +/// PRE the instruction or anything that depends on it. +bool GVNPRE::dependsOnInvoke(Value* V) { + if (PHINode* p = dyn_cast(V)) { + for (PHINode::op_iterator I = p->op_begin(), E = p->op_end(); I != E; ++I) + if (isa(*I)) + return true; + return false; + } else { + return false; + } +} + +/// clean - Remove all non-opaque values from the set whose operands are not +/// themselves in the set, as well as all values that depend on invokes (see +/// above) +void GVNPRE::clean(ValueNumberedSet& set) { + std::vector worklist; + worklist.reserve(set.size()); + topo_sort(set, worklist); + + for (unsigned i = 0; i < worklist.size(); ++i) { + Value* v = worklist[i]; + + // Handle unary ops + if (CastInst* U = dyn_cast(v)) { + bool lhsValid = !isa(U->getOperand(0)); + lhsValid |= set.test(VN.lookup(U->getOperand(0))); + if (lhsValid) + lhsValid = !dependsOnInvoke(U->getOperand(0)); + + if (!lhsValid) { + set.erase(U); + set.reset(VN.lookup(U)); + } + + // Handle binary ops + } else if (isa(v) || isa(v) || + isa(v)) { + User* U = cast(v); + + bool lhsValid = !isa(U->getOperand(0)); + lhsValid |= set.test(VN.lookup(U->getOperand(0))); + if (lhsValid) + lhsValid = !dependsOnInvoke(U->getOperand(0)); + + bool rhsValid = !isa(U->getOperand(1)); + rhsValid |= set.test(VN.lookup(U->getOperand(1))); + if (rhsValid) + rhsValid = !dependsOnInvoke(U->getOperand(1)); + + if (!lhsValid || !rhsValid) { + set.erase(U); + set.reset(VN.lookup(U)); + } + + // Handle ternary ops + } else if (isa(v) || isa(v) || + isa(v)) { + User* U = cast(v); + + bool lhsValid = !isa(U->getOperand(0)); + lhsValid |= set.test(VN.lookup(U->getOperand(0))); + if (lhsValid) + lhsValid = !dependsOnInvoke(U->getOperand(0)); + + bool rhsValid = !isa(U->getOperand(1)); + rhsValid |= set.test(VN.lookup(U->getOperand(1))); + if (rhsValid) + rhsValid = !dependsOnInvoke(U->getOperand(1)); + + bool thirdValid = !isa(U->getOperand(2)); + thirdValid |= set.test(VN.lookup(U->getOperand(2))); + if (thirdValid) + thirdValid = !dependsOnInvoke(U->getOperand(2)); + + if (!lhsValid || !rhsValid || !thirdValid) { + set.erase(U); + set.reset(VN.lookup(U)); + } + + // Handle varargs ops + } else if (GetElementPtrInst* U = dyn_cast(v)) { + bool ptrValid = !isa(U->getPointerOperand()); + ptrValid |= set.test(VN.lookup(U->getPointerOperand())); + if (ptrValid) + ptrValid = !dependsOnInvoke(U->getPointerOperand()); + + bool varValid = true; + for (GetElementPtrInst::op_iterator I = U->idx_begin(), E = U->idx_end(); + I != E; ++I) + if (varValid) { + varValid &= !isa(*I) || set.test(VN.lookup(*I)); + varValid &= !dependsOnInvoke(*I); + } + + if (!ptrValid || !varValid) { + set.erase(U); + set.reset(VN.lookup(U)); + } + } + } +} + +/// topo_sort - Given a set of values, sort them by topological +/// order into the provided vector. +void GVNPRE::topo_sort(ValueNumberedSet& set, std::vector& vec) { + SmallPtrSet visited; + std::vector stack; + for (ValueNumberedSet::iterator I = set.begin(), E = set.end(); + I != E; ++I) { + if (visited.count(*I) == 0) + stack.push_back(*I); + + while (!stack.empty()) { + Value* e = stack.back(); + + // Handle unary ops + if (CastInst* U = dyn_cast(e)) { + Value* l = find_leader(set, VN.lookup(U->getOperand(0))); + + if (l != 0 && isa(l) && + visited.count(l) == 0) + stack.push_back(l); + else { + vec.push_back(e); + visited.insert(e); + stack.pop_back(); + } + + // Handle binary ops + } else if (isa(e) || isa(e) || + isa(e)) { + User* U = cast(e); + Value* l = find_leader(set, VN.lookup(U->getOperand(0))); + Value* r = find_leader(set, VN.lookup(U->getOperand(1))); + + if (l != 0 && isa(l) && + visited.count(l) == 0) + stack.push_back(l); + else if (r != 0 && isa(r) && + visited.count(r) == 0) + stack.push_back(r); + else { + vec.push_back(e); + visited.insert(e); + stack.pop_back(); + } + + // Handle ternary ops + } else if (isa(e) || isa(e) || + isa(e)) { + User* U = cast(e); + Value* l = find_leader(set, VN.lookup(U->getOperand(0))); + Value* r = find_leader(set, VN.lookup(U->getOperand(1))); + Value* m = find_leader(set, VN.lookup(U->getOperand(2))); + + if (l != 0 && isa(l) && + visited.count(l) == 0) + stack.push_back(l); + else if (r != 0 && isa(r) && + visited.count(r) == 0) + stack.push_back(r); + else if (m != 0 && isa(m) && + visited.count(m) == 0) + stack.push_back(m); + else { + vec.push_back(e); + visited.insert(e); + stack.pop_back(); + } + + // Handle vararg ops + } else if (GetElementPtrInst* U = dyn_cast(e)) { + Value* p = find_leader(set, VN.lookup(U->getPointerOperand())); + + if (p != 0 && isa(p) && + visited.count(p) == 0) + stack.push_back(p); + else { + bool push_va = false; + for (GetElementPtrInst::op_iterator I = U->idx_begin(), + E = U->idx_end(); I != E; ++I) { + Value * v = find_leader(set, VN.lookup(*I)); + if (v != 0 && isa(v) && visited.count(v) == 0) { + stack.push_back(v); + push_va = true; + } + } + + if (!push_va) { + vec.push_back(e); + visited.insert(e); + stack.pop_back(); + } + } + + // Handle opaque ops + } else { + visited.insert(e); + vec.push_back(e); + stack.pop_back(); + } + } + + stack.clear(); + } +} + +/// dump - Dump a set of values to standard error +void GVNPRE::dump(ValueNumberedSet& s) const { + DOUT << "{ "; + for (ValueNumberedSet::iterator I = s.begin(), E = s.end(); + I != E; ++I) { + DOUT << "" << VN.lookup(*I) << ": "; + DEBUG((*I)->dump()); + } + DOUT << "}\n\n"; +} + +/// elimination - Phase 3 of the main algorithm. Perform full redundancy +/// elimination by walking the dominator tree and removing any instruction that +/// is dominated by another instruction with the same value number. +bool GVNPRE::elimination() { + bool changed_function = false; + + std::vector > replace; + std::vector erase; + + DominatorTree& DT = getAnalysis(); + + for (df_iterator DI = df_begin(DT.getRootNode()), + E = df_end(DT.getRootNode()); DI != E; ++DI) { + BasicBlock* BB = DI->getBlock(); + + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); + BI != BE; ++BI) { + + if (isa(BI) || isa(BI) || + isa(BI) || isa(BI) || + isa(BI) || isa(BI) || + isa(BI) || isa(BI)) { + + if (availableOut[BB].test(VN.lookup(BI)) && !availableOut[BB].count(BI)) { + Value *leader = find_leader(availableOut[BB], VN.lookup(BI)); + if (Instruction* Instr = dyn_cast(leader)) + if (Instr->getParent() != 0 && Instr != BI) { + replace.push_back(std::make_pair(BI, leader)); + erase.push_back(BI); + ++NumEliminated; + } + } + } + } + } + + while (!replace.empty()) { + std::pair rep = replace.back(); + replace.pop_back(); + rep.first->replaceAllUsesWith(rep.second); + changed_function = true; + } + + for (std::vector::iterator I = erase.begin(), E = erase.end(); + I != E; ++I) + (*I)->eraseFromParent(); + + return changed_function; +} + +/// cleanup - Delete any extraneous values that were created to represent +/// expressions without leaders. +void GVNPRE::cleanup() { + while (!createdExpressions.empty()) { + Instruction* I = createdExpressions.back(); + createdExpressions.pop_back(); + + delete I; + } +} + +/// buildsets_availout - When calculating availability, handle an instruction +/// by inserting it into the appropriate sets +void GVNPRE::buildsets_availout(BasicBlock::iterator I, + ValueNumberedSet& currAvail, + ValueNumberedSet& currPhis, + ValueNumberedSet& currExps, + SmallPtrSet& currTemps) { + // Handle PHI nodes + if (PHINode* p = dyn_cast(I)) { + unsigned num = VN.lookup_or_add(p); + + currPhis.insert(p); + currPhis.set(num); + + // Handle unary ops + } else if (CastInst* U = dyn_cast(I)) { + Value* leftValue = U->getOperand(0); + + unsigned num = VN.lookup_or_add(U); + + if (isa(leftValue)) + if (!currExps.test(VN.lookup(leftValue))) { + currExps.insert(leftValue); + currExps.set(VN.lookup(leftValue)); + } + + if (!currExps.test(num)) { + currExps.insert(U); + currExps.set(num); + } + + // Handle binary ops + } else if (isa(I) || isa(I) || + isa(I)) { + User* U = cast(I); + Value* leftValue = U->getOperand(0); + Value* rightValue = U->getOperand(1); + + unsigned num = VN.lookup_or_add(U); + + if (isa(leftValue)) + if (!currExps.test(VN.lookup(leftValue))) { + currExps.insert(leftValue); + currExps.set(VN.lookup(leftValue)); + } + + if (isa(rightValue)) + if (!currExps.test(VN.lookup(rightValue))) { + currExps.insert(rightValue); + currExps.set(VN.lookup(rightValue)); + } + + if (!currExps.test(num)) { + currExps.insert(U); + currExps.set(num); + } + + // Handle ternary ops + } else if (isa(I) || isa(I) || + isa(I)) { + User* U = cast(I); + Value* leftValue = U->getOperand(0); + Value* rightValue = U->getOperand(1); + Value* thirdValue = U->getOperand(2); + + VN.lookup_or_add(U); + + unsigned num = VN.lookup_or_add(U); + + if (isa(leftValue)) + if (!currExps.test(VN.lookup(leftValue))) { + currExps.insert(leftValue); + currExps.set(VN.lookup(leftValue)); + } + if (isa(rightValue)) + if (!currExps.test(VN.lookup(rightValue))) { + currExps.insert(rightValue); + currExps.set(VN.lookup(rightValue)); + } + if (isa(thirdValue)) + if (!currExps.test(VN.lookup(thirdValue))) { + currExps.insert(thirdValue); + currExps.set(VN.lookup(thirdValue)); + } + + if (!currExps.test(num)) { + currExps.insert(U); + currExps.set(num); + } + + // Handle vararg ops + } else if (GetElementPtrInst* U = dyn_cast(I)) { + Value* ptrValue = U->getPointerOperand(); + + VN.lookup_or_add(U); + + unsigned num = VN.lookup_or_add(U); + + if (isa(ptrValue)) + if (!currExps.test(VN.lookup(ptrValue))) { + currExps.insert(ptrValue); + currExps.set(VN.lookup(ptrValue)); + } + + for (GetElementPtrInst::op_iterator OI = U->idx_begin(), OE = U->idx_end(); + OI != OE; ++OI) + if (isa(*OI) && !currExps.test(VN.lookup(*OI))) { + currExps.insert(*OI); + currExps.set(VN.lookup(*OI)); + } + + if (!currExps.test(VN.lookup(U))) { + currExps.insert(U); + currExps.set(num); + } + + // Handle opaque ops + } else if (!I->isTerminator()){ + VN.lookup_or_add(I); + + currTemps.insert(I); + } + + if (!I->isTerminator()) + if (!currAvail.test(VN.lookup(I))) { + currAvail.insert(I); + currAvail.set(VN.lookup(I)); + } +} + +/// buildsets_anticout - When walking the postdom tree, calculate the ANTIC_OUT +/// set as a function of the ANTIC_IN set of the block's predecessors +bool GVNPRE::buildsets_anticout(BasicBlock* BB, + ValueNumberedSet& anticOut, + std::set& visited) { + if (BB->getTerminator()->getNumSuccessors() == 1) { + if (BB->getTerminator()->getSuccessor(0) != BB && + visited.count(BB->getTerminator()->getSuccessor(0)) == 0) { + return true; + } + else { + phi_translate_set(anticipatedIn[BB->getTerminator()->getSuccessor(0)], + BB, BB->getTerminator()->getSuccessor(0), anticOut); + } + } else if (BB->getTerminator()->getNumSuccessors() > 1) { + BasicBlock* first = BB->getTerminator()->getSuccessor(0); + for (ValueNumberedSet::iterator I = anticipatedIn[first].begin(), + E = anticipatedIn[first].end(); I != E; ++I) { + anticOut.insert(*I); + anticOut.set(VN.lookup(*I)); + } + + for (unsigned i = 1; i < BB->getTerminator()->getNumSuccessors(); ++i) { + BasicBlock* currSucc = BB->getTerminator()->getSuccessor(i); + ValueNumberedSet& succAnticIn = anticipatedIn[currSucc]; + + std::vector temp; + + for (ValueNumberedSet::iterator I = anticOut.begin(), + E = anticOut.end(); I != E; ++I) + if (!succAnticIn.test(VN.lookup(*I))) + temp.push_back(*I); + + for (std::vector::iterator I = temp.begin(), E = temp.end(); + I != E; ++I) { + anticOut.erase(*I); + anticOut.reset(VN.lookup(*I)); + } + } + } + + return false; +} + +/// buildsets_anticin - Walk the postdom tree, calculating ANTIC_OUT for +/// each block. ANTIC_IN is then a function of ANTIC_OUT and the GEN +/// sets populated in buildsets_availout +unsigned GVNPRE::buildsets_anticin(BasicBlock* BB, + ValueNumberedSet& anticOut, + ValueNumberedSet& currExps, + SmallPtrSet& currTemps, + std::set& visited) { + ValueNumberedSet& anticIn = anticipatedIn[BB]; + unsigned old = anticIn.size(); + + bool defer = buildsets_anticout(BB, anticOut, visited); + if (defer) + return 0; + + anticIn.clear(); + + for (ValueNumberedSet::iterator I = anticOut.begin(), + E = anticOut.end(); I != E; ++I) { + anticIn.insert(*I); + anticIn.set(VN.lookup(*I)); + } + for (ValueNumberedSet::iterator I = currExps.begin(), + E = currExps.end(); I != E; ++I) { + if (!anticIn.test(VN.lookup(*I))) { + anticIn.insert(*I); + anticIn.set(VN.lookup(*I)); + } + } + + for (SmallPtrSet::iterator I = currTemps.begin(), + E = currTemps.end(); I != E; ++I) { + anticIn.erase(*I); + anticIn.reset(VN.lookup(*I)); + } + + clean(anticIn); + anticOut.clear(); + + if (old != anticIn.size()) + return 2; + else + return 1; +} + +/// buildsets - Phase 1 of the main algorithm. Construct the AVAIL_OUT +/// and the ANTIC_IN sets. +void GVNPRE::buildsets(Function& F) { + std::map generatedExpressions; + std::map > generatedTemporaries; + + DominatorTree &DT = getAnalysis(); + + // Phase 1, Part 1: calculate AVAIL_OUT + + // Top-down walk of the dominator tree + for (df_iterator DI = df_begin(DT.getRootNode()), + E = df_end(DT.getRootNode()); DI != E; ++DI) { + + // Get the sets to update for this block + ValueNumberedSet& currExps = generatedExpressions[DI->getBlock()]; + ValueNumberedSet& currPhis = generatedPhis[DI->getBlock()]; + SmallPtrSet& currTemps = generatedTemporaries[DI->getBlock()]; + ValueNumberedSet& currAvail = availableOut[DI->getBlock()]; + + BasicBlock* BB = DI->getBlock(); + + // A block inherits AVAIL_OUT from its dominator + if (DI->getIDom() != 0) + currAvail = availableOut[DI->getIDom()->getBlock()]; + + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); + BI != BE; ++BI) + buildsets_availout(BI, currAvail, currPhis, currExps, + currTemps); + + } + + // Phase 1, Part 2: calculate ANTIC_IN + + std::set visited; + SmallPtrSet block_changed; + for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ++FI) + block_changed.insert(FI); + + bool changed = true; + unsigned iterations = 0; + + while (changed) { + changed = false; + ValueNumberedSet anticOut; + + // Postorder walk of the CFG + for (po_iterator BBI = po_begin(&F.getEntryBlock()), + BBE = po_end(&F.getEntryBlock()); BBI != BBE; ++BBI) { + BasicBlock* BB = *BBI; + + if (block_changed.count(BB) != 0) { + unsigned ret = buildsets_anticin(BB, anticOut,generatedExpressions[BB], + generatedTemporaries[BB], visited); + + if (ret == 0) { + changed = true; + continue; + } else { + visited.insert(BB); + + if (ret == 2) + for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); + PI != PE; ++PI) { + block_changed.insert(*PI); + } + else + block_changed.erase(BB); + + changed |= (ret == 2); + } + } + } + + iterations++; + } +} + +/// insertion_pre - When a partial redundancy has been identified, eliminate it +/// by inserting appropriate values into the predecessors and a phi node in +/// the main block +void GVNPRE::insertion_pre(Value* e, BasicBlock* BB, + std::map& avail, + std::map& new_sets) { + for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); PI != PE; ++PI) { + Value* e2 = avail[*PI]; + if (!availableOut[*PI].test(VN.lookup(e2))) { + User* U = cast(e2); + + Value* s1 = 0; + if (isa(U->getOperand(0)) || + isa(U->getOperand(0)) || + isa(U->getOperand(0)) || + isa(U->getOperand(0)) || + isa(U->getOperand(0)) || + isa(U->getOperand(0)) || + isa(U->getOperand(0)) || + isa(U->getOperand(0))) + s1 = find_leader(availableOut[*PI], VN.lookup(U->getOperand(0))); + else + s1 = U->getOperand(0); + + Value* s2 = 0; + + if (isa(U) || + isa(U) || + isa(U) || + isa(U) || + isa(U) || + isa(U)) + if (isa(U->getOperand(1)) || + isa(U->getOperand(1)) || + isa(U->getOperand(1)) || + isa(U->getOperand(1)) || + isa(U->getOperand(1)) || + isa(U->getOperand(1)) || + isa(U->getOperand(1)) || + isa(U->getOperand(1))) { + s2 = find_leader(availableOut[*PI], VN.lookup(U->getOperand(1))); + } else { + s2 = U->getOperand(1); + } + + // Ternary Operators + Value* s3 = 0; + if (isa(U) || + isa(U) || + isa(U)) + if (isa(U->getOperand(2)) || + isa(U->getOperand(2)) || + isa(U->getOperand(2)) || + isa(U->getOperand(2)) || + isa(U->getOperand(2)) || + isa(U->getOperand(2)) || + isa(U->getOperand(2)) || + isa(U->getOperand(2))) { + s3 = find_leader(availableOut[*PI], VN.lookup(U->getOperand(2))); + } else { + s3 = U->getOperand(2); + } + + // Vararg operators + std::vector sVarargs; + if (GetElementPtrInst* G = dyn_cast(U)) { + for (GetElementPtrInst::op_iterator OI = G->idx_begin(), + OE = G->idx_end(); OI != OE; ++OI) { + if (isa(*OI) || + isa(*OI) || + isa(*OI) || + isa(*OI) || + isa(*OI) || + isa(*OI) || + isa(*OI) || + isa(*OI)) { + sVarargs.push_back(find_leader(availableOut[*PI], + VN.lookup(*OI))); + } else { + sVarargs.push_back(*OI); + } + } + } + + Value* newVal = 0; + if (BinaryOperator* BO = dyn_cast(U)) + newVal = BinaryOperator::create(BO->getOpcode(), s1, s2, + BO->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (CmpInst* C = dyn_cast(U)) + newVal = CmpInst::create(C->getOpcode(), C->getPredicate(), s1, s2, + C->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (ShuffleVectorInst* S = dyn_cast(U)) + newVal = new ShuffleVectorInst(s1, s2, s3, S->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (InsertElementInst* S = dyn_cast(U)) + newVal = new InsertElementInst(s1, s2, s3, S->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (ExtractElementInst* S = dyn_cast(U)) + newVal = new ExtractElementInst(s1, s2, S->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (SelectInst* S = dyn_cast(U)) + newVal = new SelectInst(s1, s2, s3, S->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (CastInst* C = dyn_cast(U)) + newVal = CastInst::create(C->getOpcode(), s1, C->getType(), + C->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (GetElementPtrInst* G = dyn_cast(U)) + newVal = new GetElementPtrInst(s1, &sVarargs[0], sVarargs.size(), + G->getName()+".gvnpre", + (*PI)->getTerminator()); + + + VN.add(newVal, VN.lookup(U)); + + ValueNumberedSet& predAvail = availableOut[*PI]; + val_replace(predAvail, newVal); + val_replace(new_sets[*PI], newVal); + predAvail.set(VN.lookup(newVal)); + + std::map::iterator av = avail.find(*PI); + if (av != avail.end()) + avail.erase(av); + avail.insert(std::make_pair(*PI, newVal)); + + ++NumInsertedVals; + } + } + + PHINode* p = 0; + + for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); PI != PE; ++PI) { + if (p == 0) + p = new PHINode(avail[*PI]->getType(), "gvnpre-join", BB->begin()); + + p->addIncoming(avail[*PI], *PI); + } + + VN.add(p, VN.lookup(e)); + val_replace(availableOut[BB], p); + availableOut[BB].set(VN.lookup(e)); + generatedPhis[BB].insert(p); + generatedPhis[BB].set(VN.lookup(e)); + new_sets[BB].insert(p); + new_sets[BB].set(VN.lookup(e)); + + ++NumInsertedPhis; +} + +/// insertion_mergepoint - When walking the dom tree, check at each merge +/// block for the possibility of a partial redundancy. If present, eliminate it +unsigned GVNPRE::insertion_mergepoint(std::vector& workList, + df_iterator& D, + std::map& new_sets) { + bool changed_function = false; + bool new_stuff = false; + + BasicBlock* BB = D->getBlock(); + for (unsigned i = 0; i < workList.size(); ++i) { + Value* e = workList[i]; + + if (isa(e) || isa(e) || + isa(e) || isa(e) || + isa(e) || isa(e) || isa(e) || + isa(e)) { + if (availableOut[D->getIDom()->getBlock()].test(VN.lookup(e))) + continue; + + std::map avail; + bool by_some = false; + bool all_same = true; + Value * first_s = 0; + + for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); PI != PE; + ++PI) { + Value *e2 = phi_translate(e, *PI, BB); + Value *e3 = find_leader(availableOut[*PI], VN.lookup(e2)); + + if (e3 == 0) { + std::map::iterator av = avail.find(*PI); + if (av != avail.end()) + avail.erase(av); + avail.insert(std::make_pair(*PI, e2)); + all_same = false; + } else { + std::map::iterator av = avail.find(*PI); + if (av != avail.end()) + avail.erase(av); + avail.insert(std::make_pair(*PI, e3)); + + by_some = true; + if (first_s == 0) + first_s = e3; + else if (first_s != e3) + all_same = false; + } + } + + if (by_some && !all_same && + !generatedPhis[BB].test(VN.lookup(e))) { + insertion_pre(e, BB, avail, new_sets); + + changed_function = true; + new_stuff = true; + } + } + } + + unsigned retval = 0; + if (changed_function) + retval += 1; + if (new_stuff) + retval += 2; + + return retval; +} + +/// insert - Phase 2 of the main algorithm. Walk the dominator tree looking for +/// merge points. When one is found, check for a partial redundancy. If one is +/// present, eliminate it. Repeat this walk until no changes are made. +bool GVNPRE::insertion(Function& F) { + bool changed_function = false; + + DominatorTree &DT = getAnalysis(); + + std::map new_sets; + bool new_stuff = true; + while (new_stuff) { + new_stuff = false; + for (df_iterator DI = df_begin(DT.getRootNode()), + E = df_end(DT.getRootNode()); DI != E; ++DI) { + BasicBlock* BB = DI->getBlock(); + + if (BB == 0) + continue; + + ValueNumberedSet& availOut = availableOut[BB]; + ValueNumberedSet& anticIn = anticipatedIn[BB]; + + // Replace leaders with leaders inherited from dominator + if (DI->getIDom() != 0) { + ValueNumberedSet& dom_set = new_sets[DI->getIDom()->getBlock()]; + for (ValueNumberedSet::iterator I = dom_set.begin(), + E = dom_set.end(); I != E; ++I) { + val_replace(new_sets[BB], *I); + val_replace(availOut, *I); + } + } + + // If there is more than one predecessor... + if (pred_begin(BB) != pred_end(BB) && ++pred_begin(BB) != pred_end(BB)) { + std::vector workList; + workList.reserve(anticIn.size()); + topo_sort(anticIn, workList); + + unsigned result = insertion_mergepoint(workList, DI, new_sets); + if (result & 1) + changed_function = true; + if (result & 2) + new_stuff = true; + } + } + } + + return changed_function; +} + +// GVNPRE::runOnFunction - This is the main transformation entry point for a +// function. +// +bool GVNPRE::runOnFunction(Function &F) { + // Clean out global sets from any previous functions + VN.clear(); + createdExpressions.clear(); + availableOut.clear(); + anticipatedIn.clear(); + generatedPhis.clear(); + + bool changed_function = false; + + // Phase 1: BuildSets + // This phase calculates the AVAIL_OUT and ANTIC_IN sets + buildsets(F); + + // Phase 2: Insert + // This phase inserts values to make partially redundant values + // fully redundant + changed_function |= insertion(F); + + // Phase 3: Eliminate + // This phase performs trivial full redundancy elimination + changed_function |= elimination(); + + // Phase 4: Cleanup + // This phase cleans up values that were created solely + // as leaders for expressions + cleanup(); + + return changed_function; +} diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp new file mode 100644 index 0000000..01b7481 --- /dev/null +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -0,0 +1,604 @@ +//===- IndVarSimplify.cpp - Induction Variable Elimination ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This transformation analyzes and transforms the induction variables (and +// computations derived from them) into simpler forms suitable for subsequent +// analysis and transformation. +// +// This transformation makes the following changes to each loop with an +// identifiable induction variable: +// 1. All loops are transformed to have a SINGLE canonical induction variable +// which starts at zero and steps by one. +// 2. The canonical induction variable is guaranteed to be the first PHI node +// in the loop header block. +// 3. Any pointer arithmetic recurrences are raised to use array subscripts. +// +// If the trip count of a loop is computable, this pass also makes the following +// changes: +// 1. The exit condition for the loop is canonicalized to compare the +// induction value against the exit value. This turns loops like: +// 'for (i = 7; i*i < 1000; ++i)' into 'for (i = 0; i != 25; ++i)' +// 2. Any use outside of the loop of an expression derived from the indvar +// is changed to compute the derived value outside of the loop, eliminating +// the dependence on the exit value of the induction variable. If the only +// purpose of the loop is to compute the exit value of some derived +// expression, this transformation will make the loop dead. +// +// This transformation should be followed by strength reduction after all of the +// desired loop transformations have been performed. Additionally, on targets +// where it is profitable, the loop could be transformed to count down to zero +// (the "do loop" optimization). +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "indvars" +#include "llvm/Transforms/Scalar.h" +#include "llvm/BasicBlock.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/Type.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +using namespace llvm; + +STATISTIC(NumRemoved , "Number of aux indvars removed"); +STATISTIC(NumPointer , "Number of pointer indvars promoted"); +STATISTIC(NumInserted, "Number of canonical indvars added"); +STATISTIC(NumReplaced, "Number of exit values replaced"); +STATISTIC(NumLFTR , "Number of loop exit tests replaced"); + +namespace { + class VISIBILITY_HIDDEN IndVarSimplify : public LoopPass { + LoopInfo *LI; + ScalarEvolution *SE; + bool Changed; + public: + + static char ID; // Pass identification, replacement for typeid + IndVarSimplify() : LoopPass((intptr_t)&ID) {} + + bool runOnLoop(Loop *L, LPPassManager &LPM); + bool doInitialization(Loop *L, LPPassManager &LPM); + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(LCSSAID); + AU.addRequiredID(LoopSimplifyID); + AU.addRequired(); + AU.addRequired(); + AU.addPreservedID(LoopSimplifyID); + AU.addPreservedID(LCSSAID); + AU.setPreservesCFG(); + } + + private: + + void EliminatePointerRecurrence(PHINode *PN, BasicBlock *Preheader, + std::set &DeadInsts); + Instruction *LinearFunctionTestReplace(Loop *L, SCEV *IterationCount, + SCEVExpander &RW); + void RewriteLoopExitValues(Loop *L); + + void DeleteTriviallyDeadInstructions(std::set &Insts); + }; + + char IndVarSimplify::ID = 0; + RegisterPass X("indvars", "Canonicalize Induction Variables"); +} + +LoopPass *llvm::createIndVarSimplifyPass() { + return new IndVarSimplify(); +} + +/// DeleteTriviallyDeadInstructions - If any of the instructions is the +/// specified set are trivially dead, delete them and see if this makes any of +/// their operands subsequently dead. +void IndVarSimplify:: +DeleteTriviallyDeadInstructions(std::set &Insts) { + while (!Insts.empty()) { + Instruction *I = *Insts.begin(); + Insts.erase(Insts.begin()); + if (isInstructionTriviallyDead(I)) { + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Instruction *U = dyn_cast(I->getOperand(i))) + Insts.insert(U); + SE->deleteValueFromRecords(I); + DOUT << "INDVARS: Deleting: " << *I; + I->eraseFromParent(); + Changed = true; + } + } +} + + +/// EliminatePointerRecurrence - Check to see if this is a trivial GEP pointer +/// recurrence. If so, change it into an integer recurrence, permitting +/// analysis by the SCEV routines. +void IndVarSimplify::EliminatePointerRecurrence(PHINode *PN, + BasicBlock *Preheader, + std::set &DeadInsts) { + assert(PN->getNumIncomingValues() == 2 && "Noncanonicalized loop!"); + unsigned PreheaderIdx = PN->getBasicBlockIndex(Preheader); + unsigned BackedgeIdx = PreheaderIdx^1; + if (GetElementPtrInst *GEPI = + dyn_cast(PN->getIncomingValue(BackedgeIdx))) + if (GEPI->getOperand(0) == PN) { + assert(GEPI->getNumOperands() == 2 && "GEP types must match!"); + DOUT << "INDVARS: Eliminating pointer recurrence: " << *GEPI; + + // Okay, we found a pointer recurrence. Transform this pointer + // recurrence into an integer recurrence. Compute the value that gets + // added to the pointer at every iteration. + Value *AddedVal = GEPI->getOperand(1); + + // Insert a new integer PHI node into the top of the block. + PHINode *NewPhi = new PHINode(AddedVal->getType(), + PN->getName()+".rec", PN); + NewPhi->addIncoming(Constant::getNullValue(NewPhi->getType()), Preheader); + + // Create the new add instruction. + Value *NewAdd = BinaryOperator::createAdd(NewPhi, AddedVal, + GEPI->getName()+".rec", GEPI); + NewPhi->addIncoming(NewAdd, PN->getIncomingBlock(BackedgeIdx)); + + // Update the existing GEP to use the recurrence. + GEPI->setOperand(0, PN->getIncomingValue(PreheaderIdx)); + + // Update the GEP to use the new recurrence we just inserted. + GEPI->setOperand(1, NewAdd); + + // If the incoming value is a constant expr GEP, try peeling out the array + // 0 index if possible to make things simpler. + if (ConstantExpr *CE = dyn_cast(GEPI->getOperand(0))) + if (CE->getOpcode() == Instruction::GetElementPtr) { + unsigned NumOps = CE->getNumOperands(); + assert(NumOps > 1 && "CE folding didn't work!"); + if (CE->getOperand(NumOps-1)->isNullValue()) { + // Check to make sure the last index really is an array index. + gep_type_iterator GTI = gep_type_begin(CE); + for (unsigned i = 1, e = CE->getNumOperands()-1; + i != e; ++i, ++GTI) + /*empty*/; + if (isa(*GTI)) { + // Pull the last index out of the constant expr GEP. + SmallVector CEIdxs(CE->op_begin()+1, CE->op_end()-1); + Constant *NCE = ConstantExpr::getGetElementPtr(CE->getOperand(0), + &CEIdxs[0], + CEIdxs.size()); + GetElementPtrInst *NGEPI = new GetElementPtrInst( + NCE, Constant::getNullValue(Type::Int32Ty), NewAdd, + GEPI->getName(), GEPI); + SE->deleteValueFromRecords(GEPI); + GEPI->replaceAllUsesWith(NGEPI); + GEPI->eraseFromParent(); + GEPI = NGEPI; + } + } + } + + + // Finally, if there are any other users of the PHI node, we must + // insert a new GEP instruction that uses the pre-incremented version + // of the induction amount. + if (!PN->use_empty()) { + BasicBlock::iterator InsertPos = PN; ++InsertPos; + while (isa(InsertPos)) ++InsertPos; + Value *PreInc = + new GetElementPtrInst(PN->getIncomingValue(PreheaderIdx), + NewPhi, "", InsertPos); + PreInc->takeName(PN); + PN->replaceAllUsesWith(PreInc); + } + + // Delete the old PHI for sure, and the GEP if its otherwise unused. + DeadInsts.insert(PN); + + ++NumPointer; + Changed = true; + } +} + +/// LinearFunctionTestReplace - This method rewrites the exit condition of the +/// loop to be a canonical != comparison against the incremented loop induction +/// variable. This pass is able to rewrite the exit tests of any loop where the +/// SCEV analysis can determine a loop-invariant trip count of the loop, which +/// is actually a much broader range than just linear tests. +/// +/// This method returns a "potentially dead" instruction whose computation chain +/// should be deleted when convenient. +Instruction *IndVarSimplify::LinearFunctionTestReplace(Loop *L, + SCEV *IterationCount, + SCEVExpander &RW) { + // Find the exit block for the loop. We can currently only handle loops with + // a single exit. + std::vector ExitBlocks; + L->getExitBlocks(ExitBlocks); + if (ExitBlocks.size() != 1) return 0; + BasicBlock *ExitBlock = ExitBlocks[0]; + + // Make sure there is only one predecessor block in the loop. + BasicBlock *ExitingBlock = 0; + for (pred_iterator PI = pred_begin(ExitBlock), PE = pred_end(ExitBlock); + PI != PE; ++PI) + if (L->contains(*PI)) { + if (ExitingBlock == 0) + ExitingBlock = *PI; + else + return 0; // Multiple exits from loop to this block. + } + assert(ExitingBlock && "Loop info is broken"); + + if (!isa(ExitingBlock->getTerminator())) + return 0; // Can't rewrite non-branch yet + BranchInst *BI = cast(ExitingBlock->getTerminator()); + assert(BI->isConditional() && "Must be conditional to be part of loop!"); + + Instruction *PotentiallyDeadInst = dyn_cast(BI->getCondition()); + + // If the exiting block is not the same as the backedge block, we must compare + // against the preincremented value, otherwise we prefer to compare against + // the post-incremented value. + BasicBlock *Header = L->getHeader(); + pred_iterator HPI = pred_begin(Header); + assert(HPI != pred_end(Header) && "Loop with zero preds???"); + if (!L->contains(*HPI)) ++HPI; + assert(HPI != pred_end(Header) && L->contains(*HPI) && + "No backedge in loop?"); + + SCEVHandle TripCount = IterationCount; + Value *IndVar; + if (*HPI == ExitingBlock) { + // The IterationCount expression contains the number of times that the + // backedge actually branches to the loop header. This is one less than the + // number of times the loop executes, so add one to it. + ConstantInt *OneC = ConstantInt::get(IterationCount->getType(), 1); + TripCount = SCEVAddExpr::get(IterationCount, SCEVConstant::get(OneC)); + IndVar = L->getCanonicalInductionVariableIncrement(); + } else { + // We have to use the preincremented value... + IndVar = L->getCanonicalInductionVariable(); + } + + DOUT << "INDVARS: LFTR: TripCount = " << *TripCount + << " IndVar = " << *IndVar << "\n"; + + // Expand the code for the iteration count into the preheader of the loop. + BasicBlock *Preheader = L->getLoopPreheader(); + Value *ExitCnt = RW.expandCodeFor(TripCount, Preheader->getTerminator()); + + // Insert a new icmp_ne or icmp_eq instruction before the branch. + ICmpInst::Predicate Opcode; + if (L->contains(BI->getSuccessor(0))) + Opcode = ICmpInst::ICMP_NE; + else + Opcode = ICmpInst::ICMP_EQ; + + Value *Cond = new ICmpInst(Opcode, IndVar, ExitCnt, "exitcond", BI); + BI->setCondition(Cond); + ++NumLFTR; + Changed = true; + return PotentiallyDeadInst; +} + + +/// RewriteLoopExitValues - Check to see if this loop has a computable +/// loop-invariant execution count. If so, this means that we can compute the +/// final value of any expressions that are recurrent in the loop, and +/// substitute the exit values from the loop into any instructions outside of +/// the loop that use the final values of the current expressions. +void IndVarSimplify::RewriteLoopExitValues(Loop *L) { + BasicBlock *Preheader = L->getLoopPreheader(); + + // Scan all of the instructions in the loop, looking at those that have + // extra-loop users and which are recurrences. + SCEVExpander Rewriter(*SE, *LI); + + // We insert the code into the preheader of the loop if the loop contains + // multiple exit blocks, or in the exit block if there is exactly one. + BasicBlock *BlockToInsertInto; + std::vector ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + if (ExitBlocks.size() == 1) + BlockToInsertInto = ExitBlocks[0]; + else + BlockToInsertInto = Preheader; + BasicBlock::iterator InsertPt = BlockToInsertInto->begin(); + while (isa(InsertPt)) ++InsertPt; + + bool HasConstantItCount = isa(SE->getIterationCount(L)); + + std::set InstructionsToDelete; + std::map ExitValues; + + // Find all values that are computed inside the loop, but used outside of it. + // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan + // the exit blocks of the loop to find them. + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + BasicBlock *ExitBB = ExitBlocks[i]; + + // If there are no PHI nodes in this exit block, then no values defined + // inside the loop are used on this path, skip it. + PHINode *PN = dyn_cast(ExitBB->begin()); + if (!PN) continue; + + unsigned NumPreds = PN->getNumIncomingValues(); + + // Iterate over all of the PHI nodes. + BasicBlock::iterator BBI = ExitBB->begin(); + while ((PN = dyn_cast(BBI++))) { + + // Iterate over all of the values in all the PHI nodes. + for (unsigned i = 0; i != NumPreds; ++i) { + // If the value being merged in is not integer or is not defined + // in the loop, skip it. + Value *InVal = PN->getIncomingValue(i); + if (!isa(InVal) || + // SCEV only supports integer expressions for now. + !isa(InVal->getType())) + continue; + + // If this pred is for a subloop, not L itself, skip it. + if (LI->getLoopFor(PN->getIncomingBlock(i)) != L) + continue; // The Block is in a subloop, skip it. + + // Check that InVal is defined in the loop. + Instruction *Inst = cast(InVal); + if (!L->contains(Inst->getParent())) + continue; + + // We require that this value either have a computable evolution or that + // the loop have a constant iteration count. In the case where the loop + // has a constant iteration count, we can sometimes force evaluation of + // the exit value through brute force. + SCEVHandle SH = SE->getSCEV(Inst); + if (!SH->hasComputableLoopEvolution(L) && !HasConstantItCount) + continue; // Cannot get exit evolution for the loop value. + + // Okay, this instruction has a user outside of the current loop + // and varies predictably *inside* the loop. Evaluate the value it + // contains when the loop exits, if possible. + SCEVHandle ExitValue = SE->getSCEVAtScope(Inst, L->getParentLoop()); + if (isa(ExitValue) || + !ExitValue->isLoopInvariant(L)) + continue; + + Changed = true; + ++NumReplaced; + + // See if we already computed the exit value for the instruction, if so, + // just reuse it. + Value *&ExitVal = ExitValues[Inst]; + if (!ExitVal) + ExitVal = Rewriter.expandCodeFor(ExitValue, InsertPt); + + DOUT << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal + << " LoopVal = " << *Inst << "\n"; + + PN->setIncomingValue(i, ExitVal); + + // If this instruction is dead now, schedule it to be removed. + if (Inst->use_empty()) + InstructionsToDelete.insert(Inst); + + // See if this is a single-entry LCSSA PHI node. If so, we can (and + // have to) remove + // the PHI entirely. This is safe, because the NewVal won't be variant + // in the loop, so we don't need an LCSSA phi node anymore. + if (NumPreds == 1) { + SE->deleteValueFromRecords(PN); + PN->replaceAllUsesWith(ExitVal); + PN->eraseFromParent(); + break; + } + } + } + } + + DeleteTriviallyDeadInstructions(InstructionsToDelete); +} + +bool IndVarSimplify::doInitialization(Loop *L, LPPassManager &LPM) { + + Changed = false; + // First step. Check to see if there are any trivial GEP pointer recurrences. + // If there are, change them into integer recurrences, permitting analysis by + // the SCEV routines. + // + BasicBlock *Header = L->getHeader(); + BasicBlock *Preheader = L->getLoopPreheader(); + SE = &LPM.getAnalysis(); + + std::set DeadInsts; + for (BasicBlock::iterator I = Header->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + if (isa(PN->getType())) + EliminatePointerRecurrence(PN, Preheader, DeadInsts); + } + + if (!DeadInsts.empty()) + DeleteTriviallyDeadInstructions(DeadInsts); + + return Changed; +} + +bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { + + + LI = &getAnalysis(); + SE = &getAnalysis(); + + Changed = false; + BasicBlock *Header = L->getHeader(); + std::set DeadInsts; + + // Verify the input to the pass in already in LCSSA form. + assert(L->isLCSSAForm()); + + // Check to see if this loop has a computable loop-invariant execution count. + // If so, this means that we can compute the final value of any expressions + // that are recurrent in the loop, and substitute the exit values from the + // loop into any instructions outside of the loop that use the final values of + // the current expressions. + // + SCEVHandle IterationCount = SE->getIterationCount(L); + if (!isa(IterationCount)) + RewriteLoopExitValues(L); + + // Next, analyze all of the induction variables in the loop, canonicalizing + // auxillary induction variables. + std::vector > IndVars; + + for (BasicBlock::iterator I = Header->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + if (PN->getType()->isInteger()) { // FIXME: when we have fast-math, enable! + SCEVHandle SCEV = SE->getSCEV(PN); + if (SCEV->hasComputableLoopEvolution(L)) + // FIXME: It is an extremely bad idea to indvar substitute anything more + // complex than affine induction variables. Doing so will put expensive + // polynomial evaluations inside of the loop, and the str reduction pass + // currently can only reduce affine polynomials. For now just disable + // indvar subst on anything more complex than an affine addrec. + if (SCEVAddRecExpr *AR = dyn_cast(SCEV)) + if (AR->isAffine()) + IndVars.push_back(std::make_pair(PN, SCEV)); + } + } + + // If there are no induction variables in the loop, there is nothing more to + // do. + if (IndVars.empty()) { + // Actually, if we know how many times the loop iterates, lets insert a + // canonical induction variable to help subsequent passes. + if (!isa(IterationCount)) { + SCEVExpander Rewriter(*SE, *LI); + Rewriter.getOrInsertCanonicalInductionVariable(L, + IterationCount->getType()); + if (Instruction *I = LinearFunctionTestReplace(L, IterationCount, + Rewriter)) { + std::set InstructionsToDelete; + InstructionsToDelete.insert(I); + DeleteTriviallyDeadInstructions(InstructionsToDelete); + } + } + return Changed; + } + + // Compute the type of the largest recurrence expression. + // + const Type *LargestType = IndVars[0].first->getType(); + bool DifferingSizes = false; + for (unsigned i = 1, e = IndVars.size(); i != e; ++i) { + const Type *Ty = IndVars[i].first->getType(); + DifferingSizes |= + Ty->getPrimitiveSizeInBits() != LargestType->getPrimitiveSizeInBits(); + if (Ty->getPrimitiveSizeInBits() > LargestType->getPrimitiveSizeInBits()) + LargestType = Ty; + } + + // Create a rewriter object which we'll use to transform the code with. + SCEVExpander Rewriter(*SE, *LI); + + // Now that we know the largest of of the induction variables in this loop, + // insert a canonical induction variable of the largest size. + Value *IndVar = Rewriter.getOrInsertCanonicalInductionVariable(L,LargestType); + ++NumInserted; + Changed = true; + DOUT << "INDVARS: New CanIV: " << *IndVar; + + if (!isa(IterationCount)) { + if (IterationCount->getType()->getPrimitiveSizeInBits() < + LargestType->getPrimitiveSizeInBits()) + IterationCount = SCEVZeroExtendExpr::get(IterationCount, LargestType); + else if (IterationCount->getType() != LargestType) + IterationCount = SCEVTruncateExpr::get(IterationCount, LargestType); + if (Instruction *DI = LinearFunctionTestReplace(L, IterationCount,Rewriter)) + DeadInsts.insert(DI); + } + + // Now that we have a canonical induction variable, we can rewrite any + // recurrences in terms of the induction variable. Start with the auxillary + // induction variables, and recursively rewrite any of their uses. + BasicBlock::iterator InsertPt = Header->begin(); + while (isa(InsertPt)) ++InsertPt; + + // If there were induction variables of other sizes, cast the primary + // induction variable to the right size for them, avoiding the need for the + // code evaluation methods to insert induction variables of different sizes. + if (DifferingSizes) { + SmallVector InsertedSizes; + InsertedSizes.push_back(LargestType->getPrimitiveSizeInBits()); + for (unsigned i = 0, e = IndVars.size(); i != e; ++i) { + unsigned ithSize = IndVars[i].first->getType()->getPrimitiveSizeInBits(); + if (std::find(InsertedSizes.begin(), InsertedSizes.end(), ithSize) + == InsertedSizes.end()) { + PHINode *PN = IndVars[i].first; + InsertedSizes.push_back(ithSize); + Instruction *New = new TruncInst(IndVar, PN->getType(), "indvar", + InsertPt); + Rewriter.addInsertedValue(New, SE->getSCEV(New)); + DOUT << "INDVARS: Made trunc IV for " << *PN + << " NewVal = " << *New << "\n"; + } + } + } + + // Rewrite all induction variables in terms of the canonical induction + // variable. + std::map InsertedSizes; + while (!IndVars.empty()) { + PHINode *PN = IndVars.back().first; + Value *NewVal = Rewriter.expandCodeFor(IndVars.back().second, InsertPt); + DOUT << "INDVARS: Rewrote IV '" << *IndVars.back().second << "' " << *PN + << " into = " << *NewVal << "\n"; + NewVal->takeName(PN); + + // Replace the old PHI Node with the inserted computation. + PN->replaceAllUsesWith(NewVal); + DeadInsts.insert(PN); + IndVars.pop_back(); + ++NumRemoved; + Changed = true; + } + +#if 0 + // Now replace all derived expressions in the loop body with simpler + // expressions. + for (unsigned i = 0, e = L->getBlocks().size(); i != e; ++i) + if (LI->getLoopFor(L->getBlocks()[i]) == L) { // Not in a subloop... + BasicBlock *BB = L->getBlocks()[i]; + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (I->getType()->isInteger() && // Is an integer instruction + !I->use_empty() && + !Rewriter.isInsertedInstruction(I)) { + SCEVHandle SH = SE->getSCEV(I); + Value *V = Rewriter.expandCodeFor(SH, I, I->getType()); + if (V != I) { + if (isa(V)) + V->takeName(I); + I->replaceAllUsesWith(V); + DeadInsts.insert(I); + ++NumRemoved; + Changed = true; + } + } + } +#endif + + DeleteTriviallyDeadInstructions(DeadInsts); + + assert(L->isLCSSAForm()); + return Changed; +} diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp new file mode 100644 index 0000000..816a1c6 --- /dev/null +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -0,0 +1,10090 @@ +//===- InstructionCombining.cpp - Combine multiple instructions -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// InstructionCombining - Combine instructions to form fewer, simple +// instructions. This pass does not modify the CFG This pass is where algebraic +// simplification happens. +// +// This pass combines things like: +// %Y = add i32 %X, 1 +// %Z = add i32 %Y, 1 +// into: +// %Z = add i32 %X, 2 +// +// This is a simple worklist driven algorithm. +// +// This pass guarantees that the following canonicalizations are performed on +// the program: +// 1. If a binary operator has a constant operand, it is moved to the RHS +// 2. Bitwise operators with constant operands are always grouped so that +// shifts are performed first, then or's, then and's, then xor's. +// 3. Compare instructions are converted from <,>,<=,>= to ==,!= if possible +// 4. All cmp instructions on boolean values are replaced with logical ops +// 5. add X, X is represented as (X*2) => (X << 1) +// 6. Multiplies with a power-of-two constant argument are transformed into +// shifts. +// ... etc. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Scalar.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/DerivedTypes.h" +#include "llvm/GlobalVariable.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Support/InstVisitor.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/PatternMatch.h" +#include "llvm/Support/Compiler.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include +#include +using namespace llvm; +using namespace llvm::PatternMatch; + +STATISTIC(NumCombined , "Number of insts combined"); +STATISTIC(NumConstProp, "Number of constant folds"); +STATISTIC(NumDeadInst , "Number of dead inst eliminated"); +STATISTIC(NumDeadStore, "Number of dead stores eliminated"); +STATISTIC(NumSunkInst , "Number of instructions sunk"); + +namespace { + class VISIBILITY_HIDDEN InstCombiner + : public FunctionPass, + public InstVisitor { + // Worklist of all of the instructions that need to be simplified. + std::vector Worklist; + DenseMap WorklistMap; + TargetData *TD; + bool MustPreserveLCSSA; + public: + static char ID; // Pass identification, replacement for typeid + InstCombiner() : FunctionPass((intptr_t)&ID) {} + + /// AddToWorkList - Add the specified instruction to the worklist if it + /// isn't already in it. + void AddToWorkList(Instruction *I) { + if (WorklistMap.insert(std::make_pair(I, Worklist.size()))) + Worklist.push_back(I); + } + + // RemoveFromWorkList - remove I from the worklist if it exists. + void RemoveFromWorkList(Instruction *I) { + DenseMap::iterator It = WorklistMap.find(I); + if (It == WorklistMap.end()) return; // Not in worklist. + + // Don't bother moving everything down, just null out the slot. + Worklist[It->second] = 0; + + WorklistMap.erase(It); + } + + Instruction *RemoveOneFromWorkList() { + Instruction *I = Worklist.back(); + Worklist.pop_back(); + WorklistMap.erase(I); + return I; + } + + + /// AddUsersToWorkList - When an instruction is simplified, add all users of + /// the instruction to the work lists because they might get more simplified + /// now. + /// + void AddUsersToWorkList(Value &I) { + for (Value::use_iterator UI = I.use_begin(), UE = I.use_end(); + UI != UE; ++UI) + AddToWorkList(cast(*UI)); + } + + /// AddUsesToWorkList - When an instruction is simplified, add operands to + /// the work lists because they might get more simplified now. + /// + void AddUsesToWorkList(Instruction &I) { + for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) + if (Instruction *Op = dyn_cast(I.getOperand(i))) + AddToWorkList(Op); + } + + /// AddSoonDeadInstToWorklist - The specified instruction is about to become + /// dead. Add all of its operands to the worklist, turning them into + /// undef's to reduce the number of uses of those instructions. + /// + /// Return the specified operand before it is turned into an undef. + /// + Value *AddSoonDeadInstToWorklist(Instruction &I, unsigned op) { + Value *R = I.getOperand(op); + + for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) + if (Instruction *Op = dyn_cast(I.getOperand(i))) { + AddToWorkList(Op); + // Set the operand to undef to drop the use. + I.setOperand(i, UndefValue::get(Op->getType())); + } + + return R; + } + + public: + virtual bool runOnFunction(Function &F); + + bool DoOneIteration(Function &F, unsigned ItNum); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.addPreservedID(LCSSAID); + AU.setPreservesCFG(); + } + + TargetData &getTargetData() const { return *TD; } + + // Visitation implementation - Implement instruction combining for different + // instruction types. The semantics are as follows: + // Return Value: + // null - No change was made + // I - Change was made, I is still valid, I may be dead though + // otherwise - Change was made, replace I with returned instruction + // + Instruction *visitAdd(BinaryOperator &I); + Instruction *visitSub(BinaryOperator &I); + Instruction *visitMul(BinaryOperator &I); + Instruction *visitURem(BinaryOperator &I); + Instruction *visitSRem(BinaryOperator &I); + Instruction *visitFRem(BinaryOperator &I); + Instruction *commonRemTransforms(BinaryOperator &I); + Instruction *commonIRemTransforms(BinaryOperator &I); + Instruction *commonDivTransforms(BinaryOperator &I); + Instruction *commonIDivTransforms(BinaryOperator &I); + Instruction *visitUDiv(BinaryOperator &I); + Instruction *visitSDiv(BinaryOperator &I); + Instruction *visitFDiv(BinaryOperator &I); + Instruction *visitAnd(BinaryOperator &I); + Instruction *visitOr (BinaryOperator &I); + Instruction *visitXor(BinaryOperator &I); + Instruction *visitShl(BinaryOperator &I); + Instruction *visitAShr(BinaryOperator &I); + Instruction *visitLShr(BinaryOperator &I); + Instruction *commonShiftTransforms(BinaryOperator &I); + Instruction *visitFCmpInst(FCmpInst &I); + Instruction *visitICmpInst(ICmpInst &I); + Instruction *visitICmpInstWithCastAndCast(ICmpInst &ICI); + Instruction *visitICmpInstWithInstAndIntCst(ICmpInst &ICI, + Instruction *LHS, + ConstantInt *RHS); + Instruction *FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, + ConstantInt *DivRHS); + + Instruction *FoldGEPICmp(User *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, Instruction &I); + Instruction *FoldShiftByConstant(Value *Op0, ConstantInt *Op1, + BinaryOperator &I); + Instruction *commonCastTransforms(CastInst &CI); + Instruction *commonIntCastTransforms(CastInst &CI); + Instruction *commonPointerCastTransforms(CastInst &CI); + Instruction *visitTrunc(TruncInst &CI); + Instruction *visitZExt(ZExtInst &CI); + Instruction *visitSExt(SExtInst &CI); + Instruction *visitFPTrunc(CastInst &CI); + Instruction *visitFPExt(CastInst &CI); + Instruction *visitFPToUI(CastInst &CI); + Instruction *visitFPToSI(CastInst &CI); + Instruction *visitUIToFP(CastInst &CI); + Instruction *visitSIToFP(CastInst &CI); + Instruction *visitPtrToInt(CastInst &CI); + Instruction *visitIntToPtr(CastInst &CI); + Instruction *visitBitCast(BitCastInst &CI); + Instruction *FoldSelectOpOp(SelectInst &SI, Instruction *TI, + Instruction *FI); + Instruction *visitSelectInst(SelectInst &CI); + Instruction *visitCallInst(CallInst &CI); + Instruction *visitInvokeInst(InvokeInst &II); + Instruction *visitPHINode(PHINode &PN); + Instruction *visitGetElementPtrInst(GetElementPtrInst &GEP); + Instruction *visitAllocationInst(AllocationInst &AI); + Instruction *visitFreeInst(FreeInst &FI); + Instruction *visitLoadInst(LoadInst &LI); + Instruction *visitStoreInst(StoreInst &SI); + Instruction *visitBranchInst(BranchInst &BI); + Instruction *visitSwitchInst(SwitchInst &SI); + Instruction *visitInsertElementInst(InsertElementInst &IE); + Instruction *visitExtractElementInst(ExtractElementInst &EI); + Instruction *visitShuffleVectorInst(ShuffleVectorInst &SVI); + + // visitInstruction - Specify what to return for unhandled instructions... + Instruction *visitInstruction(Instruction &I) { return 0; } + + private: + Instruction *visitCallSite(CallSite CS); + bool transformConstExprCastCall(CallSite CS); + + public: + // InsertNewInstBefore - insert an instruction New before instruction Old + // in the program. Add the new instruction to the worklist. + // + Instruction *InsertNewInstBefore(Instruction *New, Instruction &Old) { + assert(New && New->getParent() == 0 && + "New instruction already inserted into a basic block!"); + BasicBlock *BB = Old.getParent(); + BB->getInstList().insert(&Old, New); // Insert inst + AddToWorkList(New); + return New; + } + + /// InsertCastBefore - Insert a cast of V to TY before the instruction POS. + /// This also adds the cast to the worklist. Finally, this returns the + /// cast. + Value *InsertCastBefore(Instruction::CastOps opc, Value *V, const Type *Ty, + Instruction &Pos) { + if (V->getType() == Ty) return V; + + if (Constant *CV = dyn_cast(V)) + return ConstantExpr::getCast(opc, CV, Ty); + + Instruction *C = CastInst::create(opc, V, Ty, V->getName(), &Pos); + AddToWorkList(C); + return C; + } + + // ReplaceInstUsesWith - This method is to be used when an instruction is + // found to be dead, replacable with another preexisting expression. Here + // we add all uses of I to the worklist, replace all uses of I with the new + // value, then return I, so that the inst combiner will know that I was + // modified. + // + Instruction *ReplaceInstUsesWith(Instruction &I, Value *V) { + AddUsersToWorkList(I); // Add all modified instrs to worklist + if (&I != V) { + I.replaceAllUsesWith(V); + return &I; + } else { + // If we are replacing the instruction with itself, this must be in a + // segment of unreachable code, so just clobber the instruction. + I.replaceAllUsesWith(UndefValue::get(I.getType())); + return &I; + } + } + + // UpdateValueUsesWith - This method is to be used when an value is + // found to be replacable with another preexisting expression or was + // updated. Here we add all uses of I to the worklist, replace all uses of + // I with the new value (unless the instruction was just updated), then + // return true, so that the inst combiner will know that I was modified. + // + bool UpdateValueUsesWith(Value *Old, Value *New) { + AddUsersToWorkList(*Old); // Add all modified instrs to worklist + if (Old != New) + Old->replaceAllUsesWith(New); + if (Instruction *I = dyn_cast(Old)) + AddToWorkList(I); + if (Instruction *I = dyn_cast(New)) + AddToWorkList(I); + return true; + } + + // EraseInstFromFunction - When dealing with an instruction that has side + // effects or produces a void value, we can't rely on DCE to delete the + // instruction. Instead, visit methods should return the value returned by + // this function. + Instruction *EraseInstFromFunction(Instruction &I) { + assert(I.use_empty() && "Cannot erase instruction that is used!"); + AddUsesToWorkList(I); + RemoveFromWorkList(&I); + I.eraseFromParent(); + return 0; // Don't do anything with FI + } + + private: + /// InsertOperandCastBefore - This inserts a cast of V to DestTy before the + /// InsertBefore instruction. This is specialized a bit to avoid inserting + /// casts that are known to not do anything... + /// + Value *InsertOperandCastBefore(Instruction::CastOps opcode, + Value *V, const Type *DestTy, + Instruction *InsertBefore); + + /// SimplifyCommutative - This performs a few simplifications for + /// commutative operators. + bool SimplifyCommutative(BinaryOperator &I); + + /// SimplifyCompare - This reorders the operands of a CmpInst to get them in + /// most-complex to least-complex order. + bool SimplifyCompare(CmpInst &I); + + /// SimplifyDemandedBits - Attempts to replace V with a simpler value based + /// on the demanded bits. + bool SimplifyDemandedBits(Value *V, APInt DemandedMask, + APInt& KnownZero, APInt& KnownOne, + unsigned Depth = 0); + + Value *SimplifyDemandedVectorElts(Value *V, uint64_t DemandedElts, + uint64_t &UndefElts, unsigned Depth = 0); + + // FoldOpIntoPhi - Given a binary operator or cast instruction which has a + // PHI node as operand #0, see if we can fold the instruction into the PHI + // (which is only possible if all operands to the PHI are constants). + Instruction *FoldOpIntoPhi(Instruction &I); + + // FoldPHIArgOpIntoPHI - If all operands to a PHI node are the same "unary" + // operator and they all are only used by the PHI, PHI together their + // inputs, and do the operation once, to the result of the PHI. + Instruction *FoldPHIArgOpIntoPHI(PHINode &PN); + Instruction *FoldPHIArgBinOpIntoPHI(PHINode &PN); + + + Instruction *OptAndOp(Instruction *Op, ConstantInt *OpRHS, + ConstantInt *AndRHS, BinaryOperator &TheAnd); + + Value *FoldLogicalPlusAnd(Value *LHS, Value *RHS, ConstantInt *Mask, + bool isSub, Instruction &I); + Instruction *InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, + bool isSigned, bool Inside, Instruction &IB); + Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocationInst &AI); + Instruction *MatchBSwap(BinaryOperator &I); + bool SimplifyStoreAtEndOfBlock(StoreInst &SI); + + Value *EvaluateInDifferentType(Value *V, const Type *Ty, bool isSigned); + }; + + char InstCombiner::ID = 0; + RegisterPass X("instcombine", "Combine redundant instructions"); +} + +// getComplexity: Assign a complexity or rank value to LLVM Values... +// 0 -> undef, 1 -> Const, 2 -> Other, 3 -> Arg, 3 -> Unary, 4 -> OtherInst +static unsigned getComplexity(Value *V) { + if (isa(V)) { + if (BinaryOperator::isNeg(V) || BinaryOperator::isNot(V)) + return 3; + return 4; + } + if (isa(V)) return 3; + return isa(V) ? (isa(V) ? 0 : 1) : 2; +} + +// isOnlyUse - Return true if this instruction will be deleted if we stop using +// it. +static bool isOnlyUse(Value *V) { + return V->hasOneUse() || isa(V); +} + +// getPromotedType - Return the specified type promoted as it would be to pass +// though a va_arg area... +static const Type *getPromotedType(const Type *Ty) { + if (const IntegerType* ITy = dyn_cast(Ty)) { + if (ITy->getBitWidth() < 32) + return Type::Int32Ty; + } + return Ty; +} + +/// getBitCastOperand - If the specified operand is a CastInst or a constant +/// expression bitcast, return the operand value, otherwise return null. +static Value *getBitCastOperand(Value *V) { + if (BitCastInst *I = dyn_cast(V)) + return I->getOperand(0); + else if (ConstantExpr *CE = dyn_cast(V)) + if (CE->getOpcode() == Instruction::BitCast) + return CE->getOperand(0); + return 0; +} + +/// This function is a wrapper around CastInst::isEliminableCastPair. It +/// simply extracts arguments and returns what that function returns. +static Instruction::CastOps +isEliminableCastPair( + const CastInst *CI, ///< The first cast instruction + unsigned opcode, ///< The opcode of the second cast instruction + const Type *DstTy, ///< The target type for the second cast instruction + TargetData *TD ///< The target data for pointer size +) { + + const Type *SrcTy = CI->getOperand(0)->getType(); // A from above + const Type *MidTy = CI->getType(); // B from above + + // Get the opcodes of the two Cast instructions + Instruction::CastOps firstOp = Instruction::CastOps(CI->getOpcode()); + Instruction::CastOps secondOp = Instruction::CastOps(opcode); + + return Instruction::CastOps( + CastInst::isEliminableCastPair(firstOp, secondOp, SrcTy, MidTy, + DstTy, TD->getIntPtrType())); +} + +/// ValueRequiresCast - Return true if the cast from "V to Ty" actually results +/// in any code being generated. It does not require codegen if V is simple +/// enough or if the cast can be folded into other casts. +static bool ValueRequiresCast(Instruction::CastOps opcode, const Value *V, + const Type *Ty, TargetData *TD) { + if (V->getType() == Ty || isa(V)) return false; + + // If this is another cast that can be eliminated, it isn't codegen either. + if (const CastInst *CI = dyn_cast(V)) + if (isEliminableCastPair(CI, opcode, Ty, TD)) + return false; + return true; +} + +/// InsertOperandCastBefore - This inserts a cast of V to DestTy before the +/// InsertBefore instruction. This is specialized a bit to avoid inserting +/// casts that are known to not do anything... +/// +Value *InstCombiner::InsertOperandCastBefore(Instruction::CastOps opcode, + Value *V, const Type *DestTy, + Instruction *InsertBefore) { + if (V->getType() == DestTy) return V; + if (Constant *C = dyn_cast(V)) + return ConstantExpr::getCast(opcode, C, DestTy); + + return InsertCastBefore(opcode, V, DestTy, *InsertBefore); +} + +// SimplifyCommutative - This performs a few simplifications for commutative +// operators: +// +// 1. Order operands such that they are listed from right (least complex) to +// left (most complex). This puts constants before unary operators before +// binary operators. +// +// 2. Transform: (op (op V, C1), C2) ==> (op V, (op C1, C2)) +// 3. Transform: (op (op V1, C1), (op V2, C2)) ==> (op (op V1, V2), (op C1,C2)) +// +bool InstCombiner::SimplifyCommutative(BinaryOperator &I) { + bool Changed = false; + if (getComplexity(I.getOperand(0)) < getComplexity(I.getOperand(1))) + Changed = !I.swapOperands(); + + if (!I.isAssociative()) return Changed; + Instruction::BinaryOps Opcode = I.getOpcode(); + if (BinaryOperator *Op = dyn_cast(I.getOperand(0))) + if (Op->getOpcode() == Opcode && isa(Op->getOperand(1))) { + if (isa(I.getOperand(1))) { + Constant *Folded = ConstantExpr::get(I.getOpcode(), + cast(I.getOperand(1)), + cast(Op->getOperand(1))); + I.setOperand(0, Op->getOperand(0)); + I.setOperand(1, Folded); + return true; + } else if (BinaryOperator *Op1=dyn_cast(I.getOperand(1))) + if (Op1->getOpcode() == Opcode && isa(Op1->getOperand(1)) && + isOnlyUse(Op) && isOnlyUse(Op1)) { + Constant *C1 = cast(Op->getOperand(1)); + Constant *C2 = cast(Op1->getOperand(1)); + + // Fold (op (op V1, C1), (op V2, C2)) ==> (op (op V1, V2), (op C1,C2)) + Constant *Folded = ConstantExpr::get(I.getOpcode(), C1, C2); + Instruction *New = BinaryOperator::create(Opcode, Op->getOperand(0), + Op1->getOperand(0), + Op1->getName(), &I); + AddToWorkList(New); + I.setOperand(0, New); + I.setOperand(1, Folded); + return true; + } + } + return Changed; +} + +/// SimplifyCompare - For a CmpInst this function just orders the operands +/// so that theyare listed from right (least complex) to left (most complex). +/// This puts constants before unary operators before binary operators. +bool InstCombiner::SimplifyCompare(CmpInst &I) { + if (getComplexity(I.getOperand(0)) >= getComplexity(I.getOperand(1))) + return false; + I.swapOperands(); + // Compare instructions are not associative so there's nothing else we can do. + return true; +} + +// dyn_castNegVal - Given a 'sub' instruction, return the RHS of the instruction +// if the LHS is a constant zero (which is the 'negate' form). +// +static inline Value *dyn_castNegVal(Value *V) { + if (BinaryOperator::isNeg(V)) + return BinaryOperator::getNegArgument(V); + + // Constants can be considered to be negated values if they can be folded. + if (ConstantInt *C = dyn_cast(V)) + return ConstantExpr::getNeg(C); + return 0; +} + +static inline Value *dyn_castNotVal(Value *V) { + if (BinaryOperator::isNot(V)) + return BinaryOperator::getNotArgument(V); + + // Constants can be considered to be not'ed values... + if (ConstantInt *C = dyn_cast(V)) + return ConstantInt::get(~C->getValue()); + return 0; +} + +// dyn_castFoldableMul - If this value is a multiply that can be folded into +// other computations (because it has a constant operand), return the +// non-constant operand of the multiply, and set CST to point to the multiplier. +// Otherwise, return null. +// +static inline Value *dyn_castFoldableMul(Value *V, ConstantInt *&CST) { + if (V->hasOneUse() && V->getType()->isInteger()) + if (Instruction *I = dyn_cast(V)) { + if (I->getOpcode() == Instruction::Mul) + if ((CST = dyn_cast(I->getOperand(1)))) + return I->getOperand(0); + if (I->getOpcode() == Instruction::Shl) + if ((CST = dyn_cast(I->getOperand(1)))) { + // The multiplier is really 1 << CST. + uint32_t BitWidth = cast(V->getType())->getBitWidth(); + uint32_t CSTVal = CST->getLimitedValue(BitWidth); + CST = ConstantInt::get(APInt(BitWidth, 1).shl(CSTVal)); + return I->getOperand(0); + } + } + return 0; +} + +/// dyn_castGetElementPtr - If this is a getelementptr instruction or constant +/// expression, return it. +static User *dyn_castGetElementPtr(Value *V) { + if (isa(V)) return cast(V); + if (ConstantExpr *CE = dyn_cast(V)) + if (CE->getOpcode() == Instruction::GetElementPtr) + return cast(V); + return false; +} + +/// AddOne - Add one to a ConstantInt +static ConstantInt *AddOne(ConstantInt *C) { + APInt Val(C->getValue()); + return ConstantInt::get(++Val); +} +/// SubOne - Subtract one from a ConstantInt +static ConstantInt *SubOne(ConstantInt *C) { + APInt Val(C->getValue()); + return ConstantInt::get(--Val); +} +/// Add - Add two ConstantInts together +static ConstantInt *Add(ConstantInt *C1, ConstantInt *C2) { + return ConstantInt::get(C1->getValue() + C2->getValue()); +} +/// And - Bitwise AND two ConstantInts together +static ConstantInt *And(ConstantInt *C1, ConstantInt *C2) { + return ConstantInt::get(C1->getValue() & C2->getValue()); +} +/// Subtract - Subtract one ConstantInt from another +static ConstantInt *Subtract(ConstantInt *C1, ConstantInt *C2) { + return ConstantInt::get(C1->getValue() - C2->getValue()); +} +/// Multiply - Multiply two ConstantInts together +static ConstantInt *Multiply(ConstantInt *C1, ConstantInt *C2) { + return ConstantInt::get(C1->getValue() * C2->getValue()); +} + +/// ComputeMaskedBits - Determine which of the bits specified in Mask are +/// known to be either zero or one and return them in the KnownZero/KnownOne +/// bit sets. This code only analyzes bits in Mask, in order to short-circuit +/// processing. +/// NOTE: we cannot consider 'undef' to be "IsZero" here. The problem is that +/// we cannot optimize based on the assumption that it is zero without changing +/// it to be an explicit zero. If we don't change it to zero, other code could +/// optimized based on the contradictory assumption that it is non-zero. +/// Because instcombine aggressively folds operations with undef args anyway, +/// this won't lose us code quality. +static void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero, + APInt& KnownOne, unsigned Depth = 0) { + assert(V && "No Value?"); + assert(Depth <= 6 && "Limit Search Depth"); + uint32_t BitWidth = Mask.getBitWidth(); + assert(cast(V->getType())->getBitWidth() == BitWidth && + KnownZero.getBitWidth() == BitWidth && + KnownOne.getBitWidth() == BitWidth && + "V, Mask, KnownOne and KnownZero should have same BitWidth"); + if (ConstantInt *CI = dyn_cast(V)) { + // We know all of the bits for a constant! + KnownOne = CI->getValue() & Mask; + KnownZero = ~KnownOne & Mask; + return; + } + + if (Depth == 6 || Mask == 0) + return; // Limit search depth. + + Instruction *I = dyn_cast(V); + if (!I) return; + + KnownZero.clear(); KnownOne.clear(); // Don't know anything. + APInt KnownZero2(KnownZero), KnownOne2(KnownOne); + + switch (I->getOpcode()) { + case Instruction::And: { + // If either the LHS or the RHS are Zero, the result is zero. + ComputeMaskedBits(I->getOperand(1), Mask, KnownZero, KnownOne, Depth+1); + APInt Mask2(Mask & ~KnownZero); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero2, KnownOne2, Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); + + // Output known-1 bits are only known if set in both the LHS & RHS. + KnownOne &= KnownOne2; + // Output known-0 are known to be clear if zero in either the LHS | RHS. + KnownZero |= KnownZero2; + return; + } + case Instruction::Or: { + ComputeMaskedBits(I->getOperand(1), Mask, KnownZero, KnownOne, Depth+1); + APInt Mask2(Mask & ~KnownOne); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero2, KnownOne2, Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); + + // Output known-0 bits are only known if clear in both the LHS & RHS. + KnownZero &= KnownZero2; + // Output known-1 are known to be set if set in either the LHS | RHS. + KnownOne |= KnownOne2; + return; + } + case Instruction::Xor: { + ComputeMaskedBits(I->getOperand(1), Mask, KnownZero, KnownOne, Depth+1); + ComputeMaskedBits(I->getOperand(0), Mask, KnownZero2, KnownOne2, Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); + + // Output known-0 bits are known if clear or set in both the LHS & RHS. + APInt KnownZeroOut = (KnownZero & KnownZero2) | (KnownOne & KnownOne2); + // Output known-1 are known to be set if set in only one of the LHS, RHS. + KnownOne = (KnownZero & KnownOne2) | (KnownOne & KnownZero2); + KnownZero = KnownZeroOut; + return; + } + case Instruction::Select: + ComputeMaskedBits(I->getOperand(2), Mask, KnownZero, KnownOne, Depth+1); + ComputeMaskedBits(I->getOperand(1), Mask, KnownZero2, KnownOne2, Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); + + // Only known if known in both the LHS and RHS. + KnownOne &= KnownOne2; + KnownZero &= KnownZero2; + return; + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::SIToFP: + case Instruction::PtrToInt: + case Instruction::UIToFP: + case Instruction::IntToPtr: + return; // Can't work with floating point or pointers + case Instruction::Trunc: { + // All these have integer operands + uint32_t SrcBitWidth = + cast(I->getOperand(0)->getType())->getBitWidth(); + APInt MaskIn(Mask); + MaskIn.zext(SrcBitWidth); + KnownZero.zext(SrcBitWidth); + KnownOne.zext(SrcBitWidth); + ComputeMaskedBits(I->getOperand(0), MaskIn, KnownZero, KnownOne, Depth+1); + KnownZero.trunc(BitWidth); + KnownOne.trunc(BitWidth); + return; + } + case Instruction::BitCast: { + const Type *SrcTy = I->getOperand(0)->getType(); + if (SrcTy->isInteger()) { + ComputeMaskedBits(I->getOperand(0), Mask, KnownZero, KnownOne, Depth+1); + return; + } + break; + } + case Instruction::ZExt: { + // Compute the bits in the result that are not present in the input. + const IntegerType *SrcTy = cast(I->getOperand(0)->getType()); + uint32_t SrcBitWidth = SrcTy->getBitWidth(); + + APInt MaskIn(Mask); + MaskIn.trunc(SrcBitWidth); + KnownZero.trunc(SrcBitWidth); + KnownOne.trunc(SrcBitWidth); + ComputeMaskedBits(I->getOperand(0), MaskIn, KnownZero, KnownOne, Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + // The top bits are known to be zero. + KnownZero.zext(BitWidth); + KnownOne.zext(BitWidth); + KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + return; + } + case Instruction::SExt: { + // Compute the bits in the result that are not present in the input. + const IntegerType *SrcTy = cast(I->getOperand(0)->getType()); + uint32_t SrcBitWidth = SrcTy->getBitWidth(); + + APInt MaskIn(Mask); + MaskIn.trunc(SrcBitWidth); + KnownZero.trunc(SrcBitWidth); + KnownOne.trunc(SrcBitWidth); + ComputeMaskedBits(I->getOperand(0), MaskIn, KnownZero, KnownOne, Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + KnownZero.zext(BitWidth); + KnownOne.zext(BitWidth); + + // If the sign bit of the input is known set or clear, then we know the + // top bits of the result. + if (KnownZero[SrcBitWidth-1]) // Input sign bit known zero + KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + else if (KnownOne[SrcBitWidth-1]) // Input sign bit known set + KnownOne |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + return; + } + case Instruction::Shl: + // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0 + if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); + APInt Mask2(Mask.lshr(ShiftAmt)); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero, KnownOne, Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + KnownZero <<= ShiftAmt; + KnownOne <<= ShiftAmt; + KnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt); // low bits known 0 + return; + } + break; + case Instruction::LShr: + // (ushr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 + if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { + // Compute the new bits that are at the top now. + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); + + // Unsigned shift right. + APInt Mask2(Mask.shl(ShiftAmt)); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero,KnownOne,Depth+1); + assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); + KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); + KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); + // high bits known zero. + KnownZero |= APInt::getHighBitsSet(BitWidth, ShiftAmt); + return; + } + break; + case Instruction::AShr: + // (ashr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 + if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { + // Compute the new bits that are at the top now. + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); + + // Signed shift right. + APInt Mask2(Mask.shl(ShiftAmt)); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero,KnownOne,Depth+1); + assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); + KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); + KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); + + APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); + if (KnownZero[BitWidth-ShiftAmt-1]) // New bits are known zero. + KnownZero |= HighBits; + else if (KnownOne[BitWidth-ShiftAmt-1]) // New bits are known one. + KnownOne |= HighBits; + return; + } + break; + } +} + +/// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero. We use +/// this predicate to simplify operations downstream. Mask is known to be zero +/// for bits that V cannot have. +static bool MaskedValueIsZero(Value *V, const APInt& Mask, unsigned Depth = 0) { + APInt KnownZero(Mask.getBitWidth(), 0), KnownOne(Mask.getBitWidth(), 0); + ComputeMaskedBits(V, Mask, KnownZero, KnownOne, Depth); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + return (KnownZero & Mask) == Mask; +} + +/// ShrinkDemandedConstant - Check to see if the specified operand of the +/// specified instruction is a constant integer. If so, check to see if there +/// are any bits set in the constant that are not demanded. If so, shrink the +/// constant and return true. +static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, + APInt Demanded) { + assert(I && "No instruction?"); + assert(OpNo < I->getNumOperands() && "Operand index too large"); + + // If the operand is not a constant integer, nothing to do. + ConstantInt *OpC = dyn_cast(I->getOperand(OpNo)); + if (!OpC) return false; + + // If there are no bits set that aren't demanded, nothing to do. + Demanded.zextOrTrunc(OpC->getValue().getBitWidth()); + if ((~Demanded & OpC->getValue()) == 0) + return false; + + // This instruction is producing bits that are not demanded. Shrink the RHS. + Demanded &= OpC->getValue(); + I->setOperand(OpNo, ConstantInt::get(Demanded)); + return true; +} + +// ComputeSignedMinMaxValuesFromKnownBits - Given a signed integer type and a +// set of known zero and one bits, compute the maximum and minimum values that +// could have the specified known zero and known one bits, returning them in +// min/max. +static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty, + const APInt& KnownZero, + const APInt& KnownOne, + APInt& Min, APInt& Max) { + uint32_t BitWidth = cast(Ty)->getBitWidth(); + assert(KnownZero.getBitWidth() == BitWidth && + KnownOne.getBitWidth() == BitWidth && + Min.getBitWidth() == BitWidth && Max.getBitWidth() == BitWidth && + "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth."); + APInt UnknownBits = ~(KnownZero|KnownOne); + + // The minimum value is when all unknown bits are zeros, EXCEPT for the sign + // bit if it is unknown. + Min = KnownOne; + Max = KnownOne|UnknownBits; + + if (UnknownBits[BitWidth-1]) { // Sign bit is unknown + Min.set(BitWidth-1); + Max.clear(BitWidth-1); + } +} + +// ComputeUnsignedMinMaxValuesFromKnownBits - Given an unsigned integer type and +// a set of known zero and one bits, compute the maximum and minimum values that +// could have the specified known zero and known one bits, returning them in +// min/max. +static void ComputeUnsignedMinMaxValuesFromKnownBits(const Type *Ty, + const APInt& KnownZero, + const APInt& KnownOne, + APInt& Min, + APInt& Max) { + uint32_t BitWidth = cast(Ty)->getBitWidth(); + assert(KnownZero.getBitWidth() == BitWidth && + KnownOne.getBitWidth() == BitWidth && + Min.getBitWidth() == BitWidth && Max.getBitWidth() && + "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth."); + APInt UnknownBits = ~(KnownZero|KnownOne); + + // The minimum value is when the unknown bits are all zeros. + Min = KnownOne; + // The maximum value is when the unknown bits are all ones. + Max = KnownOne|UnknownBits; +} + +/// SimplifyDemandedBits - This function attempts to replace V with a simpler +/// value based on the demanded bits. When this function is called, it is known +/// that only the bits set in DemandedMask of the result of V are ever used +/// downstream. Consequently, depending on the mask and V, it may be possible +/// to replace V with a constant or one of its operands. In such cases, this +/// function does the replacement and returns true. In all other cases, it +/// returns false after analyzing the expression and setting KnownOne and known +/// to be one in the expression. KnownZero contains all the bits that are known +/// to be zero in the expression. These are provided to potentially allow the +/// caller (which might recursively be SimplifyDemandedBits itself) to simplify +/// the expression. KnownOne and KnownZero always follow the invariant that +/// KnownOne & KnownZero == 0. That is, a bit can't be both 1 and 0. Note that +/// the bits in KnownOne and KnownZero may only be accurate for those bits set +/// in DemandedMask. Note also that the bitwidth of V, DemandedMask, KnownZero +/// and KnownOne must all be the same. +bool InstCombiner::SimplifyDemandedBits(Value *V, APInt DemandedMask, + APInt& KnownZero, APInt& KnownOne, + unsigned Depth) { + assert(V != 0 && "Null pointer of Value???"); + assert(Depth <= 6 && "Limit Search Depth"); + uint32_t BitWidth = DemandedMask.getBitWidth(); + const IntegerType *VTy = cast(V->getType()); + assert(VTy->getBitWidth() == BitWidth && + KnownZero.getBitWidth() == BitWidth && + KnownOne.getBitWidth() == BitWidth && + "Value *V, DemandedMask, KnownZero and KnownOne \ + must have same BitWidth"); + if (ConstantInt *CI = dyn_cast(V)) { + // We know all of the bits for a constant! + KnownOne = CI->getValue() & DemandedMask; + KnownZero = ~KnownOne & DemandedMask; + return false; + } + + KnownZero.clear(); + KnownOne.clear(); + if (!V->hasOneUse()) { // Other users may use these bits. + if (Depth != 0) { // Not at the root. + // Just compute the KnownZero/KnownOne bits to simplify things downstream. + ComputeMaskedBits(V, DemandedMask, KnownZero, KnownOne, Depth); + return false; + } + // If this is the root being simplified, allow it to have multiple uses, + // just set the DemandedMask to all bits. + DemandedMask = APInt::getAllOnesValue(BitWidth); + } else if (DemandedMask == 0) { // Not demanding any bits from V. + if (V != UndefValue::get(VTy)) + return UpdateValueUsesWith(V, UndefValue::get(VTy)); + return false; + } else if (Depth == 6) { // Limit search depth. + return false; + } + + Instruction *I = dyn_cast(V); + if (!I) return false; // Only analyze instructions. + + APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); + APInt &RHSKnownZero = KnownZero, &RHSKnownOne = KnownOne; + switch (I->getOpcode()) { + default: break; + case Instruction::And: + // If either the LHS or the RHS are Zero, the result is zero. + if (SimplifyDemandedBits(I->getOperand(1), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1)) + return true; + assert((RHSKnownZero & RHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + + // If something is known zero on the RHS, the bits aren't demanded on the + // LHS. + if (SimplifyDemandedBits(I->getOperand(0), DemandedMask & ~RHSKnownZero, + LHSKnownZero, LHSKnownOne, Depth+1)) + return true; + assert((LHSKnownZero & LHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + + // If all of the demanded bits are known 1 on one side, return the other. + // These bits cannot contribute to the result of the 'and'. + if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) == + (DemandedMask & ~LHSKnownZero)) + return UpdateValueUsesWith(I, I->getOperand(0)); + if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) == + (DemandedMask & ~RHSKnownZero)) + return UpdateValueUsesWith(I, I->getOperand(1)); + + // If all of the demanded bits in the inputs are known zeros, return zero. + if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask) + return UpdateValueUsesWith(I, Constant::getNullValue(VTy)); + + // If the RHS is a constant, see if we can simplify it. + if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnownZero)) + return UpdateValueUsesWith(I, I); + + // Output known-1 bits are only known if set in both the LHS & RHS. + RHSKnownOne &= LHSKnownOne; + // Output known-0 are known to be clear if zero in either the LHS | RHS. + RHSKnownZero |= LHSKnownZero; + break; + case Instruction::Or: + // If either the LHS or the RHS are One, the result is One. + if (SimplifyDemandedBits(I->getOperand(1), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1)) + return true; + assert((RHSKnownZero & RHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + // If something is known one on the RHS, the bits aren't demanded on the + // LHS. + if (SimplifyDemandedBits(I->getOperand(0), DemandedMask & ~RHSKnownOne, + LHSKnownZero, LHSKnownOne, Depth+1)) + return true; + assert((LHSKnownZero & LHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + + // If all of the demanded bits are known zero on one side, return the other. + // These bits cannot contribute to the result of the 'or'. + if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) == + (DemandedMask & ~LHSKnownOne)) + return UpdateValueUsesWith(I, I->getOperand(0)); + if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) == + (DemandedMask & ~RHSKnownOne)) + return UpdateValueUsesWith(I, I->getOperand(1)); + + // If all of the potentially set bits on one side are known to be set on + // the other side, just use the 'other' side. + if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) == + (DemandedMask & (~RHSKnownZero))) + return UpdateValueUsesWith(I, I->getOperand(0)); + if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) == + (DemandedMask & (~LHSKnownZero))) + return UpdateValueUsesWith(I, I->getOperand(1)); + + // If the RHS is a constant, see if we can simplify it. + if (ShrinkDemandedConstant(I, 1, DemandedMask)) + return UpdateValueUsesWith(I, I); + + // Output known-0 bits are only known if clear in both the LHS & RHS. + RHSKnownZero &= LHSKnownZero; + // Output known-1 are known to be set if set in either the LHS | RHS. + RHSKnownOne |= LHSKnownOne; + break; + case Instruction::Xor: { + if (SimplifyDemandedBits(I->getOperand(1), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1)) + return true; + assert((RHSKnownZero & RHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + if (SimplifyDemandedBits(I->getOperand(0), DemandedMask, + LHSKnownZero, LHSKnownOne, Depth+1)) + return true; + assert((LHSKnownZero & LHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + + // If all of the demanded bits are known zero on one side, return the other. + // These bits cannot contribute to the result of the 'xor'. + if ((DemandedMask & RHSKnownZero) == DemandedMask) + return UpdateValueUsesWith(I, I->getOperand(0)); + if ((DemandedMask & LHSKnownZero) == DemandedMask) + return UpdateValueUsesWith(I, I->getOperand(1)); + + // Output known-0 bits are known if clear or set in both the LHS & RHS. + APInt KnownZeroOut = (RHSKnownZero & LHSKnownZero) | + (RHSKnownOne & LHSKnownOne); + // Output known-1 are known to be set if set in only one of the LHS, RHS. + APInt KnownOneOut = (RHSKnownZero & LHSKnownOne) | + (RHSKnownOne & LHSKnownZero); + + // If all of the demanded bits are known to be zero on one side or the + // other, turn this into an *inclusive* or. + // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0 + if ((DemandedMask & ~RHSKnownZero & ~LHSKnownZero) == 0) { + Instruction *Or = + BinaryOperator::createOr(I->getOperand(0), I->getOperand(1), + I->getName()); + InsertNewInstBefore(Or, *I); + return UpdateValueUsesWith(I, Or); + } + + // If all of the demanded bits on one side are known, and all of the set + // bits on that side are also known to be set on the other side, turn this + // into an AND, as we know the bits will be cleared. + // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2 + if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) { + // all known + if ((RHSKnownOne & LHSKnownOne) == RHSKnownOne) { + Constant *AndC = ConstantInt::get(~RHSKnownOne & DemandedMask); + Instruction *And = + BinaryOperator::createAnd(I->getOperand(0), AndC, "tmp"); + InsertNewInstBefore(And, *I); + return UpdateValueUsesWith(I, And); + } + } + + // If the RHS is a constant, see if we can simplify it. + // FIXME: for XOR, we prefer to force bits to 1 if they will make a -1. + if (ShrinkDemandedConstant(I, 1, DemandedMask)) + return UpdateValueUsesWith(I, I); + + RHSKnownZero = KnownZeroOut; + RHSKnownOne = KnownOneOut; + break; + } + case Instruction::Select: + if (SimplifyDemandedBits(I->getOperand(2), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1)) + return true; + if (SimplifyDemandedBits(I->getOperand(1), DemandedMask, + LHSKnownZero, LHSKnownOne, Depth+1)) + return true; + assert((RHSKnownZero & RHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + assert((LHSKnownZero & LHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + + // If the operands are constants, see if we can simplify them. + if (ShrinkDemandedConstant(I, 1, DemandedMask)) + return UpdateValueUsesWith(I, I); + if (ShrinkDemandedConstant(I, 2, DemandedMask)) + return UpdateValueUsesWith(I, I); + + // Only known if known in both the LHS and RHS. + RHSKnownOne &= LHSKnownOne; + RHSKnownZero &= LHSKnownZero; + break; + case Instruction::Trunc: { + uint32_t truncBf = + cast(I->getOperand(0)->getType())->getBitWidth(); + DemandedMask.zext(truncBf); + RHSKnownZero.zext(truncBf); + RHSKnownOne.zext(truncBf); + if (SimplifyDemandedBits(I->getOperand(0), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1)) + return true; + DemandedMask.trunc(BitWidth); + RHSKnownZero.trunc(BitWidth); + RHSKnownOne.trunc(BitWidth); + assert((RHSKnownZero & RHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + break; + } + case Instruction::BitCast: + if (!I->getOperand(0)->getType()->isInteger()) + return false; + + if (SimplifyDemandedBits(I->getOperand(0), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1)) + return true; + assert((RHSKnownZero & RHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + break; + case Instruction::ZExt: { + // Compute the bits in the result that are not present in the input. + const IntegerType *SrcTy = cast(I->getOperand(0)->getType()); + uint32_t SrcBitWidth = SrcTy->getBitWidth(); + + DemandedMask.trunc(SrcBitWidth); + RHSKnownZero.trunc(SrcBitWidth); + RHSKnownOne.trunc(SrcBitWidth); + if (SimplifyDemandedBits(I->getOperand(0), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1)) + return true; + DemandedMask.zext(BitWidth); + RHSKnownZero.zext(BitWidth); + RHSKnownOne.zext(BitWidth); + assert((RHSKnownZero & RHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + // The top bits are known to be zero. + RHSKnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + break; + } + case Instruction::SExt: { + // Compute the bits in the result that are not present in the input. + const IntegerType *SrcTy = cast(I->getOperand(0)->getType()); + uint32_t SrcBitWidth = SrcTy->getBitWidth(); + + APInt InputDemandedBits = DemandedMask & + APInt::getLowBitsSet(BitWidth, SrcBitWidth); + + APInt NewBits(APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth)); + // If any of the sign extended bits are demanded, we know that the sign + // bit is demanded. + if ((NewBits & DemandedMask) != 0) + InputDemandedBits.set(SrcBitWidth-1); + + InputDemandedBits.trunc(SrcBitWidth); + RHSKnownZero.trunc(SrcBitWidth); + RHSKnownOne.trunc(SrcBitWidth); + if (SimplifyDemandedBits(I->getOperand(0), InputDemandedBits, + RHSKnownZero, RHSKnownOne, Depth+1)) + return true; + InputDemandedBits.zext(BitWidth); + RHSKnownZero.zext(BitWidth); + RHSKnownOne.zext(BitWidth); + assert((RHSKnownZero & RHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + + // If the sign bit of the input is known set or clear, then we know the + // top bits of the result. + + // If the input sign bit is known zero, or if the NewBits are not demanded + // convert this into a zero extension. + if (RHSKnownZero[SrcBitWidth-1] || (NewBits & ~DemandedMask) == NewBits) + { + // Convert to ZExt cast + CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName(), I); + return UpdateValueUsesWith(I, NewCast); + } else if (RHSKnownOne[SrcBitWidth-1]) { // Input sign bit known set + RHSKnownOne |= NewBits; + } + break; + } + case Instruction::Add: { + // Figure out what the input bits are. If the top bits of the and result + // are not demanded, then the add doesn't demand them from its input + // either. + uint32_t NLZ = DemandedMask.countLeadingZeros(); + + // If there is a constant on the RHS, there are a variety of xformations + // we can do. + if (ConstantInt *RHS = dyn_cast(I->getOperand(1))) { + // If null, this should be simplified elsewhere. Some of the xforms here + // won't work if the RHS is zero. + if (RHS->isZero()) + break; + + // If the top bit of the output is demanded, demand everything from the + // input. Otherwise, we demand all the input bits except NLZ top bits. + APInt InDemandedBits(APInt::getLowBitsSet(BitWidth, BitWidth - NLZ)); + + // Find information about known zero/one bits in the input. + if (SimplifyDemandedBits(I->getOperand(0), InDemandedBits, + LHSKnownZero, LHSKnownOne, Depth+1)) + return true; + + // If the RHS of the add has bits set that can't affect the input, reduce + // the constant. + if (ShrinkDemandedConstant(I, 1, InDemandedBits)) + return UpdateValueUsesWith(I, I); + + // Avoid excess work. + if (LHSKnownZero == 0 && LHSKnownOne == 0) + break; + + // Turn it into OR if input bits are zero. + if ((LHSKnownZero & RHS->getValue()) == RHS->getValue()) { + Instruction *Or = + BinaryOperator::createOr(I->getOperand(0), I->getOperand(1), + I->getName()); + InsertNewInstBefore(Or, *I); + return UpdateValueUsesWith(I, Or); + } + + // We can say something about the output known-zero and known-one bits, + // depending on potential carries from the input constant and the + // unknowns. For example if the LHS is known to have at most the 0x0F0F0 + // bits set and the RHS constant is 0x01001, then we know we have a known + // one mask of 0x00001 and a known zero mask of 0xE0F0E. + + // To compute this, we first compute the potential carry bits. These are + // the bits which may be modified. I'm not aware of a better way to do + // this scan. + const APInt& RHSVal = RHS->getValue(); + APInt CarryBits((~LHSKnownZero + RHSVal) ^ (~LHSKnownZero ^ RHSVal)); + + // Now that we know which bits have carries, compute the known-1/0 sets. + + // Bits are known one if they are known zero in one operand and one in the + // other, and there is no input carry. + RHSKnownOne = ((LHSKnownZero & RHSVal) | + (LHSKnownOne & ~RHSVal)) & ~CarryBits; + + // Bits are known zero if they are known zero in both operands and there + // is no input carry. + RHSKnownZero = LHSKnownZero & ~RHSVal & ~CarryBits; + } else { + // If the high-bits of this ADD are not demanded, then it does not demand + // the high bits of its LHS or RHS. + if (DemandedMask[BitWidth-1] == 0) { + // Right fill the mask of bits for this ADD to demand the most + // significant bit and all those below it. + APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); + if (SimplifyDemandedBits(I->getOperand(0), DemandedFromOps, + LHSKnownZero, LHSKnownOne, Depth+1)) + return true; + if (SimplifyDemandedBits(I->getOperand(1), DemandedFromOps, + LHSKnownZero, LHSKnownOne, Depth+1)) + return true; + } + } + break; + } + case Instruction::Sub: + // If the high-bits of this SUB are not demanded, then it does not demand + // the high bits of its LHS or RHS. + if (DemandedMask[BitWidth-1] == 0) { + // Right fill the mask of bits for this SUB to demand the most + // significant bit and all those below it. + uint32_t NLZ = DemandedMask.countLeadingZeros(); + APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); + if (SimplifyDemandedBits(I->getOperand(0), DemandedFromOps, + LHSKnownZero, LHSKnownOne, Depth+1)) + return true; + if (SimplifyDemandedBits(I->getOperand(1), DemandedFromOps, + LHSKnownZero, LHSKnownOne, Depth+1)) + return true; + } + break; + case Instruction::Shl: + if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); + APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt)); + if (SimplifyDemandedBits(I->getOperand(0), DemandedMaskIn, + RHSKnownZero, RHSKnownOne, Depth+1)) + return true; + assert((RHSKnownZero & RHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + RHSKnownZero <<= ShiftAmt; + RHSKnownOne <<= ShiftAmt; + // low bits known zero. + if (ShiftAmt) + RHSKnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt); + } + break; + case Instruction::LShr: + // For a logical shift right + if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); + + // Unsigned shift right. + APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); + if (SimplifyDemandedBits(I->getOperand(0), DemandedMaskIn, + RHSKnownZero, RHSKnownOne, Depth+1)) + return true; + assert((RHSKnownZero & RHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + RHSKnownZero = APIntOps::lshr(RHSKnownZero, ShiftAmt); + RHSKnownOne = APIntOps::lshr(RHSKnownOne, ShiftAmt); + if (ShiftAmt) { + // Compute the new bits that are at the top now. + APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); + RHSKnownZero |= HighBits; // high bits known zero. + } + } + break; + case Instruction::AShr: + // If this is an arithmetic shift right and only the low-bit is set, we can + // always convert this into a logical shr, even if the shift amount is + // variable. The low bit of the shift cannot be an input sign bit unless + // the shift amount is >= the size of the datatype, which is undefined. + if (DemandedMask == 1) { + // Perform the logical shift right. + Value *NewVal = BinaryOperator::createLShr( + I->getOperand(0), I->getOperand(1), I->getName()); + InsertNewInstBefore(cast(NewVal), *I); + return UpdateValueUsesWith(I, NewVal); + } + + // If the sign bit is the only bit demanded by this ashr, then there is no + // need to do it, the shift doesn't change the high bit. + if (DemandedMask.isSignBit()) + return UpdateValueUsesWith(I, I->getOperand(0)); + + if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { + uint32_t ShiftAmt = SA->getLimitedValue(BitWidth); + + // Signed shift right. + APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); + // If any of the "high bits" are demanded, we should set the sign bit as + // demanded. + if (DemandedMask.countLeadingZeros() <= ShiftAmt) + DemandedMaskIn.set(BitWidth-1); + if (SimplifyDemandedBits(I->getOperand(0), + DemandedMaskIn, + RHSKnownZero, RHSKnownOne, Depth+1)) + return true; + assert((RHSKnownZero & RHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + // Compute the new bits that are at the top now. + APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); + RHSKnownZero = APIntOps::lshr(RHSKnownZero, ShiftAmt); + RHSKnownOne = APIntOps::lshr(RHSKnownOne, ShiftAmt); + + // Handle the sign bits. + APInt SignBit(APInt::getSignBit(BitWidth)); + // Adjust to where it is now in the mask. + SignBit = APIntOps::lshr(SignBit, ShiftAmt); + + // If the input sign bit is known to be zero, or if none of the top bits + // are demanded, turn this into an unsigned shift right. + if (RHSKnownZero[BitWidth-ShiftAmt-1] || + (HighBits & ~DemandedMask) == HighBits) { + // Perform the logical shift right. + Value *NewVal = BinaryOperator::createLShr( + I->getOperand(0), SA, I->getName()); + InsertNewInstBefore(cast(NewVal), *I); + return UpdateValueUsesWith(I, NewVal); + } else if ((RHSKnownOne & SignBit) != 0) { // New bits are known one. + RHSKnownOne |= HighBits; + } + } + break; + } + + // If the client is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) + return UpdateValueUsesWith(I, ConstantInt::get(RHSKnownOne)); + return false; +} + + +/// SimplifyDemandedVectorElts - The specified value producecs a vector with +/// 64 or fewer elements. DemandedElts contains the set of elements that are +/// actually used by the caller. This method analyzes which elements of the +/// operand are undef and returns that information in UndefElts. +/// +/// If the information about demanded elements can be used to simplify the +/// operation, the operation is simplified, then the resultant value is +/// returned. This returns null if no change was made. +Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, uint64_t DemandedElts, + uint64_t &UndefElts, + unsigned Depth) { + unsigned VWidth = cast(V->getType())->getNumElements(); + assert(VWidth <= 64 && "Vector too wide to analyze!"); + uint64_t EltMask = ~0ULL >> (64-VWidth); + assert(DemandedElts != EltMask && (DemandedElts & ~EltMask) == 0 && + "Invalid DemandedElts!"); + + if (isa(V)) { + // If the entire vector is undefined, just return this info. + UndefElts = EltMask; + return 0; + } else if (DemandedElts == 0) { // If nothing is demanded, provide undef. + UndefElts = EltMask; + return UndefValue::get(V->getType()); + } + + UndefElts = 0; + if (ConstantVector *CP = dyn_cast(V)) { + const Type *EltTy = cast(V->getType())->getElementType(); + Constant *Undef = UndefValue::get(EltTy); + + std::vector Elts; + for (unsigned i = 0; i != VWidth; ++i) + if (!(DemandedElts & (1ULL << i))) { // If not demanded, set to undef. + Elts.push_back(Undef); + UndefElts |= (1ULL << i); + } else if (isa(CP->getOperand(i))) { // Already undef. + Elts.push_back(Undef); + UndefElts |= (1ULL << i); + } else { // Otherwise, defined. + Elts.push_back(CP->getOperand(i)); + } + + // If we changed the constant, return it. + Constant *NewCP = ConstantVector::get(Elts); + return NewCP != CP ? NewCP : 0; + } else if (isa(V)) { + // Simplify the CAZ to a ConstantVector where the non-demanded elements are + // set to undef. + const Type *EltTy = cast(V->getType())->getElementType(); + Constant *Zero = Constant::getNullValue(EltTy); + Constant *Undef = UndefValue::get(EltTy); + std::vector Elts; + for (unsigned i = 0; i != VWidth; ++i) + Elts.push_back((DemandedElts & (1ULL << i)) ? Zero : Undef); + UndefElts = DemandedElts ^ EltMask; + return ConstantVector::get(Elts); + } + + if (!V->hasOneUse()) { // Other users may use these bits. + if (Depth != 0) { // Not at the root. + // TODO: Just compute the UndefElts information recursively. + return false; + } + return false; + } else if (Depth == 10) { // Limit search depth. + return false; + } + + Instruction *I = dyn_cast(V); + if (!I) return false; // Only analyze instructions. + + bool MadeChange = false; + uint64_t UndefElts2; + Value *TmpV; + switch (I->getOpcode()) { + default: break; + + case Instruction::InsertElement: { + // If this is a variable index, we don't know which element it overwrites. + // demand exactly the same input as we produce. + ConstantInt *Idx = dyn_cast(I->getOperand(2)); + if (Idx == 0) { + // Note that we can't propagate undef elt info, because we don't know + // which elt is getting updated. + TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts, + UndefElts2, Depth+1); + if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + break; + } + + // If this is inserting an element that isn't demanded, remove this + // insertelement. + unsigned IdxNo = Idx->getZExtValue(); + if (IdxNo >= VWidth || (DemandedElts & (1ULL << IdxNo)) == 0) + return AddSoonDeadInstToWorklist(*I, 0); + + // Otherwise, the element inserted overwrites whatever was there, so the + // input demanded set is simpler than the output set. + TmpV = SimplifyDemandedVectorElts(I->getOperand(0), + DemandedElts & ~(1ULL << IdxNo), + UndefElts, Depth+1); + if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + + // The inserted element is defined. + UndefElts |= 1ULL << IdxNo; + break; + } + case Instruction::BitCast: { + // Vector->vector casts only. + const VectorType *VTy = dyn_cast(I->getOperand(0)->getType()); + if (!VTy) break; + unsigned InVWidth = VTy->getNumElements(); + uint64_t InputDemandedElts = 0; + unsigned Ratio; + + if (VWidth == InVWidth) { + // If we are converting from <4 x i32> -> <4 x f32>, we demand the same + // elements as are demanded of us. + Ratio = 1; + InputDemandedElts = DemandedElts; + } else if (VWidth > InVWidth) { + // Untested so far. + break; + + // If there are more elements in the result than there are in the source, + // then an input element is live if any of the corresponding output + // elements are live. + Ratio = VWidth/InVWidth; + for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) { + if (DemandedElts & (1ULL << OutIdx)) + InputDemandedElts |= 1ULL << (OutIdx/Ratio); + } + } else { + // Untested so far. + break; + + // If there are more elements in the source than there are in the result, + // then an input element is live if the corresponding output element is + // live. + Ratio = InVWidth/VWidth; + for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx) + if (DemandedElts & (1ULL << InIdx/Ratio)) + InputDemandedElts |= 1ULL << InIdx; + } + + // div/rem demand all inputs, because they don't want divide by zero. + TmpV = SimplifyDemandedVectorElts(I->getOperand(0), InputDemandedElts, + UndefElts2, Depth+1); + if (TmpV) { + I->setOperand(0, TmpV); + MadeChange = true; + } + + UndefElts = UndefElts2; + if (VWidth > InVWidth) { + assert(0 && "Unimp"); + // If there are more elements in the result than there are in the source, + // then an output element is undef if the corresponding input element is + // undef. + for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) + if (UndefElts2 & (1ULL << (OutIdx/Ratio))) + UndefElts |= 1ULL << OutIdx; + } else if (VWidth < InVWidth) { + assert(0 && "Unimp"); + // If there are more elements in the source than there are in the result, + // then a result element is undef if all of the corresponding input + // elements are undef. + UndefElts = ~0ULL >> (64-VWidth); // Start out all undef. + for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx) + if ((UndefElts2 & (1ULL << InIdx)) == 0) // Not undef? + UndefElts &= ~(1ULL << (InIdx/Ratio)); // Clear undef bit. + } + break; + } + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + // div/rem demand all inputs, because they don't want divide by zero. + TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts, + UndefElts, Depth+1); + if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + TmpV = SimplifyDemandedVectorElts(I->getOperand(1), DemandedElts, + UndefElts2, Depth+1); + if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; } + + // Output elements are undefined if both are undefined. Consider things + // like undef&0. The result is known zero, not undef. + UndefElts &= UndefElts2; + break; + + case Instruction::Call: { + IntrinsicInst *II = dyn_cast(I); + if (!II) break; + switch (II->getIntrinsicID()) { + default: break; + + // Binary vector operations that work column-wise. A dest element is a + // function of the corresponding input elements from the two inputs. + case Intrinsic::x86_sse_sub_ss: + case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse_min_ss: + case Intrinsic::x86_sse_max_ss: + case Intrinsic::x86_sse2_sub_sd: + case Intrinsic::x86_sse2_mul_sd: + case Intrinsic::x86_sse2_min_sd: + case Intrinsic::x86_sse2_max_sd: + TmpV = SimplifyDemandedVectorElts(II->getOperand(1), DemandedElts, + UndefElts, Depth+1); + if (TmpV) { II->setOperand(1, TmpV); MadeChange = true; } + TmpV = SimplifyDemandedVectorElts(II->getOperand(2), DemandedElts, + UndefElts2, Depth+1); + if (TmpV) { II->setOperand(2, TmpV); MadeChange = true; } + + // If only the low elt is demanded and this is a scalarizable intrinsic, + // scalarize it now. + if (DemandedElts == 1) { + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::x86_sse_sub_ss: + case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse2_sub_sd: + case Intrinsic::x86_sse2_mul_sd: + // TODO: Lower MIN/MAX/ABS/etc + Value *LHS = II->getOperand(1); + Value *RHS = II->getOperand(2); + // Extract the element as scalars. + LHS = InsertNewInstBefore(new ExtractElementInst(LHS, 0U,"tmp"), *II); + RHS = InsertNewInstBefore(new ExtractElementInst(RHS, 0U,"tmp"), *II); + + switch (II->getIntrinsicID()) { + default: assert(0 && "Case stmts out of sync!"); + case Intrinsic::x86_sse_sub_ss: + case Intrinsic::x86_sse2_sub_sd: + TmpV = InsertNewInstBefore(BinaryOperator::createSub(LHS, RHS, + II->getName()), *II); + break; + case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse2_mul_sd: + TmpV = InsertNewInstBefore(BinaryOperator::createMul(LHS, RHS, + II->getName()), *II); + break; + } + + Instruction *New = + new InsertElementInst(UndefValue::get(II->getType()), TmpV, 0U, + II->getName()); + InsertNewInstBefore(New, *II); + AddSoonDeadInstToWorklist(*II, 0); + return New; + } + } + + // Output elements are undefined if both are undefined. Consider things + // like undef&0. The result is known zero, not undef. + UndefElts &= UndefElts2; + break; + } + break; + } + } + return MadeChange ? I : 0; +} + +/// @returns true if the specified compare instruction is +/// true when both operands are equal... +/// @brief Determine if the ICmpInst returns true if both operands are equal +static bool isTrueWhenEqual(ICmpInst &ICI) { + ICmpInst::Predicate pred = ICI.getPredicate(); + return pred == ICmpInst::ICMP_EQ || pred == ICmpInst::ICMP_UGE || + pred == ICmpInst::ICMP_SGE || pred == ICmpInst::ICMP_ULE || + pred == ICmpInst::ICMP_SLE; +} + +/// AssociativeOpt - Perform an optimization on an associative operator. This +/// function is designed to check a chain of associative operators for a +/// potential to apply a certain optimization. Since the optimization may be +/// applicable if the expression was reassociated, this checks the chain, then +/// reassociates the expression as necessary to expose the optimization +/// opportunity. This makes use of a special Functor, which must define +/// 'shouldApply' and 'apply' methods. +/// +template +Instruction *AssociativeOpt(BinaryOperator &Root, const Functor &F) { + unsigned Opcode = Root.getOpcode(); + Value *LHS = Root.getOperand(0); + + // Quick check, see if the immediate LHS matches... + if (F.shouldApply(LHS)) + return F.apply(Root); + + // Otherwise, if the LHS is not of the same opcode as the root, return. + Instruction *LHSI = dyn_cast(LHS); + while (LHSI && LHSI->getOpcode() == Opcode && LHSI->hasOneUse()) { + // Should we apply this transform to the RHS? + bool ShouldApply = F.shouldApply(LHSI->getOperand(1)); + + // If not to the RHS, check to see if we should apply to the LHS... + if (!ShouldApply && F.shouldApply(LHSI->getOperand(0))) { + cast(LHSI)->swapOperands(); // Make the LHS the RHS + ShouldApply = true; + } + + // If the functor wants to apply the optimization to the RHS of LHSI, + // reassociate the expression from ((? op A) op B) to (? op (A op B)) + if (ShouldApply) { + BasicBlock *BB = Root.getParent(); + + // Now all of the instructions are in the current basic block, go ahead + // and perform the reassociation. + Instruction *TmpLHSI = cast(Root.getOperand(0)); + + // First move the selected RHS to the LHS of the root... + Root.setOperand(0, LHSI->getOperand(1)); + + // Make what used to be the LHS of the root be the user of the root... + Value *ExtraOperand = TmpLHSI->getOperand(1); + if (&Root == TmpLHSI) { + Root.replaceAllUsesWith(Constant::getNullValue(TmpLHSI->getType())); + return 0; + } + Root.replaceAllUsesWith(TmpLHSI); // Users now use TmpLHSI + TmpLHSI->setOperand(1, &Root); // TmpLHSI now uses the root + TmpLHSI->getParent()->getInstList().remove(TmpLHSI); + BasicBlock::iterator ARI = &Root; ++ARI; + BB->getInstList().insert(ARI, TmpLHSI); // Move TmpLHSI to after Root + ARI = Root; + + // Now propagate the ExtraOperand down the chain of instructions until we + // get to LHSI. + while (TmpLHSI != LHSI) { + Instruction *NextLHSI = cast(TmpLHSI->getOperand(0)); + // Move the instruction to immediately before the chain we are + // constructing to avoid breaking dominance properties. + NextLHSI->getParent()->getInstList().remove(NextLHSI); + BB->getInstList().insert(ARI, NextLHSI); + ARI = NextLHSI; + + Value *NextOp = NextLHSI->getOperand(1); + NextLHSI->setOperand(1, ExtraOperand); + TmpLHSI = NextLHSI; + ExtraOperand = NextOp; + } + + // Now that the instructions are reassociated, have the functor perform + // the transformation... + return F.apply(Root); + } + + LHSI = dyn_cast(LHSI->getOperand(0)); + } + return 0; +} + + +// AddRHS - Implements: X + X --> X << 1 +struct AddRHS { + Value *RHS; + AddRHS(Value *rhs) : RHS(rhs) {} + bool shouldApply(Value *LHS) const { return LHS == RHS; } + Instruction *apply(BinaryOperator &Add) const { + return BinaryOperator::createShl(Add.getOperand(0), + ConstantInt::get(Add.getType(), 1)); + } +}; + +// AddMaskingAnd - Implements (A & C1)+(B & C2) --> (A & C1)|(B & C2) +// iff C1&C2 == 0 +struct AddMaskingAnd { + Constant *C2; + AddMaskingAnd(Constant *c) : C2(c) {} + bool shouldApply(Value *LHS) const { + ConstantInt *C1; + return match(LHS, m_And(m_Value(), m_ConstantInt(C1))) && + ConstantExpr::getAnd(C1, C2)->isNullValue(); + } + Instruction *apply(BinaryOperator &Add) const { + return BinaryOperator::createOr(Add.getOperand(0), Add.getOperand(1)); + } +}; + +static Value *FoldOperationIntoSelectOperand(Instruction &I, Value *SO, + InstCombiner *IC) { + if (CastInst *CI = dyn_cast(&I)) { + if (Constant *SOC = dyn_cast(SO)) + return ConstantExpr::getCast(CI->getOpcode(), SOC, I.getType()); + + return IC->InsertNewInstBefore(CastInst::create( + CI->getOpcode(), SO, I.getType(), SO->getName() + ".cast"), I); + } + + // Figure out if the constant is the left or the right argument. + bool ConstIsRHS = isa(I.getOperand(1)); + Constant *ConstOperand = cast(I.getOperand(ConstIsRHS)); + + if (Constant *SOC = dyn_cast(SO)) { + if (ConstIsRHS) + return ConstantExpr::get(I.getOpcode(), SOC, ConstOperand); + return ConstantExpr::get(I.getOpcode(), ConstOperand, SOC); + } + + Value *Op0 = SO, *Op1 = ConstOperand; + if (!ConstIsRHS) + std::swap(Op0, Op1); + Instruction *New; + if (BinaryOperator *BO = dyn_cast(&I)) + New = BinaryOperator::create(BO->getOpcode(), Op0, Op1,SO->getName()+".op"); + else if (CmpInst *CI = dyn_cast(&I)) + New = CmpInst::create(CI->getOpcode(), CI->getPredicate(), Op0, Op1, + SO->getName()+".cmp"); + else { + assert(0 && "Unknown binary instruction type!"); + abort(); + } + return IC->InsertNewInstBefore(New, I); +} + +// FoldOpIntoSelect - Given an instruction with a select as one operand and a +// constant as the other operand, try to fold the binary operator into the +// select arguments. This also works for Cast instructions, which obviously do +// not have a second operand. +static Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI, + InstCombiner *IC) { + // Don't modify shared select instructions + if (!SI->hasOneUse()) return 0; + Value *TV = SI->getOperand(1); + Value *FV = SI->getOperand(2); + + if (isa(TV) || isa(FV)) { + // Bool selects with constant operands can be folded to logical ops. + if (SI->getType() == Type::Int1Ty) return 0; + + Value *SelectTrueVal = FoldOperationIntoSelectOperand(Op, TV, IC); + Value *SelectFalseVal = FoldOperationIntoSelectOperand(Op, FV, IC); + + return new SelectInst(SI->getCondition(), SelectTrueVal, + SelectFalseVal); + } + return 0; +} + + +/// FoldOpIntoPhi - Given a binary operator or cast instruction which has a PHI +/// node as operand #0, see if we can fold the instruction into the PHI (which +/// is only possible if all operands to the PHI are constants). +Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { + PHINode *PN = cast(I.getOperand(0)); + unsigned NumPHIValues = PN->getNumIncomingValues(); + if (!PN->hasOneUse() || NumPHIValues == 0) return 0; + + // Check to see if all of the operands of the PHI are constants. If there is + // one non-constant value, remember the BB it is. If there is more than one + // or if *it* is a PHI, bail out. + BasicBlock *NonConstBB = 0; + for (unsigned i = 0; i != NumPHIValues; ++i) + if (!isa(PN->getIncomingValue(i))) { + if (NonConstBB) return 0; // More than one non-const value. + if (isa(PN->getIncomingValue(i))) return 0; // Itself a phi. + NonConstBB = PN->getIncomingBlock(i); + + // If the incoming non-constant value is in I's block, we have an infinite + // loop. + if (NonConstBB == I.getParent()) + return 0; + } + + // If there is exactly one non-constant value, we can insert a copy of the + // operation in that block. However, if this is a critical edge, we would be + // inserting the computation one some other paths (e.g. inside a loop). Only + // do this if the pred block is unconditionally branching into the phi block. + if (NonConstBB) { + BranchInst *BI = dyn_cast(NonConstBB->getTerminator()); + if (!BI || !BI->isUnconditional()) return 0; + } + + // Okay, we can do the transformation: create the new PHI node. + PHINode *NewPN = new PHINode(I.getType(), ""); + NewPN->reserveOperandSpace(PN->getNumOperands()/2); + InsertNewInstBefore(NewPN, *PN); + NewPN->takeName(PN); + + // Next, add all of the operands to the PHI. + if (I.getNumOperands() == 2) { + Constant *C = cast(I.getOperand(1)); + for (unsigned i = 0; i != NumPHIValues; ++i) { + Value *InV; + if (Constant *InC = dyn_cast(PN->getIncomingValue(i))) { + if (CmpInst *CI = dyn_cast(&I)) + InV = ConstantExpr::getCompare(CI->getPredicate(), InC, C); + else + InV = ConstantExpr::get(I.getOpcode(), InC, C); + } else { + assert(PN->getIncomingBlock(i) == NonConstBB); + if (BinaryOperator *BO = dyn_cast(&I)) + InV = BinaryOperator::create(BO->getOpcode(), + PN->getIncomingValue(i), C, "phitmp", + NonConstBB->getTerminator()); + else if (CmpInst *CI = dyn_cast(&I)) + InV = CmpInst::create(CI->getOpcode(), + CI->getPredicate(), + PN->getIncomingValue(i), C, "phitmp", + NonConstBB->getTerminator()); + else + assert(0 && "Unknown binop!"); + + AddToWorkList(cast(InV)); + } + NewPN->addIncoming(InV, PN->getIncomingBlock(i)); + } + } else { + CastInst *CI = cast(&I); + const Type *RetTy = CI->getType(); + for (unsigned i = 0; i != NumPHIValues; ++i) { + Value *InV; + if (Constant *InC = dyn_cast(PN->getIncomingValue(i))) { + InV = ConstantExpr::getCast(CI->getOpcode(), InC, RetTy); + } else { + assert(PN->getIncomingBlock(i) == NonConstBB); + InV = CastInst::create(CI->getOpcode(), PN->getIncomingValue(i), + I.getType(), "phitmp", + NonConstBB->getTerminator()); + AddToWorkList(cast(InV)); + } + NewPN->addIncoming(InV, PN->getIncomingBlock(i)); + } + } + return ReplaceInstUsesWith(I, NewPN); +} + +Instruction *InstCombiner::visitAdd(BinaryOperator &I) { + bool Changed = SimplifyCommutative(I); + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + + if (Constant *RHSC = dyn_cast(RHS)) { + // X + undef -> undef + if (isa(RHS)) + return ReplaceInstUsesWith(I, RHS); + + // X + 0 --> X + if (!I.getType()->isFPOrFPVector()) { // NOTE: -0 + +0 = +0. + if (RHSC->isNullValue()) + return ReplaceInstUsesWith(I, LHS); + } else if (ConstantFP *CFP = dyn_cast(RHSC)) { + if (CFP->isExactlyValue(-0.0)) + return ReplaceInstUsesWith(I, LHS); + } + + if (ConstantInt *CI = dyn_cast(RHSC)) { + // X + (signbit) --> X ^ signbit + const APInt& Val = CI->getValue(); + uint32_t BitWidth = Val.getBitWidth(); + if (Val == APInt::getSignBit(BitWidth)) + return BinaryOperator::createXor(LHS, RHS); + + // See if SimplifyDemandedBits can simplify this. This handles stuff like + // (X & 254)+1 -> (X&254)|1 + if (!isa(I.getType())) { + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + if (SimplifyDemandedBits(&I, APInt::getAllOnesValue(BitWidth), + KnownZero, KnownOne)) + return &I; + } + } + + if (isa(LHS)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + + ConstantInt *XorRHS = 0; + Value *XorLHS = 0; + if (isa(RHSC) && + match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) { + uint32_t TySizeBits = I.getType()->getPrimitiveSizeInBits(); + const APInt& RHSVal = cast(RHSC)->getValue(); + + uint32_t Size = TySizeBits / 2; + APInt C0080Val(APInt(TySizeBits, 1ULL).shl(Size - 1)); + APInt CFF80Val(-C0080Val); + do { + if (TySizeBits > Size) { + // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext. + // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext. + if ((RHSVal == CFF80Val && XorRHS->getValue() == C0080Val) || + (RHSVal == C0080Val && XorRHS->getValue() == CFF80Val)) { + // This is a sign extend if the top bits are known zero. + if (!MaskedValueIsZero(XorLHS, + APInt::getHighBitsSet(TySizeBits, TySizeBits - Size))) + Size = 0; // Not a sign ext, but can't be any others either. + break; + } + } + Size >>= 1; + C0080Val = APIntOps::lshr(C0080Val, Size); + CFF80Val = APIntOps::ashr(CFF80Val, Size); + } while (Size >= 1); + + // FIXME: This shouldn't be necessary. When the backends can handle types + // with funny bit widths then this whole cascade of if statements should + // be removed. It is just here to get the size of the "middle" type back + // up to something that the back ends can handle. + const Type *MiddleType = 0; + switch (Size) { + default: break; + case 32: MiddleType = Type::Int32Ty; break; + case 16: MiddleType = Type::Int16Ty; break; + case 8: MiddleType = Type::Int8Ty; break; + } + if (MiddleType) { + Instruction *NewTrunc = new TruncInst(XorLHS, MiddleType, "sext"); + InsertNewInstBefore(NewTrunc, I); + return new SExtInst(NewTrunc, I.getType(), I.getName()); + } + } + } + + // X + X --> X << 1 + if (I.getType()->isInteger() && I.getType() != Type::Int1Ty) { + if (Instruction *Result = AssociativeOpt(I, AddRHS(RHS))) return Result; + + if (Instruction *RHSI = dyn_cast(RHS)) { + if (RHSI->getOpcode() == Instruction::Sub) + if (LHS == RHSI->getOperand(1)) // A + (B - A) --> B + return ReplaceInstUsesWith(I, RHSI->getOperand(0)); + } + if (Instruction *LHSI = dyn_cast(LHS)) { + if (LHSI->getOpcode() == Instruction::Sub) + if (RHS == LHSI->getOperand(1)) // (B - A) + A --> B + return ReplaceInstUsesWith(I, LHSI->getOperand(0)); + } + } + + // -A + B --> B - A + if (Value *V = dyn_castNegVal(LHS)) + return BinaryOperator::createSub(RHS, V); + + // A + -B --> A - B + if (!isa(RHS)) + if (Value *V = dyn_castNegVal(RHS)) + return BinaryOperator::createSub(LHS, V); + + + ConstantInt *C2; + if (Value *X = dyn_castFoldableMul(LHS, C2)) { + if (X == RHS) // X*C + X --> X * (C+1) + return BinaryOperator::createMul(RHS, AddOne(C2)); + + // X*C1 + X*C2 --> X * (C1+C2) + ConstantInt *C1; + if (X == dyn_castFoldableMul(RHS, C1)) + return BinaryOperator::createMul(X, Add(C1, C2)); + } + + // X + X*C --> X * (C+1) + if (dyn_castFoldableMul(RHS, C2) == LHS) + return BinaryOperator::createMul(LHS, AddOne(C2)); + + // X + ~X --> -1 since ~X = -X-1 + if (dyn_castNotVal(LHS) == RHS || dyn_castNotVal(RHS) == LHS) + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + + + // (A & C1)+(B & C2) --> (A & C1)|(B & C2) iff C1&C2 == 0 + if (match(RHS, m_And(m_Value(), m_ConstantInt(C2)))) + if (Instruction *R = AssociativeOpt(I, AddMaskingAnd(C2))) + return R; + + if (ConstantInt *CRHS = dyn_cast(RHS)) { + Value *X = 0; + if (match(LHS, m_Not(m_Value(X)))) // ~X + C --> (C-1) - X + return BinaryOperator::createSub(SubOne(CRHS), X); + + // (X & FF00) + xx00 -> (X+xx00) & FF00 + if (LHS->hasOneUse() && match(LHS, m_And(m_Value(X), m_ConstantInt(C2)))) { + Constant *Anded = And(CRHS, C2); + if (Anded == CRHS) { + // See if all bits from the first bit set in the Add RHS up are included + // in the mask. First, get the rightmost bit. + const APInt& AddRHSV = CRHS->getValue(); + + // Form a mask of all bits from the lowest bit added through the top. + APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1)); + + // See if the and mask includes all of these bits. + APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue()); + + if (AddRHSHighBits == AddRHSHighBitsAnd) { + // Okay, the xform is safe. Insert the new add pronto. + Value *NewAdd = InsertNewInstBefore(BinaryOperator::createAdd(X, CRHS, + LHS->getName()), I); + return BinaryOperator::createAnd(NewAdd, C2); + } + } + } + + // Try to fold constant add into select arguments. + if (SelectInst *SI = dyn_cast(LHS)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + } + + // add (cast *A to intptrtype) B -> + // cast (GEP (cast *A to sbyte*) B) -> + // intptrtype + { + CastInst *CI = dyn_cast(LHS); + Value *Other = RHS; + if (!CI) { + CI = dyn_cast(RHS); + Other = LHS; + } + if (CI && CI->getType()->isSized() && + (CI->getType()->getPrimitiveSizeInBits() == + TD->getIntPtrType()->getPrimitiveSizeInBits()) + && isa(CI->getOperand(0)->getType())) { + Value *I2 = InsertCastBefore(Instruction::BitCast, CI->getOperand(0), + PointerType::get(Type::Int8Ty), I); + I2 = InsertNewInstBefore(new GetElementPtrInst(I2, Other, "ctg2"), I); + return new PtrToIntInst(I2, CI->getType()); + } + } + + return Changed ? &I : 0; +} + +// isSignBit - Return true if the value represented by the constant only has the +// highest order bit set. +static bool isSignBit(ConstantInt *CI) { + uint32_t NumBits = CI->getType()->getPrimitiveSizeInBits(); + return CI->getValue() == APInt::getSignBit(NumBits); +} + +Instruction *InstCombiner::visitSub(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (Op0 == Op1) // sub X, X -> 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // If this is a 'B = x-(-A)', change to B = x+A... + if (Value *V = dyn_castNegVal(Op1)) + return BinaryOperator::createAdd(Op0, V); + + if (isa(Op0)) + return ReplaceInstUsesWith(I, Op0); // undef - X -> undef + if (isa(Op1)) + return ReplaceInstUsesWith(I, Op1); // X - undef -> undef + + if (ConstantInt *C = dyn_cast(Op0)) { + // Replace (-1 - A) with (~A)... + if (C->isAllOnesValue()) + return BinaryOperator::createNot(Op1); + + // C - ~X == X + (1+C) + Value *X = 0; + if (match(Op1, m_Not(m_Value(X)))) + return BinaryOperator::createAdd(X, AddOne(C)); + + // -(X >>u 31) -> (X >>s 31) + // -(X >>s 31) -> (X >>u 31) + if (C->isZero()) { + if (BinaryOperator *SI = dyn_cast(Op1)) + if (SI->getOpcode() == Instruction::LShr) { + if (ConstantInt *CU = dyn_cast(SI->getOperand(1))) { + // Check to see if we are shifting out everything but the sign bit. + if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) == + SI->getType()->getPrimitiveSizeInBits()-1) { + // Ok, the transformation is safe. Insert AShr. + return BinaryOperator::create(Instruction::AShr, + SI->getOperand(0), CU, SI->getName()); + } + } + } + else if (SI->getOpcode() == Instruction::AShr) { + if (ConstantInt *CU = dyn_cast(SI->getOperand(1))) { + // Check to see if we are shifting out everything but the sign bit. + if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) == + SI->getType()->getPrimitiveSizeInBits()-1) { + // Ok, the transformation is safe. Insert LShr. + return BinaryOperator::createLShr( + SI->getOperand(0), CU, SI->getName()); + } + } + } + } + + // Try to fold constant sub into select arguments. + if (SelectInst *SI = dyn_cast(Op1)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + + if (isa(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + + if (BinaryOperator *Op1I = dyn_cast(Op1)) { + if (Op1I->getOpcode() == Instruction::Add && + !Op0->getType()->isFPOrFPVector()) { + if (Op1I->getOperand(0) == Op0) // X-(X+Y) == -Y + return BinaryOperator::createNeg(Op1I->getOperand(1), I.getName()); + else if (Op1I->getOperand(1) == Op0) // X-(Y+X) == -Y + return BinaryOperator::createNeg(Op1I->getOperand(0), I.getName()); + else if (ConstantInt *CI1 = dyn_cast(I.getOperand(0))) { + if (ConstantInt *CI2 = dyn_cast(Op1I->getOperand(1))) + // C1-(X+C2) --> (C1-C2)-X + return BinaryOperator::createSub(Subtract(CI1, CI2), + Op1I->getOperand(0)); + } + } + + if (Op1I->hasOneUse()) { + // Replace (x - (y - z)) with (x + (z - y)) if the (y - z) subexpression + // is not used by anyone else... + // + if (Op1I->getOpcode() == Instruction::Sub && + !Op1I->getType()->isFPOrFPVector()) { + // Swap the two operands of the subexpr... + Value *IIOp0 = Op1I->getOperand(0), *IIOp1 = Op1I->getOperand(1); + Op1I->setOperand(0, IIOp1); + Op1I->setOperand(1, IIOp0); + + // Create the new top level add instruction... + return BinaryOperator::createAdd(Op0, Op1); + } + + // Replace (A - (A & B)) with (A & ~B) if this is the only use of (A&B)... + // + if (Op1I->getOpcode() == Instruction::And && + (Op1I->getOperand(0) == Op0 || Op1I->getOperand(1) == Op0)) { + Value *OtherOp = Op1I->getOperand(Op1I->getOperand(0) == Op0); + + Value *NewNot = + InsertNewInstBefore(BinaryOperator::createNot(OtherOp, "B.not"), I); + return BinaryOperator::createAnd(Op0, NewNot); + } + + // 0 - (X sdiv C) -> (X sdiv -C) + if (Op1I->getOpcode() == Instruction::SDiv) + if (ConstantInt *CSI = dyn_cast(Op0)) + if (CSI->isZero()) + if (Constant *DivRHS = dyn_cast(Op1I->getOperand(1))) + return BinaryOperator::createSDiv(Op1I->getOperand(0), + ConstantExpr::getNeg(DivRHS)); + + // X - X*C --> X * (1-C) + ConstantInt *C2 = 0; + if (dyn_castFoldableMul(Op1I, C2) == Op0) { + Constant *CP1 = Subtract(ConstantInt::get(I.getType(), 1), C2); + return BinaryOperator::createMul(Op0, CP1); + } + } + } + + if (!Op0->getType()->isFPOrFPVector()) + if (BinaryOperator *Op0I = dyn_cast(Op0)) + if (Op0I->getOpcode() == Instruction::Add) { + if (Op0I->getOperand(0) == Op1) // (Y+X)-Y == X + return ReplaceInstUsesWith(I, Op0I->getOperand(1)); + else if (Op0I->getOperand(1) == Op1) // (X+Y)-Y == X + return ReplaceInstUsesWith(I, Op0I->getOperand(0)); + } else if (Op0I->getOpcode() == Instruction::Sub) { + if (Op0I->getOperand(0) == Op1) // (X-Y)-X == -Y + return BinaryOperator::createNeg(Op0I->getOperand(1), I.getName()); + } + + ConstantInt *C1; + if (Value *X = dyn_castFoldableMul(Op0, C1)) { + if (X == Op1) // X*C - X --> X * (C-1) + return BinaryOperator::createMul(Op1, SubOne(C1)); + + ConstantInt *C2; // X*C1 - X*C2 -> X * (C1-C2) + if (X == dyn_castFoldableMul(Op1, C2)) + return BinaryOperator::createMul(Op1, Subtract(C1, C2)); + } + return 0; +} + +/// isSignBitCheck - Given an exploded icmp instruction, return true if the +/// comparison only checks the sign bit. If it only checks the sign bit, set +/// TrueIfSigned if the result of the comparison is true when the input value is +/// signed. +static bool isSignBitCheck(ICmpInst::Predicate pred, ConstantInt *RHS, + bool &TrueIfSigned) { + switch (pred) { + case ICmpInst::ICMP_SLT: // True if LHS s< 0 + TrueIfSigned = true; + return RHS->isZero(); + case ICmpInst::ICMP_SLE: // True if LHS s<= RHS and RHS == -1 + TrueIfSigned = true; + return RHS->isAllOnesValue(); + case ICmpInst::ICMP_SGT: // True if LHS s> -1 + TrueIfSigned = false; + return RHS->isAllOnesValue(); + case ICmpInst::ICMP_UGT: + // True if LHS u> RHS and RHS == high-bit-mask - 1 + TrueIfSigned = true; + return RHS->getValue() == + APInt::getSignedMaxValue(RHS->getType()->getPrimitiveSizeInBits()); + case ICmpInst::ICMP_UGE: + // True if LHS u>= RHS and RHS == high-bit-mask (2^7, 2^15, 2^31, etc) + TrueIfSigned = true; + return RHS->getValue() == + APInt::getSignBit(RHS->getType()->getPrimitiveSizeInBits()); + default: + return false; + } +} + +Instruction *InstCombiner::visitMul(BinaryOperator &I) { + bool Changed = SimplifyCommutative(I); + Value *Op0 = I.getOperand(0); + + if (isa(I.getOperand(1))) // undef * X -> 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // Simplify mul instructions with a constant RHS... + if (Constant *Op1 = dyn_cast(I.getOperand(1))) { + if (ConstantInt *CI = dyn_cast(Op1)) { + + // ((X << C1)*C2) == (X * (C2 << C1)) + if (BinaryOperator *SI = dyn_cast(Op0)) + if (SI->getOpcode() == Instruction::Shl) + if (Constant *ShOp = dyn_cast(SI->getOperand(1))) + return BinaryOperator::createMul(SI->getOperand(0), + ConstantExpr::getShl(CI, ShOp)); + + if (CI->isZero()) + return ReplaceInstUsesWith(I, Op1); // X * 0 == 0 + if (CI->equalsInt(1)) // X * 1 == X + return ReplaceInstUsesWith(I, Op0); + if (CI->isAllOnesValue()) // X * -1 == 0 - X + return BinaryOperator::createNeg(Op0, I.getName()); + + const APInt& Val = cast(CI)->getValue(); + if (Val.isPowerOf2()) { // Replace X*(2^C) with X << C + return BinaryOperator::createShl(Op0, + ConstantInt::get(Op0->getType(), Val.logBase2())); + } + } else if (ConstantFP *Op1F = dyn_cast(Op1)) { + if (Op1F->isNullValue()) + return ReplaceInstUsesWith(I, Op1); + + // "In IEEE floating point, x*1 is not equivalent to x for nans. However, + // ANSI says we can drop signals, so we can do this anyway." (from GCC) + if (Op1F->getValue() == 1.0) + return ReplaceInstUsesWith(I, Op0); // Eliminate 'mul double %X, 1.0' + } + + if (BinaryOperator *Op0I = dyn_cast(Op0)) + if (Op0I->getOpcode() == Instruction::Add && Op0I->hasOneUse() && + isa(Op0I->getOperand(1))) { + // Canonicalize (X+C1)*C2 -> X*C2+C1*C2. + Instruction *Add = BinaryOperator::createMul(Op0I->getOperand(0), + Op1, "tmp"); + InsertNewInstBefore(Add, I); + Value *C1C2 = ConstantExpr::getMul(Op1, + cast(Op0I->getOperand(1))); + return BinaryOperator::createAdd(Add, C1C2); + + } + + // Try to fold constant mul into select arguments. + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + + if (isa(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + + if (Value *Op0v = dyn_castNegVal(Op0)) // -X * -Y = X*Y + if (Value *Op1v = dyn_castNegVal(I.getOperand(1))) + return BinaryOperator::createMul(Op0v, Op1v); + + // If one of the operands of the multiply is a cast from a boolean value, then + // we know the bool is either zero or one, so this is a 'masking' multiply. + // See if we can simplify things based on how the boolean was originally + // formed. + CastInst *BoolCast = 0; + if (ZExtInst *CI = dyn_cast(I.getOperand(0))) + if (CI->getOperand(0)->getType() == Type::Int1Ty) + BoolCast = CI; + if (!BoolCast) + if (ZExtInst *CI = dyn_cast(I.getOperand(1))) + if (CI->getOperand(0)->getType() == Type::Int1Ty) + BoolCast = CI; + if (BoolCast) { + if (ICmpInst *SCI = dyn_cast(BoolCast->getOperand(0))) { + Value *SCIOp0 = SCI->getOperand(0), *SCIOp1 = SCI->getOperand(1); + const Type *SCOpTy = SCIOp0->getType(); + bool TIS = false; + + // If the icmp is true iff the sign bit of X is set, then convert this + // multiply into a shift/and combination. + if (isa(SCIOp1) && + isSignBitCheck(SCI->getPredicate(), cast(SCIOp1), TIS) && + TIS) { + // Shift the X value right to turn it into "all signbits". + Constant *Amt = ConstantInt::get(SCIOp0->getType(), + SCOpTy->getPrimitiveSizeInBits()-1); + Value *V = + InsertNewInstBefore( + BinaryOperator::create(Instruction::AShr, SCIOp0, Amt, + BoolCast->getOperand(0)->getName()+ + ".mask"), I); + + // If the multiply type is not the same as the source type, sign extend + // or truncate to the multiply type. + if (I.getType() != V->getType()) { + uint32_t SrcBits = V->getType()->getPrimitiveSizeInBits(); + uint32_t DstBits = I.getType()->getPrimitiveSizeInBits(); + Instruction::CastOps opcode = + (SrcBits == DstBits ? Instruction::BitCast : + (SrcBits < DstBits ? Instruction::SExt : Instruction::Trunc)); + V = InsertCastBefore(opcode, V, I.getType(), I); + } + + Value *OtherOp = Op0 == BoolCast ? I.getOperand(1) : Op0; + return BinaryOperator::createAnd(V, OtherOp); + } + } + } + + return Changed ? &I : 0; +} + +/// This function implements the transforms on div instructions that work +/// regardless of the kind of div instruction it is (udiv, sdiv, or fdiv). It is +/// used by the visitors to those instructions. +/// @brief Transforms common to all three div instructions +Instruction *InstCombiner::commonDivTransforms(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // undef / X -> 0 + if (isa(Op0)) + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // X / undef -> undef + if (isa(Op1)) + return ReplaceInstUsesWith(I, Op1); + + // Handle cases involving: div X, (select Cond, Y, Z) + if (SelectInst *SI = dyn_cast(Op1)) { + // div X, (Cond ? 0 : Y) -> div X, Y. If the div and the select are in the + // same basic block, then we replace the select with Y, and the condition + // of the select with false (if the cond value is in the same BB). If the + // select has uses other than the div, this allows them to be simplified + // also. Note that div X, Y is just as good as div X, 0 (undef) + if (Constant *ST = dyn_cast(SI->getOperand(1))) + if (ST->isNullValue()) { + Instruction *CondI = dyn_cast(SI->getOperand(0)); + if (CondI && CondI->getParent() == I.getParent()) + UpdateValueUsesWith(CondI, ConstantInt::getFalse()); + else if (I.getParent() != SI->getParent() || SI->hasOneUse()) + I.setOperand(1, SI->getOperand(2)); + else + UpdateValueUsesWith(SI, SI->getOperand(2)); + return &I; + } + + // Likewise for: div X, (Cond ? Y : 0) -> div X, Y + if (Constant *ST = dyn_cast(SI->getOperand(2))) + if (ST->isNullValue()) { + Instruction *CondI = dyn_cast(SI->getOperand(0)); + if (CondI && CondI->getParent() == I.getParent()) + UpdateValueUsesWith(CondI, ConstantInt::getTrue()); + else if (I.getParent() != SI->getParent() || SI->hasOneUse()) + I.setOperand(1, SI->getOperand(1)); + else + UpdateValueUsesWith(SI, SI->getOperand(1)); + return &I; + } + } + + return 0; +} + +/// This function implements the transforms common to both integer division +/// instructions (udiv and sdiv). It is called by the visitors to those integer +/// division instructions. +/// @brief Common integer divide transforms +Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (Instruction *Common = commonDivTransforms(I)) + return Common; + + if (ConstantInt *RHS = dyn_cast(Op1)) { + // div X, 1 == X + if (RHS->equalsInt(1)) + return ReplaceInstUsesWith(I, Op0); + + // (X / C1) / C2 -> X / (C1*C2) + if (Instruction *LHS = dyn_cast(Op0)) + if (Instruction::BinaryOps(LHS->getOpcode()) == I.getOpcode()) + if (ConstantInt *LHSRHS = dyn_cast(LHS->getOperand(1))) { + return BinaryOperator::create(I.getOpcode(), LHS->getOperand(0), + Multiply(RHS, LHSRHS)); + } + + if (!RHS->isZero()) { // avoid X udiv 0 + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + if (isa(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + } + + // 0 / X == 0, we don't need to preserve faults! + if (ConstantInt *LHS = dyn_cast(Op0)) + if (LHS->equalsInt(0)) + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + return 0; +} + +Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Handle the integer div common cases + if (Instruction *Common = commonIDivTransforms(I)) + return Common; + + // X udiv C^2 -> X >> C + // Check to see if this is an unsigned division with an exact power of 2, + // if so, convert to a right shift. + if (ConstantInt *C = dyn_cast(Op1)) { + if (C->getValue().isPowerOf2()) // 0 not included in isPowerOf2 + return BinaryOperator::createLShr(Op0, + ConstantInt::get(Op0->getType(), C->getValue().logBase2())); + } + + // X udiv (C1 << N), where C1 is "1< X >> (N+C2) + if (BinaryOperator *RHSI = dyn_cast(I.getOperand(1))) { + if (RHSI->getOpcode() == Instruction::Shl && + isa(RHSI->getOperand(0))) { + const APInt& C1 = cast(RHSI->getOperand(0))->getValue(); + if (C1.isPowerOf2()) { + Value *N = RHSI->getOperand(1); + const Type *NTy = N->getType(); + if (uint32_t C2 = C1.logBase2()) { + Constant *C2V = ConstantInt::get(NTy, C2); + N = InsertNewInstBefore(BinaryOperator::createAdd(N, C2V, "tmp"), I); + } + return BinaryOperator::createLShr(Op0, N); + } + } + } + + // udiv X, (Select Cond, C1, C2) --> Select Cond, (shr X, C1), (shr X, C2) + // where C1&C2 are powers of two. + if (SelectInst *SI = dyn_cast(Op1)) + if (ConstantInt *STO = dyn_cast(SI->getOperand(1))) + if (ConstantInt *SFO = dyn_cast(SI->getOperand(2))) { + const APInt &TVA = STO->getValue(), &FVA = SFO->getValue(); + if (TVA.isPowerOf2() && FVA.isPowerOf2()) { + // Compute the shift amounts + uint32_t TSA = TVA.logBase2(), FSA = FVA.logBase2(); + // Construct the "on true" case of the select + Constant *TC = ConstantInt::get(Op0->getType(), TSA); + Instruction *TSI = BinaryOperator::createLShr( + Op0, TC, SI->getName()+".t"); + TSI = InsertNewInstBefore(TSI, I); + + // Construct the "on false" case of the select + Constant *FC = ConstantInt::get(Op0->getType(), FSA); + Instruction *FSI = BinaryOperator::createLShr( + Op0, FC, SI->getName()+".f"); + FSI = InsertNewInstBefore(FSI, I); + + // construct the select instruction and return it. + return new SelectInst(SI->getOperand(0), TSI, FSI, SI->getName()); + } + } + return 0; +} + +Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Handle the integer div common cases + if (Instruction *Common = commonIDivTransforms(I)) + return Common; + + if (ConstantInt *RHS = dyn_cast(Op1)) { + // sdiv X, -1 == -X + if (RHS->isAllOnesValue()) + return BinaryOperator::createNeg(Op0); + + // -X/C -> X/-C + if (Value *LHSNeg = dyn_castNegVal(Op0)) + return BinaryOperator::createSDiv(LHSNeg, ConstantExpr::getNeg(RHS)); + } + + // If the sign bits of both operands are zero (i.e. we can prove they are + // unsigned inputs), turn this into a udiv. + if (I.getType()->isInteger()) { + APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); + if (MaskedValueIsZero(Op1, Mask) && MaskedValueIsZero(Op0, Mask)) { + return BinaryOperator::createUDiv(Op0, Op1, I.getName()); + } + } + + return 0; +} + +Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { + return commonDivTransforms(I); +} + +/// GetFactor - If we can prove that the specified value is at least a multiple +/// of some factor, return that factor. +static Constant *GetFactor(Value *V) { + if (ConstantInt *CI = dyn_cast(V)) + return CI; + + // Unless we can be tricky, we know this is a multiple of 1. + Constant *Result = ConstantInt::get(V->getType(), 1); + + Instruction *I = dyn_cast(V); + if (!I) return Result; + + if (I->getOpcode() == Instruction::Mul) { + // Handle multiplies by a constant, etc. + return ConstantExpr::getMul(GetFactor(I->getOperand(0)), + GetFactor(I->getOperand(1))); + } else if (I->getOpcode() == Instruction::Shl) { + // (X< X * (1 << C) + if (Constant *ShRHS = dyn_cast(I->getOperand(1))) { + ShRHS = ConstantExpr::getShl(Result, ShRHS); + return ConstantExpr::getMul(GetFactor(I->getOperand(0)), ShRHS); + } + } else if (I->getOpcode() == Instruction::And) { + if (ConstantInt *RHS = dyn_cast(I->getOperand(1))) { + // X & 0xFFF0 is known to be a multiple of 16. + uint32_t Zeros = RHS->getValue().countTrailingZeros(); + if (Zeros != V->getType()->getPrimitiveSizeInBits()) + return ConstantExpr::getShl(Result, + ConstantInt::get(Result->getType(), Zeros)); + } + } else if (CastInst *CI = dyn_cast(I)) { + // Only handle int->int casts. + if (!CI->isIntegerCast()) + return Result; + Value *Op = CI->getOperand(0); + return ConstantExpr::getCast(CI->getOpcode(), GetFactor(Op), V->getType()); + } + return Result; +} + +/// This function implements the transforms on rem instructions that work +/// regardless of the kind of rem instruction it is (urem, srem, or frem). It +/// is used by the visitors to those instructions. +/// @brief Transforms common to all three rem instructions +Instruction *InstCombiner::commonRemTransforms(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // 0 % X == 0, we don't need to preserve faults! + if (Constant *LHS = dyn_cast(Op0)) + if (LHS->isNullValue()) + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + if (isa(Op0)) // undef % X -> 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + if (isa(Op1)) + return ReplaceInstUsesWith(I, Op1); // X % undef -> undef + + // Handle cases involving: rem X, (select Cond, Y, Z) + if (SelectInst *SI = dyn_cast(Op1)) { + // rem X, (Cond ? 0 : Y) -> rem X, Y. If the rem and the select are in + // the same basic block, then we replace the select with Y, and the + // condition of the select with false (if the cond value is in the same + // BB). If the select has uses other than the div, this allows them to be + // simplified also. + if (Constant *ST = dyn_cast(SI->getOperand(1))) + if (ST->isNullValue()) { + Instruction *CondI = dyn_cast(SI->getOperand(0)); + if (CondI && CondI->getParent() == I.getParent()) + UpdateValueUsesWith(CondI, ConstantInt::getFalse()); + else if (I.getParent() != SI->getParent() || SI->hasOneUse()) + I.setOperand(1, SI->getOperand(2)); + else + UpdateValueUsesWith(SI, SI->getOperand(2)); + return &I; + } + // Likewise for: rem X, (Cond ? Y : 0) -> rem X, Y + if (Constant *ST = dyn_cast(SI->getOperand(2))) + if (ST->isNullValue()) { + Instruction *CondI = dyn_cast(SI->getOperand(0)); + if (CondI && CondI->getParent() == I.getParent()) + UpdateValueUsesWith(CondI, ConstantInt::getTrue()); + else if (I.getParent() != SI->getParent() || SI->hasOneUse()) + I.setOperand(1, SI->getOperand(1)); + else + UpdateValueUsesWith(SI, SI->getOperand(1)); + return &I; + } + } + + return 0; +} + +/// This function implements the transforms common to both integer remainder +/// instructions (urem and srem). It is called by the visitors to those integer +/// remainder instructions. +/// @brief Common integer remainder transforms +Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (Instruction *common = commonRemTransforms(I)) + return common; + + if (ConstantInt *RHS = dyn_cast(Op1)) { + // X % 0 == undef, we don't need to preserve faults! + if (RHS->equalsInt(0)) + return ReplaceInstUsesWith(I, UndefValue::get(I.getType())); + + if (RHS->equalsInt(1)) // X % 1 == 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + if (Instruction *Op0I = dyn_cast(Op0)) { + if (SelectInst *SI = dyn_cast(Op0I)) { + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + } else if (isa(Op0I)) { + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + // (X * C1) % C2 --> 0 iff C1 % C2 == 0 + if (ConstantExpr::getSRem(GetFactor(Op0I), RHS)->isNullValue()) + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + } + } + + return 0; +} + +Instruction *InstCombiner::visitURem(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (Instruction *common = commonIRemTransforms(I)) + return common; + + if (ConstantInt *RHS = dyn_cast(Op1)) { + // X urem C^2 -> X and C + // Check to see if this is an unsigned remainder with an exact power of 2, + // if so, convert to a bitwise and. + if (ConstantInt *C = dyn_cast(RHS)) + if (C->getValue().isPowerOf2()) + return BinaryOperator::createAnd(Op0, SubOne(C)); + } + + if (Instruction *RHSI = dyn_cast(I.getOperand(1))) { + // Turn A % (C << N), where C is 2^k, into A & ((C << N)-1) + if (RHSI->getOpcode() == Instruction::Shl && + isa(RHSI->getOperand(0))) { + if (cast(RHSI->getOperand(0))->getValue().isPowerOf2()) { + Constant *N1 = ConstantInt::getAllOnesValue(I.getType()); + Value *Add = InsertNewInstBefore(BinaryOperator::createAdd(RHSI, N1, + "tmp"), I); + return BinaryOperator::createAnd(Op0, Add); + } + } + } + + // urem X, (select Cond, 2^C1, 2^C2) --> select Cond, (and X, C1), (and X, C2) + // where C1&C2 are powers of two. + if (SelectInst *SI = dyn_cast(Op1)) { + if (ConstantInt *STO = dyn_cast(SI->getOperand(1))) + if (ConstantInt *SFO = dyn_cast(SI->getOperand(2))) { + // STO == 0 and SFO == 0 handled above. + if ((STO->getValue().isPowerOf2()) && + (SFO->getValue().isPowerOf2())) { + Value *TrueAnd = InsertNewInstBefore( + BinaryOperator::createAnd(Op0, SubOne(STO), SI->getName()+".t"), I); + Value *FalseAnd = InsertNewInstBefore( + BinaryOperator::createAnd(Op0, SubOne(SFO), SI->getName()+".f"), I); + return new SelectInst(SI->getOperand(0), TrueAnd, FalseAnd); + } + } + } + + return 0; +} + +Instruction *InstCombiner::visitSRem(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (Instruction *common = commonIRemTransforms(I)) + return common; + + if (Value *RHSNeg = dyn_castNegVal(Op1)) + if (!isa(RHSNeg) || + cast(RHSNeg)->getValue().isStrictlyPositive()) { + // X % -Y -> X % Y + AddUsesToWorkList(I); + I.setOperand(1, RHSNeg); + return &I; + } + + // If the top bits of both operands are zero (i.e. we can prove they are + // unsigned inputs), turn this into a urem. + APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); + if (MaskedValueIsZero(Op1, Mask) && MaskedValueIsZero(Op0, Mask)) { + // X srem Y -> X urem Y, iff X and Y don't have sign bit set + return BinaryOperator::createURem(Op0, Op1, I.getName()); + } + + return 0; +} + +Instruction *InstCombiner::visitFRem(BinaryOperator &I) { + return commonRemTransforms(I); +} + +// isMaxValueMinusOne - return true if this is Max-1 +static bool isMaxValueMinusOne(const ConstantInt *C, bool isSigned) { + uint32_t TypeBits = C->getType()->getPrimitiveSizeInBits(); + if (!isSigned) + return C->getValue() == APInt::getAllOnesValue(TypeBits) - 1; + return C->getValue() == APInt::getSignedMaxValue(TypeBits)-1; +} + +// isMinValuePlusOne - return true if this is Min+1 +static bool isMinValuePlusOne(const ConstantInt *C, bool isSigned) { + if (!isSigned) + return C->getValue() == 1; // unsigned + + // Calculate 1111111111000000000000 + uint32_t TypeBits = C->getType()->getPrimitiveSizeInBits(); + return C->getValue() == APInt::getSignedMinValue(TypeBits)+1; +} + +// isOneBitSet - Return true if there is exactly one bit set in the specified +// constant. +static bool isOneBitSet(const ConstantInt *CI) { + return CI->getValue().isPowerOf2(); +} + +// isHighOnes - Return true if the constant is of the form 1+0+. +// This is the same as lowones(~X). +static bool isHighOnes(const ConstantInt *CI) { + return (~CI->getValue() + 1).isPowerOf2(); +} + +/// getICmpCode - Encode a icmp predicate into a three bit mask. These bits +/// are carefully arranged to allow folding of expressions such as: +/// +/// (A < B) | (A > B) --> (A != B) +/// +/// Note that this is only valid if the first and second predicates have the +/// same sign. Is illegal to do: (A u< B) | (A s> B) +/// +/// Three bits are used to represent the condition, as follows: +/// 0 A > B +/// 1 A == B +/// 2 A < B +/// +/// <=> Value Definition +/// 000 0 Always false +/// 001 1 A > B +/// 010 2 A == B +/// 011 3 A >= B +/// 100 4 A < B +/// 101 5 A != B +/// 110 6 A <= B +/// 111 7 Always true +/// +static unsigned getICmpCode(const ICmpInst *ICI) { + switch (ICI->getPredicate()) { + // False -> 0 + case ICmpInst::ICMP_UGT: return 1; // 001 + case ICmpInst::ICMP_SGT: return 1; // 001 + case ICmpInst::ICMP_EQ: return 2; // 010 + case ICmpInst::ICMP_UGE: return 3; // 011 + case ICmpInst::ICMP_SGE: return 3; // 011 + case ICmpInst::ICMP_ULT: return 4; // 100 + case ICmpInst::ICMP_SLT: return 4; // 100 + case ICmpInst::ICMP_NE: return 5; // 101 + case ICmpInst::ICMP_ULE: return 6; // 110 + case ICmpInst::ICMP_SLE: return 6; // 110 + // True -> 7 + default: + assert(0 && "Invalid ICmp predicate!"); + return 0; + } +} + +/// getICmpValue - This is the complement of getICmpCode, which turns an +/// opcode and two operands into either a constant true or false, or a brand +/// new /// ICmp instruction. The sign is passed in to determine which kind +/// of predicate to use in new icmp instructions. +static Value *getICmpValue(bool sign, unsigned code, Value *LHS, Value *RHS) { + switch (code) { + default: assert(0 && "Illegal ICmp code!"); + case 0: return ConstantInt::getFalse(); + case 1: + if (sign) + return new ICmpInst(ICmpInst::ICMP_SGT, LHS, RHS); + else + return new ICmpInst(ICmpInst::ICMP_UGT, LHS, RHS); + case 2: return new ICmpInst(ICmpInst::ICMP_EQ, LHS, RHS); + case 3: + if (sign) + return new ICmpInst(ICmpInst::ICMP_SGE, LHS, RHS); + else + return new ICmpInst(ICmpInst::ICMP_UGE, LHS, RHS); + case 4: + if (sign) + return new ICmpInst(ICmpInst::ICMP_SLT, LHS, RHS); + else + return new ICmpInst(ICmpInst::ICMP_ULT, LHS, RHS); + case 5: return new ICmpInst(ICmpInst::ICMP_NE, LHS, RHS); + case 6: + if (sign) + return new ICmpInst(ICmpInst::ICMP_SLE, LHS, RHS); + else + return new ICmpInst(ICmpInst::ICMP_ULE, LHS, RHS); + case 7: return ConstantInt::getTrue(); + } +} + +static bool PredicatesFoldable(ICmpInst::Predicate p1, ICmpInst::Predicate p2) { + return (ICmpInst::isSignedPredicate(p1) == ICmpInst::isSignedPredicate(p2)) || + (ICmpInst::isSignedPredicate(p1) && + (p2 == ICmpInst::ICMP_EQ || p2 == ICmpInst::ICMP_NE)) || + (ICmpInst::isSignedPredicate(p2) && + (p1 == ICmpInst::ICMP_EQ || p1 == ICmpInst::ICMP_NE)); +} + +namespace { +// FoldICmpLogical - Implements (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) +struct FoldICmpLogical { + InstCombiner &IC; + Value *LHS, *RHS; + ICmpInst::Predicate pred; + FoldICmpLogical(InstCombiner &ic, ICmpInst *ICI) + : IC(ic), LHS(ICI->getOperand(0)), RHS(ICI->getOperand(1)), + pred(ICI->getPredicate()) {} + bool shouldApply(Value *V) const { + if (ICmpInst *ICI = dyn_cast(V)) + if (PredicatesFoldable(pred, ICI->getPredicate())) + return (ICI->getOperand(0) == LHS && ICI->getOperand(1) == RHS || + ICI->getOperand(0) == RHS && ICI->getOperand(1) == LHS); + return false; + } + Instruction *apply(Instruction &Log) const { + ICmpInst *ICI = cast(Log.getOperand(0)); + if (ICI->getOperand(0) != LHS) { + assert(ICI->getOperand(1) == LHS); + ICI->swapOperands(); // Swap the LHS and RHS of the ICmp + } + + ICmpInst *RHSICI = cast(Log.getOperand(1)); + unsigned LHSCode = getICmpCode(ICI); + unsigned RHSCode = getICmpCode(RHSICI); + unsigned Code; + switch (Log.getOpcode()) { + case Instruction::And: Code = LHSCode & RHSCode; break; + case Instruction::Or: Code = LHSCode | RHSCode; break; + case Instruction::Xor: Code = LHSCode ^ RHSCode; break; + default: assert(0 && "Illegal logical opcode!"); return 0; + } + + bool isSigned = ICmpInst::isSignedPredicate(RHSICI->getPredicate()) || + ICmpInst::isSignedPredicate(ICI->getPredicate()); + + Value *RV = getICmpValue(isSigned, Code, LHS, RHS); + if (Instruction *I = dyn_cast(RV)) + return I; + // Otherwise, it's a constant boolean value... + return IC.ReplaceInstUsesWith(Log, RV); + } +}; +} // end anonymous namespace + +// OptAndOp - This handles expressions of the form ((val OP C1) & C2). Where +// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. Op is +// guaranteed to be a binary operator. +Instruction *InstCombiner::OptAndOp(Instruction *Op, + ConstantInt *OpRHS, + ConstantInt *AndRHS, + BinaryOperator &TheAnd) { + Value *X = Op->getOperand(0); + Constant *Together = 0; + if (!Op->isShift()) + Together = And(AndRHS, OpRHS); + + switch (Op->getOpcode()) { + case Instruction::Xor: + if (Op->hasOneUse()) { + // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) + Instruction *And = BinaryOperator::createAnd(X, AndRHS); + InsertNewInstBefore(And, TheAnd); + And->takeName(Op); + return BinaryOperator::createXor(And, Together); + } + break; + case Instruction::Or: + if (Together == AndRHS) // (X | C) & C --> C + return ReplaceInstUsesWith(TheAnd, AndRHS); + + if (Op->hasOneUse() && Together != OpRHS) { + // (X | C1) & C2 --> (X | (C1&C2)) & C2 + Instruction *Or = BinaryOperator::createOr(X, Together); + InsertNewInstBefore(Or, TheAnd); + Or->takeName(Op); + return BinaryOperator::createAnd(Or, AndRHS); + } + break; + case Instruction::Add: + if (Op->hasOneUse()) { + // Adding a one to a single bit bit-field should be turned into an XOR + // of the bit. First thing to check is to see if this AND is with a + // single bit constant. + const APInt& AndRHSV = cast(AndRHS)->getValue(); + + // If there is only one bit set... + if (isOneBitSet(cast(AndRHS))) { + // Ok, at this point, we know that we are masking the result of the + // ADD down to exactly one bit. If the constant we are adding has + // no bits set below this bit, then we can eliminate the ADD. + const APInt& AddRHS = cast(OpRHS)->getValue(); + + // Check to see if any bits below the one bit set in AndRHSV are set. + if ((AddRHS & (AndRHSV-1)) == 0) { + // If not, the only thing that can effect the output of the AND is + // the bit specified by AndRHSV. If that bit is set, the effect of + // the XOR is to toggle the bit. If it is clear, then the ADD has + // no effect. + if ((AddRHS & AndRHSV) == 0) { // Bit is not set, noop + TheAnd.setOperand(0, X); + return &TheAnd; + } else { + // Pull the XOR out of the AND. + Instruction *NewAnd = BinaryOperator::createAnd(X, AndRHS); + InsertNewInstBefore(NewAnd, TheAnd); + NewAnd->takeName(Op); + return BinaryOperator::createXor(NewAnd, AndRHS); + } + } + } + } + break; + + case Instruction::Shl: { + // We know that the AND will not produce any of the bits shifted in, so if + // the anded constant includes them, clear them now! + // + uint32_t BitWidth = AndRHS->getType()->getBitWidth(); + uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); + APInt ShlMask(APInt::getHighBitsSet(BitWidth, BitWidth-OpRHSVal)); + ConstantInt *CI = ConstantInt::get(AndRHS->getValue() & ShlMask); + + if (CI->getValue() == ShlMask) { + // Masking out bits that the shift already masks + return ReplaceInstUsesWith(TheAnd, Op); // No need for the and. + } else if (CI != AndRHS) { // Reducing bits set in and. + TheAnd.setOperand(1, CI); + return &TheAnd; + } + break; + } + case Instruction::LShr: + { + // We know that the AND will not produce any of the bits shifted in, so if + // the anded constant includes them, clear them now! This only applies to + // unsigned shifts, because a signed shr may bring in set bits! + // + uint32_t BitWidth = AndRHS->getType()->getBitWidth(); + uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); + APInt ShrMask(APInt::getLowBitsSet(BitWidth, BitWidth - OpRHSVal)); + ConstantInt *CI = ConstantInt::get(AndRHS->getValue() & ShrMask); + + if (CI->getValue() == ShrMask) { + // Masking out bits that the shift already masks. + return ReplaceInstUsesWith(TheAnd, Op); + } else if (CI != AndRHS) { + TheAnd.setOperand(1, CI); // Reduce bits set in and cst. + return &TheAnd; + } + break; + } + case Instruction::AShr: + // Signed shr. + // See if this is shifting in some sign extension, then masking it out + // with an and. + if (Op->hasOneUse()) { + uint32_t BitWidth = AndRHS->getType()->getBitWidth(); + uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); + APInt ShrMask(APInt::getLowBitsSet(BitWidth, BitWidth - OpRHSVal)); + Constant *C = ConstantInt::get(AndRHS->getValue() & ShrMask); + if (C == AndRHS) { // Masking out bits shifted in. + // (Val ashr C1) & C2 -> (Val lshr C1) & C2 + // Make the argument unsigned. + Value *ShVal = Op->getOperand(0); + ShVal = InsertNewInstBefore( + BinaryOperator::createLShr(ShVal, OpRHS, + Op->getName()), TheAnd); + return BinaryOperator::createAnd(ShVal, AndRHS, TheAnd.getName()); + } + } + break; + } + return 0; +} + + +/// InsertRangeTest - Emit a computation of: (V >= Lo && V < Hi) if Inside is +/// true, otherwise (V < Lo || V >= Hi). In pratice, we emit the more efficient +/// (V-Lo) (ConstantExpr::getICmp((isSigned ? + ICmpInst::ICMP_SLE:ICmpInst::ICMP_ULE), Lo, Hi))->getZExtValue() && + "Lo is not <= Hi in range emission code!"); + + if (Inside) { + if (Lo == Hi) // Trivially false. + return new ICmpInst(ICmpInst::ICMP_NE, V, V); + + // V >= Min && V < Hi --> V < Hi + if (cast(Lo)->isMinValue(isSigned)) { + ICmpInst::Predicate pred = (isSigned ? + ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT); + return new ICmpInst(pred, V, Hi); + } + + // Emit V-Lo getName()+".off"); + InsertNewInstBefore(Add, IB); + Constant *UpperBound = ConstantExpr::getAdd(NegLo, Hi); + return new ICmpInst(ICmpInst::ICMP_ULT, Add, UpperBound); + } + + if (Lo == Hi) // Trivially true. + return new ICmpInst(ICmpInst::ICMP_EQ, V, V); + + // V < Min || V >= Hi -> V > Hi-1 + Hi = SubOne(cast(Hi)); + if (cast(Lo)->isMinValue(isSigned)) { + ICmpInst::Predicate pred = (isSigned ? + ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); + return new ICmpInst(pred, V, Hi); + } + + // Emit V-Lo >u Hi-1-Lo + // Note that Hi has already had one subtracted from it, above. + ConstantInt *NegLo = cast(ConstantExpr::getNeg(Lo)); + Instruction *Add = BinaryOperator::createAdd(V, NegLo, V->getName()+".off"); + InsertNewInstBefore(Add, IB); + Constant *LowerBound = ConstantExpr::getAdd(NegLo, Hi); + return new ICmpInst(ICmpInst::ICMP_UGT, Add, LowerBound); +} + +// isRunOfOnes - Returns true iff Val consists of one contiguous run of 1s with +// any number of 0s on either side. The 1s are allowed to wrap from LSB to +// MSB, so 0x000FFF0, 0x0000FFFF, and 0xFF0000FF are all runs. 0x0F0F0000 is +// not, since all 1s are not contiguous. +static bool isRunOfOnes(ConstantInt *Val, uint32_t &MB, uint32_t &ME) { + const APInt& V = Val->getValue(); + uint32_t BitWidth = Val->getType()->getBitWidth(); + if (!APIntOps::isShiftedMask(BitWidth, V)) return false; + + // look for the first zero bit after the run of ones + MB = BitWidth - ((V - 1) ^ V).countLeadingZeros(); + // look for the first non-zero bit + ME = V.getActiveBits(); + return true; +} + +/// FoldLogicalPlusAnd - This is part of an expression (LHS +/- RHS) & Mask, +/// where isSub determines whether the operator is a sub. If we can fold one of +/// the following xforms: +/// +/// ((A & N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == Mask +/// ((A | N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 +/// ((A ^ N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 +/// +/// return (A +/- B). +/// +Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS, + ConstantInt *Mask, bool isSub, + Instruction &I) { + Instruction *LHSI = dyn_cast(LHS); + if (!LHSI || LHSI->getNumOperands() != 2 || + !isa(LHSI->getOperand(1))) return 0; + + ConstantInt *N = cast(LHSI->getOperand(1)); + + switch (LHSI->getOpcode()) { + default: return 0; + case Instruction::And: + if (And(N, Mask) == Mask) { + // If the AndRHS is a power of two minus one (0+1+), this is simple. + if ((Mask->getValue().countLeadingZeros() + + Mask->getValue().countPopulation()) == + Mask->getValue().getBitWidth()) + break; + + // Otherwise, if Mask is 0+1+0+, and if B is known to have the low 0+ + // part, we don't need any explicit masks to take them out of A. If that + // is all N is, ignore it. + uint32_t MB = 0, ME = 0; + if (isRunOfOnes(Mask, MB, ME)) { // begin/end bit of run, inclusive + uint32_t BitWidth = cast(RHS->getType())->getBitWidth(); + APInt Mask(APInt::getLowBitsSet(BitWidth, MB-1)); + if (MaskedValueIsZero(RHS, Mask)) + break; + } + } + return 0; + case Instruction::Or: + case Instruction::Xor: + // If the AndRHS is a power of two minus one (0+1+), and N&Mask == 0 + if ((Mask->getValue().countLeadingZeros() + + Mask->getValue().countPopulation()) == Mask->getValue().getBitWidth() + && And(N, Mask)->isZero()) + break; + return 0; + } + + Instruction *New; + if (isSub) + New = BinaryOperator::createSub(LHSI->getOperand(0), RHS, "fold"); + else + New = BinaryOperator::createAdd(LHSI->getOperand(0), RHS, "fold"); + return InsertNewInstBefore(New, I); +} + +Instruction *InstCombiner::visitAnd(BinaryOperator &I) { + bool Changed = SimplifyCommutative(I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (isa(Op1)) // X & undef -> 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // and X, X = X + if (Op0 == Op1) + return ReplaceInstUsesWith(I, Op1); + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + if (!isa(I.getType())) { + uint32_t BitWidth = cast(I.getType())->getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + if (SimplifyDemandedBits(&I, APInt::getAllOnesValue(BitWidth), + KnownZero, KnownOne)) + return &I; + } else { + if (ConstantVector *CP = dyn_cast(Op1)) { + if (CP->isAllOnesValue()) // X & <-1,-1> -> X + return ReplaceInstUsesWith(I, I.getOperand(0)); + } else if (isa(Op1)) { + return ReplaceInstUsesWith(I, Op1); // X & <0,0> -> <0,0> + } + } + + if (ConstantInt *AndRHS = dyn_cast(Op1)) { + const APInt& AndRHSMask = AndRHS->getValue(); + APInt NotAndRHS(~AndRHSMask); + + // Optimize a variety of ((val OP C1) & C2) combinations... + if (isa(Op0)) { + Instruction *Op0I = cast(Op0); + Value *Op0LHS = Op0I->getOperand(0); + Value *Op0RHS = Op0I->getOperand(1); + switch (Op0I->getOpcode()) { + case Instruction::Xor: + case Instruction::Or: + // If the mask is only needed on one incoming arm, push it up. + if (Op0I->hasOneUse()) { + if (MaskedValueIsZero(Op0LHS, NotAndRHS)) { + // Not masking anything out for the LHS, move to RHS. + Instruction *NewRHS = BinaryOperator::createAnd(Op0RHS, AndRHS, + Op0RHS->getName()+".masked"); + InsertNewInstBefore(NewRHS, I); + return BinaryOperator::create( + cast(Op0I)->getOpcode(), Op0LHS, NewRHS); + } + if (!isa(Op0RHS) && + MaskedValueIsZero(Op0RHS, NotAndRHS)) { + // Not masking anything out for the RHS, move to LHS. + Instruction *NewLHS = BinaryOperator::createAnd(Op0LHS, AndRHS, + Op0LHS->getName()+".masked"); + InsertNewInstBefore(NewLHS, I); + return BinaryOperator::create( + cast(Op0I)->getOpcode(), NewLHS, Op0RHS); + } + } + + break; + case Instruction::Add: + // ((A & N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == AndRHS. + // ((A | N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 + // ((A ^ N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 + if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, false, I)) + return BinaryOperator::createAnd(V, AndRHS); + if (Value *V = FoldLogicalPlusAnd(Op0RHS, Op0LHS, AndRHS, false, I)) + return BinaryOperator::createAnd(V, AndRHS); // Add commutes + break; + + case Instruction::Sub: + // ((A & N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == AndRHS. + // ((A | N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 + // ((A ^ N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 + if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, true, I)) + return BinaryOperator::createAnd(V, AndRHS); + break; + } + + if (ConstantInt *Op0CI = dyn_cast(Op0I->getOperand(1))) + if (Instruction *Res = OptAndOp(Op0I, Op0CI, AndRHS, I)) + return Res; + } else if (CastInst *CI = dyn_cast(Op0)) { + // If this is an integer truncation or change from signed-to-unsigned, and + // if the source is an and/or with immediate, transform it. This + // frequently occurs for bitfield accesses. + if (Instruction *CastOp = dyn_cast(CI->getOperand(0))) { + if ((isa(CI) || isa(CI)) && + CastOp->getNumOperands() == 2) + if (ConstantInt *AndCI = dyn_cast(CastOp->getOperand(1))) + if (CastOp->getOpcode() == Instruction::And) { + // Change: and (cast (and X, C1) to T), C2 + // into : and (cast X to T), trunc_or_bitcast(C1)&C2 + // This will fold the two constants together, which may allow + // other simplifications. + Instruction *NewCast = CastInst::createTruncOrBitCast( + CastOp->getOperand(0), I.getType(), + CastOp->getName()+".shrunk"); + NewCast = InsertNewInstBefore(NewCast, I); + // trunc_or_bitcast(C1)&C2 + Constant *C3 = ConstantExpr::getTruncOrBitCast(AndCI,I.getType()); + C3 = ConstantExpr::getAnd(C3, AndRHS); + return BinaryOperator::createAnd(NewCast, C3); + } else if (CastOp->getOpcode() == Instruction::Or) { + // Change: and (cast (or X, C1) to T), C2 + // into : trunc(C1)&C2 iff trunc(C1)&C2 == C2 + Constant *C3 = ConstantExpr::getTruncOrBitCast(AndCI,I.getType()); + if (ConstantExpr::getAnd(C3, AndRHS) == AndRHS) // trunc(C1)&C2 + return ReplaceInstUsesWith(I, AndRHS); + } + } + } + + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + if (isa(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + + Value *Op0NotVal = dyn_castNotVal(Op0); + Value *Op1NotVal = dyn_castNotVal(Op1); + + if (Op0NotVal == Op1 || Op1NotVal == Op0) // A & ~A == ~A & A == 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // (~A & ~B) == (~(A | B)) - De Morgan's Law + if (Op0NotVal && Op1NotVal && isOnlyUse(Op0) && isOnlyUse(Op1)) { + Instruction *Or = BinaryOperator::createOr(Op0NotVal, Op1NotVal, + I.getName()+".demorgan"); + InsertNewInstBefore(Or, I); + return BinaryOperator::createNot(Or); + } + + { + Value *A = 0, *B = 0, *C = 0, *D = 0; + if (match(Op0, m_Or(m_Value(A), m_Value(B)))) { + if (A == Op1 || B == Op1) // (A | ?) & A --> A + return ReplaceInstUsesWith(I, Op1); + + // (A|B) & ~(A&B) -> A^B + if (match(Op1, m_Not(m_And(m_Value(C), m_Value(D))))) { + if ((A == C && B == D) || (A == D && B == C)) + return BinaryOperator::createXor(A, B); + } + } + + if (match(Op1, m_Or(m_Value(A), m_Value(B)))) { + if (A == Op0 || B == Op0) // A & (A | ?) --> A + return ReplaceInstUsesWith(I, Op0); + + // ~(A&B) & (A|B) -> A^B + if (match(Op0, m_Not(m_And(m_Value(C), m_Value(D))))) { + if ((A == C && B == D) || (A == D && B == C)) + return BinaryOperator::createXor(A, B); + } + } + + if (Op0->hasOneUse() && + match(Op0, m_Xor(m_Value(A), m_Value(B)))) { + if (A == Op1) { // (A^B)&A -> A&(A^B) + I.swapOperands(); // Simplify below + std::swap(Op0, Op1); + } else if (B == Op1) { // (A^B)&B -> B&(B^A) + cast(Op0)->swapOperands(); + I.swapOperands(); // Simplify below + std::swap(Op0, Op1); + } + } + if (Op1->hasOneUse() && + match(Op1, m_Xor(m_Value(A), m_Value(B)))) { + if (B == Op0) { // B&(A^B) -> B&(B^A) + cast(Op1)->swapOperands(); + std::swap(A, B); + } + if (A == Op0) { // A&(A^B) -> A & ~B + Instruction *NotB = BinaryOperator::createNot(B, "tmp"); + InsertNewInstBefore(NotB, I); + return BinaryOperator::createAnd(A, NotB); + } + } + } + + if (ICmpInst *RHS = dyn_cast(Op1)) { + // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) + if (Instruction *R = AssociativeOpt(I, FoldICmpLogical(*this, RHS))) + return R; + + Value *LHSVal, *RHSVal; + ConstantInt *LHSCst, *RHSCst; + ICmpInst::Predicate LHSCC, RHSCC; + if (match(Op0, m_ICmp(LHSCC, m_Value(LHSVal), m_ConstantInt(LHSCst)))) + if (match(RHS, m_ICmp(RHSCC, m_Value(RHSVal), m_ConstantInt(RHSCst)))) + if (LHSVal == RHSVal && // Found (X icmp C1) & (X icmp C2) + // ICMP_[GL]E X, CST is folded to ICMP_[GL]T elsewhere. + LHSCC != ICmpInst::ICMP_UGE && LHSCC != ICmpInst::ICMP_ULE && + RHSCC != ICmpInst::ICMP_UGE && RHSCC != ICmpInst::ICMP_ULE && + LHSCC != ICmpInst::ICMP_SGE && LHSCC != ICmpInst::ICMP_SLE && + RHSCC != ICmpInst::ICMP_SGE && RHSCC != ICmpInst::ICMP_SLE) { + // Ensure that the larger constant is on the RHS. + ICmpInst::Predicate GT = ICmpInst::isSignedPredicate(LHSCC) ? + ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; + Constant *Cmp = ConstantExpr::getICmp(GT, LHSCst, RHSCst); + ICmpInst *LHS = cast(Op0); + if (cast(Cmp)->getZExtValue()) { + std::swap(LHS, RHS); + std::swap(LHSCst, RHSCst); + std::swap(LHSCC, RHSCC); + } + + // At this point, we know we have have two icmp instructions + // comparing a value against two constants and and'ing the result + // together. Because of the above check, we know that we only have + // icmp eq, icmp ne, icmp [su]lt, and icmp [SU]gt here. We also know + // (from the FoldICmpLogical check above), that the two constants + // are not equal and that the larger constant is on the RHS + assert(LHSCst != RHSCst && "Compares not folded above?"); + + switch (LHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X == 13 & X == 15) -> false + case ICmpInst::ICMP_UGT: // (X == 13 & X > 15) -> false + case ICmpInst::ICMP_SGT: // (X == 13 & X > 15) -> false + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + case ICmpInst::ICMP_NE: // (X == 13 & X != 15) -> X == 13 + case ICmpInst::ICMP_ULT: // (X == 13 & X < 15) -> X == 13 + case ICmpInst::ICMP_SLT: // (X == 13 & X < 15) -> X == 13 + return ReplaceInstUsesWith(I, LHS); + } + case ICmpInst::ICMP_NE: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_ULT: + if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13 + return new ICmpInst(ICmpInst::ICMP_ULT, LHSVal, LHSCst); + break; // (X != 13 & X u< 15) -> no change + case ICmpInst::ICMP_SLT: + if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13 + return new ICmpInst(ICmpInst::ICMP_SLT, LHSVal, LHSCst); + break; // (X != 13 & X s< 15) -> no change + case ICmpInst::ICMP_EQ: // (X != 13 & X == 15) -> X == 15 + case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15 + case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15 + return ReplaceInstUsesWith(I, RHS); + case ICmpInst::ICMP_NE: + if (LHSCst == SubOne(RHSCst)){// (X != 13 & X != 14) -> X-13 >u 1 + Constant *AddCST = ConstantExpr::getNeg(LHSCst); + Instruction *Add = BinaryOperator::createAdd(LHSVal, AddCST, + LHSVal->getName()+".off"); + InsertNewInstBefore(Add, I); + return new ICmpInst(ICmpInst::ICMP_UGT, Add, + ConstantInt::get(Add->getType(), 1)); + } + break; // (X != 13 & X != 15) -> no change + } + break; + case ICmpInst::ICMP_ULT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u< 13 & X == 15) -> false + case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + case ICmpInst::ICMP_SGT: // (X u< 13 & X s> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X u< 13 & X != 15) -> X u< 13 + case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_SLT: // (X u< 13 & X s< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_SLT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s< 13 & X == 15) -> false + case ICmpInst::ICMP_SGT: // (X s< 13 & X s> 15) -> false + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + case ICmpInst::ICMP_UGT: // (X s< 13 & X u> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X s< 13 & X != 15) -> X < 13 + case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_ULT: // (X s< 13 & X u< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_UGT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u> 13 & X == 15) -> X > 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15 + return ReplaceInstUsesWith(I, RHS); + case ICmpInst::ICMP_SGT: // (X u> 13 & X s> 15) -> no change + break; + case ICmpInst::ICMP_NE: + if (RHSCst == AddOne(LHSCst)) // (X u> 13 & X != 14) -> X u> 14 + return new ICmpInst(LHSCC, LHSVal, RHSCst); + break; // (X u> 13 & X != 15) -> no change + case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) ->(X-14) 13 & X s< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_SGT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s> 13 & X == 15) -> X s> 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15 + return ReplaceInstUsesWith(I, RHS); + case ICmpInst::ICMP_UGT: // (X s> 13 & X u> 15) -> no change + break; + case ICmpInst::ICMP_NE: + if (RHSCst == AddOne(LHSCst)) // (X s> 13 & X != 14) -> X s> 14 + return new ICmpInst(LHSCC, LHSVal, RHSCst); + break; // (X s> 13 & X != 15) -> no change + case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) ->(X-14) s< 1 + return InsertRangeTest(LHSVal, AddOne(LHSCst), RHSCst, true, + true, I); + case ICmpInst::ICMP_ULT: // (X s> 13 & X u< 15) -> no change + break; + } + break; + } + } + } + + // fold (and (cast A), (cast B)) -> (cast (and A, B)) + if (CastInst *Op0C = dyn_cast(Op0)) + if (CastInst *Op1C = dyn_cast(Op1)) + if (Op0C->getOpcode() == Op1C->getOpcode()) { // same cast kind ? + const Type *SrcTy = Op0C->getOperand(0)->getType(); + if (SrcTy == Op1C->getOperand(0)->getType() && SrcTy->isInteger() && + // Only do this if the casts both really cause code to be generated. + ValueRequiresCast(Op0C->getOpcode(), Op0C->getOperand(0), + I.getType(), TD) && + ValueRequiresCast(Op1C->getOpcode(), Op1C->getOperand(0), + I.getType(), TD)) { + Instruction *NewOp = BinaryOperator::createAnd(Op0C->getOperand(0), + Op1C->getOperand(0), + I.getName()); + InsertNewInstBefore(NewOp, I); + return CastInst::create(Op0C->getOpcode(), NewOp, I.getType()); + } + } + + // (X >> Z) & (Y >> Z) -> (X&Y) >> Z for all shifts. + if (BinaryOperator *SI1 = dyn_cast(Op1)) { + if (BinaryOperator *SI0 = dyn_cast(Op0)) + if (SI0->isShift() && SI0->getOpcode() == SI1->getOpcode() && + SI0->getOperand(1) == SI1->getOperand(1) && + (SI0->hasOneUse() || SI1->hasOneUse())) { + Instruction *NewOp = + InsertNewInstBefore(BinaryOperator::createAnd(SI0->getOperand(0), + SI1->getOperand(0), + SI0->getName()), I); + return BinaryOperator::create(SI1->getOpcode(), NewOp, + SI1->getOperand(1)); + } + } + + return Changed ? &I : 0; +} + +/// CollectBSwapParts - Look to see if the specified value defines a single byte +/// in the result. If it does, and if the specified byte hasn't been filled in +/// yet, fill it in and return false. +static bool CollectBSwapParts(Value *V, SmallVector &ByteValues) { + Instruction *I = dyn_cast(V); + if (I == 0) return true; + + // If this is an or instruction, it is an inner node of the bswap. + if (I->getOpcode() == Instruction::Or) + return CollectBSwapParts(I->getOperand(0), ByteValues) || + CollectBSwapParts(I->getOperand(1), ByteValues); + + uint32_t BitWidth = I->getType()->getPrimitiveSizeInBits(); + // If this is a shift by a constant int, and it is "24", then its operand + // defines a byte. We only handle unsigned types here. + if (I->isShift() && isa(I->getOperand(1))) { + // Not shifting the entire input by N-1 bytes? + if (cast(I->getOperand(1))->getLimitedValue(BitWidth) != + 8*(ByteValues.size()-1)) + return true; + + unsigned DestNo; + if (I->getOpcode() == Instruction::Shl) { + // X << 24 defines the top byte with the lowest of the input bytes. + DestNo = ByteValues.size()-1; + } else { + // X >>u 24 defines the low byte with the highest of the input bytes. + DestNo = 0; + } + + // If the destination byte value is already defined, the values are or'd + // together, which isn't a bswap (unless it's an or of the same bits). + if (ByteValues[DestNo] && ByteValues[DestNo] != I->getOperand(0)) + return true; + ByteValues[DestNo] = I->getOperand(0); + return false; + } + + // Otherwise, we can only handle and(shift X, imm), imm). Bail out of if we + // don't have this. + Value *Shift = 0, *ShiftLHS = 0; + ConstantInt *AndAmt = 0, *ShiftAmt = 0; + if (!match(I, m_And(m_Value(Shift), m_ConstantInt(AndAmt))) || + !match(Shift, m_Shift(m_Value(ShiftLHS), m_ConstantInt(ShiftAmt)))) + return true; + Instruction *SI = cast(Shift); + + // Make sure that the shift amount is by a multiple of 8 and isn't too big. + if (ShiftAmt->getLimitedValue(BitWidth) & 7 || + ShiftAmt->getLimitedValue(BitWidth) > 8*ByteValues.size()) + return true; + + // Turn 0xFF -> 0, 0xFF00 -> 1, 0xFF0000 -> 2, etc. + unsigned DestByte; + if (AndAmt->getValue().getActiveBits() > 64) + return true; + uint64_t AndAmtVal = AndAmt->getZExtValue(); + for (DestByte = 0; DestByte != ByteValues.size(); ++DestByte) + if (AndAmtVal == uint64_t(0xFF) << 8*DestByte) + break; + // Unknown mask for bswap. + if (DestByte == ByteValues.size()) return true; + + unsigned ShiftBytes = ShiftAmt->getZExtValue()/8; + unsigned SrcByte; + if (SI->getOpcode() == Instruction::Shl) + SrcByte = DestByte - ShiftBytes; + else + SrcByte = DestByte + ShiftBytes; + + // If the SrcByte isn't a bswapped value from the DestByte, reject it. + if (SrcByte != ByteValues.size()-DestByte-1) + return true; + + // If the destination byte value is already defined, the values are or'd + // together, which isn't a bswap (unless it's an or of the same bits). + if (ByteValues[DestByte] && ByteValues[DestByte] != SI->getOperand(0)) + return true; + ByteValues[DestByte] = SI->getOperand(0); + return false; +} + +/// MatchBSwap - Given an OR instruction, check to see if this is a bswap idiom. +/// If so, insert the new bswap intrinsic and return it. +Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { + const IntegerType *ITy = dyn_cast(I.getType()); + if (!ITy || ITy->getBitWidth() % 16) + return 0; // Can only bswap pairs of bytes. Can't do vectors. + + /// ByteValues - For each byte of the result, we keep track of which value + /// defines each byte. + SmallVector ByteValues; + ByteValues.resize(ITy->getBitWidth()/8); + + // Try to find all the pieces corresponding to the bswap. + if (CollectBSwapParts(I.getOperand(0), ByteValues) || + CollectBSwapParts(I.getOperand(1), ByteValues)) + return 0; + + // Check to see if all of the bytes come from the same value. + Value *V = ByteValues[0]; + if (V == 0) return 0; // Didn't find a byte? Must be zero. + + // Check to make sure that all of the bytes come from the same value. + for (unsigned i = 1, e = ByteValues.size(); i != e; ++i) + if (ByteValues[i] != V) + return 0; + const Type *Tys[] = { ITy, ITy }; + Module *M = I.getParent()->getParent()->getParent(); + Function *F = Intrinsic::getDeclaration(M, Intrinsic::bswap, Tys, 2); + return new CallInst(F, V); +} + + +Instruction *InstCombiner::visitOr(BinaryOperator &I) { + bool Changed = SimplifyCommutative(I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (isa(Op1)) // X | undef -> -1 + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + + // or X, X = X + if (Op0 == Op1) + return ReplaceInstUsesWith(I, Op0); + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + if (!isa(I.getType())) { + uint32_t BitWidth = cast(I.getType())->getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + if (SimplifyDemandedBits(&I, APInt::getAllOnesValue(BitWidth), + KnownZero, KnownOne)) + return &I; + } else if (isa(Op1)) { + return ReplaceInstUsesWith(I, Op0); // X | <0,0> -> X + } else if (ConstantVector *CP = dyn_cast(Op1)) { + if (CP->isAllOnesValue()) // X | <-1,-1> -> <-1,-1> + return ReplaceInstUsesWith(I, I.getOperand(1)); + } + + + + // or X, -1 == -1 + if (ConstantInt *RHS = dyn_cast(Op1)) { + ConstantInt *C1 = 0; Value *X = 0; + // (X & C1) | C2 --> (X | C2) & (C1|C2) + if (match(Op0, m_And(m_Value(X), m_ConstantInt(C1))) && isOnlyUse(Op0)) { + Instruction *Or = BinaryOperator::createOr(X, RHS); + InsertNewInstBefore(Or, I); + Or->takeName(Op0); + return BinaryOperator::createAnd(Or, + ConstantInt::get(RHS->getValue() | C1->getValue())); + } + + // (X ^ C1) | C2 --> (X | C2) ^ (C1&~C2) + if (match(Op0, m_Xor(m_Value(X), m_ConstantInt(C1))) && isOnlyUse(Op0)) { + Instruction *Or = BinaryOperator::createOr(X, RHS); + InsertNewInstBefore(Or, I); + Or->takeName(Op0); + return BinaryOperator::createXor(Or, + ConstantInt::get(C1->getValue() & ~RHS->getValue())); + } + + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + if (isa(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + + Value *A = 0, *B = 0; + ConstantInt *C1 = 0, *C2 = 0; + + if (match(Op0, m_And(m_Value(A), m_Value(B)))) + if (A == Op1 || B == Op1) // (A & ?) | A --> A + return ReplaceInstUsesWith(I, Op1); + if (match(Op1, m_And(m_Value(A), m_Value(B)))) + if (A == Op0 || B == Op0) // A | (A & ?) --> A + return ReplaceInstUsesWith(I, Op0); + + // (A | B) | C and A | (B | C) -> bswap if possible. + // (A >> B) | (C << D) and (A << B) | (B >> C) -> bswap if possible. + if (match(Op0, m_Or(m_Value(), m_Value())) || + match(Op1, m_Or(m_Value(), m_Value())) || + (match(Op0, m_Shift(m_Value(), m_Value())) && + match(Op1, m_Shift(m_Value(), m_Value())))) { + if (Instruction *BSwap = MatchBSwap(I)) + return BSwap; + } + + // (X^C)|Y -> (X|Y)^C iff Y&C == 0 + if (Op0->hasOneUse() && match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && + MaskedValueIsZero(Op1, C1->getValue())) { + Instruction *NOr = BinaryOperator::createOr(A, Op1); + InsertNewInstBefore(NOr, I); + NOr->takeName(Op0); + return BinaryOperator::createXor(NOr, C1); + } + + // Y|(X^C) -> (X|Y)^C iff Y&C == 0 + if (Op1->hasOneUse() && match(Op1, m_Xor(m_Value(A), m_ConstantInt(C1))) && + MaskedValueIsZero(Op0, C1->getValue())) { + Instruction *NOr = BinaryOperator::createOr(A, Op0); + InsertNewInstBefore(NOr, I); + NOr->takeName(Op0); + return BinaryOperator::createXor(NOr, C1); + } + + // (A & C)|(B & D) + Value *C = 0, *D = 0; + if (match(Op0, m_And(m_Value(A), m_Value(C))) && + match(Op1, m_And(m_Value(B), m_Value(D)))) { + Value *V1 = 0, *V2 = 0, *V3 = 0; + C1 = dyn_cast(C); + C2 = dyn_cast(D); + if (C1 && C2) { // (A & C1)|(B & C2) + // If we have: ((V + N) & C1) | (V & C2) + // .. and C2 = ~C1 and C2 is 0+1+ and (N & C2) == 0 + // replace with V+N. + if (C1->getValue() == ~C2->getValue()) { + if ((C2->getValue() & (C2->getValue()+1)) == 0 && // C2 == 0+1+ + match(A, m_Add(m_Value(V1), m_Value(V2)))) { + // Add commutes, try both ways. + if (V1 == B && MaskedValueIsZero(V2, C2->getValue())) + return ReplaceInstUsesWith(I, A); + if (V2 == B && MaskedValueIsZero(V1, C2->getValue())) + return ReplaceInstUsesWith(I, A); + } + // Or commutes, try both ways. + if ((C1->getValue() & (C1->getValue()+1)) == 0 && + match(B, m_Add(m_Value(V1), m_Value(V2)))) { + // Add commutes, try both ways. + if (V1 == A && MaskedValueIsZero(V2, C1->getValue())) + return ReplaceInstUsesWith(I, B); + if (V2 == A && MaskedValueIsZero(V1, C1->getValue())) + return ReplaceInstUsesWith(I, B); + } + } + V1 = 0; V2 = 0; V3 = 0; + } + + // Check to see if we have any common things being and'ed. If so, find the + // terms for V1 & (V2|V3). + if (isOnlyUse(Op0) || isOnlyUse(Op1)) { + if (A == B) // (A & C)|(A & D) == A & (C|D) + V1 = A, V2 = C, V3 = D; + else if (A == D) // (A & C)|(B & A) == A & (B|C) + V1 = A, V2 = B, V3 = C; + else if (C == B) // (A & C)|(C & D) == C & (A|D) + V1 = C, V2 = A, V3 = D; + else if (C == D) // (A & C)|(B & C) == C & (A|B) + V1 = C, V2 = A, V3 = B; + + if (V1) { + Value *Or = + InsertNewInstBefore(BinaryOperator::createOr(V2, V3, "tmp"), I); + return BinaryOperator::createAnd(V1, Or); + } + + // (V1 & V3)|(V2 & ~V3) -> ((V1 ^ V2) & V3) ^ V2 + if (isOnlyUse(Op0) && isOnlyUse(Op1)) { + // Try all combination of terms to find V3 and ~V3. + if (A->hasOneUse() && match(A, m_Not(m_Value(V3)))) { + if (V3 == B) + V1 = D, V2 = C; + else if (V3 == D) + V1 = B, V2 = C; + } + if (B->hasOneUse() && match(B, m_Not(m_Value(V3)))) { + if (V3 == A) + V1 = C, V2 = D; + else if (V3 == C) + V1 = A, V2 = D; + } + if (C->hasOneUse() && match(C, m_Not(m_Value(V3)))) { + if (V3 == B) + V1 = D, V2 = A; + else if (V3 == D) + V1 = B, V2 = A; + } + if (D->hasOneUse() && match(D, m_Not(m_Value(V3)))) { + if (V3 == A) + V1 = C, V2 = B; + else if (V3 == C) + V1 = A, V2 = B; + } + if (V1) { + A = InsertNewInstBefore(BinaryOperator::createXor(V1, V2, "tmp"), I); + A = InsertNewInstBefore(BinaryOperator::createAnd(A, V3, "tmp"), I); + return BinaryOperator::createXor(A, V2); + } + } + } + } + + // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. + if (BinaryOperator *SI1 = dyn_cast(Op1)) { + if (BinaryOperator *SI0 = dyn_cast(Op0)) + if (SI0->isShift() && SI0->getOpcode() == SI1->getOpcode() && + SI0->getOperand(1) == SI1->getOperand(1) && + (SI0->hasOneUse() || SI1->hasOneUse())) { + Instruction *NewOp = + InsertNewInstBefore(BinaryOperator::createOr(SI0->getOperand(0), + SI1->getOperand(0), + SI0->getName()), I); + return BinaryOperator::create(SI1->getOpcode(), NewOp, + SI1->getOperand(1)); + } + } + + if (match(Op0, m_Not(m_Value(A)))) { // ~A | Op1 + if (A == Op1) // ~A | A == -1 + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + } else { + A = 0; + } + // Note, A is still live here! + if (match(Op1, m_Not(m_Value(B)))) { // Op0 | ~B + if (Op0 == B) + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + + // (~A | ~B) == (~(A & B)) - De Morgan's Law + if (A && isOnlyUse(Op0) && isOnlyUse(Op1)) { + Value *And = InsertNewInstBefore(BinaryOperator::createAnd(A, B, + I.getName()+".demorgan"), I); + return BinaryOperator::createNot(And); + } + } + + // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) + if (ICmpInst *RHS = dyn_cast(I.getOperand(1))) { + if (Instruction *R = AssociativeOpt(I, FoldICmpLogical(*this, RHS))) + return R; + + Value *LHSVal, *RHSVal; + ConstantInt *LHSCst, *RHSCst; + ICmpInst::Predicate LHSCC, RHSCC; + if (match(Op0, m_ICmp(LHSCC, m_Value(LHSVal), m_ConstantInt(LHSCst)))) + if (match(RHS, m_ICmp(RHSCC, m_Value(RHSVal), m_ConstantInt(RHSCst)))) + if (LHSVal == RHSVal && // Found (X icmp C1) | (X icmp C2) + // icmp [us][gl]e x, cst is folded to icmp [us][gl]t elsewhere. + LHSCC != ICmpInst::ICMP_UGE && LHSCC != ICmpInst::ICMP_ULE && + RHSCC != ICmpInst::ICMP_UGE && RHSCC != ICmpInst::ICMP_ULE && + LHSCC != ICmpInst::ICMP_SGE && LHSCC != ICmpInst::ICMP_SLE && + RHSCC != ICmpInst::ICMP_SGE && RHSCC != ICmpInst::ICMP_SLE && + // We can't fold (ugt x, C) | (sgt x, C2). + PredicatesFoldable(LHSCC, RHSCC)) { + // Ensure that the larger constant is on the RHS. + ICmpInst *LHS = cast(Op0); + bool NeedsSwap; + if (ICmpInst::isSignedPredicate(LHSCC)) + NeedsSwap = LHSCst->getValue().sgt(RHSCst->getValue()); + else + NeedsSwap = LHSCst->getValue().ugt(RHSCst->getValue()); + + if (NeedsSwap) { + std::swap(LHS, RHS); + std::swap(LHSCst, RHSCst); + std::swap(LHSCC, RHSCC); + } + + // At this point, we know we have have two icmp instructions + // comparing a value against two constants and or'ing the result + // together. Because of the above check, we know that we only have + // ICMP_EQ, ICMP_NE, ICMP_LT, and ICMP_GT here. We also know (from the + // FoldICmpLogical check above), that the two constants are not + // equal. + assert(LHSCst != RHSCst && "Compares not folded above?"); + + switch (LHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: + if (LHSCst == SubOne(RHSCst)) {// (X == 13 | X == 14) -> X-13 getName()+".off"); + InsertNewInstBefore(Add, I); + AddCST = Subtract(AddOne(RHSCst), LHSCst); + return new ICmpInst(ICmpInst::ICMP_ULT, Add, AddCST); + } + break; // (X == 13 | X == 15) -> no change + case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change + case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change + break; + case ICmpInst::ICMP_NE: // (X == 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15 + case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15 + return ReplaceInstUsesWith(I, RHS); + } + break; + case ICmpInst::ICMP_NE: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X != 13 | X == 15) -> X != 13 + case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13 + case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_NE: // (X != 13 | X != 15) -> true + case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true + case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + } + break; + case ICmpInst::ICMP_ULT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change + break; + case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) ->(X-13) u> 2 + return InsertRangeTest(LHSVal, LHSCst, AddOne(RHSCst), false, + false, I); + case ICmpInst::ICMP_SGT: // (X u< 13 | X s> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15 + return ReplaceInstUsesWith(I, RHS); + case ICmpInst::ICMP_SLT: // (X u< 13 | X s< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_SLT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change + break; + case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) ->(X-13) s> 2 + return InsertRangeTest(LHSVal, LHSCst, AddOne(RHSCst), true, + false, I); + case ICmpInst::ICMP_UGT: // (X s< 13 | X u> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15 + return ReplaceInstUsesWith(I, RHS); + case ICmpInst::ICMP_ULT: // (X s< 13 | X u< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_UGT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u> 13 | X == 15) -> X u> 13 + case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_SGT: // (X u> 13 | X s> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X u> 13 | X != 15) -> true + case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + case ICmpInst::ICMP_SLT: // (X u> 13 | X s< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_SGT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s> 13 | X == 15) -> X > 13 + case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_UGT: // (X s> 13 | X u> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X s> 13 | X != 15) -> true + case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + case ICmpInst::ICMP_ULT: // (X s> 13 | X u< 15) -> no change + break; + } + break; + } + } + } + + // fold (or (cast A), (cast B)) -> (cast (or A, B)) + if (CastInst *Op0C = dyn_cast(Op0)) + if (CastInst *Op1C = dyn_cast(Op1)) + if (Op0C->getOpcode() == Op1C->getOpcode()) {// same cast kind ? + const Type *SrcTy = Op0C->getOperand(0)->getType(); + if (SrcTy == Op1C->getOperand(0)->getType() && SrcTy->isInteger() && + // Only do this if the casts both really cause code to be generated. + ValueRequiresCast(Op0C->getOpcode(), Op0C->getOperand(0), + I.getType(), TD) && + ValueRequiresCast(Op1C->getOpcode(), Op1C->getOperand(0), + I.getType(), TD)) { + Instruction *NewOp = BinaryOperator::createOr(Op0C->getOperand(0), + Op1C->getOperand(0), + I.getName()); + InsertNewInstBefore(NewOp, I); + return CastInst::create(Op0C->getOpcode(), NewOp, I.getType()); + } + } + + + return Changed ? &I : 0; +} + +// XorSelf - Implements: X ^ X --> 0 +struct XorSelf { + Value *RHS; + XorSelf(Value *rhs) : RHS(rhs) {} + bool shouldApply(Value *LHS) const { return LHS == RHS; } + Instruction *apply(BinaryOperator &Xor) const { + return &Xor; + } +}; + + +Instruction *InstCombiner::visitXor(BinaryOperator &I) { + bool Changed = SimplifyCommutative(I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (isa(Op1)) + return ReplaceInstUsesWith(I, Op1); // X ^ undef -> undef + + // xor X, X = 0, even if X is nested in a sequence of Xor's. + if (Instruction *Result = AssociativeOpt(I, XorSelf(Op1))) { + assert(Result == &I && "AssociativeOpt didn't work?"); + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + } + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + if (!isa(I.getType())) { + uint32_t BitWidth = cast(I.getType())->getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + if (SimplifyDemandedBits(&I, APInt::getAllOnesValue(BitWidth), + KnownZero, KnownOne)) + return &I; + } else if (isa(Op1)) { + return ReplaceInstUsesWith(I, Op0); // X ^ <0,0> -> X + } + + // Is this a ~ operation? + if (Value *NotOp = dyn_castNotVal(&I)) { + // ~(~X & Y) --> (X | ~Y) - De Morgan's Law + // ~(~X | Y) === (X & ~Y) - De Morgan's Law + if (BinaryOperator *Op0I = dyn_cast(NotOp)) { + if (Op0I->getOpcode() == Instruction::And || + Op0I->getOpcode() == Instruction::Or) { + if (dyn_castNotVal(Op0I->getOperand(1))) Op0I->swapOperands(); + if (Value *Op0NotVal = dyn_castNotVal(Op0I->getOperand(0))) { + Instruction *NotY = + BinaryOperator::createNot(Op0I->getOperand(1), + Op0I->getOperand(1)->getName()+".not"); + InsertNewInstBefore(NotY, I); + if (Op0I->getOpcode() == Instruction::And) + return BinaryOperator::createOr(Op0NotVal, NotY); + else + return BinaryOperator::createAnd(Op0NotVal, NotY); + } + } + } + } + + + if (ConstantInt *RHS = dyn_cast(Op1)) { + // xor (icmp A, B), true = not (icmp A, B) = !icmp A, B + if (ICmpInst *ICI = dyn_cast(Op0)) + if (RHS == ConstantInt::getTrue() && ICI->hasOneUse()) + return new ICmpInst(ICI->getInversePredicate(), + ICI->getOperand(0), ICI->getOperand(1)); + + if (BinaryOperator *Op0I = dyn_cast(Op0)) { + // ~(c-X) == X-c-1 == X+(-c-1) + if (Op0I->getOpcode() == Instruction::Sub && RHS->isAllOnesValue()) + if (Constant *Op0I0C = dyn_cast(Op0I->getOperand(0))) { + Constant *NegOp0I0C = ConstantExpr::getNeg(Op0I0C); + Constant *ConstantRHS = ConstantExpr::getSub(NegOp0I0C, + ConstantInt::get(I.getType(), 1)); + return BinaryOperator::createAdd(Op0I->getOperand(1), ConstantRHS); + } + + if (ConstantInt *Op0CI = dyn_cast(Op0I->getOperand(1))) + if (Op0I->getOpcode() == Instruction::Add) { + // ~(X-c) --> (-c-1)-X + if (RHS->isAllOnesValue()) { + Constant *NegOp0CI = ConstantExpr::getNeg(Op0CI); + return BinaryOperator::createSub( + ConstantExpr::getSub(NegOp0CI, + ConstantInt::get(I.getType(), 1)), + Op0I->getOperand(0)); + } else if (RHS->getValue().isSignBit()) { + // (X + C) ^ signbit -> (X + C + signbit) + Constant *C = ConstantInt::get(RHS->getValue() + Op0CI->getValue()); + return BinaryOperator::createAdd(Op0I->getOperand(0), C); + + } + } else if (Op0I->getOpcode() == Instruction::Or) { + // (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0 + if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue())) { + Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHS); + // Anything in both C1 and C2 is known to be zero, remove it from + // NewRHS. + Constant *CommonBits = And(Op0CI, RHS); + NewRHS = ConstantExpr::getAnd(NewRHS, + ConstantExpr::getNot(CommonBits)); + AddToWorkList(Op0I); + I.setOperand(0, Op0I->getOperand(0)); + I.setOperand(1, NewRHS); + return &I; + } + } + } + + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + if (isa(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + + if (Value *X = dyn_castNotVal(Op0)) // ~A ^ A == -1 + if (X == Op1) + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + + if (Value *X = dyn_castNotVal(Op1)) // A ^ ~A == -1 + if (X == Op0) + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + + + BinaryOperator *Op1I = dyn_cast(Op1); + if (Op1I) { + Value *A, *B; + if (match(Op1I, m_Or(m_Value(A), m_Value(B)))) { + if (A == Op0) { // B^(B|A) == (A|B)^B + Op1I->swapOperands(); + I.swapOperands(); + std::swap(Op0, Op1); + } else if (B == Op0) { // B^(A|B) == (A|B)^B + I.swapOperands(); // Simplified below. + std::swap(Op0, Op1); + } + } else if (match(Op1I, m_Xor(m_Value(A), m_Value(B)))) { + if (Op0 == A) // A^(A^B) == B + return ReplaceInstUsesWith(I, B); + else if (Op0 == B) // A^(B^A) == B + return ReplaceInstUsesWith(I, A); + } else if (match(Op1I, m_And(m_Value(A), m_Value(B))) && Op1I->hasOneUse()){ + if (A == Op0) { // A^(A&B) -> A^(B&A) + Op1I->swapOperands(); + std::swap(A, B); + } + if (B == Op0) { // A^(B&A) -> (B&A)^A + I.swapOperands(); // Simplified below. + std::swap(Op0, Op1); + } + } + } + + BinaryOperator *Op0I = dyn_cast(Op0); + if (Op0I) { + Value *A, *B; + if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && Op0I->hasOneUse()) { + if (A == Op1) // (B|A)^B == (A|B)^B + std::swap(A, B); + if (B == Op1) { // (A|B)^B == A & ~B + Instruction *NotB = + InsertNewInstBefore(BinaryOperator::createNot(Op1, "tmp"), I); + return BinaryOperator::createAnd(A, NotB); + } + } else if (match(Op0I, m_Xor(m_Value(A), m_Value(B)))) { + if (Op1 == A) // (A^B)^A == B + return ReplaceInstUsesWith(I, B); + else if (Op1 == B) // (B^A)^A == B + return ReplaceInstUsesWith(I, A); + } else if (match(Op0I, m_And(m_Value(A), m_Value(B))) && Op0I->hasOneUse()){ + if (A == Op1) // (A&B)^A -> (B&A)^A + std::swap(A, B); + if (B == Op1 && // (B&A)^A == ~B & A + !isa(Op1)) { // Canonical form is (B&C)^C + Instruction *N = + InsertNewInstBefore(BinaryOperator::createNot(A, "tmp"), I); + return BinaryOperator::createAnd(N, Op1); + } + } + } + + // (X >> Z) ^ (Y >> Z) -> (X^Y) >> Z for all shifts. + if (Op0I && Op1I && Op0I->isShift() && + Op0I->getOpcode() == Op1I->getOpcode() && + Op0I->getOperand(1) == Op1I->getOperand(1) && + (Op1I->hasOneUse() || Op1I->hasOneUse())) { + Instruction *NewOp = + InsertNewInstBefore(BinaryOperator::createXor(Op0I->getOperand(0), + Op1I->getOperand(0), + Op0I->getName()), I); + return BinaryOperator::create(Op1I->getOpcode(), NewOp, + Op1I->getOperand(1)); + } + + if (Op0I && Op1I) { + Value *A, *B, *C, *D; + // (A & B)^(A | B) -> A ^ B + if (match(Op0I, m_And(m_Value(A), m_Value(B))) && + match(Op1I, m_Or(m_Value(C), m_Value(D)))) { + if ((A == C && B == D) || (A == D && B == C)) + return BinaryOperator::createXor(A, B); + } + // (A | B)^(A & B) -> A ^ B + if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && + match(Op1I, m_And(m_Value(C), m_Value(D)))) { + if ((A == C && B == D) || (A == D && B == C)) + return BinaryOperator::createXor(A, B); + } + + // (A & B)^(C & D) + if ((Op0I->hasOneUse() || Op1I->hasOneUse()) && + match(Op0I, m_And(m_Value(A), m_Value(B))) && + match(Op1I, m_And(m_Value(C), m_Value(D)))) { + // (X & Y)^(X & Y) -> (Y^Z) & X + Value *X = 0, *Y = 0, *Z = 0; + if (A == C) + X = A, Y = B, Z = D; + else if (A == D) + X = A, Y = B, Z = C; + else if (B == C) + X = B, Y = A, Z = D; + else if (B == D) + X = B, Y = A, Z = C; + + if (X) { + Instruction *NewOp = + InsertNewInstBefore(BinaryOperator::createXor(Y, Z, Op0->getName()), I); + return BinaryOperator::createAnd(NewOp, X); + } + } + } + + // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) + if (ICmpInst *RHS = dyn_cast(I.getOperand(1))) + if (Instruction *R = AssociativeOpt(I, FoldICmpLogical(*this, RHS))) + return R; + + // fold (xor (cast A), (cast B)) -> (cast (xor A, B)) + if (CastInst *Op0C = dyn_cast(Op0)) + if (CastInst *Op1C = dyn_cast(Op1)) + if (Op0C->getOpcode() == Op1C->getOpcode()) { // same cast kind? + const Type *SrcTy = Op0C->getOperand(0)->getType(); + if (SrcTy == Op1C->getOperand(0)->getType() && SrcTy->isInteger() && + // Only do this if the casts both really cause code to be generated. + ValueRequiresCast(Op0C->getOpcode(), Op0C->getOperand(0), + I.getType(), TD) && + ValueRequiresCast(Op1C->getOpcode(), Op1C->getOperand(0), + I.getType(), TD)) { + Instruction *NewOp = BinaryOperator::createXor(Op0C->getOperand(0), + Op1C->getOperand(0), + I.getName()); + InsertNewInstBefore(NewOp, I); + return CastInst::create(Op0C->getOpcode(), NewOp, I.getType()); + } + } + + return Changed ? &I : 0; +} + +/// AddWithOverflow - Compute Result = In1+In2, returning true if the result +/// overflowed for this type. +static bool AddWithOverflow(ConstantInt *&Result, ConstantInt *In1, + ConstantInt *In2, bool IsSigned = false) { + Result = cast(Add(In1, In2)); + + if (IsSigned) + if (In2->getValue().isNegative()) + return Result->getValue().sgt(In1->getValue()); + else + return Result->getValue().slt(In1->getValue()); + else + return Result->getValue().ult(In1->getValue()); +} + +/// EmitGEPOffset - Given a getelementptr instruction/constantexpr, emit the +/// code necessary to compute the offset from the base pointer (without adding +/// in the base pointer). Return the result as a signed integer of intptr size. +static Value *EmitGEPOffset(User *GEP, Instruction &I, InstCombiner &IC) { + TargetData &TD = IC.getTargetData(); + gep_type_iterator GTI = gep_type_begin(GEP); + const Type *IntPtrTy = TD.getIntPtrType(); + Value *Result = Constant::getNullValue(IntPtrTy); + + // Build a mask for high order bits. + unsigned IntPtrWidth = TD.getPointerSize()*8; + uint64_t PtrSizeMask = ~0ULL >> (64-IntPtrWidth); + + for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { + Value *Op = GEP->getOperand(i); + uint64_t Size = TD.getTypeSize(GTI.getIndexedType()) & PtrSizeMask; + if (ConstantInt *OpC = dyn_cast(Op)) { + if (OpC->isZero()) continue; + + // Handle a struct index, which adds its field offset to the pointer. + if (const StructType *STy = dyn_cast(*GTI)) { + Size = TD.getStructLayout(STy)->getElementOffset(OpC->getZExtValue()); + + if (ConstantInt *RC = dyn_cast(Result)) + Result = ConstantInt::get(RC->getValue() + APInt(IntPtrWidth, Size)); + else + Result = IC.InsertNewInstBefore( + BinaryOperator::createAdd(Result, + ConstantInt::get(IntPtrTy, Size), + GEP->getName()+".offs"), I); + continue; + } + + Constant *Scale = ConstantInt::get(IntPtrTy, Size); + Constant *OC = ConstantExpr::getIntegerCast(OpC, IntPtrTy, true /*SExt*/); + Scale = ConstantExpr::getMul(OC, Scale); + if (Constant *RC = dyn_cast(Result)) + Result = ConstantExpr::getAdd(RC, Scale); + else { + // Emit an add instruction. + Result = IC.InsertNewInstBefore( + BinaryOperator::createAdd(Result, Scale, + GEP->getName()+".offs"), I); + } + continue; + } + // Convert to correct type. + if (Op->getType() != IntPtrTy) { + if (Constant *OpC = dyn_cast(Op)) + Op = ConstantExpr::getSExt(OpC, IntPtrTy); + else + Op = IC.InsertNewInstBefore(new SExtInst(Op, IntPtrTy, + Op->getName()+".c"), I); + } + if (Size != 1) { + Constant *Scale = ConstantInt::get(IntPtrTy, Size); + if (Constant *OpC = dyn_cast(Op)) + Op = ConstantExpr::getMul(OpC, Scale); + else // We'll let instcombine(mul) convert this to a shl if possible. + Op = IC.InsertNewInstBefore(BinaryOperator::createMul(Op, Scale, + GEP->getName()+".idx"), I); + } + + // Emit an add instruction. + if (isa(Op) && isa(Result)) + Result = ConstantExpr::getAdd(cast(Op), + cast(Result)); + else + Result = IC.InsertNewInstBefore(BinaryOperator::createAdd(Op, Result, + GEP->getName()+".offs"), I); + } + return Result; +} + +/// FoldGEPICmp - Fold comparisons between a GEP instruction and something +/// else. At this point we know that the GEP is on the LHS of the comparison. +Instruction *InstCombiner::FoldGEPICmp(User *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, + Instruction &I) { + assert(dyn_castGetElementPtr(GEPLHS) && "LHS is not a getelementptr!"); + + if (CastInst *CI = dyn_cast(RHS)) + if (isa(CI->getOperand(0)->getType())) + RHS = CI->getOperand(0); + + Value *PtrBase = GEPLHS->getOperand(0); + if (PtrBase == RHS) { + // As an optimization, we don't actually have to compute the actual value of + // OFFSET if this is a icmp_eq or icmp_ne comparison, just return whether + // each index is zero or not. + if (Cond == ICmpInst::ICMP_EQ || Cond == ICmpInst::ICMP_NE) { + Instruction *InVal = 0; + gep_type_iterator GTI = gep_type_begin(GEPLHS); + for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i, ++GTI) { + bool EmitIt = true; + if (Constant *C = dyn_cast(GEPLHS->getOperand(i))) { + if (isa(C)) // undef index -> undef. + return ReplaceInstUsesWith(I, UndefValue::get(I.getType())); + if (C->isNullValue()) + EmitIt = false; + else if (TD->getTypeSize(GTI.getIndexedType()) == 0) { + EmitIt = false; // This is indexing into a zero sized array? + } else if (isa(C)) + return ReplaceInstUsesWith(I, // No comparison is needed here. + ConstantInt::get(Type::Int1Ty, + Cond == ICmpInst::ICMP_NE)); + } + + if (EmitIt) { + Instruction *Comp = + new ICmpInst(Cond, GEPLHS->getOperand(i), + Constant::getNullValue(GEPLHS->getOperand(i)->getType())); + if (InVal == 0) + InVal = Comp; + else { + InVal = InsertNewInstBefore(InVal, I); + InsertNewInstBefore(Comp, I); + if (Cond == ICmpInst::ICMP_NE) // True if any are unequal + InVal = BinaryOperator::createOr(InVal, Comp); + else // True if all are equal + InVal = BinaryOperator::createAnd(InVal, Comp); + } + } + } + + if (InVal) + return InVal; + else + // No comparison is needed here, all indexes = 0 + ReplaceInstUsesWith(I, ConstantInt::get(Type::Int1Ty, + Cond == ICmpInst::ICMP_EQ)); + } + + // Only lower this if the icmp is the only user of the GEP or if we expect + // the result to fold to a constant! + if (isa(GEPLHS) || GEPLHS->hasOneUse()) { + // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). + Value *Offset = EmitGEPOffset(GEPLHS, I, *this); + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, + Constant::getNullValue(Offset->getType())); + } + } else if (User *GEPRHS = dyn_castGetElementPtr(RHS)) { + // If the base pointers are different, but the indices are the same, just + // compare the base pointer. + if (PtrBase != GEPRHS->getOperand(0)) { + bool IndicesTheSame = GEPLHS->getNumOperands()==GEPRHS->getNumOperands(); + IndicesTheSame &= GEPLHS->getOperand(0)->getType() == + GEPRHS->getOperand(0)->getType(); + if (IndicesTheSame) + for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i) + if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) { + IndicesTheSame = false; + break; + } + + // If all indices are the same, just compare the base pointers. + if (IndicesTheSame) + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), + GEPLHS->getOperand(0), GEPRHS->getOperand(0)); + + // Otherwise, the base pointers are different and the indices are + // different, bail out. + return 0; + } + + // If one of the GEPs has all zero indices, recurse. + bool AllZeros = true; + for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i) + if (!isa(GEPLHS->getOperand(i)) || + !cast(GEPLHS->getOperand(i))->isNullValue()) { + AllZeros = false; + break; + } + if (AllZeros) + return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0), + ICmpInst::getSwappedPredicate(Cond), I); + + // If the other GEP has all zero indices, recurse. + AllZeros = true; + for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) + if (!isa(GEPRHS->getOperand(i)) || + !cast(GEPRHS->getOperand(i))->isNullValue()) { + AllZeros = false; + break; + } + if (AllZeros) + return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); + + if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands()) { + // If the GEPs only differ by one index, compare it. + unsigned NumDifferences = 0; // Keep track of # differences. + unsigned DiffOperand = 0; // The operand that differs. + for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) + if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) { + if (GEPLHS->getOperand(i)->getType()->getPrimitiveSizeInBits() != + GEPRHS->getOperand(i)->getType()->getPrimitiveSizeInBits()) { + // Irreconcilable differences. + NumDifferences = 2; + break; + } else { + if (NumDifferences++) break; + DiffOperand = i; + } + } + + if (NumDifferences == 0) // SAME GEP? + return ReplaceInstUsesWith(I, // No comparison is needed here. + ConstantInt::get(Type::Int1Ty, + Cond == ICmpInst::ICMP_EQ)); + else if (NumDifferences == 1) { + Value *LHSV = GEPLHS->getOperand(DiffOperand); + Value *RHSV = GEPRHS->getOperand(DiffOperand); + // Make sure we do a signed comparison here. + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), LHSV, RHSV); + } + } + + // Only lower this if the icmp is the only user of the GEP or if we expect + // the result to fold to a constant! + if ((isa(GEPLHS) || GEPLHS->hasOneUse()) && + (isa(GEPRHS) || GEPRHS->hasOneUse())) { + // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) + Value *L = EmitGEPOffset(GEPLHS, I, *this); + Value *R = EmitGEPOffset(GEPRHS, I, *this); + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R); + } + } + return 0; +} + +Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { + bool Changed = SimplifyCompare(I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Fold trivial predicates. + if (I.getPredicate() == FCmpInst::FCMP_FALSE) + return ReplaceInstUsesWith(I, Constant::getNullValue(Type::Int1Ty)); + if (I.getPredicate() == FCmpInst::FCMP_TRUE) + return ReplaceInstUsesWith(I, ConstantInt::get(Type::Int1Ty, 1)); + + // Simplify 'fcmp pred X, X' + if (Op0 == Op1) { + switch (I.getPredicate()) { + default: assert(0 && "Unknown predicate!"); + case FCmpInst::FCMP_UEQ: // True if unordered or equal + case FCmpInst::FCMP_UGE: // True if unordered, greater than, or equal + case FCmpInst::FCMP_ULE: // True if unordered, less than, or equal + return ReplaceInstUsesWith(I, ConstantInt::get(Type::Int1Ty, 1)); + case FCmpInst::FCMP_OGT: // True if ordered and greater than + case FCmpInst::FCMP_OLT: // True if ordered and less than + case FCmpInst::FCMP_ONE: // True if ordered and operands are unequal + return ReplaceInstUsesWith(I, ConstantInt::get(Type::Int1Ty, 0)); + + case FCmpInst::FCMP_UNO: // True if unordered: isnan(X) | isnan(Y) + case FCmpInst::FCMP_ULT: // True if unordered or less than + case FCmpInst::FCMP_UGT: // True if unordered or greater than + case FCmpInst::FCMP_UNE: // True if unordered or not equal + // Canonicalize these to be 'fcmp uno %X, 0.0'. + I.setPredicate(FCmpInst::FCMP_UNO); + I.setOperand(1, Constant::getNullValue(Op0->getType())); + return &I; + + case FCmpInst::FCMP_ORD: // True if ordered (no nans) + case FCmpInst::FCMP_OEQ: // True if ordered and equal + case FCmpInst::FCMP_OGE: // True if ordered and greater than or equal + case FCmpInst::FCMP_OLE: // True if ordered and less than or equal + // Canonicalize these to be 'fcmp ord %X, 0.0'. + I.setPredicate(FCmpInst::FCMP_ORD); + I.setOperand(1, Constant::getNullValue(Op0->getType())); + return &I; + } + } + + if (isa(Op1)) // fcmp pred X, undef -> undef + return ReplaceInstUsesWith(I, UndefValue::get(Type::Int1Ty)); + + // Handle fcmp with constant RHS + if (Constant *RHSC = dyn_cast(Op1)) { + if (Instruction *LHSI = dyn_cast(Op0)) + switch (LHSI->getOpcode()) { + case Instruction::PHI: + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + break; + case Instruction::Select: + // If either operand of the select is a constant, we can fold the + // comparison into the select arms, which will cause one to be + // constant folded and the select turned into a bitwise or. + Value *Op1 = 0, *Op2 = 0; + if (LHSI->hasOneUse()) { + if (Constant *C = dyn_cast(LHSI->getOperand(1))) { + // Fold the known value into the constant operand. + Op1 = ConstantExpr::getCompare(I.getPredicate(), C, RHSC); + // Insert a new FCmp of the other select operand. + Op2 = InsertNewInstBefore(new FCmpInst(I.getPredicate(), + LHSI->getOperand(2), RHSC, + I.getName()), I); + } else if (Constant *C = dyn_cast(LHSI->getOperand(2))) { + // Fold the known value into the constant operand. + Op2 = ConstantExpr::getCompare(I.getPredicate(), C, RHSC); + // Insert a new FCmp of the other select operand. + Op1 = InsertNewInstBefore(new FCmpInst(I.getPredicate(), + LHSI->getOperand(1), RHSC, + I.getName()), I); + } + } + + if (Op1) + return new SelectInst(LHSI->getOperand(0), Op1, Op2); + break; + } + } + + return Changed ? &I : 0; +} + +Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { + bool Changed = SimplifyCompare(I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + const Type *Ty = Op0->getType(); + + // icmp X, X + if (Op0 == Op1) + return ReplaceInstUsesWith(I, ConstantInt::get(Type::Int1Ty, + isTrueWhenEqual(I))); + + if (isa(Op1)) // X icmp undef -> undef + return ReplaceInstUsesWith(I, UndefValue::get(Type::Int1Ty)); + + // icmp of GlobalValues can never equal each other as long as they aren't + // external weak linkage type. + if (GlobalValue *GV0 = dyn_cast(Op0)) + if (GlobalValue *GV1 = dyn_cast(Op1)) + if (!GV0->hasExternalWeakLinkage() || !GV1->hasExternalWeakLinkage()) + return ReplaceInstUsesWith(I, ConstantInt::get(Type::Int1Ty, + !isTrueWhenEqual(I))); + + // icmp , - Global/Stack value + // addresses never equal each other! We already know that Op0 != Op1. + if ((isa(Op0) || isa(Op0) || + isa(Op0)) && + (isa(Op1) || isa(Op1) || + isa(Op1))) + return ReplaceInstUsesWith(I, ConstantInt::get(Type::Int1Ty, + !isTrueWhenEqual(I))); + + // icmp's with boolean values can always be turned into bitwise operations + if (Ty == Type::Int1Ty) { + switch (I.getPredicate()) { + default: assert(0 && "Invalid icmp instruction!"); + case ICmpInst::ICMP_EQ: { // icmp eq bool %A, %B -> ~(A^B) + Instruction *Xor = BinaryOperator::createXor(Op0, Op1, I.getName()+"tmp"); + InsertNewInstBefore(Xor, I); + return BinaryOperator::createNot(Xor); + } + case ICmpInst::ICMP_NE: // icmp eq bool %A, %B -> A^B + return BinaryOperator::createXor(Op0, Op1); + + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + std::swap(Op0, Op1); // Change icmp gt -> icmp lt + // FALL THROUGH + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: { // icmp lt bool A, B -> ~X & Y + Instruction *Not = BinaryOperator::createNot(Op0, I.getName()+"tmp"); + InsertNewInstBefore(Not, I); + return BinaryOperator::createAnd(Not, Op1); + } + case ICmpInst::ICMP_UGE: + case ICmpInst::ICMP_SGE: + std::swap(Op0, Op1); // Change icmp ge -> icmp le + // FALL THROUGH + case ICmpInst::ICMP_ULE: + case ICmpInst::ICMP_SLE: { // icmp le bool %A, %B -> ~A | B + Instruction *Not = BinaryOperator::createNot(Op0, I.getName()+"tmp"); + InsertNewInstBefore(Not, I); + return BinaryOperator::createOr(Not, Op1); + } + } + } + + // See if we are doing a comparison between a constant and an instruction that + // can be folded into the comparison. + if (ConstantInt *CI = dyn_cast(Op1)) { + switch (I.getPredicate()) { + default: break; + case ICmpInst::ICMP_ULT: // A FALSE + if (CI->isMinValue(false)) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + if (CI->isMaxValue(false)) // A A != MAX + return new ICmpInst(ICmpInst::ICMP_NE, Op0,Op1); + if (isMinValuePlusOne(CI,false)) // A A == MIN + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI)); + // (x (x >s -1) -> true if sign bit clear + if (CI->isMinValue(true)) + return new ICmpInst(ICmpInst::ICMP_SGT, Op0, + ConstantInt::getAllOnesValue(Op0->getType())); + + break; + + case ICmpInst::ICMP_SLT: + if (CI->isMinValue(true)) // A FALSE + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + if (CI->isMaxValue(true)) // A A != MAX + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (isMinValuePlusOne(CI,true)) // A A == MIN + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI)); + break; + + case ICmpInst::ICMP_UGT: + if (CI->isMaxValue(false)) // A >u MAX -> FALSE + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + if (CI->isMinValue(false)) // A >u MIN -> A != MIN + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (isMaxValueMinusOne(CI, false)) // A >u MAX-1 -> A == MAX + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI)); + + // (x >u 2147483647) -> (x true if sign bit set + if (CI->isMaxValue(true)) + return new ICmpInst(ICmpInst::ICMP_SLT, Op0, + ConstantInt::getNullValue(Op0->getType())); + break; + + case ICmpInst::ICMP_SGT: + if (CI->isMaxValue(true)) // A >s MAX -> FALSE + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + if (CI->isMinValue(true)) // A >s MIN -> A != MIN + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (isMaxValueMinusOne(CI, true)) // A >s MAX-1 -> A == MAX + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI)); + break; + + case ICmpInst::ICMP_ULE: + if (CI->isMaxValue(false)) // A <=u MAX -> TRUE + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (CI->isMinValue(false)) // A <=u MIN -> A == MIN + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + if (isMaxValueMinusOne(CI,false)) // A <=u MAX-1 -> A != MAX + return new ICmpInst(ICmpInst::ICMP_NE, Op0, AddOne(CI)); + break; + + case ICmpInst::ICMP_SLE: + if (CI->isMaxValue(true)) // A <=s MAX -> TRUE + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (CI->isMinValue(true)) // A <=s MIN -> A == MIN + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + if (isMaxValueMinusOne(CI,true)) // A <=s MAX-1 -> A != MAX + return new ICmpInst(ICmpInst::ICMP_NE, Op0, AddOne(CI)); + break; + + case ICmpInst::ICMP_UGE: + if (CI->isMinValue(false)) // A >=u MIN -> TRUE + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (CI->isMaxValue(false)) // A >=u MAX -> A == MAX + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + if (isMinValuePlusOne(CI,false)) // A >=u MIN-1 -> A != MIN + return new ICmpInst(ICmpInst::ICMP_NE, Op0, SubOne(CI)); + break; + + case ICmpInst::ICMP_SGE: + if (CI->isMinValue(true)) // A >=s MIN -> TRUE + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (CI->isMaxValue(true)) // A >=s MAX -> A == MAX + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + if (isMinValuePlusOne(CI,true)) // A >=s MIN-1 -> A != MIN + return new ICmpInst(ICmpInst::ICMP_NE, Op0, SubOne(CI)); + break; + } + + // If we still have a icmp le or icmp ge instruction, turn it into the + // appropriate icmp lt or icmp gt instruction. Since the border cases have + // already been handled above, this requires little checking. + // + switch (I.getPredicate()) { + default: break; + case ICmpInst::ICMP_ULE: + return new ICmpInst(ICmpInst::ICMP_ULT, Op0, AddOne(CI)); + case ICmpInst::ICMP_SLE: + return new ICmpInst(ICmpInst::ICMP_SLT, Op0, AddOne(CI)); + case ICmpInst::ICMP_UGE: + return new ICmpInst( ICmpInst::ICMP_UGT, Op0, SubOne(CI)); + case ICmpInst::ICMP_SGE: + return new ICmpInst(ICmpInst::ICMP_SGT, Op0, SubOne(CI)); + } + + // See if we can fold the comparison based on bits known to be zero or one + // in the input. If this comparison is a normal comparison, it demands all + // bits, if it is a sign bit comparison, it only demands the sign bit. + + bool UnusedBit; + bool isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit); + + uint32_t BitWidth = cast(Ty)->getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + if (SimplifyDemandedBits(Op0, + isSignBit ? APInt::getSignBit(BitWidth) + : APInt::getAllOnesValue(BitWidth), + KnownZero, KnownOne, 0)) + return &I; + + // Given the known and unknown bits, compute a range that the LHS could be + // in. + if ((KnownOne | KnownZero) != 0) { + // Compute the Min, Max and RHS values based on the known bits. For the + // EQ and NE we use unsigned values. + APInt Min(BitWidth, 0), Max(BitWidth, 0); + const APInt& RHSVal = CI->getValue(); + if (ICmpInst::isSignedPredicate(I.getPredicate())) { + ComputeSignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne, Min, + Max); + } else { + ComputeUnsignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne, Min, + Max); + } + switch (I.getPredicate()) { // LE/GE have been folded already. + default: assert(0 && "Unknown icmp opcode!"); + case ICmpInst::ICMP_EQ: + if (Max.ult(RHSVal) || Min.ugt(RHSVal)) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_NE: + if (Max.ult(RHSVal) || Min.ugt(RHSVal)) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + break; + case ICmpInst::ICMP_ULT: + if (Max.ult(RHSVal)) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Min.uge(RHSVal)) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_UGT: + if (Min.ugt(RHSVal)) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Max.ule(RHSVal)) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_SLT: + if (Max.slt(RHSVal)) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Min.sgt(RHSVal)) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_SGT: + if (Min.sgt(RHSVal)) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Max.sle(RHSVal)) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + } + } + + // Since the RHS is a ConstantInt (CI), if the left hand side is an + // instruction, see if that instruction also has constants so that the + // instruction can be folded into the icmp + if (Instruction *LHSI = dyn_cast(Op0)) + if (Instruction *Res = visitICmpInstWithInstAndIntCst(I, LHSI, CI)) + return Res; + } + + // Handle icmp with constant (but not simple integer constant) RHS + if (Constant *RHSC = dyn_cast(Op1)) { + if (Instruction *LHSI = dyn_cast(Op0)) + switch (LHSI->getOpcode()) { + case Instruction::GetElementPtr: + if (RHSC->isNullValue()) { + // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null + bool isAllZeros = true; + for (unsigned i = 1, e = LHSI->getNumOperands(); i != e; ++i) + if (!isa(LHSI->getOperand(i)) || + !cast(LHSI->getOperand(i))->isNullValue()) { + isAllZeros = false; + break; + } + if (isAllZeros) + return new ICmpInst(I.getPredicate(), LHSI->getOperand(0), + Constant::getNullValue(LHSI->getOperand(0)->getType())); + } + break; + + case Instruction::PHI: + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + break; + case Instruction::Select: { + // If either operand of the select is a constant, we can fold the + // comparison into the select arms, which will cause one to be + // constant folded and the select turned into a bitwise or. + Value *Op1 = 0, *Op2 = 0; + if (LHSI->hasOneUse()) { + if (Constant *C = dyn_cast(LHSI->getOperand(1))) { + // Fold the known value into the constant operand. + Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + // Insert a new ICmp of the other select operand. + Op2 = InsertNewInstBefore(new ICmpInst(I.getPredicate(), + LHSI->getOperand(2), RHSC, + I.getName()), I); + } else if (Constant *C = dyn_cast(LHSI->getOperand(2))) { + // Fold the known value into the constant operand. + Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + // Insert a new ICmp of the other select operand. + Op1 = InsertNewInstBefore(new ICmpInst(I.getPredicate(), + LHSI->getOperand(1), RHSC, + I.getName()), I); + } + } + + if (Op1) + return new SelectInst(LHSI->getOperand(0), Op1, Op2); + break; + } + case Instruction::Malloc: + // If we have (malloc != null), and if the malloc has a single use, we + // can assume it is successful and remove the malloc. + if (LHSI->hasOneUse() && isa(RHSC)) { + AddToWorkList(LHSI); + return ReplaceInstUsesWith(I, ConstantInt::get(Type::Int1Ty, + !isTrueWhenEqual(I))); + } + break; + } + } + + // If we can optimize a 'icmp GEP, P' or 'icmp P, GEP', do so now. + if (User *GEP = dyn_castGetElementPtr(Op0)) + if (Instruction *NI = FoldGEPICmp(GEP, Op1, I.getPredicate(), I)) + return NI; + if (User *GEP = dyn_castGetElementPtr(Op1)) + if (Instruction *NI = FoldGEPICmp(GEP, Op0, + ICmpInst::getSwappedPredicate(I.getPredicate()), I)) + return NI; + + // Test to see if the operands of the icmp are casted versions of other + // values. If the ptr->ptr cast can be stripped off both arguments, we do so + // now. + if (BitCastInst *CI = dyn_cast(Op0)) { + if (isa(Op0->getType()) && + (isa(Op1) || isa(Op1))) { + // We keep moving the cast from the left operand over to the right + // operand, where it can often be eliminated completely. + Op0 = CI->getOperand(0); + + // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast + // so eliminate it as well. + if (BitCastInst *CI2 = dyn_cast(Op1)) + Op1 = CI2->getOperand(0); + + // If Op1 is a constant, we can fold the cast into the constant. + if (Op0->getType() != Op1->getType()) + if (Constant *Op1C = dyn_cast(Op1)) { + Op1 = ConstantExpr::getBitCast(Op1C, Op0->getType()); + } else { + // Otherwise, cast the RHS right before the icmp + Op1 = InsertCastBefore(Instruction::BitCast, Op1, Op0->getType(), I); + } + return new ICmpInst(I.getPredicate(), Op0, Op1); + } + } + + if (isa(Op0)) { + // Handle the special case of: icmp (cast bool to X), + // This comes up when you have code like + // int X = A < B; + // if (X) ... + // For generality, we handle any zero-extension of any operand comparison + // with a constant or another cast from the same type. + if (isa(Op1) || isa(Op1)) + if (Instruction *R = visitICmpInstWithCastAndCast(I)) + return R; + } + + if (I.isEquality()) { + Value *A, *B, *C, *D; + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { + if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 + Value *OtherVal = A == Op1 ? B : A; + return new ICmpInst(I.getPredicate(), OtherVal, + Constant::getNullValue(A->getType())); + } + + if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { + // A^c1 == C^c2 --> A == C^(c1^c2) + if (ConstantInt *C1 = dyn_cast(B)) + if (ConstantInt *C2 = dyn_cast(D)) + if (Op1->hasOneUse()) { + Constant *NC = ConstantInt::get(C1->getValue() ^ C2->getValue()); + Instruction *Xor = BinaryOperator::createXor(C, NC, "tmp"); + return new ICmpInst(I.getPredicate(), A, + InsertNewInstBefore(Xor, I)); + } + + // A^B == A^D -> B == D + if (A == C) return new ICmpInst(I.getPredicate(), B, D); + if (A == D) return new ICmpInst(I.getPredicate(), B, C); + if (B == C) return new ICmpInst(I.getPredicate(), A, D); + if (B == D) return new ICmpInst(I.getPredicate(), A, C); + } + } + + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && + (A == Op0 || B == Op0)) { + // A == (A^B) -> B == 0 + Value *OtherVal = A == Op0 ? B : A; + return new ICmpInst(I.getPredicate(), OtherVal, + Constant::getNullValue(A->getType())); + } + if (match(Op0, m_Sub(m_Value(A), m_Value(B))) && A == Op1) { + // (A-B) == A -> B == 0 + return new ICmpInst(I.getPredicate(), B, + Constant::getNullValue(B->getType())); + } + if (match(Op1, m_Sub(m_Value(A), m_Value(B))) && A == Op0) { + // A == (A-B) -> B == 0 + return new ICmpInst(I.getPredicate(), B, + Constant::getNullValue(B->getType())); + } + + // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 + if (Op0->hasOneUse() && Op1->hasOneUse() && + match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_And(m_Value(C), m_Value(D)))) { + Value *X = 0, *Y = 0, *Z = 0; + + if (A == C) { + X = B; Y = D; Z = A; + } else if (A == D) { + X = B; Y = C; Z = A; + } else if (B == C) { + X = A; Y = D; Z = B; + } else if (B == D) { + X = A; Y = C; Z = B; + } + + if (X) { // Build (X^Y) & Z + Op1 = InsertNewInstBefore(BinaryOperator::createXor(X, Y, "tmp"), I); + Op1 = InsertNewInstBefore(BinaryOperator::createAnd(Op1, Z, "tmp"), I); + I.setOperand(0, Op1); + I.setOperand(1, Constant::getNullValue(Op1->getType())); + return &I; + } + } + } + return Changed ? &I : 0; +} + + +/// FoldICmpDivCst - Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS +/// and CmpRHS are both known to be integer constants. +Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, + ConstantInt *DivRHS) { + ConstantInt *CmpRHS = cast(ICI.getOperand(1)); + const APInt &CmpRHSV = CmpRHS->getValue(); + + // FIXME: If the operand types don't match the type of the divide + // then don't attempt this transform. The code below doesn't have the + // logic to deal with a signed divide and an unsigned compare (and + // vice versa). This is because (x /s C1) getOpcode() == Instruction::SDiv; + if (!ICI.isEquality() && DivIsSigned != ICI.isSignedPredicate()) + return 0; + if (DivRHS->isZero()) + return 0; // The ProdOV computation fails on divide by zero. + + // Compute Prod = CI * DivRHS. We are essentially solving an equation + // of form X/C1=C2. We solve for X by multiplying C1 (DivRHS) and + // C2 (CI). By solving for X we can turn this into a range check + // instead of computing a divide. + ConstantInt *Prod = Multiply(CmpRHS, DivRHS); + + // Determine if the product overflows by seeing if the product is + // not equal to the divide. Make sure we do the same kind of divide + // as in the LHS instruction that we're folding. + bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) : + ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; + + // Get the ICmp opcode + ICmpInst::Predicate Pred = ICI.getPredicate(); + + // Figure out the interval that is being checked. For example, a comparison + // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). + // Compute this interval based on the constants involved and the signedness of + // the compare/divide. This computes a half-open interval, keeping track of + // whether either value in the interval overflows. After analysis each + // overflow variable is set to 0 if it's corresponding bound variable is valid + // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. + int LoOverflow = 0, HiOverflow = 0; + ConstantInt *LoBound = 0, *HiBound = 0; + + + if (!DivIsSigned) { // udiv + // e.g. X/5 op 3 --> [15, 20) + LoBound = Prod; + HiOverflow = LoOverflow = ProdOV; + if (!HiOverflow) + HiOverflow = AddWithOverflow(HiBound, LoBound, DivRHS, false); + } else if (DivRHS->getValue().isPositive()) { // Divisor is > 0. + if (CmpRHSV == 0) { // (X / pos) op 0 + // Can't overflow. e.g. X/2 op 0 --> [-1, 2) + LoBound = cast(ConstantExpr::getNeg(SubOne(DivRHS))); + HiBound = DivRHS; + } else if (CmpRHSV.isPositive()) { // (X / pos) op pos + LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) + HiOverflow = LoOverflow = ProdOV; + if (!HiOverflow) + HiOverflow = AddWithOverflow(HiBound, Prod, DivRHS, true); + } else { // (X / pos) op neg + // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) + Constant *DivRHSH = ConstantExpr::getNeg(SubOne(DivRHS)); + LoOverflow = AddWithOverflow(LoBound, Prod, + cast(DivRHSH), true) ? -1 : 0; + HiBound = AddOne(Prod); + HiOverflow = ProdOV ? -1 : 0; + } + } else { // Divisor is < 0. + if (CmpRHSV == 0) { // (X / neg) op 0 + // e.g. X/-5 op 0 --> [-4, 5) + LoBound = AddOne(DivRHS); + HiBound = cast(ConstantExpr::getNeg(DivRHS)); + if (HiBound == DivRHS) { // -INTMIN = INTMIN + HiOverflow = 1; // [INTMIN+1, overflow) + HiBound = 0; // e.g. X/INTMIN = 0 --> X > INTMIN + } + } else if (CmpRHSV.isPositive()) { // (X / neg) op pos + // e.g. X/-5 op 3 --> [-19, -14) + HiOverflow = LoOverflow = ProdOV ? -1 : 0; + if (!LoOverflow) + LoOverflow = AddWithOverflow(LoBound, Prod, AddOne(DivRHS), true) ?-1:0; + HiBound = AddOne(Prod); + } else { // (X / neg) op neg + // e.g. X/-5 op -3 --> [15, 20) + LoBound = Prod; + LoOverflow = HiOverflow = ProdOV ? 1 : 0; + HiBound = Subtract(Prod, DivRHS); + } + + // Dividing by a negative swaps the condition. LT <-> GT + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + Value *X = DivI->getOperand(0); + switch (Pred) { + default: assert(0 && "Unhandled icmp opcode!"); + case ICmpInst::ICMP_EQ: + if (LoOverflow && HiOverflow) + return ReplaceInstUsesWith(ICI, ConstantInt::getFalse()); + else if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : + ICmpInst::ICMP_UGE, X, LoBound); + else if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, X, HiBound); + else + return InsertRangeTest(X, LoBound, HiBound, DivIsSigned, true, ICI); + case ICmpInst::ICMP_NE: + if (LoOverflow && HiOverflow) + return ReplaceInstUsesWith(ICI, ConstantInt::getTrue()); + else if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, X, LoBound); + else if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : + ICmpInst::ICMP_UGE, X, HiBound); + else + return InsertRangeTest(X, LoBound, HiBound, DivIsSigned, false, ICI); + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + if (LoOverflow == +1) // Low bound is greater than input range. + return ReplaceInstUsesWith(ICI, ConstantInt::getTrue()); + if (LoOverflow == -1) // Low bound is less than input range. + return ReplaceInstUsesWith(ICI, ConstantInt::getFalse()); + return new ICmpInst(Pred, X, LoBound); + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + if (HiOverflow == +1) // High bound greater than input range. + return ReplaceInstUsesWith(ICI, ConstantInt::getFalse()); + else if (HiOverflow == -1) // High bound less than input range. + return ReplaceInstUsesWith(ICI, ConstantInt::getTrue()); + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); + else + return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + } +} + + +/// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". +/// +Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, + Instruction *LHSI, + ConstantInt *RHS) { + const APInt &RHSV = RHS->getValue(); + + switch (LHSI->getOpcode()) { + case Instruction::Xor: // (icmp pred (xor X, XorCST), CI) + if (ConstantInt *XorCST = dyn_cast(LHSI->getOperand(1))) { + // If this is a comparison that tests the signbit (X < 0) or (x > -1), + // fold the xor. + if (ICI.getPredicate() == ICmpInst::ICMP_SLT && RHSV == 0 || + ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV.isAllOnesValue()) { + Value *CompareVal = LHSI->getOperand(0); + + // If the sign bit of the XorCST is not set, there is no change to + // the operation, just stop using the Xor. + if (!XorCST->getValue().isNegative()) { + ICI.setOperand(0, CompareVal); + AddToWorkList(LHSI); + return &ICI; + } + + // Was the old condition true if the operand is positive? + bool isTrueIfPositive = ICI.getPredicate() == ICmpInst::ICMP_SGT; + + // If so, the new one isn't. + isTrueIfPositive ^= true; + + if (isTrueIfPositive) + return new ICmpInst(ICmpInst::ICMP_SGT, CompareVal, SubOne(RHS)); + else + return new ICmpInst(ICmpInst::ICMP_SLT, CompareVal, AddOne(RHS)); + } + } + break; + case Instruction::And: // (icmp pred (and X, AndCST), RHS) + if (LHSI->hasOneUse() && isa(LHSI->getOperand(1)) && + LHSI->getOperand(0)->hasOneUse()) { + ConstantInt *AndCST = cast(LHSI->getOperand(1)); + + // If the LHS is an AND of a truncating cast, we can widen the + // and/compare to be the input width without changing the value + // produced, eliminating a cast. + if (TruncInst *Cast = dyn_cast(LHSI->getOperand(0))) { + // We can do this transformation if either the AND constant does not + // have its sign bit set or if it is an equality comparison. + // Extending a relational comparison when we're checking the sign + // bit would not work. + if (Cast->hasOneUse() && + (ICI.isEquality() || AndCST->getValue().isPositive() && + RHSV.isPositive())) { + uint32_t BitWidth = + cast(Cast->getOperand(0)->getType())->getBitWidth(); + APInt NewCST = AndCST->getValue(); + NewCST.zext(BitWidth); + APInt NewCI = RHSV; + NewCI.zext(BitWidth); + Instruction *NewAnd = + BinaryOperator::createAnd(Cast->getOperand(0), + ConstantInt::get(NewCST),LHSI->getName()); + InsertNewInstBefore(NewAnd, ICI); + return new ICmpInst(ICI.getPredicate(), NewAnd, + ConstantInt::get(NewCI)); + } + } + + // If this is: (X >> C1) & C2 != C3 (where any shift and any compare + // could exist), turn it into (X & (C2 << C1)) != (C3 << C1). This + // happens a LOT in code produced by the C front-end, for bitfield + // access. + BinaryOperator *Shift = dyn_cast(LHSI->getOperand(0)); + if (Shift && !Shift->isShift()) + Shift = 0; + + ConstantInt *ShAmt; + ShAmt = Shift ? dyn_cast(Shift->getOperand(1)) : 0; + const Type *Ty = Shift ? Shift->getType() : 0; // Type of the shift. + const Type *AndTy = AndCST->getType(); // Type of the and. + + // We can fold this as long as we can't shift unknown bits + // into the mask. This can only happen with signed shift + // rights, as they sign-extend. + if (ShAmt) { + bool CanFold = Shift->isLogicalShift(); + if (!CanFold) { + // To test for the bad case of the signed shr, see if any + // of the bits shifted in could be tested after the mask. + uint32_t TyBits = Ty->getPrimitiveSizeInBits(); + int ShAmtVal = TyBits - ShAmt->getLimitedValue(TyBits); + + uint32_t BitWidth = AndTy->getPrimitiveSizeInBits(); + if ((APInt::getHighBitsSet(BitWidth, BitWidth-ShAmtVal) & + AndCST->getValue()) == 0) + CanFold = true; + } + + if (CanFold) { + Constant *NewCst; + if (Shift->getOpcode() == Instruction::Shl) + NewCst = ConstantExpr::getLShr(RHS, ShAmt); + else + NewCst = ConstantExpr::getShl(RHS, ShAmt); + + // Check to see if we are shifting out any of the bits being + // compared. + if (ConstantExpr::get(Shift->getOpcode(), NewCst, ShAmt) != RHS) { + // If we shifted bits out, the fold is not going to work out. + // As a special case, check to see if this means that the + // result is always true or false now. + if (ICI.getPredicate() == ICmpInst::ICMP_EQ) + return ReplaceInstUsesWith(ICI, ConstantInt::getFalse()); + if (ICI.getPredicate() == ICmpInst::ICMP_NE) + return ReplaceInstUsesWith(ICI, ConstantInt::getTrue()); + } else { + ICI.setOperand(1, NewCst); + Constant *NewAndCST; + if (Shift->getOpcode() == Instruction::Shl) + NewAndCST = ConstantExpr::getLShr(AndCST, ShAmt); + else + NewAndCST = ConstantExpr::getShl(AndCST, ShAmt); + LHSI->setOperand(1, NewAndCST); + LHSI->setOperand(0, Shift->getOperand(0)); + AddToWorkList(Shift); // Shift is dead. + AddUsesToWorkList(ICI); + return &ICI; + } + } + } + + // Turn ((X >> Y) & C) == 0 into (X & (C << Y)) == 0. The later is + // preferable because it allows the C<hasOneUse() && RHSV == 0 && + ICI.isEquality() && !Shift->isArithmeticShift() && + isa(Shift->getOperand(0))) { + // Compute C << Y. + Value *NS; + if (Shift->getOpcode() == Instruction::LShr) { + NS = BinaryOperator::createShl(AndCST, + Shift->getOperand(1), "tmp"); + } else { + // Insert a logical shift. + NS = BinaryOperator::createLShr(AndCST, + Shift->getOperand(1), "tmp"); + } + InsertNewInstBefore(cast(NS), ICI); + + // Compute X & (C << Y). + Instruction *NewAnd = + BinaryOperator::createAnd(Shift->getOperand(0), NS, LHSI->getName()); + InsertNewInstBefore(NewAnd, ICI); + + ICI.setOperand(0, NewAnd); + return &ICI; + } + } + break; + + case Instruction::Shl: { // (icmp pred (shl X, ShAmt), CI) + ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1)); + if (!ShAmt) break; + + uint32_t TypeBits = RHSV.getBitWidth(); + + // Check that the shift amount is in range. If not, don't perform + // undefined shifts. When the shift is visited it will be + // simplified. + if (ShAmt->uge(TypeBits)) + break; + + if (ICI.isEquality()) { + // If we are comparing against bits always shifted out, the + // comparison cannot succeed. + Constant *Comp = + ConstantExpr::getShl(ConstantExpr::getLShr(RHS, ShAmt), ShAmt); + if (Comp != RHS) {// Comparing against a bit that we know is zero. + bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + Constant *Cst = ConstantInt::get(Type::Int1Ty, IsICMP_NE); + return ReplaceInstUsesWith(ICI, Cst); + } + + if (LHSI->hasOneUse()) { + // Otherwise strength reduce the shift into an and. + uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + Constant *Mask = + ConstantInt::get(APInt::getLowBitsSet(TypeBits, TypeBits-ShAmtVal)); + + Instruction *AndI = + BinaryOperator::createAnd(LHSI->getOperand(0), + Mask, LHSI->getName()+".mask"); + Value *And = InsertNewInstBefore(AndI, ICI); + return new ICmpInst(ICI.getPredicate(), And, + ConstantInt::get(RHSV.lshr(ShAmtVal))); + } + } + + // Otherwise, if this is a comparison of the sign bit, simplify to and/test. + bool TrueIfSigned = false; + if (LHSI->hasOneUse() && + isSignBitCheck(ICI.getPredicate(), RHS, TrueIfSigned)) { + // (X << 31) (X&1) != 0 + Constant *Mask = ConstantInt::get(APInt(TypeBits, 1) << + (TypeBits-ShAmt->getZExtValue()-1)); + Instruction *AndI = + BinaryOperator::createAnd(LHSI->getOperand(0), + Mask, LHSI->getName()+".mask"); + Value *And = InsertNewInstBefore(AndI, ICI); + + return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, + And, Constant::getNullValue(And->getType())); + } + break; + } + + case Instruction::LShr: // (icmp pred (shr X, ShAmt), CI) + case Instruction::AShr: { + ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1)); + if (!ShAmt) break; + + if (ICI.isEquality()) { + // Check that the shift amount is in range. If not, don't perform + // undefined shifts. When the shift is visited it will be + // simplified. + uint32_t TypeBits = RHSV.getBitWidth(); + if (ShAmt->uge(TypeBits)) + break; + uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + + // If we are comparing against bits always shifted out, the + // comparison cannot succeed. + APInt Comp = RHSV << ShAmtVal; + if (LHSI->getOpcode() == Instruction::LShr) + Comp = Comp.lshr(ShAmtVal); + else + Comp = Comp.ashr(ShAmtVal); + + if (Comp != RHSV) { // Comparing against a bit that we know is zero. + bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + Constant *Cst = ConstantInt::get(Type::Int1Ty, IsICMP_NE); + return ReplaceInstUsesWith(ICI, Cst); + } + + if (LHSI->hasOneUse() || RHSV == 0) { + // Otherwise strength reduce the shift into an and. + APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); + Constant *Mask = ConstantInt::get(Val); + + Instruction *AndI = + BinaryOperator::createAnd(LHSI->getOperand(0), + Mask, LHSI->getName()+".mask"); + Value *And = InsertNewInstBefore(AndI, ICI); + return new ICmpInst(ICI.getPredicate(), And, + ConstantExpr::getShl(RHS, ShAmt)); + } + } + break; + } + + case Instruction::SDiv: + case Instruction::UDiv: + // Fold: icmp pred ([us]div X, C1), C2 -> range test + // Fold this div into the comparison, producing a range check. + // Determine, based on the divide type, what the range is being + // checked. If there is an overflow on the low or high side, remember + // it, otherwise compute the range [low, hi) bounding the new value. + // See: InsertRangeTest above for the kinds of replacements possible. + if (ConstantInt *DivRHS = dyn_cast(LHSI->getOperand(1))) + if (Instruction *R = FoldICmpDivCst(ICI, cast(LHSI), + DivRHS)) + return R; + break; + } + + // Simplify icmp_eq and icmp_ne instructions with integer constant RHS. + if (ICI.isEquality()) { + bool isICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + + // If the first operand is (add|sub|and|or|xor|rem) with a constant, and + // the second operand is a constant, simplify a bit. + if (BinaryOperator *BO = dyn_cast(LHSI)) { + switch (BO->getOpcode()) { + case Instruction::SRem: + // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. + if (RHSV == 0 && isa(BO->getOperand(1)) &&BO->hasOneUse()){ + const APInt &V = cast(BO->getOperand(1))->getValue(); + if (V.sgt(APInt(V.getBitWidth(), 1)) && V.isPowerOf2()) { + Instruction *NewRem = + BinaryOperator::createURem(BO->getOperand(0), BO->getOperand(1), + BO->getName()); + InsertNewInstBefore(NewRem, ICI); + return new ICmpInst(ICI.getPredicate(), NewRem, + Constant::getNullValue(BO->getType())); + } + } + break; + case Instruction::Add: + // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. + if (ConstantInt *BOp1C = dyn_cast(BO->getOperand(1))) { + if (BO->hasOneUse()) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + Subtract(RHS, BOp1C)); + } else if (RHSV == 0) { + // Replace ((add A, B) != 0) with (A != -B) if A or B is + // efficiently invertible, or if the add has just this one use. + Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1); + + if (Value *NegVal = dyn_castNegVal(BOp1)) + return new ICmpInst(ICI.getPredicate(), BOp0, NegVal); + else if (Value *NegVal = dyn_castNegVal(BOp0)) + return new ICmpInst(ICI.getPredicate(), NegVal, BOp1); + else if (BO->hasOneUse()) { + Instruction *Neg = BinaryOperator::createNeg(BOp1); + InsertNewInstBefore(Neg, ICI); + Neg->takeName(BO); + return new ICmpInst(ICI.getPredicate(), BOp0, Neg); + } + } + break; + case Instruction::Xor: + // For the xor case, we can xor two constants together, eliminating + // the explicit xor. + if (Constant *BOC = dyn_cast(BO->getOperand(1))) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + ConstantExpr::getXor(RHS, BOC)); + + // FALLTHROUGH + case Instruction::Sub: + // Replace (([sub|xor] A, B) != 0) with (A != B) + if (RHSV == 0) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + BO->getOperand(1)); + break; + + case Instruction::Or: + // If bits are being or'd in that are not present in the constant we + // are comparing against, then the comparison could never succeed! + if (Constant *BOC = dyn_cast(BO->getOperand(1))) { + Constant *NotCI = ConstantExpr::getNot(RHS); + if (!ConstantExpr::getAnd(BOC, NotCI)->isNullValue()) + return ReplaceInstUsesWith(ICI, ConstantInt::get(Type::Int1Ty, + isICMP_NE)); + } + break; + + case Instruction::And: + if (ConstantInt *BOC = dyn_cast(BO->getOperand(1))) { + // If bits are being compared against that are and'd out, then the + // comparison can never succeed! + if ((RHSV & ~BOC->getValue()) != 0) + return ReplaceInstUsesWith(ICI, ConstantInt::get(Type::Int1Ty, + isICMP_NE)); + + // If we have ((X & C) == C), turn it into ((X & C) != 0). + if (RHS == BOC && RHSV.isPowerOf2()) + return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : + ICmpInst::ICMP_NE, LHSI, + Constant::getNullValue(RHS->getType())); + + // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 + if (isSignBit(BOC)) { + Value *X = BO->getOperand(0); + Constant *Zero = Constant::getNullValue(X->getType()); + ICmpInst::Predicate pred = isICMP_NE ? + ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; + return new ICmpInst(pred, X, Zero); + } + + // ((X & ~7) == 0) --> X < 8 + if (RHSV == 0 && isHighOnes(BOC)) { + Value *X = BO->getOperand(0); + Constant *NegX = ConstantExpr::getNeg(BOC); + ICmpInst::Predicate pred = isICMP_NE ? + ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + return new ICmpInst(pred, X, NegX); + } + } + default: break; + } + } else if (IntrinsicInst *II = dyn_cast(LHSI)) { + // Handle icmp {eq|ne} , intcst. + if (II->getIntrinsicID() == Intrinsic::bswap) { + AddToWorkList(II); + ICI.setOperand(0, II->getOperand(1)); + ICI.setOperand(1, ConstantInt::get(RHSV.byteSwap())); + return &ICI; + } + } + } else { // Not a ICMP_EQ/ICMP_NE + // If the LHS is a cast from an integral value of the same size, + // then since we know the RHS is a constant, try to simlify. + if (CastInst *Cast = dyn_cast(LHSI)) { + Value *CastOp = Cast->getOperand(0); + const Type *SrcTy = CastOp->getType(); + uint32_t SrcTySize = SrcTy->getPrimitiveSizeInBits(); + if (SrcTy->isInteger() && + SrcTySize == Cast->getType()->getPrimitiveSizeInBits()) { + // If this is an unsigned comparison, try to make the comparison use + // smaller constant values. + if (ICI.getPredicate() == ICmpInst::ICMP_ULT && RHSV.isSignBit()) { + // X u< 128 => X s> -1 + return new ICmpInst(ICmpInst::ICMP_SGT, CastOp, + ConstantInt::get(APInt::getAllOnesValue(SrcTySize))); + } else if (ICI.getPredicate() == ICmpInst::ICMP_UGT && + RHSV == APInt::getSignedMaxValue(SrcTySize)) { + // X u> 127 => X s< 0 + return new ICmpInst(ICmpInst::ICMP_SLT, CastOp, + Constant::getNullValue(SrcTy)); + } + } + } + } + return 0; +} + +/// visitICmpInstWithCastAndCast - Handle icmp (cast x to y), (cast/cst). +/// We only handle extending casts so far. +/// +Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { + const CastInst *LHSCI = cast(ICI.getOperand(0)); + Value *LHSCIOp = LHSCI->getOperand(0); + const Type *SrcTy = LHSCIOp->getType(); + const Type *DestTy = LHSCI->getType(); + Value *RHSCIOp; + + // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the + // integer type is the same size as the pointer type. + if (LHSCI->getOpcode() == Instruction::PtrToInt && + getTargetData().getPointerSizeInBits() == + cast(DestTy)->getBitWidth()) { + Value *RHSOp = 0; + if (Constant *RHSC = dyn_cast(ICI.getOperand(1))) { + RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); + } else if (PtrToIntInst *RHSC = dyn_cast(ICI.getOperand(1))) { + RHSOp = RHSC->getOperand(0); + // If the pointer types don't match, insert a bitcast. + if (LHSCIOp->getType() != RHSOp->getType()) + RHSOp = InsertCastBefore(Instruction::BitCast, RHSOp, + LHSCIOp->getType(), ICI); + } + + if (RHSOp) + return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSOp); + } + + // The code below only handles extension cast instructions, so far. + // Enforce this. + if (LHSCI->getOpcode() != Instruction::ZExt && + LHSCI->getOpcode() != Instruction::SExt) + return 0; + + bool isSignedExt = LHSCI->getOpcode() == Instruction::SExt; + bool isSignedCmp = ICI.isSignedPredicate(); + + if (CastInst *CI = dyn_cast(ICI.getOperand(1))) { + // Not an extension from the same type? + RHSCIOp = CI->getOperand(0); + if (RHSCIOp->getType() != LHSCIOp->getType()) + return 0; + + // If the signedness of the two compares doesn't agree (i.e. one is a sext + // and the other is a zext), then we can't handle this. + if (CI->getOpcode() != LHSCI->getOpcode()) + return 0; + + // Likewise, if the signedness of the [sz]exts and the compare don't match, + // then we can't handle this. + if (isSignedExt != isSignedCmp && !ICI.isEquality()) + return 0; + + // Okay, just insert a compare of the reduced operands now! + return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSCIOp); + } + + // If we aren't dealing with a constant on the RHS, exit early + ConstantInt *CI = dyn_cast(ICI.getOperand(1)); + if (!CI) + return 0; + + // Compute the constant that would happen if we truncated to SrcTy then + // reextended to DestTy. + Constant *Res1 = ConstantExpr::getTrunc(CI, SrcTy); + Constant *Res2 = ConstantExpr::getCast(LHSCI->getOpcode(), Res1, DestTy); + + // If the re-extended constant didn't change... + if (Res2 == CI) { + // Make sure that sign of the Cmp and the sign of the Cast are the same. + // For example, we might have: + // %A = sext short %X to uint + // %B = icmp ugt uint %A, 1330 + // It is incorrect to transform this into + // %B = icmp ugt short %X, 1330 + // because %A may have negative value. + // + // However, it is OK if SrcTy is bool (See cast-set.ll testcase) + // OR operation is EQ/NE. + if (isSignedExt == isSignedCmp || SrcTy == Type::Int1Ty || ICI.isEquality()) + return new ICmpInst(ICI.getPredicate(), LHSCIOp, Res1); + else + return 0; + } + + // The re-extended constant changed so the constant cannot be represented + // in the shorter type. Consequently, we cannot emit a simple comparison. + + // First, handle some easy cases. We know the result cannot be equal at this + // point so handle the ICI.isEquality() cases + if (ICI.getPredicate() == ICmpInst::ICMP_EQ) + return ReplaceInstUsesWith(ICI, ConstantInt::getFalse()); + if (ICI.getPredicate() == ICmpInst::ICMP_NE) + return ReplaceInstUsesWith(ICI, ConstantInt::getTrue()); + + // Evaluate the comparison for LT (we invert for GT below). LE and GE cases + // should have been folded away previously and not enter in here. + Value *Result; + if (isSignedCmp) { + // We're performing a signed comparison. + if (cast(CI)->getValue().isNegative()) + Result = ConstantInt::getFalse(); // X < (small) --> false + else + Result = ConstantInt::getTrue(); // X < (large) --> true + } else { + // We're performing an unsigned comparison. + if (isSignedExt) { + // We're performing an unsigned comp with a sign extended value. + // This is true if the input is >= 0. [aka >s -1] + Constant *NegOne = ConstantInt::getAllOnesValue(SrcTy); + Result = InsertNewInstBefore(new ICmpInst(ICmpInst::ICMP_SGT, LHSCIOp, + NegOne, ICI.getName()), ICI); + } else { + // Unsigned extend & unsigned compare -> always true. + Result = ConstantInt::getTrue(); + } + } + + // Finally, return the value computed. + if (ICI.getPredicate() == ICmpInst::ICMP_ULT || + ICI.getPredicate() == ICmpInst::ICMP_SLT) { + return ReplaceInstUsesWith(ICI, Result); + } else { + assert((ICI.getPredicate()==ICmpInst::ICMP_UGT || + ICI.getPredicate()==ICmpInst::ICMP_SGT) && + "ICmp should be folded!"); + if (Constant *CI = dyn_cast(Result)) + return ReplaceInstUsesWith(ICI, ConstantExpr::getNot(CI)); + else + return BinaryOperator::createNot(Result); + } +} + +Instruction *InstCombiner::visitShl(BinaryOperator &I) { + return commonShiftTransforms(I); +} + +Instruction *InstCombiner::visitLShr(BinaryOperator &I) { + return commonShiftTransforms(I); +} + +Instruction *InstCombiner::visitAShr(BinaryOperator &I) { + return commonShiftTransforms(I); +} + +Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { + assert(I.getOperand(1)->getType() == I.getOperand(0)->getType()); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // shl X, 0 == X and shr X, 0 == X + // shl 0, X == 0 and shr 0, X == 0 + if (Op1 == Constant::getNullValue(Op1->getType()) || + Op0 == Constant::getNullValue(Op0->getType())) + return ReplaceInstUsesWith(I, Op0); + + if (isa(Op0)) { + if (I.getOpcode() == Instruction::AShr) // undef >>s X -> undef + return ReplaceInstUsesWith(I, Op0); + else // undef << X -> 0, undef >>u X -> 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + } + if (isa(Op1)) { + if (I.getOpcode() == Instruction::AShr) // X >>s undef -> X + return ReplaceInstUsesWith(I, Op0); + else // X << undef, X >>u undef -> 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + } + + // ashr int -1, X = -1 (for any arithmetic shift rights of ~0) + if (I.getOpcode() == Instruction::AShr) + if (ConstantInt *CSI = dyn_cast(Op0)) + if (CSI->isAllOnesValue()) + return ReplaceInstUsesWith(I, CSI); + + // Try to fold constant and into select arguments. + if (isa(Op0)) + if (SelectInst *SI = dyn_cast(Op1)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + + // See if we can turn a signed shr into an unsigned shr. + if (I.isArithmeticShift()) { + if (MaskedValueIsZero(Op0, + APInt::getSignBit(I.getType()->getPrimitiveSizeInBits()))) { + return BinaryOperator::createLShr(Op0, Op1, I.getName()); + } + } + + if (ConstantInt *CUI = dyn_cast(Op1)) + if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) + return Res; + return 0; +} + +Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, + BinaryOperator &I) { + bool isLeftShift = I.getOpcode() == Instruction::Shl; + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + uint32_t TypeBits = Op0->getType()->getPrimitiveSizeInBits(); + APInt KnownZero(TypeBits, 0), KnownOne(TypeBits, 0); + if (SimplifyDemandedBits(&I, APInt::getAllOnesValue(TypeBits), + KnownZero, KnownOne)) + return &I; + + // shl uint X, 32 = 0 and shr ubyte Y, 9 = 0, ... just don't eliminate shr + // of a signed value. + // + if (Op1->uge(TypeBits)) { + if (I.getOpcode() != Instruction::AShr) + return ReplaceInstUsesWith(I, Constant::getNullValue(Op0->getType())); + else { + I.setOperand(1, ConstantInt::get(I.getType(), TypeBits-1)); + return &I; + } + } + + // ((X*C1) << C2) == (X * (C1 << C2)) + if (BinaryOperator *BO = dyn_cast(Op0)) + if (BO->getOpcode() == Instruction::Mul && isLeftShift) + if (Constant *BOOp = dyn_cast(BO->getOperand(1))) + return BinaryOperator::createMul(BO->getOperand(0), + ConstantExpr::getShl(BOOp, Op1)); + + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + if (isa(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + + if (Op0->hasOneUse()) { + if (BinaryOperator *Op0BO = dyn_cast(Op0)) { + // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) + Value *V1, *V2; + ConstantInt *CC; + switch (Op0BO->getOpcode()) { + default: break; + case Instruction::Add: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + // These operators commute. + // Turn (Y + (X >> C)) << C -> (X + (Y << C)) & (~0 << C) + if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() && + match(Op0BO->getOperand(1), + m_Shr(m_Value(V1), m_ConstantInt(CC))) && CC == Op1) { + Instruction *YS = BinaryOperator::createShl( + Op0BO->getOperand(0), Op1, + Op0BO->getName()); + InsertNewInstBefore(YS, I); // (Y << C) + Instruction *X = + BinaryOperator::create(Op0BO->getOpcode(), YS, V1, + Op0BO->getOperand(1)->getName()); + InsertNewInstBefore(X, I); // (X + (Y << C)) + uint32_t Op1Val = Op1->getLimitedValue(TypeBits); + return BinaryOperator::createAnd(X, ConstantInt::get( + APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val))); + } + + // Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C)) + Value *Op0BOOp1 = Op0BO->getOperand(1); + if (isLeftShift && Op0BOOp1->hasOneUse() && + match(Op0BOOp1, + m_And(m_Shr(m_Value(V1), m_Value(V2)),m_ConstantInt(CC))) && + cast(Op0BOOp1)->getOperand(0)->hasOneUse() && + V2 == Op1) { + Instruction *YS = BinaryOperator::createShl( + Op0BO->getOperand(0), Op1, + Op0BO->getName()); + InsertNewInstBefore(YS, I); // (Y << C) + Instruction *XM = + BinaryOperator::createAnd(V1, ConstantExpr::getShl(CC, Op1), + V1->getName()+".mask"); + InsertNewInstBefore(XM, I); // X & (CC << C) + + return BinaryOperator::create(Op0BO->getOpcode(), YS, XM); + } + } + + // FALL THROUGH. + case Instruction::Sub: { + // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) + if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && + match(Op0BO->getOperand(0), + m_Shr(m_Value(V1), m_ConstantInt(CC))) && CC == Op1) { + Instruction *YS = BinaryOperator::createShl( + Op0BO->getOperand(1), Op1, + Op0BO->getName()); + InsertNewInstBefore(YS, I); // (Y << C) + Instruction *X = + BinaryOperator::create(Op0BO->getOpcode(), V1, YS, + Op0BO->getOperand(0)->getName()); + InsertNewInstBefore(X, I); // (X + (Y << C)) + uint32_t Op1Val = Op1->getLimitedValue(TypeBits); + return BinaryOperator::createAnd(X, ConstantInt::get( + APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val))); + } + + // Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C) + if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && + match(Op0BO->getOperand(0), + m_And(m_Shr(m_Value(V1), m_Value(V2)), + m_ConstantInt(CC))) && V2 == Op1 && + cast(Op0BO->getOperand(0)) + ->getOperand(0)->hasOneUse()) { + Instruction *YS = BinaryOperator::createShl( + Op0BO->getOperand(1), Op1, + Op0BO->getName()); + InsertNewInstBefore(YS, I); // (Y << C) + Instruction *XM = + BinaryOperator::createAnd(V1, ConstantExpr::getShl(CC, Op1), + V1->getName()+".mask"); + InsertNewInstBefore(XM, I); // X & (CC << C) + + return BinaryOperator::create(Op0BO->getOpcode(), XM, YS); + } + + break; + } + } + + + // If the operand is an bitwise operator with a constant RHS, and the + // shift is the only use, we can pull it out of the shift. + if (ConstantInt *Op0C = dyn_cast(Op0BO->getOperand(1))) { + bool isValid = true; // Valid only for And, Or, Xor + bool highBitSet = false; // Transform if high bit of constant set? + + switch (Op0BO->getOpcode()) { + default: isValid = false; break; // Do not perform transform! + case Instruction::Add: + isValid = isLeftShift; + break; + case Instruction::Or: + case Instruction::Xor: + highBitSet = false; + break; + case Instruction::And: + highBitSet = true; + break; + } + + // If this is a signed shift right, and the high bit is modified + // by the logical operation, do not perform the transformation. + // The highBitSet boolean indicates the value of the high bit of + // the constant which would cause it to be modified for this + // operation. + // + if (isValid && !isLeftShift && I.getOpcode() == Instruction::AShr) { + isValid = Op0C->getValue()[TypeBits-1] == highBitSet; + } + + if (isValid) { + Constant *NewRHS = ConstantExpr::get(I.getOpcode(), Op0C, Op1); + + Instruction *NewShift = + BinaryOperator::create(I.getOpcode(), Op0BO->getOperand(0), Op1); + InsertNewInstBefore(NewShift, I); + NewShift->takeName(Op0BO); + + return BinaryOperator::create(Op0BO->getOpcode(), NewShift, + NewRHS); + } + } + } + } + + // Find out if this is a shift of a shift by a constant. + BinaryOperator *ShiftOp = dyn_cast(Op0); + if (ShiftOp && !ShiftOp->isShift()) + ShiftOp = 0; + + if (ShiftOp && isa(ShiftOp->getOperand(1))) { + ConstantInt *ShiftAmt1C = cast(ShiftOp->getOperand(1)); + uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); + uint32_t ShiftAmt2 = Op1->getLimitedValue(TypeBits); + assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); + if (ShiftAmt1 == 0) return 0; // Will be simplified in the future. + Value *X = ShiftOp->getOperand(0); + + uint32_t AmtSum = ShiftAmt1+ShiftAmt2; // Fold into one big shift. + if (AmtSum > TypeBits) + AmtSum = TypeBits; + + const IntegerType *Ty = cast(I.getType()); + + // Check for (X << c1) << c2 and (X >> c1) >> c2 + if (I.getOpcode() == ShiftOp->getOpcode()) { + return BinaryOperator::create(I.getOpcode(), X, + ConstantInt::get(Ty, AmtSum)); + } else if (ShiftOp->getOpcode() == Instruction::LShr && + I.getOpcode() == Instruction::AShr) { + // ((X >>u C1) >>s C2) -> (X >>u (C1+C2)) since C1 != 0. + return BinaryOperator::createLShr(X, ConstantInt::get(Ty, AmtSum)); + } else if (ShiftOp->getOpcode() == Instruction::AShr && + I.getOpcode() == Instruction::LShr) { + // ((X >>s C1) >>u C2) -> ((X >>s (C1+C2)) & mask) since C1 != 0. + Instruction *Shift = + BinaryOperator::createAShr(X, ConstantInt::get(Ty, AmtSum)); + InsertNewInstBefore(Shift, I); + + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::createAnd(Shift, ConstantInt::get(Mask)); + } + + // Okay, if we get here, one shift must be left, and the other shift must be + // right. See if the amounts are equal. + if (ShiftAmt1 == ShiftAmt2) { + // If we have ((X >>? C) << C), turn this into X & (-1 << C). + if (I.getOpcode() == Instruction::Shl) { + APInt Mask(APInt::getHighBitsSet(TypeBits, TypeBits - ShiftAmt1)); + return BinaryOperator::createAnd(X, ConstantInt::get(Mask)); + } + // If we have ((X << C) >>u C), turn this into X & (-1 >>u C). + if (I.getOpcode() == Instruction::LShr) { + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1)); + return BinaryOperator::createAnd(X, ConstantInt::get(Mask)); + } + // We can simplify ((X << C) >>s C) into a trunc + sext. + // NOTE: we could do this for any C, but that would make 'unusual' integer + // types. For now, just stick to ones well-supported by the code + // generators. + const Type *SExtType = 0; + switch (Ty->getBitWidth() - ShiftAmt1) { + case 1 : + case 8 : + case 16 : + case 32 : + case 64 : + case 128: + SExtType = IntegerType::get(Ty->getBitWidth() - ShiftAmt1); + break; + default: break; + } + if (SExtType) { + Instruction *NewTrunc = new TruncInst(X, SExtType, "sext"); + InsertNewInstBefore(NewTrunc, I); + return new SExtInst(NewTrunc, Ty); + } + // Otherwise, we can't handle it yet. + } else if (ShiftAmt1 < ShiftAmt2) { + uint32_t ShiftDiff = ShiftAmt2-ShiftAmt1; + + // (X >>? C1) << C2 --> X << (C2-C1) & (-1 << C2) + if (I.getOpcode() == Instruction::Shl) { + assert(ShiftOp->getOpcode() == Instruction::LShr || + ShiftOp->getOpcode() == Instruction::AShr); + Instruction *Shift = + BinaryOperator::createShl(X, ConstantInt::get(Ty, ShiftDiff)); + InsertNewInstBefore(Shift, I); + + APInt Mask(APInt::getHighBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::createAnd(Shift, ConstantInt::get(Mask)); + } + + // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2) + if (I.getOpcode() == Instruction::LShr) { + assert(ShiftOp->getOpcode() == Instruction::Shl); + Instruction *Shift = + BinaryOperator::createLShr(X, ConstantInt::get(Ty, ShiftDiff)); + InsertNewInstBefore(Shift, I); + + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::createAnd(Shift, ConstantInt::get(Mask)); + } + + // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. + } else { + assert(ShiftAmt2 < ShiftAmt1); + uint32_t ShiftDiff = ShiftAmt1-ShiftAmt2; + + // (X >>? C1) << C2 --> X >>? (C1-C2) & (-1 << C2) + if (I.getOpcode() == Instruction::Shl) { + assert(ShiftOp->getOpcode() == Instruction::LShr || + ShiftOp->getOpcode() == Instruction::AShr); + Instruction *Shift = + BinaryOperator::create(ShiftOp->getOpcode(), X, + ConstantInt::get(Ty, ShiftDiff)); + InsertNewInstBefore(Shift, I); + + APInt Mask(APInt::getHighBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::createAnd(Shift, ConstantInt::get(Mask)); + } + + // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2) + if (I.getOpcode() == Instruction::LShr) { + assert(ShiftOp->getOpcode() == Instruction::Shl); + Instruction *Shift = + BinaryOperator::createShl(X, ConstantInt::get(Ty, ShiftDiff)); + InsertNewInstBefore(Shift, I); + + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::createAnd(Shift, ConstantInt::get(Mask)); + } + + // We can't handle (X << C1) >>a C2, it shifts arbitrary bits in. + } + } + return 0; +} + + +/// DecomposeSimpleLinearExpr - Analyze 'Val', seeing if it is a simple linear +/// expression. If so, decompose it, returning some value X, such that Val is +/// X*Scale+Offset. +/// +static Value *DecomposeSimpleLinearExpr(Value *Val, unsigned &Scale, + int &Offset) { + assert(Val->getType() == Type::Int32Ty && "Unexpected allocation size type!"); + if (ConstantInt *CI = dyn_cast(Val)) { + Offset = CI->getZExtValue(); + Scale = 1; + return ConstantInt::get(Type::Int32Ty, 0); + } else if (Instruction *I = dyn_cast(Val)) { + if (I->getNumOperands() == 2) { + if (ConstantInt *CUI = dyn_cast(I->getOperand(1))) { + if (I->getOpcode() == Instruction::Shl) { + // This is a value scaled by '1 << the shift amt'. + Scale = 1U << CUI->getZExtValue(); + Offset = 0; + return I->getOperand(0); + } else if (I->getOpcode() == Instruction::Mul) { + // This value is scaled by 'CUI'. + Scale = CUI->getZExtValue(); + Offset = 0; + return I->getOperand(0); + } else if (I->getOpcode() == Instruction::Add) { + // We have X+C. Check to see if we really have (X*C2)+C1, + // where C1 is divisible by C2. + unsigned SubScale; + Value *SubVal = + DecomposeSimpleLinearExpr(I->getOperand(0), SubScale, Offset); + Offset += CUI->getZExtValue(); + if (SubScale > 1 && (Offset % SubScale == 0)) { + Scale = SubScale; + return SubVal; + } + } + } + } + } + + // Otherwise, we can't look past this. + Scale = 1; + Offset = 0; + return Val; +} + + +/// PromoteCastOfAllocation - If we find a cast of an allocation instruction, +/// try to eliminate the cast by moving the type information into the alloc. +Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, + AllocationInst &AI) { + const PointerType *PTy = cast(CI.getType()); + + // Remove any uses of AI that are dead. + assert(!CI.use_empty() && "Dead instructions should be removed earlier!"); + + for (Value::use_iterator UI = AI.use_begin(), E = AI.use_end(); UI != E; ) { + Instruction *User = cast(*UI++); + if (isInstructionTriviallyDead(User)) { + while (UI != E && *UI == User) + ++UI; // If this instruction uses AI more than once, don't break UI. + + ++NumDeadInst; + DOUT << "IC: DCE: " << *User; + EraseInstFromFunction(*User); + } + } + + // Get the type really allocated and the type casted to. + const Type *AllocElTy = AI.getAllocatedType(); + const Type *CastElTy = PTy->getElementType(); + if (!AllocElTy->isSized() || !CastElTy->isSized()) return 0; + + unsigned AllocElTyAlign = TD->getABITypeAlignment(AllocElTy); + unsigned CastElTyAlign = TD->getABITypeAlignment(CastElTy); + if (CastElTyAlign < AllocElTyAlign) return 0; + + // If the allocation has multiple uses, only promote it if we are strictly + // increasing the alignment of the resultant allocation. If we keep it the + // same, we open the door to infinite loops of various kinds. + if (!AI.hasOneUse() && CastElTyAlign == AllocElTyAlign) return 0; + + uint64_t AllocElTySize = TD->getTypeSize(AllocElTy); + uint64_t CastElTySize = TD->getTypeSize(CastElTy); + if (CastElTySize == 0 || AllocElTySize == 0) return 0; + + // See if we can satisfy the modulus by pulling a scale out of the array + // size argument. + unsigned ArraySizeScale; + int ArrayOffset; + Value *NumElements = // See if the array size is a decomposable linear expr. + DecomposeSimpleLinearExpr(AI.getOperand(0), ArraySizeScale, ArrayOffset); + + // If we can now satisfy the modulus, by using a non-1 scale, we really can + // do the xform. + if ((AllocElTySize*ArraySizeScale) % CastElTySize != 0 || + (AllocElTySize*ArrayOffset ) % CastElTySize != 0) return 0; + + unsigned Scale = (AllocElTySize*ArraySizeScale)/CastElTySize; + Value *Amt = 0; + if (Scale == 1) { + Amt = NumElements; + } else { + // If the allocation size is constant, form a constant mul expression + Amt = ConstantInt::get(Type::Int32Ty, Scale); + if (isa(NumElements)) + Amt = Multiply(cast(NumElements), cast(Amt)); + // otherwise multiply the amount and the number of elements + else if (Scale != 1) { + Instruction *Tmp = BinaryOperator::createMul(Amt, NumElements, "tmp"); + Amt = InsertNewInstBefore(Tmp, AI); + } + } + + if (int Offset = (AllocElTySize*ArrayOffset)/CastElTySize) { + Value *Off = ConstantInt::get(Type::Int32Ty, Offset, true); + Instruction *Tmp = BinaryOperator::createAdd(Amt, Off, "tmp"); + Amt = InsertNewInstBefore(Tmp, AI); + } + + AllocationInst *New; + if (isa(AI)) + New = new MallocInst(CastElTy, Amt, AI.getAlignment()); + else + New = new AllocaInst(CastElTy, Amt, AI.getAlignment()); + InsertNewInstBefore(New, AI); + New->takeName(&AI); + + // If the allocation has multiple uses, insert a cast and change all things + // that used it to use the new cast. This will also hack on CI, but it will + // die soon. + if (!AI.hasOneUse()) { + AddUsesToWorkList(AI); + // New is the allocation instruction, pointer typed. AI is the original + // allocation instruction, also pointer typed. Thus, cast to use is BitCast. + CastInst *NewCast = new BitCastInst(New, AI.getType(), "tmpcast"); + InsertNewInstBefore(NewCast, AI); + AI.replaceAllUsesWith(NewCast); + } + return ReplaceInstUsesWith(CI, New); +} + +/// CanEvaluateInDifferentType - Return true if we can take the specified value +/// and return it as type Ty without inserting any new casts and without +/// changing the computed value. This is used by code that tries to decide +/// whether promoting or shrinking integer operations to wider or smaller types +/// will allow us to eliminate a truncate or extend. +/// +/// This is a truncation operation if Ty is smaller than V->getType(), or an +/// extension operation if Ty is larger. +static bool CanEvaluateInDifferentType(Value *V, const IntegerType *Ty, + int &NumCastsRemoved) { + // We can always evaluate constants in another type. + if (isa(V)) + return true; + + Instruction *I = dyn_cast(V); + if (!I) return false; + + const IntegerType *OrigTy = cast(V->getType()); + + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + if (!I->hasOneUse()) return false; + // These operators can all arbitrarily be extended or truncated. + return CanEvaluateInDifferentType(I->getOperand(0), Ty, NumCastsRemoved) && + CanEvaluateInDifferentType(I->getOperand(1), Ty, NumCastsRemoved); + + case Instruction::Shl: + if (!I->hasOneUse()) return false; + // If we are truncating the result of this SHL, and if it's a shift of a + // constant amount, we can always perform a SHL in a smaller type. + if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { + uint32_t BitWidth = Ty->getBitWidth(); + if (BitWidth < OrigTy->getBitWidth() && + CI->getLimitedValue(BitWidth) < BitWidth) + return CanEvaluateInDifferentType(I->getOperand(0), Ty,NumCastsRemoved); + } + break; + case Instruction::LShr: + if (!I->hasOneUse()) return false; + // If this is a truncate of a logical shr, we can truncate it to a smaller + // lshr iff we know that the bits we would otherwise be shifting in are + // already zeros. + if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { + uint32_t OrigBitWidth = OrigTy->getBitWidth(); + uint32_t BitWidth = Ty->getBitWidth(); + if (BitWidth < OrigBitWidth && + MaskedValueIsZero(I->getOperand(0), + APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) && + CI->getLimitedValue(BitWidth) < BitWidth) { + return CanEvaluateInDifferentType(I->getOperand(0), Ty,NumCastsRemoved); + } + } + break; + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + // If this is a cast from the destination type, we can trivially eliminate + // it, and this will remove a cast overall. + if (I->getOperand(0)->getType() == Ty) { + // If the first operand is itself a cast, and is eliminable, do not count + // this as an eliminable cast. We would prefer to eliminate those two + // casts first. + if (isa(I->getOperand(0))) + return true; + + ++NumCastsRemoved; + return true; + } + break; + default: + // TODO: Can handle more cases here. + break; + } + + return false; +} + +/// EvaluateInDifferentType - Given an expression that +/// CanEvaluateInDifferentType returns true for, actually insert the code to +/// evaluate the expression. +Value *InstCombiner::EvaluateInDifferentType(Value *V, const Type *Ty, + bool isSigned) { + if (Constant *C = dyn_cast(V)) + return ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/); + + // Otherwise, it must be an instruction. + Instruction *I = cast(V); + Instruction *Res = 0; + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::AShr: + case Instruction::LShr: + case Instruction::Shl: { + Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned); + Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned); + Res = BinaryOperator::create((Instruction::BinaryOps)I->getOpcode(), + LHS, RHS, I->getName()); + break; + } + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::BitCast: + // If the source type of the cast is the type we're trying for then we can + // just return the source. There's no need to insert it because its not new. + if (I->getOperand(0)->getType() == Ty) + return I->getOperand(0); + + // Some other kind of cast, which shouldn't happen, so just .. + // FALL THROUGH + default: + // TODO: Can handle more cases here. + assert(0 && "Unreachable!"); + break; + } + + return InsertNewInstBefore(Res, *I); +} + +/// @brief Implement the transforms common to all CastInst visitors. +Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { + Value *Src = CI.getOperand(0); + + // Casting undef to anything results in undef so might as just replace it and + // get rid of the cast. + if (isa(Src)) // cast undef -> undef + return ReplaceInstUsesWith(CI, UndefValue::get(CI.getType())); + + // Many cases of "cast of a cast" are eliminable. If it's eliminable we just + // eliminate it now. + if (CastInst *CSrc = dyn_cast(Src)) { // A->B->C cast + if (Instruction::CastOps opc = + isEliminableCastPair(CSrc, CI.getOpcode(), CI.getType(), TD)) { + // The first cast (CSrc) is eliminable so we need to fix up or replace + // the second cast (CI). CSrc will then have a good chance of being dead. + return CastInst::create(opc, CSrc->getOperand(0), CI.getType()); + } + } + + // If we are casting a select then fold the cast into the select + if (SelectInst *SI = dyn_cast(Src)) + if (Instruction *NV = FoldOpIntoSelect(CI, SI, this)) + return NV; + + // If we are casting a PHI then fold the cast into the PHI + if (isa(Src)) + if (Instruction *NV = FoldOpIntoPhi(CI)) + return NV; + + return 0; +} + +/// @brief Implement the transforms for cast of pointer (bitcast/ptrtoint) +Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) { + Value *Src = CI.getOperand(0); + + if (GetElementPtrInst *GEP = dyn_cast(Src)) { + // If casting the result of a getelementptr instruction with no offset, turn + // this into a cast of the original pointer! + if (GEP->hasAllZeroIndices()) { + // Changing the cast operand is usually not a good idea but it is safe + // here because the pointer operand is being replaced with another + // pointer operand so the opcode doesn't need to change. + AddToWorkList(GEP); + CI.setOperand(0, GEP->getOperand(0)); + return &CI; + } + + // If the GEP has a single use, and the base pointer is a bitcast, and the + // GEP computes a constant offset, see if we can convert these three + // instructions into fewer. This typically happens with unions and other + // non-type-safe code. + if (GEP->hasOneUse() && isa(GEP->getOperand(0))) { + if (GEP->hasAllConstantIndices()) { + // We are guaranteed to get a constant from EmitGEPOffset. + ConstantInt *OffsetV = cast(EmitGEPOffset(GEP, CI, *this)); + int64_t Offset = OffsetV->getSExtValue(); + + // Get the base pointer input of the bitcast, and the type it points to. + Value *OrigBase = cast(GEP->getOperand(0))->getOperand(0); + const Type *GEPIdxTy = + cast(OrigBase->getType())->getElementType(); + if (GEPIdxTy->isSized()) { + SmallVector NewIndices; + + // Start with the index over the outer type. Note that the type size + // might be zero (even if the offset isn't zero) if the indexed type + // is something like [0 x {int, int}] + const Type *IntPtrTy = TD->getIntPtrType(); + int64_t FirstIdx = 0; + if (int64_t TySize = TD->getTypeSize(GEPIdxTy)) { + FirstIdx = Offset/TySize; + Offset %= TySize; + + // Handle silly modulus not returning values values [0..TySize). + if (Offset < 0) { + --FirstIdx; + Offset += TySize; + assert(Offset >= 0); + } + assert((uint64_t)Offset < (uint64_t)TySize &&"Out of range offset"); + } + + NewIndices.push_back(ConstantInt::get(IntPtrTy, FirstIdx)); + + // Index into the types. If we fail, set OrigBase to null. + while (Offset) { + if (const StructType *STy = dyn_cast(GEPIdxTy)) { + const StructLayout *SL = TD->getStructLayout(STy); + if (Offset < (int64_t)SL->getSizeInBytes()) { + unsigned Elt = SL->getElementContainingOffset(Offset); + NewIndices.push_back(ConstantInt::get(Type::Int32Ty, Elt)); + + Offset -= SL->getElementOffset(Elt); + GEPIdxTy = STy->getElementType(Elt); + } else { + // Otherwise, we can't index into this, bail out. + Offset = 0; + OrigBase = 0; + } + } else if (isa(GEPIdxTy) || isa(GEPIdxTy)) { + const SequentialType *STy = cast(GEPIdxTy); + if (uint64_t EltSize = TD->getTypeSize(STy->getElementType())) { + NewIndices.push_back(ConstantInt::get(IntPtrTy,Offset/EltSize)); + Offset %= EltSize; + } else { + NewIndices.push_back(ConstantInt::get(IntPtrTy, 0)); + } + GEPIdxTy = STy->getElementType(); + } else { + // Otherwise, we can't index into this, bail out. + Offset = 0; + OrigBase = 0; + } + } + if (OrigBase) { + // If we were able to index down into an element, create the GEP + // and bitcast the result. This eliminates one bitcast, potentially + // two. + Instruction *NGEP = new GetElementPtrInst(OrigBase, &NewIndices[0], + NewIndices.size(), ""); + InsertNewInstBefore(NGEP, CI); + NGEP->takeName(GEP); + + if (isa(CI)) + return new BitCastInst(NGEP, CI.getType()); + assert(isa(CI)); + return new PtrToIntInst(NGEP, CI.getType()); + } + } + } + } + } + + return commonCastTransforms(CI); +} + + + +/// Only the TRUNC, ZEXT, SEXT, and BITCAST can both operand and result as +/// integer types. This function implements the common transforms for all those +/// cases. +/// @brief Implement the transforms common to CastInst with integer operands +Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) { + if (Instruction *Result = commonCastTransforms(CI)) + return Result; + + Value *Src = CI.getOperand(0); + const Type *SrcTy = Src->getType(); + const Type *DestTy = CI.getType(); + uint32_t SrcBitSize = SrcTy->getPrimitiveSizeInBits(); + uint32_t DestBitSize = DestTy->getPrimitiveSizeInBits(); + + // See if we can simplify any instructions used by the LHS whose sole + // purpose is to compute bits we don't care about. + APInt KnownZero(DestBitSize, 0), KnownOne(DestBitSize, 0); + if (SimplifyDemandedBits(&CI, APInt::getAllOnesValue(DestBitSize), + KnownZero, KnownOne)) + return &CI; + + // If the source isn't an instruction or has more than one use then we + // can't do anything more. + Instruction *SrcI = dyn_cast(Src); + if (!SrcI || !Src->hasOneUse()) + return 0; + + // Attempt to propagate the cast into the instruction for int->int casts. + int NumCastsRemoved = 0; + if (!isa(CI) && + CanEvaluateInDifferentType(SrcI, cast(DestTy), + NumCastsRemoved)) { + // If this cast is a truncate, evaluting in a different type always + // eliminates the cast, so it is always a win. If this is a noop-cast + // this just removes a noop cast which isn't pointful, but simplifies + // the code. If this is a zero-extension, we need to do an AND to + // maintain the clear top-part of the computation, so we require that + // the input have eliminated at least one cast. If this is a sign + // extension, we insert two new casts (to do the extension) so we + // require that two casts have been eliminated. + bool DoXForm; + switch (CI.getOpcode()) { + default: + // All the others use floating point so we shouldn't actually + // get here because of the check above. + assert(0 && "Unknown cast type"); + case Instruction::Trunc: + DoXForm = true; + break; + case Instruction::ZExt: + DoXForm = NumCastsRemoved >= 1; + break; + case Instruction::SExt: + DoXForm = NumCastsRemoved >= 2; + break; + case Instruction::BitCast: + DoXForm = false; + break; + } + + if (DoXForm) { + Value *Res = EvaluateInDifferentType(SrcI, DestTy, + CI.getOpcode() == Instruction::SExt); + assert(Res->getType() == DestTy); + switch (CI.getOpcode()) { + default: assert(0 && "Unknown cast type!"); + case Instruction::Trunc: + case Instruction::BitCast: + // Just replace this cast with the result. + return ReplaceInstUsesWith(CI, Res); + case Instruction::ZExt: { + // We need to emit an AND to clear the high bits. + assert(SrcBitSize < DestBitSize && "Not a zext?"); + Constant *C = ConstantInt::get(APInt::getLowBitsSet(DestBitSize, + SrcBitSize)); + return BinaryOperator::createAnd(Res, C); + } + case Instruction::SExt: + // We need to emit a cast to truncate, then a cast to sext. + return CastInst::create(Instruction::SExt, + InsertCastBefore(Instruction::Trunc, Res, Src->getType(), + CI), DestTy); + } + } + } + + Value *Op0 = SrcI->getNumOperands() > 0 ? SrcI->getOperand(0) : 0; + Value *Op1 = SrcI->getNumOperands() > 1 ? SrcI->getOperand(1) : 0; + + switch (SrcI->getOpcode()) { + case Instruction::Add: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + // If we are discarding information, rewrite. + if (DestBitSize <= SrcBitSize && DestBitSize != 1) { + // Don't insert two casts if they cannot be eliminated. We allow + // two casts to be inserted if the sizes are the same. This could + // only be converting signedness, which is a noop. + if (DestBitSize == SrcBitSize || + !ValueRequiresCast(CI.getOpcode(), Op1, DestTy,TD) || + !ValueRequiresCast(CI.getOpcode(), Op0, DestTy, TD)) { + Instruction::CastOps opcode = CI.getOpcode(); + Value *Op0c = InsertOperandCastBefore(opcode, Op0, DestTy, SrcI); + Value *Op1c = InsertOperandCastBefore(opcode, Op1, DestTy, SrcI); + return BinaryOperator::create( + cast(SrcI)->getOpcode(), Op0c, Op1c); + } + } + + // cast (xor bool X, true) to int --> xor (cast bool X to int), 1 + if (isa(CI) && SrcBitSize == 1 && + SrcI->getOpcode() == Instruction::Xor && + Op1 == ConstantInt::getTrue() && + (!Op0->hasOneUse() || !isa(Op0))) { + Value *New = InsertOperandCastBefore(Instruction::ZExt, Op0, DestTy, &CI); + return BinaryOperator::createXor(New, ConstantInt::get(CI.getType(), 1)); + } + break; + case Instruction::SDiv: + case Instruction::UDiv: + case Instruction::SRem: + case Instruction::URem: + // If we are just changing the sign, rewrite. + if (DestBitSize == SrcBitSize) { + // Don't insert two casts if they cannot be eliminated. We allow + // two casts to be inserted if the sizes are the same. This could + // only be converting signedness, which is a noop. + if (!ValueRequiresCast(CI.getOpcode(), Op1, DestTy, TD) || + !ValueRequiresCast(CI.getOpcode(), Op0, DestTy, TD)) { + Value *Op0c = InsertOperandCastBefore(Instruction::BitCast, + Op0, DestTy, SrcI); + Value *Op1c = InsertOperandCastBefore(Instruction::BitCast, + Op1, DestTy, SrcI); + return BinaryOperator::create( + cast(SrcI)->getOpcode(), Op0c, Op1c); + } + } + break; + + case Instruction::Shl: + // Allow changing the sign of the source operand. Do not allow + // changing the size of the shift, UNLESS the shift amount is a + // constant. We must not change variable sized shifts to a smaller + // size, because it is undefined to shift more bits out than exist + // in the value. + if (DestBitSize == SrcBitSize || + (DestBitSize < SrcBitSize && isa(Op1))) { + Instruction::CastOps opcode = (DestBitSize == SrcBitSize ? + Instruction::BitCast : Instruction::Trunc); + Value *Op0c = InsertOperandCastBefore(opcode, Op0, DestTy, SrcI); + Value *Op1c = InsertOperandCastBefore(opcode, Op1, DestTy, SrcI); + return BinaryOperator::createShl(Op0c, Op1c); + } + break; + case Instruction::AShr: + // If this is a signed shr, and if all bits shifted in are about to be + // truncated off, turn it into an unsigned shr to allow greater + // simplifications. + if (DestBitSize < SrcBitSize && + isa(Op1)) { + uint32_t ShiftAmt = cast(Op1)->getLimitedValue(SrcBitSize); + if (SrcBitSize > ShiftAmt && SrcBitSize-ShiftAmt >= DestBitSize) { + // Insert the new logical shift right. + return BinaryOperator::createLShr(Op0, Op1); + } + } + break; + } + return 0; +} + +Instruction *InstCombiner::visitTrunc(TruncInst &CI) { + if (Instruction *Result = commonIntCastTransforms(CI)) + return Result; + + Value *Src = CI.getOperand(0); + const Type *Ty = CI.getType(); + uint32_t DestBitWidth = Ty->getPrimitiveSizeInBits(); + uint32_t SrcBitWidth = cast(Src->getType())->getBitWidth(); + + if (Instruction *SrcI = dyn_cast(Src)) { + switch (SrcI->getOpcode()) { + default: break; + case Instruction::LShr: + // We can shrink lshr to something smaller if we know the bits shifted in + // are already zeros. + if (ConstantInt *ShAmtV = dyn_cast(SrcI->getOperand(1))) { + uint32_t ShAmt = ShAmtV->getLimitedValue(SrcBitWidth); + + // Get a mask for the bits shifting in. + APInt Mask(APInt::getLowBitsSet(SrcBitWidth, ShAmt).shl(DestBitWidth)); + Value* SrcIOp0 = SrcI->getOperand(0); + if (SrcI->hasOneUse() && MaskedValueIsZero(SrcIOp0, Mask)) { + if (ShAmt >= DestBitWidth) // All zeros. + return ReplaceInstUsesWith(CI, Constant::getNullValue(Ty)); + + // Okay, we can shrink this. Truncate the input, then return a new + // shift. + Value *V1 = InsertCastBefore(Instruction::Trunc, SrcIOp0, Ty, CI); + Value *V2 = InsertCastBefore(Instruction::Trunc, SrcI->getOperand(1), + Ty, CI); + return BinaryOperator::createLShr(V1, V2); + } + } else { // This is a variable shr. + + // Turn 'trunc (lshr X, Y) to bool' into '(X & (1 << Y)) != 0'. This is + // more LLVM instructions, but allows '1 << Y' to be hoisted if + // loop-invariant and CSE'd. + if (CI.getType() == Type::Int1Ty && SrcI->hasOneUse()) { + Value *One = ConstantInt::get(SrcI->getType(), 1); + + Value *V = InsertNewInstBefore( + BinaryOperator::createShl(One, SrcI->getOperand(1), + "tmp"), CI); + V = InsertNewInstBefore(BinaryOperator::createAnd(V, + SrcI->getOperand(0), + "tmp"), CI); + Value *Zero = Constant::getNullValue(V->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, V, Zero); + } + } + break; + } + } + + return 0; +} + +Instruction *InstCombiner::visitZExt(ZExtInst &CI) { + // If one of the common conversion will work .. + if (Instruction *Result = commonIntCastTransforms(CI)) + return Result; + + Value *Src = CI.getOperand(0); + + // If this is a cast of a cast + if (CastInst *CSrc = dyn_cast(Src)) { // A->B->C cast + // If this is a TRUNC followed by a ZEXT then we are dealing with integral + // types and if the sizes are just right we can convert this into a logical + // 'and' which will be much cheaper than the pair of casts. + if (isa(CSrc)) { + // Get the sizes of the types involved + Value *A = CSrc->getOperand(0); + uint32_t SrcSize = A->getType()->getPrimitiveSizeInBits(); + uint32_t MidSize = CSrc->getType()->getPrimitiveSizeInBits(); + uint32_t DstSize = CI.getType()->getPrimitiveSizeInBits(); + // If we're actually extending zero bits and the trunc is a no-op + if (MidSize < DstSize && SrcSize == DstSize) { + // Replace both of the casts with an And of the type mask. + APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize)); + Constant *AndConst = ConstantInt::get(AndValue); + Instruction *And = + BinaryOperator::createAnd(CSrc->getOperand(0), AndConst); + // Unfortunately, if the type changed, we need to cast it back. + if (And->getType() != CI.getType()) { + And->setName(CSrc->getName()+".mask"); + InsertNewInstBefore(And, CI); + And = CastInst::createIntegerCast(And, CI.getType(), false/*ZExt*/); + } + return And; + } + } + } + + if (ICmpInst *ICI = dyn_cast(Src)) { + // If we are just checking for a icmp eq of a single bit and zext'ing it + // to an integer, then shift the bit to the appropriate place and then + // cast to integer to avoid the comparison. + if (ConstantInt *Op1C = dyn_cast(ICI->getOperand(1))) { + const APInt &Op1CV = Op1C->getValue(); + + // zext (x x>>u31 true if signbit set. + // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. + if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) || + (ICI->getPredicate() == ICmpInst::ICMP_SGT &&Op1CV.isAllOnesValue())){ + Value *In = ICI->getOperand(0); + Value *Sh = ConstantInt::get(In->getType(), + In->getType()->getPrimitiveSizeInBits()-1); + In = InsertNewInstBefore(BinaryOperator::createLShr(In, Sh, + In->getName()+".lobit"), + CI); + if (In->getType() != CI.getType()) + In = CastInst::createIntegerCast(In, CI.getType(), + false/*ZExt*/, "tmp", &CI); + + if (ICI->getPredicate() == ICmpInst::ICMP_SGT) { + Constant *One = ConstantInt::get(In->getType(), 1); + In = InsertNewInstBefore(BinaryOperator::createXor(In, One, + In->getName()+".not"), + CI); + } + + return ReplaceInstUsesWith(CI, In); + } + + + + // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. + // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. + // zext (X == 1) to i32 --> X iff X has only the low bit set. + // zext (X == 2) to i32 --> X>>1 iff X has only the 2nd bit set. + // zext (X != 0) to i32 --> X iff X has only the low bit set. + // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. + // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. + // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. + if ((Op1CV == 0 || Op1CV.isPowerOf2()) && + // This only works for EQ and NE + ICI->isEquality()) { + // If Op1C some other power of two, convert: + uint32_t BitWidth = Op1C->getType()->getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + APInt TypeMask(APInt::getAllOnesValue(BitWidth)); + ComputeMaskedBits(ICI->getOperand(0), TypeMask, KnownZero, KnownOne); + + APInt KnownZeroMask(~KnownZero); + if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? + bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; + if (Op1CV != 0 && (Op1CV != KnownZeroMask)) { + // (X&4) == 2 --> false + // (X&4) != 2 --> true + Constant *Res = ConstantInt::get(Type::Int1Ty, isNE); + Res = ConstantExpr::getZExt(Res, CI.getType()); + return ReplaceInstUsesWith(CI, Res); + } + + uint32_t ShiftAmt = KnownZeroMask.logBase2(); + Value *In = ICI->getOperand(0); + if (ShiftAmt) { + // Perform a logical shr by shiftamt. + // Insert the shift to put the result in the low bit. + In = InsertNewInstBefore( + BinaryOperator::createLShr(In, + ConstantInt::get(In->getType(), ShiftAmt), + In->getName()+".lobit"), CI); + } + + if ((Op1CV != 0) == isNE) { // Toggle the low bit. + Constant *One = ConstantInt::get(In->getType(), 1); + In = BinaryOperator::createXor(In, One, "tmp"); + InsertNewInstBefore(cast(In), CI); + } + + if (CI.getType() == In->getType()) + return ReplaceInstUsesWith(CI, In); + else + return CastInst::createIntegerCast(In, CI.getType(), false/*ZExt*/); + } + } + } + } + return 0; +} + +Instruction *InstCombiner::visitSExt(SExtInst &CI) { + if (Instruction *I = commonIntCastTransforms(CI)) + return I; + + Value *Src = CI.getOperand(0); + + // sext (x ashr x, 31 -> all ones if signed + // sext (x >s -1) -> ashr x, 31 -> all ones if not signed + if (ICmpInst *ICI = dyn_cast(Src)) { + // If we are just checking for a icmp eq of a single bit and zext'ing it + // to an integer, then shift the bit to the appropriate place and then + // cast to integer to avoid the comparison. + if (ConstantInt *Op1C = dyn_cast(ICI->getOperand(1))) { + const APInt &Op1CV = Op1C->getValue(); + + // sext (x x>>s31 true if signbit set. + // sext (x >s -1) to i32 --> (x>>s31)^-1 true if signbit clear. + if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) || + (ICI->getPredicate() == ICmpInst::ICMP_SGT &&Op1CV.isAllOnesValue())){ + Value *In = ICI->getOperand(0); + Value *Sh = ConstantInt::get(In->getType(), + In->getType()->getPrimitiveSizeInBits()-1); + In = InsertNewInstBefore(BinaryOperator::createAShr(In, Sh, + In->getName()+".lobit"), + CI); + if (In->getType() != CI.getType()) + In = CastInst::createIntegerCast(In, CI.getType(), + true/*SExt*/, "tmp", &CI); + + if (ICI->getPredicate() == ICmpInst::ICMP_SGT) + In = InsertNewInstBefore(BinaryOperator::createNot(In, + In->getName()+".not"), CI); + + return ReplaceInstUsesWith(CI, In); + } + } + } + + return 0; +} + +Instruction *InstCombiner::visitFPTrunc(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombiner::visitFPExt(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombiner::visitFPToUI(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombiner::visitFPToSI(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombiner::visitUIToFP(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombiner::visitSIToFP(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombiner::visitPtrToInt(CastInst &CI) { + return commonPointerCastTransforms(CI); +} + +Instruction *InstCombiner::visitIntToPtr(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { + // If the operands are integer typed then apply the integer transforms, + // otherwise just apply the common ones. + Value *Src = CI.getOperand(0); + const Type *SrcTy = Src->getType(); + const Type *DestTy = CI.getType(); + + if (SrcTy->isInteger() && DestTy->isInteger()) { + if (Instruction *Result = commonIntCastTransforms(CI)) + return Result; + } else if (isa(SrcTy)) { + if (Instruction *I = commonPointerCastTransforms(CI)) + return I; + } else { + if (Instruction *Result = commonCastTransforms(CI)) + return Result; + } + + + // Get rid of casts from one type to the same type. These are useless and can + // be replaced by the operand. + if (DestTy == Src->getType()) + return ReplaceInstUsesWith(CI, Src); + + if (const PointerType *DstPTy = dyn_cast(DestTy)) { + const PointerType *SrcPTy = cast(SrcTy); + const Type *DstElTy = DstPTy->getElementType(); + const Type *SrcElTy = SrcPTy->getElementType(); + + // If we are casting a malloc or alloca to a pointer to a type of the same + // size, rewrite the allocation instruction to allocate the "right" type. + if (AllocationInst *AI = dyn_cast(Src)) + if (Instruction *V = PromoteCastOfAllocation(CI, *AI)) + return V; + + // If the source and destination are pointers, and this cast is equivalent + // to a getelementptr X, 0, 0, 0... turn it into the appropriate gep. + // This can enhance SROA and other transforms that want type-safe pointers. + Constant *ZeroUInt = Constant::getNullValue(Type::Int32Ty); + unsigned NumZeros = 0; + while (SrcElTy != DstElTy && + isa(SrcElTy) && !isa(SrcElTy) && + SrcElTy->getNumContainedTypes() /* not "{}" */) { + SrcElTy = cast(SrcElTy)->getTypeAtIndex(ZeroUInt); + ++NumZeros; + } + + // If we found a path from the src to dest, create the getelementptr now. + if (SrcElTy == DstElTy) { + SmallVector Idxs(NumZeros+1, ZeroUInt); + return new GetElementPtrInst(Src, &Idxs[0], Idxs.size()); + } + } + + if (ShuffleVectorInst *SVI = dyn_cast(Src)) { + if (SVI->hasOneUse()) { + // Okay, we have (bitconvert (shuffle ..)). Check to see if this is + // a bitconvert to a vector with the same # elts. + if (isa(DestTy) && + cast(DestTy)->getNumElements() == + SVI->getType()->getNumElements()) { + CastInst *Tmp; + // If either of the operands is a cast from CI.getType(), then + // evaluating the shuffle in the casted destination's type will allow + // us to eliminate at least one cast. + if (((Tmp = dyn_cast(SVI->getOperand(0))) && + Tmp->getOperand(0)->getType() == DestTy) || + ((Tmp = dyn_cast(SVI->getOperand(1))) && + Tmp->getOperand(0)->getType() == DestTy)) { + Value *LHS = InsertOperandCastBefore(Instruction::BitCast, + SVI->getOperand(0), DestTy, &CI); + Value *RHS = InsertOperandCastBefore(Instruction::BitCast, + SVI->getOperand(1), DestTy, &CI); + // Return a new shuffle vector. Use the same element ID's, as we + // know the vector types match #elts. + return new ShuffleVectorInst(LHS, RHS, SVI->getOperand(2)); + } + } + } + } + return 0; +} + +/// GetSelectFoldableOperands - We want to turn code that looks like this: +/// %C = or %A, %B +/// %D = select %cond, %C, %A +/// into: +/// %C = select %cond, %B, 0 +/// %D = or %A, %C +/// +/// Assuming that the specified instruction is an operand to the select, return +/// a bitmask indicating which operands of this instruction are foldable if they +/// equal the other incoming value of the select. +/// +static unsigned GetSelectFoldableOperands(Instruction *I) { + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + return 3; // Can fold through either operand. + case Instruction::Sub: // Can only fold on the amount subtracted. + case Instruction::Shl: // Can only fold on the shift amount. + case Instruction::LShr: + case Instruction::AShr: + return 1; + default: + return 0; // Cannot fold + } +} + +/// GetSelectFoldableConstant - For the same transformation as the previous +/// function, return the identity constant that goes into the select. +static Constant *GetSelectFoldableConstant(Instruction *I) { + switch (I->getOpcode()) { + default: assert(0 && "This cannot happen!"); abort(); + case Instruction::Add: + case Instruction::Sub: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + return Constant::getNullValue(I->getType()); + case Instruction::And: + return Constant::getAllOnesValue(I->getType()); + case Instruction::Mul: + return ConstantInt::get(I->getType(), 1); + } +} + +/// FoldSelectOpOp - Here we have (select c, TI, FI), and we know that TI and FI +/// have the same opcode and only one use each. Try to simplify this. +Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, + Instruction *FI) { + if (TI->getNumOperands() == 1) { + // If this is a non-volatile load or a cast from the same type, + // merge. + if (TI->isCast()) { + if (TI->getOperand(0)->getType() != FI->getOperand(0)->getType()) + return 0; + } else { + return 0; // unknown unary op. + } + + // Fold this by inserting a select from the input values. + SelectInst *NewSI = new SelectInst(SI.getCondition(), TI->getOperand(0), + FI->getOperand(0), SI.getName()+".v"); + InsertNewInstBefore(NewSI, SI); + return CastInst::create(Instruction::CastOps(TI->getOpcode()), NewSI, + TI->getType()); + } + + // Only handle binary operators here. + if (!isa(TI)) + return 0; + + // Figure out if the operations have any operands in common. + Value *MatchOp, *OtherOpT, *OtherOpF; + bool MatchIsOpZero; + if (TI->getOperand(0) == FI->getOperand(0)) { + MatchOp = TI->getOperand(0); + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = true; + } else if (TI->getOperand(1) == FI->getOperand(1)) { + MatchOp = TI->getOperand(1); + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = false; + } else if (!TI->isCommutative()) { + return 0; + } else if (TI->getOperand(0) == FI->getOperand(1)) { + MatchOp = TI->getOperand(0); + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = true; + } else if (TI->getOperand(1) == FI->getOperand(0)) { + MatchOp = TI->getOperand(1); + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = true; + } else { + return 0; + } + + // If we reach here, they do have operations in common. + SelectInst *NewSI = new SelectInst(SI.getCondition(), OtherOpT, + OtherOpF, SI.getName()+".v"); + InsertNewInstBefore(NewSI, SI); + + if (BinaryOperator *BO = dyn_cast(TI)) { + if (MatchIsOpZero) + return BinaryOperator::create(BO->getOpcode(), MatchOp, NewSI); + else + return BinaryOperator::create(BO->getOpcode(), NewSI, MatchOp); + } + assert(0 && "Shouldn't get here"); + return 0; +} + +Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + // select true, X, Y -> X + // select false, X, Y -> Y + if (ConstantInt *C = dyn_cast(CondVal)) + return ReplaceInstUsesWith(SI, C->getZExtValue() ? TrueVal : FalseVal); + + // select C, X, X -> X + if (TrueVal == FalseVal) + return ReplaceInstUsesWith(SI, TrueVal); + + if (isa(TrueVal)) // select C, undef, X -> X + return ReplaceInstUsesWith(SI, FalseVal); + if (isa(FalseVal)) // select C, X, undef -> X + return ReplaceInstUsesWith(SI, TrueVal); + if (isa(CondVal)) { // select undef, X, Y -> X or Y + if (isa(TrueVal)) + return ReplaceInstUsesWith(SI, TrueVal); + else + return ReplaceInstUsesWith(SI, FalseVal); + } + + if (SI.getType() == Type::Int1Ty) { + if (ConstantInt *C = dyn_cast(TrueVal)) { + if (C->getZExtValue()) { + // Change: A = select B, true, C --> A = or B, C + return BinaryOperator::createOr(CondVal, FalseVal); + } else { + // Change: A = select B, false, C --> A = and !B, C + Value *NotCond = + InsertNewInstBefore(BinaryOperator::createNot(CondVal, + "not."+CondVal->getName()), SI); + return BinaryOperator::createAnd(NotCond, FalseVal); + } + } else if (ConstantInt *C = dyn_cast(FalseVal)) { + if (C->getZExtValue() == false) { + // Change: A = select B, C, false --> A = and B, C + return BinaryOperator::createAnd(CondVal, TrueVal); + } else { + // Change: A = select B, C, true --> A = or !B, C + Value *NotCond = + InsertNewInstBefore(BinaryOperator::createNot(CondVal, + "not."+CondVal->getName()), SI); + return BinaryOperator::createOr(NotCond, TrueVal); + } + } + } + + // Selecting between two integer constants? + if (ConstantInt *TrueValC = dyn_cast(TrueVal)) + if (ConstantInt *FalseValC = dyn_cast(FalseVal)) { + // select C, 1, 0 -> zext C to int + if (FalseValC->isZero() && TrueValC->getValue() == 1) { + return CastInst::create(Instruction::ZExt, CondVal, SI.getType()); + } else if (TrueValC->isZero() && FalseValC->getValue() == 1) { + // select C, 0, 1 -> zext !C to int + Value *NotCond = + InsertNewInstBefore(BinaryOperator::createNot(CondVal, + "not."+CondVal->getName()), SI); + return CastInst::create(Instruction::ZExt, NotCond, SI.getType()); + } + + // FIXME: Turn select 0/-1 and -1/0 into sext from condition! + + if (ICmpInst *IC = dyn_cast(SI.getCondition())) { + + // (x ashr x, 31 + if (TrueValC->isAllOnesValue() && FalseValC->isZero()) + if (ConstantInt *CmpCst = dyn_cast(IC->getOperand(1))) { + if (IC->getPredicate() == ICmpInst::ICMP_SLT && CmpCst->isZero()) { + // The comparison constant and the result are not neccessarily the + // same width. Make an all-ones value by inserting a AShr. + Value *X = IC->getOperand(0); + uint32_t Bits = X->getType()->getPrimitiveSizeInBits(); + Constant *ShAmt = ConstantInt::get(X->getType(), Bits-1); + Instruction *SRA = BinaryOperator::create(Instruction::AShr, X, + ShAmt, "ones"); + InsertNewInstBefore(SRA, SI); + + // Finally, convert to the type of the select RHS. We figure out + // if this requires a SExt, Trunc or BitCast based on the sizes. + Instruction::CastOps opc = Instruction::BitCast; + uint32_t SRASize = SRA->getType()->getPrimitiveSizeInBits(); + uint32_t SISize = SI.getType()->getPrimitiveSizeInBits(); + if (SRASize < SISize) + opc = Instruction::SExt; + else if (SRASize > SISize) + opc = Instruction::Trunc; + return CastInst::create(opc, SRA, SI.getType()); + } + } + + + // If one of the constants is zero (we know they can't both be) and we + // have an icmp instruction with zero, and we have an 'and' with the + // non-constant value, eliminate this whole mess. This corresponds to + // cases like this: ((X & 27) ? 27 : 0) + if (TrueValC->isZero() || FalseValC->isZero()) + if (IC->isEquality() && isa(IC->getOperand(1)) && + cast(IC->getOperand(1))->isNullValue()) + if (Instruction *ICA = dyn_cast(IC->getOperand(0))) + if (ICA->getOpcode() == Instruction::And && + isa(ICA->getOperand(1)) && + (ICA->getOperand(1) == TrueValC || + ICA->getOperand(1) == FalseValC) && + isOneBitSet(cast(ICA->getOperand(1)))) { + // Okay, now we know that everything is set up, we just don't + // know whether we have a icmp_ne or icmp_eq and whether the + // true or false val is the zero. + bool ShouldNotVal = !TrueValC->isZero(); + ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE; + Value *V = ICA; + if (ShouldNotVal) + V = InsertNewInstBefore(BinaryOperator::create( + Instruction::Xor, V, ICA->getOperand(1)), SI); + return ReplaceInstUsesWith(SI, V); + } + } + } + + // See if we are selecting two values based on a comparison of the two values. + if (FCmpInst *FCI = dyn_cast(CondVal)) { + if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) { + // Transform (X == Y) ? X : Y -> Y + if (FCI->getPredicate() == FCmpInst::FCMP_OEQ) + return ReplaceInstUsesWith(SI, FalseVal); + // Transform (X != Y) ? X : Y -> X + if (FCI->getPredicate() == FCmpInst::FCMP_ONE) + return ReplaceInstUsesWith(SI, TrueVal); + // NOTE: if we wanted to, this is where to detect MIN/MAX/ABS/etc. + + } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){ + // Transform (X == Y) ? Y : X -> X + if (FCI->getPredicate() == FCmpInst::FCMP_OEQ) + return ReplaceInstUsesWith(SI, FalseVal); + // Transform (X != Y) ? Y : X -> Y + if (FCI->getPredicate() == FCmpInst::FCMP_ONE) + return ReplaceInstUsesWith(SI, TrueVal); + // NOTE: if we wanted to, this is where to detect MIN/MAX/ABS/etc. + } + } + + // See if we are selecting two values based on a comparison of the two values. + if (ICmpInst *ICI = dyn_cast(CondVal)) { + if (ICI->getOperand(0) == TrueVal && ICI->getOperand(1) == FalseVal) { + // Transform (X == Y) ? X : Y -> Y + if (ICI->getPredicate() == ICmpInst::ICMP_EQ) + return ReplaceInstUsesWith(SI, FalseVal); + // Transform (X != Y) ? X : Y -> X + if (ICI->getPredicate() == ICmpInst::ICMP_NE) + return ReplaceInstUsesWith(SI, TrueVal); + // NOTE: if we wanted to, this is where to detect MIN/MAX/ABS/etc. + + } else if (ICI->getOperand(0) == FalseVal && ICI->getOperand(1) == TrueVal){ + // Transform (X == Y) ? Y : X -> X + if (ICI->getPredicate() == ICmpInst::ICMP_EQ) + return ReplaceInstUsesWith(SI, FalseVal); + // Transform (X != Y) ? Y : X -> Y + if (ICI->getPredicate() == ICmpInst::ICMP_NE) + return ReplaceInstUsesWith(SI, TrueVal); + // NOTE: if we wanted to, this is where to detect MIN/MAX/ABS/etc. + } + } + + if (Instruction *TI = dyn_cast(TrueVal)) + if (Instruction *FI = dyn_cast(FalseVal)) + if (TI->hasOneUse() && FI->hasOneUse()) { + Instruction *AddOp = 0, *SubOp = 0; + + // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) + if (TI->getOpcode() == FI->getOpcode()) + if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) + return IV; + + // Turn select C, (X+Y), (X-Y) --> (X+(select C, Y, (-Y))). This is + // even legal for FP. + if (TI->getOpcode() == Instruction::Sub && + FI->getOpcode() == Instruction::Add) { + AddOp = FI; SubOp = TI; + } else if (FI->getOpcode() == Instruction::Sub && + TI->getOpcode() == Instruction::Add) { + AddOp = TI; SubOp = FI; + } + + if (AddOp) { + Value *OtherAddOp = 0; + if (SubOp->getOperand(0) == AddOp->getOperand(0)) { + OtherAddOp = AddOp->getOperand(1); + } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { + OtherAddOp = AddOp->getOperand(0); + } + + if (OtherAddOp) { + // So at this point we know we have (Y -> OtherAddOp): + // select C, (add X, Y), (sub X, Z) + Value *NegVal; // Compute -Z + if (Constant *C = dyn_cast(SubOp->getOperand(1))) { + NegVal = ConstantExpr::getNeg(C); + } else { + NegVal = InsertNewInstBefore( + BinaryOperator::createNeg(SubOp->getOperand(1), "tmp"), SI); + } + + Value *NewTrueOp = OtherAddOp; + Value *NewFalseOp = NegVal; + if (AddOp != TI) + std::swap(NewTrueOp, NewFalseOp); + Instruction *NewSel = + new SelectInst(CondVal, NewTrueOp,NewFalseOp,SI.getName()+".p"); + + NewSel = InsertNewInstBefore(NewSel, SI); + return BinaryOperator::createAdd(SubOp->getOperand(0), NewSel); + } + } + } + + // See if we can fold the select into one of our operands. + if (SI.getType()->isInteger()) { + // See the comment above GetSelectFoldableOperands for a description of the + // transformation we are doing here. + if (Instruction *TVI = dyn_cast(TrueVal)) + if (TVI->hasOneUse() && TVI->getNumOperands() == 2 && + !isa(FalseVal)) + if (unsigned SFO = GetSelectFoldableOperands(TVI)) { + unsigned OpToFold = 0; + if ((SFO & 1) && FalseVal == TVI->getOperand(0)) { + OpToFold = 1; + } else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) { + OpToFold = 2; + } + + if (OpToFold) { + Constant *C = GetSelectFoldableConstant(TVI); + Instruction *NewSel = + new SelectInst(SI.getCondition(), TVI->getOperand(2-OpToFold), C); + InsertNewInstBefore(NewSel, SI); + NewSel->takeName(TVI); + if (BinaryOperator *BO = dyn_cast(TVI)) + return BinaryOperator::create(BO->getOpcode(), FalseVal, NewSel); + else { + assert(0 && "Unknown instruction!!"); + } + } + } + + if (Instruction *FVI = dyn_cast(FalseVal)) + if (FVI->hasOneUse() && FVI->getNumOperands() == 2 && + !isa(TrueVal)) + if (unsigned SFO = GetSelectFoldableOperands(FVI)) { + unsigned OpToFold = 0; + if ((SFO & 1) && TrueVal == FVI->getOperand(0)) { + OpToFold = 1; + } else if ((SFO & 2) && TrueVal == FVI->getOperand(1)) { + OpToFold = 2; + } + + if (OpToFold) { + Constant *C = GetSelectFoldableConstant(FVI); + Instruction *NewSel = + new SelectInst(SI.getCondition(), C, FVI->getOperand(2-OpToFold)); + InsertNewInstBefore(NewSel, SI); + NewSel->takeName(FVI); + if (BinaryOperator *BO = dyn_cast(FVI)) + return BinaryOperator::create(BO->getOpcode(), TrueVal, NewSel); + else + assert(0 && "Unknown instruction!!"); + } + } + } + + if (BinaryOperator::isNot(CondVal)) { + SI.setOperand(0, BinaryOperator::getNotArgument(CondVal)); + SI.setOperand(1, FalseVal); + SI.setOperand(2, TrueVal); + return &SI; + } + + return 0; +} + +/// GetKnownAlignment - If the specified pointer has an alignment that we can +/// determine, return it, otherwise return 0. +static unsigned GetKnownAlignment(Value *V, TargetData *TD) { + if (GlobalVariable *GV = dyn_cast(V)) { + unsigned Align = GV->getAlignment(); + if (Align == 0 && TD) + Align = TD->getPrefTypeAlignment(GV->getType()->getElementType()); + return Align; + } else if (AllocationInst *AI = dyn_cast(V)) { + unsigned Align = AI->getAlignment(); + if (Align == 0 && TD) { + if (isa(AI)) + Align = TD->getPrefTypeAlignment(AI->getType()->getElementType()); + else if (isa(AI)) { + // Malloc returns maximally aligned memory. + Align = TD->getABITypeAlignment(AI->getType()->getElementType()); + Align = + std::max(Align, + (unsigned)TD->getABITypeAlignment(Type::DoubleTy)); + Align = + std::max(Align, + (unsigned)TD->getABITypeAlignment(Type::Int64Ty)); + } + } + return Align; + } else if (isa(V) || + (isa(V) && + cast(V)->getOpcode() == Instruction::BitCast)) { + User *CI = cast(V); + if (isa(CI->getOperand(0)->getType())) + return GetKnownAlignment(CI->getOperand(0), TD); + return 0; + } else if (User *GEPI = dyn_castGetElementPtr(V)) { + unsigned BaseAlignment = GetKnownAlignment(GEPI->getOperand(0), TD); + if (BaseAlignment == 0) return 0; + + // If all indexes are zero, it is just the alignment of the base pointer. + bool AllZeroOperands = true; + for (unsigned i = 1, e = GEPI->getNumOperands(); i != e; ++i) + if (!isa(GEPI->getOperand(i)) || + !cast(GEPI->getOperand(i))->isNullValue()) { + AllZeroOperands = false; + break; + } + if (AllZeroOperands) + return BaseAlignment; + + // Otherwise, if the base alignment is >= the alignment we expect for the + // base pointer type, then we know that the resultant pointer is aligned at + // least as much as its type requires. + if (!TD) return 0; + + const Type *BasePtrTy = GEPI->getOperand(0)->getType(); + const PointerType *PtrTy = cast(BasePtrTy); + if (TD->getABITypeAlignment(PtrTy->getElementType()) + <= BaseAlignment) { + const Type *GEPTy = GEPI->getType(); + const PointerType *GEPPtrTy = cast(GEPTy); + return TD->getABITypeAlignment(GEPPtrTy->getElementType()); + } + return 0; + } + return 0; +} + + +/// visitCallInst - CallInst simplification. This mostly only handles folding +/// of intrinsic instructions. For normal calls, it allows visitCallSite to do +/// the heavy lifting. +/// +Instruction *InstCombiner::visitCallInst(CallInst &CI) { + IntrinsicInst *II = dyn_cast(&CI); + if (!II) return visitCallSite(&CI); + + // Intrinsics cannot occur in an invoke, so handle them here instead of in + // visitCallSite. + if (MemIntrinsic *MI = dyn_cast(II)) { + bool Changed = false; + + // memmove/cpy/set of zero bytes is a noop. + if (Constant *NumBytes = dyn_cast(MI->getLength())) { + if (NumBytes->isNullValue()) return EraseInstFromFunction(CI); + + if (ConstantInt *CI = dyn_cast(NumBytes)) + if (CI->getZExtValue() == 1) { + // Replace the instruction with just byte operations. We would + // transform other cases to loads/stores, but we don't know if + // alignment is sufficient. + } + } + + // If we have a memmove and the source operation is a constant global, + // then the source and dest pointers can't alias, so we can change this + // into a call to memcpy. + if (MemMoveInst *MMI = dyn_cast(II)) { + if (GlobalVariable *GVSrc = dyn_cast(MMI->getSource())) + if (GVSrc->isConstant()) { + Module *M = CI.getParent()->getParent()->getParent(); + const char *Name; + if (CI.getCalledFunction()->getFunctionType()->getParamType(2) == + Type::Int32Ty) + Name = "llvm.memcpy.i32"; + else + Name = "llvm.memcpy.i64"; + Constant *MemCpy = M->getOrInsertFunction(Name, + CI.getCalledFunction()->getFunctionType()); + CI.setOperand(0, MemCpy); + Changed = true; + } + } + + // If we can determine a pointer alignment that is bigger than currently + // set, update the alignment. + if (isa(MI) || isa(MI)) { + unsigned Alignment1 = GetKnownAlignment(MI->getOperand(1), TD); + unsigned Alignment2 = GetKnownAlignment(MI->getOperand(2), TD); + unsigned Align = std::min(Alignment1, Alignment2); + if (MI->getAlignment()->getZExtValue() < Align) { + MI->setAlignment(ConstantInt::get(Type::Int32Ty, Align)); + Changed = true; + } + } else if (isa(MI)) { + unsigned Alignment = GetKnownAlignment(MI->getDest(), TD); + if (MI->getAlignment()->getZExtValue() < Alignment) { + MI->setAlignment(ConstantInt::get(Type::Int32Ty, Alignment)); + Changed = true; + } + } + + if (Changed) return II; + } else { + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::ppc_altivec_lvx: + case Intrinsic::ppc_altivec_lvxl: + case Intrinsic::x86_sse_loadu_ps: + case Intrinsic::x86_sse2_loadu_pd: + case Intrinsic::x86_sse2_loadu_dq: + // Turn PPC lvx -> load if the pointer is known aligned. + // Turn X86 loadups -> load if the pointer is known aligned. + if (GetKnownAlignment(II->getOperand(1), TD) >= 16) { + Value *Ptr = InsertCastBefore(Instruction::BitCast, II->getOperand(1), + PointerType::get(II->getType()), CI); + return new LoadInst(Ptr); + } + break; + case Intrinsic::ppc_altivec_stvx: + case Intrinsic::ppc_altivec_stvxl: + // Turn stvx -> store if the pointer is known aligned. + if (GetKnownAlignment(II->getOperand(2), TD) >= 16) { + const Type *OpPtrTy = PointerType::get(II->getOperand(1)->getType()); + Value *Ptr = InsertCastBefore(Instruction::BitCast, II->getOperand(2), + OpPtrTy, CI); + return new StoreInst(II->getOperand(1), Ptr); + } + break; + case Intrinsic::x86_sse_storeu_ps: + case Intrinsic::x86_sse2_storeu_pd: + case Intrinsic::x86_sse2_storeu_dq: + case Intrinsic::x86_sse2_storel_dq: + // Turn X86 storeu -> store if the pointer is known aligned. + if (GetKnownAlignment(II->getOperand(1), TD) >= 16) { + const Type *OpPtrTy = PointerType::get(II->getOperand(2)->getType()); + Value *Ptr = InsertCastBefore(Instruction::BitCast, II->getOperand(1), + OpPtrTy, CI); + return new StoreInst(II->getOperand(2), Ptr); + } + break; + + case Intrinsic::x86_sse_cvttss2si: { + // These intrinsics only demands the 0th element of its input vector. If + // we can simplify the input based on that, do so now. + uint64_t UndefElts; + if (Value *V = SimplifyDemandedVectorElts(II->getOperand(1), 1, + UndefElts)) { + II->setOperand(1, V); + return II; + } + break; + } + + case Intrinsic::ppc_altivec_vperm: + // Turn vperm(V1,V2,mask) -> shuffle(V1,V2,mask) if mask is a constant. + if (ConstantVector *Mask = dyn_cast(II->getOperand(3))) { + assert(Mask->getNumOperands() == 16 && "Bad type for intrinsic!"); + + // Check that all of the elements are integer constants or undefs. + bool AllEltsOk = true; + for (unsigned i = 0; i != 16; ++i) { + if (!isa(Mask->getOperand(i)) && + !isa(Mask->getOperand(i))) { + AllEltsOk = false; + break; + } + } + + if (AllEltsOk) { + // Cast the input vectors to byte vectors. + Value *Op0 = InsertCastBefore(Instruction::BitCast, + II->getOperand(1), Mask->getType(), CI); + Value *Op1 = InsertCastBefore(Instruction::BitCast, + II->getOperand(2), Mask->getType(), CI); + Value *Result = UndefValue::get(Op0->getType()); + + // Only extract each element once. + Value *ExtractedElts[32]; + memset(ExtractedElts, 0, sizeof(ExtractedElts)); + + for (unsigned i = 0; i != 16; ++i) { + if (isa(Mask->getOperand(i))) + continue; + unsigned Idx=cast(Mask->getOperand(i))->getZExtValue(); + Idx &= 31; // Match the hardware behavior. + + if (ExtractedElts[Idx] == 0) { + Instruction *Elt = + new ExtractElementInst(Idx < 16 ? Op0 : Op1, Idx&15, "tmp"); + InsertNewInstBefore(Elt, CI); + ExtractedElts[Idx] = Elt; + } + + // Insert this value into the result vector. + Result = new InsertElementInst(Result, ExtractedElts[Idx], i,"tmp"); + InsertNewInstBefore(cast(Result), CI); + } + return CastInst::create(Instruction::BitCast, Result, CI.getType()); + } + } + break; + + case Intrinsic::stackrestore: { + // If the save is right next to the restore, remove the restore. This can + // happen when variable allocas are DCE'd. + if (IntrinsicInst *SS = dyn_cast(II->getOperand(1))) { + if (SS->getIntrinsicID() == Intrinsic::stacksave) { + BasicBlock::iterator BI = SS; + if (&*++BI == II) + return EraseInstFromFunction(CI); + } + } + + // If the stack restore is in a return/unwind block and if there are no + // allocas or calls between the restore and the return, nuke the restore. + TerminatorInst *TI = II->getParent()->getTerminator(); + if (isa(TI) || isa(TI)) { + BasicBlock::iterator BI = II; + bool CannotRemove = false; + for (++BI; &*BI != TI; ++BI) { + if (isa(BI) || + (isa(BI) && !isa(BI))) { + CannotRemove = true; + break; + } + } + if (!CannotRemove) + return EraseInstFromFunction(CI); + } + break; + } + } + } + + return visitCallSite(II); +} + +// InvokeInst simplification +// +Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) { + return visitCallSite(&II); +} + +// visitCallSite - Improvements for call and invoke instructions. +// +Instruction *InstCombiner::visitCallSite(CallSite CS) { + bool Changed = false; + + // If the callee is a constexpr cast of a function, attempt to move the cast + // to the arguments of the call/invoke. + if (transformConstExprCastCall(CS)) return 0; + + Value *Callee = CS.getCalledValue(); + + if (Function *CalleeF = dyn_cast(Callee)) + if (CalleeF->getCallingConv() != CS.getCallingConv()) { + Instruction *OldCall = CS.getInstruction(); + // If the call and callee calling conventions don't match, this call must + // be unreachable, as the call is undefined. + new StoreInst(ConstantInt::getTrue(), + UndefValue::get(PointerType::get(Type::Int1Ty)), OldCall); + if (!OldCall->use_empty()) + OldCall->replaceAllUsesWith(UndefValue::get(OldCall->getType())); + if (isa(OldCall)) // Not worth removing an invoke here. + return EraseInstFromFunction(*OldCall); + return 0; + } + + if (isa(Callee) || isa(Callee)) { + // This instruction is not reachable, just remove it. We insert a store to + // undef so that we know that this code is not reachable, despite the fact + // that we can't modify the CFG here. + new StoreInst(ConstantInt::getTrue(), + UndefValue::get(PointerType::get(Type::Int1Ty)), + CS.getInstruction()); + + if (!CS.getInstruction()->use_empty()) + CS.getInstruction()-> + replaceAllUsesWith(UndefValue::get(CS.getInstruction()->getType())); + + if (InvokeInst *II = dyn_cast(CS.getInstruction())) { + // Don't break the CFG, insert a dummy cond branch. + new BranchInst(II->getNormalDest(), II->getUnwindDest(), + ConstantInt::getTrue(), II); + } + return EraseInstFromFunction(*CS.getInstruction()); + } + + const PointerType *PTy = cast(Callee->getType()); + const FunctionType *FTy = cast(PTy->getElementType()); + if (FTy->isVarArg()) { + // See if we can optimize any arguments passed through the varargs area of + // the call. + for (CallSite::arg_iterator I = CS.arg_begin()+FTy->getNumParams(), + E = CS.arg_end(); I != E; ++I) + if (CastInst *CI = dyn_cast(*I)) { + // If this cast does not effect the value passed through the varargs + // area, we can eliminate the use of the cast. + Value *Op = CI->getOperand(0); + if (CI->isLosslessCast()) { + *I = Op; + Changed = true; + } + } + } + + return Changed ? CS.getInstruction() : 0; +} + +// transformConstExprCastCall - If the callee is a constexpr cast of a function, +// attempt to move the cast to the arguments of the call/invoke. +// +bool InstCombiner::transformConstExprCastCall(CallSite CS) { + if (!isa(CS.getCalledValue())) return false; + ConstantExpr *CE = cast(CS.getCalledValue()); + if (CE->getOpcode() != Instruction::BitCast || + !isa(CE->getOperand(0))) + return false; + Function *Callee = cast(CE->getOperand(0)); + Instruction *Caller = CS.getInstruction(); + + // Okay, this is a cast from a function to a different type. Unless doing so + // would cause a type conversion of one of our arguments, change this call to + // be a direct call with arguments casted to the appropriate types. + // + const FunctionType *FT = Callee->getFunctionType(); + const Type *OldRetTy = Caller->getType(); + + const FunctionType *ActualFT = + cast(cast(CE->getType())->getElementType()); + + // If the parameter attributes don't match up, don't do the xform. We don't + // want to lose an sret attribute or something. + if (FT->getParamAttrs() != ActualFT->getParamAttrs()) + return false; + + // Check to see if we are changing the return type... + if (OldRetTy != FT->getReturnType()) { + if (Callee->isDeclaration() && !Caller->use_empty() && + // Conversion is ok if changing from pointer to int of same size. + !(isa(FT->getReturnType()) && + TD->getIntPtrType() == OldRetTy)) + return false; // Cannot transform this return value. + + // If the callsite is an invoke instruction, and the return value is used by + // a PHI node in a successor, we cannot change the return type of the call + // because there is no place to put the cast instruction (without breaking + // the critical edge). Bail out in this case. + if (!Caller->use_empty()) + if (InvokeInst *II = dyn_cast(Caller)) + for (Value::use_iterator UI = II->use_begin(), E = II->use_end(); + UI != E; ++UI) + if (PHINode *PN = dyn_cast(*UI)) + if (PN->getParent() == II->getNormalDest() || + PN->getParent() == II->getUnwindDest()) + return false; + } + + unsigned NumActualArgs = unsigned(CS.arg_end()-CS.arg_begin()); + unsigned NumCommonArgs = std::min(FT->getNumParams(), NumActualArgs); + + CallSite::arg_iterator AI = CS.arg_begin(); + for (unsigned i = 0, e = NumCommonArgs; i != e; ++i, ++AI) { + const Type *ParamTy = FT->getParamType(i); + const Type *ActTy = (*AI)->getType(); + ConstantInt *c = dyn_cast(*AI); + //Some conversions are safe even if we do not have a body. + //Either we can cast directly, or we can upconvert the argument + bool isConvertible = ActTy == ParamTy || + (isa(ParamTy) && isa(ActTy)) || + (ParamTy->isInteger() && ActTy->isInteger() && + ParamTy->getPrimitiveSizeInBits() >= ActTy->getPrimitiveSizeInBits()) || + (c && ParamTy->getPrimitiveSizeInBits() >= ActTy->getPrimitiveSizeInBits() + && c->getValue().isStrictlyPositive()); + if (Callee->isDeclaration() && !isConvertible) return false; + + // Most other conversions can be done if we have a body, even if these + // lose information, e.g. int->short. + // Some conversions cannot be done at all, e.g. float to pointer. + // Logic here parallels CastInst::getCastOpcode (the design there + // requires legality checks like this be done before calling it). + if (ParamTy->isInteger()) { + if (const VectorType *VActTy = dyn_cast(ActTy)) { + if (VActTy->getBitWidth() != ParamTy->getPrimitiveSizeInBits()) + return false; + } + if (!ActTy->isInteger() && !ActTy->isFloatingPoint() && + !isa(ActTy)) + return false; + } else if (ParamTy->isFloatingPoint()) { + if (const VectorType *VActTy = dyn_cast(ActTy)) { + if (VActTy->getBitWidth() != ParamTy->getPrimitiveSizeInBits()) + return false; + } + if (!ActTy->isInteger() && !ActTy->isFloatingPoint()) + return false; + } else if (const VectorType *VParamTy = dyn_cast(ParamTy)) { + if (const VectorType *VActTy = dyn_cast(ActTy)) { + if (VActTy->getBitWidth() != VParamTy->getBitWidth()) + return false; + } + if (VParamTy->getBitWidth() != ActTy->getPrimitiveSizeInBits()) + return false; + } else if (isa(ParamTy)) { + if (!ActTy->isInteger() && !isa(ActTy)) + return false; + } else { + return false; + } + } + + if (FT->getNumParams() < NumActualArgs && !FT->isVarArg() && + Callee->isDeclaration()) + return false; // Do not delete arguments unless we have a function body... + + // Okay, we decided that this is a safe thing to do: go ahead and start + // inserting cast instructions as necessary... + std::vector Args; + Args.reserve(NumActualArgs); + + AI = CS.arg_begin(); + for (unsigned i = 0; i != NumCommonArgs; ++i, ++AI) { + const Type *ParamTy = FT->getParamType(i); + if ((*AI)->getType() == ParamTy) { + Args.push_back(*AI); + } else { + Instruction::CastOps opcode = CastInst::getCastOpcode(*AI, + false, ParamTy, false); + CastInst *NewCast = CastInst::create(opcode, *AI, ParamTy, "tmp"); + Args.push_back(InsertNewInstBefore(NewCast, *Caller)); + } + } + + // If the function takes more arguments than the call was taking, add them + // now... + for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i) + Args.push_back(Constant::getNullValue(FT->getParamType(i))); + + // If we are removing arguments to the function, emit an obnoxious warning... + if (FT->getNumParams() < NumActualArgs) + if (!FT->isVarArg()) { + cerr << "WARNING: While resolving call to function '" + << Callee->getName() << "' arguments were dropped!\n"; + } else { + // Add all of the arguments in their promoted form to the arg list... + for (unsigned i = FT->getNumParams(); i != NumActualArgs; ++i, ++AI) { + const Type *PTy = getPromotedType((*AI)->getType()); + if (PTy != (*AI)->getType()) { + // Must promote to pass through va_arg area! + Instruction::CastOps opcode = CastInst::getCastOpcode(*AI, false, + PTy, false); + Instruction *Cast = CastInst::create(opcode, *AI, PTy, "tmp"); + InsertNewInstBefore(Cast, *Caller); + Args.push_back(Cast); + } else { + Args.push_back(*AI); + } + } + } + + if (FT->getReturnType() == Type::VoidTy) + Caller->setName(""); // Void type should not have a name. + + Instruction *NC; + if (InvokeInst *II = dyn_cast(Caller)) { + NC = new InvokeInst(Callee, II->getNormalDest(), II->getUnwindDest(), + &Args[0], Args.size(), Caller->getName(), Caller); + cast(II)->setCallingConv(II->getCallingConv()); + } else { + NC = new CallInst(Callee, &Args[0], Args.size(), Caller->getName(), Caller); + if (cast(Caller)->isTailCall()) + cast(NC)->setTailCall(); + cast(NC)->setCallingConv(cast(Caller)->getCallingConv()); + } + + // Insert a cast of the return type as necessary. + Value *NV = NC; + if (Caller->getType() != NV->getType() && !Caller->use_empty()) { + if (NV->getType() != Type::VoidTy) { + const Type *CallerTy = Caller->getType(); + Instruction::CastOps opcode = CastInst::getCastOpcode(NC, false, + CallerTy, false); + NV = NC = CastInst::create(opcode, NC, CallerTy, "tmp"); + + // If this is an invoke instruction, we should insert it after the first + // non-phi, instruction in the normal successor block. + if (InvokeInst *II = dyn_cast(Caller)) { + BasicBlock::iterator I = II->getNormalDest()->begin(); + while (isa(I)) ++I; + InsertNewInstBefore(NC, *I); + } else { + // Otherwise, it's a call, just insert cast right after the call instr + InsertNewInstBefore(NC, *Caller); + } + AddUsersToWorkList(*Caller); + } else { + NV = UndefValue::get(Caller->getType()); + } + } + + if (Caller->getType() != Type::VoidTy && !Caller->use_empty()) + Caller->replaceAllUsesWith(NV); + Caller->eraseFromParent(); + RemoveFromWorkList(Caller); + return true; +} + +/// FoldPHIArgBinOpIntoPHI - If we have something like phi [add (a,b), add(c,d)] +/// and if a/b/c/d and the add's all have a single use, turn this into two phi's +/// and a single binop. +Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { + Instruction *FirstInst = cast(PN.getIncomingValue(0)); + assert(isa(FirstInst) || isa(FirstInst) || + isa(FirstInst)); + unsigned Opc = FirstInst->getOpcode(); + Value *LHSVal = FirstInst->getOperand(0); + Value *RHSVal = FirstInst->getOperand(1); + + const Type *LHSType = LHSVal->getType(); + const Type *RHSType = RHSVal->getType(); + + // Scan to see if all operands are the same opcode, all have one use, and all + // kill their operands (i.e. the operands have one use). + for (unsigned i = 0; i != PN.getNumIncomingValues(); ++i) { + Instruction *I = dyn_cast(PN.getIncomingValue(i)); + if (!I || I->getOpcode() != Opc || !I->hasOneUse() || + // Verify type of the LHS matches so we don't fold cmp's of different + // types or GEP's with different index types. + I->getOperand(0)->getType() != LHSType || + I->getOperand(1)->getType() != RHSType) + return 0; + + // If they are CmpInst instructions, check their predicates + if (Opc == Instruction::ICmp || Opc == Instruction::FCmp) + if (cast(I)->getPredicate() != + cast(FirstInst)->getPredicate()) + return 0; + + // Keep track of which operand needs a phi node. + if (I->getOperand(0) != LHSVal) LHSVal = 0; + if (I->getOperand(1) != RHSVal) RHSVal = 0; + } + + // Otherwise, this is safe to transform, determine if it is profitable. + + // If this is a GEP, and if the index (not the pointer) needs a PHI, bail out. + // Indexes are often folded into load/store instructions, so we don't want to + // hide them behind a phi. + if (isa(FirstInst) && RHSVal == 0) + return 0; + + Value *InLHS = FirstInst->getOperand(0); + Value *InRHS = FirstInst->getOperand(1); + PHINode *NewLHS = 0, *NewRHS = 0; + if (LHSVal == 0) { + NewLHS = new PHINode(LHSType, FirstInst->getOperand(0)->getName()+".pn"); + NewLHS->reserveOperandSpace(PN.getNumOperands()/2); + NewLHS->addIncoming(InLHS, PN.getIncomingBlock(0)); + InsertNewInstBefore(NewLHS, PN); + LHSVal = NewLHS; + } + + if (RHSVal == 0) { + NewRHS = new PHINode(RHSType, FirstInst->getOperand(1)->getName()+".pn"); + NewRHS->reserveOperandSpace(PN.getNumOperands()/2); + NewRHS->addIncoming(InRHS, PN.getIncomingBlock(0)); + InsertNewInstBefore(NewRHS, PN); + RHSVal = NewRHS; + } + + // Add all operands to the new PHIs. + for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) { + if (NewLHS) { + Value *NewInLHS =cast(PN.getIncomingValue(i))->getOperand(0); + NewLHS->addIncoming(NewInLHS, PN.getIncomingBlock(i)); + } + if (NewRHS) { + Value *NewInRHS =cast(PN.getIncomingValue(i))->getOperand(1); + NewRHS->addIncoming(NewInRHS, PN.getIncomingBlock(i)); + } + } + + if (BinaryOperator *BinOp = dyn_cast(FirstInst)) + return BinaryOperator::create(BinOp->getOpcode(), LHSVal, RHSVal); + else if (CmpInst *CIOp = dyn_cast(FirstInst)) + return CmpInst::create(CIOp->getOpcode(), CIOp->getPredicate(), LHSVal, + RHSVal); + else { + assert(isa(FirstInst)); + return new GetElementPtrInst(LHSVal, RHSVal); + } +} + +/// isSafeToSinkLoad - Return true if we know that it is safe sink the load out +/// of the block that defines it. This means that it must be obvious the value +/// of the load is not changed from the point of the load to the end of the +/// block it is in. +/// +/// Finally, it is safe, but not profitable, to sink a load targetting a +/// non-address-taken alloca. Doing so will cause us to not promote the alloca +/// to a register. +static bool isSafeToSinkLoad(LoadInst *L) { + BasicBlock::iterator BBI = L, E = L->getParent()->end(); + + for (++BBI; BBI != E; ++BBI) + if (BBI->mayWriteToMemory()) + return false; + + // Check for non-address taken alloca. If not address-taken already, it isn't + // profitable to do this xform. + if (AllocaInst *AI = dyn_cast(L->getOperand(0))) { + bool isAddressTaken = false; + for (Value::use_iterator UI = AI->use_begin(), E = AI->use_end(); + UI != E; ++UI) { + if (isa(UI)) continue; + if (StoreInst *SI = dyn_cast(*UI)) { + // If storing TO the alloca, then the address isn't taken. + if (SI->getOperand(1) == AI) continue; + } + isAddressTaken = true; + break; + } + + if (!isAddressTaken) + return false; + } + + return true; +} + + +// FoldPHIArgOpIntoPHI - If all operands to a PHI node are the same "unary" +// operator and they all are only used by the PHI, PHI together their +// inputs, and do the operation once, to the result of the PHI. +Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { + Instruction *FirstInst = cast(PN.getIncomingValue(0)); + + // Scan the instruction, looking for input operations that can be folded away. + // If all input operands to the phi are the same instruction (e.g. a cast from + // the same type or "+42") we can pull the operation through the PHI, reducing + // code size and simplifying code. + Constant *ConstantOp = 0; + const Type *CastSrcTy = 0; + bool isVolatile = false; + if (isa(FirstInst)) { + CastSrcTy = FirstInst->getOperand(0)->getType(); + } else if (isa(FirstInst) || isa(FirstInst)) { + // Can fold binop, compare or shift here if the RHS is a constant, + // otherwise call FoldPHIArgBinOpIntoPHI. + ConstantOp = dyn_cast(FirstInst->getOperand(1)); + if (ConstantOp == 0) + return FoldPHIArgBinOpIntoPHI(PN); + } else if (LoadInst *LI = dyn_cast(FirstInst)) { + isVolatile = LI->isVolatile(); + // We can't sink the load if the loaded value could be modified between the + // load and the PHI. + if (LI->getParent() != PN.getIncomingBlock(0) || + !isSafeToSinkLoad(LI)) + return 0; + } else if (isa(FirstInst)) { + if (FirstInst->getNumOperands() == 2) + return FoldPHIArgBinOpIntoPHI(PN); + // Can't handle general GEPs yet. + return 0; + } else { + return 0; // Cannot fold this operation. + } + + // Check to see if all arguments are the same operation. + for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) { + if (!isa(PN.getIncomingValue(i))) return 0; + Instruction *I = cast(PN.getIncomingValue(i)); + if (!I->hasOneUse() || !I->isSameOperationAs(FirstInst)) + return 0; + if (CastSrcTy) { + if (I->getOperand(0)->getType() != CastSrcTy) + return 0; // Cast operation must match. + } else if (LoadInst *LI = dyn_cast(I)) { + // We can't sink the load if the loaded value could be modified between + // the load and the PHI. + if (LI->isVolatile() != isVolatile || + LI->getParent() != PN.getIncomingBlock(i) || + !isSafeToSinkLoad(LI)) + return 0; + } else if (I->getOperand(1) != ConstantOp) { + return 0; + } + } + + // Okay, they are all the same operation. Create a new PHI node of the + // correct type, and PHI together all of the LHS's of the instructions. + PHINode *NewPN = new PHINode(FirstInst->getOperand(0)->getType(), + PN.getName()+".in"); + NewPN->reserveOperandSpace(PN.getNumOperands()/2); + + Value *InVal = FirstInst->getOperand(0); + NewPN->addIncoming(InVal, PN.getIncomingBlock(0)); + + // Add all operands to the new PHI. + for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) { + Value *NewInVal = cast(PN.getIncomingValue(i))->getOperand(0); + if (NewInVal != InVal) + InVal = 0; + NewPN->addIncoming(NewInVal, PN.getIncomingBlock(i)); + } + + Value *PhiVal; + if (InVal) { + // The new PHI unions all of the same values together. This is really + // common, so we handle it intelligently here for compile-time speed. + PhiVal = InVal; + delete NewPN; + } else { + InsertNewInstBefore(NewPN, PN); + PhiVal = NewPN; + } + + // Insert and return the new operation. + if (CastInst* FirstCI = dyn_cast(FirstInst)) + return CastInst::create(FirstCI->getOpcode(), PhiVal, PN.getType()); + else if (isa(FirstInst)) + return new LoadInst(PhiVal, "", isVolatile); + else if (BinaryOperator *BinOp = dyn_cast(FirstInst)) + return BinaryOperator::create(BinOp->getOpcode(), PhiVal, ConstantOp); + else if (CmpInst *CIOp = dyn_cast(FirstInst)) + return CmpInst::create(CIOp->getOpcode(), CIOp->getPredicate(), + PhiVal, ConstantOp); + else + assert(0 && "Unknown operation"); + return 0; +} + +/// DeadPHICycle - Return true if this PHI node is only used by a PHI node cycle +/// that is dead. +static bool DeadPHICycle(PHINode *PN, + SmallPtrSet &PotentiallyDeadPHIs) { + if (PN->use_empty()) return true; + if (!PN->hasOneUse()) return false; + + // Remember this node, and if we find the cycle, return. + if (!PotentiallyDeadPHIs.insert(PN)) + return true; + + if (PHINode *PU = dyn_cast(PN->use_back())) + return DeadPHICycle(PU, PotentiallyDeadPHIs); + + return false; +} + +// PHINode simplification +// +Instruction *InstCombiner::visitPHINode(PHINode &PN) { + // If LCSSA is around, don't mess with Phi nodes + if (MustPreserveLCSSA) return 0; + + if (Value *V = PN.hasConstantValue()) + return ReplaceInstUsesWith(PN, V); + + // If all PHI operands are the same operation, pull them through the PHI, + // reducing code size. + if (isa(PN.getIncomingValue(0)) && + PN.getIncomingValue(0)->hasOneUse()) + if (Instruction *Result = FoldPHIArgOpIntoPHI(PN)) + return Result; + + // If this is a trivial cycle in the PHI node graph, remove it. Basically, if + // this PHI only has a single use (a PHI), and if that PHI only has one use (a + // PHI)... break the cycle. + if (PN.hasOneUse()) { + Instruction *PHIUser = cast(PN.use_back()); + if (PHINode *PU = dyn_cast(PHIUser)) { + SmallPtrSet PotentiallyDeadPHIs; + PotentiallyDeadPHIs.insert(&PN); + if (DeadPHICycle(PU, PotentiallyDeadPHIs)) + return ReplaceInstUsesWith(PN, UndefValue::get(PN.getType())); + } + + // If this phi has a single use, and if that use just computes a value for + // the next iteration of a loop, delete the phi. This occurs with unused + // induction variables, e.g. "for (int j = 0; ; ++j);". Detecting this + // common case here is good because the only other things that catch this + // are induction variable analysis (sometimes) and ADCE, which is only run + // late. + if (PHIUser->hasOneUse() && + (isa(PHIUser) || isa(PHIUser)) && + PHIUser->use_back() == &PN) { + return ReplaceInstUsesWith(PN, UndefValue::get(PN.getType())); + } + } + + return 0; +} + +static Value *InsertCastToIntPtrTy(Value *V, const Type *DTy, + Instruction *InsertPoint, + InstCombiner *IC) { + unsigned PtrSize = DTy->getPrimitiveSizeInBits(); + unsigned VTySize = V->getType()->getPrimitiveSizeInBits(); + // We must cast correctly to the pointer type. Ensure that we + // sign extend the integer value if it is smaller as this is + // used for address computation. + Instruction::CastOps opcode = + (VTySize < PtrSize ? Instruction::SExt : + (VTySize == PtrSize ? Instruction::BitCast : Instruction::Trunc)); + return IC->InsertCastBefore(opcode, V, DTy, *InsertPoint); +} + + +Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { + Value *PtrOp = GEP.getOperand(0); + // Is it 'getelementptr %P, i32 0' or 'getelementptr %P' + // If so, eliminate the noop. + if (GEP.getNumOperands() == 1) + return ReplaceInstUsesWith(GEP, PtrOp); + + if (isa(GEP.getOperand(0))) + return ReplaceInstUsesWith(GEP, UndefValue::get(GEP.getType())); + + bool HasZeroPointerIndex = false; + if (Constant *C = dyn_cast(GEP.getOperand(1))) + HasZeroPointerIndex = C->isNullValue(); + + if (GEP.getNumOperands() == 2 && HasZeroPointerIndex) + return ReplaceInstUsesWith(GEP, PtrOp); + + // Eliminate unneeded casts for indices. + bool MadeChange = false; + + gep_type_iterator GTI = gep_type_begin(GEP); + for (unsigned i = 1, e = GEP.getNumOperands(); i != e; ++i, ++GTI) { + if (isa(*GTI)) { + if (CastInst *CI = dyn_cast(GEP.getOperand(i))) { + if (CI->getOpcode() == Instruction::ZExt || + CI->getOpcode() == Instruction::SExt) { + const Type *SrcTy = CI->getOperand(0)->getType(); + // We can eliminate a cast from i32 to i64 iff the target + // is a 32-bit pointer target. + if (SrcTy->getPrimitiveSizeInBits() >= TD->getPointerSizeInBits()) { + MadeChange = true; + GEP.setOperand(i, CI->getOperand(0)); + } + } + } + // If we are using a wider index than needed for this platform, shrink it + // to what we need. If the incoming value needs a cast instruction, + // insert it. This explicit cast can make subsequent optimizations more + // obvious. + Value *Op = GEP.getOperand(i); + if (TD->getTypeSize(Op->getType()) > TD->getPointerSize()) + if (Constant *C = dyn_cast(Op)) { + GEP.setOperand(i, ConstantExpr::getTrunc(C, TD->getIntPtrType())); + MadeChange = true; + } else { + Op = InsertCastBefore(Instruction::Trunc, Op, TD->getIntPtrType(), + GEP); + GEP.setOperand(i, Op); + MadeChange = true; + } + } + } + if (MadeChange) return &GEP; + + // If this GEP instruction doesn't move the pointer, and if the input operand + // is a bitcast of another pointer, just replace the GEP with a bitcast of the + // real input to the dest type. + if (GEP.hasAllZeroIndices() && isa(GEP.getOperand(0))) + return new BitCastInst(cast(GEP.getOperand(0))->getOperand(0), + GEP.getType()); + + // Combine Indices - If the source pointer to this getelementptr instruction + // is a getelementptr instruction, combine the indices of the two + // getelementptr instructions into a single instruction. + // + SmallVector SrcGEPOperands; + if (User *Src = dyn_castGetElementPtr(PtrOp)) + SrcGEPOperands.append(Src->op_begin(), Src->op_end()); + + if (!SrcGEPOperands.empty()) { + // Note that if our source is a gep chain itself that we wait for that + // chain to be resolved before we perform this transformation. This + // avoids us creating a TON of code in some cases. + // + if (isa(SrcGEPOperands[0]) && + cast(SrcGEPOperands[0])->getNumOperands() == 2) + return 0; // Wait until our source is folded to completion. + + SmallVector Indices; + + // Find out whether the last index in the source GEP is a sequential idx. + bool EndsWithSequential = false; + for (gep_type_iterator I = gep_type_begin(*cast(PtrOp)), + E = gep_type_end(*cast(PtrOp)); I != E; ++I) + EndsWithSequential = !isa(*I); + + // Can we combine the two pointer arithmetics offsets? + if (EndsWithSequential) { + // Replace: gep (gep %P, long B), long A, ... + // With: T = long A+B; gep %P, T, ... + // + Value *Sum, *SO1 = SrcGEPOperands.back(), *GO1 = GEP.getOperand(1); + if (SO1 == Constant::getNullValue(SO1->getType())) { + Sum = GO1; + } else if (GO1 == Constant::getNullValue(GO1->getType())) { + Sum = SO1; + } else { + // If they aren't the same type, convert both to an integer of the + // target's pointer size. + if (SO1->getType() != GO1->getType()) { + if (Constant *SO1C = dyn_cast(SO1)) { + SO1 = ConstantExpr::getIntegerCast(SO1C, GO1->getType(), true); + } else if (Constant *GO1C = dyn_cast(GO1)) { + GO1 = ConstantExpr::getIntegerCast(GO1C, SO1->getType(), true); + } else { + unsigned PS = TD->getPointerSize(); + if (TD->getTypeSize(SO1->getType()) == PS) { + // Convert GO1 to SO1's type. + GO1 = InsertCastToIntPtrTy(GO1, SO1->getType(), &GEP, this); + + } else if (TD->getTypeSize(GO1->getType()) == PS) { + // Convert SO1 to GO1's type. + SO1 = InsertCastToIntPtrTy(SO1, GO1->getType(), &GEP, this); + } else { + const Type *PT = TD->getIntPtrType(); + SO1 = InsertCastToIntPtrTy(SO1, PT, &GEP, this); + GO1 = InsertCastToIntPtrTy(GO1, PT, &GEP, this); + } + } + } + if (isa(SO1) && isa(GO1)) + Sum = ConstantExpr::getAdd(cast(SO1), cast(GO1)); + else { + Sum = BinaryOperator::createAdd(SO1, GO1, PtrOp->getName()+".sum"); + InsertNewInstBefore(cast(Sum), GEP); + } + } + + // Recycle the GEP we already have if possible. + if (SrcGEPOperands.size() == 2) { + GEP.setOperand(0, SrcGEPOperands[0]); + GEP.setOperand(1, Sum); + return &GEP; + } else { + Indices.insert(Indices.end(), SrcGEPOperands.begin()+1, + SrcGEPOperands.end()-1); + Indices.push_back(Sum); + Indices.insert(Indices.end(), GEP.op_begin()+2, GEP.op_end()); + } + } else if (isa(*GEP.idx_begin()) && + cast(*GEP.idx_begin())->isNullValue() && + SrcGEPOperands.size() != 1) { + // Otherwise we can do the fold if the first index of the GEP is a zero + Indices.insert(Indices.end(), SrcGEPOperands.begin()+1, + SrcGEPOperands.end()); + Indices.insert(Indices.end(), GEP.idx_begin()+1, GEP.idx_end()); + } + + if (!Indices.empty()) + return new GetElementPtrInst(SrcGEPOperands[0], &Indices[0], + Indices.size(), GEP.getName()); + + } else if (GlobalValue *GV = dyn_cast(PtrOp)) { + // GEP of global variable. If all of the indices for this GEP are + // constants, we can promote this to a constexpr instead of an instruction. + + // Scan for nonconstants... + SmallVector Indices; + User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); + for (; I != E && isa(*I); ++I) + Indices.push_back(cast(*I)); + + if (I == E) { // If they are all constants... + Constant *CE = ConstantExpr::getGetElementPtr(GV, + &Indices[0],Indices.size()); + + // Replace all uses of the GEP with the new constexpr... + return ReplaceInstUsesWith(GEP, CE); + } + } else if (Value *X = getBitCastOperand(PtrOp)) { // Is the operand a cast? + if (!isa(X->getType())) { + // Not interesting. Source pointer must be a cast from pointer. + } else if (HasZeroPointerIndex) { + // transform: GEP (cast [10 x ubyte]* X to [0 x ubyte]*), long 0, ... + // into : GEP [10 x ubyte]* X, long 0, ... + // + // This occurs when the program declares an array extern like "int X[];" + // + const PointerType *CPTy = cast(PtrOp->getType()); + const PointerType *XTy = cast(X->getType()); + if (const ArrayType *XATy = + dyn_cast(XTy->getElementType())) + if (const ArrayType *CATy = + dyn_cast(CPTy->getElementType())) + if (CATy->getElementType() == XATy->getElementType()) { + // At this point, we know that the cast source type is a pointer + // to an array of the same type as the destination pointer + // array. Because the array type is never stepped over (there + // is a leading zero) we can fold the cast into this GEP. + GEP.setOperand(0, X); + return &GEP; + } + } else if (GEP.getNumOperands() == 2) { + // Transform things like: + // %t = getelementptr ubyte* cast ([2 x int]* %str to uint*), uint %V + // into: %t1 = getelementptr [2 x int*]* %str, int 0, uint %V; cast + const Type *SrcElTy = cast(X->getType())->getElementType(); + const Type *ResElTy=cast(PtrOp->getType())->getElementType(); + if (isa(SrcElTy) && + TD->getTypeSize(cast(SrcElTy)->getElementType()) == + TD->getTypeSize(ResElTy)) { + Value *V = InsertNewInstBefore( + new GetElementPtrInst(X, Constant::getNullValue(Type::Int32Ty), + GEP.getOperand(1), GEP.getName()), GEP); + // V and GEP are both pointer types --> BitCast + return new BitCastInst(V, GEP.getType()); + } + + // Transform things like: + // getelementptr sbyte* cast ([100 x double]* X to sbyte*), int %tmp + // (where tmp = 8*tmp2) into: + // getelementptr [100 x double]* %arr, int 0, int %tmp.2 + + if (isa(SrcElTy) && + (ResElTy == Type::Int8Ty || ResElTy == Type::Int8Ty)) { + uint64_t ArrayEltSize = + TD->getTypeSize(cast(SrcElTy)->getElementType()); + + // Check to see if "tmp" is a scale by a multiple of ArrayEltSize. We + // allow either a mul, shift, or constant here. + Value *NewIdx = 0; + ConstantInt *Scale = 0; + if (ArrayEltSize == 1) { + NewIdx = GEP.getOperand(1); + Scale = ConstantInt::get(NewIdx->getType(), 1); + } else if (ConstantInt *CI = dyn_cast(GEP.getOperand(1))) { + NewIdx = ConstantInt::get(CI->getType(), 1); + Scale = CI; + } else if (Instruction *Inst =dyn_cast(GEP.getOperand(1))){ + if (Inst->getOpcode() == Instruction::Shl && + isa(Inst->getOperand(1))) { + ConstantInt *ShAmt = cast(Inst->getOperand(1)); + uint32_t ShAmtVal = ShAmt->getLimitedValue(64); + Scale = ConstantInt::get(Inst->getType(), 1ULL << ShAmtVal); + NewIdx = Inst->getOperand(0); + } else if (Inst->getOpcode() == Instruction::Mul && + isa(Inst->getOperand(1))) { + Scale = cast(Inst->getOperand(1)); + NewIdx = Inst->getOperand(0); + } + } + + // If the index will be to exactly the right offset with the scale taken + // out, perform the transformation. + if (Scale && Scale->getZExtValue() % ArrayEltSize == 0) { + if (isa(Scale)) + Scale = ConstantInt::get(Scale->getType(), + Scale->getZExtValue() / ArrayEltSize); + if (Scale->getZExtValue() != 1) { + Constant *C = ConstantExpr::getIntegerCast(Scale, NewIdx->getType(), + true /*SExt*/); + Instruction *Sc = BinaryOperator::createMul(NewIdx, C, "idxscale"); + NewIdx = InsertNewInstBefore(Sc, GEP); + } + + // Insert the new GEP instruction. + Instruction *NewGEP = + new GetElementPtrInst(X, Constant::getNullValue(Type::Int32Ty), + NewIdx, GEP.getName()); + NewGEP = InsertNewInstBefore(NewGEP, GEP); + // The NewGEP must be pointer typed, so must the old one -> BitCast + return new BitCastInst(NewGEP, GEP.getType()); + } + } + } + } + + return 0; +} + +Instruction *InstCombiner::visitAllocationInst(AllocationInst &AI) { + // Convert: malloc Ty, C - where C is a constant != 1 into: malloc [C x Ty], 1 + if (AI.isArrayAllocation()) // Check C != 1 + if (const ConstantInt *C = dyn_cast(AI.getArraySize())) { + const Type *NewTy = + ArrayType::get(AI.getAllocatedType(), C->getZExtValue()); + AllocationInst *New = 0; + + // Create and insert the replacement instruction... + if (isa(AI)) + New = new MallocInst(NewTy, 0, AI.getAlignment(), AI.getName()); + else { + assert(isa(AI) && "Unknown type of allocation inst!"); + New = new AllocaInst(NewTy, 0, AI.getAlignment(), AI.getName()); + } + + InsertNewInstBefore(New, AI); + + // Scan to the end of the allocation instructions, to skip over a block of + // allocas if possible... + // + BasicBlock::iterator It = New; + while (isa(*It)) ++It; + + // Now that I is pointing to the first non-allocation-inst in the block, + // insert our getelementptr instruction... + // + Value *NullIdx = Constant::getNullValue(Type::Int32Ty); + Value *V = new GetElementPtrInst(New, NullIdx, NullIdx, + New->getName()+".sub", It); + + // Now make everything use the getelementptr instead of the original + // allocation. + return ReplaceInstUsesWith(AI, V); + } else if (isa(AI.getArraySize())) { + return ReplaceInstUsesWith(AI, Constant::getNullValue(AI.getType())); + } + + // If alloca'ing a zero byte object, replace the alloca with a null pointer. + // Note that we only do this for alloca's, because malloc should allocate and + // return a unique pointer, even for a zero byte allocation. + if (isa(AI) && AI.getAllocatedType()->isSized() && + TD->getTypeSize(AI.getAllocatedType()) == 0) + return ReplaceInstUsesWith(AI, Constant::getNullValue(AI.getType())); + + return 0; +} + +Instruction *InstCombiner::visitFreeInst(FreeInst &FI) { + Value *Op = FI.getOperand(0); + + // free undef -> unreachable. + if (isa(Op)) { + // Insert a new store to null because we cannot modify the CFG here. + new StoreInst(ConstantInt::getTrue(), + UndefValue::get(PointerType::get(Type::Int1Ty)), &FI); + return EraseInstFromFunction(FI); + } + + // If we have 'free null' delete the instruction. This can happen in stl code + // when lots of inlining happens. + if (isa(Op)) + return EraseInstFromFunction(FI); + + // Change free * (cast * X to *) into free * X + if (BitCastInst *CI = dyn_cast(Op)) { + FI.setOperand(0, CI->getOperand(0)); + return &FI; + } + + // Change free (gep X, 0,0,0,0) into free(X) + if (GetElementPtrInst *GEPI = dyn_cast(Op)) { + if (GEPI->hasAllZeroIndices()) { + AddToWorkList(GEPI); + FI.setOperand(0, GEPI->getOperand(0)); + return &FI; + } + } + + // Change free(malloc) into nothing, if the malloc has a single use. + if (MallocInst *MI = dyn_cast(Op)) + if (MI->hasOneUse()) { + EraseInstFromFunction(FI); + return EraseInstFromFunction(*MI); + } + + return 0; +} + + +/// InstCombineLoadCast - Fold 'load (cast P)' -> cast (load P)' when possible. +static Instruction *InstCombineLoadCast(InstCombiner &IC, LoadInst &LI) { + User *CI = cast(LI.getOperand(0)); + Value *CastOp = CI->getOperand(0); + + const Type *DestPTy = cast(CI->getType())->getElementType(); + if (const PointerType *SrcTy = dyn_cast(CastOp->getType())) { + const Type *SrcPTy = SrcTy->getElementType(); + + if (DestPTy->isInteger() || isa(DestPTy) || + isa(DestPTy)) { + // If the source is an array, the code below will not succeed. Check to + // see if a trivial 'gep P, 0, 0' will help matters. Only do this for + // constants. + if (const ArrayType *ASrcTy = dyn_cast(SrcPTy)) + if (Constant *CSrc = dyn_cast(CastOp)) + if (ASrcTy->getNumElements() != 0) { + Value *Idxs[2]; + Idxs[0] = Idxs[1] = Constant::getNullValue(Type::Int32Ty); + CastOp = ConstantExpr::getGetElementPtr(CSrc, Idxs, 2); + SrcTy = cast(CastOp->getType()); + SrcPTy = SrcTy->getElementType(); + } + + if ((SrcPTy->isInteger() || isa(SrcPTy) || + isa(SrcPTy)) && + // Do not allow turning this into a load of an integer, which is then + // casted to a pointer, this pessimizes pointer analysis a lot. + (isa(SrcPTy) == isa(LI.getType())) && + IC.getTargetData().getTypeSizeInBits(SrcPTy) == + IC.getTargetData().getTypeSizeInBits(DestPTy)) { + + // Okay, we are casting from one integer or pointer type to another of + // the same size. Instead of casting the pointer before the load, cast + // the result of the loaded value. + Value *NewLoad = IC.InsertNewInstBefore(new LoadInst(CastOp, + CI->getName(), + LI.isVolatile()),LI); + // Now cast the result of the load. + return new BitCastInst(NewLoad, LI.getType()); + } + } + } + return 0; +} + +/// isSafeToLoadUnconditionally - Return true if we know that executing a load +/// from this value cannot trap. If it is not obviously safe to load from the +/// specified pointer, we do a quick local scan of the basic block containing +/// ScanFrom, to determine if the address is already accessed. +static bool isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom) { + // If it is an alloca or global variable, it is always safe to load from. + if (isa(V) || isa(V)) return true; + + // Otherwise, be a little bit agressive by scanning the local block where we + // want to check to see if the pointer is already being loaded or stored + // from/to. If so, the previous load or store would have already trapped, + // so there is no harm doing an extra load (also, CSE will later eliminate + // the load entirely). + BasicBlock::iterator BBI = ScanFrom, E = ScanFrom->getParent()->begin(); + + while (BBI != E) { + --BBI; + + if (LoadInst *LI = dyn_cast(BBI)) { + if (LI->getOperand(0) == V) return true; + } else if (StoreInst *SI = dyn_cast(BBI)) + if (SI->getOperand(1) == V) return true; + + } + return false; +} + +Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { + Value *Op = LI.getOperand(0); + + // load (cast X) --> cast (load X) iff safe + if (isa(Op)) + if (Instruction *Res = InstCombineLoadCast(*this, LI)) + return Res; + + // None of the following transforms are legal for volatile loads. + if (LI.isVolatile()) return 0; + + if (&LI.getParent()->front() != &LI) { + BasicBlock::iterator BBI = &LI; --BBI; + // If the instruction immediately before this is a store to the same + // address, do a simple form of store->load forwarding. + if (StoreInst *SI = dyn_cast(BBI)) + if (SI->getOperand(1) == LI.getOperand(0)) + return ReplaceInstUsesWith(LI, SI->getOperand(0)); + if (LoadInst *LIB = dyn_cast(BBI)) + if (LIB->getOperand(0) == LI.getOperand(0)) + return ReplaceInstUsesWith(LI, LIB); + } + + if (GetElementPtrInst *GEPI = dyn_cast(Op)) + if (isa(GEPI->getOperand(0))) { + // Insert a new store to null instruction before the load to indicate + // that this code is not reachable. We do this instead of inserting + // an unreachable instruction directly because we cannot modify the + // CFG. + new StoreInst(UndefValue::get(LI.getType()), + Constant::getNullValue(Op->getType()), &LI); + return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); + } + + if (Constant *C = dyn_cast(Op)) { + // load null/undef -> undef + if ((C->isNullValue() || isa(C))) { + // Insert a new store to null instruction before the load to indicate that + // this code is not reachable. We do this instead of inserting an + // unreachable instruction directly because we cannot modify the CFG. + new StoreInst(UndefValue::get(LI.getType()), + Constant::getNullValue(Op->getType()), &LI); + return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); + } + + // Instcombine load (constant global) into the value loaded. + if (GlobalVariable *GV = dyn_cast(Op)) + if (GV->isConstant() && !GV->isDeclaration()) + return ReplaceInstUsesWith(LI, GV->getInitializer()); + + // Instcombine load (constantexpr_GEP global, 0, ...) into the value loaded. + if (ConstantExpr *CE = dyn_cast(Op)) + if (CE->getOpcode() == Instruction::GetElementPtr) { + if (GlobalVariable *GV = dyn_cast(CE->getOperand(0))) + if (GV->isConstant() && !GV->isDeclaration()) + if (Constant *V = + ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE)) + return ReplaceInstUsesWith(LI, V); + if (CE->getOperand(0)->isNullValue()) { + // Insert a new store to null instruction before the load to indicate + // that this code is not reachable. We do this instead of inserting + // an unreachable instruction directly because we cannot modify the + // CFG. + new StoreInst(UndefValue::get(LI.getType()), + Constant::getNullValue(Op->getType()), &LI); + return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); + } + + } else if (CE->isCast()) { + if (Instruction *Res = InstCombineLoadCast(*this, LI)) + return Res; + } + } + + if (Op->hasOneUse()) { + // Change select and PHI nodes to select values instead of addresses: this + // helps alias analysis out a lot, allows many others simplifications, and + // exposes redundancy in the code. + // + // Note that we cannot do the transformation unless we know that the + // introduced loads cannot trap! Something like this is valid as long as + // the condition is always false: load (select bool %C, int* null, int* %G), + // but it would not be valid if we transformed it to load from null + // unconditionally. + // + if (SelectInst *SI = dyn_cast(Op)) { + // load (select (Cond, &V1, &V2)) --> select(Cond, load &V1, load &V2). + if (isSafeToLoadUnconditionally(SI->getOperand(1), SI) && + isSafeToLoadUnconditionally(SI->getOperand(2), SI)) { + Value *V1 = InsertNewInstBefore(new LoadInst(SI->getOperand(1), + SI->getOperand(1)->getName()+".val"), LI); + Value *V2 = InsertNewInstBefore(new LoadInst(SI->getOperand(2), + SI->getOperand(2)->getName()+".val"), LI); + return new SelectInst(SI->getCondition(), V1, V2); + } + + // load (select (cond, null, P)) -> load P + if (Constant *C = dyn_cast(SI->getOperand(1))) + if (C->isNullValue()) { + LI.setOperand(0, SI->getOperand(2)); + return &LI; + } + + // load (select (cond, P, null)) -> load P + if (Constant *C = dyn_cast(SI->getOperand(2))) + if (C->isNullValue()) { + LI.setOperand(0, SI->getOperand(1)); + return &LI; + } + } + } + return 0; +} + +/// InstCombineStoreToCast - Fold store V, (cast P) -> store (cast V), P +/// when possible. +static Instruction *InstCombineStoreToCast(InstCombiner &IC, StoreInst &SI) { + User *CI = cast(SI.getOperand(1)); + Value *CastOp = CI->getOperand(0); + + const Type *DestPTy = cast(CI->getType())->getElementType(); + if (const PointerType *SrcTy = dyn_cast(CastOp->getType())) { + const Type *SrcPTy = SrcTy->getElementType(); + + if (DestPTy->isInteger() || isa(DestPTy)) { + // If the source is an array, the code below will not succeed. Check to + // see if a trivial 'gep P, 0, 0' will help matters. Only do this for + // constants. + if (const ArrayType *ASrcTy = dyn_cast(SrcPTy)) + if (Constant *CSrc = dyn_cast(CastOp)) + if (ASrcTy->getNumElements() != 0) { + Value* Idxs[2]; + Idxs[0] = Idxs[1] = Constant::getNullValue(Type::Int32Ty); + CastOp = ConstantExpr::getGetElementPtr(CSrc, Idxs, 2); + SrcTy = cast(CastOp->getType()); + SrcPTy = SrcTy->getElementType(); + } + + if ((SrcPTy->isInteger() || isa(SrcPTy)) && + IC.getTargetData().getTypeSizeInBits(SrcPTy) == + IC.getTargetData().getTypeSizeInBits(DestPTy)) { + + // Okay, we are casting from one integer or pointer type to another of + // the same size. Instead of casting the pointer before + // the store, cast the value to be stored. + Value *NewCast; + Value *SIOp0 = SI.getOperand(0); + Instruction::CastOps opcode = Instruction::BitCast; + const Type* CastSrcTy = SIOp0->getType(); + const Type* CastDstTy = SrcPTy; + if (isa(CastDstTy)) { + if (CastSrcTy->isInteger()) + opcode = Instruction::IntToPtr; + } else if (isa(CastDstTy)) { + if (isa(SIOp0->getType())) + opcode = Instruction::PtrToInt; + } + if (Constant *C = dyn_cast(SIOp0)) + NewCast = ConstantExpr::getCast(opcode, C, CastDstTy); + else + NewCast = IC.InsertNewInstBefore( + CastInst::create(opcode, SIOp0, CastDstTy, SIOp0->getName()+".c"), + SI); + return new StoreInst(NewCast, CastOp); + } + } + } + return 0; +} + +Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { + Value *Val = SI.getOperand(0); + Value *Ptr = SI.getOperand(1); + + if (isa(Ptr)) { // store X, undef -> noop (even if volatile) + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + + // If the RHS is an alloca with a single use, zapify the store, making the + // alloca dead. + if (Ptr->hasOneUse()) { + if (isa(Ptr)) { + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + + if (GetElementPtrInst *GEP = dyn_cast(Ptr)) + if (isa(GEP->getOperand(0)) && + GEP->getOperand(0)->hasOneUse()) { + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + } + + // Do really simple DSE, to catch cases where there are several consequtive + // stores to the same location, separated by a few arithmetic operations. This + // situation often occurs with bitfield accesses. + BasicBlock::iterator BBI = &SI; + for (unsigned ScanInsts = 6; BBI != SI.getParent()->begin() && ScanInsts; + --ScanInsts) { + --BBI; + + if (StoreInst *PrevSI = dyn_cast(BBI)) { + // Prev store isn't volatile, and stores to the same location? + if (!PrevSI->isVolatile() && PrevSI->getOperand(1) == SI.getOperand(1)) { + ++NumDeadStore; + ++BBI; + EraseInstFromFunction(*PrevSI); + continue; + } + break; + } + + // If this is a load, we have to stop. However, if the loaded value is from + // the pointer we're loading and is producing the pointer we're storing, + // then *this* store is dead (X = load P; store X -> P). + if (LoadInst *LI = dyn_cast(BBI)) { + if (LI == Val && LI->getOperand(0) == Ptr) { + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + // Otherwise, this is a load from some other location. Stores before it + // may not be dead. + break; + } + + // Don't skip over loads or things that can modify memory. + if (BBI->mayWriteToMemory()) + break; + } + + + if (SI.isVolatile()) return 0; // Don't hack volatile stores. + + // store X, null -> turns into 'unreachable' in SimplifyCFG + if (isa(Ptr)) { + if (!isa(Val)) { + SI.setOperand(0, UndefValue::get(Val->getType())); + if (Instruction *U = dyn_cast(Val)) + AddToWorkList(U); // Dropped a use. + ++NumCombined; + } + return 0; // Do not modify these! + } + + // store undef, Ptr -> noop + if (isa(Val)) { + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + + // If the pointer destination is a cast, see if we can fold the cast into the + // source instead. + if (isa(Ptr)) + if (Instruction *Res = InstCombineStoreToCast(*this, SI)) + return Res; + if (ConstantExpr *CE = dyn_cast(Ptr)) + if (CE->isCast()) + if (Instruction *Res = InstCombineStoreToCast(*this, SI)) + return Res; + + + // If this store is the last instruction in the basic block, and if the block + // ends with an unconditional branch, try to move it to the successor block. + BBI = &SI; ++BBI; + if (BranchInst *BI = dyn_cast(BBI)) + if (BI->isUnconditional()) + if (SimplifyStoreAtEndOfBlock(SI)) + return 0; // xform done! + + return 0; +} + +/// SimplifyStoreAtEndOfBlock - Turn things like: +/// if () { *P = v1; } else { *P = v2 } +/// into a phi node with a store in the successor. +/// +/// Simplify things like: +/// *P = v1; if () { *P = v2; } +/// into a phi node with a store in the successor. +/// +bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { + BasicBlock *StoreBB = SI.getParent(); + + // Check to see if the successor block has exactly two incoming edges. If + // so, see if the other predecessor contains a store to the same location. + // if so, insert a PHI node (if needed) and move the stores down. + BasicBlock *DestBB = StoreBB->getTerminator()->getSuccessor(0); + + // Determine whether Dest has exactly two predecessors and, if so, compute + // the other predecessor. + pred_iterator PI = pred_begin(DestBB); + BasicBlock *OtherBB = 0; + if (*PI != StoreBB) + OtherBB = *PI; + ++PI; + if (PI == pred_end(DestBB)) + return false; + + if (*PI != StoreBB) { + if (OtherBB) + return false; + OtherBB = *PI; + } + if (++PI != pred_end(DestBB)) + return false; + + + // Verify that the other block ends in a branch and is not otherwise empty. + BasicBlock::iterator BBI = OtherBB->getTerminator(); + BranchInst *OtherBr = dyn_cast(BBI); + if (!OtherBr || BBI == OtherBB->begin()) + return false; + + // If the other block ends in an unconditional branch, check for the 'if then + // else' case. there is an instruction before the branch. + StoreInst *OtherStore = 0; + if (OtherBr->isUnconditional()) { + // If this isn't a store, or isn't a store to the same location, bail out. + --BBI; + OtherStore = dyn_cast(BBI); + if (!OtherStore || OtherStore->getOperand(1) != SI.getOperand(1)) + return false; + } else { + // Otherwise, the other block ended with a conditional branch. If one of the + // destinations is StoreBB, then we have the if/then case. + if (OtherBr->getSuccessor(0) != StoreBB && + OtherBr->getSuccessor(1) != StoreBB) + return false; + + // Okay, we know that OtherBr now goes to Dest and StoreBB, so this is an + // if/then triangle. See if there is a store to the same ptr as SI that + // lives in OtherBB. + for (;; --BBI) { + // Check to see if we find the matching store. + if ((OtherStore = dyn_cast(BBI))) { + if (OtherStore->getOperand(1) != SI.getOperand(1)) + return false; + break; + } + // If we find something that may be using the stored value, or if we run + // out of instructions, we can't do the xform. + if (isa(BBI) || BBI->mayWriteToMemory() || + BBI == OtherBB->begin()) + return false; + } + + // In order to eliminate the store in OtherBr, we have to + // make sure nothing reads the stored value in StoreBB. + for (BasicBlock::iterator I = StoreBB->begin(); &*I != &SI; ++I) { + // FIXME: This should really be AA driven. + if (isa(I) || I->mayWriteToMemory()) + return false; + } + } + + // Insert a PHI node now if we need it. + Value *MergedVal = OtherStore->getOperand(0); + if (MergedVal != SI.getOperand(0)) { + PHINode *PN = new PHINode(MergedVal->getType(), "storemerge"); + PN->reserveOperandSpace(2); + PN->addIncoming(SI.getOperand(0), SI.getParent()); + PN->addIncoming(OtherStore->getOperand(0), OtherBB); + MergedVal = InsertNewInstBefore(PN, DestBB->front()); + } + + // Advance to a place where it is safe to insert the new store and + // insert it. + BBI = DestBB->begin(); + while (isa(BBI)) ++BBI; + InsertNewInstBefore(new StoreInst(MergedVal, SI.getOperand(1), + OtherStore->isVolatile()), *BBI); + + // Nuke the old stores. + EraseInstFromFunction(SI); + EraseInstFromFunction(*OtherStore); + ++NumCombined; + return true; +} + + +Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { + // Change br (not X), label True, label False to: br X, label False, True + Value *X = 0; + BasicBlock *TrueDest; + BasicBlock *FalseDest; + if (match(&BI, m_Br(m_Not(m_Value(X)), TrueDest, FalseDest)) && + !isa(X)) { + // Swap Destinations and condition... + BI.setCondition(X); + BI.setSuccessor(0, FalseDest); + BI.setSuccessor(1, TrueDest); + return &BI; + } + + // Cannonicalize fcmp_one -> fcmp_oeq + FCmpInst::Predicate FPred; Value *Y; + if (match(&BI, m_Br(m_FCmp(FPred, m_Value(X), m_Value(Y)), + TrueDest, FalseDest))) + if ((FPred == FCmpInst::FCMP_ONE || FPred == FCmpInst::FCMP_OLE || + FPred == FCmpInst::FCMP_OGE) && BI.getCondition()->hasOneUse()) { + FCmpInst *I = cast(BI.getCondition()); + FCmpInst::Predicate NewPred = FCmpInst::getInversePredicate(FPred); + Instruction *NewSCC = new FCmpInst(NewPred, X, Y, "", I); + NewSCC->takeName(I); + // Swap Destinations and condition... + BI.setCondition(NewSCC); + BI.setSuccessor(0, FalseDest); + BI.setSuccessor(1, TrueDest); + RemoveFromWorkList(I); + I->eraseFromParent(); + AddToWorkList(NewSCC); + return &BI; + } + + // Cannonicalize icmp_ne -> icmp_eq + ICmpInst::Predicate IPred; + if (match(&BI, m_Br(m_ICmp(IPred, m_Value(X), m_Value(Y)), + TrueDest, FalseDest))) + if ((IPred == ICmpInst::ICMP_NE || IPred == ICmpInst::ICMP_ULE || + IPred == ICmpInst::ICMP_SLE || IPred == ICmpInst::ICMP_UGE || + IPred == ICmpInst::ICMP_SGE) && BI.getCondition()->hasOneUse()) { + ICmpInst *I = cast(BI.getCondition()); + ICmpInst::Predicate NewPred = ICmpInst::getInversePredicate(IPred); + Instruction *NewSCC = new ICmpInst(NewPred, X, Y, "", I); + NewSCC->takeName(I); + // Swap Destinations and condition... + BI.setCondition(NewSCC); + BI.setSuccessor(0, FalseDest); + BI.setSuccessor(1, TrueDest); + RemoveFromWorkList(I); + I->eraseFromParent();; + AddToWorkList(NewSCC); + return &BI; + } + + return 0; +} + +Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { + Value *Cond = SI.getCondition(); + if (Instruction *I = dyn_cast(Cond)) { + if (I->getOpcode() == Instruction::Add) + if (ConstantInt *AddRHS = dyn_cast(I->getOperand(1))) { + // change 'switch (X+4) case 1:' into 'switch (X) case -3' + for (unsigned i = 2, e = SI.getNumOperands(); i != e; i += 2) + SI.setOperand(i,ConstantExpr::getSub(cast(SI.getOperand(i)), + AddRHS)); + SI.setOperand(0, I->getOperand(0)); + AddToWorkList(I); + return &SI; + } + } + return 0; +} + +/// CheapToScalarize - Return true if the value is cheaper to scalarize than it +/// is to leave as a vector operation. +static bool CheapToScalarize(Value *V, bool isConstant) { + if (isa(V)) + return true; + if (ConstantVector *C = dyn_cast(V)) { + if (isConstant) return true; + // If all elts are the same, we can extract. + Constant *Op0 = C->getOperand(0); + for (unsigned i = 1; i < C->getNumOperands(); ++i) + if (C->getOperand(i) != Op0) + return false; + return true; + } + Instruction *I = dyn_cast(V); + if (!I) return false; + + // Insert element gets simplified to the inserted element or is deleted if + // this is constant idx extract element and its a constant idx insertelt. + if (I->getOpcode() == Instruction::InsertElement && isConstant && + isa(I->getOperand(2))) + return true; + if (I->getOpcode() == Instruction::Load && I->hasOneUse()) + return true; + if (BinaryOperator *BO = dyn_cast(I)) + if (BO->hasOneUse() && + (CheapToScalarize(BO->getOperand(0), isConstant) || + CheapToScalarize(BO->getOperand(1), isConstant))) + return true; + if (CmpInst *CI = dyn_cast(I)) + if (CI->hasOneUse() && + (CheapToScalarize(CI->getOperand(0), isConstant) || + CheapToScalarize(CI->getOperand(1), isConstant))) + return true; + + return false; +} + +/// Read and decode a shufflevector mask. +/// +/// It turns undef elements into values that are larger than the number of +/// elements in the input. +static std::vector getShuffleMask(const ShuffleVectorInst *SVI) { + unsigned NElts = SVI->getType()->getNumElements(); + if (isa(SVI->getOperand(2))) + return std::vector(NElts, 0); + if (isa(SVI->getOperand(2))) + return std::vector(NElts, 2*NElts); + + std::vector Result; + const ConstantVector *CP = cast(SVI->getOperand(2)); + for (unsigned i = 0, e = CP->getNumOperands(); i != e; ++i) + if (isa(CP->getOperand(i))) + Result.push_back(NElts*2); // undef -> 8 + else + Result.push_back(cast(CP->getOperand(i))->getZExtValue()); + return Result; +} + +/// FindScalarElement - Given a vector and an element number, see if the scalar +/// value is already around as a register, for example if it were inserted then +/// extracted from the vector. +static Value *FindScalarElement(Value *V, unsigned EltNo) { + assert(isa(V->getType()) && "Not looking at a vector?"); + const VectorType *PTy = cast(V->getType()); + unsigned Width = PTy->getNumElements(); + if (EltNo >= Width) // Out of range access. + return UndefValue::get(PTy->getElementType()); + + if (isa(V)) + return UndefValue::get(PTy->getElementType()); + else if (isa(V)) + return Constant::getNullValue(PTy->getElementType()); + else if (ConstantVector *CP = dyn_cast(V)) + return CP->getOperand(EltNo); + else if (InsertElementInst *III = dyn_cast(V)) { + // If this is an insert to a variable element, we don't know what it is. + if (!isa(III->getOperand(2))) + return 0; + unsigned IIElt = cast(III->getOperand(2))->getZExtValue(); + + // If this is an insert to the element we are looking for, return the + // inserted value. + if (EltNo == IIElt) + return III->getOperand(1); + + // Otherwise, the insertelement doesn't modify the value, recurse on its + // vector input. + return FindScalarElement(III->getOperand(0), EltNo); + } else if (ShuffleVectorInst *SVI = dyn_cast(V)) { + unsigned InEl = getShuffleMask(SVI)[EltNo]; + if (InEl < Width) + return FindScalarElement(SVI->getOperand(0), InEl); + else if (InEl < Width*2) + return FindScalarElement(SVI->getOperand(1), InEl - Width); + else + return UndefValue::get(PTy->getElementType()); + } + + // Otherwise, we don't know. + return 0; +} + +Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { + + // If vector val is undef, replace extract with scalar undef. + if (isa(EI.getOperand(0))) + return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType())); + + // If vector val is constant 0, replace extract with scalar 0. + if (isa(EI.getOperand(0))) + return ReplaceInstUsesWith(EI, Constant::getNullValue(EI.getType())); + + if (ConstantVector *C = dyn_cast(EI.getOperand(0))) { + // If vector val is constant with uniform operands, replace EI + // with that operand + Constant *op0 = C->getOperand(0); + for (unsigned i = 1; i < C->getNumOperands(); ++i) + if (C->getOperand(i) != op0) { + op0 = 0; + break; + } + if (op0) + return ReplaceInstUsesWith(EI, op0); + } + + // If extracting a specified index from the vector, see if we can recursively + // find a previously computed scalar that was inserted into the vector. + if (ConstantInt *IdxC = dyn_cast(EI.getOperand(1))) { + unsigned IndexVal = IdxC->getZExtValue(); + unsigned VectorWidth = + cast(EI.getOperand(0)->getType())->getNumElements(); + + // If this is extracting an invalid index, turn this into undef, to avoid + // crashing the code below. + if (IndexVal >= VectorWidth) + return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType())); + + // This instruction only demands the single element from the input vector. + // If the input vector has a single use, simplify it based on this use + // property. + if (EI.getOperand(0)->hasOneUse() && VectorWidth != 1) { + uint64_t UndefElts; + if (Value *V = SimplifyDemandedVectorElts(EI.getOperand(0), + 1 << IndexVal, + UndefElts)) { + EI.setOperand(0, V); + return &EI; + } + } + + if (Value *Elt = FindScalarElement(EI.getOperand(0), IndexVal)) + return ReplaceInstUsesWith(EI, Elt); + + // If the this extractelement is directly using a bitcast from a vector of + // the same number of elements, see if we can find the source element from + // it. In this case, we will end up needing to bitcast the scalars. + if (BitCastInst *BCI = dyn_cast(EI.getOperand(0))) { + if (const VectorType *VT = + dyn_cast(BCI->getOperand(0)->getType())) + if (VT->getNumElements() == VectorWidth) + if (Value *Elt = FindScalarElement(BCI->getOperand(0), IndexVal)) + return new BitCastInst(Elt, EI.getType()); + } + } + + if (Instruction *I = dyn_cast(EI.getOperand(0))) { + if (I->hasOneUse()) { + // Push extractelement into predecessor operation if legal and + // profitable to do so + if (BinaryOperator *BO = dyn_cast(I)) { + bool isConstantElt = isa(EI.getOperand(1)); + if (CheapToScalarize(BO, isConstantElt)) { + ExtractElementInst *newEI0 = + new ExtractElementInst(BO->getOperand(0), EI.getOperand(1), + EI.getName()+".lhs"); + ExtractElementInst *newEI1 = + new ExtractElementInst(BO->getOperand(1), EI.getOperand(1), + EI.getName()+".rhs"); + InsertNewInstBefore(newEI0, EI); + InsertNewInstBefore(newEI1, EI); + return BinaryOperator::create(BO->getOpcode(), newEI0, newEI1); + } + } else if (isa(I)) { + Value *Ptr = InsertCastBefore(Instruction::BitCast, I->getOperand(0), + PointerType::get(EI.getType()), EI); + GetElementPtrInst *GEP = + new GetElementPtrInst(Ptr, EI.getOperand(1), I->getName() + ".gep"); + InsertNewInstBefore(GEP, EI); + return new LoadInst(GEP); + } + } + if (InsertElementInst *IE = dyn_cast(I)) { + // Extracting the inserted element? + if (IE->getOperand(2) == EI.getOperand(1)) + return ReplaceInstUsesWith(EI, IE->getOperand(1)); + // If the inserted and extracted elements are constants, they must not + // be the same value, extract from the pre-inserted value instead. + if (isa(IE->getOperand(2)) && + isa(EI.getOperand(1))) { + AddUsesToWorkList(EI); + EI.setOperand(0, IE->getOperand(0)); + return &EI; + } + } else if (ShuffleVectorInst *SVI = dyn_cast(I)) { + // If this is extracting an element from a shufflevector, figure out where + // it came from and extract from the appropriate input element instead. + if (ConstantInt *Elt = dyn_cast(EI.getOperand(1))) { + unsigned SrcIdx = getShuffleMask(SVI)[Elt->getZExtValue()]; + Value *Src; + if (SrcIdx < SVI->getType()->getNumElements()) + Src = SVI->getOperand(0); + else if (SrcIdx < SVI->getType()->getNumElements()*2) { + SrcIdx -= SVI->getType()->getNumElements(); + Src = SVI->getOperand(1); + } else { + return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType())); + } + return new ExtractElementInst(Src, SrcIdx); + } + } + } + return 0; +} + +/// CollectSingleShuffleElements - If V is a shuffle of values that ONLY returns +/// elements from either LHS or RHS, return the shuffle mask and true. +/// Otherwise, return false. +static bool CollectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, + std::vector &Mask) { + assert(V->getType() == LHS->getType() && V->getType() == RHS->getType() && + "Invalid CollectSingleShuffleElements"); + unsigned NumElts = cast(V->getType())->getNumElements(); + + if (isa(V)) { + Mask.assign(NumElts, UndefValue::get(Type::Int32Ty)); + return true; + } else if (V == LHS) { + for (unsigned i = 0; i != NumElts; ++i) + Mask.push_back(ConstantInt::get(Type::Int32Ty, i)); + return true; + } else if (V == RHS) { + for (unsigned i = 0; i != NumElts; ++i) + Mask.push_back(ConstantInt::get(Type::Int32Ty, i+NumElts)); + return true; + } else if (InsertElementInst *IEI = dyn_cast(V)) { + // If this is an insert of an extract from some other vector, include it. + Value *VecOp = IEI->getOperand(0); + Value *ScalarOp = IEI->getOperand(1); + Value *IdxOp = IEI->getOperand(2); + + if (!isa(IdxOp)) + return false; + unsigned InsertedIdx = cast(IdxOp)->getZExtValue(); + + if (isa(ScalarOp)) { // inserting undef into vector. + // Okay, we can handle this if the vector we are insertinting into is + // transitively ok. + if (CollectSingleShuffleElements(VecOp, LHS, RHS, Mask)) { + // If so, update the mask to reflect the inserted undef. + Mask[InsertedIdx] = UndefValue::get(Type::Int32Ty); + return true; + } + } else if (ExtractElementInst *EI = dyn_cast(ScalarOp)){ + if (isa(EI->getOperand(1)) && + EI->getOperand(0)->getType() == V->getType()) { + unsigned ExtractedIdx = + cast(EI->getOperand(1))->getZExtValue(); + + // This must be extracting from either LHS or RHS. + if (EI->getOperand(0) == LHS || EI->getOperand(0) == RHS) { + // Okay, we can handle this if the vector we are insertinting into is + // transitively ok. + if (CollectSingleShuffleElements(VecOp, LHS, RHS, Mask)) { + // If so, update the mask to reflect the inserted value. + if (EI->getOperand(0) == LHS) { + Mask[InsertedIdx & (NumElts-1)] = + ConstantInt::get(Type::Int32Ty, ExtractedIdx); + } else { + assert(EI->getOperand(0) == RHS); + Mask[InsertedIdx & (NumElts-1)] = + ConstantInt::get(Type::Int32Ty, ExtractedIdx+NumElts); + + } + return true; + } + } + } + } + } + // TODO: Handle shufflevector here! + + return false; +} + +/// CollectShuffleElements - We are building a shuffle of V, using RHS as the +/// RHS of the shuffle instruction, if it is not null. Return a shuffle mask +/// that computes V and the LHS value of the shuffle. +static Value *CollectShuffleElements(Value *V, std::vector &Mask, + Value *&RHS) { + assert(isa(V->getType()) && + (RHS == 0 || V->getType() == RHS->getType()) && + "Invalid shuffle!"); + unsigned NumElts = cast(V->getType())->getNumElements(); + + if (isa(V)) { + Mask.assign(NumElts, UndefValue::get(Type::Int32Ty)); + return V; + } else if (isa(V)) { + Mask.assign(NumElts, ConstantInt::get(Type::Int32Ty, 0)); + return V; + } else if (InsertElementInst *IEI = dyn_cast(V)) { + // If this is an insert of an extract from some other vector, include it. + Value *VecOp = IEI->getOperand(0); + Value *ScalarOp = IEI->getOperand(1); + Value *IdxOp = IEI->getOperand(2); + + if (ExtractElementInst *EI = dyn_cast(ScalarOp)) { + if (isa(EI->getOperand(1)) && isa(IdxOp) && + EI->getOperand(0)->getType() == V->getType()) { + unsigned ExtractedIdx = + cast(EI->getOperand(1))->getZExtValue(); + unsigned InsertedIdx = cast(IdxOp)->getZExtValue(); + + // Either the extracted from or inserted into vector must be RHSVec, + // otherwise we'd end up with a shuffle of three inputs. + if (EI->getOperand(0) == RHS || RHS == 0) { + RHS = EI->getOperand(0); + Value *V = CollectShuffleElements(VecOp, Mask, RHS); + Mask[InsertedIdx & (NumElts-1)] = + ConstantInt::get(Type::Int32Ty, NumElts+ExtractedIdx); + return V; + } + + if (VecOp == RHS) { + Value *V = CollectShuffleElements(EI->getOperand(0), Mask, RHS); + // Everything but the extracted element is replaced with the RHS. + for (unsigned i = 0; i != NumElts; ++i) { + if (i != InsertedIdx) + Mask[i] = ConstantInt::get(Type::Int32Ty, NumElts+i); + } + return V; + } + + // If this insertelement is a chain that comes from exactly these two + // vectors, return the vector and the effective shuffle. + if (CollectSingleShuffleElements(IEI, EI->getOperand(0), RHS, Mask)) + return EI->getOperand(0); + + } + } + } + // TODO: Handle shufflevector here! + + // Otherwise, can't do anything fancy. Return an identity vector. + for (unsigned i = 0; i != NumElts; ++i) + Mask.push_back(ConstantInt::get(Type::Int32Ty, i)); + return V; +} + +Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { + Value *VecOp = IE.getOperand(0); + Value *ScalarOp = IE.getOperand(1); + Value *IdxOp = IE.getOperand(2); + + // Inserting an undef or into an undefined place, remove this. + if (isa(ScalarOp) || isa(IdxOp)) + ReplaceInstUsesWith(IE, VecOp); + + // If the inserted element was extracted from some other vector, and if the + // indexes are constant, try to turn this into a shufflevector operation. + if (ExtractElementInst *EI = dyn_cast(ScalarOp)) { + if (isa(EI->getOperand(1)) && isa(IdxOp) && + EI->getOperand(0)->getType() == IE.getType()) { + unsigned NumVectorElts = IE.getType()->getNumElements(); + unsigned ExtractedIdx = + cast(EI->getOperand(1))->getZExtValue(); + unsigned InsertedIdx = cast(IdxOp)->getZExtValue(); + + if (ExtractedIdx >= NumVectorElts) // Out of range extract. + return ReplaceInstUsesWith(IE, VecOp); + + if (InsertedIdx >= NumVectorElts) // Out of range insert. + return ReplaceInstUsesWith(IE, UndefValue::get(IE.getType())); + + // If we are extracting a value from a vector, then inserting it right + // back into the same place, just use the input vector. + if (EI->getOperand(0) == VecOp && ExtractedIdx == InsertedIdx) + return ReplaceInstUsesWith(IE, VecOp); + + // We could theoretically do this for ANY input. However, doing so could + // turn chains of insertelement instructions into a chain of shufflevector + // instructions, and right now we do not merge shufflevectors. As such, + // only do this in a situation where it is clear that there is benefit. + if (isa(VecOp) || isa(VecOp)) { + // Turn this into shuffle(EIOp0, VecOp, Mask). The result has all of + // the values of VecOp, except then one read from EIOp0. + // Build a new shuffle mask. + std::vector Mask; + if (isa(VecOp)) + Mask.assign(NumVectorElts, UndefValue::get(Type::Int32Ty)); + else { + assert(isa(VecOp) && "Unknown thing"); + Mask.assign(NumVectorElts, ConstantInt::get(Type::Int32Ty, + NumVectorElts)); + } + Mask[InsertedIdx] = ConstantInt::get(Type::Int32Ty, ExtractedIdx); + return new ShuffleVectorInst(EI->getOperand(0), VecOp, + ConstantVector::get(Mask)); + } + + // If this insertelement isn't used by some other insertelement, turn it + // (and any insertelements it points to), into one big shuffle. + if (!IE.hasOneUse() || !isa(IE.use_back())) { + std::vector Mask; + Value *RHS = 0; + Value *LHS = CollectShuffleElements(&IE, Mask, RHS); + if (RHS == 0) RHS = UndefValue::get(LHS->getType()); + // We now have a shuffle of LHS, RHS, Mask. + return new ShuffleVectorInst(LHS, RHS, ConstantVector::get(Mask)); + } + } + } + + return 0; +} + + +Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { + Value *LHS = SVI.getOperand(0); + Value *RHS = SVI.getOperand(1); + std::vector Mask = getShuffleMask(&SVI); + + bool MadeChange = false; + + // Undefined shuffle mask -> undefined value. + if (isa(SVI.getOperand(2))) + return ReplaceInstUsesWith(SVI, UndefValue::get(SVI.getType())); + + // If we have shuffle(x, undef, mask) and any elements of mask refer to + // the undef, change them to undefs. + if (isa(SVI.getOperand(1))) { + // Scan to see if there are any references to the RHS. If so, replace them + // with undef element refs and set MadeChange to true. + for (unsigned i = 0, e = Mask.size(); i != e; ++i) { + if (Mask[i] >= e && Mask[i] != 2*e) { + Mask[i] = 2*e; + MadeChange = true; + } + } + + if (MadeChange) { + // Remap any references to RHS to use LHS. + std::vector Elts; + for (unsigned i = 0, e = Mask.size(); i != e; ++i) { + if (Mask[i] == 2*e) + Elts.push_back(UndefValue::get(Type::Int32Ty)); + else + Elts.push_back(ConstantInt::get(Type::Int32Ty, Mask[i])); + } + SVI.setOperand(2, ConstantVector::get(Elts)); + } + } + + // Canonicalize shuffle(x ,x,mask) -> shuffle(x, undef,mask') + // Canonicalize shuffle(undef,x,mask) -> shuffle(x, undef,mask'). + if (LHS == RHS || isa(LHS)) { + if (isa(LHS) && LHS == RHS) { + // shuffle(undef,undef,mask) -> undef. + return ReplaceInstUsesWith(SVI, LHS); + } + + // Remap any references to RHS to use LHS. + std::vector Elts; + for (unsigned i = 0, e = Mask.size(); i != e; ++i) { + if (Mask[i] >= 2*e) + Elts.push_back(UndefValue::get(Type::Int32Ty)); + else { + if ((Mask[i] >= e && isa(RHS)) || + (Mask[i] < e && isa(LHS))) + Mask[i] = 2*e; // Turn into undef. + else + Mask[i] &= (e-1); // Force to LHS. + Elts.push_back(ConstantInt::get(Type::Int32Ty, Mask[i])); + } + } + SVI.setOperand(0, SVI.getOperand(1)); + SVI.setOperand(1, UndefValue::get(RHS->getType())); + SVI.setOperand(2, ConstantVector::get(Elts)); + LHS = SVI.getOperand(0); + RHS = SVI.getOperand(1); + MadeChange = true; + } + + // Analyze the shuffle, are the LHS or RHS and identity shuffles? + bool isLHSID = true, isRHSID = true; + + for (unsigned i = 0, e = Mask.size(); i != e; ++i) { + if (Mask[i] >= e*2) continue; // Ignore undef values. + // Is this an identity shuffle of the LHS value? + isLHSID &= (Mask[i] == i); + + // Is this an identity shuffle of the RHS value? + isRHSID &= (Mask[i]-e == i); + } + + // Eliminate identity shuffles. + if (isLHSID) return ReplaceInstUsesWith(SVI, LHS); + if (isRHSID) return ReplaceInstUsesWith(SVI, RHS); + + // If the LHS is a shufflevector itself, see if we can combine it with this + // one without producing an unusual shuffle. Here we are really conservative: + // we are absolutely afraid of producing a shuffle mask not in the input + // program, because the code gen may not be smart enough to turn a merged + // shuffle into two specific shuffles: it may produce worse code. As such, + // we only merge two shuffles if the result is one of the two input shuffle + // masks. In this case, merging the shuffles just removes one instruction, + // which we know is safe. This is good for things like turning: + // (splat(splat)) -> splat. + if (ShuffleVectorInst *LHSSVI = dyn_cast(LHS)) { + if (isa(RHS)) { + std::vector LHSMask = getShuffleMask(LHSSVI); + + std::vector NewMask; + for (unsigned i = 0, e = Mask.size(); i != e; ++i) + if (Mask[i] >= 2*e) + NewMask.push_back(2*e); + else + NewMask.push_back(LHSMask[Mask[i]]); + + // If the result mask is equal to the src shuffle or this shuffle mask, do + // the replacement. + if (NewMask == LHSMask || NewMask == Mask) { + std::vector Elts; + for (unsigned i = 0, e = NewMask.size(); i != e; ++i) { + if (NewMask[i] >= e*2) { + Elts.push_back(UndefValue::get(Type::Int32Ty)); + } else { + Elts.push_back(ConstantInt::get(Type::Int32Ty, NewMask[i])); + } + } + return new ShuffleVectorInst(LHSSVI->getOperand(0), + LHSSVI->getOperand(1), + ConstantVector::get(Elts)); + } + } + } + + return MadeChange ? &SVI : 0; +} + + + + +/// TryToSinkInstruction - Try to move the specified instruction from its +/// current block into the beginning of DestBlock, which can only happen if it's +/// safe to move the instruction past all of the instructions between it and the +/// end of its block. +static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { + assert(I->hasOneUse() && "Invariants didn't hold!"); + + // Cannot move control-flow-involving, volatile loads, vaarg, etc. + if (isa(I) || I->mayWriteToMemory()) return false; + + // Do not sink alloca instructions out of the entry block. + if (isa(I) && I->getParent() == + &DestBlock->getParent()->getEntryBlock()) + return false; + + // We can only sink load instructions if there is nothing between the load and + // the end of block that could change the value. + if (LoadInst *LI = dyn_cast(I)) { + for (BasicBlock::iterator Scan = LI, E = LI->getParent()->end(); + Scan != E; ++Scan) + if (Scan->mayWriteToMemory()) + return false; + } + + BasicBlock::iterator InsertPos = DestBlock->begin(); + while (isa(InsertPos)) ++InsertPos; + + I->moveBefore(InsertPos); + ++NumSunkInst; + return true; +} + + +/// AddReachableCodeToWorklist - Walk the function in depth-first order, adding +/// all reachable code to the worklist. +/// +/// This has a couple of tricks to make the code faster and more powerful. In +/// particular, we constant fold and DCE instructions as we go, to avoid adding +/// them to the worklist (this significantly speeds up instcombine on code where +/// many instructions are dead or constant). Additionally, if we find a branch +/// whose condition is a known constant, we only visit the reachable successors. +/// +static void AddReachableCodeToWorklist(BasicBlock *BB, + SmallPtrSet &Visited, + InstCombiner &IC, + const TargetData *TD) { + std::vector Worklist; + Worklist.push_back(BB); + + while (!Worklist.empty()) { + BB = Worklist.back(); + Worklist.pop_back(); + + // We have now visited this block! If we've already been here, ignore it. + if (!Visited.insert(BB)) continue; + + for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E; ) { + Instruction *Inst = BBI++; + + // DCE instruction if trivially dead. + if (isInstructionTriviallyDead(Inst)) { + ++NumDeadInst; + DOUT << "IC: DCE: " << *Inst; + Inst->eraseFromParent(); + continue; + } + + // ConstantProp instruction if trivially constant. + if (Constant *C = ConstantFoldInstruction(Inst, TD)) { + DOUT << "IC: ConstFold to: " << *C << " from: " << *Inst; + Inst->replaceAllUsesWith(C); + ++NumConstProp; + Inst->eraseFromParent(); + continue; + } + + IC.AddToWorkList(Inst); + } + + // Recursively visit successors. If this is a branch or switch on a + // constant, only visit the reachable successor. + TerminatorInst *TI = BB->getTerminator(); + if (BranchInst *BI = dyn_cast(TI)) { + if (BI->isConditional() && isa(BI->getCondition())) { + bool CondVal = cast(BI->getCondition())->getZExtValue(); + Worklist.push_back(BI->getSuccessor(!CondVal)); + continue; + } + } else if (SwitchInst *SI = dyn_cast(TI)) { + if (ConstantInt *Cond = dyn_cast(SI->getCondition())) { + // See if this is an explicit destination. + for (unsigned i = 1, e = SI->getNumSuccessors(); i != e; ++i) + if (SI->getCaseValue(i) == Cond) { + Worklist.push_back(SI->getSuccessor(i)); + continue; + } + + // Otherwise it is the default destination. + Worklist.push_back(SI->getSuccessor(0)); + continue; + } + } + + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + Worklist.push_back(TI->getSuccessor(i)); + } +} + +bool InstCombiner::DoOneIteration(Function &F, unsigned Iteration) { + bool Changed = false; + TD = &getAnalysis(); + + DEBUG(DOUT << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " + << F.getNameStr() << "\n"); + + { + // Do a depth-first traversal of the function, populate the worklist with + // the reachable instructions. Ignore blocks that are not reachable. Keep + // track of which blocks we visit. + SmallPtrSet Visited; + AddReachableCodeToWorklist(F.begin(), Visited, *this, TD); + + // Do a quick scan over the function. If we find any blocks that are + // unreachable, remove any instructions inside of them. This prevents + // the instcombine code from having to deal with some bad special cases. + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (!Visited.count(BB)) { + Instruction *Term = BB->getTerminator(); + while (Term != BB->begin()) { // Remove instrs bottom-up + BasicBlock::iterator I = Term; --I; + + DOUT << "IC: DCE: " << *I; + ++NumDeadInst; + + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->eraseFromParent(); + } + } + } + + while (!Worklist.empty()) { + Instruction *I = RemoveOneFromWorkList(); + if (I == 0) continue; // skip null values. + + // Check to see if we can DCE the instruction. + if (isInstructionTriviallyDead(I)) { + // Add operands to the worklist. + if (I->getNumOperands() < 4) + AddUsesToWorkList(*I); + ++NumDeadInst; + + DOUT << "IC: DCE: " << *I; + + I->eraseFromParent(); + RemoveFromWorkList(I); + continue; + } + + // Instruction isn't dead, see if we can constant propagate it. + if (Constant *C = ConstantFoldInstruction(I, TD)) { + DOUT << "IC: ConstFold to: " << *C << " from: " << *I; + + // Add operands to the worklist. + AddUsesToWorkList(*I); + ReplaceInstUsesWith(*I, C); + + ++NumConstProp; + I->eraseFromParent(); + RemoveFromWorkList(I); + continue; + } + + // See if we can trivially sink this instruction to a successor basic block. + if (I->hasOneUse()) { + BasicBlock *BB = I->getParent(); + BasicBlock *UserParent = cast(I->use_back())->getParent(); + if (UserParent != BB) { + bool UserIsSuccessor = false; + // See if the user is one of our successors. + for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI) + if (*SI == UserParent) { + UserIsSuccessor = true; + break; + } + + // If the user is one of our immediate successors, and if that successor + // only has us as a predecessors (we'd have to split the critical edge + // otherwise), we can keep going. + if (UserIsSuccessor && !isa(I->use_back()) && + next(pred_begin(UserParent)) == pred_end(UserParent)) + // Okay, the CFG is simple enough, try to sink this instruction. + Changed |= TryToSinkInstruction(I, UserParent); + } + } + + // Now that we have an instruction, try combining it to simplify it... +#ifndef NDEBUG + std::string OrigI; +#endif + DEBUG(std::ostringstream SS; I->print(SS); OrigI = SS.str();); + if (Instruction *Result = visit(*I)) { + ++NumCombined; + // Should we replace the old instruction with a new one? + if (Result != I) { + DOUT << "IC: Old = " << *I + << " New = " << *Result; + + // Everything uses the new instruction now. + I->replaceAllUsesWith(Result); + + // Push the new instruction and any users onto the worklist. + AddToWorkList(Result); + AddUsersToWorkList(*Result); + + // Move the name to the new instruction first. + Result->takeName(I); + + // Insert the new instruction into the basic block... + BasicBlock *InstParent = I->getParent(); + BasicBlock::iterator InsertPos = I; + + if (!isa(Result)) // If combining a PHI, don't insert + while (isa(InsertPos)) // middle of a block of PHIs. + ++InsertPos; + + InstParent->getInstList().insert(InsertPos, Result); + + // Make sure that we reprocess all operands now that we reduced their + // use counts. + AddUsesToWorkList(*I); + + // Instructions can end up on the worklist more than once. Make sure + // we do not process an instruction that has been deleted. + RemoveFromWorkList(I); + + // Erase the old instruction. + InstParent->getInstList().erase(I); + } else { +#ifndef NDEBUG + DOUT << "IC: Mod = " << OrigI + << " New = " << *I; +#endif + + // If the instruction was modified, it's possible that it is now dead. + // if so, remove it. + if (isInstructionTriviallyDead(I)) { + // Make sure we process all operands now that we are reducing their + // use counts. + AddUsesToWorkList(*I); + + // Instructions may end up in the worklist more than once. Erase all + // occurrences of this instruction. + RemoveFromWorkList(I); + I->eraseFromParent(); + } else { + AddToWorkList(I); + AddUsersToWorkList(*I); + } + } + Changed = true; + } + } + + assert(WorklistMap.empty() && "Worklist empty, but map not?"); + return Changed; +} + + +bool InstCombiner::runOnFunction(Function &F) { + MustPreserveLCSSA = mustPreserveAnalysisID(LCSSAID); + + bool EverMadeChange = false; + + // Iterate while there is work to do. + unsigned Iteration = 0; + while (DoOneIteration(F, Iteration++)) + EverMadeChange = true; + return EverMadeChange; +} + +FunctionPass *llvm::createInstructionCombiningPass() { + return new InstCombiner(); +} + diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp new file mode 100644 index 0000000..77ac563 --- /dev/null +++ b/lib/Transforms/Scalar/LICM.cpp @@ -0,0 +1,797 @@ +//===-- LICM.cpp - Loop Invariant Code Motion Pass ------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs loop invariant code motion, attempting to remove as much +// code from the body of a loop as possible. It does this by either hoisting +// code into the preheader block, or by sinking code to the exit blocks if it is +// safe. This pass also promotes must-aliased memory locations in the loop to +// live in registers, thus hoisting and sinking "invariant" loads and stores. +// +// This pass uses alias analysis for two purposes: +// +// 1. Moving loop invariant loads and calls out of loops. If we can determine +// that a load or call inside of a loop never aliases anything stored to, +// we can hoist it or sink it like any other instruction. +// 2. Scalar Promotion of Memory - If there is a store instruction inside of +// the loop, we try to move the store to happen AFTER the loop instead of +// inside of the loop. This can only happen if a few conditions are true: +// A. The pointer stored through is loop invariant +// B. There are no stores or loads in the loop which _may_ alias the +// pointer. There are no calls in the loop which mod/ref the pointer. +// If these conditions are true, we can promote the loads and stores in the +// loop of the pointer to use a temporary alloca'd variable. We then use +// the mem2reg functionality to construct the appropriate SSA form for the +// variable. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "licm" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" +#include +using namespace llvm; + +STATISTIC(NumSunk , "Number of instructions sunk out of loop"); +STATISTIC(NumHoisted , "Number of instructions hoisted out of loop"); +STATISTIC(NumMovedLoads, "Number of load insts hoisted or sunk"); +STATISTIC(NumMovedCalls, "Number of call insts hoisted or sunk"); +STATISTIC(NumPromoted , "Number of memory locations promoted to registers"); + +namespace { + cl::opt + DisablePromotion("disable-licm-promotion", cl::Hidden, + cl::desc("Disable memory promotion in LICM pass")); + + struct VISIBILITY_HIDDEN LICM : public LoopPass { + static char ID; // Pass identification, replacement for typeid + LICM() : LoopPass((intptr_t)&ID) {} + + virtual bool runOnLoop(Loop *L, LPPassManager &LPM); + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG... + /// + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequiredID(LoopSimplifyID); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); // For scalar promotion (mem2reg) + AU.addRequired(); + } + + bool doFinalization() { + LoopToAliasMap.clear(); + return false; + } + + private: + // Various analyses that we use... + AliasAnalysis *AA; // Current AliasAnalysis information + LoopInfo *LI; // Current LoopInfo + DominatorTree *DT; // Dominator Tree for the current Loop... + DominanceFrontier *DF; // Current Dominance Frontier + + // State that is updated as we process loops + bool Changed; // Set to true when we change anything. + BasicBlock *Preheader; // The preheader block of the current loop... + Loop *CurLoop; // The current loop we are working on... + AliasSetTracker *CurAST; // AliasSet information for the current loop... + std::map LoopToAliasMap; + + /// SinkRegion - Walk the specified region of the CFG (defined by all blocks + /// dominated by the specified block, and that are in the current loop) in + /// reverse depth first order w.r.t the DominatorTree. This allows us to + /// visit uses before definitions, allowing us to sink a loop body in one + /// pass without iteration. + /// + void SinkRegion(DomTreeNode *N); + + /// HoistRegion - Walk the specified region of the CFG (defined by all + /// blocks dominated by the specified block, and that are in the current + /// loop) in depth first order w.r.t the DominatorTree. This allows us to + /// visit definitions before uses, allowing us to hoist a loop body in one + /// pass without iteration. + /// + void HoistRegion(DomTreeNode *N); + + /// inSubLoop - Little predicate that returns true if the specified basic + /// block is in a subloop of the current one, not the current one itself. + /// + bool inSubLoop(BasicBlock *BB) { + assert(CurLoop->contains(BB) && "Only valid if BB is IN the loop"); + for (Loop::iterator I = CurLoop->begin(), E = CurLoop->end(); I != E; ++I) + if ((*I)->contains(BB)) + return true; // A subloop actually contains this block! + return false; + } + + /// isExitBlockDominatedByBlockInLoop - This method checks to see if the + /// specified exit block of the loop is dominated by the specified block + /// that is in the body of the loop. We use these constraints to + /// dramatically limit the amount of the dominator tree that needs to be + /// searched. + bool isExitBlockDominatedByBlockInLoop(BasicBlock *ExitBlock, + BasicBlock *BlockInLoop) const { + // If the block in the loop is the loop header, it must be dominated! + BasicBlock *LoopHeader = CurLoop->getHeader(); + if (BlockInLoop == LoopHeader) + return true; + + DomTreeNode *BlockInLoopNode = DT->getNode(BlockInLoop); + DomTreeNode *IDom = DT->getNode(ExitBlock); + + // Because the exit block is not in the loop, we know we have to get _at + // least_ its immediate dominator. + do { + // Get next Immediate Dominator. + IDom = IDom->getIDom(); + + // If we have got to the header of the loop, then the instructions block + // did not dominate the exit node, so we can't hoist it. + if (IDom->getBlock() == LoopHeader) + return false; + + } while (IDom != BlockInLoopNode); + + return true; + } + + /// sink - When an instruction is found to only be used outside of the loop, + /// this function moves it to the exit blocks and patches up SSA form as + /// needed. + /// + void sink(Instruction &I); + + /// hoist - When an instruction is found to only use loop invariant operands + /// that is safe to hoist, this instruction is called to do the dirty work. + /// + void hoist(Instruction &I); + + /// isSafeToExecuteUnconditionally - Only sink or hoist an instruction if it + /// is not a trapping instruction or if it is a trapping instruction and is + /// guaranteed to execute. + /// + bool isSafeToExecuteUnconditionally(Instruction &I); + + /// pointerInvalidatedByLoop - Return true if the body of this loop may + /// store into the memory location pointed to by V. + /// + bool pointerInvalidatedByLoop(Value *V, unsigned Size) { + // Check to see if any of the basic blocks in CurLoop invalidate *V. + return CurAST->getAliasSetForPointer(V, Size).isMod(); + } + + bool canSinkOrHoistInst(Instruction &I); + bool isLoopInvariantInst(Instruction &I); + bool isNotUsedInLoop(Instruction &I); + + /// PromoteValuesInLoop - Look at the stores in the loop and promote as many + /// to scalars as we can. + /// + void PromoteValuesInLoop(); + + /// FindPromotableValuesInLoop - Check the current loop for stores to + /// definite pointers, which are not loaded and stored through may aliases. + /// If these are found, create an alloca for the value, add it to the + /// PromotedValues list, and keep track of the mapping from value to + /// alloca... + /// + void FindPromotableValuesInLoop( + std::vector > &PromotedValues, + std::map &Val2AlMap); + }; + + char LICM::ID = 0; + RegisterPass X("licm", "Loop Invariant Code Motion"); +} + +LoopPass *llvm::createLICMPass() { return new LICM(); } + +/// Hoist expressions out of the specified loop... +/// +bool LICM::runOnLoop(Loop *L, LPPassManager &LPM) { + Changed = false; + + // Get our Loop and Alias Analysis information... + LI = &getAnalysis(); + AA = &getAnalysis(); + DF = &getAnalysis(); + DT = &getAnalysis(); + + CurAST = new AliasSetTracker(*AA); + // Collect Alias info from subloops + for (Loop::iterator LoopItr = L->begin(), LoopItrE = L->end(); + LoopItr != LoopItrE; ++LoopItr) { + Loop *InnerL = *LoopItr; + AliasSetTracker *InnerAST = LoopToAliasMap[InnerL]; + assert (InnerAST && "Where is my AST?"); + + // What if InnerLoop was modified by other passes ? + CurAST->add(*InnerAST); + } + + CurLoop = L; + + // Get the preheader block to move instructions into... + Preheader = L->getLoopPreheader(); + assert(Preheader&&"Preheader insertion pass guarantees we have a preheader!"); + + // Loop over the body of this loop, looking for calls, invokes, and stores. + // Because subloops have already been incorporated into AST, we skip blocks in + // subloops. + // + for (std::vector::const_iterator I = L->getBlocks().begin(), + E = L->getBlocks().end(); I != E; ++I) + if (LI->getLoopFor(*I) == L) // Ignore blocks in subloops... + CurAST->add(**I); // Incorporate the specified basic block + + // We want to visit all of the instructions in this loop... that are not parts + // of our subloops (they have already had their invariants hoisted out of + // their loop, into this loop, so there is no need to process the BODIES of + // the subloops). + // + // Traverse the body of the loop in depth first order on the dominator tree so + // that we are guaranteed to see definitions before we see uses. This allows + // us to sink instructions in one pass, without iteration. AFter sinking + // instructions, we perform another pass to hoist them out of the loop. + // + SinkRegion(DT->getNode(L->getHeader())); + HoistRegion(DT->getNode(L->getHeader())); + + // Now that all loop invariants have been removed from the loop, promote any + // memory references to scalars that we can... + if (!DisablePromotion) + PromoteValuesInLoop(); + + // Clear out loops state information for the next iteration + CurLoop = 0; + Preheader = 0; + + LoopToAliasMap[L] = CurAST; + return Changed; +} + +/// SinkRegion - Walk the specified region of the CFG (defined by all blocks +/// dominated by the specified block, and that are in the current loop) in +/// reverse depth first order w.r.t the DominatorTree. This allows us to visit +/// uses before definitions, allowing us to sink a loop body in one pass without +/// iteration. +/// +void LICM::SinkRegion(DomTreeNode *N) { + assert(N != 0 && "Null dominator tree node?"); + BasicBlock *BB = N->getBlock(); + + // If this subregion is not in the top level loop at all, exit. + if (!CurLoop->contains(BB)) return; + + // We are processing blocks in reverse dfo, so process children first... + const std::vector &Children = N->getChildren(); + for (unsigned i = 0, e = Children.size(); i != e; ++i) + SinkRegion(Children[i]); + + // Only need to process the contents of this block if it is not part of a + // subloop (which would already have been processed). + if (inSubLoop(BB)) return; + + for (BasicBlock::iterator II = BB->end(); II != BB->begin(); ) { + Instruction &I = *--II; + + // Check to see if we can sink this instruction to the exit blocks + // of the loop. We can do this if the all users of the instruction are + // outside of the loop. In this case, it doesn't even matter if the + // operands of the instruction are loop invariant. + // + if (isNotUsedInLoop(I) && canSinkOrHoistInst(I)) { + ++II; + sink(I); + } + } +} + + +/// HoistRegion - Walk the specified region of the CFG (defined by all blocks +/// dominated by the specified block, and that are in the current loop) in depth +/// first order w.r.t the DominatorTree. This allows us to visit definitions +/// before uses, allowing us to hoist a loop body in one pass without iteration. +/// +void LICM::HoistRegion(DomTreeNode *N) { + assert(N != 0 && "Null dominator tree node?"); + BasicBlock *BB = N->getBlock(); + + // If this subregion is not in the top level loop at all, exit. + if (!CurLoop->contains(BB)) return; + + // Only need to process the contents of this block if it is not part of a + // subloop (which would already have been processed). + if (!inSubLoop(BB)) + for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E; ) { + Instruction &I = *II++; + + // Try hoisting the instruction out to the preheader. We can only do this + // if all of the operands of the instruction are loop invariant and if it + // is safe to hoist the instruction. + // + if (isLoopInvariantInst(I) && canSinkOrHoistInst(I) && + isSafeToExecuteUnconditionally(I)) + hoist(I); + } + + const std::vector &Children = N->getChildren(); + for (unsigned i = 0, e = Children.size(); i != e; ++i) + HoistRegion(Children[i]); +} + +/// canSinkOrHoistInst - Return true if the hoister and sinker can handle this +/// instruction. +/// +bool LICM::canSinkOrHoistInst(Instruction &I) { + // Loads have extra constraints we have to verify before we can hoist them. + if (LoadInst *LI = dyn_cast(&I)) { + if (LI->isVolatile()) + return false; // Don't hoist volatile loads! + + // Don't hoist loads which have may-aliased stores in loop. + unsigned Size = 0; + if (LI->getType()->isSized()) + Size = AA->getTargetData().getTypeSize(LI->getType()); + return !pointerInvalidatedByLoop(LI->getOperand(0), Size); + } else if (CallInst *CI = dyn_cast(&I)) { + // Handle obvious cases efficiently. + if (Function *Callee = CI->getCalledFunction()) { + AliasAnalysis::ModRefBehavior Behavior =AA->getModRefBehavior(Callee, CI); + if (Behavior == AliasAnalysis::DoesNotAccessMemory) + return true; + else if (Behavior == AliasAnalysis::OnlyReadsMemory) { + // If this call only reads from memory and there are no writes to memory + // in the loop, we can hoist or sink the call as appropriate. + bool FoundMod = false; + for (AliasSetTracker::iterator I = CurAST->begin(), E = CurAST->end(); + I != E; ++I) { + AliasSet &AS = *I; + if (!AS.isForwardingAliasSet() && AS.isMod()) { + FoundMod = true; + break; + } + } + if (!FoundMod) return true; + } + } + + // FIXME: This should use mod/ref information to see if we can hoist or sink + // the call. + + return false; + } + + // Otherwise these instructions are hoistable/sinkable + return isa(I) || isa(I) || + isa(I) || isa(I) || isa(I) || + isa(I) || isa(I) || + isa(I); +} + +/// isNotUsedInLoop - Return true if the only users of this instruction are +/// outside of the loop. If this is true, we can sink the instruction to the +/// exit blocks of the loop. +/// +bool LICM::isNotUsedInLoop(Instruction &I) { + for (Value::use_iterator UI = I.use_begin(), E = I.use_end(); UI != E; ++UI) { + Instruction *User = cast(*UI); + if (PHINode *PN = dyn_cast(User)) { + // PHI node uses occur in predecessor blocks! + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == &I) + if (CurLoop->contains(PN->getIncomingBlock(i))) + return false; + } else if (CurLoop->contains(User->getParent())) { + return false; + } + } + return true; +} + + +/// isLoopInvariantInst - Return true if all operands of this instruction are +/// loop invariant. We also filter out non-hoistable instructions here just for +/// efficiency. +/// +bool LICM::isLoopInvariantInst(Instruction &I) { + // The instruction is loop invariant if all of its operands are loop-invariant + for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) + if (!CurLoop->isLoopInvariant(I.getOperand(i))) + return false; + + // If we got this far, the instruction is loop invariant! + return true; +} + +/// sink - When an instruction is found to only be used outside of the loop, +/// this function moves it to the exit blocks and patches up SSA form as needed. +/// This method is guaranteed to remove the original instruction from its +/// position, and may either delete it or move it to outside of the loop. +/// +void LICM::sink(Instruction &I) { + DOUT << "LICM sinking instruction: " << I; + + std::vector ExitBlocks; + CurLoop->getExitBlocks(ExitBlocks); + + if (isa(I)) ++NumMovedLoads; + else if (isa(I)) ++NumMovedCalls; + ++NumSunk; + Changed = true; + + // The case where there is only a single exit node of this loop is common + // enough that we handle it as a special (more efficient) case. It is more + // efficient to handle because there are no PHI nodes that need to be placed. + if (ExitBlocks.size() == 1) { + if (!isExitBlockDominatedByBlockInLoop(ExitBlocks[0], I.getParent())) { + // Instruction is not used, just delete it. + CurAST->deleteValue(&I); + if (!I.use_empty()) // If I has users in unreachable blocks, eliminate. + I.replaceAllUsesWith(UndefValue::get(I.getType())); + I.eraseFromParent(); + } else { + // Move the instruction to the start of the exit block, after any PHI + // nodes in it. + I.removeFromParent(); + + BasicBlock::iterator InsertPt = ExitBlocks[0]->begin(); + while (isa(InsertPt)) ++InsertPt; + ExitBlocks[0]->getInstList().insert(InsertPt, &I); + } + } else if (ExitBlocks.size() == 0) { + // The instruction is actually dead if there ARE NO exit blocks. + CurAST->deleteValue(&I); + if (!I.use_empty()) // If I has users in unreachable blocks, eliminate. + I.replaceAllUsesWith(UndefValue::get(I.getType())); + I.eraseFromParent(); + } else { + // Otherwise, if we have multiple exits, use the PromoteMem2Reg function to + // do all of the hard work of inserting PHI nodes as necessary. We convert + // the value into a stack object to get it to do this. + + // Firstly, we create a stack object to hold the value... + AllocaInst *AI = 0; + + if (I.getType() != Type::VoidTy) { + AI = new AllocaInst(I.getType(), 0, I.getName(), + I.getParent()->getParent()->getEntryBlock().begin()); + CurAST->add(AI); + } + + // Secondly, insert load instructions for each use of the instruction + // outside of the loop. + while (!I.use_empty()) { + Instruction *U = cast(I.use_back()); + + // If the user is a PHI Node, we actually have to insert load instructions + // in all predecessor blocks, not in the PHI block itself! + if (PHINode *UPN = dyn_cast(U)) { + // Only insert into each predecessor once, so that we don't have + // different incoming values from the same block! + std::map InsertedBlocks; + for (unsigned i = 0, e = UPN->getNumIncomingValues(); i != e; ++i) + if (UPN->getIncomingValue(i) == &I) { + BasicBlock *Pred = UPN->getIncomingBlock(i); + Value *&PredVal = InsertedBlocks[Pred]; + if (!PredVal) { + // Insert a new load instruction right before the terminator in + // the predecessor block. + PredVal = new LoadInst(AI, "", Pred->getTerminator()); + CurAST->add(cast(PredVal)); + } + + UPN->setIncomingValue(i, PredVal); + } + + } else { + LoadInst *L = new LoadInst(AI, "", U); + U->replaceUsesOfWith(&I, L); + CurAST->add(L); + } + } + + // Thirdly, insert a copy of the instruction in each exit block of the loop + // that is dominated by the instruction, storing the result into the memory + // location. Be careful not to insert the instruction into any particular + // basic block more than once. + std::set InsertedBlocks; + BasicBlock *InstOrigBB = I.getParent(); + + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + BasicBlock *ExitBlock = ExitBlocks[i]; + + if (isExitBlockDominatedByBlockInLoop(ExitBlock, InstOrigBB)) { + // If we haven't already processed this exit block, do so now. + if (InsertedBlocks.insert(ExitBlock).second) { + // Insert the code after the last PHI node... + BasicBlock::iterator InsertPt = ExitBlock->begin(); + while (isa(InsertPt)) ++InsertPt; + + // If this is the first exit block processed, just move the original + // instruction, otherwise clone the original instruction and insert + // the copy. + Instruction *New; + if (InsertedBlocks.size() == 1) { + I.removeFromParent(); + ExitBlock->getInstList().insert(InsertPt, &I); + New = &I; + } else { + New = I.clone(); + CurAST->copyValue(&I, New); + if (!I.getName().empty()) + New->setName(I.getName()+".le"); + ExitBlock->getInstList().insert(InsertPt, New); + } + + // Now that we have inserted the instruction, store it into the alloca + if (AI) new StoreInst(New, AI, InsertPt); + } + } + } + + // If the instruction doesn't dominate any exit blocks, it must be dead. + if (InsertedBlocks.empty()) { + CurAST->deleteValue(&I); + I.eraseFromParent(); + } + + // Finally, promote the fine value to SSA form. + if (AI) { + std::vector Allocas; + Allocas.push_back(AI); + PromoteMemToReg(Allocas, *DT, *DF, CurAST); + } + } +} + +/// hoist - When an instruction is found to only use loop invariant operands +/// that is safe to hoist, this instruction is called to do the dirty work. +/// +void LICM::hoist(Instruction &I) { + DOUT << "LICM hoisting to " << Preheader->getName() << ": " << I; + + // Remove the instruction from its current basic block... but don't delete the + // instruction. + I.removeFromParent(); + + // Insert the new node in Preheader, before the terminator. + Preheader->getInstList().insert(Preheader->getTerminator(), &I); + + if (isa(I)) ++NumMovedLoads; + else if (isa(I)) ++NumMovedCalls; + ++NumHoisted; + Changed = true; +} + +/// isSafeToExecuteUnconditionally - Only sink or hoist an instruction if it is +/// not a trapping instruction or if it is a trapping instruction and is +/// guaranteed to execute. +/// +bool LICM::isSafeToExecuteUnconditionally(Instruction &Inst) { + // If it is not a trapping instruction, it is always safe to hoist. + if (!Inst.isTrapping()) return true; + + // Otherwise we have to check to make sure that the instruction dominates all + // of the exit blocks. If it doesn't, then there is a path out of the loop + // which does not execute this instruction, so we can't hoist it. + + // If the instruction is in the header block for the loop (which is very + // common), it is always guaranteed to dominate the exit blocks. Since this + // is a common case, and can save some work, check it now. + if (Inst.getParent() == CurLoop->getHeader()) + return true; + + // It's always safe to load from a global or alloca. + if (isa(Inst)) + if (isa(Inst.getOperand(0)) || + isa(Inst.getOperand(0))) + return true; + + // Get the exit blocks for the current loop. + std::vector ExitBlocks; + CurLoop->getExitBlocks(ExitBlocks); + + // For each exit block, get the DT node and walk up the DT until the + // instruction's basic block is found or we exit the loop. + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) + if (!isExitBlockDominatedByBlockInLoop(ExitBlocks[i], Inst.getParent())) + return false; + + return true; +} + + +/// PromoteValuesInLoop - Try to promote memory values to scalars by sinking +/// stores out of the loop and moving loads to before the loop. We do this by +/// looping over the stores in the loop, looking for stores to Must pointers +/// which are loop invariant. We promote these memory locations to use allocas +/// instead. These allocas can easily be raised to register values by the +/// PromoteMem2Reg functionality. +/// +void LICM::PromoteValuesInLoop() { + // PromotedValues - List of values that are promoted out of the loop. Each + // value has an alloca instruction for it, and a canonical version of the + // pointer. + std::vector > PromotedValues; + std::map ValueToAllocaMap; // Map of ptr to alloca + + FindPromotableValuesInLoop(PromotedValues, ValueToAllocaMap); + if (ValueToAllocaMap.empty()) return; // If there are values to promote. + + Changed = true; + NumPromoted += PromotedValues.size(); + + std::vector PointerValueNumbers; + + // Emit a copy from the value into the alloca'd value in the loop preheader + TerminatorInst *LoopPredInst = Preheader->getTerminator(); + for (unsigned i = 0, e = PromotedValues.size(); i != e; ++i) { + Value *Ptr = PromotedValues[i].second; + + // If we are promoting a pointer value, update alias information for the + // inserted load. + Value *LoadValue = 0; + if (isa(cast(Ptr->getType())->getElementType())) { + // Locate a load or store through the pointer, and assign the same value + // to LI as we are loading or storing. Since we know that the value is + // stored in this loop, this will always succeed. + for (Value::use_iterator UI = Ptr->use_begin(), E = Ptr->use_end(); + UI != E; ++UI) + if (LoadInst *LI = dyn_cast(*UI)) { + LoadValue = LI; + break; + } else if (StoreInst *SI = dyn_cast(*UI)) { + if (SI->getOperand(1) == Ptr) { + LoadValue = SI->getOperand(0); + break; + } + } + assert(LoadValue && "No store through the pointer found!"); + PointerValueNumbers.push_back(LoadValue); // Remember this for later. + } + + // Load from the memory we are promoting. + LoadInst *LI = new LoadInst(Ptr, Ptr->getName()+".promoted", LoopPredInst); + + if (LoadValue) CurAST->copyValue(LoadValue, LI); + + // Store into the temporary alloca. + new StoreInst(LI, PromotedValues[i].first, LoopPredInst); + } + + // Scan the basic blocks in the loop, replacing uses of our pointers with + // uses of the allocas in question. + // + const std::vector &LoopBBs = CurLoop->getBlocks(); + for (std::vector::const_iterator I = LoopBBs.begin(), + E = LoopBBs.end(); I != E; ++I) { + // Rewrite all loads and stores in the block of the pointer... + for (BasicBlock::iterator II = (*I)->begin(), E = (*I)->end(); + II != E; ++II) { + if (LoadInst *L = dyn_cast(II)) { + std::map::iterator + I = ValueToAllocaMap.find(L->getOperand(0)); + if (I != ValueToAllocaMap.end()) + L->setOperand(0, I->second); // Rewrite load instruction... + } else if (StoreInst *S = dyn_cast(II)) { + std::map::iterator + I = ValueToAllocaMap.find(S->getOperand(1)); + if (I != ValueToAllocaMap.end()) + S->setOperand(1, I->second); // Rewrite store instruction... + } + } + } + + // Now that the body of the loop uses the allocas instead of the original + // memory locations, insert code to copy the alloca value back into the + // original memory location on all exits from the loop. Note that we only + // want to insert one copy of the code in each exit block, though the loop may + // exit to the same block more than once. + // + std::set ProcessedBlocks; + + std::vector ExitBlocks; + CurLoop->getExitBlocks(ExitBlocks); + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) + if (ProcessedBlocks.insert(ExitBlocks[i]).second) { + // Copy all of the allocas into their memory locations. + BasicBlock::iterator BI = ExitBlocks[i]->begin(); + while (isa(*BI)) + ++BI; // Skip over all of the phi nodes in the block. + Instruction *InsertPos = BI; + unsigned PVN = 0; + for (unsigned i = 0, e = PromotedValues.size(); i != e; ++i) { + // Load from the alloca. + LoadInst *LI = new LoadInst(PromotedValues[i].first, "", InsertPos); + + // If this is a pointer type, update alias info appropriately. + if (isa(LI->getType())) + CurAST->copyValue(PointerValueNumbers[PVN++], LI); + + // Store into the memory we promoted. + new StoreInst(LI, PromotedValues[i].second, InsertPos); + } + } + + // Now that we have done the deed, use the mem2reg functionality to promote + // all of the new allocas we just created into real SSA registers. + // + std::vector PromotedAllocas; + PromotedAllocas.reserve(PromotedValues.size()); + for (unsigned i = 0, e = PromotedValues.size(); i != e; ++i) + PromotedAllocas.push_back(PromotedValues[i].first); + PromoteMemToReg(PromotedAllocas, *DT, *DF, CurAST); +} + +/// FindPromotableValuesInLoop - Check the current loop for stores to definite +/// pointers, which are not loaded and stored through may aliases. If these are +/// found, create an alloca for the value, add it to the PromotedValues list, +/// and keep track of the mapping from value to alloca. +/// +void LICM::FindPromotableValuesInLoop( + std::vector > &PromotedValues, + std::map &ValueToAllocaMap) { + Instruction *FnStart = CurLoop->getHeader()->getParent()->begin()->begin(); + + // Loop over all of the alias sets in the tracker object. + for (AliasSetTracker::iterator I = CurAST->begin(), E = CurAST->end(); + I != E; ++I) { + AliasSet &AS = *I; + // We can promote this alias set if it has a store, if it is a "Must" alias + // set, if the pointer is loop invariant, and if we are not eliminating any + // volatile loads or stores. + if (!AS.isForwardingAliasSet() && AS.isMod() && AS.isMustAlias() && + !AS.isVolatile() && CurLoop->isLoopInvariant(AS.begin()->first)) { + assert(AS.begin() != AS.end() && + "Must alias set should have at least one pointer element in it!"); + Value *V = AS.begin()->first; + + // Check that all of the pointers in the alias set have the same type. We + // cannot (yet) promote a memory location that is loaded and stored in + // different sizes. + bool PointerOk = true; + for (AliasSet::iterator I = AS.begin(), E = AS.end(); I != E; ++I) + if (V->getType() != I->first->getType()) { + PointerOk = false; + break; + } + + if (PointerOk) { + const Type *Ty = cast(V->getType())->getElementType(); + AllocaInst *AI = new AllocaInst(Ty, 0, V->getName()+".tmp", FnStart); + PromotedValues.push_back(std::make_pair(AI, V)); + + // Update the AST and alias analysis. + CurAST->copyValue(V, AI); + + for (AliasSet::iterator I = AS.begin(), E = AS.end(); I != E; ++I) + ValueToAllocaMap.insert(std::make_pair(I->first, AI)); + + DOUT << "LICM: Promoting value: " << *V << "\n"; + } + } + } +} diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp new file mode 100644 index 0000000..d35a8ed --- /dev/null +++ b/lib/Transforms/Scalar/LoopRotation.cpp @@ -0,0 +1,579 @@ +//===- LoopRotation.cpp - Loop Rotation Pass ------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by Devang Patel and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements Loop Rotation Pass. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-rotate" + +#include "llvm/Transforms/Scalar.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SmallVector.h" + +using namespace llvm; + +#define MAX_HEADER_SIZE 16 + +STATISTIC(NumRotated, "Number of loops rotated"); +namespace { + + class VISIBILITY_HIDDEN RenameData { + public: + RenameData(Instruction *O, Value *P, Instruction *H) + : Original(O), PreHeader(P), Header(H) { } + public: + Instruction *Original; // Original instruction + Value *PreHeader; // Original pre-header replacement + Instruction *Header; // New header replacement + }; + + class VISIBILITY_HIDDEN LoopRotate : public LoopPass { + + public: + static char ID; // Pass ID, replacement for typeid + LoopRotate() : LoopPass((intptr_t)&ID) {} + + // Rotate Loop L as many times as possible. Return true if + // loop is rotated at least once. + bool runOnLoop(Loop *L, LPPassManager &LPM); + + // LCSSA form makes instruction renaming easier. + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(LCSSAID); + AU.addPreservedID(LCSSAID); + AU.addPreserved(); + AU.addPreserved(); + AU.addRequiredID(LoopSimplifyID); + AU.addPreservedID(LoopSimplifyID); + } + + // Helper functions + + /// Do actual work + bool rotateLoop(Loop *L, LPPassManager &LPM); + + /// Initialize local data + void initialize(); + + /// Make sure all Exit block PHINodes have required incoming values. + /// If incoming value is constant or defined outside the loop then + /// PHINode may not have an entry for original pre-header. + void updateExitBlock(); + + /// Return true if this instruction is used outside original header. + bool usedOutsideOriginalHeader(Instruction *In); + + /// Find Replacement information for instruction. Return NULL if it is + /// not available. + const RenameData *findReplacementData(Instruction *I); + + /// After loop rotation, loop pre-header has multiple sucessors. + /// Insert one forwarding basic block to ensure that loop pre-header + /// has only one successor. + void preserveCanonicalLoopForm(LPPassManager &LPM); + + private: + + Loop *L; + BasicBlock *OrigHeader; + BasicBlock *OrigPreHeader; + BasicBlock *OrigLatch; + BasicBlock *NewHeader; + BasicBlock *Exit; + LPPassManager *LPM_Ptr; + SmallVector LoopHeaderInfo; + }; + + char LoopRotate::ID = 0; + RegisterPass X ("loop-rotate", "Rotate Loops"); +} + +LoopPass *llvm::createLoopRotatePass() { return new LoopRotate(); } + +/// Rotate Loop L as many times as possible. Return true if +/// loop is rotated at least once. +bool LoopRotate::runOnLoop(Loop *Lp, LPPassManager &LPM) { + + bool RotatedOneLoop = false; + initialize(); + LPM_Ptr = &LPM; + + // One loop can be rotated multiple times. + while (rotateLoop(Lp,LPM)) { + RotatedOneLoop = true; + initialize(); + } + + return RotatedOneLoop; +} + +/// Rotate loop LP. Return true if the loop is rotated. +bool LoopRotate::rotateLoop(Loop *Lp, LPPassManager &LPM) { + + L = Lp; + + OrigHeader = L->getHeader(); + OrigPreHeader = L->getLoopPreheader(); + OrigLatch = L->getLoopLatch(); + + // If loop has only one block then there is not much to rotate. + if (L->getBlocks().size() == 1) + return false; + + assert (OrigHeader && OrigLatch && OrigPreHeader && + "Loop is not in canonical form"); + + // If loop header is not one of the loop exit block then + // either this loop is already rotated or it is not + // suitable for loop rotation transformations. + if (!L->isLoopExit(OrigHeader)) + return false; + + BranchInst *BI = dyn_cast(OrigHeader->getTerminator()); + if (!BI) + return false; + assert (BI->isConditional() && "Branch Instruction is not condiitional"); + + // Updating PHInodes in loops with multiple exits adds complexity. + // Keep it simple, and restrict loop rotation to loops with one exit only. + // In future, lift this restriction and support for multiple exits if + // required. + std::vector ExitBlocks; + L->getExitBlocks(ExitBlocks); + if (ExitBlocks.size() > 1) + return false; + + // Check size of original header and reject + // loop if it is very big. + if (OrigHeader->getInstList().size() > MAX_HEADER_SIZE) + return false; + + // Now, this loop is suitable for rotation. + + // Find new Loop header. NewHeader is a Header's one and only successor + // that is inside loop. Header's other successor is out side the + // loop. Otherwise loop is not suitable for rotation. + Exit = BI->getSuccessor(0); + NewHeader = BI->getSuccessor(1); + if (L->contains(Exit)) + std::swap(Exit, NewHeader); + assert (NewHeader && "Unable to determine new loop header"); + assert(L->contains(NewHeader) && !L->contains(Exit) && + "Unable to determine loop header and exit blocks"); + + // Copy PHI nodes and other instructions from original header + // into original pre-header. Unlike original header, original pre-header is + // not a member of loop. + // + // New loop header is one and only successor of original header that + // is inside the loop. All other original header successors are outside + // the loop. Copy PHI Nodes from original header into new loop header. + // Add second incoming value, from original loop pre-header into these phi + // nodes. If a value defined in original header is used outside original + // header then new loop header will need new phi nodes with two incoming + // values, one definition from original header and second definition is + // from original loop pre-header. + + // Remove terminator from Original pre-header. Original pre-header will + // receive a clone of original header terminator as a new terminator. + OrigPreHeader->getInstList().pop_back(); + BasicBlock::iterator I = OrigHeader->begin(), E = OrigHeader->end(); + PHINode *PN = NULL; + for (; (PN = dyn_cast(I)); ++I) { + Instruction *In = I; + + // PHI nodes are not copied into original pre-header. Instead their values + // are directly propagated. + Value * NPV = PN->getIncomingValueForBlock(OrigPreHeader); + + // Create new PHI node with two incoming values for NewHeader. + // One incoming value is from OrigLatch (through OrigHeader) and + // second incoming value is from original pre-header. + PHINode *NH = new PHINode(In->getType(), In->getName()); + NH->addIncoming(PN->getIncomingValueForBlock(OrigLatch), OrigHeader); + NH->addIncoming(NPV, OrigPreHeader); + NewHeader->getInstList().push_front(NH); + + // "In" can be replaced by NH at various places. + LoopHeaderInfo.push_back(RenameData(In, NPV, NH)); + } + + // Now, handle non-phi instructions. + for (; I != E; ++I) { + Instruction *In = I; + + assert (!isa(In) && "PHINode is not expected here"); + // This is not a PHI instruction. Insert its clone into original pre-header. + // If this instruction is using a value from same basic block then + // update it to use value from cloned instruction. + Instruction *C = In->clone(); + C->setName(In->getName()); + OrigPreHeader->getInstList().push_back(C); + + for (unsigned opi = 0, e = In->getNumOperands(); opi != e; ++opi) { + if (Instruction *OpPhi = dyn_cast(In->getOperand(opi))) { + if (const RenameData *D = findReplacementData(OpPhi)) { + // This is using values from original header PHI node. + // Here, directly used incoming value from original pre-header. + C->setOperand(opi, D->PreHeader); + } + } + else if (Instruction *OpInsn = + dyn_cast(In->getOperand(opi))) { + if (const RenameData *D = findReplacementData(OpInsn)) + C->setOperand(opi, D->PreHeader); + } + } + + + // If this instruction is used outside this basic block then + // create new PHINode for this instruction. + Instruction *NewHeaderReplacement = NULL; + if (usedOutsideOriginalHeader(In)) { + PHINode *PN = new PHINode(In->getType(), In->getName()); + PN->addIncoming(In, OrigHeader); + PN->addIncoming(C, OrigPreHeader); + NewHeader->getInstList().push_front(PN); + NewHeaderReplacement = PN; + } + + // "In" can be replaced by NPH or NH at various places. + LoopHeaderInfo.push_back(RenameData(In, C, NewHeaderReplacement)); + } + + // Rename uses of original header instructions to reflect their new + // definitions (either from original pre-header node or from newly created + // new header PHINodes. + // + // Original header instructions are used in + // 1) Original header: + // + // If instruction is used in non-phi instructions then it is using + // defintion from original heder iteself. Do not replace this use + // with definition from new header or original pre-header. + // + // If instruction is used in phi node then it is an incoming + // value. Rename its use to reflect new definition from new-preheader + // or new header. + // + // 2) Inside loop but not in original header + // + // Replace this use to reflect definition from new header. + for(unsigned LHI = 0, LHI_E = LoopHeaderInfo.size(); LHI != LHI_E; ++LHI) { + const RenameData &ILoopHeaderInfo = LoopHeaderInfo[LHI]; + + if (!ILoopHeaderInfo.Header) + continue; + + Instruction *OldPhi = ILoopHeaderInfo.Original; + Instruction *NewPhi = ILoopHeaderInfo.Header; + + // Before replacing uses, collect them first, so that iterator is + // not invalidated. + SmallVector AllUses; + for (Value::use_iterator UI = OldPhi->use_begin(), UE = OldPhi->use_end(); + UI != UE; ++UI) { + Instruction *U = cast(UI); + AllUses.push_back(U); + } + + for (SmallVector::iterator UI = AllUses.begin(), + UE = AllUses.end(); UI != UE; ++UI) { + Instruction *U = *UI; + BasicBlock *Parent = U->getParent(); + + // Used inside original header + if (Parent == OrigHeader) { + // Do not rename uses inside original header non-phi instructions. + PHINode *PU = dyn_cast(U); + if (!PU) + continue; + + // Do not rename uses inside original header phi nodes, if the + // incoming value is for new header. + if (PU->getBasicBlockIndex(NewHeader) != -1 + && PU->getIncomingValueForBlock(NewHeader) == U) + continue; + + U->replaceUsesOfWith(OldPhi, NewPhi); + continue; + } + + // Used inside loop, but not in original header. + if (L->contains(U->getParent())) { + if (U != NewPhi) + U->replaceUsesOfWith(OldPhi, NewPhi); + continue; + } + + // Used inside Exit Block. Since we are in LCSSA form, U must be PHINode. + if (U->getParent() == Exit) { + assert (isa(U) && "Use in Exit Block that is not PHINode"); + + PHINode *UPhi = cast(U); + // UPhi already has one incoming argument from original header. + // Add second incoming argument from new Pre header. + UPhi->addIncoming(ILoopHeaderInfo.PreHeader, OrigPreHeader); + } else { + // Used outside Exit block. Create a new PHI node from exit block + // to receive value from ne new header ane pre header. + PHINode *PN = new PHINode(U->getType(), U->getName()); + PN->addIncoming(ILoopHeaderInfo.PreHeader, OrigPreHeader); + PN->addIncoming(OldPhi, OrigHeader); + Exit->getInstList().push_front(PN); + U->replaceUsesOfWith(OldPhi, PN); + } + } + } + + /// Make sure all Exit block PHINodes have required incoming values. + updateExitBlock(); + + // Update CFG + + // Removing incoming branch from loop preheader to original header. + // Now original header is inside the loop. + for (BasicBlock::iterator I = OrigHeader->begin(), E = OrigHeader->end(); + I != E; ++I) { + Instruction *In = I; + PHINode *PN = dyn_cast(In); + if (!PN) + break; + + PN->removeIncomingValue(OrigPreHeader); + } + + // Make NewHeader as the new header for the loop. + L->moveToHeader(NewHeader); + + preserveCanonicalLoopForm(LPM); + + NumRotated++; + return true; +} + +/// Make sure all Exit block PHINodes have required incoming values. +/// If incoming value is constant or defined outside the loop then +/// PHINode may not have an entry for original pre-header. +void LoopRotate::updateExitBlock() { + + for (BasicBlock::iterator I = Exit->begin(), E = Exit->end(); + I != E; ++I) { + + PHINode *PN = dyn_cast(I); + if (!PN) + break; + + // There is already one incoming value from original pre-header block. + if (PN->getBasicBlockIndex(OrigPreHeader) != -1) + continue; + + const RenameData *ILoopHeaderInfo; + Value *V = PN->getIncomingValueForBlock(OrigHeader); + if (isa(V) && + (ILoopHeaderInfo = findReplacementData(cast(V)))) { + assert(ILoopHeaderInfo->PreHeader && "Missing New Preheader Instruction"); + PN->addIncoming(ILoopHeaderInfo->PreHeader, OrigPreHeader); + } else { + PN->addIncoming(V, OrigPreHeader); + } + } +} + +/// Initialize local data +void LoopRotate::initialize() { + L = NULL; + OrigHeader = NULL; + OrigPreHeader = NULL; + NewHeader = NULL; + Exit = NULL; + + LoopHeaderInfo.clear(); +} + +/// Return true if this instruction is used by any instructions in the loop that +/// aren't in original header. +bool LoopRotate::usedOutsideOriginalHeader(Instruction *In) { + + for (Value::use_iterator UI = In->use_begin(), UE = In->use_end(); + UI != UE; ++UI) { + Instruction *U = cast(UI); + if (U->getParent() != OrigHeader) { + if (L->contains(U->getParent())) + return true; + } + } + + return false; +} + +/// Find Replacement information for instruction. Return NULL if it is +/// not available. +const RenameData *LoopRotate::findReplacementData(Instruction *In) { + + // Since LoopHeaderInfo is small, linear walk is OK. + for(unsigned LHI = 0, LHI_E = LoopHeaderInfo.size(); LHI != LHI_E; ++LHI) { + const RenameData &ILoopHeaderInfo = LoopHeaderInfo[LHI]; + if (ILoopHeaderInfo.Original == In) + return &ILoopHeaderInfo; + } + return NULL; +} + +/// After loop rotation, loop pre-header has multiple sucessors. +/// Insert one forwarding basic block to ensure that loop pre-header +/// has only one successor. +void LoopRotate::preserveCanonicalLoopForm(LPPassManager &LPM) { + + // Right now original pre-header has two successors, new header and + // exit block. Insert new block between original pre-header and + // new header such that loop's new pre-header has only one successor. + BasicBlock *NewPreHeader = new BasicBlock("bb.nph", OrigHeader->getParent(), + NewHeader); + LoopInfo &LI = LPM.getAnalysis(); + if (Loop *PL = LI.getLoopFor(OrigPreHeader)) + PL->addBasicBlockToLoop(NewPreHeader, LI); + new BranchInst(NewHeader, NewPreHeader); + + BranchInst *OrigPH_BI = cast(OrigPreHeader->getTerminator()); + if (OrigPH_BI->getSuccessor(0) == NewHeader) + OrigPH_BI->setSuccessor(0, NewPreHeader); + else { + assert (OrigPH_BI->getSuccessor(1) == NewHeader && + "Unexpected original pre-header terminator"); + OrigPH_BI->setSuccessor(1, NewPreHeader); + } + + for (BasicBlock::iterator I = NewHeader->begin(), E = NewHeader->end(); + I != E; ++I) { + Instruction *In = I; + PHINode *PN = dyn_cast(In); + if (!PN) + break; + + int index = PN->getBasicBlockIndex(OrigPreHeader); + assert (index != -1 && "Expected incoming value from Original PreHeader"); + PN->setIncomingBlock(index, NewPreHeader); + assert (PN->getBasicBlockIndex(OrigPreHeader) == -1 && + "Expected only one incoming value from Original PreHeader"); + } + + if (DominatorTree *DT = getAnalysisToUpdate()) { + DT->addNewBlock(NewPreHeader, OrigPreHeader); + DT->changeImmediateDominator(L->getHeader(), NewPreHeader); + DT->changeImmediateDominator(Exit, OrigPreHeader); + for (Loop::block_iterator BI = L->block_begin(), BE = L->block_end(); + BI != BE; ++BI) { + BasicBlock *B = *BI; + if (L->getHeader() != B) { + DomTreeNode *Node = DT->getNode(B); + if (Node && Node->getBlock() == OrigHeader) + DT->changeImmediateDominator(*BI, L->getHeader()); + } + } + DT->changeImmediateDominator(OrigHeader, OrigLatch); + } + + if(DominanceFrontier *DF = getAnalysisToUpdate()) { + + // New Preheader's dominance frontier is Exit block. + DominanceFrontier::DomSetType NewPHSet; + NewPHSet.insert(Exit); + DF->addBasicBlock(NewPreHeader, NewPHSet); + + // New Header's dominance frontier now includes itself and Exit block + DominanceFrontier::iterator HeadI = DF->find(L->getHeader()); + if (HeadI != DF->end()) { + DominanceFrontier::DomSetType & HeaderSet = HeadI->second; + HeaderSet.clear(); + HeaderSet.insert(L->getHeader()); + HeaderSet.insert(Exit); + } else { + DominanceFrontier::DomSetType HeaderSet; + HeaderSet.insert(L->getHeader()); + HeaderSet.insert(Exit); + DF->addBasicBlock(L->getHeader(), HeaderSet); + } + + // Original header (new Loop Latch)'s dominance frontier is Exit. + DominanceFrontier::iterator LatchI = DF->find(L->getLoopLatch()); + if (LatchI != DF->end()) { + DominanceFrontier::DomSetType &LatchSet = LatchI->second; + LatchSet = LatchI->second; + LatchSet.clear(); + LatchSet.insert(Exit); + } else { + DominanceFrontier::DomSetType LatchSet; + LatchSet.insert(Exit); + DF->addBasicBlock(L->getHeader(), LatchSet); + } + + // If a loop block dominates new loop latch then its frontier is + // new header and Exit. + BasicBlock *NewLatch = L->getLoopLatch(); + DominatorTree *DT = getAnalysisToUpdate(); + for (Loop::block_iterator BI = L->block_begin(), BE = L->block_end(); + BI != BE; ++BI) { + BasicBlock *B = *BI; + if (DT->dominates(B, NewLatch)) { + DominanceFrontier::iterator BDFI = DF->find(B); + if (BDFI != DF->end()) { + DominanceFrontier::DomSetType &BSet = BDFI->second; + BSet = BDFI->second; + BSet.clear(); + BSet.insert(L->getHeader()); + BSet.insert(Exit); + } else { + DominanceFrontier::DomSetType BSet; + BSet.insert(L->getHeader()); + BSet.insert(Exit); + DF->addBasicBlock(B, BSet); + } + } + } + } + + // Preserve canonical loop form, which means Exit block should + // have only one predecessor. + BasicBlock *NExit = SplitEdge(L->getLoopLatch(), Exit, this); + + // Preserve LCSSA. + BasicBlock::iterator I = Exit->begin(), E = Exit->end(); + PHINode *PN = NULL; + for (; (PN = dyn_cast(I)); ++I) { + PHINode *NewPN = new PHINode(PN->getType(), PN->getName()); + unsigned N = PN->getNumIncomingValues(); + for (unsigned index = 0; index < N; ++index) + if (PN->getIncomingBlock(index) == NExit) { + NewPN->addIncoming(PN->getIncomingValue(index), L->getLoopLatch()); + PN->setIncomingValue(index, NewPN); + PN->setIncomingBlock(index, NExit); + NExit->getInstList().push_front(NewPN); + } + } + + assert (NewHeader && L->getHeader() == NewHeader + && "Invalid loop header after loop rotation"); + assert (NewPreHeader && L->getLoopPreheader() == NewPreHeader + && "Invalid loop preheader after loop rotation"); + assert (L->getLoopLatch() + && "Invalid loop latch after loop rotation"); + +} diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp new file mode 100644 index 0000000..9689c12 --- /dev/null +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -0,0 +1,1504 @@ +//===- LoopStrengthReduce.cpp - Strength Reduce GEPs in Loops -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by Nate Begeman and is distributed under the +// University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs a strength reduction on array references inside loops that +// have as one or more of their components the loop induction variable. This is +// accomplished by creating a new Value to hold the initial value of the array +// access for the first iteration, and then creating a new GEP instruction in +// the loop to increment the value by the appropriate amount. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-reduce" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Type.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Target/TargetData.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Target/TargetLowering.h" +#include +#include +using namespace llvm; + +STATISTIC(NumReduced , "Number of GEPs strength reduced"); +STATISTIC(NumInserted, "Number of PHIs inserted"); +STATISTIC(NumVariable, "Number of PHIs with variable strides"); + +namespace { + + struct BasedUser; + + /// IVStrideUse - Keep track of one use of a strided induction variable, where + /// the stride is stored externally. The Offset member keeps track of the + /// offset from the IV, User is the actual user of the operand, and 'Operand' + /// is the operand # of the User that is the use. + struct VISIBILITY_HIDDEN IVStrideUse { + SCEVHandle Offset; + Instruction *User; + Value *OperandValToReplace; + + // isUseOfPostIncrementedValue - True if this should use the + // post-incremented version of this IV, not the preincremented version. + // This can only be set in special cases, such as the terminating setcc + // instruction for a loop or uses dominated by the loop. + bool isUseOfPostIncrementedValue; + + IVStrideUse(const SCEVHandle &Offs, Instruction *U, Value *O) + : Offset(Offs), User(U), OperandValToReplace(O), + isUseOfPostIncrementedValue(false) {} + }; + + /// IVUsersOfOneStride - This structure keeps track of all instructions that + /// have an operand that is based on the trip count multiplied by some stride. + /// The stride for all of these users is common and kept external to this + /// structure. + struct VISIBILITY_HIDDEN IVUsersOfOneStride { + /// Users - Keep track of all of the users of this stride as well as the + /// initial value and the operand that uses the IV. + std::vector Users; + + void addUser(const SCEVHandle &Offset,Instruction *User, Value *Operand) { + Users.push_back(IVStrideUse(Offset, User, Operand)); + } + }; + + /// IVInfo - This structure keeps track of one IV expression inserted during + /// StrengthReduceStridedIVUsers. It contains the stride, the common base, as + /// well as the PHI node and increment value created for rewrite. + struct VISIBILITY_HIDDEN IVExpr { + SCEVHandle Stride; + SCEVHandle Base; + PHINode *PHI; + Value *IncV; + + IVExpr() + : Stride(SCEVUnknown::getIntegerSCEV(0, Type::Int32Ty)), + Base (SCEVUnknown::getIntegerSCEV(0, Type::Int32Ty)) {} + IVExpr(const SCEVHandle &stride, const SCEVHandle &base, PHINode *phi, + Value *incv) + : Stride(stride), Base(base), PHI(phi), IncV(incv) {} + }; + + /// IVsOfOneStride - This structure keeps track of all IV expression inserted + /// during StrengthReduceStridedIVUsers for a particular stride of the IV. + struct VISIBILITY_HIDDEN IVsOfOneStride { + std::vector IVs; + + void addIV(const SCEVHandle &Stride, const SCEVHandle &Base, PHINode *PHI, + Value *IncV) { + IVs.push_back(IVExpr(Stride, Base, PHI, IncV)); + } + }; + + class VISIBILITY_HIDDEN LoopStrengthReduce : public LoopPass { + LoopInfo *LI; + DominatorTree *DT; + ScalarEvolution *SE; + const TargetData *TD; + const Type *UIntPtrTy; + bool Changed; + + /// IVUsesByStride - Keep track of all uses of induction variables that we + /// are interested in. The key of the map is the stride of the access. + std::map IVUsesByStride; + + /// IVsByStride - Keep track of all IVs that have been inserted for a + /// particular stride. + std::map IVsByStride; + + /// StrideOrder - An ordering of the keys in IVUsesByStride that is stable: + /// We use this to iterate over the IVUsesByStride collection without being + /// dependent on random ordering of pointers in the process. + std::vector StrideOrder; + + /// CastedValues - As we need to cast values to uintptr_t, this keeps track + /// of the casted version of each value. This is accessed by + /// getCastedVersionOf. + std::map CastedPointers; + + /// DeadInsts - Keep track of instructions we may have made dead, so that + /// we can remove them after we are done working. + std::set DeadInsts; + + /// TLI - Keep a pointer of a TargetLowering to consult for determining + /// transformation profitability. + const TargetLowering *TLI; + + public: + static char ID; // Pass ID, replacement for typeid + LoopStrengthReduce(const TargetLowering *tli = NULL) : + LoopPass((intptr_t)&ID), TLI(tli) { + } + + bool runOnLoop(Loop *L, LPPassManager &LPM); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + // We split critical edges, so we change the CFG. However, we do update + // many analyses if they are around. + AU.addPreservedID(LoopSimplifyID); + AU.addPreserved(); + AU.addPreserved(); + AU.addPreserved(); + + AU.addRequiredID(LoopSimplifyID); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + } + + /// getCastedVersionOf - Return the specified value casted to uintptr_t. + /// + Value *getCastedVersionOf(Instruction::CastOps opcode, Value *V); +private: + bool AddUsersIfInteresting(Instruction *I, Loop *L, + std::set &Processed); + SCEVHandle GetExpressionSCEV(Instruction *E, Loop *L); + + void OptimizeIndvars(Loop *L); + bool FindIVForUser(ICmpInst *Cond, IVStrideUse *&CondUse, + const SCEVHandle *&CondStride); + + unsigned CheckForIVReuse(const SCEVHandle&, IVExpr&, const Type*, + const std::vector& UsersToProcess); + + bool ValidStride(int64_t, const std::vector& UsersToProcess); + + void StrengthReduceStridedIVUsers(const SCEVHandle &Stride, + IVUsersOfOneStride &Uses, + Loop *L, bool isOnlyStride); + void DeleteTriviallyDeadInstructions(std::set &Insts); + }; + char LoopStrengthReduce::ID = 0; + RegisterPass X("loop-reduce", "Loop Strength Reduction"); +} + +LoopPass *llvm::createLoopStrengthReducePass(const TargetLowering *TLI) { + return new LoopStrengthReduce(TLI); +} + +/// getCastedVersionOf - Return the specified value casted to uintptr_t. This +/// assumes that the Value* V is of integer or pointer type only. +/// +Value *LoopStrengthReduce::getCastedVersionOf(Instruction::CastOps opcode, + Value *V) { + if (V->getType() == UIntPtrTy) return V; + if (Constant *CB = dyn_cast(V)) + return ConstantExpr::getCast(opcode, CB, UIntPtrTy); + + Value *&New = CastedPointers[V]; + if (New) return New; + + New = SCEVExpander::InsertCastOfTo(opcode, V, UIntPtrTy); + DeadInsts.insert(cast(New)); + return New; +} + + +/// DeleteTriviallyDeadInstructions - If any of the instructions is the +/// specified set are trivially dead, delete them and see if this makes any of +/// their operands subsequently dead. +void LoopStrengthReduce:: +DeleteTriviallyDeadInstructions(std::set &Insts) { + while (!Insts.empty()) { + Instruction *I = *Insts.begin(); + Insts.erase(Insts.begin()); + if (isInstructionTriviallyDead(I)) { + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Instruction *U = dyn_cast(I->getOperand(i))) + Insts.insert(U); + SE->deleteValueFromRecords(I); + I->eraseFromParent(); + Changed = true; + } + } +} + + +/// GetExpressionSCEV - Compute and return the SCEV for the specified +/// instruction. +SCEVHandle LoopStrengthReduce::GetExpressionSCEV(Instruction *Exp, Loop *L) { + // Pointer to pointer bitcast instructions return the same value as their + // operand. + if (BitCastInst *BCI = dyn_cast(Exp)) { + if (SE->hasSCEV(BCI) || !isa(BCI->getOperand(0))) + return SE->getSCEV(BCI); + SCEVHandle R = GetExpressionSCEV(cast(BCI->getOperand(0)), L); + SE->setSCEV(BCI, R); + return R; + } + + // Scalar Evolutions doesn't know how to compute SCEV's for GEP instructions. + // If this is a GEP that SE doesn't know about, compute it now and insert it. + // If this is not a GEP, or if we have already done this computation, just let + // SE figure it out. + GetElementPtrInst *GEP = dyn_cast(Exp); + if (!GEP || SE->hasSCEV(GEP)) + return SE->getSCEV(Exp); + + // Analyze all of the subscripts of this getelementptr instruction, looking + // for uses that are determined by the trip count of L. First, skip all + // operands the are not dependent on the IV. + + // Build up the base expression. Insert an LLVM cast of the pointer to + // uintptr_t first. + SCEVHandle GEPVal = SCEVUnknown::get( + getCastedVersionOf(Instruction::PtrToInt, GEP->getOperand(0))); + + gep_type_iterator GTI = gep_type_begin(GEP); + + for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { + // If this is a use of a recurrence that we can analyze, and it comes before + // Op does in the GEP operand list, we will handle this when we process this + // operand. + if (const StructType *STy = dyn_cast(*GTI)) { + const StructLayout *SL = TD->getStructLayout(STy); + unsigned Idx = cast(GEP->getOperand(i))->getZExtValue(); + uint64_t Offset = SL->getElementOffset(Idx); + GEPVal = SCEVAddExpr::get(GEPVal, + SCEVUnknown::getIntegerSCEV(Offset, UIntPtrTy)); + } else { + unsigned GEPOpiBits = + GEP->getOperand(i)->getType()->getPrimitiveSizeInBits(); + unsigned IntPtrBits = UIntPtrTy->getPrimitiveSizeInBits(); + Instruction::CastOps opcode = (GEPOpiBits < IntPtrBits ? + Instruction::SExt : (GEPOpiBits > IntPtrBits ? Instruction::Trunc : + Instruction::BitCast)); + Value *OpVal = getCastedVersionOf(opcode, GEP->getOperand(i)); + SCEVHandle Idx = SE->getSCEV(OpVal); + + uint64_t TypeSize = TD->getTypeSize(GTI.getIndexedType()); + if (TypeSize != 1) + Idx = SCEVMulExpr::get(Idx, + SCEVConstant::get(ConstantInt::get(UIntPtrTy, + TypeSize))); + GEPVal = SCEVAddExpr::get(GEPVal, Idx); + } + } + + SE->setSCEV(GEP, GEPVal); + return GEPVal; +} + +/// getSCEVStartAndStride - Compute the start and stride of this expression, +/// returning false if the expression is not a start/stride pair, or true if it +/// is. The stride must be a loop invariant expression, but the start may be +/// a mix of loop invariant and loop variant expressions. +static bool getSCEVStartAndStride(const SCEVHandle &SH, Loop *L, + SCEVHandle &Start, SCEVHandle &Stride) { + SCEVHandle TheAddRec = Start; // Initialize to zero. + + // If the outer level is an AddExpr, the operands are all start values except + // for a nested AddRecExpr. + if (SCEVAddExpr *AE = dyn_cast(SH)) { + for (unsigned i = 0, e = AE->getNumOperands(); i != e; ++i) + if (SCEVAddRecExpr *AddRec = + dyn_cast(AE->getOperand(i))) { + if (AddRec->getLoop() == L) + TheAddRec = SCEVAddExpr::get(AddRec, TheAddRec); + else + return false; // Nested IV of some sort? + } else { + Start = SCEVAddExpr::get(Start, AE->getOperand(i)); + } + + } else if (isa(SH)) { + TheAddRec = SH; + } else { + return false; // not analyzable. + } + + SCEVAddRecExpr *AddRec = dyn_cast(TheAddRec); + if (!AddRec || AddRec->getLoop() != L) return false; + + // FIXME: Generalize to non-affine IV's. + if (!AddRec->isAffine()) return false; + + Start = SCEVAddExpr::get(Start, AddRec->getOperand(0)); + + if (!isa(AddRec->getOperand(1))) + DOUT << "[" << L->getHeader()->getName() + << "] Variable stride: " << *AddRec << "\n"; + + Stride = AddRec->getOperand(1); + return true; +} + +/// IVUseShouldUsePostIncValue - We have discovered a "User" of an IV expression +/// and now we need to decide whether the user should use the preinc or post-inc +/// value. If this user should use the post-inc version of the IV, return true. +/// +/// Choosing wrong here can break dominance properties (if we choose to use the +/// post-inc value when we cannot) or it can end up adding extra live-ranges to +/// the loop, resulting in reg-reg copies (if we use the pre-inc value when we +/// should use the post-inc value). +static bool IVUseShouldUsePostIncValue(Instruction *User, Instruction *IV, + Loop *L, DominatorTree *DT, Pass *P) { + // If the user is in the loop, use the preinc value. + if (L->contains(User->getParent())) return false; + + BasicBlock *LatchBlock = L->getLoopLatch(); + + // Ok, the user is outside of the loop. If it is dominated by the latch + // block, use the post-inc value. + if (DT->dominates(LatchBlock, User->getParent())) + return true; + + // There is one case we have to be careful of: PHI nodes. These little guys + // can live in blocks that do not dominate the latch block, but (since their + // uses occur in the predecessor block, not the block the PHI lives in) should + // still use the post-inc value. Check for this case now. + PHINode *PN = dyn_cast(User); + if (!PN) return false; // not a phi, not dominated by latch block. + + // Look at all of the uses of IV by the PHI node. If any use corresponds to + // a block that is not dominated by the latch block, give up and use the + // preincremented value. + unsigned NumUses = 0; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == IV) { + ++NumUses; + if (!DT->dominates(LatchBlock, PN->getIncomingBlock(i))) + return false; + } + + // Okay, all uses of IV by PN are in predecessor blocks that really are + // dominated by the latch block. Split the critical edges and use the + // post-incremented value. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == IV) { + SplitCriticalEdge(PN->getIncomingBlock(i), PN->getParent(), P, + true); + // Splitting the critical edge can reduce the number of entries in this + // PHI. + e = PN->getNumIncomingValues(); + if (--NumUses == 0) break; + } + + return true; +} + + + +/// AddUsersIfInteresting - Inspect the specified instruction. If it is a +/// reducible SCEV, recursively add its users to the IVUsesByStride set and +/// return true. Otherwise, return false. +bool LoopStrengthReduce::AddUsersIfInteresting(Instruction *I, Loop *L, + std::set &Processed) { + if (!I->getType()->isInteger() && !isa(I->getType())) + return false; // Void and FP expressions cannot be reduced. + if (!Processed.insert(I).second) + return true; // Instruction already handled. + + // Get the symbolic expression for this instruction. + SCEVHandle ISE = GetExpressionSCEV(I, L); + if (isa(ISE)) return false; + + // Get the start and stride for this expression. + SCEVHandle Start = SCEVUnknown::getIntegerSCEV(0, ISE->getType()); + SCEVHandle Stride = Start; + if (!getSCEVStartAndStride(ISE, L, Start, Stride)) + return false; // Non-reducible symbolic expression, bail out. + + std::vector IUsers; + // Collect all I uses now because IVUseShouldUsePostIncValue may + // invalidate use_iterator. + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); UI != E; ++UI) + IUsers.push_back(cast(*UI)); + + for (unsigned iused_index = 0, iused_size = IUsers.size(); + iused_index != iused_size; ++iused_index) { + + Instruction *User = IUsers[iused_index]; + + // Do not infinitely recurse on PHI nodes. + if (isa(User) && Processed.count(User)) + continue; + + // If this is an instruction defined in a nested loop, or outside this loop, + // don't recurse into it. + bool AddUserToIVUsers = false; + if (LI->getLoopFor(User->getParent()) != L) { + DOUT << "FOUND USER in other loop: " << *User + << " OF SCEV: " << *ISE << "\n"; + AddUserToIVUsers = true; + } else if (!AddUsersIfInteresting(User, L, Processed)) { + DOUT << "FOUND USER: " << *User + << " OF SCEV: " << *ISE << "\n"; + AddUserToIVUsers = true; + } + + if (AddUserToIVUsers) { + IVUsersOfOneStride &StrideUses = IVUsesByStride[Stride]; + if (StrideUses.Users.empty()) // First occurance of this stride? + StrideOrder.push_back(Stride); + + // Okay, we found a user that we cannot reduce. Analyze the instruction + // and decide what to do with it. If we are a use inside of the loop, use + // the value before incrementation, otherwise use it after incrementation. + if (IVUseShouldUsePostIncValue(User, I, L, DT, this)) { + // The value used will be incremented by the stride more than we are + // expecting, so subtract this off. + SCEVHandle NewStart = SCEV::getMinusSCEV(Start, Stride); + StrideUses.addUser(NewStart, User, I); + StrideUses.Users.back().isUseOfPostIncrementedValue = true; + DOUT << " USING POSTINC SCEV, START=" << *NewStart<< "\n"; + } else { + StrideUses.addUser(Start, User, I); + } + } + } + return true; +} + +namespace { + /// BasedUser - For a particular base value, keep information about how we've + /// partitioned the expression so far. + struct BasedUser { + /// Base - The Base value for the PHI node that needs to be inserted for + /// this use. As the use is processed, information gets moved from this + /// field to the Imm field (below). BasedUser values are sorted by this + /// field. + SCEVHandle Base; + + /// Inst - The instruction using the induction variable. + Instruction *Inst; + + /// OperandValToReplace - The operand value of Inst to replace with the + /// EmittedBase. + Value *OperandValToReplace; + + /// Imm - The immediate value that should be added to the base immediately + /// before Inst, because it will be folded into the imm field of the + /// instruction. + SCEVHandle Imm; + + /// EmittedBase - The actual value* to use for the base value of this + /// operation. This is null if we should just use zero so far. + Value *EmittedBase; + + // isUseOfPostIncrementedValue - True if this should use the + // post-incremented version of this IV, not the preincremented version. + // This can only be set in special cases, such as the terminating setcc + // instruction for a loop and uses outside the loop that are dominated by + // the loop. + bool isUseOfPostIncrementedValue; + + BasedUser(IVStrideUse &IVSU) + : Base(IVSU.Offset), Inst(IVSU.User), + OperandValToReplace(IVSU.OperandValToReplace), + Imm(SCEVUnknown::getIntegerSCEV(0, Base->getType())), EmittedBase(0), + isUseOfPostIncrementedValue(IVSU.isUseOfPostIncrementedValue) {} + + // Once we rewrite the code to insert the new IVs we want, update the + // operands of Inst to use the new expression 'NewBase', with 'Imm' added + // to it. + void RewriteInstructionToUseNewBase(const SCEVHandle &NewBase, + SCEVExpander &Rewriter, Loop *L, + Pass *P); + + Value *InsertCodeForBaseAtPosition(const SCEVHandle &NewBase, + SCEVExpander &Rewriter, + Instruction *IP, Loop *L); + void dump() const; + }; +} + +void BasedUser::dump() const { + cerr << " Base=" << *Base; + cerr << " Imm=" << *Imm; + if (EmittedBase) + cerr << " EB=" << *EmittedBase; + + cerr << " Inst: " << *Inst; +} + +Value *BasedUser::InsertCodeForBaseAtPosition(const SCEVHandle &NewBase, + SCEVExpander &Rewriter, + Instruction *IP, Loop *L) { + // Figure out where we *really* want to insert this code. In particular, if + // the user is inside of a loop that is nested inside of L, we really don't + // want to insert this expression before the user, we'd rather pull it out as + // many loops as possible. + LoopInfo &LI = Rewriter.getLoopInfo(); + Instruction *BaseInsertPt = IP; + + // Figure out the most-nested loop that IP is in. + Loop *InsertLoop = LI.getLoopFor(IP->getParent()); + + // If InsertLoop is not L, and InsertLoop is nested inside of L, figure out + // the preheader of the outer-most loop where NewBase is not loop invariant. + while (InsertLoop && NewBase->isLoopInvariant(InsertLoop)) { + BaseInsertPt = InsertLoop->getLoopPreheader()->getTerminator(); + InsertLoop = InsertLoop->getParentLoop(); + } + + // If there is no immediate value, skip the next part. + if (SCEVConstant *SC = dyn_cast(Imm)) + if (SC->getValue()->isZero()) + return Rewriter.expandCodeFor(NewBase, BaseInsertPt); + + Value *Base = Rewriter.expandCodeFor(NewBase, BaseInsertPt); + + // If we are inserting the base and imm values in the same block, make sure to + // adjust the IP position if insertion reused a result. + if (IP == BaseInsertPt) + IP = Rewriter.getInsertionPoint(); + + // Always emit the immediate (if non-zero) into the same block as the user. + SCEVHandle NewValSCEV = SCEVAddExpr::get(SCEVUnknown::get(Base), Imm); + return Rewriter.expandCodeFor(NewValSCEV, IP); + +} + + +// Once we rewrite the code to insert the new IVs we want, update the +// operands of Inst to use the new expression 'NewBase', with 'Imm' added +// to it. +void BasedUser::RewriteInstructionToUseNewBase(const SCEVHandle &NewBase, + SCEVExpander &Rewriter, + Loop *L, Pass *P) { + if (!isa(Inst)) { + // By default, insert code at the user instruction. + BasicBlock::iterator InsertPt = Inst; + + // However, if the Operand is itself an instruction, the (potentially + // complex) inserted code may be shared by many users. Because of this, we + // want to emit code for the computation of the operand right before its old + // computation. This is usually safe, because we obviously used to use the + // computation when it was computed in its current block. However, in some + // cases (e.g. use of a post-incremented induction variable) the NewBase + // value will be pinned to live somewhere after the original computation. + // In this case, we have to back off. + if (!isUseOfPostIncrementedValue) { + if (Instruction *OpInst = dyn_cast(OperandValToReplace)) { + InsertPt = OpInst; + while (isa(InsertPt)) ++InsertPt; + } + } + Value *NewVal = InsertCodeForBaseAtPosition(NewBase, Rewriter, InsertPt, L); + // Adjust the type back to match the Inst. + if (isa(OperandValToReplace->getType())) { + NewVal = new IntToPtrInst(NewVal, OperandValToReplace->getType(), "cast", + InsertPt); + } + // Replace the use of the operand Value with the new Phi we just created. + Inst->replaceUsesOfWith(OperandValToReplace, NewVal); + DOUT << " CHANGED: IMM =" << *Imm; + DOUT << " \tNEWBASE =" << *NewBase; + DOUT << " \tInst = " << *Inst; + return; + } + + // PHI nodes are more complex. We have to insert one copy of the NewBase+Imm + // expression into each operand block that uses it. Note that PHI nodes can + // have multiple entries for the same predecessor. We use a map to make sure + // that a PHI node only has a single Value* for each predecessor (which also + // prevents us from inserting duplicate code in some blocks). + std::map InsertedCode; + PHINode *PN = cast(Inst); + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + if (PN->getIncomingValue(i) == OperandValToReplace) { + // If this is a critical edge, split the edge so that we do not insert the + // code on all predecessor/successor paths. We do this unless this is the + // canonical backedge for this loop, as this can make some inserted code + // be in an illegal position. + BasicBlock *PHIPred = PN->getIncomingBlock(i); + if (e != 1 && PHIPred->getTerminator()->getNumSuccessors() > 1 && + (PN->getParent() != L->getHeader() || !L->contains(PHIPred))) { + + // First step, split the critical edge. + SplitCriticalEdge(PHIPred, PN->getParent(), P, true); + + // Next step: move the basic block. In particular, if the PHI node + // is outside of the loop, and PredTI is in the loop, we want to + // move the block to be immediately before the PHI block, not + // immediately after PredTI. + if (L->contains(PHIPred) && !L->contains(PN->getParent())) { + BasicBlock *NewBB = PN->getIncomingBlock(i); + NewBB->moveBefore(PN->getParent()); + } + + // Splitting the edge can reduce the number of PHI entries we have. + e = PN->getNumIncomingValues(); + } + + Value *&Code = InsertedCode[PN->getIncomingBlock(i)]; + if (!Code) { + // Insert the code into the end of the predecessor block. + Instruction *InsertPt = PN->getIncomingBlock(i)->getTerminator(); + Code = InsertCodeForBaseAtPosition(NewBase, Rewriter, InsertPt, L); + + // Adjust the type back to match the PHI. + if (isa(PN->getType())) { + Code = new IntToPtrInst(Code, PN->getType(), "cast", InsertPt); + } + } + + // Replace the use of the operand Value with the new Phi we just created. + PN->setIncomingValue(i, Code); + Rewriter.clear(); + } + } + DOUT << " CHANGED: IMM =" << *Imm << " Inst = " << *Inst; +} + + +/// isTargetConstant - Return true if the following can be referenced by the +/// immediate field of a target instruction. +static bool isTargetConstant(const SCEVHandle &V, const Type *UseTy, + const TargetLowering *TLI) { + if (SCEVConstant *SC = dyn_cast(V)) { + int64_t VC = SC->getValue()->getSExtValue(); + if (TLI) { + TargetLowering::AddrMode AM; + AM.BaseOffs = VC; + return TLI->isLegalAddressingMode(AM, UseTy); + } else { + // Defaults to PPC. PPC allows a sign-extended 16-bit immediate field. + return (VC > -(1 << 16) && VC < (1 << 16)-1); + } + } + + if (SCEVUnknown *SU = dyn_cast(V)) + if (ConstantExpr *CE = dyn_cast(SU->getValue())) + if (TLI && CE->getOpcode() == Instruction::PtrToInt) { + Constant *Op0 = CE->getOperand(0); + if (GlobalValue *GV = dyn_cast(Op0)) { + TargetLowering::AddrMode AM; + AM.BaseGV = GV; + return TLI->isLegalAddressingMode(AM, UseTy); + } + } + return false; +} + +/// MoveLoopVariantsToImediateField - Move any subexpressions from Val that are +/// loop varying to the Imm operand. +static void MoveLoopVariantsToImediateField(SCEVHandle &Val, SCEVHandle &Imm, + Loop *L) { + if (Val->isLoopInvariant(L)) return; // Nothing to do. + + if (SCEVAddExpr *SAE = dyn_cast(Val)) { + std::vector NewOps; + NewOps.reserve(SAE->getNumOperands()); + + for (unsigned i = 0; i != SAE->getNumOperands(); ++i) + if (!SAE->getOperand(i)->isLoopInvariant(L)) { + // If this is a loop-variant expression, it must stay in the immediate + // field of the expression. + Imm = SCEVAddExpr::get(Imm, SAE->getOperand(i)); + } else { + NewOps.push_back(SAE->getOperand(i)); + } + + if (NewOps.empty()) + Val = SCEVUnknown::getIntegerSCEV(0, Val->getType()); + else + Val = SCEVAddExpr::get(NewOps); + } else if (SCEVAddRecExpr *SARE = dyn_cast(Val)) { + // Try to pull immediates out of the start value of nested addrec's. + SCEVHandle Start = SARE->getStart(); + MoveLoopVariantsToImediateField(Start, Imm, L); + + std::vector Ops(SARE->op_begin(), SARE->op_end()); + Ops[0] = Start; + Val = SCEVAddRecExpr::get(Ops, SARE->getLoop()); + } else { + // Otherwise, all of Val is variant, move the whole thing over. + Imm = SCEVAddExpr::get(Imm, Val); + Val = SCEVUnknown::getIntegerSCEV(0, Val->getType()); + } +} + + +/// MoveImmediateValues - Look at Val, and pull out any additions of constants +/// that can fit into the immediate field of instructions in the target. +/// Accumulate these immediate values into the Imm value. +static void MoveImmediateValues(const TargetLowering *TLI, + Instruction *User, + SCEVHandle &Val, SCEVHandle &Imm, + bool isAddress, Loop *L) { + const Type *UseTy = User->getType(); + if (StoreInst *SI = dyn_cast(User)) + UseTy = SI->getOperand(0)->getType(); + + if (SCEVAddExpr *SAE = dyn_cast(Val)) { + std::vector NewOps; + NewOps.reserve(SAE->getNumOperands()); + + for (unsigned i = 0; i != SAE->getNumOperands(); ++i) { + SCEVHandle NewOp = SAE->getOperand(i); + MoveImmediateValues(TLI, User, NewOp, Imm, isAddress, L); + + if (!NewOp->isLoopInvariant(L)) { + // If this is a loop-variant expression, it must stay in the immediate + // field of the expression. + Imm = SCEVAddExpr::get(Imm, NewOp); + } else { + NewOps.push_back(NewOp); + } + } + + if (NewOps.empty()) + Val = SCEVUnknown::getIntegerSCEV(0, Val->getType()); + else + Val = SCEVAddExpr::get(NewOps); + return; + } else if (SCEVAddRecExpr *SARE = dyn_cast(Val)) { + // Try to pull immediates out of the start value of nested addrec's. + SCEVHandle Start = SARE->getStart(); + MoveImmediateValues(TLI, User, Start, Imm, isAddress, L); + + if (Start != SARE->getStart()) { + std::vector Ops(SARE->op_begin(), SARE->op_end()); + Ops[0] = Start; + Val = SCEVAddRecExpr::get(Ops, SARE->getLoop()); + } + return; + } else if (SCEVMulExpr *SME = dyn_cast(Val)) { + // Transform "8 * (4 + v)" -> "32 + 8*V" if "32" fits in the immed field. + if (isAddress && isTargetConstant(SME->getOperand(0), UseTy, TLI) && + SME->getNumOperands() == 2 && SME->isLoopInvariant(L)) { + + SCEVHandle SubImm = SCEVUnknown::getIntegerSCEV(0, Val->getType()); + SCEVHandle NewOp = SME->getOperand(1); + MoveImmediateValues(TLI, User, NewOp, SubImm, isAddress, L); + + // If we extracted something out of the subexpressions, see if we can + // simplify this! + if (NewOp != SME->getOperand(1)) { + // Scale SubImm up by "8". If the result is a target constant, we are + // good. + SubImm = SCEVMulExpr::get(SubImm, SME->getOperand(0)); + if (isTargetConstant(SubImm, UseTy, TLI)) { + // Accumulate the immediate. + Imm = SCEVAddExpr::get(Imm, SubImm); + + // Update what is left of 'Val'. + Val = SCEVMulExpr::get(SME->getOperand(0), NewOp); + return; + } + } + } + } + + // Loop-variant expressions must stay in the immediate field of the + // expression. + if ((isAddress && isTargetConstant(Val, UseTy, TLI)) || + !Val->isLoopInvariant(L)) { + Imm = SCEVAddExpr::get(Imm, Val); + Val = SCEVUnknown::getIntegerSCEV(0, Val->getType()); + return; + } + + // Otherwise, no immediates to move. +} + + +/// SeparateSubExprs - Decompose Expr into all of the subexpressions that are +/// added together. This is used to reassociate common addition subexprs +/// together for maximal sharing when rewriting bases. +static void SeparateSubExprs(std::vector &SubExprs, + SCEVHandle Expr) { + if (SCEVAddExpr *AE = dyn_cast(Expr)) { + for (unsigned j = 0, e = AE->getNumOperands(); j != e; ++j) + SeparateSubExprs(SubExprs, AE->getOperand(j)); + } else if (SCEVAddRecExpr *SARE = dyn_cast(Expr)) { + SCEVHandle Zero = SCEVUnknown::getIntegerSCEV(0, Expr->getType()); + if (SARE->getOperand(0) == Zero) { + SubExprs.push_back(Expr); + } else { + // Compute the addrec with zero as its base. + std::vector Ops(SARE->op_begin(), SARE->op_end()); + Ops[0] = Zero; // Start with zero base. + SubExprs.push_back(SCEVAddRecExpr::get(Ops, SARE->getLoop())); + + + SeparateSubExprs(SubExprs, SARE->getOperand(0)); + } + } else if (!isa(Expr) || + !cast(Expr)->getValue()->isZero()) { + // Do not add zero. + SubExprs.push_back(Expr); + } +} + + +/// RemoveCommonExpressionsFromUseBases - Look through all of the uses in Bases, +/// removing any common subexpressions from it. Anything truly common is +/// removed, accumulated, and returned. This looks for things like (a+b+c) and +/// (a+c+d) -> (a+c). The common expression is *removed* from the Bases. +static SCEVHandle +RemoveCommonExpressionsFromUseBases(std::vector &Uses) { + unsigned NumUses = Uses.size(); + + // Only one use? Use its base, regardless of what it is! + SCEVHandle Zero = SCEVUnknown::getIntegerSCEV(0, Uses[0].Base->getType()); + SCEVHandle Result = Zero; + if (NumUses == 1) { + std::swap(Result, Uses[0].Base); + return Result; + } + + // To find common subexpressions, count how many of Uses use each expression. + // If any subexpressions are used Uses.size() times, they are common. + std::map SubExpressionUseCounts; + + // UniqueSubExprs - Keep track of all of the subexpressions we see in the + // order we see them. + std::vector UniqueSubExprs; + + std::vector SubExprs; + for (unsigned i = 0; i != NumUses; ++i) { + // If the base is zero (which is common), return zero now, there are no + // CSEs we can find. + if (Uses[i].Base == Zero) return Zero; + + // Split the expression into subexprs. + SeparateSubExprs(SubExprs, Uses[i].Base); + // Add one to SubExpressionUseCounts for each subexpr present. + for (unsigned j = 0, e = SubExprs.size(); j != e; ++j) + if (++SubExpressionUseCounts[SubExprs[j]] == 1) + UniqueSubExprs.push_back(SubExprs[j]); + SubExprs.clear(); + } + + // Now that we know how many times each is used, build Result. Iterate over + // UniqueSubexprs so that we have a stable ordering. + for (unsigned i = 0, e = UniqueSubExprs.size(); i != e; ++i) { + std::map::iterator I = + SubExpressionUseCounts.find(UniqueSubExprs[i]); + assert(I != SubExpressionUseCounts.end() && "Entry not found?"); + if (I->second == NumUses) { // Found CSE! + Result = SCEVAddExpr::get(Result, I->first); + } else { + // Remove non-cse's from SubExpressionUseCounts. + SubExpressionUseCounts.erase(I); + } + } + + // If we found no CSE's, return now. + if (Result == Zero) return Result; + + // Otherwise, remove all of the CSE's we found from each of the base values. + for (unsigned i = 0; i != NumUses; ++i) { + // Split the expression into subexprs. + SeparateSubExprs(SubExprs, Uses[i].Base); + + // Remove any common subexpressions. + for (unsigned j = 0, e = SubExprs.size(); j != e; ++j) + if (SubExpressionUseCounts.count(SubExprs[j])) { + SubExprs.erase(SubExprs.begin()+j); + --j; --e; + } + + // Finally, the non-shared expressions together. + if (SubExprs.empty()) + Uses[i].Base = Zero; + else + Uses[i].Base = SCEVAddExpr::get(SubExprs); + SubExprs.clear(); + } + + return Result; +} + +/// isZero - returns true if the scalar evolution expression is zero. +/// +static bool isZero(SCEVHandle &V) { + if (SCEVConstant *SC = dyn_cast(V)) + return SC->getValue()->isZero(); + return false; +} + +/// ValidStride - Check whether the given Scale is valid for all loads and +/// stores in UsersToProcess. +/// +bool LoopStrengthReduce::ValidStride(int64_t Scale, + const std::vector& UsersToProcess) { + for (unsigned i=0, e = UsersToProcess.size(); i!=e; ++i) { + // If this is a load or other access, pass the type of the access in. + const Type *AccessTy = Type::VoidTy; + if (StoreInst *SI = dyn_cast(UsersToProcess[i].Inst)) + AccessTy = SI->getOperand(0)->getType(); + else if (LoadInst *LI = dyn_cast(UsersToProcess[i].Inst)) + AccessTy = LI->getType(); + + TargetLowering::AddrMode AM; + if (SCEVConstant *SC = dyn_cast(UsersToProcess[i].Imm)) + AM.BaseOffs = SC->getValue()->getSExtValue(); + AM.Scale = Scale; + + // If load[imm+r*scale] is illegal, bail out. + if (!TLI->isLegalAddressingMode(AM, AccessTy)) + return false; + } + return true; +} + +/// CheckForIVReuse - Returns the multiple if the stride is the multiple +/// of a previous stride and it is a legal value for the target addressing +/// mode scale component. This allows the users of this stride to be rewritten +/// as prev iv * factor. It returns 0 if no reuse is possible. +unsigned LoopStrengthReduce::CheckForIVReuse(const SCEVHandle &Stride, + IVExpr &IV, const Type *Ty, + const std::vector& UsersToProcess) { + if (!TLI) return 0; + + if (SCEVConstant *SC = dyn_cast(Stride)) { + int64_t SInt = SC->getValue()->getSExtValue(); + if (SInt == 1) return 0; + + for (std::map::iterator SI= IVsByStride.begin(), + SE = IVsByStride.end(); SI != SE; ++SI) { + int64_t SSInt = cast(SI->first)->getValue()->getSExtValue(); + if (SInt != -SSInt && + (unsigned(abs(SInt)) < SSInt || (SInt % SSInt) != 0)) + continue; + int64_t Scale = SInt / SSInt; + // Check that this stride is valid for all the types used for loads and + // stores; if it can be used for some and not others, we might as well use + // the original stride everywhere, since we have to create the IV for it + // anyway. + if (ValidStride(Scale, UsersToProcess)) + for (std::vector::iterator II = SI->second.IVs.begin(), + IE = SI->second.IVs.end(); II != IE; ++II) + // FIXME: Only handle base == 0 for now. + // Only reuse previous IV if it would not require a type conversion. + if (isZero(II->Base) && II->Base->getType() == Ty) { + IV = *II; + return Scale; + } + } + } + return 0; +} + +/// PartitionByIsUseOfPostIncrementedValue - Simple boolean predicate that +/// returns true if Val's isUseOfPostIncrementedValue is true. +static bool PartitionByIsUseOfPostIncrementedValue(const BasedUser &Val) { + return Val.isUseOfPostIncrementedValue; +} + +/// isNonConstantNegative - REturn true if the specified scev is negated, but +/// not a constant. +static bool isNonConstantNegative(const SCEVHandle &Expr) { + SCEVMulExpr *Mul = dyn_cast(Expr); + if (!Mul) return false; + + // If there is a constant factor, it will be first. + SCEVConstant *SC = dyn_cast(Mul->getOperand(0)); + if (!SC) return false; + + // Return true if the value is negative, this matches things like (-42 * V). + return SC->getValue()->getValue().isNegative(); +} + +/// StrengthReduceStridedIVUsers - Strength reduce all of the users of a single +/// stride of IV. All of the users may have different starting values, and this +/// may not be the only stride (we know it is if isOnlyStride is true). +void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride, + IVUsersOfOneStride &Uses, + Loop *L, + bool isOnlyStride) { + // Transform our list of users and offsets to a bit more complex table. In + // this new vector, each 'BasedUser' contains 'Base' the base of the + // strided accessas well as the old information from Uses. We progressively + // move information from the Base field to the Imm field, until we eventually + // have the full access expression to rewrite the use. + std::vector UsersToProcess; + UsersToProcess.reserve(Uses.Users.size()); + for (unsigned i = 0, e = Uses.Users.size(); i != e; ++i) { + UsersToProcess.push_back(Uses.Users[i]); + + // Move any loop invariant operands from the offset field to the immediate + // field of the use, so that we don't try to use something before it is + // computed. + MoveLoopVariantsToImediateField(UsersToProcess.back().Base, + UsersToProcess.back().Imm, L); + assert(UsersToProcess.back().Base->isLoopInvariant(L) && + "Base value is not loop invariant!"); + } + + // We now have a whole bunch of uses of like-strided induction variables, but + // they might all have different bases. We want to emit one PHI node for this + // stride which we fold as many common expressions (between the IVs) into as + // possible. Start by identifying the common expressions in the base values + // for the strides (e.g. if we have "A+C+B" and "A+B+D" as our bases, find + // "A+B"), emit it to the preheader, then remove the expression from the + // UsersToProcess base values. + SCEVHandle CommonExprs = + RemoveCommonExpressionsFromUseBases(UsersToProcess); + + // Next, figure out what we can represent in the immediate fields of + // instructions. If we can represent anything there, move it to the imm + // fields of the BasedUsers. We do this so that it increases the commonality + // of the remaining uses. + for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) { + // If the user is not in the current loop, this means it is using the exit + // value of the IV. Do not put anything in the base, make sure it's all in + // the immediate field to allow as much factoring as possible. + if (!L->contains(UsersToProcess[i].Inst->getParent())) { + UsersToProcess[i].Imm = SCEVAddExpr::get(UsersToProcess[i].Imm, + UsersToProcess[i].Base); + UsersToProcess[i].Base = + SCEVUnknown::getIntegerSCEV(0, UsersToProcess[i].Base->getType()); + } else { + + // Addressing modes can be folded into loads and stores. Be careful that + // the store is through the expression, not of the expression though. + bool isAddress = isa(UsersToProcess[i].Inst); + if (StoreInst *SI = dyn_cast(UsersToProcess[i].Inst)) { + if (SI->getOperand(1) == UsersToProcess[i].OperandValToReplace) + isAddress = true; + } else if (IntrinsicInst *II = + dyn_cast(UsersToProcess[i].Inst)) { + // Addressing modes can also be folded into prefetches. + if (II->getIntrinsicID() == Intrinsic::prefetch && + II->getOperand(1) == UsersToProcess[i].OperandValToReplace) + isAddress = true; + } + + MoveImmediateValues(TLI, UsersToProcess[i].Inst, UsersToProcess[i].Base, + UsersToProcess[i].Imm, isAddress, L); + } + } + + // Check if it is possible to reuse a IV with stride that is factor of this + // stride. And the multiple is a number that can be encoded in the scale + // field of the target addressing mode. And we will have a valid + // instruction after this substition, including the immediate field, if any. + PHINode *NewPHI = NULL; + Value *IncV = NULL; + IVExpr ReuseIV; + unsigned RewriteFactor = CheckForIVReuse(Stride, ReuseIV, + CommonExprs->getType(), + UsersToProcess); + if (RewriteFactor != 0) { + DOUT << "BASED ON IV of STRIDE " << *ReuseIV.Stride + << " and BASE " << *ReuseIV.Base << " :\n"; + NewPHI = ReuseIV.PHI; + IncV = ReuseIV.IncV; + } + + const Type *ReplacedTy = CommonExprs->getType(); + + // Now that we know what we need to do, insert the PHI node itself. + // + DOUT << "INSERTING IV of TYPE " << *ReplacedTy << " of STRIDE " + << *Stride << " and BASE " << *CommonExprs << ": "; + + SCEVExpander Rewriter(*SE, *LI); + SCEVExpander PreheaderRewriter(*SE, *LI); + + BasicBlock *Preheader = L->getLoopPreheader(); + Instruction *PreInsertPt = Preheader->getTerminator(); + Instruction *PhiInsertBefore = L->getHeader()->begin(); + + BasicBlock *LatchBlock = L->getLoopLatch(); + + + // Emit the initial base value into the loop preheader. + Value *CommonBaseV + = PreheaderRewriter.expandCodeFor(CommonExprs, PreInsertPt); + + if (RewriteFactor == 0) { + // Create a new Phi for this base, and stick it in the loop header. + NewPHI = new PHINode(ReplacedTy, "iv.", PhiInsertBefore); + ++NumInserted; + + // Add common base to the new Phi node. + NewPHI->addIncoming(CommonBaseV, Preheader); + + // If the stride is negative, insert a sub instead of an add for the + // increment. + bool isNegative = isNonConstantNegative(Stride); + SCEVHandle IncAmount = Stride; + if (isNegative) + IncAmount = SCEV::getNegativeSCEV(Stride); + + // Insert the stride into the preheader. + Value *StrideV = PreheaderRewriter.expandCodeFor(IncAmount, PreInsertPt); + if (!isa(StrideV)) ++NumVariable; + + // Emit the increment of the base value before the terminator of the loop + // latch block, and add it to the Phi node. + SCEVHandle IncExp = SCEVUnknown::get(StrideV); + if (isNegative) + IncExp = SCEV::getNegativeSCEV(IncExp); + IncExp = SCEVAddExpr::get(SCEVUnknown::get(NewPHI), IncExp); + + IncV = Rewriter.expandCodeFor(IncExp, LatchBlock->getTerminator()); + IncV->setName(NewPHI->getName()+".inc"); + NewPHI->addIncoming(IncV, LatchBlock); + + // Remember this in case a later stride is multiple of this. + IVsByStride[Stride].addIV(Stride, CommonExprs, NewPHI, IncV); + + DOUT << " IV=%" << NewPHI->getNameStr() << " INC=%" << IncV->getNameStr(); + } else { + Constant *C = dyn_cast(CommonBaseV); + if (!C || + (!C->isNullValue() && + !isTargetConstant(SCEVUnknown::get(CommonBaseV), ReplacedTy, TLI))) + // We want the common base emitted into the preheader! This is just + // using cast as a copy so BitCast (no-op cast) is appropriate + CommonBaseV = new BitCastInst(CommonBaseV, CommonBaseV->getType(), + "commonbase", PreInsertPt); + } + DOUT << "\n"; + + // We want to emit code for users inside the loop first. To do this, we + // rearrange BasedUser so that the entries at the end have + // isUseOfPostIncrementedValue = false, because we pop off the end of the + // vector (so we handle them first). + std::partition(UsersToProcess.begin(), UsersToProcess.end(), + PartitionByIsUseOfPostIncrementedValue); + + // Sort this by base, so that things with the same base are handled + // together. By partitioning first and stable-sorting later, we are + // guaranteed that within each base we will pop off users from within the + // loop before users outside of the loop with a particular base. + // + // We would like to use stable_sort here, but we can't. The problem is that + // SCEVHandle's don't have a deterministic ordering w.r.t to each other, so + // we don't have anything to do a '<' comparison on. Because we think the + // number of uses is small, do a horrible bubble sort which just relies on + // ==. + for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) { + // Get a base value. + SCEVHandle Base = UsersToProcess[i].Base; + + // Compact everything with this base to be consequetive with this one. + for (unsigned j = i+1; j != e; ++j) { + if (UsersToProcess[j].Base == Base) { + std::swap(UsersToProcess[i+1], UsersToProcess[j]); + ++i; + } + } + } + + // Process all the users now. This outer loop handles all bases, the inner + // loop handles all users of a particular base. + while (!UsersToProcess.empty()) { + SCEVHandle Base = UsersToProcess.back().Base; + + // Emit the code for Base into the preheader. + Value *BaseV = PreheaderRewriter.expandCodeFor(Base, PreInsertPt); + + DOUT << " INSERTING code for BASE = " << *Base << ":"; + if (BaseV->hasName()) + DOUT << " Result value name = %" << BaseV->getNameStr(); + DOUT << "\n"; + + // If BaseV is a constant other than 0, make sure that it gets inserted into + // the preheader, instead of being forward substituted into the uses. We do + // this by forcing a BitCast (noop cast) to be inserted into the preheader + // in this case. + if (Constant *C = dyn_cast(BaseV)) { + if (!C->isNullValue() && !isTargetConstant(Base, ReplacedTy, TLI)) { + // We want this constant emitted into the preheader! This is just + // using cast as a copy so BitCast (no-op cast) is appropriate + BaseV = new BitCastInst(BaseV, BaseV->getType(), "preheaderinsert", + PreInsertPt); + } + } + + // Emit the code to add the immediate offset to the Phi value, just before + // the instructions that we identified as using this stride and base. + do { + // FIXME: Use emitted users to emit other users. + BasedUser &User = UsersToProcess.back(); + + // If this instruction wants to use the post-incremented value, move it + // after the post-inc and use its value instead of the PHI. + Value *RewriteOp = NewPHI; + if (User.isUseOfPostIncrementedValue) { + RewriteOp = IncV; + + // If this user is in the loop, make sure it is the last thing in the + // loop to ensure it is dominated by the increment. + if (L->contains(User.Inst->getParent())) + User.Inst->moveBefore(LatchBlock->getTerminator()); + } + if (RewriteOp->getType() != ReplacedTy) { + Instruction::CastOps opcode = Instruction::Trunc; + if (ReplacedTy->getPrimitiveSizeInBits() == + RewriteOp->getType()->getPrimitiveSizeInBits()) + opcode = Instruction::BitCast; + RewriteOp = SCEVExpander::InsertCastOfTo(opcode, RewriteOp, ReplacedTy); + } + + SCEVHandle RewriteExpr = SCEVUnknown::get(RewriteOp); + + // Clear the SCEVExpander's expression map so that we are guaranteed + // to have the code emitted where we expect it. + Rewriter.clear(); + + // If we are reusing the iv, then it must be multiplied by a constant + // factor take advantage of addressing mode scale component. + if (RewriteFactor != 0) { + RewriteExpr = + SCEVMulExpr::get(SCEVUnknown::getIntegerSCEV(RewriteFactor, + RewriteExpr->getType()), + RewriteExpr); + + // The common base is emitted in the loop preheader. But since we + // are reusing an IV, it has not been used to initialize the PHI node. + // Add it to the expression used to rewrite the uses. + if (!isa(CommonBaseV) || + !cast(CommonBaseV)->isZero()) + RewriteExpr = SCEVAddExpr::get(RewriteExpr, + SCEVUnknown::get(CommonBaseV)); + } + + // Now that we know what we need to do, insert code before User for the + // immediate and any loop-variant expressions. + if (!isa(BaseV) || !cast(BaseV)->isZero()) + // Add BaseV to the PHI value if needed. + RewriteExpr = SCEVAddExpr::get(RewriteExpr, SCEVUnknown::get(BaseV)); + + User.RewriteInstructionToUseNewBase(RewriteExpr, Rewriter, L, this); + + // Mark old value we replaced as possibly dead, so that it is elminated + // if we just replaced the last use of that value. + DeadInsts.insert(cast(User.OperandValToReplace)); + + UsersToProcess.pop_back(); + ++NumReduced; + + // If there are any more users to process with the same base, process them + // now. We sorted by base above, so we just have to check the last elt. + } while (!UsersToProcess.empty() && UsersToProcess.back().Base == Base); + // TODO: Next, find out which base index is the most common, pull it out. + } + + // IMPORTANT TODO: Figure out how to partition the IV's with this stride, but + // different starting values, into different PHIs. +} + +/// FindIVForUser - If Cond has an operand that is an expression of an IV, +/// set the IV user and stride information and return true, otherwise return +/// false. +bool LoopStrengthReduce::FindIVForUser(ICmpInst *Cond, IVStrideUse *&CondUse, + const SCEVHandle *&CondStride) { + for (unsigned Stride = 0, e = StrideOrder.size(); Stride != e && !CondUse; + ++Stride) { + std::map::iterator SI = + IVUsesByStride.find(StrideOrder[Stride]); + assert(SI != IVUsesByStride.end() && "Stride doesn't exist!"); + + for (std::vector::iterator UI = SI->second.Users.begin(), + E = SI->second.Users.end(); UI != E; ++UI) + if (UI->User == Cond) { + // NOTE: we could handle setcc instructions with multiple uses here, but + // InstCombine does it as well for simple uses, it's not clear that it + // occurs enough in real life to handle. + CondUse = &*UI; + CondStride = &SI->first; + return true; + } + } + return false; +} + +// OptimizeIndvars - Now that IVUsesByStride is set up with all of the indvar +// uses in the loop, look to see if we can eliminate some, in favor of using +// common indvars for the different uses. +void LoopStrengthReduce::OptimizeIndvars(Loop *L) { + // TODO: implement optzns here. + + // Finally, get the terminating condition for the loop if possible. If we + // can, we want to change it to use a post-incremented version of its + // induction variable, to allow coalescing the live ranges for the IV into + // one register value. + PHINode *SomePHI = cast(L->getHeader()->begin()); + BasicBlock *Preheader = L->getLoopPreheader(); + BasicBlock *LatchBlock = + SomePHI->getIncomingBlock(SomePHI->getIncomingBlock(0) == Preheader); + BranchInst *TermBr = dyn_cast(LatchBlock->getTerminator()); + if (!TermBr || TermBr->isUnconditional() || + !isa(TermBr->getCondition())) + return; + ICmpInst *Cond = cast(TermBr->getCondition()); + + // Search IVUsesByStride to find Cond's IVUse if there is one. + IVStrideUse *CondUse = 0; + const SCEVHandle *CondStride = 0; + + if (!FindIVForUser(Cond, CondUse, CondStride)) + return; // setcc doesn't use the IV. + + + // It's possible for the setcc instruction to be anywhere in the loop, and + // possible for it to have multiple users. If it is not immediately before + // the latch block branch, move it. + if (&*++BasicBlock::iterator(Cond) != (Instruction*)TermBr) { + if (Cond->hasOneUse()) { // Condition has a single use, just move it. + Cond->moveBefore(TermBr); + } else { + // Otherwise, clone the terminating condition and insert into the loopend. + Cond = cast(Cond->clone()); + Cond->setName(L->getHeader()->getName() + ".termcond"); + LatchBlock->getInstList().insert(TermBr, Cond); + + // Clone the IVUse, as the old use still exists! + IVUsesByStride[*CondStride].addUser(CondUse->Offset, Cond, + CondUse->OperandValToReplace); + CondUse = &IVUsesByStride[*CondStride].Users.back(); + } + } + + // If we get to here, we know that we can transform the setcc instruction to + // use the post-incremented version of the IV, allowing us to coalesce the + // live ranges for the IV correctly. + CondUse->Offset = SCEV::getMinusSCEV(CondUse->Offset, *CondStride); + CondUse->isUseOfPostIncrementedValue = true; +} + +namespace { + // Constant strides come first which in turns are sorted by their absolute + // values. If absolute values are the same, then positive strides comes first. + // e.g. + // 4, -1, X, 1, 2 ==> 1, -1, 2, 4, X + struct StrideCompare { + bool operator()(const SCEVHandle &LHS, const SCEVHandle &RHS) { + SCEVConstant *LHSC = dyn_cast(LHS); + SCEVConstant *RHSC = dyn_cast(RHS); + if (LHSC && RHSC) { + int64_t LV = LHSC->getValue()->getSExtValue(); + int64_t RV = RHSC->getValue()->getSExtValue(); + uint64_t ALV = (LV < 0) ? -LV : LV; + uint64_t ARV = (RV < 0) ? -RV : RV; + if (ALV == ARV) + return LV > RV; + else + return ALV < ARV; + } + return (LHSC && !RHSC); + } + }; +} + +bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager &LPM) { + + LI = &getAnalysis(); + DT = &getAnalysis(); + SE = &getAnalysis(); + TD = &getAnalysis(); + UIntPtrTy = TD->getIntPtrType(); + + // Find all uses of induction variables in this loop, and catagorize + // them by stride. Start by finding all of the PHI nodes in the header for + // this loop. If they are induction variables, inspect their uses. + std::set Processed; // Don't reprocess instructions. + for (BasicBlock::iterator I = L->getHeader()->begin(); isa(I); ++I) + AddUsersIfInteresting(I, L, Processed); + + // If we have nothing to do, return. + if (IVUsesByStride.empty()) return false; + + // Optimize induction variables. Some indvar uses can be transformed to use + // strides that will be needed for other purposes. A common example of this + // is the exit test for the loop, which can often be rewritten to use the + // computation of some other indvar to decide when to terminate the loop. + OptimizeIndvars(L); + + + // FIXME: We can widen subreg IV's here for RISC targets. e.g. instead of + // doing computation in byte values, promote to 32-bit values if safe. + + // FIXME: Attempt to reuse values across multiple IV's. In particular, we + // could have something like "for(i) { foo(i*8); bar(i*16) }", which should be + // codegened as "for (j = 0;; j+=8) { foo(j); bar(j+j); }" on X86/PPC. Need + // to be careful that IV's are all the same type. Only works for intptr_t + // indvars. + + // If we only have one stride, we can more aggressively eliminate some things. + bool HasOneStride = IVUsesByStride.size() == 1; + +#ifndef NDEBUG + DOUT << "\nLSR on "; + DEBUG(L->dump()); +#endif + + // IVsByStride keeps IVs for one particular loop. + IVsByStride.clear(); + + // Sort the StrideOrder so we process larger strides first. + std::stable_sort(StrideOrder.begin(), StrideOrder.end(), StrideCompare()); + + // Note: this processes each stride/type pair individually. All users passed + // into StrengthReduceStridedIVUsers have the same type AND stride. Also, + // node that we iterate over IVUsesByStride indirectly by using StrideOrder. + // This extra layer of indirection makes the ordering of strides deterministic + // - not dependent on map order. + for (unsigned Stride = 0, e = StrideOrder.size(); Stride != e; ++Stride) { + std::map::iterator SI = + IVUsesByStride.find(StrideOrder[Stride]); + assert(SI != IVUsesByStride.end() && "Stride doesn't exist!"); + StrengthReduceStridedIVUsers(SI->first, SI->second, L, HasOneStride); + } + + // Clean up after ourselves + if (!DeadInsts.empty()) { + DeleteTriviallyDeadInstructions(DeadInsts); + + BasicBlock::iterator I = L->getHeader()->begin(); + PHINode *PN; + while ((PN = dyn_cast(I))) { + ++I; // Preincrement iterator to avoid invalidating it when deleting PN. + + // At this point, we know that we have killed one or more GEP + // instructions. It is worth checking to see if the cann indvar is also + // dead, so that we can remove it as well. The requirements for the cann + // indvar to be considered dead are: + // 1. the cann indvar has one use + // 2. the use is an add instruction + // 3. the add has one use + // 4. the add is used by the cann indvar + // If all four cases above are true, then we can remove both the add and + // the cann indvar. + // FIXME: this needs to eliminate an induction variable even if it's being + // compared against some value to decide loop termination. + if (PN->hasOneUse()) { + Instruction *BO = dyn_cast(*PN->use_begin()); + if (BO && (isa(BO) || isa(BO))) { + if (BO->hasOneUse() && PN == *(BO->use_begin())) { + DeadInsts.insert(BO); + // Break the cycle, then delete the PHI. + PN->replaceAllUsesWith(UndefValue::get(PN->getType())); + SE->deleteValueFromRecords(PN); + PN->eraseFromParent(); + } + } + } + } + DeleteTriviallyDeadInstructions(DeadInsts); + } + + CastedPointers.clear(); + IVUsesByStride.clear(); + StrideOrder.clear(); + return false; +} diff --git a/lib/Transforms/Scalar/LoopUnroll.cpp b/lib/Transforms/Scalar/LoopUnroll.cpp new file mode 100644 index 0000000..babfc24 --- /dev/null +++ b/lib/Transforms/Scalar/LoopUnroll.cpp @@ -0,0 +1,500 @@ +//===-- LoopUnroll.cpp - Loop unroller pass -------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass implements a simple loop unroller. It works best when loops have +// been canonicalized by the -indvars pass, allowing it to determine the trip +// counts of loops easily. +// +// This pass will multi-block loops only if they contain no non-unrolled +// subloops. The process of unrolling can produce extraneous basic blocks +// linked with unconditional branches. This will be corrected in the future. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-unroll" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/IntrinsicInst.h" +#include +#include +using namespace llvm; + +STATISTIC(NumCompletelyUnrolled, "Number of loops completely unrolled"); +STATISTIC(NumUnrolled, "Number of loops unrolled (completely or otherwise)"); + +namespace { + cl::opt + UnrollThreshold + ("unroll-threshold", cl::init(100), cl::Hidden, + cl::desc("The cut-off point for automatic loop unrolling")); + + cl::opt + UnrollCount + ("unroll-count", cl::init(0), cl::Hidden, + cl::desc("Use this unroll count for all loops, for testing purposes")); + + class VISIBILITY_HIDDEN LoopUnroll : public LoopPass { + LoopInfo *LI; // The current loop information + public: + static char ID; // Pass ID, replacement for typeid + LoopUnroll() : LoopPass((intptr_t)&ID) {} + + /// A magic value for use with the Threshold parameter to indicate + /// that the loop unroll should be performed regardless of how much + /// code expansion would result. + static const unsigned NoThreshold = UINT_MAX; + + bool runOnLoop(Loop *L, LPPassManager &LPM); + bool unrollLoop(Loop *L, unsigned Count, unsigned Threshold); + BasicBlock *FoldBlockIntoPredecessor(BasicBlock *BB); + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG... + /// + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(LoopSimplifyID); + AU.addRequiredID(LCSSAID); + AU.addRequired(); + AU.addPreservedID(LCSSAID); + AU.addPreserved(); + } + }; + char LoopUnroll::ID = 0; + RegisterPass X("loop-unroll", "Unroll loops"); +} + +LoopPass *llvm::createLoopUnrollPass() { return new LoopUnroll(); } + +/// ApproximateLoopSize - Approximate the size of the loop. +static unsigned ApproximateLoopSize(const Loop *L) { + unsigned Size = 0; + for (unsigned i = 0, e = L->getBlocks().size(); i != e; ++i) { + BasicBlock *BB = L->getBlocks()[i]; + Instruction *Term = BB->getTerminator(); + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + if (isa(I) && BB == L->getHeader()) { + // Ignore PHI nodes in the header. + } else if (I->hasOneUse() && I->use_back() == Term) { + // Ignore instructions only used by the loop terminator. + } else if (isa(I)) { + // Ignore debug instructions + } else { + ++Size; + } + + // TODO: Ignore expressions derived from PHI and constants if inval of phi + // is a constant, or if operation is associative. This will get induction + // variables. + } + } + + return Size; +} + +// RemapInstruction - Convert the instruction operands from referencing the +// current values into those specified by ValueMap. +// +static inline void RemapInstruction(Instruction *I, + DenseMap &ValueMap) { + for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) { + Value *Op = I->getOperand(op); + DenseMap::iterator It = ValueMap.find(Op); + if (It != ValueMap.end()) Op = It->second; + I->setOperand(op, Op); + } +} + +// FoldBlockIntoPredecessor - Folds a basic block into its predecessor if it +// only has one predecessor, and that predecessor only has one successor. +// Returns the new combined block. +BasicBlock *LoopUnroll::FoldBlockIntoPredecessor(BasicBlock *BB) { + // Merge basic blocks into their predecessor if there is only one distinct + // pred, and if there is only one distinct successor of the predecessor, and + // if there are no PHI nodes. + // + BasicBlock *OnlyPred = BB->getSinglePredecessor(); + if (!OnlyPred) return 0; + + if (OnlyPred->getTerminator()->getNumSuccessors() != 1) + return 0; + + DOUT << "Merging: " << *BB << "into: " << *OnlyPred; + + // Resolve any PHI nodes at the start of the block. They are all + // guaranteed to have exactly one entry if they exist, unless there are + // multiple duplicate (but guaranteed to be equal) entries for the + // incoming edges. This occurs when there are multiple edges from + // OnlyPred to OnlySucc. + // + while (PHINode *PN = dyn_cast(&BB->front())) { + PN->replaceAllUsesWith(PN->getIncomingValue(0)); + BB->getInstList().pop_front(); // Delete the phi node... + } + + // Delete the unconditional branch from the predecessor... + OnlyPred->getInstList().pop_back(); + + // Move all definitions in the successor to the predecessor... + OnlyPred->getInstList().splice(OnlyPred->end(), BB->getInstList()); + + // Make all PHI nodes that referred to BB now refer to Pred as their + // source... + BB->replaceAllUsesWith(OnlyPred); + + std::string OldName = BB->getName(); + + // Erase basic block from the function... + LI->removeBlock(BB); + BB->eraseFromParent(); + + // Inherit predecessor's name if it exists... + if (!OldName.empty() && !OnlyPred->hasName()) + OnlyPred->setName(OldName); + + return OnlyPred; +} + +bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) { + LI = &getAnalysis(); + + // Unroll the loop. + if (!unrollLoop(L, UnrollCount, UnrollThreshold)) + return false; + + // Update the loop information for this loop. + // If we completely unrolled the loop, remove it from the parent. + if (L->getNumBackEdges() == 0) + LPM.deleteLoopFromQueue(L); + + return true; +} + +/// Unroll the given loop by UnrollCount, or by a heuristically-determined +/// value if Count is zero. If Threshold is not NoThreshold, it is a value +/// to limit code size expansion. If the loop size would expand beyond the +/// threshold value, unrolling is suppressed. The return value is true if +/// any transformations are performed. +/// +bool LoopUnroll::unrollLoop(Loop *L, unsigned Count, unsigned Threshold) { + assert(L->isLCSSAForm()); + + BasicBlock *Header = L->getHeader(); + BasicBlock *LatchBlock = L->getLoopLatch(); + BranchInst *BI = dyn_cast(LatchBlock->getTerminator()); + + DOUT << "Loop Unroll: F[" << Header->getParent()->getName() + << "] Loop %" << Header->getName() << "\n"; + + if (!BI || BI->isUnconditional()) { + // The loop-rorate pass can be helpful to avoid this in many cases. + DOUT << " Can't unroll; loop not terminated by a conditional branch.\n"; + return false; + } + + // Determine the trip count and/or trip multiple. A TripCount value of zero + // is used to mean an unknown trip count. The TripMultiple value is the + // greatest known integer multiple of the trip count. + unsigned TripCount = 0; + unsigned TripMultiple = 1; + if (Value *TripCountValue = L->getTripCount()) { + if (ConstantInt *TripCountC = dyn_cast(TripCountValue)) { + // Guard against huge trip counts. This also guards against assertions in + // APInt from the use of getZExtValue, below. + if (TripCountC->getValue().getActiveBits() <= 32) { + TripCount = (unsigned)TripCountC->getZExtValue(); + } + } else if (BinaryOperator *BO = dyn_cast(TripCountValue)) { + switch (BO->getOpcode()) { + case BinaryOperator::Mul: + if (ConstantInt *MultipleC = dyn_cast(BO->getOperand(1))) { + if (MultipleC->getValue().getActiveBits() <= 32) { + TripMultiple = (unsigned)MultipleC->getZExtValue(); + } + } + break; + default: break; + } + } + } + if (TripCount != 0) + DOUT << " Trip Count = " << TripCount << "\n"; + if (TripMultiple != 1) + DOUT << " Trip Multiple = " << TripMultiple << "\n"; + + // Automatically select an unroll count. + if (Count == 0) { + // Conservative heuristic: if we know the trip count, see if we can + // completely unroll (subject to the threshold, checked below); otherwise + // don't unroll. + if (TripCount != 0) { + Count = TripCount; + } else { + return false; + } + } + + // Effectively "DCE" unrolled iterations that are beyond the tripcount + // and will never be executed. + if (TripCount != 0 && Count > TripCount) + Count = TripCount; + + assert(Count > 0); + assert(TripMultiple > 0); + assert(TripCount == 0 || TripCount % TripMultiple == 0); + + // Enforce the threshold. + if (Threshold != NoThreshold) { + unsigned LoopSize = ApproximateLoopSize(L); + DOUT << " Loop Size = " << LoopSize << "\n"; + uint64_t Size = (uint64_t)LoopSize*Count; + if (TripCount != 1 && Size > Threshold) { + DOUT << " TOO LARGE TO UNROLL: " + << Size << ">" << Threshold << "\n"; + return false; + } + } + + // Are we eliminating the loop control altogether? + bool CompletelyUnroll = Count == TripCount; + + // If we know the trip count, we know the multiple... + unsigned BreakoutTrip = 0; + if (TripCount != 0) { + BreakoutTrip = TripCount % Count; + TripMultiple = 0; + } else { + // Figure out what multiple to use. + BreakoutTrip = TripMultiple = + (unsigned)GreatestCommonDivisor64(Count, TripMultiple); + } + + if (CompletelyUnroll) { + DOUT << "COMPLETELY UNROLLING loop %" << Header->getName() + << " with trip count " << TripCount << "!\n"; + } else { + DOUT << "UNROLLING loop %" << Header->getName() + << " by " << Count; + if (TripMultiple == 0 || BreakoutTrip != TripMultiple) { + DOUT << " with a breakout at trip " << BreakoutTrip; + } else if (TripMultiple != 1) { + DOUT << " with " << TripMultiple << " trips per branch"; + } + DOUT << "!\n"; + } + + std::vector LoopBlocks = L->getBlocks(); + + bool ContinueOnTrue = L->contains(BI->getSuccessor(0)); + BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue); + + // For the first iteration of the loop, we should use the precloned values for + // PHI nodes. Insert associations now. + typedef DenseMap ValueMapTy; + ValueMapTy LastValueMap; + std::vector OrigPHINode; + for (BasicBlock::iterator I = Header->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + OrigPHINode.push_back(PN); + if (Instruction *I = + dyn_cast(PN->getIncomingValueForBlock(LatchBlock))) + if (L->contains(I->getParent())) + LastValueMap[I] = I; + } + + std::vector Headers; + std::vector Latches; + Headers.push_back(Header); + Latches.push_back(LatchBlock); + + for (unsigned It = 1; It != Count; ++It) { + char SuffixBuffer[100]; + sprintf(SuffixBuffer, ".%d", It); + + std::vector NewBlocks; + + for (std::vector::iterator BB = LoopBlocks.begin(), + E = LoopBlocks.end(); BB != E; ++BB) { + ValueMapTy ValueMap; + BasicBlock *New = CloneBasicBlock(*BB, ValueMap, SuffixBuffer); + Header->getParent()->getBasicBlockList().push_back(New); + + // Loop over all of the PHI nodes in the block, changing them to use the + // incoming values from the previous block. + if (*BB == Header) + for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) { + PHINode *NewPHI = cast(ValueMap[OrigPHINode[i]]); + Value *InVal = NewPHI->getIncomingValueForBlock(LatchBlock); + if (Instruction *InValI = dyn_cast(InVal)) + if (It > 1 && L->contains(InValI->getParent())) + InVal = LastValueMap[InValI]; + ValueMap[OrigPHINode[i]] = InVal; + New->getInstList().erase(NewPHI); + } + + // Update our running map of newest clones + LastValueMap[*BB] = New; + for (ValueMapTy::iterator VI = ValueMap.begin(), VE = ValueMap.end(); + VI != VE; ++VI) + LastValueMap[VI->first] = VI->second; + + L->addBasicBlockToLoop(New, *LI); + + // Add phi entries for newly created values to all exit blocks except + // the successor of the latch block. The successor of the exit block will + // be updated specially after unrolling all the way. + if (*BB != LatchBlock) + for (Value::use_iterator UI = (*BB)->use_begin(), UE = (*BB)->use_end(); + UI != UE; ++UI) { + Instruction *UseInst = cast(*UI); + if (isa(UseInst) && !L->contains(UseInst->getParent())) { + PHINode *phi = cast(UseInst); + Value *Incoming = phi->getIncomingValueForBlock(*BB); + if (isa(Incoming)) + Incoming = LastValueMap[Incoming]; + + phi->addIncoming(Incoming, New); + } + } + + // Keep track of new headers and latches as we create them, so that + // we can insert the proper branches later. + if (*BB == Header) + Headers.push_back(New); + if (*BB == LatchBlock) { + Latches.push_back(New); + + // Also, clear out the new latch's back edge so that it doesn't look + // like a new loop, so that it's amenable to being merged with adjacent + // blocks later on. + TerminatorInst *Term = New->getTerminator(); + assert(L->contains(Term->getSuccessor(!ContinueOnTrue))); + assert(Term->getSuccessor(ContinueOnTrue) == LoopExit); + Term->setSuccessor(!ContinueOnTrue, NULL); + } + + NewBlocks.push_back(New); + } + + // Remap all instructions in the most recent iteration + for (unsigned i = 0; i < NewBlocks.size(); ++i) + for (BasicBlock::iterator I = NewBlocks[i]->begin(), + E = NewBlocks[i]->end(); I != E; ++I) + RemapInstruction(I, LastValueMap); + } + + // The latch block exits the loop. If there are any PHI nodes in the + // successor blocks, update them to use the appropriate values computed as the + // last iteration of the loop. + if (Count != 1) { + SmallPtrSet Users; + for (Value::use_iterator UI = LatchBlock->use_begin(), + UE = LatchBlock->use_end(); UI != UE; ++UI) + if (PHINode *phi = dyn_cast(*UI)) + Users.insert(phi); + + BasicBlock *LastIterationBB = cast(LastValueMap[LatchBlock]); + for (SmallPtrSet::iterator SI = Users.begin(), SE = Users.end(); + SI != SE; ++SI) { + PHINode *PN = *SI; + Value *InVal = PN->removeIncomingValue(LatchBlock, false); + // If this value was defined in the loop, take the value defined by the + // last iteration of the loop. + if (Instruction *InValI = dyn_cast(InVal)) { + if (L->contains(InValI->getParent())) + InVal = LastValueMap[InVal]; + } + PN->addIncoming(InVal, LastIterationBB); + } + } + + // Now, if we're doing complete unrolling, loop over the PHI nodes in the + // original block, setting them to their incoming values. + if (CompletelyUnroll) { + BasicBlock *Preheader = L->getLoopPreheader(); + for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) { + PHINode *PN = OrigPHINode[i]; + PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader)); + Header->getInstList().erase(PN); + } + } + + // Now that all the basic blocks for the unrolled iterations are in place, + // set up the branches to connect them. + for (unsigned i = 0, e = Latches.size(); i != e; ++i) { + // The original branch was replicated in each unrolled iteration. + BranchInst *Term = cast(Latches[i]->getTerminator()); + + // The branch destination. + unsigned j = (i + 1) % e; + BasicBlock *Dest = Headers[j]; + bool NeedConditional = true; + + // For a complete unroll, make the last iteration end with a branch + // to the exit block. + if (CompletelyUnroll && j == 0) { + Dest = LoopExit; + NeedConditional = false; + } + + // If we know the trip count or a multiple of it, we can safely use an + // unconditional branch for some iterations. + if (j != BreakoutTrip && (TripMultiple == 0 || j % TripMultiple != 0)) { + NeedConditional = false; + } + + if (NeedConditional) { + // Update the conditional branch's successor for the following + // iteration. + Term->setSuccessor(!ContinueOnTrue, Dest); + } else { + Term->setUnconditionalDest(Dest); + // Merge adjacent basic blocks, if possible. + if (BasicBlock *Fold = FoldBlockIntoPredecessor(Dest)) { + std::replace(Latches.begin(), Latches.end(), Dest, Fold); + std::replace(Headers.begin(), Headers.end(), Dest, Fold); + } + } + } + + // At this point, the code is well formed. We now do a quick sweep over the + // inserted code, doing constant propagation and dead code elimination as we + // go. + const std::vector &NewLoopBlocks = L->getBlocks(); + for (std::vector::const_iterator BB = NewLoopBlocks.begin(), + BBE = NewLoopBlocks.end(); BB != BBE; ++BB) + for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); I != E; ) { + Instruction *Inst = I++; + + if (isInstructionTriviallyDead(Inst)) + (*BB)->getInstList().erase(Inst); + else if (Constant *C = ConstantFoldInstruction(Inst)) { + Inst->replaceAllUsesWith(C); + (*BB)->getInstList().erase(Inst); + } + } + + NumCompletelyUnrolled += CompletelyUnroll; + ++NumUnrolled; + return true; +} diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp new file mode 100644 index 0000000..c433e63 --- /dev/null +++ b/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -0,0 +1,1074 @@ +//===-- LoopUnswitch.cpp - Hoist loop-invariant conditionals in loop ------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass transforms loops that contain branches on loop-invariant conditions +// to have multiple loops. For example, it turns the left into the right code: +// +// for (...) if (lic) +// A for (...) +// if (lic) A; B; C +// B else +// C for (...) +// A; C +// +// This can increase the size of the code exponentially (doubling it every time +// a loop is unswitched) so we only unswitch if the resultant code will be +// smaller than a threshold. +// +// This pass expects LICM to be run before it to hoist invariant conditions out +// of the loop, to make the unswitching opportunity obvious. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-unswitch" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include +#include +using namespace llvm; + +STATISTIC(NumBranches, "Number of branches unswitched"); +STATISTIC(NumSwitches, "Number of switches unswitched"); +STATISTIC(NumSelects , "Number of selects unswitched"); +STATISTIC(NumTrivial , "Number of unswitches that are trivial"); +STATISTIC(NumSimplify, "Number of simplifications of unswitched code"); + +namespace { + cl::opt + Threshold("loop-unswitch-threshold", cl::desc("Max loop size to unswitch"), + cl::init(10), cl::Hidden); + + class VISIBILITY_HIDDEN LoopUnswitch : public LoopPass { + LoopInfo *LI; // Loop information + LPPassManager *LPM; + + // LoopProcessWorklist - Used to check if second loop needs processing + // after RewriteLoopBodyWithConditionConstant rewrites first loop. + std::vector LoopProcessWorklist; + SmallPtrSet UnswitchedVals; + + bool OptimizeForSize; + public: + static char ID; // Pass ID, replacement for typeid + LoopUnswitch(bool Os = false) : + LoopPass((intptr_t)&ID), OptimizeForSize(Os) {} + + bool runOnLoop(Loop *L, LPPassManager &LPM); + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG... + /// + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(LoopSimplifyID); + AU.addPreservedID(LoopSimplifyID); + AU.addRequired(); + AU.addPreserved(); + AU.addRequiredID(LCSSAID); + } + + private: + /// RemoveLoopFromWorklist - If the specified loop is on the loop worklist, + /// remove it. + void RemoveLoopFromWorklist(Loop *L) { + std::vector::iterator I = std::find(LoopProcessWorklist.begin(), + LoopProcessWorklist.end(), L); + if (I != LoopProcessWorklist.end()) + LoopProcessWorklist.erase(I); + } + + bool UnswitchIfProfitable(Value *LoopCond, Constant *Val,Loop *L); + unsigned getLoopUnswitchCost(Loop *L, Value *LIC); + void UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, + BasicBlock *ExitBlock); + void UnswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L); + + void RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, + Constant *Val, bool isEqual); + + void EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, + BasicBlock *TrueDest, + BasicBlock *FalseDest, + Instruction *InsertPt); + + void SimplifyCode(std::vector &Worklist); + void RemoveBlockIfDead(BasicBlock *BB, + std::vector &Worklist); + void RemoveLoopFromHierarchy(Loop *L); + }; + char LoopUnswitch::ID = 0; + RegisterPass X("loop-unswitch", "Unswitch loops"); +} + +LoopPass *llvm::createLoopUnswitchPass(bool Os) { + return new LoopUnswitch(Os); +} + +/// FindLIVLoopCondition - Cond is a condition that occurs in L. If it is +/// invariant in the loop, or has an invariant piece, return the invariant. +/// Otherwise, return null. +static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { + // Constants should be folded, not unswitched on! + if (isa(Cond)) return false; + + // TODO: Handle: br (VARIANT|INVARIANT). + // TODO: Hoist simple expressions out of loops. + if (L->isLoopInvariant(Cond)) return Cond; + + if (BinaryOperator *BO = dyn_cast(Cond)) + if (BO->getOpcode() == Instruction::And || + BO->getOpcode() == Instruction::Or) { + // If either the left or right side is invariant, we can unswitch on this, + // which will cause the branch to go away in one loop and the condition to + // simplify in the other one. + if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed)) + return LHS; + if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed)) + return RHS; + } + + return 0; +} + +bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { + assert(L->isLCSSAForm()); + LI = &getAnalysis(); + LPM = &LPM_Ref; + bool Changed = false; + + // Loop over all of the basic blocks in the loop. If we find an interior + // block that is branching on a loop-invariant condition, we can unswitch this + // loop. + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) { + TerminatorInst *TI = (*I)->getTerminator(); + if (BranchInst *BI = dyn_cast(TI)) { + // If this isn't branching on an invariant condition, we can't unswitch + // it. + if (BI->isConditional()) { + // See if this, or some part of it, is loop invariant. If so, we can + // unswitch on it if we desire. + Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), L, Changed); + if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(), + L)) { + ++NumBranches; + return true; + } + } + } else if (SwitchInst *SI = dyn_cast(TI)) { + Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), L, Changed); + if (LoopCond && SI->getNumCases() > 1) { + // Find a value to unswitch on: + // FIXME: this should chose the most expensive case! + Constant *UnswitchVal = SI->getCaseValue(1); + // Do not process same value again and again. + if (!UnswitchedVals.insert(UnswitchVal)) + continue; + + if (UnswitchIfProfitable(LoopCond, UnswitchVal, L)) { + ++NumSwitches; + return true; + } + } + } + + // Scan the instructions to check for unswitchable values. + for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end(); + BBI != E; ++BBI) + if (SelectInst *SI = dyn_cast(BBI)) { + Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), L, Changed); + if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(), + L)) { + ++NumSelects; + return true; + } + } + } + + assert(L->isLCSSAForm()); + + return Changed; +} + +/// isTrivialLoopExitBlock - Check to see if all paths from BB either: +/// 1. Exit the loop with no side effects. +/// 2. Branch to the latch block with no side-effects. +/// +/// If these conditions are true, we return true and set ExitBB to the block we +/// exit through. +/// +static bool isTrivialLoopExitBlockHelper(Loop *L, BasicBlock *BB, + BasicBlock *&ExitBB, + std::set &Visited) { + if (!Visited.insert(BB).second) { + // Already visited and Ok, end of recursion. + return true; + } else if (!L->contains(BB)) { + // Otherwise, this is a loop exit, this is fine so long as this is the + // first exit. + if (ExitBB != 0) return false; + ExitBB = BB; + return true; + } + + // Otherwise, this is an unvisited intra-loop node. Check all successors. + for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI) { + // Check to see if the successor is a trivial loop exit. + if (!isTrivialLoopExitBlockHelper(L, *SI, ExitBB, Visited)) + return false; + } + + // Okay, everything after this looks good, check to make sure that this block + // doesn't include any side effects. + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (I->mayWriteToMemory()) + return false; + + return true; +} + +/// isTrivialLoopExitBlock - Return true if the specified block unconditionally +/// leads to an exit from the specified loop, and has no side-effects in the +/// process. If so, return the block that is exited to, otherwise return null. +static BasicBlock *isTrivialLoopExitBlock(Loop *L, BasicBlock *BB) { + std::set Visited; + Visited.insert(L->getHeader()); // Branches to header are ok. + BasicBlock *ExitBB = 0; + if (isTrivialLoopExitBlockHelper(L, BB, ExitBB, Visited)) + return ExitBB; + return 0; +} + +/// IsTrivialUnswitchCondition - Check to see if this unswitch condition is +/// trivial: that is, that the condition controls whether or not the loop does +/// anything at all. If this is a trivial condition, unswitching produces no +/// code duplications (equivalently, it produces a simpler loop and a new empty +/// loop, which gets deleted). +/// +/// If this is a trivial condition, return true, otherwise return false. When +/// returning true, this sets Cond and Val to the condition that controls the +/// trivial condition: when Cond dynamically equals Val, the loop is known to +/// exit. Finally, this sets LoopExit to the BB that the loop exits to when +/// Cond == Val. +/// +static bool IsTrivialUnswitchCondition(Loop *L, Value *Cond, Constant **Val = 0, + BasicBlock **LoopExit = 0) { + BasicBlock *Header = L->getHeader(); + TerminatorInst *HeaderTerm = Header->getTerminator(); + + BasicBlock *LoopExitBB = 0; + if (BranchInst *BI = dyn_cast(HeaderTerm)) { + // If the header block doesn't end with a conditional branch on Cond, we + // can't handle it. + if (!BI->isConditional() || BI->getCondition() != Cond) + return false; + + // Check to see if a successor of the branch is guaranteed to go to the + // latch block or exit through a one exit block without having any + // side-effects. If so, determine the value of Cond that causes it to do + // this. + if ((LoopExitBB = isTrivialLoopExitBlock(L, BI->getSuccessor(0)))) { + if (Val) *Val = ConstantInt::getTrue(); + } else if ((LoopExitBB = isTrivialLoopExitBlock(L, BI->getSuccessor(1)))) { + if (Val) *Val = ConstantInt::getFalse(); + } + } else if (SwitchInst *SI = dyn_cast(HeaderTerm)) { + // If this isn't a switch on Cond, we can't handle it. + if (SI->getCondition() != Cond) return false; + + // Check to see if a successor of the switch is guaranteed to go to the + // latch block or exit through a one exit block without having any + // side-effects. If so, determine the value of Cond that causes it to do + // this. Note that we can't trivially unswitch on the default case. + for (unsigned i = 1, e = SI->getNumSuccessors(); i != e; ++i) + if ((LoopExitBB = isTrivialLoopExitBlock(L, SI->getSuccessor(i)))) { + // Okay, we found a trivial case, remember the value that is trivial. + if (Val) *Val = SI->getCaseValue(i); + break; + } + } + + // If we didn't find a single unique LoopExit block, or if the loop exit block + // contains phi nodes, this isn't trivial. + if (!LoopExitBB || isa(LoopExitBB->begin())) + return false; // Can't handle this. + + if (LoopExit) *LoopExit = LoopExitBB; + + // We already know that nothing uses any scalar values defined inside of this + // loop. As such, we just have to check to see if this loop will execute any + // side-effecting instructions (e.g. stores, calls, volatile loads) in the + // part of the loop that the code *would* execute. We already checked the + // tail, check the header now. + for (BasicBlock::iterator I = Header->begin(), E = Header->end(); I != E; ++I) + if (I->mayWriteToMemory()) + return false; + return true; +} + +/// getLoopUnswitchCost - Return the cost (code size growth) that will happen if +/// we choose to unswitch the specified loop on the specified value. +/// +unsigned LoopUnswitch::getLoopUnswitchCost(Loop *L, Value *LIC) { + // If the condition is trivial, always unswitch. There is no code growth for + // this case. + if (IsTrivialUnswitchCondition(L, LIC)) + return 0; + + // FIXME: This is really overly conservative. However, more liberal + // estimations have thus far resulted in excessive unswitching, which is bad + // both in compile time and in code size. This should be replaced once + // someone figures out how a good estimation. + return L->getBlocks().size(); + + unsigned Cost = 0; + // FIXME: this is brain dead. It should take into consideration code + // shrinkage. + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) { + BasicBlock *BB = *I; + // Do not include empty blocks in the cost calculation. This happen due to + // loop canonicalization and will be removed. + if (BB->begin() == BasicBlock::iterator(BB->getTerminator())) + continue; + + // Count basic blocks. + ++Cost; + } + + return Cost; +} + +/// UnswitchIfProfitable - We have found that we can unswitch L when +/// LoopCond == Val to simplify the loop. If we decide that this is profitable, +/// unswitch the loop, reprocess the pieces, then return true. +bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val,Loop *L){ + // Check to see if it would be profitable to unswitch this loop. + unsigned Cost = getLoopUnswitchCost(L, LoopCond); + + // Do not do non-trivial unswitch while optimizing for size. + if (Cost && OptimizeForSize) + return false; + + if (Cost > Threshold) { + // FIXME: this should estimate growth by the amount of code shared by the + // resultant unswitched loops. + // + DOUT << "NOT unswitching loop %" + << L->getHeader()->getName() << ", cost too high: " + << L->getBlocks().size() << "\n"; + return false; + } + + // If this is a trivial condition to unswitch (which results in no code + // duplication), do it now. + Constant *CondVal; + BasicBlock *ExitBlock; + if (IsTrivialUnswitchCondition(L, LoopCond, &CondVal, &ExitBlock)) { + UnswitchTrivialCondition(L, LoopCond, CondVal, ExitBlock); + } else { + UnswitchNontrivialCondition(LoopCond, Val, L); + } + + return true; +} + +// RemapInstruction - Convert the instruction operands from referencing the +// current values into those specified by ValueMap. +// +static inline void RemapInstruction(Instruction *I, + DenseMap &ValueMap) { + for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) { + Value *Op = I->getOperand(op); + DenseMap::iterator It = ValueMap.find(Op); + if (It != ValueMap.end()) Op = It->second; + I->setOperand(op, Op); + } +} + +// CloneDomInfo - NewBB is cloned from Orig basic block. Now clone Dominator Info. +// If Orig is in Loop then find and use Orig dominator's cloned block as NewBB +// dominator. +void CloneDomInfo(BasicBlock *NewBB, BasicBlock *Orig, Loop *L, + DominatorTree *DT, DominanceFrontier *DF, + DenseMap &VM) { + + DomTreeNode *OrigNode = DT->getNode(Orig); + if (!OrigNode) + return; + BasicBlock *OrigIDom = OrigNode->getBlock(); + BasicBlock *NewIDom = OrigIDom; + if (L->contains(OrigIDom)) { + if (!DT->getNode(OrigIDom)) + CloneDomInfo(NewIDom, OrigIDom, L, DT, DF, VM); + NewIDom = cast(VM[OrigIDom]); + } + if (NewBB == NewIDom) { + DT->addNewBlock(NewBB, OrigIDom); + DT->changeImmediateDominator(NewBB, NewIDom); + } else + DT->addNewBlock(NewBB, NewIDom); + + DominanceFrontier::DomSetType NewDFSet; + if (DF) { + DominanceFrontier::iterator DFI = DF->find(Orig); + if ( DFI != DF->end()) { + DominanceFrontier::DomSetType S = DFI->second; + for (DominanceFrontier::DomSetType::iterator I = S.begin(), E = S.end(); + I != E; ++I) { + BasicBlock *BB = *I; + if (L->contains(BB)) + NewDFSet.insert(cast(VM[Orig])); + else + NewDFSet.insert(BB); + } + } + DF->addBasicBlock(NewBB, NewDFSet); + } +} + +/// CloneLoop - Recursively clone the specified loop and all of its children, +/// mapping the blocks with the specified map. +static Loop *CloneLoop(Loop *L, Loop *PL, DenseMap &VM, + LoopInfo *LI, LPPassManager *LPM) { + Loop *New = new Loop(); + + LPM->insertLoop(New, PL); + + // Add all of the blocks in L to the new loop. + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) + if (LI->getLoopFor(*I) == L) + New->addBasicBlockToLoop(cast(VM[*I]), *LI); + + // Add all of the subloops to the new loop. + for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) + CloneLoop(*I, New, VM, LI, LPM); + + return New; +} + +/// EmitPreheaderBranchOnCondition - Emit a conditional branch on two values +/// if LIC == Val, branch to TrueDst, otherwise branch to FalseDest. Insert the +/// code immediately before InsertPt. +void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, + BasicBlock *TrueDest, + BasicBlock *FalseDest, + Instruction *InsertPt) { + // Insert a conditional branch on LIC to the two preheaders. The original + // code is the true version and the new code is the false version. + Value *BranchVal = LIC; + if (!isa(Val) || Val->getType() != Type::Int1Ty) + BranchVal = new ICmpInst(ICmpInst::ICMP_EQ, LIC, Val, "tmp", InsertPt); + else if (Val != ConstantInt::getTrue()) + // We want to enter the new loop when the condition is true. + std::swap(TrueDest, FalseDest); + + // Insert the new branch. + BranchInst *BRI = new BranchInst(TrueDest, FalseDest, BranchVal, InsertPt); + + // Update dominator info. + // BranchVal is a new preheader so it dominates true and false destination + // loop headers. + if (DominatorTree *DT = getAnalysisToUpdate()) { + DT->changeImmediateDominator(TrueDest, BRI->getParent()); + DT->changeImmediateDominator(FalseDest, BRI->getParent()); + } + // No need to update DominanceFrontier. BRI->getParent() dominated TrueDest + // and FalseDest anyway. Now it immediately dominates them. +} + + +/// UnswitchTrivialCondition - Given a loop that has a trivial unswitchable +/// condition in it (a cond branch from its header block to its latch block, +/// where the path through the loop that doesn't execute its body has no +/// side-effects), unswitch it. This doesn't involve any code duplication, just +/// moving the conditional branch outside of the loop and updating loop info. +void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, + Constant *Val, + BasicBlock *ExitBlock) { + DOUT << "loop-unswitch: Trivial-Unswitch loop %" + << L->getHeader()->getName() << " [" << L->getBlocks().size() + << " blocks] in Function " << L->getHeader()->getParent()->getName() + << " on cond: " << *Val << " == " << *Cond << "\n"; + + // First step, split the preheader, so that we know that there is a safe place + // to insert the conditional branch. We will change 'OrigPH' to have a + // conditional branch on Cond. + BasicBlock *OrigPH = L->getLoopPreheader(); + BasicBlock *NewPH = SplitEdge(OrigPH, L->getHeader(), this); + + // Now that we have a place to insert the conditional branch, create a place + // to branch to: this is the exit block out of the loop that we should + // short-circuit to. + + // Split this block now, so that the loop maintains its exit block, and so + // that the jump from the preheader can execute the contents of the exit block + // without actually branching to it (the exit block should be dominated by the + // loop header, not the preheader). + assert(!L->contains(ExitBlock) && "Exit block is in the loop?"); + BasicBlock *NewExit = SplitBlock(ExitBlock, ExitBlock->begin(), this); + + // Okay, now we have a position to branch from and a position to branch to, + // insert the new conditional branch. + EmitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, + OrigPH->getTerminator()); + OrigPH->getTerminator()->eraseFromParent(); + + // We need to reprocess this loop, it could be unswitched again. + LPM->redoLoop(L); + + // Now that we know that the loop is never entered when this condition is a + // particular value, rewrite the loop with this info. We know that this will + // at least eliminate the old branch. + RewriteLoopBodyWithConditionConstant(L, Cond, Val, false); + ++NumTrivial; +} + +/// VersionLoop - We determined that the loop is profitable to unswitch when LIC +/// equal Val. Split it into loop versions and test the condition outside of +/// either loop. Return the loops created as Out1/Out2. +void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, + Loop *L) { + Function *F = L->getHeader()->getParent(); + DOUT << "loop-unswitch: Unswitching loop %" + << L->getHeader()->getName() << " [" << L->getBlocks().size() + << " blocks] in Function " << F->getName() + << " when '" << *Val << "' == " << *LIC << "\n"; + + // LoopBlocks contains all of the basic blocks of the loop, including the + // preheader of the loop, the body of the loop, and the exit blocks of the + // loop, in that order. + std::vector LoopBlocks; + + // First step, split the preheader and exit blocks, and add these blocks to + // the LoopBlocks list. + BasicBlock *OrigPreheader = L->getLoopPreheader(); + LoopBlocks.push_back(SplitEdge(OrigPreheader, L->getHeader(), this)); + + // We want the loop to come after the preheader, but before the exit blocks. + LoopBlocks.insert(LoopBlocks.end(), L->block_begin(), L->block_end()); + + std::vector ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + + // Split all of the edges from inside the loop to their exit blocks. Update + // the appropriate Phi nodes as we do so. + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + BasicBlock *ExitBlock = ExitBlocks[i]; + std::vector Preds(pred_begin(ExitBlock), pred_end(ExitBlock)); + + for (unsigned j = 0, e = Preds.size(); j != e; ++j) { + BasicBlock* MiddleBlock = SplitEdge(Preds[j], ExitBlock, this); + BasicBlock* StartBlock = Preds[j]; + BasicBlock* EndBlock; + if (MiddleBlock->getSinglePredecessor() == ExitBlock) { + EndBlock = MiddleBlock; + MiddleBlock = EndBlock->getSinglePredecessor();; + } else { + EndBlock = ExitBlock; + } + + std::set InsertedPHIs; + PHINode* OldLCSSA = 0; + for (BasicBlock::iterator I = EndBlock->begin(); + (OldLCSSA = dyn_cast(I)); ++I) { + Value* OldValue = OldLCSSA->getIncomingValueForBlock(MiddleBlock); + PHINode* NewLCSSA = new PHINode(OldLCSSA->getType(), + OldLCSSA->getName() + ".us-lcssa", + MiddleBlock->getTerminator()); + NewLCSSA->addIncoming(OldValue, StartBlock); + OldLCSSA->setIncomingValue(OldLCSSA->getBasicBlockIndex(MiddleBlock), + NewLCSSA); + InsertedPHIs.insert(NewLCSSA); + } + + BasicBlock::iterator InsertPt = EndBlock->begin(); + while (dyn_cast(InsertPt)) ++InsertPt; + for (BasicBlock::iterator I = MiddleBlock->begin(); + (OldLCSSA = dyn_cast(I)) && InsertedPHIs.count(OldLCSSA) == 0; + ++I) { + PHINode *NewLCSSA = new PHINode(OldLCSSA->getType(), + OldLCSSA->getName() + ".us-lcssa", + InsertPt); + OldLCSSA->replaceAllUsesWith(NewLCSSA); + NewLCSSA->addIncoming(OldLCSSA, MiddleBlock); + } + } + } + + // The exit blocks may have been changed due to edge splitting, recompute. + ExitBlocks.clear(); + L->getUniqueExitBlocks(ExitBlocks); + + // Add exit blocks to the loop blocks. + LoopBlocks.insert(LoopBlocks.end(), ExitBlocks.begin(), ExitBlocks.end()); + + // Next step, clone all of the basic blocks that make up the loop (including + // the loop preheader and exit blocks), keeping track of the mapping between + // the instructions and blocks. + std::vector NewBlocks; + NewBlocks.reserve(LoopBlocks.size()); + DenseMap ValueMap; + for (unsigned i = 0, e = LoopBlocks.size(); i != e; ++i) { + BasicBlock *New = CloneBasicBlock(LoopBlocks[i], ValueMap, ".us", F); + NewBlocks.push_back(New); + ValueMap[LoopBlocks[i]] = New; // Keep the BB mapping. + } + + // Update dominator info + DominanceFrontier *DF = getAnalysisToUpdate(); + if (DominatorTree *DT = getAnalysisToUpdate()) + for (unsigned i = 0, e = LoopBlocks.size(); i != e; ++i) { + BasicBlock *LBB = LoopBlocks[i]; + BasicBlock *NBB = NewBlocks[i]; + CloneDomInfo(NBB, LBB, L, DT, DF, ValueMap); + } + + // Splice the newly inserted blocks into the function right before the + // original preheader. + F->getBasicBlockList().splice(LoopBlocks[0], F->getBasicBlockList(), + NewBlocks[0], F->end()); + + // Now we create the new Loop object for the versioned loop. + Loop *NewLoop = CloneLoop(L, L->getParentLoop(), ValueMap, LI, LPM); + Loop *ParentLoop = L->getParentLoop(); + if (ParentLoop) { + // Make sure to add the cloned preheader and exit blocks to the parent loop + // as well. + ParentLoop->addBasicBlockToLoop(NewBlocks[0], *LI); + } + + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + BasicBlock *NewExit = cast(ValueMap[ExitBlocks[i]]); + // The new exit block should be in the same loop as the old one. + if (Loop *ExitBBLoop = LI->getLoopFor(ExitBlocks[i])) + ExitBBLoop->addBasicBlockToLoop(NewExit, *LI); + + assert(NewExit->getTerminator()->getNumSuccessors() == 1 && + "Exit block should have been split to have one successor!"); + BasicBlock *ExitSucc = NewExit->getTerminator()->getSuccessor(0); + + // If the successor of the exit block had PHI nodes, add an entry for + // NewExit. + PHINode *PN; + for (BasicBlock::iterator I = ExitSucc->begin(); + (PN = dyn_cast(I)); ++I) { + Value *V = PN->getIncomingValueForBlock(ExitBlocks[i]); + DenseMap::iterator It = ValueMap.find(V); + if (It != ValueMap.end()) V = It->second; + PN->addIncoming(V, NewExit); + } + } + + // Rewrite the code to refer to itself. + for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) + for (BasicBlock::iterator I = NewBlocks[i]->begin(), + E = NewBlocks[i]->end(); I != E; ++I) + RemapInstruction(I, ValueMap); + + // Rewrite the original preheader to select between versions of the loop. + BranchInst *OldBR = cast(OrigPreheader->getTerminator()); + assert(OldBR->isUnconditional() && OldBR->getSuccessor(0) == LoopBlocks[0] && + "Preheader splitting did not work correctly!"); + + // Emit the new branch that selects between the two versions of this loop. + EmitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR); + OldBR->eraseFromParent(); + + LoopProcessWorklist.push_back(NewLoop); + LPM->redoLoop(L); + + // Now we rewrite the original code to know that the condition is true and the + // new code to know that the condition is false. + RewriteLoopBodyWithConditionConstant(L , LIC, Val, false); + + // It's possible that simplifying one loop could cause the other to be + // deleted. If so, don't simplify it. + if (!LoopProcessWorklist.empty() && LoopProcessWorklist.back() == NewLoop) + RewriteLoopBodyWithConditionConstant(NewLoop, LIC, Val, true); +} + +/// RemoveFromWorklist - Remove all instances of I from the worklist vector +/// specified. +static void RemoveFromWorklist(Instruction *I, + std::vector &Worklist) { + std::vector::iterator WI = std::find(Worklist.begin(), + Worklist.end(), I); + while (WI != Worklist.end()) { + unsigned Offset = WI-Worklist.begin(); + Worklist.erase(WI); + WI = std::find(Worklist.begin()+Offset, Worklist.end(), I); + } +} + +/// ReplaceUsesOfWith - When we find that I really equals V, remove I from the +/// program, replacing all uses with V and update the worklist. +static void ReplaceUsesOfWith(Instruction *I, Value *V, + std::vector &Worklist) { + DOUT << "Replace with '" << *V << "': " << *I; + + // Add uses to the worklist, which may be dead now. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Instruction *Use = dyn_cast(I->getOperand(i))) + Worklist.push_back(Use); + + // Add users to the worklist which may be simplified now. + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) + Worklist.push_back(cast(*UI)); + I->replaceAllUsesWith(V); + I->eraseFromParent(); + RemoveFromWorklist(I, Worklist); + ++NumSimplify; +} + +/// RemoveBlockIfDead - If the specified block is dead, remove it, update loop +/// information, and remove any dead successors it has. +/// +void LoopUnswitch::RemoveBlockIfDead(BasicBlock *BB, + std::vector &Worklist) { + if (pred_begin(BB) != pred_end(BB)) { + // This block isn't dead, since an edge to BB was just removed, see if there + // are any easy simplifications we can do now. + if (BasicBlock *Pred = BB->getSinglePredecessor()) { + // If it has one pred, fold phi nodes in BB. + while (isa(BB->begin())) + ReplaceUsesOfWith(BB->begin(), + cast(BB->begin())->getIncomingValue(0), + Worklist); + + // If this is the header of a loop and the only pred is the latch, we now + // have an unreachable loop. + if (Loop *L = LI->getLoopFor(BB)) + if (L->getHeader() == BB && L->contains(Pred)) { + // Remove the branch from the latch to the header block, this makes + // the header dead, which will make the latch dead (because the header + // dominates the latch). + Pred->getTerminator()->eraseFromParent(); + new UnreachableInst(Pred); + + // The loop is now broken, remove it from LI. + RemoveLoopFromHierarchy(L); + + // Reprocess the header, which now IS dead. + RemoveBlockIfDead(BB, Worklist); + return; + } + + // If pred ends in a uncond branch, add uncond branch to worklist so that + // the two blocks will get merged. + if (BranchInst *BI = dyn_cast(Pred->getTerminator())) + if (BI->isUnconditional()) + Worklist.push_back(BI); + } + return; + } + + DOUT << "Nuking dead block: " << *BB; + + // Remove the instructions in the basic block from the worklist. + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + RemoveFromWorklist(I, Worklist); + + // Anything that uses the instructions in this basic block should have their + // uses replaced with undefs. + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + } + + // If this is the edge to the header block for a loop, remove the loop and + // promote all subloops. + if (Loop *BBLoop = LI->getLoopFor(BB)) { + if (BBLoop->getLoopLatch() == BB) + RemoveLoopFromHierarchy(BBLoop); + } + + // Remove the block from the loop info, which removes it from any loops it + // was in. + LI->removeBlock(BB); + + + // Remove phi node entries in successors for this block. + TerminatorInst *TI = BB->getTerminator(); + std::vector Succs; + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + Succs.push_back(TI->getSuccessor(i)); + TI->getSuccessor(i)->removePredecessor(BB); + } + + // Unique the successors, remove anything with multiple uses. + std::sort(Succs.begin(), Succs.end()); + Succs.erase(std::unique(Succs.begin(), Succs.end()), Succs.end()); + + // Remove the basic block, including all of the instructions contained in it. + BB->eraseFromParent(); + + // Remove successor blocks here that are not dead, so that we know we only + // have dead blocks in this list. Nondead blocks have a way of becoming dead, + // then getting removed before we revisit them, which is badness. + // + for (unsigned i = 0; i != Succs.size(); ++i) + if (pred_begin(Succs[i]) != pred_end(Succs[i])) { + // One exception is loop headers. If this block was the preheader for a + // loop, then we DO want to visit the loop so the loop gets deleted. + // We know that if the successor is a loop header, that this loop had to + // be the preheader: the case where this was the latch block was handled + // above and headers can only have two predecessors. + if (!LI->isLoopHeader(Succs[i])) { + Succs.erase(Succs.begin()+i); + --i; + } + } + + for (unsigned i = 0, e = Succs.size(); i != e; ++i) + RemoveBlockIfDead(Succs[i], Worklist); +} + +/// RemoveLoopFromHierarchy - We have discovered that the specified loop has +/// become unwrapped, either because the backedge was deleted, or because the +/// edge into the header was removed. If the edge into the header from the +/// latch block was removed, the loop is unwrapped but subloops are still alive, +/// so they just reparent loops. If the loops are actually dead, they will be +/// removed later. +void LoopUnswitch::RemoveLoopFromHierarchy(Loop *L) { + LPM->deleteLoopFromQueue(L); + RemoveLoopFromWorklist(L); +} + + + +// RewriteLoopBodyWithConditionConstant - We know either that the value LIC has +// the value specified by Val in the specified loop, or we know it does NOT have +// that value. Rewrite any uses of LIC or of properties correlated to it. +void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, + Constant *Val, + bool IsEqual) { + assert(!isa(LIC) && "Why are we unswitching on a constant?"); + + // FIXME: Support correlated properties, like: + // for (...) + // if (li1 < li2) + // ... + // if (li1 > li2) + // ... + + // FOLD boolean conditions (X|LIC), (X&LIC). Fold conditional branches, + // selects, switches. + std::vector Users(LIC->use_begin(), LIC->use_end()); + std::vector Worklist; + + // If we know that LIC == Val, or that LIC == NotVal, just replace uses of LIC + // in the loop with the appropriate one directly. + if (IsEqual || (isa(Val) && Val->getType() == Type::Int1Ty)) { + Value *Replacement; + if (IsEqual) + Replacement = Val; + else + Replacement = ConstantInt::get(Type::Int1Ty, + !cast(Val)->getZExtValue()); + + for (unsigned i = 0, e = Users.size(); i != e; ++i) + if (Instruction *U = cast(Users[i])) { + if (!L->contains(U->getParent())) + continue; + U->replaceUsesOfWith(LIC, Replacement); + Worklist.push_back(U); + } + } else { + // Otherwise, we don't know the precise value of LIC, but we do know that it + // is certainly NOT "Val". As such, simplify any uses in the loop that we + // can. This case occurs when we unswitch switch statements. + for (unsigned i = 0, e = Users.size(); i != e; ++i) + if (Instruction *U = cast(Users[i])) { + if (!L->contains(U->getParent())) + continue; + + Worklist.push_back(U); + + // If we know that LIC is not Val, use this info to simplify code. + if (SwitchInst *SI = dyn_cast(U)) { + for (unsigned i = 1, e = SI->getNumCases(); i != e; ++i) { + if (SI->getCaseValue(i) == Val) { + // Found a dead case value. Don't remove PHI nodes in the + // successor if they become single-entry, those PHI nodes may + // be in the Users list. + + // FIXME: This is a hack. We need to keep the successor around + // and hooked up so as to preserve the loop structure, because + // trying to update it is complicated. So instead we preserve the + // loop structure and put the block on an dead code path. + + BasicBlock* Old = SI->getParent(); + BasicBlock* Split = SplitBlock(Old, SI, this); + + Instruction* OldTerm = Old->getTerminator(); + new BranchInst(Split, SI->getSuccessor(i), + ConstantInt::getTrue(), OldTerm); + + Old->getTerminator()->eraseFromParent(); + + + PHINode *PN; + for (BasicBlock::iterator II = SI->getSuccessor(i)->begin(); + (PN = dyn_cast(II)); ++II) { + Value *InVal = PN->removeIncomingValue(Split, false); + PN->addIncoming(InVal, Old); + } + + SI->removeCase(i); + break; + } + } + } + + // TODO: We could do other simplifications, for example, turning + // LIC == Val -> false. + } + } + + SimplifyCode(Worklist); +} + +/// SimplifyCode - Okay, now that we have simplified some instructions in the +/// loop, walk over it and constant prop, dce, and fold control flow where +/// possible. Note that this is effectively a very simple loop-structure-aware +/// optimizer. During processing of this loop, L could very well be deleted, so +/// it must not be used. +/// +/// FIXME: When the loop optimizer is more mature, separate this out to a new +/// pass. +/// +void LoopUnswitch::SimplifyCode(std::vector &Worklist) { + while (!Worklist.empty()) { + Instruction *I = Worklist.back(); + Worklist.pop_back(); + + // Simple constant folding. + if (Constant *C = ConstantFoldInstruction(I)) { + ReplaceUsesOfWith(I, C, Worklist); + continue; + } + + // Simple DCE. + if (isInstructionTriviallyDead(I)) { + DOUT << "Remove dead instruction '" << *I; + + // Add uses to the worklist, which may be dead now. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Instruction *Use = dyn_cast(I->getOperand(i))) + Worklist.push_back(Use); + I->eraseFromParent(); + RemoveFromWorklist(I, Worklist); + ++NumSimplify; + continue; + } + + // Special case hacks that appear commonly in unswitched code. + switch (I->getOpcode()) { + case Instruction::Select: + if (ConstantInt *CB = dyn_cast(I->getOperand(0))) { + ReplaceUsesOfWith(I, I->getOperand(!CB->getZExtValue()+1), Worklist); + continue; + } + break; + case Instruction::And: + if (isa(I->getOperand(0)) && + I->getOperand(0)->getType() == Type::Int1Ty) // constant -> RHS + cast(I)->swapOperands(); + if (ConstantInt *CB = dyn_cast(I->getOperand(1))) + if (CB->getType() == Type::Int1Ty) { + if (CB->isOne()) // X & 1 -> X + ReplaceUsesOfWith(I, I->getOperand(0), Worklist); + else // X & 0 -> 0 + ReplaceUsesOfWith(I, I->getOperand(1), Worklist); + continue; + } + break; + case Instruction::Or: + if (isa(I->getOperand(0)) && + I->getOperand(0)->getType() == Type::Int1Ty) // constant -> RHS + cast(I)->swapOperands(); + if (ConstantInt *CB = dyn_cast(I->getOperand(1))) + if (CB->getType() == Type::Int1Ty) { + if (CB->isOne()) // X | 1 -> 1 + ReplaceUsesOfWith(I, I->getOperand(1), Worklist); + else // X | 0 -> X + ReplaceUsesOfWith(I, I->getOperand(0), Worklist); + continue; + } + break; + case Instruction::Br: { + BranchInst *BI = cast(I); + if (BI->isUnconditional()) { + // If BI's parent is the only pred of the successor, fold the two blocks + // together. + BasicBlock *Pred = BI->getParent(); + BasicBlock *Succ = BI->getSuccessor(0); + BasicBlock *SinglePred = Succ->getSinglePredecessor(); + if (!SinglePred) continue; // Nothing to do. + assert(SinglePred == Pred && "CFG broken"); + + DOUT << "Merging blocks: " << Pred->getName() << " <- " + << Succ->getName() << "\n"; + + // Resolve any single entry PHI nodes in Succ. + while (PHINode *PN = dyn_cast(Succ->begin())) + ReplaceUsesOfWith(PN, PN->getIncomingValue(0), Worklist); + + // Move all of the successor contents from Succ to Pred. + Pred->getInstList().splice(BI, Succ->getInstList(), Succ->begin(), + Succ->end()); + BI->eraseFromParent(); + RemoveFromWorklist(BI, Worklist); + + // If Succ has any successors with PHI nodes, update them to have + // entries coming from Pred instead of Succ. + Succ->replaceAllUsesWith(Pred); + + // Remove Succ from the loop tree. + LI->removeBlock(Succ); + Succ->eraseFromParent(); + ++NumSimplify; + } else if (ConstantInt *CB = dyn_cast(BI->getCondition())){ + // Conditional branch. Turn it into an unconditional branch, then + // remove dead blocks. + break; // FIXME: Enable. + + DOUT << "Folded branch: " << *BI; + BasicBlock *DeadSucc = BI->getSuccessor(CB->getZExtValue()); + BasicBlock *LiveSucc = BI->getSuccessor(!CB->getZExtValue()); + DeadSucc->removePredecessor(BI->getParent(), true); + Worklist.push_back(new BranchInst(LiveSucc, BI)); + BI->eraseFromParent(); + RemoveFromWorklist(BI, Worklist); + ++NumSimplify; + + RemoveBlockIfDead(DeadSucc, Worklist); + } + break; + } + } + } +} diff --git a/lib/Transforms/Scalar/LowerGC.cpp b/lib/Transforms/Scalar/LowerGC.cpp new file mode 100644 index 0000000..27cccd5 --- /dev/null +++ b/lib/Transforms/Scalar/LowerGC.cpp @@ -0,0 +1,330 @@ +//===-- LowerGC.cpp - Provide GC support for targets that don't -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering for the llvm.gc* intrinsics for targets that do +// not natively support them (which includes the C backend). Note that the code +// generated is not as efficient as it would be for targets that natively +// support the GC intrinsics, but it is useful for getting new targets +// up-and-running quickly. +// +// This pass implements the code transformation described in this paper: +// "Accurate Garbage Collection in an Uncooperative Environment" +// Fergus Henderson, ISMM, 2002 +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "lowergc" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +namespace { + class VISIBILITY_HIDDEN LowerGC : public FunctionPass { + /// GCRootInt, GCReadInt, GCWriteInt - The function prototypes for the + /// llvm.gcread/llvm.gcwrite/llvm.gcroot intrinsics. + Function *GCRootInt, *GCReadInt, *GCWriteInt; + + /// GCRead/GCWrite - These are the functions provided by the garbage + /// collector for read/write barriers. + Constant *GCRead, *GCWrite; + + /// RootChain - This is the global linked-list that contains the chain of GC + /// roots. + GlobalVariable *RootChain; + + /// MainRootRecordType - This is the type for a function root entry if it + /// had zero roots. + const Type *MainRootRecordType; + public: + static char ID; // Pass identification, replacement for typeid + LowerGC() : FunctionPass((intptr_t)&ID), + GCRootInt(0), GCReadInt(0), GCWriteInt(0), + GCRead(0), GCWrite(0), RootChain(0), MainRootRecordType(0) {} + virtual bool doInitialization(Module &M); + virtual bool runOnFunction(Function &F); + + private: + const StructType *getRootRecordType(unsigned NumRoots); + }; + + char LowerGC::ID = 0; + RegisterPass + X("lowergc", "Lower GC intrinsics, for GCless code generators"); +} + +/// createLowerGCPass - This function returns an instance of the "lowergc" +/// pass, which lowers garbage collection intrinsics to normal LLVM code. +FunctionPass *llvm::createLowerGCPass() { + return new LowerGC(); +} + +/// getRootRecordType - This function creates and returns the type for a root +/// record containing 'NumRoots' roots. +const StructType *LowerGC::getRootRecordType(unsigned NumRoots) { + // Build a struct that is a type used for meta-data/root pairs. + std::vector ST; + ST.push_back(GCRootInt->getFunctionType()->getParamType(0)); + ST.push_back(GCRootInt->getFunctionType()->getParamType(1)); + StructType *PairTy = StructType::get(ST); + + // Build the array of pairs. + ArrayType *PairArrTy = ArrayType::get(PairTy, NumRoots); + + // Now build the recursive list type. + PATypeHolder RootListH = + MainRootRecordType ? (Type*)MainRootRecordType : (Type*)OpaqueType::get(); + ST.clear(); + ST.push_back(PointerType::get(RootListH)); // Prev pointer + ST.push_back(Type::Int32Ty); // NumElements in array + ST.push_back(PairArrTy); // The pairs + StructType *RootList = StructType::get(ST); + if (MainRootRecordType) + return RootList; + + assert(NumRoots == 0 && "The main struct type should have zero entries!"); + cast((Type*)RootListH.get())->refineAbstractTypeTo(RootList); + MainRootRecordType = RootListH; + return cast(RootListH.get()); +} + +/// doInitialization - If this module uses the GC intrinsics, find them now. If +/// not, this pass does not do anything. +bool LowerGC::doInitialization(Module &M) { + GCRootInt = M.getFunction("llvm.gcroot"); + GCReadInt = M.getFunction("llvm.gcread"); + GCWriteInt = M.getFunction("llvm.gcwrite"); + if (!GCRootInt && !GCReadInt && !GCWriteInt) return false; + + PointerType *VoidPtr = PointerType::get(Type::Int8Ty); + PointerType *VoidPtrPtr = PointerType::get(VoidPtr); + + // If the program is using read/write barriers, find the implementations of + // them from the GC runtime library. + if (GCReadInt) // Make: sbyte* %llvm_gc_read(sbyte**) + GCRead = M.getOrInsertFunction("llvm_gc_read", VoidPtr, VoidPtr, VoidPtrPtr, + (Type *)0); + if (GCWriteInt) // Make: void %llvm_gc_write(sbyte*, sbyte**) + GCWrite = M.getOrInsertFunction("llvm_gc_write", Type::VoidTy, + VoidPtr, VoidPtr, VoidPtrPtr, (Type *)0); + + // If the program has GC roots, get or create the global root list. + if (GCRootInt) { + const StructType *RootListTy = getRootRecordType(0); + const Type *PRLTy = PointerType::get(RootListTy); + M.addTypeName("llvm_gc_root_ty", RootListTy); + + // Get the root chain if it already exists. + RootChain = M.getGlobalVariable("llvm_gc_root_chain", PRLTy); + if (RootChain == 0) { + // If the root chain does not exist, insert a new one with linkonce + // linkage! + RootChain = new GlobalVariable(PRLTy, false, + GlobalValue::LinkOnceLinkage, + Constant::getNullValue(PRLTy), + "llvm_gc_root_chain", &M); + } else if (RootChain->hasExternalLinkage() && RootChain->isDeclaration()) { + RootChain->setInitializer(Constant::getNullValue(PRLTy)); + RootChain->setLinkage(GlobalValue::LinkOnceLinkage); + } + } + return true; +} + +/// Coerce - If the specified operand number of the specified instruction does +/// not have the specified type, insert a cast. Note that this only uses BitCast +/// because the types involved are all pointers. +static void Coerce(Instruction *I, unsigned OpNum, Type *Ty) { + if (I->getOperand(OpNum)->getType() != Ty) { + if (Constant *C = dyn_cast(I->getOperand(OpNum))) + I->setOperand(OpNum, ConstantExpr::getBitCast(C, Ty)); + else { + CastInst *CI = new BitCastInst(I->getOperand(OpNum), Ty, "", I); + I->setOperand(OpNum, CI); + } + } +} + +/// runOnFunction - If the program is using GC intrinsics, replace any +/// read/write intrinsics with the appropriate read/write barrier calls, then +/// inline them. Finally, build the data structures for +bool LowerGC::runOnFunction(Function &F) { + // Quick exit for programs that are not using GC mechanisms. + if (!GCRootInt && !GCReadInt && !GCWriteInt) return false; + + PointerType *VoidPtr = PointerType::get(Type::Int8Ty); + PointerType *VoidPtrPtr = PointerType::get(VoidPtr); + + // If there are read/write barriers in the program, perform a quick pass over + // the function eliminating them. While we are at it, remember where we see + // calls to llvm.gcroot. + std::vector GCRoots; + std::vector NormalCalls; + + bool MadeChange = false; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E;) + if (CallInst *CI = dyn_cast(II++)) { + if (!CI->getCalledFunction() || + !CI->getCalledFunction()->getIntrinsicID()) + NormalCalls.push_back(CI); // Remember all normal function calls. + + if (Function *F = CI->getCalledFunction()) + if (F == GCRootInt) + GCRoots.push_back(CI); + else if (F == GCReadInt || F == GCWriteInt) { + if (F == GCWriteInt) { + // Change a llvm.gcwrite call to call llvm_gc_write instead. + CI->setOperand(0, GCWrite); + // Insert casts of the operands as needed. + Coerce(CI, 1, VoidPtr); + Coerce(CI, 2, VoidPtr); + Coerce(CI, 3, VoidPtrPtr); + } else { + Coerce(CI, 1, VoidPtr); + Coerce(CI, 2, VoidPtrPtr); + if (CI->getType() == VoidPtr) { + CI->setOperand(0, GCRead); + } else { + // Create a whole new call to replace the old one. + CallInst *NC = new CallInst(GCRead, CI->getOperand(1), + CI->getOperand(2), + CI->getName(), CI); + // These functions only deal with ptr type results so BitCast + // is the correct kind of cast (no-op cast). + Value *NV = new BitCastInst(NC, CI->getType(), "", CI); + CI->replaceAllUsesWith(NV); + BB->getInstList().erase(CI); + CI = NC; + } + } + + MadeChange = true; + } + } + + // If there are no GC roots in this function, then there is no need to create + // a GC list record for it. + if (GCRoots.empty()) return MadeChange; + + // Okay, there are GC roots in this function. On entry to the function, add a + // record to the llvm_gc_root_chain, and remove it on exit. + + // Create the alloca, and zero it out. + const StructType *RootListTy = getRootRecordType(GCRoots.size()); + AllocaInst *AI = new AllocaInst(RootListTy, 0, "gcroots", F.begin()->begin()); + + // Insert the memset call after all of the allocas in the function. + BasicBlock::iterator IP = AI; + while (isa(IP)) ++IP; + + Constant *Zero = ConstantInt::get(Type::Int32Ty, 0); + Constant *One = ConstantInt::get(Type::Int32Ty, 1); + + // Get a pointer to the prev pointer. + Value *PrevPtrPtr = new GetElementPtrInst(AI, Zero, Zero, "prevptrptr", IP); + + // Load the previous pointer. + Value *PrevPtr = new LoadInst(RootChain, "prevptr", IP); + // Store the previous pointer into the prevptrptr + new StoreInst(PrevPtr, PrevPtrPtr, IP); + + // Set the number of elements in this record. + Value *NumEltsPtr = new GetElementPtrInst(AI, Zero, One, "numeltsptr", IP); + new StoreInst(ConstantInt::get(Type::Int32Ty, GCRoots.size()), NumEltsPtr,IP); + + Value* Par[4]; + Par[0] = Zero; + Par[1] = ConstantInt::get(Type::Int32Ty, 2); + + const PointerType *PtrLocTy = + cast(GCRootInt->getFunctionType()->getParamType(0)); + Constant *Null = ConstantPointerNull::get(PtrLocTy); + + // Initialize all of the gcroot records now, and eliminate them as we go. + for (unsigned i = 0, e = GCRoots.size(); i != e; ++i) { + // Initialize the meta-data pointer. + Par[2] = ConstantInt::get(Type::Int32Ty, i); + Par[3] = One; + Value *MetaDataPtr = new GetElementPtrInst(AI, Par, 4, "MetaDataPtr", IP); + assert(isa(GCRoots[i]->getOperand(2)) && "Must be a constant"); + new StoreInst(GCRoots[i]->getOperand(2), MetaDataPtr, IP); + + // Initialize the root pointer to null on entry to the function. + Par[3] = Zero; + Value *RootPtrPtr = new GetElementPtrInst(AI, Par, 4, "RootEntPtr", IP); + new StoreInst(Null, RootPtrPtr, IP); + + // Each occurrance of the llvm.gcroot intrinsic now turns into an + // initialization of the slot with the address and a zeroing out of the + // address specified. + new StoreInst(Constant::getNullValue(PtrLocTy->getElementType()), + GCRoots[i]->getOperand(1), GCRoots[i]); + new StoreInst(GCRoots[i]->getOperand(1), RootPtrPtr, GCRoots[i]); + GCRoots[i]->getParent()->getInstList().erase(GCRoots[i]); + } + + // Now that the record is all initialized, store the pointer into the global + // pointer. + Value *C = new BitCastInst(AI, PointerType::get(MainRootRecordType), "", IP); + new StoreInst(C, RootChain, IP); + + // On exit from the function we have to remove the entry from the GC root + // chain. Doing this is straight-forward for return and unwind instructions: + // just insert the appropriate copy. + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (isa(BB->getTerminator()) || + isa(BB->getTerminator())) { + // We could reuse the PrevPtr loaded on entry to the function, but this + // would make the value live for the whole function, which is probably a + // bad idea. Just reload the value out of our stack entry. + PrevPtr = new LoadInst(PrevPtrPtr, "prevptr", BB->getTerminator()); + new StoreInst(PrevPtr, RootChain, BB->getTerminator()); + } + + // If an exception is thrown from a callee we have to make sure to + // unconditionally take the record off the stack. For this reason, we turn + // all call instructions into invoke whose cleanup pops the entry off the + // stack. We only insert one cleanup block, which is shared by all invokes. + if (!NormalCalls.empty()) { + // Create the shared cleanup block. + BasicBlock *Cleanup = new BasicBlock("gc_cleanup", &F); + UnwindInst *UI = new UnwindInst(Cleanup); + PrevPtr = new LoadInst(PrevPtrPtr, "prevptr", UI); + new StoreInst(PrevPtr, RootChain, UI); + + // Loop over all of the function calls, turning them into invokes. + while (!NormalCalls.empty()) { + CallInst *CI = NormalCalls.back(); + BasicBlock *CBB = CI->getParent(); + NormalCalls.pop_back(); + + // Split the basic block containing the function call. + BasicBlock *NewBB = CBB->splitBasicBlock(CI, CBB->getName()+".cont"); + + // Remove the unconditional branch inserted at the end of the CBB. + CBB->getInstList().pop_back(); + NewBB->getInstList().remove(CI); + + // Create a new invoke instruction. + std::vector Args(CI->op_begin()+1, CI->op_end()); + + Value *II = new InvokeInst(CI->getCalledValue(), NewBB, Cleanup, + &Args[0], Args.size(), CI->getName(), CBB); + CI->replaceAllUsesWith(II); + delete CI; + } + } + + return true; +} diff --git a/lib/Transforms/Scalar/LowerPacked.cpp b/lib/Transforms/Scalar/LowerPacked.cpp new file mode 100644 index 0000000..0530172 --- /dev/null +++ b/lib/Transforms/Scalar/LowerPacked.cpp @@ -0,0 +1,462 @@ +//===- LowerPacked.cpp - Implementation of LowerPacked Transform ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by Brad Jones and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering Packed datatypes into more primitive +// Packed datatypes, and finally to scalar operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" +#include "llvm/Argument.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/InstVisitor.h" +#include "llvm/Support/Streams.h" +#include "llvm/ADT/StringExtras.h" +#include +#include +#include +using namespace llvm; + +namespace { + +/// This pass converts packed operators to an +/// equivalent operations on smaller packed data, to possibly +/// scalar operations. Currently it supports lowering +/// to scalar operations. +/// +/// @brief Transforms packed instructions to simpler instructions. +/// +class VISIBILITY_HIDDEN LowerPacked + : public FunctionPass, public InstVisitor { +public: + static char ID; // Pass identification, replacement for typeid + LowerPacked() : FunctionPass((intptr_t)&ID) {} + + /// @brief Lowers packed operations to scalar operations. + /// @param F The fuction to process + virtual bool runOnFunction(Function &F); + + /// @brief Lowers packed load instructions. + /// @param LI the load instruction to convert + void visitLoadInst(LoadInst& LI); + + /// @brief Lowers packed store instructions. + /// @param SI the store instruction to convert + void visitStoreInst(StoreInst& SI); + + /// @brief Lowers packed binary operations. + /// @param BO the binary operator to convert + void visitBinaryOperator(BinaryOperator& BO); + + /// @brief Lowers packed icmp operations. + /// @param CI the icmp operator to convert + void visitICmpInst(ICmpInst& IC); + + /// @brief Lowers packed select instructions. + /// @param SELI the select operator to convert + void visitSelectInst(SelectInst& SELI); + + /// @brief Lowers packed extractelement instructions. + /// @param EI the extractelement operator to convert + void visitExtractElementInst(ExtractElementInst& EE); + + /// @brief Lowers packed insertelement instructions. + /// @param EI the insertelement operator to convert + void visitInsertElementInst(InsertElementInst& IE); + + /// This function asserts if the instruction is a VectorType but + /// is handled by another function. + /// + /// @brief Asserts if VectorType instruction is not handled elsewhere. + /// @param I the unhandled instruction + void visitInstruction(Instruction &I) { + if (isa(I.getType())) + cerr << "Unhandled Instruction with Packed ReturnType: " << I << '\n'; + } +private: + /// @brief Retrieves lowered values for a packed value. + /// @param val the packed value + /// @return the lowered values + std::vector& getValues(Value* val); + + /// @brief Sets lowered values for a packed value. + /// @param val the packed value + /// @param values the corresponding lowered values + void setValues(Value* val,const std::vector& values); + + // Data Members + /// @brief whether we changed the function or not + bool Changed; + + /// @brief a map from old packed values to new smaller packed values + std::map > packedToScalarMap; + + /// Instructions in the source program to get rid of + /// after we do a pass (the old packed instructions) + std::vector instrsToRemove; +}; + +char LowerPacked::ID = 0; +RegisterPass +X("lower-packed", + "lowers packed operations to operations on smaller packed datatypes"); + +} // end namespace + +FunctionPass *llvm::createLowerPackedPass() { return new LowerPacked(); } + + +// This function sets lowered values for a corresponding +// packed value. Note, in the case of a forward reference +// getValues(Value*) will have already been called for +// the packed parameter. This function will then replace +// all references in the in the function of the "dummy" +// value the previous getValues(Value*) call +// returned with actual references. +void LowerPacked::setValues(Value* value,const std::vector& values) +{ + std::map >::iterator it = + packedToScalarMap.lower_bound(value); + if (it == packedToScalarMap.end() || it->first != value) { + // there was not a forward reference to this element + packedToScalarMap.insert(it,std::make_pair(value,values)); + } + else { + // replace forward declarations with actual definitions + assert(it->second.size() == values.size() && + "Error forward refences and actual definition differ in size"); + for (unsigned i = 0, e = values.size(); i != e; ++i) { + // replace and get rid of old forward references + it->second[i]->replaceAllUsesWith(values[i]); + delete it->second[i]; + it->second[i] = values[i]; + } + } +} + +// This function will examine the packed value parameter +// and if it is a packed constant or a forward reference +// properly create the lowered values needed. Otherwise +// it will simply retreive values from a +// setValues(Value*,const std::vector&) +// call. Failing both of these cases, it will abort +// the program. +std::vector& LowerPacked::getValues(Value* value) +{ + assert(isa(value->getType()) && + "Value must be VectorType"); + + // reject further processing if this one has + // already been handled + std::map >::iterator it = + packedToScalarMap.lower_bound(value); + if (it != packedToScalarMap.end() && it->first == value) { + return it->second; + } + + if (ConstantVector* CP = dyn_cast(value)) { + // non-zero constant case + std::vector results; + results.reserve(CP->getNumOperands()); + for (unsigned i = 0, e = CP->getNumOperands(); i != e; ++i) { + results.push_back(CP->getOperand(i)); + } + return packedToScalarMap.insert(it, + std::make_pair(value,results))->second; + } + else if (ConstantAggregateZero* CAZ = + dyn_cast(value)) { + // zero constant + const VectorType* PKT = cast(CAZ->getType()); + std::vector results; + results.reserve(PKT->getNumElements()); + + Constant* C = Constant::getNullValue(PKT->getElementType()); + for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) { + results.push_back(C); + } + return packedToScalarMap.insert(it, + std::make_pair(value,results))->second; + } + else if (isa(value)) { + // foward reference + const VectorType* PKT = cast(value->getType()); + std::vector results; + results.reserve(PKT->getNumElements()); + + for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) { + results.push_back(new Argument(PKT->getElementType())); + } + return packedToScalarMap.insert(it, + std::make_pair(value,results))->second; + } + else { + // we don't know what it is, and we are trying to retrieve + // a value for it + assert(false && "Unhandled VectorType value"); + abort(); + } +} + +void LowerPacked::visitLoadInst(LoadInst& LI) +{ + // Make sure what we are dealing with is a vector type + if (const VectorType* PKT = dyn_cast(LI.getType())) { + // Initialization, Idx is needed for getelementptr needed later + std::vector Idx(2); + Idx[0] = ConstantInt::get(Type::Int32Ty,0); + + ArrayType* AT = ArrayType::get(PKT->getContainedType(0), + PKT->getNumElements()); + PointerType* APT = PointerType::get(AT); + + // Cast the pointer to vector type to an equivalent array + Value* array = new BitCastInst(LI.getPointerOperand(), APT, + LI.getName() + ".a", &LI); + + // Convert this load into num elements number of loads + std::vector values; + values.reserve(PKT->getNumElements()); + + for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) { + // Calculate the second index we will need + Idx[1] = ConstantInt::get(Type::Int32Ty,i); + + // Get the pointer + Value* val = new GetElementPtrInst(array, + &Idx[0], Idx.size(), + LI.getName() + + ".ge." + utostr(i), + &LI); + + // generate the new load and save the result in packedToScalar map + values.push_back(new LoadInst(val, LI.getName()+"."+utostr(i), + LI.isVolatile(), &LI)); + } + + setValues(&LI,values); + Changed = true; + instrsToRemove.push_back(&LI); + } +} + +void LowerPacked::visitBinaryOperator(BinaryOperator& BO) +{ + // Make sure both operands are VectorTypes + if (isa(BO.getOperand(0)->getType())) { + std::vector& op0Vals = getValues(BO.getOperand(0)); + std::vector& op1Vals = getValues(BO.getOperand(1)); + std::vector result; + assert((op0Vals.size() == op1Vals.size()) && + "The two packed operand to scalar maps must be equal in size."); + + result.reserve(op0Vals.size()); + + // generate the new binary op and save the result + for (unsigned i = 0; i != op0Vals.size(); ++i) { + result.push_back(BinaryOperator::create(BO.getOpcode(), + op0Vals[i], + op1Vals[i], + BO.getName() + + "." + utostr(i), + &BO)); + } + + setValues(&BO,result); + Changed = true; + instrsToRemove.push_back(&BO); + } +} + +void LowerPacked::visitICmpInst(ICmpInst& IC) +{ + // Make sure both operands are VectorTypes + if (isa(IC.getOperand(0)->getType())) { + std::vector& op0Vals = getValues(IC.getOperand(0)); + std::vector& op1Vals = getValues(IC.getOperand(1)); + std::vector result; + assert((op0Vals.size() == op1Vals.size()) && + "The two packed operand to scalar maps must be equal in size."); + + result.reserve(op0Vals.size()); + + // generate the new binary op and save the result + for (unsigned i = 0; i != op0Vals.size(); ++i) { + result.push_back(CmpInst::create(IC.getOpcode(), + IC.getPredicate(), + op0Vals[i], + op1Vals[i], + IC.getName() + + "." + utostr(i), + &IC)); + } + + setValues(&IC,result); + Changed = true; + instrsToRemove.push_back(&IC); + } +} + +void LowerPacked::visitStoreInst(StoreInst& SI) +{ + if (const VectorType* PKT = + dyn_cast(SI.getOperand(0)->getType())) { + // We will need this for getelementptr + std::vector Idx(2); + Idx[0] = ConstantInt::get(Type::Int32Ty,0); + + ArrayType* AT = ArrayType::get(PKT->getContainedType(0), + PKT->getNumElements()); + PointerType* APT = PointerType::get(AT); + + // Cast the pointer to packed to an array of equivalent type + Value* array = new BitCastInst(SI.getPointerOperand(), APT, + "store.ge.a.", &SI); + + std::vector& values = getValues(SI.getOperand(0)); + + assert((values.size() == PKT->getNumElements()) && + "Scalar must have the same number of elements as Vector Type"); + + for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) { + // Generate the indices for getelementptr + Idx[1] = ConstantInt::get(Type::Int32Ty,i); + Value* val = new GetElementPtrInst(array, + &Idx[0], Idx.size(), + "store.ge." + + utostr(i) + ".", + &SI); + new StoreInst(values[i], val, SI.isVolatile(),&SI); + } + + Changed = true; + instrsToRemove.push_back(&SI); + } +} + +void LowerPacked::visitSelectInst(SelectInst& SELI) +{ + // Make sure both operands are VectorTypes + if (isa(SELI.getType())) { + std::vector& op0Vals = getValues(SELI.getTrueValue()); + std::vector& op1Vals = getValues(SELI.getFalseValue()); + std::vector result; + + assert((op0Vals.size() == op1Vals.size()) && + "The two packed operand to scalar maps must be equal in size."); + + for (unsigned i = 0; i != op0Vals.size(); ++i) { + result.push_back(new SelectInst(SELI.getCondition(), + op0Vals[i], + op1Vals[i], + SELI.getName()+ "." + utostr(i), + &SELI)); + } + + setValues(&SELI,result); + Changed = true; + instrsToRemove.push_back(&SELI); + } +} + +void LowerPacked::visitExtractElementInst(ExtractElementInst& EI) +{ + std::vector& op0Vals = getValues(EI.getOperand(0)); + const VectorType *PTy = cast(EI.getOperand(0)->getType()); + Value *op1 = EI.getOperand(1); + + if (ConstantInt *C = dyn_cast(op1)) { + EI.replaceAllUsesWith(op0Vals[C->getZExtValue()]); + } else { + AllocaInst *alloca = + new AllocaInst(PTy->getElementType(), + ConstantInt::get(Type::Int32Ty, PTy->getNumElements()), + EI.getName() + ".alloca", + EI.getParent()->getParent()->getEntryBlock().begin()); + for (unsigned i = 0; i < PTy->getNumElements(); ++i) { + GetElementPtrInst *GEP = + new GetElementPtrInst(alloca, ConstantInt::get(Type::Int32Ty, i), + "store.ge", &EI); + new StoreInst(op0Vals[i], GEP, &EI); + } + GetElementPtrInst *GEP = + new GetElementPtrInst(alloca, op1, EI.getName() + ".ge", &EI); + LoadInst *load = new LoadInst(GEP, EI.getName() + ".load", &EI); + EI.replaceAllUsesWith(load); + } + + Changed = true; + instrsToRemove.push_back(&EI); +} + +void LowerPacked::visitInsertElementInst(InsertElementInst& IE) +{ + std::vector& Vals = getValues(IE.getOperand(0)); + Value *Elt = IE.getOperand(1); + Value *Idx = IE.getOperand(2); + std::vector result; + result.reserve(Vals.size()); + + if (ConstantInt *C = dyn_cast(Idx)) { + unsigned idxVal = C->getZExtValue(); + for (unsigned i = 0; i != Vals.size(); ++i) { + result.push_back(i == idxVal ? Elt : Vals[i]); + } + } else { + for (unsigned i = 0; i != Vals.size(); ++i) { + ICmpInst *icmp = + new ICmpInst(ICmpInst::ICMP_EQ, Idx, + ConstantInt::get(Type::Int32Ty, i), + "icmp", &IE); + SelectInst *select = + new SelectInst(icmp, Elt, Vals[i], "select", &IE); + result.push_back(select); + } + } + + setValues(&IE, result); + Changed = true; + instrsToRemove.push_back(&IE); +} + +bool LowerPacked::runOnFunction(Function& F) +{ + // initialize + Changed = false; + + // Does three passes: + // Pass 1) Converts Packed Operations to + // new Packed Operations on smaller + // datatypes + visit(F); + + // Pass 2) Drop all references + std::for_each(instrsToRemove.begin(), + instrsToRemove.end(), + std::mem_fun(&Instruction::dropAllReferences)); + + // Pass 3) Delete the Instructions to remove aka packed instructions + for (std::vector::iterator i = instrsToRemove.begin(), + e = instrsToRemove.end(); + i != e; ++i) { + (*i)->getParent()->getInstList().erase(*i); + } + + // clean-up + packedToScalarMap.clear(); + instrsToRemove.clear(); + + return Changed; +} + diff --git a/lib/Transforms/Scalar/Makefile b/lib/Transforms/Scalar/Makefile new file mode 100644 index 0000000..79643c4 --- /dev/null +++ b/lib/Transforms/Scalar/Makefile @@ -0,0 +1,15 @@ +##===- lib/Transforms/Scalar/Makefile ----------------------*- Makefile -*-===## +# +# The LLVM Compiler Infrastructure +# +# This file was developed by the LLVM research group and is distributed under +# the University of Illinois Open Source License. See LICENSE.TXT for details. +# +##===----------------------------------------------------------------------===## + +LEVEL = ../../.. +LIBRARYNAME = LLVMScalarOpts +BUILD_ARCHIVE = 1 + +include $(LEVEL)/Makefile.common + diff --git a/lib/Transforms/Scalar/PredicateSimplifier.cpp b/lib/Transforms/Scalar/PredicateSimplifier.cpp new file mode 100644 index 0000000..7b41fb2 --- /dev/null +++ b/lib/Transforms/Scalar/PredicateSimplifier.cpp @@ -0,0 +1,2640 @@ +//===-- PredicateSimplifier.cpp - Path Sensitive Simplifier ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by Nick Lewycky and is distributed under the +// University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Path-sensitive optimizer. In a branch where x == y, replace uses of +// x with y. Permits further optimization, such as the elimination of +// the unreachable call: +// +// void test(int *p, int *q) +// { +// if (p != q) +// return; +// +// if (*p != *q) +// foo(); // unreachable +// } +// +//===----------------------------------------------------------------------===// +// +// The InequalityGraph focusses on four properties; equals, not equals, +// less-than and less-than-or-equals-to. The greater-than forms are also held +// just to allow walking from a lesser node to a greater one. These properties +// are stored in a lattice; LE can become LT or EQ, NE can become LT or GT. +// +// These relationships define a graph between values of the same type. Each +// Value is stored in a map table that retrieves the associated Node. This +// is how EQ relationships are stored; the map contains pointers from equal +// Value to the same node. The node contains a most canonical Value* form +// and the list of known relationships with other nodes. +// +// If two nodes are known to be inequal, then they will contain pointers to +// each other with an "NE" relationship. If node getNode(%x) is less than +// getNode(%y), then the %x node will contain <%y, GT> and %y will contain +// <%x, LT>. This allows us to tie nodes together into a graph like this: +// +// %a < %b < %c < %d +// +// with four nodes representing the properties. The InequalityGraph provides +// querying with "isRelatedBy" and mutators "addEquality" and "addInequality". +// To find a relationship, we start with one of the nodes any binary search +// through its list to find where the relationships with the second node start. +// Then we iterate through those to find the first relationship that dominates +// our context node. +// +// To create these properties, we wait until a branch or switch instruction +// implies that a particular value is true (or false). The VRPSolver is +// responsible for analyzing the variable and seeing what new inferences +// can be made from each property. For example: +// +// %P = icmp ne i32* %ptr, null +// %a = and i1 %P, %Q +// br i1 %a label %cond_true, label %cond_false +// +// For the true branch, the VRPSolver will start with %a EQ true and look at +// the definition of %a and find that it can infer that %P and %Q are both +// true. From %P being true, it can infer that %ptr NE null. For the false +// branch it can't infer anything from the "and" instruction. +// +// Besides branches, we can also infer properties from instruction that may +// have undefined behaviour in certain cases. For example, the dividend of +// a division may never be zero. After the division instruction, we may assume +// that the dividend is not equal to zero. +// +//===----------------------------------------------------------------------===// +// +// The ValueRanges class stores the known integer bounds of a Value. When we +// encounter i8 %a u< %b, the ValueRanges stores that %a = [1, 255] and +// %b = [0, 254]. Because we store these by Value*, you should always +// canonicalize through the InequalityGraph first. +// +// It never stores an empty range, because that means that the code is +// unreachable. It never stores a single-element range since that's an equality +// relationship and better stored in the InequalityGraph, nor an empty range +// since that is better stored in UnreachableBlocks. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "predsimplify" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/ConstantRange.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/InstVisitor.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/Utils/Local.h" +#include +#include +#include +#include +using namespace llvm; + +STATISTIC(NumVarsReplaced, "Number of argument substitutions"); +STATISTIC(NumInstruction , "Number of instructions removed"); +STATISTIC(NumSimple , "Number of simple replacements"); +STATISTIC(NumBlocks , "Number of blocks marked unreachable"); +STATISTIC(NumSnuggle , "Number of comparisons snuggled"); + +namespace { + class DomTreeDFS { + public: + class Node { + friend class DomTreeDFS; + public: + typedef std::vector::iterator iterator; + typedef std::vector::const_iterator const_iterator; + + unsigned getDFSNumIn() const { return DFSin; } + unsigned getDFSNumOut() const { return DFSout; } + + BasicBlock *getBlock() const { return BB; } + + iterator begin() { return Children.begin(); } + iterator end() { return Children.end(); } + + const_iterator begin() const { return Children.begin(); } + const_iterator end() const { return Children.end(); } + + bool dominates(const Node *N) const { + return DFSin <= N->DFSin && DFSout >= N->DFSout; + } + + bool DominatedBy(const Node *N) const { + return N->dominates(this); + } + + /// Sorts by the number of descendants. With this, you can iterate + /// through a sorted list and the first matching entry is the most + /// specific match for your basic block. The order provided is stable; + /// DomTreeDFS::Nodes with the same number of descendants are sorted by + /// DFS in number. + bool operator<(const Node &N) const { + unsigned spread = DFSout - DFSin; + unsigned N_spread = N.DFSout - N.DFSin; + if (spread == N_spread) return DFSin < N.DFSin; + return spread < N_spread; + } + bool operator>(const Node &N) const { return N < *this; } + + private: + unsigned DFSin, DFSout; + BasicBlock *BB; + + std::vector Children; + }; + + // XXX: this may be slow. Instead of using "new" for each node, consider + // putting them in a vector to keep them contiguous. + explicit DomTreeDFS(DominatorTree *DT) { + std::stack > S; + + Entry = new Node; + Entry->BB = DT->getRootNode()->getBlock(); + S.push(std::make_pair(Entry, DT->getRootNode())); + + NodeMap[Entry->BB] = Entry; + + while (!S.empty()) { + std::pair &Pair = S.top(); + Node *N = Pair.first; + DomTreeNode *DTNode = Pair.second; + S.pop(); + + for (DomTreeNode::iterator I = DTNode->begin(), E = DTNode->end(); + I != E; ++I) { + Node *NewNode = new Node; + NewNode->BB = (*I)->getBlock(); + N->Children.push_back(NewNode); + S.push(std::make_pair(NewNode, *I)); + + NodeMap[NewNode->BB] = NewNode; + } + } + + renumber(); + +#ifndef NDEBUG + DEBUG(dump()); +#endif + } + +#ifndef NDEBUG + virtual +#endif + ~DomTreeDFS() { + std::stack S; + + S.push(Entry); + while (!S.empty()) { + Node *N = S.top(); S.pop(); + + for (Node::iterator I = N->begin(), E = N->end(); I != E; ++I) + S.push(*I); + + delete N; + } + } + + /// getRootNode - This returns the entry node for the CFG of the function. + Node *getRootNode() const { return Entry; } + + /// getNodeForBlock - return the node for the specified basic block. + Node *getNodeForBlock(BasicBlock *BB) const { + if (!NodeMap.count(BB)) return 0; + return const_cast(this)->NodeMap[BB]; + } + + /// dominates - returns true if the basic block for I1 dominates that of + /// the basic block for I2. If the instructions belong to the same basic + /// block, the instruction first instruction sequentially in the block is + /// considered dominating. + bool dominates(Instruction *I1, Instruction *I2) { + BasicBlock *BB1 = I1->getParent(), + *BB2 = I2->getParent(); + if (BB1 == BB2) { + if (isa(I1)) return false; + if (isa(I2)) return true; + if ( isa(I1) && !isa(I2)) return true; + if (!isa(I1) && isa(I2)) return false; + + for (BasicBlock::const_iterator I = BB2->begin(), E = BB2->end(); + I != E; ++I) { + if (&*I == I1) return true; + else if (&*I == I2) return false; + } + assert(!"Instructions not found in parent BasicBlock?"); + } else { + Node *Node1 = getNodeForBlock(BB1), + *Node2 = getNodeForBlock(BB2); + return Node1 && Node2 && Node1->dominates(Node2); + } + } + + private: + /// renumber - calculates the depth first search numberings and applies + /// them onto the nodes. + void renumber() { + std::stack > S; + unsigned n = 0; + + Entry->DFSin = ++n; + S.push(std::make_pair(Entry, Entry->begin())); + + while (!S.empty()) { + std::pair &Pair = S.top(); + Node *N = Pair.first; + Node::iterator &I = Pair.second; + + if (I == N->end()) { + N->DFSout = ++n; + S.pop(); + } else { + Node *Next = *I++; + Next->DFSin = ++n; + S.push(std::make_pair(Next, Next->begin())); + } + } + } + +#ifndef NDEBUG + virtual void dump() const { + dump(*cerr.stream()); + } + + void dump(std::ostream &os) const { + os << "Predicate simplifier DomTreeDFS: \n"; + dump(Entry, 0, os); + os << "\n\n"; + } + + void dump(Node *N, int depth, std::ostream &os) const { + ++depth; + for (int i = 0; i < depth; ++i) { os << " "; } + os << "[" << depth << "] "; + + os << N->getBlock()->getName() << " (" << N->getDFSNumIn() + << ", " << N->getDFSNumOut() << ")\n"; + + for (Node::iterator I = N->begin(), E = N->end(); I != E; ++I) + dump(*I, depth, os); + } +#endif + + Node *Entry; + std::map NodeMap; + }; + + // SLT SGT ULT UGT EQ + // 0 1 0 1 0 -- GT 10 + // 0 1 0 1 1 -- GE 11 + // 0 1 1 0 0 -- SGTULT 12 + // 0 1 1 0 1 -- SGEULE 13 + // 0 1 1 1 0 -- SGT 14 + // 0 1 1 1 1 -- SGE 15 + // 1 0 0 1 0 -- SLTUGT 18 + // 1 0 0 1 1 -- SLEUGE 19 + // 1 0 1 0 0 -- LT 20 + // 1 0 1 0 1 -- LE 21 + // 1 0 1 1 0 -- SLT 22 + // 1 0 1 1 1 -- SLE 23 + // 1 1 0 1 0 -- UGT 26 + // 1 1 0 1 1 -- UGE 27 + // 1 1 1 0 0 -- ULT 28 + // 1 1 1 0 1 -- ULE 29 + // 1 1 1 1 0 -- NE 30 + enum LatticeBits { + EQ_BIT = 1, UGT_BIT = 2, ULT_BIT = 4, SGT_BIT = 8, SLT_BIT = 16 + }; + enum LatticeVal { + GT = SGT_BIT | UGT_BIT, + GE = GT | EQ_BIT, + LT = SLT_BIT | ULT_BIT, + LE = LT | EQ_BIT, + NE = SLT_BIT | SGT_BIT | ULT_BIT | UGT_BIT, + SGTULT = SGT_BIT | ULT_BIT, + SGEULE = SGTULT | EQ_BIT, + SLTUGT = SLT_BIT | UGT_BIT, + SLEUGE = SLTUGT | EQ_BIT, + ULT = SLT_BIT | SGT_BIT | ULT_BIT, + UGT = SLT_BIT | SGT_BIT | UGT_BIT, + SLT = SLT_BIT | ULT_BIT | UGT_BIT, + SGT = SGT_BIT | ULT_BIT | UGT_BIT, + SLE = SLT | EQ_BIT, + SGE = SGT | EQ_BIT, + ULE = ULT | EQ_BIT, + UGE = UGT | EQ_BIT + }; + + static bool validPredicate(LatticeVal LV) { + switch (LV) { + case GT: case GE: case LT: case LE: case NE: + case SGTULT: case SGT: case SGEULE: + case SLTUGT: case SLT: case SLEUGE: + case ULT: case UGT: + case SLE: case SGE: case ULE: case UGE: + return true; + default: + return false; + } + } + + /// reversePredicate - reverse the direction of the inequality + static LatticeVal reversePredicate(LatticeVal LV) { + unsigned reverse = LV ^ (SLT_BIT|SGT_BIT|ULT_BIT|UGT_BIT); //preserve EQ_BIT + + if ((reverse & (SLT_BIT|SGT_BIT)) == 0) + reverse |= (SLT_BIT|SGT_BIT); + + if ((reverse & (ULT_BIT|UGT_BIT)) == 0) + reverse |= (ULT_BIT|UGT_BIT); + + LatticeVal Rev = static_cast(reverse); + assert(validPredicate(Rev) && "Failed reversing predicate."); + return Rev; + } + + /// ValueNumbering stores the scope-specific value numbers for a given Value. + class VISIBILITY_HIDDEN ValueNumbering { + class VISIBILITY_HIDDEN VNPair { + public: + Value *V; + unsigned index; + DomTreeDFS::Node *Subtree; + + VNPair(Value *V, unsigned index, DomTreeDFS::Node *Subtree) + : V(V), index(index), Subtree(Subtree) {} + + bool operator==(const VNPair &RHS) const { + return V == RHS.V && Subtree == RHS.Subtree; + } + + bool operator<(const VNPair &RHS) const { + if (V != RHS.V) return V < RHS.V; + return *Subtree < *RHS.Subtree; + } + + bool operator<(Value *RHS) const { + return V < RHS; + } + }; + + typedef std::vector VNMapType; + VNMapType VNMap; + + std::vector Values; + + DomTreeDFS *DTDFS; + + public: +#ifndef NDEBUG + virtual ~ValueNumbering() {} + virtual void dump() { + dump(*cerr.stream()); + } + + void dump(std::ostream &os) { + for (unsigned i = 1; i <= Values.size(); ++i) { + os << i << " = "; + WriteAsOperand(os, Values[i-1]); + os << " {"; + for (unsigned j = 0; j < VNMap.size(); ++j) { + if (VNMap[j].index == i) { + WriteAsOperand(os, VNMap[j].V); + os << " (" << VNMap[j].Subtree->getDFSNumIn() << ") "; + } + } + os << "}\n"; + } + } +#endif + + /// compare - returns true if V1 is a better canonical value than V2. + bool compare(Value *V1, Value *V2) const { + if (isa(V1)) + return !isa(V2); + else if (isa(V2)) + return false; + else if (isa(V1)) + return !isa(V2); + else if (isa(V2)) + return false; + + Instruction *I1 = dyn_cast(V1); + Instruction *I2 = dyn_cast(V2); + + if (!I1 || !I2) + return V1->getNumUses() < V2->getNumUses(); + + return DTDFS->dominates(I1, I2); + } + + ValueNumbering(DomTreeDFS *DTDFS) : DTDFS(DTDFS) {} + + /// valueNumber - finds the value number for V under the Subtree. If + /// there is no value number, returns zero. + unsigned valueNumber(Value *V, DomTreeDFS::Node *Subtree) { + if (!(isa(V) || isa(V) || isa(V)) + || V->getType() == Type::VoidTy) return 0; + + VNMapType::iterator E = VNMap.end(); + VNPair pair(V, 0, Subtree); + VNMapType::iterator I = std::lower_bound(VNMap.begin(), E, pair); + while (I != E && I->V == V) { + if (I->Subtree->dominates(Subtree)) + return I->index; + ++I; + } + return 0; + } + + /// getOrInsertVN - always returns a value number, creating it if necessary. + unsigned getOrInsertVN(Value *V, DomTreeDFS::Node *Subtree) { + if (unsigned n = valueNumber(V, Subtree)) + return n; + else + return newVN(V); + } + + /// newVN - creates a new value number. Value V must not already have a + /// value number assigned. + unsigned newVN(Value *V) { + assert((isa(V) || isa(V) || isa(V)) && + "Bad Value for value numbering."); + assert(V->getType() != Type::VoidTy && "Won't value number a void value"); + + Values.push_back(V); + + VNPair pair = VNPair(V, Values.size(), DTDFS->getRootNode()); + VNMapType::iterator I = std::lower_bound(VNMap.begin(), VNMap.end(), pair); + assert((I == VNMap.end() || value(I->index) != V) && + "Attempt to create a duplicate value number."); + VNMap.insert(I, pair); + + return Values.size(); + } + + /// value - returns the Value associated with a value number. + Value *value(unsigned index) const { + assert(index != 0 && "Zero index is reserved for not found."); + assert(index <= Values.size() && "Index out of range."); + return Values[index-1]; + } + + /// canonicalize - return a Value that is equal to V under Subtree. + Value *canonicalize(Value *V, DomTreeDFS::Node *Subtree) { + if (isa(V)) return V; + + if (unsigned n = valueNumber(V, Subtree)) + return value(n); + else + return V; + } + + /// addEquality - adds that value V belongs to the set of equivalent + /// values defined by value number n under Subtree. + void addEquality(unsigned n, Value *V, DomTreeDFS::Node *Subtree) { + assert(canonicalize(value(n), Subtree) == value(n) && + "Node's 'canonical' choice isn't best within this subtree."); + + // Suppose that we are given "%x -> node #1 (%y)". The problem is that + // we may already have "%z -> node #2 (%x)" somewhere above us in the + // graph. We need to find those edges and add "%z -> node #1 (%y)" + // to keep the lookups canonical. + + std::vector ToRepoint(1, V); + + if (unsigned Conflict = valueNumber(V, Subtree)) { + for (VNMapType::iterator I = VNMap.begin(), E = VNMap.end(); + I != E; ++I) { + if (I->index == Conflict && I->Subtree->dominates(Subtree)) + ToRepoint.push_back(I->V); + } + } + + for (std::vector::iterator VI = ToRepoint.begin(), + VE = ToRepoint.end(); VI != VE; ++VI) { + Value *V = *VI; + + VNPair pair(V, n, Subtree); + VNMapType::iterator B = VNMap.begin(), E = VNMap.end(); + VNMapType::iterator I = std::lower_bound(B, E, pair); + if (I != E && I->V == V && I->Subtree == Subtree) + I->index = n; // Update best choice + else + VNMap.insert(I, pair); // New Value + + // XXX: we currently don't have to worry about updating values with + // more specific Subtrees, but we will need to for PHI node support. + +#ifndef NDEBUG + Value *V_n = value(n); + if (isa(V) && isa(V_n)) { + assert(V == V_n && "Constant equals different constant?"); + } +#endif + } + } + + /// remove - removes all references to value V. + void remove(Value *V) { + VNMapType::iterator B = VNMap.begin(), E = VNMap.end(); + VNPair pair(V, 0, DTDFS->getRootNode()); + VNMapType::iterator J = std::upper_bound(B, E, pair); + VNMapType::iterator I = J; + + while (I != B && (I == E || I->V == V)) --I; + + VNMap.erase(I, J); + } + }; + + /// The InequalityGraph stores the relationships between values. + /// Each Value in the graph is assigned to a Node. Nodes are pointer + /// comparable for equality. The caller is expected to maintain the logical + /// consistency of the system. + /// + /// The InequalityGraph class may invalidate Node*s after any mutator call. + /// @brief The InequalityGraph stores the relationships between values. + class VISIBILITY_HIDDEN InequalityGraph { + ValueNumbering &VN; + DomTreeDFS::Node *TreeRoot; + + InequalityGraph(); // DO NOT IMPLEMENT + InequalityGraph(InequalityGraph &); // DO NOT IMPLEMENT + public: + InequalityGraph(ValueNumbering &VN, DomTreeDFS::Node *TreeRoot) + : VN(VN), TreeRoot(TreeRoot) {} + + class Node; + + /// An Edge is contained inside a Node making one end of the edge implicit + /// and contains a pointer to the other end. The edge contains a lattice + /// value specifying the relationship and an DomTreeDFS::Node specifying + /// the root in the dominator tree to which this edge applies. + class VISIBILITY_HIDDEN Edge { + public: + Edge(unsigned T, LatticeVal V, DomTreeDFS::Node *ST) + : To(T), LV(V), Subtree(ST) {} + + unsigned To; + LatticeVal LV; + DomTreeDFS::Node *Subtree; + + bool operator<(const Edge &edge) const { + if (To != edge.To) return To < edge.To; + return *Subtree < *edge.Subtree; + } + + bool operator<(unsigned to) const { + return To < to; + } + + bool operator>(unsigned to) const { + return To > to; + } + + friend bool operator<(unsigned to, const Edge &edge) { + return edge.operator>(to); + } + }; + + /// A single node in the InequalityGraph. This stores the canonical Value + /// for the node, as well as the relationships with the neighbours. + /// + /// @brief A single node in the InequalityGraph. + class VISIBILITY_HIDDEN Node { + friend class InequalityGraph; + + typedef SmallVector RelationsType; + RelationsType Relations; + + // TODO: can this idea improve performance? + //friend class std::vector; + //Node(Node &N) { RelationsType.swap(N.RelationsType); } + + public: + typedef RelationsType::iterator iterator; + typedef RelationsType::const_iterator const_iterator; + +#ifndef NDEBUG + virtual ~Node() {} + virtual void dump() const { + dump(*cerr.stream()); + } + private: + void dump(std::ostream &os) const { + static const std::string names[32] = + { "000000", "000001", "000002", "000003", "000004", "000005", + "000006", "000007", "000008", "000009", " >", " >=", + " s>u<", "s>=u<=", " s>", " s>=", "000016", "000017", + " s", "s<=u>=", " <", " <=", " s<", " s<=", + "000024", "000025", " u>", " u>=", " u<", " u<=", + " !=", "000031" }; + for (Node::const_iterator NI = begin(), NE = end(); NI != NE; ++NI) { + os << names[NI->LV] << " " << NI->To + << " (" << NI->Subtree->getDFSNumIn() << "), "; + } + } + public: +#endif + + iterator begin() { return Relations.begin(); } + iterator end() { return Relations.end(); } + const_iterator begin() const { return Relations.begin(); } + const_iterator end() const { return Relations.end(); } + + iterator find(unsigned n, DomTreeDFS::Node *Subtree) { + iterator E = end(); + for (iterator I = std::lower_bound(begin(), E, n); + I != E && I->To == n; ++I) { + if (Subtree->DominatedBy(I->Subtree)) + return I; + } + return E; + } + + const_iterator find(unsigned n, DomTreeDFS::Node *Subtree) const { + const_iterator E = end(); + for (const_iterator I = std::lower_bound(begin(), E, n); + I != E && I->To == n; ++I) { + if (Subtree->DominatedBy(I->Subtree)) + return I; + } + return E; + } + + /// Updates the lattice value for a given node. Create a new entry if + /// one doesn't exist, otherwise it merges the values. The new lattice + /// value must not be inconsistent with any previously existing value. + void update(unsigned n, LatticeVal R, DomTreeDFS::Node *Subtree) { + assert(validPredicate(R) && "Invalid predicate."); + iterator I = find(n, Subtree); + if (I == end()) { + Edge edge(n, R, Subtree); + iterator Insert = std::lower_bound(begin(), end(), edge); + Relations.insert(Insert, edge); + } else { + LatticeVal LV = static_cast(I->LV & R); + assert(validPredicate(LV) && "Invalid union of lattice values."); + if (LV != I->LV) { + if (Subtree != I->Subtree) { + assert(Subtree->DominatedBy(I->Subtree) && + "Find returned subtree that doesn't apply."); + + Edge edge(n, R, Subtree); + iterator Insert = std::lower_bound(begin(), end(), edge); + Relations.insert(Insert, edge); // invalidates I + I = find(n, Subtree); + } + + // Also, we have to tighten any edge that Subtree dominates. + for (iterator B = begin(); I->To == n; --I) { + if (I->Subtree->DominatedBy(Subtree)) { + LatticeVal LV = static_cast(I->LV & R); + assert(validPredicate(LV) && "Invalid union of lattice values"); + I->LV = LV; + } + if (I == B) break; + } + } + } + } + }; + + private: + + std::vector Nodes; + + public: + /// node - returns the node object at a given value number. The pointer + /// returned may be invalidated on the next call to node(). + Node *node(unsigned index) { + assert(VN.value(index)); // This triggers the necessary checks. + if (Nodes.size() < index) Nodes.resize(index); + return &Nodes[index-1]; + } + + /// isRelatedBy - true iff n1 op n2 + bool isRelatedBy(unsigned n1, unsigned n2, DomTreeDFS::Node *Subtree, + LatticeVal LV) { + if (n1 == n2) return LV & EQ_BIT; + + Node *N1 = node(n1); + Node::iterator I = N1->find(n2, Subtree), E = N1->end(); + if (I != E) return (I->LV & LV) == I->LV; + + return false; + } + + // The add* methods assume that your input is logically valid and may + // assertion-fail or infinitely loop if you attempt a contradiction. + + /// addInequality - Sets n1 op n2. + /// It is also an error to call this on an inequality that is already true. + void addInequality(unsigned n1, unsigned n2, DomTreeDFS::Node *Subtree, + LatticeVal LV1) { + assert(n1 != n2 && "A node can't be inequal to itself."); + + if (LV1 != NE) + assert(!isRelatedBy(n1, n2, Subtree, reversePredicate(LV1)) && + "Contradictory inequality."); + + // Suppose we're adding %n1 < %n2. Find all the %a < %n1 and + // add %a < %n2 too. This keeps the graph fully connected. + if (LV1 != NE) { + // Break up the relationship into signed and unsigned comparison parts. + // If the signed parts of %a op1 %n1 match that of %n1 op2 %n2, and + // op1 and op2 aren't NE, then add %a op3 %n2. The new relationship + // should have the EQ_BIT iff it's set for both op1 and op2. + + unsigned LV1_s = LV1 & (SLT_BIT|SGT_BIT); + unsigned LV1_u = LV1 & (ULT_BIT|UGT_BIT); + + for (Node::iterator I = node(n1)->begin(), E = node(n1)->end(); I != E; ++I) { + if (I->LV != NE && I->To != n2) { + + DomTreeDFS::Node *Local_Subtree = NULL; + if (Subtree->DominatedBy(I->Subtree)) + Local_Subtree = Subtree; + else if (I->Subtree->DominatedBy(Subtree)) + Local_Subtree = I->Subtree; + + if (Local_Subtree) { + unsigned new_relationship = 0; + LatticeVal ILV = reversePredicate(I->LV); + unsigned ILV_s = ILV & (SLT_BIT|SGT_BIT); + unsigned ILV_u = ILV & (ULT_BIT|UGT_BIT); + + if (LV1_s != (SLT_BIT|SGT_BIT) && ILV_s == LV1_s) + new_relationship |= ILV_s; + if (LV1_u != (ULT_BIT|UGT_BIT) && ILV_u == LV1_u) + new_relationship |= ILV_u; + + if (new_relationship) { + if ((new_relationship & (SLT_BIT|SGT_BIT)) == 0) + new_relationship |= (SLT_BIT|SGT_BIT); + if ((new_relationship & (ULT_BIT|UGT_BIT)) == 0) + new_relationship |= (ULT_BIT|UGT_BIT); + if ((LV1 & EQ_BIT) && (ILV & EQ_BIT)) + new_relationship |= EQ_BIT; + + LatticeVal NewLV = static_cast(new_relationship); + + node(I->To)->update(n2, NewLV, Local_Subtree); + node(n2)->update(I->To, reversePredicate(NewLV), Local_Subtree); + } + } + } + } + + for (Node::iterator I = node(n2)->begin(), E = node(n2)->end(); I != E; ++I) { + if (I->LV != NE && I->To != n1) { + DomTreeDFS::Node *Local_Subtree = NULL; + if (Subtree->DominatedBy(I->Subtree)) + Local_Subtree = Subtree; + else if (I->Subtree->DominatedBy(Subtree)) + Local_Subtree = I->Subtree; + + if (Local_Subtree) { + unsigned new_relationship = 0; + unsigned ILV_s = I->LV & (SLT_BIT|SGT_BIT); + unsigned ILV_u = I->LV & (ULT_BIT|UGT_BIT); + + if (LV1_s != (SLT_BIT|SGT_BIT) && ILV_s == LV1_s) + new_relationship |= ILV_s; + + if (LV1_u != (ULT_BIT|UGT_BIT) && ILV_u == LV1_u) + new_relationship |= ILV_u; + + if (new_relationship) { + if ((new_relationship & (SLT_BIT|SGT_BIT)) == 0) + new_relationship |= (SLT_BIT|SGT_BIT); + if ((new_relationship & (ULT_BIT|UGT_BIT)) == 0) + new_relationship |= (ULT_BIT|UGT_BIT); + if ((LV1 & EQ_BIT) && (I->LV & EQ_BIT)) + new_relationship |= EQ_BIT; + + LatticeVal NewLV = static_cast(new_relationship); + + node(n1)->update(I->To, NewLV, Local_Subtree); + node(I->To)->update(n1, reversePredicate(NewLV), Local_Subtree); + } + } + } + } + } + + node(n1)->update(n2, LV1, Subtree); + node(n2)->update(n1, reversePredicate(LV1), Subtree); + } + + /// remove - removes a node from the graph by removing all references to + /// and from it. + void remove(unsigned n) { + Node *N = node(n); + for (Node::iterator NI = N->begin(), NE = N->end(); NI != NE; ++NI) { + Node::iterator Iter = node(NI->To)->find(n, TreeRoot); + do { + node(NI->To)->Relations.erase(Iter); + Iter = node(NI->To)->find(n, TreeRoot); + } while (Iter != node(NI->To)->end()); + } + N->Relations.clear(); + } + +#ifndef NDEBUG + virtual ~InequalityGraph() {} + virtual void dump() { + dump(*cerr.stream()); + } + + void dump(std::ostream &os) { + for (unsigned i = 1; i <= Nodes.size(); ++i) { + os << i << " = {"; + node(i)->dump(os); + os << "}\n"; + } + } +#endif + }; + + class VRPSolver; + + /// ValueRanges tracks the known integer ranges and anti-ranges of the nodes + /// in the InequalityGraph. + class VISIBILITY_HIDDEN ValueRanges { + ValueNumbering &VN; + TargetData *TD; + + class VISIBILITY_HIDDEN ScopedRange { + typedef std::vector > + RangeListType; + RangeListType RangeList; + + static bool swo(const std::pair &LHS, + const std::pair &RHS) { + return *LHS.first < *RHS.first; + } + + public: +#ifndef NDEBUG + virtual ~ScopedRange() {} + virtual void dump() const { + dump(*cerr.stream()); + } + + void dump(std::ostream &os) const { + os << "{"; + for (const_iterator I = begin(), E = end(); I != E; ++I) { + os << I->second << " (" << I->first->getDFSNumIn() << "), "; + } + os << "}"; + } +#endif + + typedef RangeListType::iterator iterator; + typedef RangeListType::const_iterator const_iterator; + + iterator begin() { return RangeList.begin(); } + iterator end() { return RangeList.end(); } + const_iterator begin() const { return RangeList.begin(); } + const_iterator end() const { return RangeList.end(); } + + iterator find(DomTreeDFS::Node *Subtree) { + static ConstantRange empty(1, false); + iterator E = end(); + iterator I = std::lower_bound(begin(), E, + std::make_pair(Subtree, empty), swo); + + while (I != E && !I->first->dominates(Subtree)) ++I; + return I; + } + + const_iterator find(DomTreeDFS::Node *Subtree) const { + static const ConstantRange empty(1, false); + const_iterator E = end(); + const_iterator I = std::lower_bound(begin(), E, + std::make_pair(Subtree, empty), swo); + + while (I != E && !I->first->dominates(Subtree)) ++I; + return I; + } + + void update(const ConstantRange &CR, DomTreeDFS::Node *Subtree) { + assert(!CR.isEmptySet() && "Empty ConstantRange."); + assert(!CR.isSingleElement() && "Won't store single element."); + + static ConstantRange empty(1, false); + iterator E = end(); + iterator I = + std::lower_bound(begin(), E, std::make_pair(Subtree, empty), swo); + + if (I != end() && I->first == Subtree) { + ConstantRange CR2 = I->second.maximalIntersectWith(CR); + assert(!CR2.isEmptySet() && !CR2.isSingleElement() && + "Invalid union of ranges."); + I->second = CR2; + } else + RangeList.insert(I, std::make_pair(Subtree, CR)); + } + }; + + std::vector Ranges; + + void update(unsigned n, const ConstantRange &CR, DomTreeDFS::Node *Subtree){ + if (CR.isFullSet()) return; + if (Ranges.size() < n) Ranges.resize(n); + Ranges[n-1].update(CR, Subtree); + } + + /// create - Creates a ConstantRange that matches the given LatticeVal + /// relation with a given integer. + ConstantRange create(LatticeVal LV, const ConstantRange &CR) { + assert(!CR.isEmptySet() && "Can't deal with empty set."); + + if (LV == NE) + return makeConstantRange(ICmpInst::ICMP_NE, CR); + + unsigned LV_s = LV & (SGT_BIT|SLT_BIT); + unsigned LV_u = LV & (UGT_BIT|ULT_BIT); + bool hasEQ = LV & EQ_BIT; + + ConstantRange Range(CR.getBitWidth()); + + if (LV_s == SGT_BIT) { + Range = Range.maximalIntersectWith(makeConstantRange( + hasEQ ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_SGT, CR)); + } else if (LV_s == SLT_BIT) { + Range = Range.maximalIntersectWith(makeConstantRange( + hasEQ ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_SLT, CR)); + } + + if (LV_u == UGT_BIT) { + Range = Range.maximalIntersectWith(makeConstantRange( + hasEQ ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_UGT, CR)); + } else if (LV_u == ULT_BIT) { + Range = Range.maximalIntersectWith(makeConstantRange( + hasEQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT, CR)); + } + + return Range; + } + + /// makeConstantRange - Creates a ConstantRange representing the set of all + /// value that match the ICmpInst::Predicate with any of the values in CR. + ConstantRange makeConstantRange(ICmpInst::Predicate ICmpOpcode, + const ConstantRange &CR) { + uint32_t W = CR.getBitWidth(); + switch (ICmpOpcode) { + default: assert(!"Invalid ICmp opcode to makeConstantRange()"); + case ICmpInst::ICMP_EQ: + return ConstantRange(CR.getLower(), CR.getUpper()); + case ICmpInst::ICMP_NE: + if (CR.isSingleElement()) + return ConstantRange(CR.getUpper(), CR.getLower()); + return ConstantRange(W); + case ICmpInst::ICMP_ULT: + return ConstantRange(APInt::getMinValue(W), CR.getUnsignedMax()); + case ICmpInst::ICMP_SLT: + return ConstantRange(APInt::getSignedMinValue(W), CR.getSignedMax()); + case ICmpInst::ICMP_ULE: { + APInt UMax(CR.getUnsignedMax()); + if (UMax.isMaxValue()) + return ConstantRange(W); + return ConstantRange(APInt::getMinValue(W), UMax + 1); + } + case ICmpInst::ICMP_SLE: { + APInt SMax(CR.getSignedMax()); + if (SMax.isMaxSignedValue() || (SMax+1).isMaxSignedValue()) + return ConstantRange(W); + return ConstantRange(APInt::getSignedMinValue(W), SMax + 1); + } + case ICmpInst::ICMP_UGT: + return ConstantRange(CR.getUnsignedMin() + 1, APInt::getNullValue(W)); + case ICmpInst::ICMP_SGT: + return ConstantRange(CR.getSignedMin() + 1, + APInt::getSignedMinValue(W)); + case ICmpInst::ICMP_UGE: { + APInt UMin(CR.getUnsignedMin()); + if (UMin.isMinValue()) + return ConstantRange(W); + return ConstantRange(UMin, APInt::getNullValue(W)); + } + case ICmpInst::ICMP_SGE: { + APInt SMin(CR.getSignedMin()); + if (SMin.isMinSignedValue()) + return ConstantRange(W); + return ConstantRange(SMin, APInt::getSignedMinValue(W)); + } + } + } + +#ifndef NDEBUG + bool isCanonical(Value *V, DomTreeDFS::Node *Subtree) { + return V == VN.canonicalize(V, Subtree); + } +#endif + + public: + + ValueRanges(ValueNumbering &VN, TargetData *TD) : VN(VN), TD(TD) {} + +#ifndef NDEBUG + virtual ~ValueRanges() {} + + virtual void dump() const { + dump(*cerr.stream()); + } + + void dump(std::ostream &os) const { + for (unsigned i = 0, e = Ranges.size(); i != e; ++i) { + os << (i+1) << " = "; + Ranges[i].dump(os); + os << "\n"; + } + } +#endif + + /// range - looks up the ConstantRange associated with a value number. + ConstantRange range(unsigned n, DomTreeDFS::Node *Subtree) { + assert(VN.value(n)); // performs range checks + + if (n <= Ranges.size()) { + ScopedRange::iterator I = Ranges[n-1].find(Subtree); + if (I != Ranges[n-1].end()) return I->second; + } + + Value *V = VN.value(n); + ConstantRange CR = range(V); + return CR; + } + + /// range - determine a range from a Value without performing any lookups. + ConstantRange range(Value *V) const { + if (ConstantInt *C = dyn_cast(V)) + return ConstantRange(C->getValue()); + else if (isa(V)) + return ConstantRange(APInt::getNullValue(typeToWidth(V->getType()))); + else + return typeToWidth(V->getType()); + } + + // typeToWidth - returns the number of bits necessary to store a value of + // this type, or zero if unknown. + uint32_t typeToWidth(const Type *Ty) const { + if (TD) + return TD->getTypeSizeInBits(Ty); + + if (const IntegerType *ITy = dyn_cast(Ty)) + return ITy->getBitWidth(); + + return 0; + } + + static bool isRelatedBy(const ConstantRange &CR1, const ConstantRange &CR2, + LatticeVal LV) { + switch (LV) { + default: assert(!"Impossible lattice value!"); + case NE: + return CR1.maximalIntersectWith(CR2).isEmptySet(); + case ULT: + return CR1.getUnsignedMax().ult(CR2.getUnsignedMin()); + case ULE: + return CR1.getUnsignedMax().ule(CR2.getUnsignedMin()); + case UGT: + return CR1.getUnsignedMin().ugt(CR2.getUnsignedMax()); + case UGE: + return CR1.getUnsignedMin().uge(CR2.getUnsignedMax()); + case SLT: + return CR1.getSignedMax().slt(CR2.getSignedMin()); + case SLE: + return CR1.getSignedMax().sle(CR2.getSignedMin()); + case SGT: + return CR1.getSignedMin().sgt(CR2.getSignedMax()); + case SGE: + return CR1.getSignedMin().sge(CR2.getSignedMax()); + case LT: + return CR1.getUnsignedMax().ult(CR2.getUnsignedMin()) && + CR1.getSignedMax().slt(CR2.getUnsignedMin()); + case LE: + return CR1.getUnsignedMax().ule(CR2.getUnsignedMin()) && + CR1.getSignedMax().sle(CR2.getUnsignedMin()); + case GT: + return CR1.getUnsignedMin().ugt(CR2.getUnsignedMax()) && + CR1.getSignedMin().sgt(CR2.getSignedMax()); + case GE: + return CR1.getUnsignedMin().uge(CR2.getUnsignedMax()) && + CR1.getSignedMin().sge(CR2.getSignedMax()); + case SLTUGT: + return CR1.getSignedMax().slt(CR2.getSignedMin()) && + CR1.getUnsignedMin().ugt(CR2.getUnsignedMax()); + case SLEUGE: + return CR1.getSignedMax().sle(CR2.getSignedMin()) && + CR1.getUnsignedMin().uge(CR2.getUnsignedMax()); + case SGTULT: + return CR1.getSignedMin().sgt(CR2.getSignedMax()) && + CR1.getUnsignedMax().ult(CR2.getUnsignedMin()); + case SGEULE: + return CR1.getSignedMin().sge(CR2.getSignedMax()) && + CR1.getUnsignedMax().ule(CR2.getUnsignedMin()); + } + } + + bool isRelatedBy(unsigned n1, unsigned n2, DomTreeDFS::Node *Subtree, + LatticeVal LV) { + ConstantRange CR1 = range(n1, Subtree); + ConstantRange CR2 = range(n2, Subtree); + + // True iff all values in CR1 are LV to all values in CR2. + return isRelatedBy(CR1, CR2, LV); + } + + void addToWorklist(Value *V, Constant *C, ICmpInst::Predicate Pred, + VRPSolver *VRP); + void markBlock(VRPSolver *VRP); + + void mergeInto(Value **I, unsigned n, unsigned New, + DomTreeDFS::Node *Subtree, VRPSolver *VRP) { + ConstantRange CR_New = range(New, Subtree); + ConstantRange Merged = CR_New; + + for (; n != 0; ++I, --n) { + unsigned i = VN.valueNumber(*I, Subtree); + ConstantRange CR_Kill = i ? range(i, Subtree) : range(*I); + if (CR_Kill.isFullSet()) continue; + Merged = Merged.maximalIntersectWith(CR_Kill); + } + + if (Merged.isFullSet() || Merged == CR_New) return; + + applyRange(New, Merged, Subtree, VRP); + } + + void applyRange(unsigned n, const ConstantRange &CR, + DomTreeDFS::Node *Subtree, VRPSolver *VRP) { + ConstantRange Merged = CR.maximalIntersectWith(range(n, Subtree)); + if (Merged.isEmptySet()) { + markBlock(VRP); + return; + } + + if (const APInt *I = Merged.getSingleElement()) { + Value *V = VN.value(n); // XXX: redesign worklist. + const Type *Ty = V->getType(); + if (Ty->isInteger()) { + addToWorklist(V, ConstantInt::get(*I), ICmpInst::ICMP_EQ, VRP); + return; + } else if (const PointerType *PTy = dyn_cast(Ty)) { + assert(*I == 0 && "Pointer is null but not zero?"); + addToWorklist(V, ConstantPointerNull::get(PTy), + ICmpInst::ICMP_EQ, VRP); + return; + } + } + + update(n, Merged, Subtree); + } + + void addNotEquals(unsigned n1, unsigned n2, DomTreeDFS::Node *Subtree, + VRPSolver *VRP) { + ConstantRange CR1 = range(n1, Subtree); + ConstantRange CR2 = range(n2, Subtree); + + uint32_t W = CR1.getBitWidth(); + + if (const APInt *I = CR1.getSingleElement()) { + if (CR2.isFullSet()) { + ConstantRange NewCR2(CR1.getUpper(), CR1.getLower()); + applyRange(n2, NewCR2, Subtree, VRP); + } else if (*I == CR2.getLower()) { + APInt NewLower(CR2.getLower() + 1), + NewUpper(CR2.getUpper()); + if (NewLower == NewUpper) + NewLower = NewUpper = APInt::getMinValue(W); + + ConstantRange NewCR2(NewLower, NewUpper); + applyRange(n2, NewCR2, Subtree, VRP); + } else if (*I == CR2.getUpper() - 1) { + APInt NewLower(CR2.getLower()), + NewUpper(CR2.getUpper() - 1); + if (NewLower == NewUpper) + NewLower = NewUpper = APInt::getMinValue(W); + + ConstantRange NewCR2(NewLower, NewUpper); + applyRange(n2, NewCR2, Subtree, VRP); + } + } + + if (const APInt *I = CR2.getSingleElement()) { + if (CR1.isFullSet()) { + ConstantRange NewCR1(CR2.getUpper(), CR2.getLower()); + applyRange(n1, NewCR1, Subtree, VRP); + } else if (*I == CR1.getLower()) { + APInt NewLower(CR1.getLower() + 1), + NewUpper(CR1.getUpper()); + if (NewLower == NewUpper) + NewLower = NewUpper = APInt::getMinValue(W); + + ConstantRange NewCR1(NewLower, NewUpper); + applyRange(n1, NewCR1, Subtree, VRP); + } else if (*I == CR1.getUpper() - 1) { + APInt NewLower(CR1.getLower()), + NewUpper(CR1.getUpper() - 1); + if (NewLower == NewUpper) + NewLower = NewUpper = APInt::getMinValue(W); + + ConstantRange NewCR1(NewLower, NewUpper); + applyRange(n1, NewCR1, Subtree, VRP); + } + } + } + + void addInequality(unsigned n1, unsigned n2, DomTreeDFS::Node *Subtree, + LatticeVal LV, VRPSolver *VRP) { + assert(!isRelatedBy(n1, n2, Subtree, LV) && "Asked to do useless work."); + + if (LV == NE) { + addNotEquals(n1, n2, Subtree, VRP); + return; + } + + ConstantRange CR1 = range(n1, Subtree); + ConstantRange CR2 = range(n2, Subtree); + + if (!CR1.isSingleElement()) { + ConstantRange NewCR1 = CR1.maximalIntersectWith(create(LV, CR2)); + if (NewCR1 != CR1) + applyRange(n1, NewCR1, Subtree, VRP); + } + + if (!CR2.isSingleElement()) { + ConstantRange NewCR2 = CR2.maximalIntersectWith( + create(reversePredicate(LV), CR1)); + if (NewCR2 != CR2) + applyRange(n2, NewCR2, Subtree, VRP); + } + } + }; + + /// UnreachableBlocks keeps tracks of blocks that are for one reason or + /// another discovered to be unreachable. This is used to cull the graph when + /// analyzing instructions, and to mark blocks with the "unreachable" + /// terminator instruction after the function has executed. + class VISIBILITY_HIDDEN UnreachableBlocks { + private: + std::vector DeadBlocks; + + public: + /// mark - mark a block as dead + void mark(BasicBlock *BB) { + std::vector::iterator E = DeadBlocks.end(); + std::vector::iterator I = + std::lower_bound(DeadBlocks.begin(), E, BB); + + if (I == E || *I != BB) DeadBlocks.insert(I, BB); + } + + /// isDead - returns whether a block is known to be dead already + bool isDead(BasicBlock *BB) { + std::vector::iterator E = DeadBlocks.end(); + std::vector::iterator I = + std::lower_bound(DeadBlocks.begin(), E, BB); + + return I != E && *I == BB; + } + + /// kill - replace the dead blocks' terminator with an UnreachableInst. + bool kill() { + bool modified = false; + for (std::vector::iterator I = DeadBlocks.begin(), + E = DeadBlocks.end(); I != E; ++I) { + BasicBlock *BB = *I; + + DOUT << "unreachable block: " << BB->getName() << "\n"; + + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); + SI != SE; ++SI) { + BasicBlock *Succ = *SI; + Succ->removePredecessor(BB); + } + + TerminatorInst *TI = BB->getTerminator(); + TI->replaceAllUsesWith(UndefValue::get(TI->getType())); + TI->eraseFromParent(); + new UnreachableInst(BB); + ++NumBlocks; + modified = true; + } + DeadBlocks.clear(); + return modified; + } + }; + + /// VRPSolver keeps track of how changes to one variable affect other + /// variables, and forwards changes along to the InequalityGraph. It + /// also maintains the correct choice for "canonical" in the IG. + /// @brief VRPSolver calculates inferences from a new relationship. + class VISIBILITY_HIDDEN VRPSolver { + private: + friend class ValueRanges; + + struct Operation { + Value *LHS, *RHS; + ICmpInst::Predicate Op; + + BasicBlock *ContextBB; // XXX use a DomTreeDFS::Node instead + Instruction *ContextInst; + }; + std::deque WorkList; + + ValueNumbering &VN; + InequalityGraph &IG; + UnreachableBlocks &UB; + ValueRanges &VR; + DomTreeDFS *DTDFS; + DomTreeDFS::Node *Top; + BasicBlock *TopBB; + Instruction *TopInst; + bool &modified; + + typedef InequalityGraph::Node Node; + + // below - true if the Instruction is dominated by the current context + // block or instruction + bool below(Instruction *I) { + BasicBlock *BB = I->getParent(); + if (TopInst && TopInst->getParent() == BB) { + if (isa(TopInst)) return false; + if (isa(I)) return true; + if ( isa(TopInst) && !isa(I)) return true; + if (!isa(TopInst) && isa(I)) return false; + + for (BasicBlock::const_iterator Iter = BB->begin(), E = BB->end(); + Iter != E; ++Iter) { + if (&*Iter == TopInst) return true; + else if (&*Iter == I) return false; + } + assert(!"Instructions not found in parent BasicBlock?"); + } else { + DomTreeDFS::Node *Node = DTDFS->getNodeForBlock(BB); + if (!Node) return false; + return Top->dominates(Node); + } + } + + // aboveOrBelow - true if the Instruction either dominates or is dominated + // by the current context block or instruction + bool aboveOrBelow(Instruction *I) { + BasicBlock *BB = I->getParent(); + DomTreeDFS::Node *Node = DTDFS->getNodeForBlock(BB); + if (!Node) return false; + + return Top == Node || Top->dominates(Node) || Node->dominates(Top); + } + + bool makeEqual(Value *V1, Value *V2) { + DOUT << "makeEqual(" << *V1 << ", " << *V2 << ")\n"; + DOUT << "context is "; + if (TopInst) DOUT << "I: " << *TopInst << "\n"; + else DOUT << "BB: " << TopBB->getName() + << "(" << Top->getDFSNumIn() << ")\n"; + + assert(V1->getType() == V2->getType() && + "Can't make two values with different types equal."); + + if (V1 == V2) return true; + + if (isa(V1) && isa(V2)) + return false; + + unsigned n1 = VN.valueNumber(V1, Top), n2 = VN.valueNumber(V2, Top); + + if (n1 && n2) { + if (n1 == n2) return true; + if (IG.isRelatedBy(n1, n2, Top, NE)) return false; + } + + if (n1) assert(V1 == VN.value(n1) && "Value isn't canonical."); + if (n2) assert(V2 == VN.value(n2) && "Value isn't canonical."); + + assert(!VN.compare(V2, V1) && "Please order parameters to makeEqual."); + + assert(!isa(V2) && "Tried to remove a constant."); + + SetVector Remove; + if (n2) Remove.insert(n2); + + if (n1 && n2) { + // Suppose we're being told that %x == %y, and %x <= %z and %y >= %z. + // We can't just merge %x and %y because the relationship with %z would + // be EQ and that's invalid. What we're doing is looking for any nodes + // %z such that %x <= %z and %y >= %z, and vice versa. + + Node::iterator end = IG.node(n2)->end(); + + // Find the intersection between N1 and N2 which is dominated by + // Top. If we find %x where N1 <= %x <= N2 (or >=) then add %x to + // Remove. + for (Node::iterator I = IG.node(n1)->begin(), E = IG.node(n1)->end(); + I != E; ++I) { + if (!(I->LV & EQ_BIT) || !Top->DominatedBy(I->Subtree)) continue; + + unsigned ILV_s = I->LV & (SLT_BIT|SGT_BIT); + unsigned ILV_u = I->LV & (ULT_BIT|UGT_BIT); + Node::iterator NI = IG.node(n2)->find(I->To, Top); + if (NI != end) { + LatticeVal NILV = reversePredicate(NI->LV); + unsigned NILV_s = NILV & (SLT_BIT|SGT_BIT); + unsigned NILV_u = NILV & (ULT_BIT|UGT_BIT); + + if ((ILV_s != (SLT_BIT|SGT_BIT) && ILV_s == NILV_s) || + (ILV_u != (ULT_BIT|UGT_BIT) && ILV_u == NILV_u)) + Remove.insert(I->To); + } + } + + // See if one of the nodes about to be removed is actually a better + // canonical choice than n1. + unsigned orig_n1 = n1; + SetVector::iterator DontRemove = Remove.end(); + for (SetVector::iterator I = Remove.begin()+1 /* skip n2 */, + E = Remove.end(); I != E; ++I) { + unsigned n = *I; + Value *V = VN.value(n); + if (VN.compare(V, V1)) { + V1 = V; + n1 = n; + DontRemove = I; + } + } + if (DontRemove != Remove.end()) { + unsigned n = *DontRemove; + Remove.remove(n); + Remove.insert(orig_n1); + } + } + + // We'd like to allow makeEqual on two values to perform a simple + // substitution without every creating nodes in the IG whenever possible. + // + // The first iteration through this loop operates on V2 before going + // through the Remove list and operating on those too. If all of the + // iterations performed simple replacements then we exit early. + bool mergeIGNode = false; + unsigned i = 0; + for (Value *R = V2; i == 0 || i < Remove.size(); ++i) { + if (i) R = VN.value(Remove[i]); // skip n2. + + // Try to replace the whole instruction. If we can, we're done. + Instruction *I2 = dyn_cast(R); + if (I2 && below(I2)) { + std::vector ToNotify; + for (Value::use_iterator UI = R->use_begin(), UE = R->use_end(); + UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast(TheUse.getUser())) + ToNotify.push_back(I); + } + + DOUT << "Simply removing " << *I2 + << ", replacing with " << *V1 << "\n"; + I2->replaceAllUsesWith(V1); + // leave it dead; it'll get erased later. + ++NumInstruction; + modified = true; + + for (std::vector::iterator II = ToNotify.begin(), + IE = ToNotify.end(); II != IE; ++II) { + opsToDef(*II); + } + + continue; + } + + // Otherwise, replace all dominated uses. + for (Value::use_iterator UI = R->use_begin(), UE = R->use_end(); + UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast(TheUse.getUser())) { + if (below(I)) { + TheUse.set(V1); + modified = true; + ++NumVarsReplaced; + opsToDef(I); + } + } + } + + // If that killed the instruction, stop here. + if (I2 && isInstructionTriviallyDead(I2)) { + DOUT << "Killed all uses of " << *I2 + << ", replacing with " << *V1 << "\n"; + continue; + } + + // If we make it to here, then we will need to create a node for N1. + // Otherwise, we can skip out early! + mergeIGNode = true; + } + + if (!isa(V1)) { + if (Remove.empty()) { + VR.mergeInto(&V2, 1, VN.getOrInsertVN(V1, Top), Top, this); + } else { + std::vector RemoveVals; + RemoveVals.reserve(Remove.size()); + + for (SetVector::iterator I = Remove.begin(), + E = Remove.end(); I != E; ++I) { + Value *V = VN.value(*I); + if (!V->use_empty()) + RemoveVals.push_back(V); + } + VR.mergeInto(&RemoveVals[0], RemoveVals.size(), + VN.getOrInsertVN(V1, Top), Top, this); + } + } + + if (mergeIGNode) { + // Create N1. + if (!n1) n1 = VN.getOrInsertVN(V1, Top); + + // Migrate relationships from removed nodes to N1. + for (SetVector::iterator I = Remove.begin(), E = Remove.end(); + I != E; ++I) { + unsigned n = *I; + for (Node::iterator NI = IG.node(n)->begin(), NE = IG.node(n)->end(); + NI != NE; ++NI) { + if (NI->Subtree->DominatedBy(Top)) { + if (NI->To == n1) { + assert((NI->LV & EQ_BIT) && "Node inequal to itself."); + continue; + } + if (Remove.count(NI->To)) + continue; + + IG.node(NI->To)->update(n1, reversePredicate(NI->LV), Top); + IG.node(n1)->update(NI->To, NI->LV, Top); + } + } + } + + // Point V2 (and all items in Remove) to N1. + if (!n2) + VN.addEquality(n1, V2, Top); + else { + for (SetVector::iterator I = Remove.begin(), + E = Remove.end(); I != E; ++I) { + VN.addEquality(n1, VN.value(*I), Top); + } + } + + // If !Remove.empty() then V2 = Remove[0]->getValue(). + // Even when Remove is empty, we still want to process V2. + i = 0; + for (Value *R = V2; i == 0 || i < Remove.size(); ++i) { + if (i) R = VN.value(Remove[i]); // skip n2. + + if (Instruction *I2 = dyn_cast(R)) { + if (aboveOrBelow(I2)) + defToOps(I2); + } + for (Value::use_iterator UI = V2->use_begin(), UE = V2->use_end(); + UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast(TheUse.getUser())) { + if (aboveOrBelow(I)) + opsToDef(I); + } + } + } + } + + // re-opsToDef all dominated users of V1. + if (Instruction *I = dyn_cast(V1)) { + for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); + UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + Value *V = TheUse.getUser(); + if (!V->use_empty()) { + if (Instruction *Inst = dyn_cast(V)) { + if (aboveOrBelow(Inst)) + opsToDef(Inst); + } + } + } + } + + return true; + } + + /// cmpInstToLattice - converts an CmpInst::Predicate to lattice value + /// Requires that the lattice value be valid; does not accept ICMP_EQ. + static LatticeVal cmpInstToLattice(ICmpInst::Predicate Pred) { + switch (Pred) { + case ICmpInst::ICMP_EQ: + assert(!"No matching lattice value."); + return static_cast(EQ_BIT); + default: + assert(!"Invalid 'icmp' predicate."); + case ICmpInst::ICMP_NE: + return NE; + case ICmpInst::ICMP_UGT: + return UGT; + case ICmpInst::ICMP_UGE: + return UGE; + case ICmpInst::ICMP_ULT: + return ULT; + case ICmpInst::ICMP_ULE: + return ULE; + case ICmpInst::ICMP_SGT: + return SGT; + case ICmpInst::ICMP_SGE: + return SGE; + case ICmpInst::ICMP_SLT: + return SLT; + case ICmpInst::ICMP_SLE: + return SLE; + } + } + + public: + VRPSolver(ValueNumbering &VN, InequalityGraph &IG, UnreachableBlocks &UB, + ValueRanges &VR, DomTreeDFS *DTDFS, bool &modified, + BasicBlock *TopBB) + : VN(VN), + IG(IG), + UB(UB), + VR(VR), + DTDFS(DTDFS), + Top(DTDFS->getNodeForBlock(TopBB)), + TopBB(TopBB), + TopInst(NULL), + modified(modified) + { + assert(Top && "VRPSolver created for unreachable basic block."); + } + + VRPSolver(ValueNumbering &VN, InequalityGraph &IG, UnreachableBlocks &UB, + ValueRanges &VR, DomTreeDFS *DTDFS, bool &modified, + Instruction *TopInst) + : VN(VN), + IG(IG), + UB(UB), + VR(VR), + DTDFS(DTDFS), + Top(DTDFS->getNodeForBlock(TopInst->getParent())), + TopBB(TopInst->getParent()), + TopInst(TopInst), + modified(modified) + { + assert(Top && "VRPSolver created for unreachable basic block."); + assert(Top->getBlock() == TopInst->getParent() && "Context mismatch."); + } + + bool isRelatedBy(Value *V1, Value *V2, ICmpInst::Predicate Pred) const { + if (Constant *C1 = dyn_cast(V1)) + if (Constant *C2 = dyn_cast(V2)) + return ConstantExpr::getCompare(Pred, C1, C2) == + ConstantInt::getTrue(); + + unsigned n1 = VN.valueNumber(V1, Top); + unsigned n2 = VN.valueNumber(V2, Top); + + if (n1 && n2) { + if (n1 == n2) return Pred == ICmpInst::ICMP_EQ || + Pred == ICmpInst::ICMP_ULE || + Pred == ICmpInst::ICMP_UGE || + Pred == ICmpInst::ICMP_SLE || + Pred == ICmpInst::ICMP_SGE; + if (Pred == ICmpInst::ICMP_EQ) return false; + if (IG.isRelatedBy(n1, n2, Top, cmpInstToLattice(Pred))) return true; + if (VR.isRelatedBy(n1, n2, Top, cmpInstToLattice(Pred))) return true; + } + + if ((n1 && !n2 && isa(V2)) || + (n2 && !n1 && isa(V1))) { + ConstantRange CR1 = n1 ? VR.range(n1, Top) : VR.range(V1); + ConstantRange CR2 = n2 ? VR.range(n2, Top) : VR.range(V2); + + if (Pred == ICmpInst::ICMP_EQ) + return CR1.isSingleElement() && + CR1.getSingleElement() == CR2.getSingleElement(); + + return VR.isRelatedBy(CR1, CR2, cmpInstToLattice(Pred)); + } + if (Pred == ICmpInst::ICMP_EQ) return V1 == V2; + return false; + } + + /// add - adds a new property to the work queue + void add(Value *V1, Value *V2, ICmpInst::Predicate Pred, + Instruction *I = NULL) { + DOUT << "adding " << *V1 << " " << Pred << " " << *V2; + if (I) DOUT << " context: " << *I; + else DOUT << " default context (" << Top->getDFSNumIn() << ")"; + DOUT << "\n"; + + assert(V1->getType() == V2->getType() && + "Can't relate two values with different types."); + + WorkList.push_back(Operation()); + Operation &O = WorkList.back(); + O.LHS = V1, O.RHS = V2, O.Op = Pred, O.ContextInst = I; + O.ContextBB = I ? I->getParent() : TopBB; + } + + /// defToOps - Given an instruction definition that we've learned something + /// new about, find any new relationships between its operands. + void defToOps(Instruction *I) { + Instruction *NewContext = below(I) ? I : TopInst; + Value *Canonical = VN.canonicalize(I, Top); + + if (BinaryOperator *BO = dyn_cast(I)) { + const Type *Ty = BO->getType(); + assert(!Ty->isFPOrFPVector() && "Float in work queue!"); + + Value *Op0 = VN.canonicalize(BO->getOperand(0), Top); + Value *Op1 = VN.canonicalize(BO->getOperand(1), Top); + + // TODO: "and i32 -1, %x" EQ %y then %x EQ %y. + + switch (BO->getOpcode()) { + case Instruction::And: { + // "and i32 %a, %b" EQ -1 then %a EQ -1 and %b EQ -1 + ConstantInt *CI = ConstantInt::getAllOnesValue(Ty); + if (Canonical == CI) { + add(CI, Op0, ICmpInst::ICMP_EQ, NewContext); + add(CI, Op1, ICmpInst::ICMP_EQ, NewContext); + } + } break; + case Instruction::Or: { + // "or i32 %a, %b" EQ 0 then %a EQ 0 and %b EQ 0 + Constant *Zero = Constant::getNullValue(Ty); + if (Canonical == Zero) { + add(Zero, Op0, ICmpInst::ICMP_EQ, NewContext); + add(Zero, Op1, ICmpInst::ICMP_EQ, NewContext); + } + } break; + case Instruction::Xor: { + // "xor i32 %c, %a" EQ %b then %a EQ %c ^ %b + // "xor i32 %c, %a" EQ %c then %a EQ 0 + // "xor i32 %c, %a" NE %c then %a NE 0 + // Repeat the above, with order of operands reversed. + Value *LHS = Op0; + Value *RHS = Op1; + if (!isa(LHS)) std::swap(LHS, RHS); + + if (ConstantInt *CI = dyn_cast(Canonical)) { + if (ConstantInt *Arg = dyn_cast(LHS)) { + add(RHS, ConstantInt::get(CI->getValue() ^ Arg->getValue()), + ICmpInst::ICMP_EQ, NewContext); + } + } + if (Canonical == LHS) { + if (isa(Canonical)) + add(RHS, Constant::getNullValue(Ty), ICmpInst::ICMP_EQ, + NewContext); + } else if (isRelatedBy(LHS, Canonical, ICmpInst::ICMP_NE)) { + add(RHS, Constant::getNullValue(Ty), ICmpInst::ICMP_NE, + NewContext); + } + } break; + default: + break; + } + } else if (ICmpInst *IC = dyn_cast(I)) { + // "icmp ult i32 %a, %y" EQ true then %a u< y + // etc. + + if (Canonical == ConstantInt::getTrue()) { + add(IC->getOperand(0), IC->getOperand(1), IC->getPredicate(), + NewContext); + } else if (Canonical == ConstantInt::getFalse()) { + add(IC->getOperand(0), IC->getOperand(1), + ICmpInst::getInversePredicate(IC->getPredicate()), NewContext); + } + } else if (SelectInst *SI = dyn_cast(I)) { + if (I->getType()->isFPOrFPVector()) return; + + // Given: "%a = select i1 %x, i32 %b, i32 %c" + // %a EQ %b and %b NE %c then %x EQ true + // %a EQ %c and %b NE %c then %x EQ false + + Value *True = SI->getTrueValue(); + Value *False = SI->getFalseValue(); + if (isRelatedBy(True, False, ICmpInst::ICMP_NE)) { + if (Canonical == VN.canonicalize(True, Top) || + isRelatedBy(Canonical, False, ICmpInst::ICMP_NE)) + add(SI->getCondition(), ConstantInt::getTrue(), + ICmpInst::ICMP_EQ, NewContext); + else if (Canonical == VN.canonicalize(False, Top) || + isRelatedBy(Canonical, True, ICmpInst::ICMP_NE)) + add(SI->getCondition(), ConstantInt::getFalse(), + ICmpInst::ICMP_EQ, NewContext); + } + } else if (GetElementPtrInst *GEPI = dyn_cast(I)) { + for (GetElementPtrInst::op_iterator OI = GEPI->idx_begin(), + OE = GEPI->idx_end(); OI != OE; ++OI) { + ConstantInt *Op = dyn_cast(VN.canonicalize(*OI, Top)); + if (!Op || !Op->isZero()) return; + } + // TODO: The GEPI indices are all zero. Copy from definition to operand, + // jumping the type plane as needed. + if (isRelatedBy(GEPI, Constant::getNullValue(GEPI->getType()), + ICmpInst::ICMP_NE)) { + Value *Ptr = GEPI->getPointerOperand(); + add(Ptr, Constant::getNullValue(Ptr->getType()), ICmpInst::ICMP_NE, + NewContext); + } + } else if (CastInst *CI = dyn_cast(I)) { + const Type *SrcTy = CI->getSrcTy(); + + unsigned ci = VN.getOrInsertVN(CI, Top); + uint32_t W = VR.typeToWidth(SrcTy); + if (!W) return; + ConstantRange CR = VR.range(ci, Top); + + if (CR.isFullSet()) return; + + switch (CI->getOpcode()) { + default: break; + case Instruction::ZExt: + case Instruction::SExt: + VR.applyRange(VN.getOrInsertVN(CI->getOperand(0), Top), + CR.truncate(W), Top, this); + break; + case Instruction::BitCast: + VR.applyRange(VN.getOrInsertVN(CI->getOperand(0), Top), + CR, Top, this); + break; + } + } + } + + /// opsToDef - A new relationship was discovered involving one of this + /// instruction's operands. Find any new relationship involving the + /// definition, or another operand. + void opsToDef(Instruction *I) { + Instruction *NewContext = below(I) ? I : TopInst; + + if (BinaryOperator *BO = dyn_cast(I)) { + Value *Op0 = VN.canonicalize(BO->getOperand(0), Top); + Value *Op1 = VN.canonicalize(BO->getOperand(1), Top); + + if (ConstantInt *CI0 = dyn_cast(Op0)) + if (ConstantInt *CI1 = dyn_cast(Op1)) { + add(BO, ConstantExpr::get(BO->getOpcode(), CI0, CI1), + ICmpInst::ICMP_EQ, NewContext); + return; + } + + // "%y = and i1 true, %x" then %x EQ %y + // "%y = or i1 false, %x" then %x EQ %y + // "%x = add i32 %y, 0" then %x EQ %y + // "%x = mul i32 %y, 0" then %x EQ 0 + + Instruction::BinaryOps Opcode = BO->getOpcode(); + const Type *Ty = BO->getType(); + assert(!Ty->isFPOrFPVector() && "Float in work queue!"); + + Constant *Zero = Constant::getNullValue(Ty); + ConstantInt *AllOnes = ConstantInt::getAllOnesValue(Ty); + + switch (Opcode) { + default: break; + case Instruction::LShr: + case Instruction::AShr: + case Instruction::Shl: + case Instruction::Sub: + if (Op1 == Zero) { + add(BO, Op0, ICmpInst::ICMP_EQ, NewContext); + return; + } + break; + case Instruction::Or: + if (Op0 == AllOnes || Op1 == AllOnes) { + add(BO, AllOnes, ICmpInst::ICMP_EQ, NewContext); + return; + } // fall-through + case Instruction::Xor: + case Instruction::Add: + if (Op0 == Zero) { + add(BO, Op1, ICmpInst::ICMP_EQ, NewContext); + return; + } else if (Op1 == Zero) { + add(BO, Op0, ICmpInst::ICMP_EQ, NewContext); + return; + } + break; + case Instruction::And: + if (Op0 == AllOnes) { + add(BO, Op1, ICmpInst::ICMP_EQ, NewContext); + return; + } else if (Op1 == AllOnes) { + add(BO, Op0, ICmpInst::ICMP_EQ, NewContext); + return; + } + // fall-through + case Instruction::Mul: + if (Op0 == Zero || Op1 == Zero) { + add(BO, Zero, ICmpInst::ICMP_EQ, NewContext); + return; + } + break; + } + + // "%x = add i32 %y, %z" and %x EQ %y then %z EQ 0 + // "%x = add i32 %y, %z" and %x EQ %z then %y EQ 0 + // "%x = shl i32 %y, %z" and %x EQ %y and %y NE 0 then %z EQ 0 + // "%x = udiv i32 %y, %z" and %x EQ %y then %z EQ 1 + + Value *Known = Op0, *Unknown = Op1, + *TheBO = VN.canonicalize(BO, Top); + if (Known != TheBO) std::swap(Known, Unknown); + if (Known == TheBO) { + switch (Opcode) { + default: break; + case Instruction::LShr: + case Instruction::AShr: + case Instruction::Shl: + if (!isRelatedBy(Known, Zero, ICmpInst::ICMP_NE)) break; + // otherwise, fall-through. + case Instruction::Sub: + if (Unknown == Op1) break; + // otherwise, fall-through. + case Instruction::Xor: + case Instruction::Add: + add(Unknown, Zero, ICmpInst::ICMP_EQ, NewContext); + break; + case Instruction::UDiv: + case Instruction::SDiv: + if (Unknown == Op1) break; + if (isRelatedBy(Known, Zero, ICmpInst::ICMP_NE)) { + Constant *One = ConstantInt::get(Ty, 1); + add(Unknown, One, ICmpInst::ICMP_EQ, NewContext); + } + break; + } + } + + // TODO: "%a = add i32 %b, 1" and %b > %z then %a >= %z. + + } else if (ICmpInst *IC = dyn_cast(I)) { + // "%a = icmp ult i32 %b, %c" and %b u< %c then %a EQ true + // "%a = icmp ult i32 %b, %c" and %b u>= %c then %a EQ false + // etc. + + Value *Op0 = VN.canonicalize(IC->getOperand(0), Top); + Value *Op1 = VN.canonicalize(IC->getOperand(1), Top); + + ICmpInst::Predicate Pred = IC->getPredicate(); + if (isRelatedBy(Op0, Op1, Pred)) + add(IC, ConstantInt::getTrue(), ICmpInst::ICMP_EQ, NewContext); + else if (isRelatedBy(Op0, Op1, ICmpInst::getInversePredicate(Pred))) + add(IC, ConstantInt::getFalse(), ICmpInst::ICMP_EQ, NewContext); + + } else if (SelectInst *SI = dyn_cast(I)) { + if (I->getType()->isFPOrFPVector()) return; + + // Given: "%a = select i1 %x, i32 %b, i32 %c" + // %x EQ true then %a EQ %b + // %x EQ false then %a EQ %c + // %b EQ %c then %a EQ %b + + Value *Canonical = VN.canonicalize(SI->getCondition(), Top); + if (Canonical == ConstantInt::getTrue()) { + add(SI, SI->getTrueValue(), ICmpInst::ICMP_EQ, NewContext); + } else if (Canonical == ConstantInt::getFalse()) { + add(SI, SI->getFalseValue(), ICmpInst::ICMP_EQ, NewContext); + } else if (VN.canonicalize(SI->getTrueValue(), Top) == + VN.canonicalize(SI->getFalseValue(), Top)) { + add(SI, SI->getTrueValue(), ICmpInst::ICMP_EQ, NewContext); + } + } else if (CastInst *CI = dyn_cast(I)) { + const Type *DestTy = CI->getDestTy(); + if (DestTy->isFPOrFPVector()) return; + + Value *Op = VN.canonicalize(CI->getOperand(0), Top); + Instruction::CastOps Opcode = CI->getOpcode(); + + if (Constant *C = dyn_cast(Op)) { + add(CI, ConstantExpr::getCast(Opcode, C, DestTy), + ICmpInst::ICMP_EQ, NewContext); + } + + uint32_t W = VR.typeToWidth(DestTy); + unsigned ci = VN.getOrInsertVN(CI, Top); + ConstantRange CR = VR.range(VN.getOrInsertVN(Op, Top), Top); + + if (!CR.isFullSet()) { + switch (Opcode) { + default: break; + case Instruction::ZExt: + VR.applyRange(ci, CR.zeroExtend(W), Top, this); + break; + case Instruction::SExt: + VR.applyRange(ci, CR.signExtend(W), Top, this); + break; + case Instruction::Trunc: { + ConstantRange Result = CR.truncate(W); + if (!Result.isFullSet()) + VR.applyRange(ci, Result, Top, this); + } break; + case Instruction::BitCast: + VR.applyRange(ci, CR, Top, this); + break; + // TODO: other casts? + } + } + } else if (GetElementPtrInst *GEPI = dyn_cast(I)) { + for (GetElementPtrInst::op_iterator OI = GEPI->idx_begin(), + OE = GEPI->idx_end(); OI != OE; ++OI) { + ConstantInt *Op = dyn_cast(VN.canonicalize(*OI, Top)); + if (!Op || !Op->isZero()) return; + } + // TODO: The GEPI indices are all zero. Copy from operand to definition, + // jumping the type plane as needed. + Value *Ptr = GEPI->getPointerOperand(); + if (isRelatedBy(Ptr, Constant::getNullValue(Ptr->getType()), + ICmpInst::ICMP_NE)) { + add(GEPI, Constant::getNullValue(GEPI->getType()), ICmpInst::ICMP_NE, + NewContext); + } + } + } + + /// solve - process the work queue + void solve() { + //DOUT << "WorkList entry, size: " << WorkList.size() << "\n"; + while (!WorkList.empty()) { + //DOUT << "WorkList size: " << WorkList.size() << "\n"; + + Operation &O = WorkList.front(); + TopInst = O.ContextInst; + TopBB = O.ContextBB; + Top = DTDFS->getNodeForBlock(TopBB); // XXX move this into Context + + O.LHS = VN.canonicalize(O.LHS, Top); + O.RHS = VN.canonicalize(O.RHS, Top); + + assert(O.LHS == VN.canonicalize(O.LHS, Top) && "Canonicalize isn't."); + assert(O.RHS == VN.canonicalize(O.RHS, Top) && "Canonicalize isn't."); + + DOUT << "solving " << *O.LHS << " " << O.Op << " " << *O.RHS; + if (O.ContextInst) DOUT << " context inst: " << *O.ContextInst; + else DOUT << " context block: " << O.ContextBB->getName(); + DOUT << "\n"; + + DEBUG(VN.dump()); + DEBUG(IG.dump()); + DEBUG(VR.dump()); + + // If they're both Constant, skip it. Check for contradiction and mark + // the BB as unreachable if so. + if (Constant *CI_L = dyn_cast(O.LHS)) { + if (Constant *CI_R = dyn_cast(O.RHS)) { + if (ConstantExpr::getCompare(O.Op, CI_L, CI_R) == + ConstantInt::getFalse()) + UB.mark(TopBB); + + WorkList.pop_front(); + continue; + } + } + + if (VN.compare(O.LHS, O.RHS)) { + std::swap(O.LHS, O.RHS); + O.Op = ICmpInst::getSwappedPredicate(O.Op); + } + + if (O.Op == ICmpInst::ICMP_EQ) { + if (!makeEqual(O.RHS, O.LHS)) + UB.mark(TopBB); + } else { + LatticeVal LV = cmpInstToLattice(O.Op); + + if ((LV & EQ_BIT) && + isRelatedBy(O.LHS, O.RHS, ICmpInst::getSwappedPredicate(O.Op))) { + if (!makeEqual(O.RHS, O.LHS)) + UB.mark(TopBB); + } else { + if (isRelatedBy(O.LHS, O.RHS, ICmpInst::getInversePredicate(O.Op))){ + UB.mark(TopBB); + WorkList.pop_front(); + continue; + } + + unsigned n1 = VN.getOrInsertVN(O.LHS, Top); + unsigned n2 = VN.getOrInsertVN(O.RHS, Top); + + if (n1 == n2) { + if (O.Op != ICmpInst::ICMP_UGE && O.Op != ICmpInst::ICMP_ULE && + O.Op != ICmpInst::ICMP_SGE && O.Op != ICmpInst::ICMP_SLE) + UB.mark(TopBB); + + WorkList.pop_front(); + continue; + } + + if (VR.isRelatedBy(n1, n2, Top, LV) || + IG.isRelatedBy(n1, n2, Top, LV)) { + WorkList.pop_front(); + continue; + } + + VR.addInequality(n1, n2, Top, LV, this); + if ((!isa(O.RHS) && !isa(O.LHS)) || + LV == NE) + IG.addInequality(n1, n2, Top, LV); + + if (Instruction *I1 = dyn_cast(O.LHS)) { + if (aboveOrBelow(I1)) + defToOps(I1); + } + if (isa(O.LHS) || isa(O.LHS)) { + for (Value::use_iterator UI = O.LHS->use_begin(), + UE = O.LHS->use_end(); UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast(TheUse.getUser())) { + if (aboveOrBelow(I)) + opsToDef(I); + } + } + } + if (Instruction *I2 = dyn_cast(O.RHS)) { + if (aboveOrBelow(I2)) + defToOps(I2); + } + if (isa(O.RHS) || isa(O.RHS)) { + for (Value::use_iterator UI = O.RHS->use_begin(), + UE = O.RHS->use_end(); UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast(TheUse.getUser())) { + if (aboveOrBelow(I)) + opsToDef(I); + } + } + } + } + } + WorkList.pop_front(); + } + } + }; + + void ValueRanges::addToWorklist(Value *V, Constant *C, + ICmpInst::Predicate Pred, VRPSolver *VRP) { + VRP->add(V, C, Pred, VRP->TopInst); + } + + void ValueRanges::markBlock(VRPSolver *VRP) { + VRP->UB.mark(VRP->TopBB); + } + + /// PredicateSimplifier - This class is a simplifier that replaces + /// one equivalent variable with another. It also tracks what + /// can't be equal and will solve setcc instructions when possible. + /// @brief Root of the predicate simplifier optimization. + class VISIBILITY_HIDDEN PredicateSimplifier : public FunctionPass { + DomTreeDFS *DTDFS; + bool modified; + ValueNumbering *VN; + InequalityGraph *IG; + UnreachableBlocks UB; + ValueRanges *VR; + + std::vector WorkList; + + public: + static char ID; // Pass identification, replacement for typeid + PredicateSimplifier() : FunctionPass((intptr_t)&ID) {} + + bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(BreakCriticalEdgesID); + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); + } + + private: + /// Forwards - Adds new properties to VRPSolver and uses them to + /// simplify instructions. Because new properties sometimes apply to + /// a transition from one BasicBlock to another, this will use the + /// PredicateSimplifier::proceedToSuccessor(s) interface to enter the + /// basic block. + /// @brief Performs abstract execution of the program. + class VISIBILITY_HIDDEN Forwards : public InstVisitor { + friend class InstVisitor; + PredicateSimplifier *PS; + DomTreeDFS::Node *DTNode; + + public: + ValueNumbering &VN; + InequalityGraph &IG; + UnreachableBlocks &UB; + ValueRanges &VR; + + Forwards(PredicateSimplifier *PS, DomTreeDFS::Node *DTNode) + : PS(PS), DTNode(DTNode), VN(*PS->VN), IG(*PS->IG), UB(PS->UB), + VR(*PS->VR) {} + + void visitTerminatorInst(TerminatorInst &TI); + void visitBranchInst(BranchInst &BI); + void visitSwitchInst(SwitchInst &SI); + + void visitAllocaInst(AllocaInst &AI); + void visitLoadInst(LoadInst &LI); + void visitStoreInst(StoreInst &SI); + + void visitSExtInst(SExtInst &SI); + void visitZExtInst(ZExtInst &ZI); + + void visitBinaryOperator(BinaryOperator &BO); + void visitICmpInst(ICmpInst &IC); + }; + + // Used by terminator instructions to proceed from the current basic + // block to the next. Verifies that "current" dominates "next", + // then calls visitBasicBlock. + void proceedToSuccessors(DomTreeDFS::Node *Current) { + for (DomTreeDFS::Node::iterator I = Current->begin(), + E = Current->end(); I != E; ++I) { + WorkList.push_back(*I); + } + } + + void proceedToSuccessor(DomTreeDFS::Node *Next) { + WorkList.push_back(Next); + } + + // Visits each instruction in the basic block. + void visitBasicBlock(DomTreeDFS::Node *Node) { + BasicBlock *BB = Node->getBlock(); + DOUT << "Entering Basic Block: " << BB->getName() + << " (" << Node->getDFSNumIn() << ")\n"; + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { + visitInstruction(I++, Node); + } + } + + // Tries to simplify each Instruction and add new properties. + void visitInstruction(Instruction *I, DomTreeDFS::Node *DT) { + DOUT << "Considering instruction " << *I << "\n"; + DEBUG(VN->dump()); + DEBUG(IG->dump()); + DEBUG(VR->dump()); + + // Sometimes instructions are killed in earlier analysis. + if (isInstructionTriviallyDead(I)) { + ++NumSimple; + modified = true; + if (unsigned n = VN->valueNumber(I, DTDFS->getRootNode())) + if (VN->value(n) == I) IG->remove(n); + VN->remove(I); + I->eraseFromParent(); + return; + } + +#ifndef NDEBUG + // Try to replace the whole instruction. + Value *V = VN->canonicalize(I, DT); + assert(V == I && "Late instruction canonicalization."); + if (V != I) { + modified = true; + ++NumInstruction; + DOUT << "Removing " << *I << ", replacing with " << *V << "\n"; + if (unsigned n = VN->valueNumber(I, DTDFS->getRootNode())) + if (VN->value(n) == I) IG->remove(n); + VN->remove(I); + I->replaceAllUsesWith(V); + I->eraseFromParent(); + return; + } + + // Try to substitute operands. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + Value *Oper = I->getOperand(i); + Value *V = VN->canonicalize(Oper, DT); + assert(V == Oper && "Late operand canonicalization."); + if (V != Oper) { + modified = true; + ++NumVarsReplaced; + DOUT << "Resolving " << *I; + I->setOperand(i, V); + DOUT << " into " << *I; + } + } +#endif + + std::string name = I->getParent()->getName(); + DOUT << "push (%" << name << ")\n"; + Forwards visit(this, DT); + visit.visit(*I); + DOUT << "pop (%" << name << ")\n"; + } + }; + + bool PredicateSimplifier::runOnFunction(Function &F) { + DominatorTree *DT = &getAnalysis(); + DTDFS = new DomTreeDFS(DT); + TargetData *TD = &getAnalysis(); + + DOUT << "Entering Function: " << F.getName() << "\n"; + + modified = false; + DomTreeDFS::Node *Root = DTDFS->getRootNode(); + VN = new ValueNumbering(DTDFS); + IG = new InequalityGraph(*VN, Root); + VR = new ValueRanges(*VN, TD); + WorkList.push_back(Root); + + do { + DomTreeDFS::Node *DTNode = WorkList.back(); + WorkList.pop_back(); + if (!UB.isDead(DTNode->getBlock())) visitBasicBlock(DTNode); + } while (!WorkList.empty()); + + delete DTDFS; + delete VR; + delete IG; + + modified |= UB.kill(); + + return modified; + } + + void PredicateSimplifier::Forwards::visitTerminatorInst(TerminatorInst &TI) { + PS->proceedToSuccessors(DTNode); + } + + void PredicateSimplifier::Forwards::visitBranchInst(BranchInst &BI) { + if (BI.isUnconditional()) { + PS->proceedToSuccessors(DTNode); + return; + } + + Value *Condition = BI.getCondition(); + BasicBlock *TrueDest = BI.getSuccessor(0); + BasicBlock *FalseDest = BI.getSuccessor(1); + + if (isa(Condition) || TrueDest == FalseDest) { + PS->proceedToSuccessors(DTNode); + return; + } + + for (DomTreeDFS::Node::iterator I = DTNode->begin(), E = DTNode->end(); + I != E; ++I) { + BasicBlock *Dest = (*I)->getBlock(); + DOUT << "Branch thinking about %" << Dest->getName() + << "(" << PS->DTDFS->getNodeForBlock(Dest)->getDFSNumIn() << ")\n"; + + if (Dest == TrueDest) { + DOUT << "(" << DTNode->getBlock()->getName() << ") true set:\n"; + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, Dest); + VRP.add(ConstantInt::getTrue(), Condition, ICmpInst::ICMP_EQ); + VRP.solve(); + DEBUG(VN.dump()); + DEBUG(IG.dump()); + DEBUG(VR.dump()); + } else if (Dest == FalseDest) { + DOUT << "(" << DTNode->getBlock()->getName() << ") false set:\n"; + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, Dest); + VRP.add(ConstantInt::getFalse(), Condition, ICmpInst::ICMP_EQ); + VRP.solve(); + DEBUG(VN.dump()); + DEBUG(IG.dump()); + DEBUG(VR.dump()); + } + + PS->proceedToSuccessor(*I); + } + } + + void PredicateSimplifier::Forwards::visitSwitchInst(SwitchInst &SI) { + Value *Condition = SI.getCondition(); + + // Set the EQProperty in each of the cases BBs, and the NEProperties + // in the default BB. + + for (DomTreeDFS::Node::iterator I = DTNode->begin(), E = DTNode->end(); + I != E; ++I) { + BasicBlock *BB = (*I)->getBlock(); + DOUT << "Switch thinking about BB %" << BB->getName() + << "(" << PS->DTDFS->getNodeForBlock(BB)->getDFSNumIn() << ")\n"; + + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, BB); + if (BB == SI.getDefaultDest()) { + for (unsigned i = 1, e = SI.getNumCases(); i < e; ++i) + if (SI.getSuccessor(i) != BB) + VRP.add(Condition, SI.getCaseValue(i), ICmpInst::ICMP_NE); + VRP.solve(); + } else if (ConstantInt *CI = SI.findCaseDest(BB)) { + VRP.add(Condition, CI, ICmpInst::ICMP_EQ); + VRP.solve(); + } + PS->proceedToSuccessor(*I); + } + } + + void PredicateSimplifier::Forwards::visitAllocaInst(AllocaInst &AI) { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &AI); + VRP.add(Constant::getNullValue(AI.getType()), &AI, ICmpInst::ICMP_NE); + VRP.solve(); + } + + void PredicateSimplifier::Forwards::visitLoadInst(LoadInst &LI) { + Value *Ptr = LI.getPointerOperand(); + // avoid "load uint* null" -> null NE null. + if (isa(Ptr)) return; + + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &LI); + VRP.add(Constant::getNullValue(Ptr->getType()), Ptr, ICmpInst::ICMP_NE); + VRP.solve(); + } + + void PredicateSimplifier::Forwards::visitStoreInst(StoreInst &SI) { + Value *Ptr = SI.getPointerOperand(); + if (isa(Ptr)) return; + + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &SI); + VRP.add(Constant::getNullValue(Ptr->getType()), Ptr, ICmpInst::ICMP_NE); + VRP.solve(); + } + + void PredicateSimplifier::Forwards::visitSExtInst(SExtInst &SI) { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &SI); + uint32_t SrcBitWidth = cast(SI.getSrcTy())->getBitWidth(); + uint32_t DstBitWidth = cast(SI.getDestTy())->getBitWidth(); + APInt Min(APInt::getHighBitsSet(DstBitWidth, DstBitWidth-SrcBitWidth+1)); + APInt Max(APInt::getLowBitsSet(DstBitWidth, SrcBitWidth-1)); + VRP.add(ConstantInt::get(Min), &SI, ICmpInst::ICMP_SLE); + VRP.add(ConstantInt::get(Max), &SI, ICmpInst::ICMP_SGE); + VRP.solve(); + } + + void PredicateSimplifier::Forwards::visitZExtInst(ZExtInst &ZI) { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &ZI); + uint32_t SrcBitWidth = cast(ZI.getSrcTy())->getBitWidth(); + uint32_t DstBitWidth = cast(ZI.getDestTy())->getBitWidth(); + APInt Max(APInt::getLowBitsSet(DstBitWidth, SrcBitWidth)); + VRP.add(ConstantInt::get(Max), &ZI, ICmpInst::ICMP_UGE); + VRP.solve(); + } + + void PredicateSimplifier::Forwards::visitBinaryOperator(BinaryOperator &BO) { + Instruction::BinaryOps ops = BO.getOpcode(); + + switch (ops) { + default: break; + case Instruction::URem: + case Instruction::SRem: + case Instruction::UDiv: + case Instruction::SDiv: { + Value *Divisor = BO.getOperand(1); + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(Constant::getNullValue(Divisor->getType()), Divisor, + ICmpInst::ICMP_NE); + VRP.solve(); + break; + } + } + + switch (ops) { + default: break; + case Instruction::Shl: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_UGE); + VRP.solve(); + } break; + case Instruction::AShr: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_SLE); + VRP.solve(); + } break; + case Instruction::LShr: + case Instruction::UDiv: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_ULE); + VRP.solve(); + } break; + case Instruction::URem: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(1), ICmpInst::ICMP_ULE); + VRP.solve(); + } break; + case Instruction::And: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_ULE); + VRP.add(&BO, BO.getOperand(1), ICmpInst::ICMP_ULE); + VRP.solve(); + } break; + case Instruction::Or: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_UGE); + VRP.add(&BO, BO.getOperand(1), ICmpInst::ICMP_UGE); + VRP.solve(); + } break; + } + } + + void PredicateSimplifier::Forwards::visitICmpInst(ICmpInst &IC) { + // If possible, squeeze the ICmp predicate into something simpler. + // Eg., if x = [0, 4) and we're being asked icmp uge %x, 3 then change + // the predicate to eq. + + // XXX: once we do full PHI handling, modifying the instruction in the + // Forwards visitor will cause missed optimizations. + + ICmpInst::Predicate Pred = IC.getPredicate(); + + switch (Pred) { + default: break; + case ICmpInst::ICMP_ULE: Pred = ICmpInst::ICMP_ULT; break; + case ICmpInst::ICMP_UGE: Pred = ICmpInst::ICMP_UGT; break; + case ICmpInst::ICMP_SLE: Pred = ICmpInst::ICMP_SLT; break; + case ICmpInst::ICMP_SGE: Pred = ICmpInst::ICMP_SGT; break; + } + if (Pred != IC.getPredicate()) { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &IC); + if (VRP.isRelatedBy(IC.getOperand(1), IC.getOperand(0), + ICmpInst::ICMP_NE)) { + ++NumSnuggle; + PS->modified = true; + IC.setPredicate(Pred); + } + } + + Pred = IC.getPredicate(); + + if (ConstantInt *Op1 = dyn_cast(IC.getOperand(1))) { + ConstantInt *NextVal = 0; + switch (Pred) { + default: break; + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_ULT: + if (Op1->getValue() != 0) + NextVal = ConstantInt::get(Op1->getValue()-1); + break; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + if (!Op1->getValue().isAllOnesValue()) + NextVal = ConstantInt::get(Op1->getValue()+1); + break; + + } + if (NextVal) { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &IC); + if (VRP.isRelatedBy(IC.getOperand(0), NextVal, + ICmpInst::getInversePredicate(Pred))) { + ICmpInst *NewIC = new ICmpInst(ICmpInst::ICMP_EQ, IC.getOperand(0), + NextVal, "", &IC); + NewIC->takeName(&IC); + IC.replaceAllUsesWith(NewIC); + + // XXX: prove this isn't necessary + if (unsigned n = VN.valueNumber(&IC, PS->DTDFS->getRootNode())) + if (VN.value(n) == &IC) IG.remove(n); + VN.remove(&IC); + + IC.eraseFromParent(); + ++NumSnuggle; + PS->modified = true; + } + } + } + } + + char PredicateSimplifier::ID = 0; + RegisterPass X("predsimplify", + "Predicate Simplifier"); +} + +FunctionPass *llvm::createPredicateSimplifierPass() { + return new PredicateSimplifier(); +} diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp new file mode 100644 index 0000000..95f9e7b --- /dev/null +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -0,0 +1,868 @@ +//===- Reassociate.cpp - Reassociate binary expressions -------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass reassociates commutative expressions in an order that is designed +// to promote better constant propagation, GCSE, LICM, PRE... +// +// For example: 4 + (x + 5) -> x + (4 + 5) +// +// In the implementation of this algorithm, constants are assigned rank = 0, +// function arguments are rank = 1, and other values are assigned ranks +// corresponding to the reverse post order traversal of current function +// (starting at 2), which effectively gives values in deep loops higher rank +// than values not in loops. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "reassociate" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/Statistic.h" +#include +using namespace llvm; + +STATISTIC(NumLinear , "Number of insts linearized"); +STATISTIC(NumChanged, "Number of insts reassociated"); +STATISTIC(NumAnnihil, "Number of expr tree annihilated"); +STATISTIC(NumFactor , "Number of multiplies factored"); + +namespace { + struct VISIBILITY_HIDDEN ValueEntry { + unsigned Rank; + Value *Op; + ValueEntry(unsigned R, Value *O) : Rank(R), Op(O) {} + }; + inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) { + return LHS.Rank > RHS.Rank; // Sort so that highest rank goes to start. + } +} + +/// PrintOps - Print out the expression identified in the Ops list. +/// +static void PrintOps(Instruction *I, const std::vector &Ops) { + Module *M = I->getParent()->getParent()->getParent(); + cerr << Instruction::getOpcodeName(I->getOpcode()) << " " + << *Ops[0].Op->getType(); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + WriteAsOperand(*cerr.stream() << " ", Ops[i].Op, false, M) + << "," << Ops[i].Rank; +} + +namespace { + class VISIBILITY_HIDDEN Reassociate : public FunctionPass { + std::map RankMap; + std::map ValueRankMap; + bool MadeChange; + public: + static char ID; // Pass identification, replacement for typeid + Reassociate() : FunctionPass((intptr_t)&ID) {} + + bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + } + private: + void BuildRankMap(Function &F); + unsigned getRank(Value *V); + void ReassociateExpression(BinaryOperator *I); + void RewriteExprTree(BinaryOperator *I, std::vector &Ops, + unsigned Idx = 0); + Value *OptimizeExpression(BinaryOperator *I, std::vector &Ops); + void LinearizeExprTree(BinaryOperator *I, std::vector &Ops); + void LinearizeExpr(BinaryOperator *I); + Value *RemoveFactorFromExpression(Value *V, Value *Factor); + void ReassociateBB(BasicBlock *BB); + + void RemoveDeadBinaryOp(Value *V); + }; + + char Reassociate::ID = 0; + RegisterPass X("reassociate", "Reassociate expressions"); +} + +// Public interface to the Reassociate pass +FunctionPass *llvm::createReassociatePass() { return new Reassociate(); } + +void Reassociate::RemoveDeadBinaryOp(Value *V) { + Instruction *Op = dyn_cast(V); + if (!Op || !isa(Op) || !isa(Op) || !Op->use_empty()) + return; + + Value *LHS = Op->getOperand(0), *RHS = Op->getOperand(1); + RemoveDeadBinaryOp(LHS); + RemoveDeadBinaryOp(RHS); +} + + +static bool isUnmovableInstruction(Instruction *I) { + if (I->getOpcode() == Instruction::PHI || + I->getOpcode() == Instruction::Alloca || + I->getOpcode() == Instruction::Load || + I->getOpcode() == Instruction::Malloc || + I->getOpcode() == Instruction::Invoke || + I->getOpcode() == Instruction::Call || + I->getOpcode() == Instruction::UDiv || + I->getOpcode() == Instruction::SDiv || + I->getOpcode() == Instruction::FDiv || + I->getOpcode() == Instruction::URem || + I->getOpcode() == Instruction::SRem || + I->getOpcode() == Instruction::FRem) + return true; + return false; +} + +void Reassociate::BuildRankMap(Function &F) { + unsigned i = 2; + + // Assign distinct ranks to function arguments + for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) + ValueRankMap[I] = ++i; + + ReversePostOrderTraversal RPOT(&F); + for (ReversePostOrderTraversal::rpo_iterator I = RPOT.begin(), + E = RPOT.end(); I != E; ++I) { + BasicBlock *BB = *I; + unsigned BBRank = RankMap[BB] = ++i << 16; + + // Walk the basic block, adding precomputed ranks for any instructions that + // we cannot move. This ensures that the ranks for these instructions are + // all different in the block. + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (isUnmovableInstruction(I)) + ValueRankMap[I] = ++BBRank; + } +} + +unsigned Reassociate::getRank(Value *V) { + if (isa(V)) return ValueRankMap[V]; // Function argument... + + Instruction *I = dyn_cast(V); + if (I == 0) return 0; // Otherwise it's a global or constant, rank 0. + + unsigned &CachedRank = ValueRankMap[I]; + if (CachedRank) return CachedRank; // Rank already known? + + // If this is an expression, return the 1+MAX(rank(LHS), rank(RHS)) so that + // we can reassociate expressions for code motion! Since we do not recurse + // for PHI nodes, we cannot have infinite recursion here, because there + // cannot be loops in the value graph that do not go through PHI nodes. + unsigned Rank = 0, MaxRank = RankMap[I->getParent()]; + for (unsigned i = 0, e = I->getNumOperands(); + i != e && Rank != MaxRank; ++i) + Rank = std::max(Rank, getRank(I->getOperand(i))); + + // If this is a not or neg instruction, do not count it for rank. This + // assures us that X and ~X will have the same rank. + if (!I->getType()->isInteger() || + (!BinaryOperator::isNot(I) && !BinaryOperator::isNeg(I))) + ++Rank; + + //DOUT << "Calculated Rank[" << V->getName() << "] = " + // << Rank << "\n"; + + return CachedRank = Rank; +} + +/// isReassociableOp - Return true if V is an instruction of the specified +/// opcode and if it only has one use. +static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode) { + if ((V->hasOneUse() || V->use_empty()) && isa(V) && + cast(V)->getOpcode() == Opcode) + return cast(V); + return 0; +} + +/// LowerNegateToMultiply - Replace 0-X with X*-1. +/// +static Instruction *LowerNegateToMultiply(Instruction *Neg) { + Constant *Cst = ConstantInt::getAllOnesValue(Neg->getType()); + + Instruction *Res = BinaryOperator::createMul(Neg->getOperand(1), Cst, "",Neg); + Res->takeName(Neg); + Neg->replaceAllUsesWith(Res); + Neg->eraseFromParent(); + return Res; +} + +// Given an expression of the form '(A+B)+(D+C)', turn it into '(((A+B)+C)+D)'. +// Note that if D is also part of the expression tree that we recurse to +// linearize it as well. Besides that case, this does not recurse into A,B, or +// C. +void Reassociate::LinearizeExpr(BinaryOperator *I) { + BinaryOperator *LHS = cast(I->getOperand(0)); + BinaryOperator *RHS = cast(I->getOperand(1)); + assert(isReassociableOp(LHS, I->getOpcode()) && + isReassociableOp(RHS, I->getOpcode()) && + "Not an expression that needs linearization?"); + + DOUT << "Linear" << *LHS << *RHS << *I; + + // Move the RHS instruction to live immediately before I, avoiding breaking + // dominator properties. + RHS->moveBefore(I); + + // Move operands around to do the linearization. + I->setOperand(1, RHS->getOperand(0)); + RHS->setOperand(0, LHS); + I->setOperand(0, RHS); + + ++NumLinear; + MadeChange = true; + DOUT << "Linearized: " << *I; + + // If D is part of this expression tree, tail recurse. + if (isReassociableOp(I->getOperand(1), I->getOpcode())) + LinearizeExpr(I); +} + + +/// LinearizeExprTree - Given an associative binary expression tree, traverse +/// all of the uses putting it into canonical form. This forces a left-linear +/// form of the the expression (((a+b)+c)+d), and collects information about the +/// rank of the non-tree operands. +/// +/// NOTE: These intentionally destroys the expression tree operands (turning +/// them into undef values) to reduce #uses of the values. This means that the +/// caller MUST use something like RewriteExprTree to put the values back in. +/// +void Reassociate::LinearizeExprTree(BinaryOperator *I, + std::vector &Ops) { + Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); + unsigned Opcode = I->getOpcode(); + + // First step, linearize the expression if it is in ((A+B)+(C+D)) form. + BinaryOperator *LHSBO = isReassociableOp(LHS, Opcode); + BinaryOperator *RHSBO = isReassociableOp(RHS, Opcode); + + // If this is a multiply expression tree and it contains internal negations, + // transform them into multiplies by -1 so they can be reassociated. + if (I->getOpcode() == Instruction::Mul) { + if (!LHSBO && LHS->hasOneUse() && BinaryOperator::isNeg(LHS)) { + LHS = LowerNegateToMultiply(cast(LHS)); + LHSBO = isReassociableOp(LHS, Opcode); + } + if (!RHSBO && RHS->hasOneUse() && BinaryOperator::isNeg(RHS)) { + RHS = LowerNegateToMultiply(cast(RHS)); + RHSBO = isReassociableOp(RHS, Opcode); + } + } + + if (!LHSBO) { + if (!RHSBO) { + // Neither the LHS or RHS as part of the tree, thus this is a leaf. As + // such, just remember these operands and their rank. + Ops.push_back(ValueEntry(getRank(LHS), LHS)); + Ops.push_back(ValueEntry(getRank(RHS), RHS)); + + // Clear the leaves out. + I->setOperand(0, UndefValue::get(I->getType())); + I->setOperand(1, UndefValue::get(I->getType())); + return; + } else { + // Turn X+(Y+Z) -> (Y+Z)+X + std::swap(LHSBO, RHSBO); + std::swap(LHS, RHS); + bool Success = !I->swapOperands(); + assert(Success && "swapOperands failed"); + MadeChange = true; + } + } else if (RHSBO) { + // Turn (A+B)+(C+D) -> (((A+B)+C)+D). This guarantees the the RHS is not + // part of the expression tree. + LinearizeExpr(I); + LHS = LHSBO = cast(I->getOperand(0)); + RHS = I->getOperand(1); + RHSBO = 0; + } + + // Okay, now we know that the LHS is a nested expression and that the RHS is + // not. Perform reassociation. + assert(!isReassociableOp(RHS, Opcode) && "LinearizeExpr failed!"); + + // Move LHS right before I to make sure that the tree expression dominates all + // values. + LHSBO->moveBefore(I); + + // Linearize the expression tree on the LHS. + LinearizeExprTree(LHSBO, Ops); + + // Remember the RHS operand and its rank. + Ops.push_back(ValueEntry(getRank(RHS), RHS)); + + // Clear the RHS leaf out. + I->setOperand(1, UndefValue::get(I->getType())); +} + +// RewriteExprTree - Now that the operands for this expression tree are +// linearized and optimized, emit them in-order. This function is written to be +// tail recursive. +void Reassociate::RewriteExprTree(BinaryOperator *I, + std::vector &Ops, + unsigned i) { + if (i+2 == Ops.size()) { + if (I->getOperand(0) != Ops[i].Op || + I->getOperand(1) != Ops[i+1].Op) { + Value *OldLHS = I->getOperand(0); + DOUT << "RA: " << *I; + I->setOperand(0, Ops[i].Op); + I->setOperand(1, Ops[i+1].Op); + DOUT << "TO: " << *I; + MadeChange = true; + ++NumChanged; + + // If we reassociated a tree to fewer operands (e.g. (1+a+2) -> (a+3) + // delete the extra, now dead, nodes. + RemoveDeadBinaryOp(OldLHS); + } + return; + } + assert(i+2 < Ops.size() && "Ops index out of range!"); + + if (I->getOperand(1) != Ops[i].Op) { + DOUT << "RA: " << *I; + I->setOperand(1, Ops[i].Op); + DOUT << "TO: " << *I; + MadeChange = true; + ++NumChanged; + } + + BinaryOperator *LHS = cast(I->getOperand(0)); + assert(LHS->getOpcode() == I->getOpcode() && + "Improper expression tree!"); + + // Compactify the tree instructions together with each other to guarantee + // that the expression tree is dominated by all of Ops. + LHS->moveBefore(I); + RewriteExprTree(LHS, Ops, i+1); +} + + + +// NegateValue - Insert instructions before the instruction pointed to by BI, +// that computes the negative version of the value specified. The negative +// version of the value is returned, and BI is left pointing at the instruction +// that should be processed next by the reassociation pass. +// +static Value *NegateValue(Value *V, Instruction *BI) { + // We are trying to expose opportunity for reassociation. One of the things + // that we want to do to achieve this is to push a negation as deep into an + // expression chain as possible, to expose the add instructions. In practice, + // this means that we turn this: + // X = -(A+12+C+D) into X = -A + -12 + -C + -D = -12 + -A + -C + -D + // so that later, a: Y = 12+X could get reassociated with the -12 to eliminate + // the constants. We assume that instcombine will clean up the mess later if + // we introduce tons of unnecessary negation instructions... + // + if (Instruction *I = dyn_cast(V)) + if (I->getOpcode() == Instruction::Add && I->hasOneUse()) { + // Push the negates through the add. + I->setOperand(0, NegateValue(I->getOperand(0), BI)); + I->setOperand(1, NegateValue(I->getOperand(1), BI)); + + // We must move the add instruction here, because the neg instructions do + // not dominate the old add instruction in general. By moving it, we are + // assured that the neg instructions we just inserted dominate the + // instruction we are about to insert after them. + // + I->moveBefore(BI); + I->setName(I->getName()+".neg"); + return I; + } + + // Insert a 'neg' instruction that subtracts the value from zero to get the + // negation. + // + return BinaryOperator::createNeg(V, V->getName() + ".neg", BI); +} + +/// BreakUpSubtract - If we have (X-Y), and if either X is an add, or if this is +/// only used by an add, transform this into (X+(0-Y)) to promote better +/// reassociation. +static Instruction *BreakUpSubtract(Instruction *Sub) { + // Don't bother to break this up unless either the LHS is an associable add or + // if this is only used by one. + if (!isReassociableOp(Sub->getOperand(0), Instruction::Add) && + !isReassociableOp(Sub->getOperand(1), Instruction::Add) && + !(Sub->hasOneUse() &&isReassociableOp(Sub->use_back(), Instruction::Add))) + return 0; + + // Convert a subtract into an add and a neg instruction... so that sub + // instructions can be commuted with other add instructions... + // + // Calculate the negative value of Operand 1 of the sub instruction... + // and set it as the RHS of the add instruction we just made... + // + Value *NegVal = NegateValue(Sub->getOperand(1), Sub); + Instruction *New = + BinaryOperator::createAdd(Sub->getOperand(0), NegVal, "", Sub); + New->takeName(Sub); + + // Everyone now refers to the add instruction. + Sub->replaceAllUsesWith(New); + Sub->eraseFromParent(); + + DOUT << "Negated: " << *New; + return New; +} + +/// ConvertShiftToMul - If this is a shift of a reassociable multiply or is used +/// by one, change this into a multiply by a constant to assist with further +/// reassociation. +static Instruction *ConvertShiftToMul(Instruction *Shl) { + // If an operand of this shift is a reassociable multiply, or if the shift + // is used by a reassociable multiply or add, turn into a multiply. + if (isReassociableOp(Shl->getOperand(0), Instruction::Mul) || + (Shl->hasOneUse() && + (isReassociableOp(Shl->use_back(), Instruction::Mul) || + isReassociableOp(Shl->use_back(), Instruction::Add)))) { + Constant *MulCst = ConstantInt::get(Shl->getType(), 1); + MulCst = ConstantExpr::getShl(MulCst, cast(Shl->getOperand(1))); + + Instruction *Mul = BinaryOperator::createMul(Shl->getOperand(0), MulCst, + "", Shl); + Mul->takeName(Shl); + Shl->replaceAllUsesWith(Mul); + Shl->eraseFromParent(); + return Mul; + } + return 0; +} + +// Scan backwards and forwards among values with the same rank as element i to +// see if X exists. If X does not exist, return i. +static unsigned FindInOperandList(std::vector &Ops, unsigned i, + Value *X) { + unsigned XRank = Ops[i].Rank; + unsigned e = Ops.size(); + for (unsigned j = i+1; j != e && Ops[j].Rank == XRank; ++j) + if (Ops[j].Op == X) + return j; + // Scan backwards + for (unsigned j = i-1; j != ~0U && Ops[j].Rank == XRank; --j) + if (Ops[j].Op == X) + return j; + return i; +} + +/// EmitAddTreeOfValues - Emit a tree of add instructions, summing Ops together +/// and returning the result. Insert the tree before I. +static Value *EmitAddTreeOfValues(Instruction *I, std::vector &Ops) { + if (Ops.size() == 1) return Ops.back(); + + Value *V1 = Ops.back(); + Ops.pop_back(); + Value *V2 = EmitAddTreeOfValues(I, Ops); + return BinaryOperator::createAdd(V2, V1, "tmp", I); +} + +/// RemoveFactorFromExpression - If V is an expression tree that is a +/// multiplication sequence, and if this sequence contains a multiply by Factor, +/// remove Factor from the tree and return the new tree. +Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { + BinaryOperator *BO = isReassociableOp(V, Instruction::Mul); + if (!BO) return 0; + + std::vector Factors; + LinearizeExprTree(BO, Factors); + + bool FoundFactor = false; + for (unsigned i = 0, e = Factors.size(); i != e; ++i) + if (Factors[i].Op == Factor) { + FoundFactor = true; + Factors.erase(Factors.begin()+i); + break; + } + if (!FoundFactor) { + // Make sure to restore the operands to the expression tree. + RewriteExprTree(BO, Factors); + return 0; + } + + if (Factors.size() == 1) return Factors[0].Op; + + RewriteExprTree(BO, Factors); + return BO; +} + +/// FindSingleUseMultiplyFactors - If V is a single-use multiply, recursively +/// add its operands as factors, otherwise add V to the list of factors. +static void FindSingleUseMultiplyFactors(Value *V, + std::vector &Factors) { + BinaryOperator *BO; + if ((!V->hasOneUse() && !V->use_empty()) || + !(BO = dyn_cast(V)) || + BO->getOpcode() != Instruction::Mul) { + Factors.push_back(V); + return; + } + + // Otherwise, add the LHS and RHS to the list of factors. + FindSingleUseMultiplyFactors(BO->getOperand(1), Factors); + FindSingleUseMultiplyFactors(BO->getOperand(0), Factors); +} + + + +Value *Reassociate::OptimizeExpression(BinaryOperator *I, + std::vector &Ops) { + // Now that we have the linearized expression tree, try to optimize it. + // Start by folding any constants that we found. + bool IterateOptimization = false; + if (Ops.size() == 1) return Ops[0].Op; + + unsigned Opcode = I->getOpcode(); + + if (Constant *V1 = dyn_cast(Ops[Ops.size()-2].Op)) + if (Constant *V2 = dyn_cast(Ops.back().Op)) { + Ops.pop_back(); + Ops.back().Op = ConstantExpr::get(Opcode, V1, V2); + return OptimizeExpression(I, Ops); + } + + // Check for destructive annihilation due to a constant being used. + if (ConstantInt *CstVal = dyn_cast(Ops.back().Op)) + switch (Opcode) { + default: break; + case Instruction::And: + if (CstVal->isZero()) { // ... & 0 -> 0 + ++NumAnnihil; + return CstVal; + } else if (CstVal->isAllOnesValue()) { // ... & -1 -> ... + Ops.pop_back(); + } + break; + case Instruction::Mul: + if (CstVal->isZero()) { // ... * 0 -> 0 + ++NumAnnihil; + return CstVal; + } else if (cast(CstVal)->isOne()) { + Ops.pop_back(); // ... * 1 -> ... + } + break; + case Instruction::Or: + if (CstVal->isAllOnesValue()) { // ... | -1 -> -1 + ++NumAnnihil; + return CstVal; + } + // FALLTHROUGH! + case Instruction::Add: + case Instruction::Xor: + if (CstVal->isZero()) // ... [|^+] 0 -> ... + Ops.pop_back(); + break; + } + if (Ops.size() == 1) return Ops[0].Op; + + // Handle destructive annihilation do to identities between elements in the + // argument list here. + switch (Opcode) { + default: break; + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + // Scan the operand lists looking for X and ~X pairs, along with X,X pairs. + // If we find any, we can simplify the expression. X&~X == 0, X|~X == -1. + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + // First, check for X and ~X in the operand list. + assert(i < Ops.size()); + if (BinaryOperator::isNot(Ops[i].Op)) { // Cannot occur for ^. + Value *X = BinaryOperator::getNotArgument(Ops[i].Op); + unsigned FoundX = FindInOperandList(Ops, i, X); + if (FoundX != i) { + if (Opcode == Instruction::And) { // ...&X&~X = 0 + ++NumAnnihil; + return Constant::getNullValue(X->getType()); + } else if (Opcode == Instruction::Or) { // ...|X|~X = -1 + ++NumAnnihil; + return ConstantInt::getAllOnesValue(X->getType()); + } + } + } + + // Next, check for duplicate pairs of values, which we assume are next to + // each other, due to our sorting criteria. + assert(i < Ops.size()); + if (i+1 != Ops.size() && Ops[i+1].Op == Ops[i].Op) { + if (Opcode == Instruction::And || Opcode == Instruction::Or) { + // Drop duplicate values. + Ops.erase(Ops.begin()+i); + --i; --e; + IterateOptimization = true; + ++NumAnnihil; + } else { + assert(Opcode == Instruction::Xor); + if (e == 2) { + ++NumAnnihil; + return Constant::getNullValue(Ops[0].Op->getType()); + } + // ... X^X -> ... + Ops.erase(Ops.begin()+i, Ops.begin()+i+2); + i -= 1; e -= 2; + IterateOptimization = true; + ++NumAnnihil; + } + } + } + break; + + case Instruction::Add: + // Scan the operand lists looking for X and -X pairs. If we find any, we + // can simplify the expression. X+-X == 0. + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + assert(i < Ops.size()); + // Check for X and -X in the operand list. + if (BinaryOperator::isNeg(Ops[i].Op)) { + Value *X = BinaryOperator::getNegArgument(Ops[i].Op); + unsigned FoundX = FindInOperandList(Ops, i, X); + if (FoundX != i) { + // Remove X and -X from the operand list. + if (Ops.size() == 2) { + ++NumAnnihil; + return Constant::getNullValue(X->getType()); + } else { + Ops.erase(Ops.begin()+i); + if (i < FoundX) + --FoundX; + else + --i; // Need to back up an extra one. + Ops.erase(Ops.begin()+FoundX); + IterateOptimization = true; + ++NumAnnihil; + --i; // Revisit element. + e -= 2; // Removed two elements. + } + } + } + } + + + // Scan the operand list, checking to see if there are any common factors + // between operands. Consider something like A*A+A*B*C+D. We would like to + // reassociate this to A*(A+B*C)+D, which reduces the number of multiplies. + // To efficiently find this, we count the number of times a factor occurs + // for any ADD operands that are MULs. + std::map FactorOccurrences; + unsigned MaxOcc = 0; + Value *MaxOccVal = 0; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + if (BinaryOperator *BOp = dyn_cast(Ops[i].Op)) { + if (BOp->getOpcode() == Instruction::Mul && BOp->use_empty()) { + // Compute all of the factors of this added value. + std::vector Factors; + FindSingleUseMultiplyFactors(BOp, Factors); + assert(Factors.size() > 1 && "Bad linearize!"); + + // Add one to FactorOccurrences for each unique factor in this op. + if (Factors.size() == 2) { + unsigned Occ = ++FactorOccurrences[Factors[0]]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[0]; } + if (Factors[0] != Factors[1]) { // Don't double count A*A. + Occ = ++FactorOccurrences[Factors[1]]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[1]; } + } + } else { + std::set Duplicates; + for (unsigned i = 0, e = Factors.size(); i != e; ++i) { + if (Duplicates.insert(Factors[i]).second) { + unsigned Occ = ++FactorOccurrences[Factors[i]]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[i]; } + } + } + } + } + } + } + + // If any factor occurred more than one time, we can pull it out. + if (MaxOcc > 1) { + DOUT << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << "\n"; + + // Create a new instruction that uses the MaxOccVal twice. If we don't do + // this, we could otherwise run into situations where removing a factor + // from an expression will drop a use of maxocc, and this can cause + // RemoveFactorFromExpression on successive values to behave differently. + Instruction *DummyInst = BinaryOperator::createAdd(MaxOccVal, MaxOccVal); + std::vector NewMulOps; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + if (Value *V = RemoveFactorFromExpression(Ops[i].Op, MaxOccVal)) { + NewMulOps.push_back(V); + Ops.erase(Ops.begin()+i); + --i; --e; + } + } + + // No need for extra uses anymore. + delete DummyInst; + + unsigned NumAddedValues = NewMulOps.size(); + Value *V = EmitAddTreeOfValues(I, NewMulOps); + Value *V2 = BinaryOperator::createMul(V, MaxOccVal, "tmp", I); + + // Now that we have inserted V and its sole use, optimize it. This allows + // us to handle cases that require multiple factoring steps, such as this: + // A*A*B + A*A*C --> A*(A*B+A*C) --> A*(A*(B+C)) + if (NumAddedValues > 1) + ReassociateExpression(cast(V)); + + ++NumFactor; + + if (Ops.size() == 0) + return V2; + + // Add the new value to the list of things being added. + Ops.insert(Ops.begin(), ValueEntry(getRank(V2), V2)); + + // Rewrite the tree so that there is now a use of V. + RewriteExprTree(I, Ops); + return OptimizeExpression(I, Ops); + } + break; + //case Instruction::Mul: + } + + if (IterateOptimization) + return OptimizeExpression(I, Ops); + return 0; +} + + +/// ReassociateBB - Inspect all of the instructions in this basic block, +/// reassociating them as we go. +void Reassociate::ReassociateBB(BasicBlock *BB) { + for (BasicBlock::iterator BBI = BB->begin(); BBI != BB->end(); ) { + Instruction *BI = BBI++; + if (BI->getOpcode() == Instruction::Shl && + isa(BI->getOperand(1))) + if (Instruction *NI = ConvertShiftToMul(BI)) { + MadeChange = true; + BI = NI; + } + + // Reject cases where it is pointless to do this. + if (!isa(BI) || BI->getType()->isFloatingPoint() || + isa(BI->getType())) + continue; // Floating point ops are not associative. + + // If this is a subtract instruction which is not already in negate form, + // see if we can convert it to X+-Y. + if (BI->getOpcode() == Instruction::Sub) { + if (!BinaryOperator::isNeg(BI)) { + if (Instruction *NI = BreakUpSubtract(BI)) { + MadeChange = true; + BI = NI; + } + } else { + // Otherwise, this is a negation. See if the operand is a multiply tree + // and if this is not an inner node of a multiply tree. + if (isReassociableOp(BI->getOperand(1), Instruction::Mul) && + (!BI->hasOneUse() || + !isReassociableOp(BI->use_back(), Instruction::Mul))) { + BI = LowerNegateToMultiply(BI); + MadeChange = true; + } + } + } + + // If this instruction is a commutative binary operator, process it. + if (!BI->isAssociative()) continue; + BinaryOperator *I = cast(BI); + + // If this is an interior node of a reassociable tree, ignore it until we + // get to the root of the tree, to avoid N^2 analysis. + if (I->hasOneUse() && isReassociableOp(I->use_back(), I->getOpcode())) + continue; + + // If this is an add tree that is used by a sub instruction, ignore it + // until we process the subtract. + if (I->hasOneUse() && I->getOpcode() == Instruction::Add && + cast(I->use_back())->getOpcode() == Instruction::Sub) + continue; + + ReassociateExpression(I); + } +} + +void Reassociate::ReassociateExpression(BinaryOperator *I) { + + // First, walk the expression tree, linearizing the tree, collecting + std::vector Ops; + LinearizeExprTree(I, Ops); + + DOUT << "RAIn:\t"; DEBUG(PrintOps(I, Ops)); DOUT << "\n"; + + // Now that we have linearized the tree to a list and have gathered all of + // the operands and their ranks, sort the operands by their rank. Use a + // stable_sort so that values with equal ranks will have their relative + // positions maintained (and so the compiler is deterministic). Note that + // this sorts so that the highest ranking values end up at the beginning of + // the vector. + std::stable_sort(Ops.begin(), Ops.end()); + + // OptimizeExpression - Now that we have the expression tree in a convenient + // sorted form, optimize it globally if possible. + if (Value *V = OptimizeExpression(I, Ops)) { + // This expression tree simplified to something that isn't a tree, + // eliminate it. + DOUT << "Reassoc to scalar: " << *V << "\n"; + I->replaceAllUsesWith(V); + RemoveDeadBinaryOp(I); + return; + } + + // We want to sink immediates as deeply as possible except in the case where + // this is a multiply tree used only by an add, and the immediate is a -1. + // In this case we reassociate to put the negation on the outside so that we + // can fold the negation into the add: (-X)*Y + Z -> Z-X*Y + if (I->getOpcode() == Instruction::Mul && I->hasOneUse() && + cast(I->use_back())->getOpcode() == Instruction::Add && + isa(Ops.back().Op) && + cast(Ops.back().Op)->isAllOnesValue()) { + Ops.insert(Ops.begin(), Ops.back()); + Ops.pop_back(); + } + + DOUT << "RAOut:\t"; DEBUG(PrintOps(I, Ops)); DOUT << "\n"; + + if (Ops.size() == 1) { + // This expression tree simplified to something that isn't a tree, + // eliminate it. + I->replaceAllUsesWith(Ops[0].Op); + RemoveDeadBinaryOp(I); + } else { + // Now that we ordered and optimized the expressions, splat them back into + // the expression tree, removing any unneeded nodes. + RewriteExprTree(I, Ops); + } +} + + +bool Reassociate::runOnFunction(Function &F) { + // Recalculate the rank map for F + BuildRankMap(F); + + MadeChange = false; + for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ++FI) + ReassociateBB(FI); + + // We are done with the rank map... + RankMap.clear(); + ValueRankMap.clear(); + return MadeChange; +} + diff --git a/lib/Transforms/Scalar/Reg2Mem.cpp b/lib/Transforms/Scalar/Reg2Mem.cpp new file mode 100644 index 0000000..ef7411a --- /dev/null +++ b/lib/Transforms/Scalar/Reg2Mem.cpp @@ -0,0 +1,91 @@ +//===- Reg2Mem.cpp - Convert registers to allocas -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file demotes all registers to memory references. It is intented to be +// the inverse of PromoteMemoryToRegister. By converting to loads, the only +// values live accross basic blocks are allocas and loads before phi nodes. +// It is intended that this should make CFG hacking much easier. +// To make later hacking easier, the entry block is split into two, such that +// all introduced allocas and nothing else are in the entry block. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "reg2mem" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Pass.h" +#include "llvm/Function.h" +#include "llvm/Module.h" +#include "llvm/BasicBlock.h" +#include "llvm/Instructions.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +#include +using namespace llvm; + +STATISTIC(NumDemoted, "Number of registers demoted"); + +namespace { + struct VISIBILITY_HIDDEN RegToMem : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + RegToMem() : FunctionPass((intptr_t)&ID) {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(BreakCriticalEdgesID); + AU.addPreservedID(BreakCriticalEdgesID); + } + + bool valueEscapes(Instruction* i) { + BasicBlock* bb = i->getParent(); + for(Value::use_iterator ii = i->use_begin(), ie = i->use_end(); + ii != ie; ++ii) + if (cast(*ii)->getParent() != bb || + isa(*ii)) + return true; + return false; + } + + virtual bool runOnFunction(Function &F) { + if (!F.isDeclaration()) { + //give us a clean block + BasicBlock* bbold = &F.getEntryBlock(); + BasicBlock* bbnew = new BasicBlock("allocablock", &F, + &F.getEntryBlock()); + new BranchInst(bbold, bbnew); + + //find the instructions + std::list worklist; + for (Function::iterator ibb = F.begin(), ibe = F.end(); + ibb != ibe; ++ibb) + for (BasicBlock::iterator iib = ibb->begin(), iie = ibb->end(); + iib != iie; ++iib) { + if(valueEscapes(iib)) + worklist.push_front(&*iib); + } + //demote escaped instructions + NumDemoted += worklist.size(); + for (std::list::iterator ilb = worklist.begin(), + ile = worklist.end(); ilb != ile; ++ilb) + DemoteRegToStack(**ilb, false); + return true; + } + return false; + } + }; + + char RegToMem::ID = 0; + RegisterPass X("reg2mem", "Demote all values to stack slots"); +} + +// createDemoteRegisterToMemory - Provide an entry point to create this pass. +// +const PassInfo *llvm::DemoteRegisterToMemoryID = X.getPassInfo(); +FunctionPass *llvm::createDemoteRegisterToMemoryPass() { + return new RegToMem(); +} diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp new file mode 100644 index 0000000..0e4fe8f --- /dev/null +++ b/lib/Transforms/Scalar/SCCP.cpp @@ -0,0 +1,1691 @@ +//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements sparse conditional constant propagation and merging: +// +// Specifically, this: +// * Assumes values are constant unless proven otherwise +// * Assumes BasicBlocks are dead unless proven otherwise +// * Proves values to be constant, and replaces them with constants +// * Proves conditional branches to be unconditional +// +// Notice that: +// * This pass has a habit of making definitions be dead. It is a good idea +// to to run a DCE pass sometime after running this pass. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "sccp" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/InstVisitor.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include +using namespace llvm; + +STATISTIC(NumInstRemoved, "Number of instructions removed"); +STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable"); + +STATISTIC(IPNumInstRemoved, "Number ofinstructions removed by IPSCCP"); +STATISTIC(IPNumDeadBlocks , "Number of basic blocks unreachable by IPSCCP"); +STATISTIC(IPNumArgsElimed ,"Number of arguments constant propagated by IPSCCP"); +STATISTIC(IPNumGlobalConst, "Number of globals found to be constant by IPSCCP"); + +namespace { +/// LatticeVal class - This class represents the different lattice values that +/// an LLVM value may occupy. It is a simple class with value semantics. +/// +class VISIBILITY_HIDDEN LatticeVal { + enum { + /// undefined - This LLVM Value has no known value yet. + undefined, + + /// constant - This LLVM Value has a specific constant value. + constant, + + /// forcedconstant - This LLVM Value was thought to be undef until + /// ResolvedUndefsIn. This is treated just like 'constant', but if merged + /// with another (different) constant, it goes to overdefined, instead of + /// asserting. + forcedconstant, + + /// overdefined - This instruction is not known to be constant, and we know + /// it has a value. + overdefined + } LatticeValue; // The current lattice position + + Constant *ConstantVal; // If Constant value, the current value +public: + inline LatticeVal() : LatticeValue(undefined), ConstantVal(0) {} + + // markOverdefined - Return true if this is a new status to be in... + inline bool markOverdefined() { + if (LatticeValue != overdefined) { + LatticeValue = overdefined; + return true; + } + return false; + } + + // markConstant - Return true if this is a new status for us. + inline bool markConstant(Constant *V) { + if (LatticeValue != constant) { + if (LatticeValue == undefined) { + LatticeValue = constant; + assert(V && "Marking constant with NULL"); + ConstantVal = V; + } else { + assert(LatticeValue == forcedconstant && + "Cannot move from overdefined to constant!"); + // Stay at forcedconstant if the constant is the same. + if (V == ConstantVal) return false; + + // Otherwise, we go to overdefined. Assumptions made based on the + // forced value are possibly wrong. Assuming this is another constant + // could expose a contradiction. + LatticeValue = overdefined; + } + return true; + } else { + assert(ConstantVal == V && "Marking constant with different value"); + } + return false; + } + + inline void markForcedConstant(Constant *V) { + assert(LatticeValue == undefined && "Can't force a defined value!"); + LatticeValue = forcedconstant; + ConstantVal = V; + } + + inline bool isUndefined() const { return LatticeValue == undefined; } + inline bool isConstant() const { + return LatticeValue == constant || LatticeValue == forcedconstant; + } + inline bool isOverdefined() const { return LatticeValue == overdefined; } + + inline Constant *getConstant() const { + assert(isConstant() && "Cannot get the constant of a non-constant!"); + return ConstantVal; + } +}; + +} // end anonymous namespace + + +//===----------------------------------------------------------------------===// +// +/// SCCPSolver - This class is a general purpose solver for Sparse Conditional +/// Constant Propagation. +/// +class SCCPSolver : public InstVisitor { + SmallSet BBExecutable;// The basic blocks that are executable + std::map ValueState; // The state each value is in. + + /// GlobalValue - If we are tracking any values for the contents of a global + /// variable, we keep a mapping from the constant accessor to the element of + /// the global, to the currently known value. If the value becomes + /// overdefined, it's entry is simply removed from this map. + DenseMap TrackedGlobals; + + /// TrackedFunctionRetVals - If we are tracking arguments into and the return + /// value out of a function, it will have an entry in this map, indicating + /// what the known return value for the function is. + DenseMap TrackedFunctionRetVals; + + // The reason for two worklists is that overdefined is the lowest state + // on the lattice, and moving things to overdefined as fast as possible + // makes SCCP converge much faster. + // By having a separate worklist, we accomplish this because everything + // possibly overdefined will become overdefined at the soonest possible + // point. + std::vector OverdefinedInstWorkList; + std::vector InstWorkList; + + + std::vector BBWorkList; // The BasicBlock work list + + /// UsersOfOverdefinedPHIs - Keep track of any users of PHI nodes that are not + /// overdefined, despite the fact that the PHI node is overdefined. + std::multimap UsersOfOverdefinedPHIs; + + /// KnownFeasibleEdges - Entries in this set are edges which have already had + /// PHI nodes retriggered. + typedef std::pair Edge; + std::set KnownFeasibleEdges; +public: + + /// MarkBlockExecutable - This method can be used by clients to mark all of + /// the blocks that are known to be intrinsically live in the processed unit. + void MarkBlockExecutable(BasicBlock *BB) { + DOUT << "Marking Block Executable: " << BB->getName() << "\n"; + BBExecutable.insert(BB); // Basic block is executable! + BBWorkList.push_back(BB); // Add the block to the work list! + } + + /// TrackValueOfGlobalVariable - Clients can use this method to + /// inform the SCCPSolver that it should track loads and stores to the + /// specified global variable if it can. This is only legal to call if + /// performing Interprocedural SCCP. + void TrackValueOfGlobalVariable(GlobalVariable *GV) { + const Type *ElTy = GV->getType()->getElementType(); + if (ElTy->isFirstClassType()) { + LatticeVal &IV = TrackedGlobals[GV]; + if (!isa(GV->getInitializer())) + IV.markConstant(GV->getInitializer()); + } + } + + /// AddTrackedFunction - If the SCCP solver is supposed to track calls into + /// and out of the specified function (which cannot have its address taken), + /// this method must be called. + void AddTrackedFunction(Function *F) { + assert(F->hasInternalLinkage() && "Can only track internal functions!"); + // Add an entry, F -> undef. + TrackedFunctionRetVals[F]; + } + + /// Solve - Solve for constants and executable blocks. + /// + void Solve(); + + /// ResolvedUndefsIn - While solving the dataflow for a function, we assume + /// that branches on undef values cannot reach any of their successors. + /// However, this is not a safe assumption. After we solve dataflow, this + /// method should be use to handle this. If this returns true, the solver + /// should be rerun. + bool ResolvedUndefsIn(Function &F); + + /// getExecutableBlocks - Once we have solved for constants, return the set of + /// blocks that is known to be executable. + SmallSet &getExecutableBlocks() { + return BBExecutable; + } + + /// getValueMapping - Once we have solved for constants, return the mapping of + /// LLVM values to LatticeVals. + std::map &getValueMapping() { + return ValueState; + } + + /// getTrackedFunctionRetVals - Get the inferred return value map. + /// + const DenseMap &getTrackedFunctionRetVals() { + return TrackedFunctionRetVals; + } + + /// getTrackedGlobals - Get and return the set of inferred initializers for + /// global variables. + const DenseMap &getTrackedGlobals() { + return TrackedGlobals; + } + + inline void markOverdefined(Value *V) { + markOverdefined(ValueState[V], V); + } + +private: + // markConstant - Make a value be marked as "constant". If the value + // is not already a constant, add it to the instruction work list so that + // the users of the instruction are updated later. + // + inline void markConstant(LatticeVal &IV, Value *V, Constant *C) { + if (IV.markConstant(C)) { + DOUT << "markConstant: " << *C << ": " << *V; + InstWorkList.push_back(V); + } + } + + inline void markForcedConstant(LatticeVal &IV, Value *V, Constant *C) { + IV.markForcedConstant(C); + DOUT << "markForcedConstant: " << *C << ": " << *V; + InstWorkList.push_back(V); + } + + inline void markConstant(Value *V, Constant *C) { + markConstant(ValueState[V], V, C); + } + + // markOverdefined - Make a value be marked as "overdefined". If the + // value is not already overdefined, add it to the overdefined instruction + // work list so that the users of the instruction are updated later. + + inline void markOverdefined(LatticeVal &IV, Value *V) { + if (IV.markOverdefined()) { + DEBUG(DOUT << "markOverdefined: "; + if (Function *F = dyn_cast(V)) + DOUT << "Function '" << F->getName() << "'\n"; + else + DOUT << *V); + // Only instructions go on the work list + OverdefinedInstWorkList.push_back(V); + } + } + + inline void mergeInValue(LatticeVal &IV, Value *V, LatticeVal &MergeWithV) { + if (IV.isOverdefined() || MergeWithV.isUndefined()) + return; // Noop. + if (MergeWithV.isOverdefined()) + markOverdefined(IV, V); + else if (IV.isUndefined()) + markConstant(IV, V, MergeWithV.getConstant()); + else if (IV.getConstant() != MergeWithV.getConstant()) + markOverdefined(IV, V); + } + + inline void mergeInValue(Value *V, LatticeVal &MergeWithV) { + return mergeInValue(ValueState[V], V, MergeWithV); + } + + + // getValueState - Return the LatticeVal object that corresponds to the value. + // This function is necessary because not all values should start out in the + // underdefined state... Argument's should be overdefined, and + // constants should be marked as constants. If a value is not known to be an + // Instruction object, then use this accessor to get its value from the map. + // + inline LatticeVal &getValueState(Value *V) { + std::map::iterator I = ValueState.find(V); + if (I != ValueState.end()) return I->second; // Common case, in the map + + if (Constant *C = dyn_cast(V)) { + if (isa(V)) { + // Nothing to do, remain undefined. + } else { + LatticeVal &LV = ValueState[C]; + LV.markConstant(C); // Constants are constant + return LV; + } + } + // All others are underdefined by default... + return ValueState[V]; + } + + // markEdgeExecutable - Mark a basic block as executable, adding it to the BB + // work list if it is not already executable... + // + void markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest) { + if (!KnownFeasibleEdges.insert(Edge(Source, Dest)).second) + return; // This edge is already known to be executable! + + if (BBExecutable.count(Dest)) { + DOUT << "Marking Edge Executable: " << Source->getName() + << " -> " << Dest->getName() << "\n"; + + // The destination is already executable, but we just made an edge + // feasible that wasn't before. Revisit the PHI nodes in the block + // because they have potentially new operands. + for (BasicBlock::iterator I = Dest->begin(); isa(I); ++I) + visitPHINode(*cast(I)); + + } else { + MarkBlockExecutable(Dest); + } + } + + // getFeasibleSuccessors - Return a vector of booleans to indicate which + // successors are reachable from a given terminator instruction. + // + void getFeasibleSuccessors(TerminatorInst &TI, SmallVector &Succs); + + // isEdgeFeasible - Return true if the control flow edge from the 'From' basic + // block to the 'To' basic block is currently feasible... + // + bool isEdgeFeasible(BasicBlock *From, BasicBlock *To); + + // OperandChangedState - This method is invoked on all of the users of an + // instruction that was just changed state somehow.... Based on this + // information, we need to update the specified user of this instruction. + // + void OperandChangedState(User *U) { + // Only instructions use other variable values! + Instruction &I = cast(*U); + if (BBExecutable.count(I.getParent())) // Inst is executable? + visit(I); + } + +private: + friend class InstVisitor; + + // visit implementations - Something changed in this instruction... Either an + // operand made a transition, or the instruction is newly executable. Change + // the value type of I to reflect these changes if appropriate. + // + void visitPHINode(PHINode &I); + + // Terminators + void visitReturnInst(ReturnInst &I); + void visitTerminatorInst(TerminatorInst &TI); + + void visitCastInst(CastInst &I); + void visitSelectInst(SelectInst &I); + void visitBinaryOperator(Instruction &I); + void visitCmpInst(CmpInst &I); + void visitExtractElementInst(ExtractElementInst &I); + void visitInsertElementInst(InsertElementInst &I); + void visitShuffleVectorInst(ShuffleVectorInst &I); + + // Instructions that cannot be folded away... + void visitStoreInst (Instruction &I); + void visitLoadInst (LoadInst &I); + void visitGetElementPtrInst(GetElementPtrInst &I); + void visitCallInst (CallInst &I) { visitCallSite(CallSite::get(&I)); } + void visitInvokeInst (InvokeInst &II) { + visitCallSite(CallSite::get(&II)); + visitTerminatorInst(II); + } + void visitCallSite (CallSite CS); + void visitUnwindInst (TerminatorInst &I) { /*returns void*/ } + void visitUnreachableInst(TerminatorInst &I) { /*returns void*/ } + void visitAllocationInst(Instruction &I) { markOverdefined(&I); } + void visitVANextInst (Instruction &I) { markOverdefined(&I); } + void visitVAArgInst (Instruction &I) { markOverdefined(&I); } + void visitFreeInst (Instruction &I) { /*returns void*/ } + + void visitInstruction(Instruction &I) { + // If a new instruction is added to LLVM that we don't handle... + cerr << "SCCP: Don't know how to handle: " << I; + markOverdefined(&I); // Just in case + } +}; + +// getFeasibleSuccessors - Return a vector of booleans to indicate which +// successors are reachable from a given terminator instruction. +// +void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, + SmallVector &Succs) { + Succs.resize(TI.getNumSuccessors()); + if (BranchInst *BI = dyn_cast(&TI)) { + if (BI->isUnconditional()) { + Succs[0] = true; + } else { + LatticeVal &BCValue = getValueState(BI->getCondition()); + if (BCValue.isOverdefined() || + (BCValue.isConstant() && !isa(BCValue.getConstant()))) { + // Overdefined condition variables, and branches on unfoldable constant + // conditions, mean the branch could go either way. + Succs[0] = Succs[1] = true; + } else if (BCValue.isConstant()) { + // Constant condition variables mean the branch can only go a single way + Succs[BCValue.getConstant() == ConstantInt::getFalse()] = true; + } + } + } else if (isa(&TI)) { + // Invoke instructions successors are always executable. + Succs[0] = Succs[1] = true; + } else if (SwitchInst *SI = dyn_cast(&TI)) { + LatticeVal &SCValue = getValueState(SI->getCondition()); + if (SCValue.isOverdefined() || // Overdefined condition? + (SCValue.isConstant() && !isa(SCValue.getConstant()))) { + // All destinations are executable! + Succs.assign(TI.getNumSuccessors(), true); + } else if (SCValue.isConstant()) { + Constant *CPV = SCValue.getConstant(); + // Make sure to skip the "default value" which isn't a value + for (unsigned i = 1, E = SI->getNumSuccessors(); i != E; ++i) { + if (SI->getSuccessorValue(i) == CPV) {// Found the right branch... + Succs[i] = true; + return; + } + } + + // Constant value not equal to any of the branches... must execute + // default branch then... + Succs[0] = true; + } + } else { + assert(0 && "SCCP: Don't know how to handle this terminator!"); + } +} + + +// isEdgeFeasible - Return true if the control flow edge from the 'From' basic +// block to the 'To' basic block is currently feasible... +// +bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { + assert(BBExecutable.count(To) && "Dest should always be alive!"); + + // Make sure the source basic block is executable!! + if (!BBExecutable.count(From)) return false; + + // Check to make sure this edge itself is actually feasible now... + TerminatorInst *TI = From->getTerminator(); + if (BranchInst *BI = dyn_cast(TI)) { + if (BI->isUnconditional()) + return true; + else { + LatticeVal &BCValue = getValueState(BI->getCondition()); + if (BCValue.isOverdefined()) { + // Overdefined condition variables mean the branch could go either way. + return true; + } else if (BCValue.isConstant()) { + // Not branching on an evaluatable constant? + if (!isa(BCValue.getConstant())) return true; + + // Constant condition variables mean the branch can only go a single way + return BI->getSuccessor(BCValue.getConstant() == + ConstantInt::getFalse()) == To; + } + return false; + } + } else if (isa(TI)) { + // Invoke instructions successors are always executable. + return true; + } else if (SwitchInst *SI = dyn_cast(TI)) { + LatticeVal &SCValue = getValueState(SI->getCondition()); + if (SCValue.isOverdefined()) { // Overdefined condition? + // All destinations are executable! + return true; + } else if (SCValue.isConstant()) { + Constant *CPV = SCValue.getConstant(); + if (!isa(CPV)) + return true; // not a foldable constant? + + // Make sure to skip the "default value" which isn't a value + for (unsigned i = 1, E = SI->getNumSuccessors(); i != E; ++i) + if (SI->getSuccessorValue(i) == CPV) // Found the taken branch... + return SI->getSuccessor(i) == To; + + // Constant value not equal to any of the branches... must execute + // default branch then... + return SI->getDefaultDest() == To; + } + return false; + } else { + cerr << "Unknown terminator instruction: " << *TI; + abort(); + } +} + +// visit Implementations - Something changed in this instruction... Either an +// operand made a transition, or the instruction is newly executable. Change +// the value type of I to reflect these changes if appropriate. This method +// makes sure to do the following actions: +// +// 1. If a phi node merges two constants in, and has conflicting value coming +// from different branches, or if the PHI node merges in an overdefined +// value, then the PHI node becomes overdefined. +// 2. If a phi node merges only constants in, and they all agree on value, the +// PHI node becomes a constant value equal to that. +// 3. If V <- x (op) y && isConstant(x) && isConstant(y) V = Constant +// 4. If V <- x (op) y && (isOverdefined(x) || isOverdefined(y)) V = Overdefined +// 5. If V <- MEM or V <- CALL or V <- (unknown) then V = Overdefined +// 6. If a conditional branch has a value that is constant, make the selected +// destination executable +// 7. If a conditional branch has a value that is overdefined, make all +// successors executable. +// +void SCCPSolver::visitPHINode(PHINode &PN) { + LatticeVal &PNIV = getValueState(&PN); + if (PNIV.isOverdefined()) { + // There may be instructions using this PHI node that are not overdefined + // themselves. If so, make sure that they know that the PHI node operand + // changed. + std::multimap::iterator I, E; + tie(I, E) = UsersOfOverdefinedPHIs.equal_range(&PN); + if (I != E) { + SmallVector Users; + for (; I != E; ++I) Users.push_back(I->second); + while (!Users.empty()) { + visit(Users.back()); + Users.pop_back(); + } + } + return; // Quick exit + } + + // Super-extra-high-degree PHI nodes are unlikely to ever be marked constant, + // and slow us down a lot. Just mark them overdefined. + if (PN.getNumIncomingValues() > 64) { + markOverdefined(PNIV, &PN); + return; + } + + // Look at all of the executable operands of the PHI node. If any of them + // are overdefined, the PHI becomes overdefined as well. If they are all + // constant, and they agree with each other, the PHI becomes the identical + // constant. If they are constant and don't agree, the PHI is overdefined. + // If there are no executable operands, the PHI remains undefined. + // + Constant *OperandVal = 0; + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { + LatticeVal &IV = getValueState(PN.getIncomingValue(i)); + if (IV.isUndefined()) continue; // Doesn't influence PHI node. + + if (isEdgeFeasible(PN.getIncomingBlock(i), PN.getParent())) { + if (IV.isOverdefined()) { // PHI node becomes overdefined! + markOverdefined(PNIV, &PN); + return; + } + + if (OperandVal == 0) { // Grab the first value... + OperandVal = IV.getConstant(); + } else { // Another value is being merged in! + // There is already a reachable operand. If we conflict with it, + // then the PHI node becomes overdefined. If we agree with it, we + // can continue on. + + // Check to see if there are two different constants merging... + if (IV.getConstant() != OperandVal) { + // Yes there is. This means the PHI node is not constant. + // You must be overdefined poor PHI. + // + markOverdefined(PNIV, &PN); // The PHI node now becomes overdefined + return; // I'm done analyzing you + } + } + } + } + + // If we exited the loop, this means that the PHI node only has constant + // arguments that agree with each other(and OperandVal is the constant) or + // OperandVal is null because there are no defined incoming arguments. If + // this is the case, the PHI remains undefined. + // + if (OperandVal) + markConstant(PNIV, &PN, OperandVal); // Acquire operand value +} + +void SCCPSolver::visitReturnInst(ReturnInst &I) { + if (I.getNumOperands() == 0) return; // Ret void + + // If we are tracking the return value of this function, merge it in. + Function *F = I.getParent()->getParent(); + if (F->hasInternalLinkage() && !TrackedFunctionRetVals.empty()) { + DenseMap::iterator TFRVI = + TrackedFunctionRetVals.find(F); + if (TFRVI != TrackedFunctionRetVals.end() && + !TFRVI->second.isOverdefined()) { + LatticeVal &IV = getValueState(I.getOperand(0)); + mergeInValue(TFRVI->second, F, IV); + } + } +} + + +void SCCPSolver::visitTerminatorInst(TerminatorInst &TI) { + SmallVector SuccFeasible; + getFeasibleSuccessors(TI, SuccFeasible); + + BasicBlock *BB = TI.getParent(); + + // Mark all feasible successors executable... + for (unsigned i = 0, e = SuccFeasible.size(); i != e; ++i) + if (SuccFeasible[i]) + markEdgeExecutable(BB, TI.getSuccessor(i)); +} + +void SCCPSolver::visitCastInst(CastInst &I) { + Value *V = I.getOperand(0); + LatticeVal &VState = getValueState(V); + if (VState.isOverdefined()) // Inherit overdefinedness of operand + markOverdefined(&I); + else if (VState.isConstant()) // Propagate constant value + markConstant(&I, ConstantExpr::getCast(I.getOpcode(), + VState.getConstant(), I.getType())); +} + +void SCCPSolver::visitSelectInst(SelectInst &I) { + LatticeVal &CondValue = getValueState(I.getCondition()); + if (CondValue.isUndefined()) + return; + if (CondValue.isConstant()) { + if (ConstantInt *CondCB = dyn_cast(CondValue.getConstant())){ + mergeInValue(&I, getValueState(CondCB->getZExtValue() ? I.getTrueValue() + : I.getFalseValue())); + return; + } + } + + // Otherwise, the condition is overdefined or a constant we can't evaluate. + // See if we can produce something better than overdefined based on the T/F + // value. + LatticeVal &TVal = getValueState(I.getTrueValue()); + LatticeVal &FVal = getValueState(I.getFalseValue()); + + // select ?, C, C -> C. + if (TVal.isConstant() && FVal.isConstant() && + TVal.getConstant() == FVal.getConstant()) { + markConstant(&I, FVal.getConstant()); + return; + } + + if (TVal.isUndefined()) { // select ?, undef, X -> X. + mergeInValue(&I, FVal); + } else if (FVal.isUndefined()) { // select ?, X, undef -> X. + mergeInValue(&I, TVal); + } else { + markOverdefined(&I); + } +} + +// Handle BinaryOperators and Shift Instructions... +void SCCPSolver::visitBinaryOperator(Instruction &I) { + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + LatticeVal &V1State = getValueState(I.getOperand(0)); + LatticeVal &V2State = getValueState(I.getOperand(1)); + + if (V1State.isOverdefined() || V2State.isOverdefined()) { + // If this is an AND or OR with 0 or -1, it doesn't matter that the other + // operand is overdefined. + if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Or) { + LatticeVal *NonOverdefVal = 0; + if (!V1State.isOverdefined()) { + NonOverdefVal = &V1State; + } else if (!V2State.isOverdefined()) { + NonOverdefVal = &V2State; + } + + if (NonOverdefVal) { + if (NonOverdefVal->isUndefined()) { + // Could annihilate value. + if (I.getOpcode() == Instruction::And) + markConstant(IV, &I, Constant::getNullValue(I.getType())); + else if (const VectorType *PT = dyn_cast(I.getType())) + markConstant(IV, &I, ConstantVector::getAllOnesValue(PT)); + else + markConstant(IV, &I, ConstantInt::getAllOnesValue(I.getType())); + return; + } else { + if (I.getOpcode() == Instruction::And) { + if (NonOverdefVal->getConstant()->isNullValue()) { + markConstant(IV, &I, NonOverdefVal->getConstant()); + return; // X and 0 = 0 + } + } else { + if (ConstantInt *CI = + dyn_cast(NonOverdefVal->getConstant())) + if (CI->isAllOnesValue()) { + markConstant(IV, &I, NonOverdefVal->getConstant()); + return; // X or -1 = -1 + } + } + } + } + } + + + // If both operands are PHI nodes, it is possible that this instruction has + // a constant value, despite the fact that the PHI node doesn't. Check for + // this condition now. + if (PHINode *PN1 = dyn_cast(I.getOperand(0))) + if (PHINode *PN2 = dyn_cast(I.getOperand(1))) + if (PN1->getParent() == PN2->getParent()) { + // Since the two PHI nodes are in the same basic block, they must have + // entries for the same predecessors. Walk the predecessor list, and + // if all of the incoming values are constants, and the result of + // evaluating this expression with all incoming value pairs is the + // same, then this expression is a constant even though the PHI node + // is not a constant! + LatticeVal Result; + for (unsigned i = 0, e = PN1->getNumIncomingValues(); i != e; ++i) { + LatticeVal &In1 = getValueState(PN1->getIncomingValue(i)); + BasicBlock *InBlock = PN1->getIncomingBlock(i); + LatticeVal &In2 = + getValueState(PN2->getIncomingValueForBlock(InBlock)); + + if (In1.isOverdefined() || In2.isOverdefined()) { + Result.markOverdefined(); + break; // Cannot fold this operation over the PHI nodes! + } else if (In1.isConstant() && In2.isConstant()) { + Constant *V = ConstantExpr::get(I.getOpcode(), In1.getConstant(), + In2.getConstant()); + if (Result.isUndefined()) + Result.markConstant(V); + else if (Result.isConstant() && Result.getConstant() != V) { + Result.markOverdefined(); + break; + } + } + } + + // If we found a constant value here, then we know the instruction is + // constant despite the fact that the PHI nodes are overdefined. + if (Result.isConstant()) { + markConstant(IV, &I, Result.getConstant()); + // Remember that this instruction is virtually using the PHI node + // operands. + UsersOfOverdefinedPHIs.insert(std::make_pair(PN1, &I)); + UsersOfOverdefinedPHIs.insert(std::make_pair(PN2, &I)); + return; + } else if (Result.isUndefined()) { + return; + } + + // Okay, this really is overdefined now. Since we might have + // speculatively thought that this was not overdefined before, and + // added ourselves to the UsersOfOverdefinedPHIs list for the PHIs, + // make sure to clean out any entries that we put there, for + // efficiency. + std::multimap::iterator It, E; + tie(It, E) = UsersOfOverdefinedPHIs.equal_range(PN1); + while (It != E) { + if (It->second == &I) { + UsersOfOverdefinedPHIs.erase(It++); + } else + ++It; + } + tie(It, E) = UsersOfOverdefinedPHIs.equal_range(PN2); + while (It != E) { + if (It->second == &I) { + UsersOfOverdefinedPHIs.erase(It++); + } else + ++It; + } + } + + markOverdefined(IV, &I); + } else if (V1State.isConstant() && V2State.isConstant()) { + markConstant(IV, &I, ConstantExpr::get(I.getOpcode(), V1State.getConstant(), + V2State.getConstant())); + } +} + +// Handle ICmpInst instruction... +void SCCPSolver::visitCmpInst(CmpInst &I) { + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + LatticeVal &V1State = getValueState(I.getOperand(0)); + LatticeVal &V2State = getValueState(I.getOperand(1)); + + if (V1State.isOverdefined() || V2State.isOverdefined()) { + // If both operands are PHI nodes, it is possible that this instruction has + // a constant value, despite the fact that the PHI node doesn't. Check for + // this condition now. + if (PHINode *PN1 = dyn_cast(I.getOperand(0))) + if (PHINode *PN2 = dyn_cast(I.getOperand(1))) + if (PN1->getParent() == PN2->getParent()) { + // Since the two PHI nodes are in the same basic block, they must have + // entries for the same predecessors. Walk the predecessor list, and + // if all of the incoming values are constants, and the result of + // evaluating this expression with all incoming value pairs is the + // same, then this expression is a constant even though the PHI node + // is not a constant! + LatticeVal Result; + for (unsigned i = 0, e = PN1->getNumIncomingValues(); i != e; ++i) { + LatticeVal &In1 = getValueState(PN1->getIncomingValue(i)); + BasicBlock *InBlock = PN1->getIncomingBlock(i); + LatticeVal &In2 = + getValueState(PN2->getIncomingValueForBlock(InBlock)); + + if (In1.isOverdefined() || In2.isOverdefined()) { + Result.markOverdefined(); + break; // Cannot fold this operation over the PHI nodes! + } else if (In1.isConstant() && In2.isConstant()) { + Constant *V = ConstantExpr::getCompare(I.getPredicate(), + In1.getConstant(), + In2.getConstant()); + if (Result.isUndefined()) + Result.markConstant(V); + else if (Result.isConstant() && Result.getConstant() != V) { + Result.markOverdefined(); + break; + } + } + } + + // If we found a constant value here, then we know the instruction is + // constant despite the fact that the PHI nodes are overdefined. + if (Result.isConstant()) { + markConstant(IV, &I, Result.getConstant()); + // Remember that this instruction is virtually using the PHI node + // operands. + UsersOfOverdefinedPHIs.insert(std::make_pair(PN1, &I)); + UsersOfOverdefinedPHIs.insert(std::make_pair(PN2, &I)); + return; + } else if (Result.isUndefined()) { + return; + } + + // Okay, this really is overdefined now. Since we might have + // speculatively thought that this was not overdefined before, and + // added ourselves to the UsersOfOverdefinedPHIs list for the PHIs, + // make sure to clean out any entries that we put there, for + // efficiency. + std::multimap::iterator It, E; + tie(It, E) = UsersOfOverdefinedPHIs.equal_range(PN1); + while (It != E) { + if (It->second == &I) { + UsersOfOverdefinedPHIs.erase(It++); + } else + ++It; + } + tie(It, E) = UsersOfOverdefinedPHIs.equal_range(PN2); + while (It != E) { + if (It->second == &I) { + UsersOfOverdefinedPHIs.erase(It++); + } else + ++It; + } + } + + markOverdefined(IV, &I); + } else if (V1State.isConstant() && V2State.isConstant()) { + markConstant(IV, &I, ConstantExpr::getCompare(I.getPredicate(), + V1State.getConstant(), + V2State.getConstant())); + } +} + +void SCCPSolver::visitExtractElementInst(ExtractElementInst &I) { + // FIXME : SCCP does not handle vectors properly. + markOverdefined(&I); + return; + +#if 0 + LatticeVal &ValState = getValueState(I.getOperand(0)); + LatticeVal &IdxState = getValueState(I.getOperand(1)); + + if (ValState.isOverdefined() || IdxState.isOverdefined()) + markOverdefined(&I); + else if(ValState.isConstant() && IdxState.isConstant()) + markConstant(&I, ConstantExpr::getExtractElement(ValState.getConstant(), + IdxState.getConstant())); +#endif +} + +void SCCPSolver::visitInsertElementInst(InsertElementInst &I) { + // FIXME : SCCP does not handle vectors properly. + markOverdefined(&I); + return; +#if 0 + LatticeVal &ValState = getValueState(I.getOperand(0)); + LatticeVal &EltState = getValueState(I.getOperand(1)); + LatticeVal &IdxState = getValueState(I.getOperand(2)); + + if (ValState.isOverdefined() || EltState.isOverdefined() || + IdxState.isOverdefined()) + markOverdefined(&I); + else if(ValState.isConstant() && EltState.isConstant() && + IdxState.isConstant()) + markConstant(&I, ConstantExpr::getInsertElement(ValState.getConstant(), + EltState.getConstant(), + IdxState.getConstant())); + else if (ValState.isUndefined() && EltState.isConstant() && + IdxState.isConstant()) + markConstant(&I,ConstantExpr::getInsertElement(UndefValue::get(I.getType()), + EltState.getConstant(), + IdxState.getConstant())); +#endif +} + +void SCCPSolver::visitShuffleVectorInst(ShuffleVectorInst &I) { + // FIXME : SCCP does not handle vectors properly. + markOverdefined(&I); + return; +#if 0 + LatticeVal &V1State = getValueState(I.getOperand(0)); + LatticeVal &V2State = getValueState(I.getOperand(1)); + LatticeVal &MaskState = getValueState(I.getOperand(2)); + + if (MaskState.isUndefined() || + (V1State.isUndefined() && V2State.isUndefined())) + return; // Undefined output if mask or both inputs undefined. + + if (V1State.isOverdefined() || V2State.isOverdefined() || + MaskState.isOverdefined()) { + markOverdefined(&I); + } else { + // A mix of constant/undef inputs. + Constant *V1 = V1State.isConstant() ? + V1State.getConstant() : UndefValue::get(I.getType()); + Constant *V2 = V2State.isConstant() ? + V2State.getConstant() : UndefValue::get(I.getType()); + Constant *Mask = MaskState.isConstant() ? + MaskState.getConstant() : UndefValue::get(I.getOperand(2)->getType()); + markConstant(&I, ConstantExpr::getShuffleVector(V1, V2, Mask)); + } +#endif +} + +// Handle getelementptr instructions... if all operands are constants then we +// can turn this into a getelementptr ConstantExpr. +// +void SCCPSolver::visitGetElementPtrInst(GetElementPtrInst &I) { + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + SmallVector Operands; + Operands.reserve(I.getNumOperands()); + + for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) { + LatticeVal &State = getValueState(I.getOperand(i)); + if (State.isUndefined()) + return; // Operands are not resolved yet... + else if (State.isOverdefined()) { + markOverdefined(IV, &I); + return; + } + assert(State.isConstant() && "Unknown state!"); + Operands.push_back(State.getConstant()); + } + + Constant *Ptr = Operands[0]; + Operands.erase(Operands.begin()); // Erase the pointer from idx list... + + markConstant(IV, &I, ConstantExpr::getGetElementPtr(Ptr, &Operands[0], + Operands.size())); +} + +void SCCPSolver::visitStoreInst(Instruction &SI) { + if (TrackedGlobals.empty() || !isa(SI.getOperand(1))) + return; + GlobalVariable *GV = cast(SI.getOperand(1)); + DenseMap::iterator I = TrackedGlobals.find(GV); + if (I == TrackedGlobals.end() || I->second.isOverdefined()) return; + + // Get the value we are storing into the global. + LatticeVal &PtrVal = getValueState(SI.getOperand(0)); + + mergeInValue(I->second, GV, PtrVal); + if (I->second.isOverdefined()) + TrackedGlobals.erase(I); // No need to keep tracking this! +} + + +// Handle load instructions. If the operand is a constant pointer to a constant +// global, we can replace the load with the loaded constant value! +void SCCPSolver::visitLoadInst(LoadInst &I) { + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + LatticeVal &PtrVal = getValueState(I.getOperand(0)); + if (PtrVal.isUndefined()) return; // The pointer is not resolved yet! + if (PtrVal.isConstant() && !I.isVolatile()) { + Value *Ptr = PtrVal.getConstant(); + if (isa(Ptr)) { + // load null -> null + markConstant(IV, &I, Constant::getNullValue(I.getType())); + return; + } + + // Transform load (constant global) into the value loaded. + if (GlobalVariable *GV = dyn_cast(Ptr)) { + if (GV->isConstant()) { + if (!GV->isDeclaration()) { + markConstant(IV, &I, GV->getInitializer()); + return; + } + } else if (!TrackedGlobals.empty()) { + // If we are tracking this global, merge in the known value for it. + DenseMap::iterator It = + TrackedGlobals.find(GV); + if (It != TrackedGlobals.end()) { + mergeInValue(IV, &I, It->second); + return; + } + } + } + + // Transform load (constantexpr_GEP global, 0, ...) into the value loaded. + if (ConstantExpr *CE = dyn_cast(Ptr)) + if (CE->getOpcode() == Instruction::GetElementPtr) + if (GlobalVariable *GV = dyn_cast(CE->getOperand(0))) + if (GV->isConstant() && !GV->isDeclaration()) + if (Constant *V = + ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE)) { + markConstant(IV, &I, V); + return; + } + } + + // Otherwise we cannot say for certain what value this load will produce. + // Bail out. + markOverdefined(IV, &I); +} + +void SCCPSolver::visitCallSite(CallSite CS) { + Function *F = CS.getCalledFunction(); + + // If we are tracking this function, we must make sure to bind arguments as + // appropriate. + DenseMap::iterator TFRVI =TrackedFunctionRetVals.end(); + if (F && F->hasInternalLinkage()) + TFRVI = TrackedFunctionRetVals.find(F); + + if (TFRVI != TrackedFunctionRetVals.end()) { + // If this is the first call to the function hit, mark its entry block + // executable. + if (!BBExecutable.count(F->begin())) + MarkBlockExecutable(F->begin()); + + CallSite::arg_iterator CAI = CS.arg_begin(); + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); + AI != E; ++AI, ++CAI) { + LatticeVal &IV = ValueState[AI]; + if (!IV.isOverdefined()) + mergeInValue(IV, AI, getValueState(*CAI)); + } + } + Instruction *I = CS.getInstruction(); + if (I->getType() == Type::VoidTy) return; + + LatticeVal &IV = ValueState[I]; + if (IV.isOverdefined()) return; + + // Propagate the return value of the function to the value of the instruction. + if (TFRVI != TrackedFunctionRetVals.end()) { + mergeInValue(IV, I, TFRVI->second); + return; + } + + if (F == 0 || !F->isDeclaration() || !canConstantFoldCallTo(F)) { + markOverdefined(IV, I); + return; + } + + SmallVector Operands; + Operands.reserve(I->getNumOperands()-1); + + for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end(); + AI != E; ++AI) { + LatticeVal &State = getValueState(*AI); + if (State.isUndefined()) + return; // Operands are not resolved yet... + else if (State.isOverdefined()) { + markOverdefined(IV, I); + return; + } + assert(State.isConstant() && "Unknown state!"); + Operands.push_back(State.getConstant()); + } + + if (Constant *C = ConstantFoldCall(F, &Operands[0], Operands.size())) + markConstant(IV, I, C); + else + markOverdefined(IV, I); +} + + +void SCCPSolver::Solve() { + // Process the work lists until they are empty! + while (!BBWorkList.empty() || !InstWorkList.empty() || + !OverdefinedInstWorkList.empty()) { + // Process the instruction work list... + while (!OverdefinedInstWorkList.empty()) { + Value *I = OverdefinedInstWorkList.back(); + OverdefinedInstWorkList.pop_back(); + + DOUT << "\nPopped off OI-WL: " << *I; + + // "I" got into the work list because it either made the transition from + // bottom to constant + // + // Anything on this worklist that is overdefined need not be visited + // since all of its users will have already been marked as overdefined + // Update all of the users of this instruction's value... + // + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) + OperandChangedState(*UI); + } + // Process the instruction work list... + while (!InstWorkList.empty()) { + Value *I = InstWorkList.back(); + InstWorkList.pop_back(); + + DOUT << "\nPopped off I-WL: " << *I; + + // "I" got into the work list because it either made the transition from + // bottom to constant + // + // Anything on this worklist that is overdefined need not be visited + // since all of its users will have already been marked as overdefined. + // Update all of the users of this instruction's value... + // + if (!getValueState(I).isOverdefined()) + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) + OperandChangedState(*UI); + } + + // Process the basic block work list... + while (!BBWorkList.empty()) { + BasicBlock *BB = BBWorkList.back(); + BBWorkList.pop_back(); + + DOUT << "\nPopped off BBWL: " << *BB; + + // Notify all instructions in this basic block that they are newly + // executable. + visit(BB); + } + } +} + +/// ResolvedUndefsIn - While solving the dataflow for a function, we assume +/// that branches on undef values cannot reach any of their successors. +/// However, this is not a safe assumption. After we solve dataflow, this +/// method should be use to handle this. If this returns true, the solver +/// should be rerun. +/// +/// This method handles this by finding an unresolved branch and marking it one +/// of the edges from the block as being feasible, even though the condition +/// doesn't say it would otherwise be. This allows SCCP to find the rest of the +/// CFG and only slightly pessimizes the analysis results (by marking one, +/// potentially infeasible, edge feasible). This cannot usefully modify the +/// constraints on the condition of the branch, as that would impact other users +/// of the value. +/// +/// This scan also checks for values that use undefs, whose results are actually +/// defined. For example, 'zext i8 undef to i32' should produce all zeros +/// conservatively, as "(zext i8 X -> i32) & 0xFF00" must always return zero, +/// even if X isn't defined. +bool SCCPSolver::ResolvedUndefsIn(Function &F) { + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { + if (!BBExecutable.count(BB)) + continue; + + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + // Look for instructions which produce undef values. + if (I->getType() == Type::VoidTy) continue; + + LatticeVal &LV = getValueState(I); + if (!LV.isUndefined()) continue; + + // Get the lattice values of the first two operands for use below. + LatticeVal &Op0LV = getValueState(I->getOperand(0)); + LatticeVal Op1LV; + if (I->getNumOperands() == 2) { + // If this is a two-operand instruction, and if both operands are + // undefs, the result stays undef. + Op1LV = getValueState(I->getOperand(1)); + if (Op0LV.isUndefined() && Op1LV.isUndefined()) + continue; + } + + // If this is an instructions whose result is defined even if the input is + // not fully defined, propagate the information. + const Type *ITy = I->getType(); + switch (I->getOpcode()) { + default: break; // Leave the instruction as an undef. + case Instruction::ZExt: + // After a zero extend, we know the top part is zero. SExt doesn't have + // to be handled here, because we don't know whether the top part is 1's + // or 0's. + assert(Op0LV.isUndefined()); + markForcedConstant(LV, I, Constant::getNullValue(ITy)); + return true; + case Instruction::Mul: + case Instruction::And: + // undef * X -> 0. X could be zero. + // undef & X -> 0. X could be zero. + markForcedConstant(LV, I, Constant::getNullValue(ITy)); + return true; + + case Instruction::Or: + // undef | X -> -1. X could be -1. + if (const VectorType *PTy = dyn_cast(ITy)) + markForcedConstant(LV, I, ConstantVector::getAllOnesValue(PTy)); + else + markForcedConstant(LV, I, ConstantInt::getAllOnesValue(ITy)); + return true; + + case Instruction::SDiv: + case Instruction::UDiv: + case Instruction::SRem: + case Instruction::URem: + // X / undef -> undef. No change. + // X % undef -> undef. No change. + if (Op1LV.isUndefined()) break; + + // undef / X -> 0. X could be maxint. + // undef % X -> 0. X could be 1. + markForcedConstant(LV, I, Constant::getNullValue(ITy)); + return true; + + case Instruction::AShr: + // undef >>s X -> undef. No change. + if (Op0LV.isUndefined()) break; + + // X >>s undef -> X. X could be 0, X could have the high-bit known set. + if (Op0LV.isConstant()) + markForcedConstant(LV, I, Op0LV.getConstant()); + else + markOverdefined(LV, I); + return true; + case Instruction::LShr: + case Instruction::Shl: + // undef >> X -> undef. No change. + // undef << X -> undef. No change. + if (Op0LV.isUndefined()) break; + + // X >> undef -> 0. X could be 0. + // X << undef -> 0. X could be 0. + markForcedConstant(LV, I, Constant::getNullValue(ITy)); + return true; + case Instruction::Select: + // undef ? X : Y -> X or Y. There could be commonality between X/Y. + if (Op0LV.isUndefined()) { + if (!Op1LV.isConstant()) // Pick the constant one if there is any. + Op1LV = getValueState(I->getOperand(2)); + } else if (Op1LV.isUndefined()) { + // c ? undef : undef -> undef. No change. + Op1LV = getValueState(I->getOperand(2)); + if (Op1LV.isUndefined()) + break; + // Otherwise, c ? undef : x -> x. + } else { + // Leave Op1LV as Operand(1)'s LatticeValue. + } + + if (Op1LV.isConstant()) + markForcedConstant(LV, I, Op1LV.getConstant()); + else + markOverdefined(LV, I); + return true; + } + } + + TerminatorInst *TI = BB->getTerminator(); + if (BranchInst *BI = dyn_cast(TI)) { + if (!BI->isConditional()) continue; + if (!getValueState(BI->getCondition()).isUndefined()) + continue; + } else if (SwitchInst *SI = dyn_cast(TI)) { + if (!getValueState(SI->getCondition()).isUndefined()) + continue; + } else { + continue; + } + + // If the edge to the first successor isn't thought to be feasible yet, mark + // it so now. + if (KnownFeasibleEdges.count(Edge(BB, TI->getSuccessor(0)))) + continue; + + // Otherwise, it isn't already thought to be feasible. Mark it as such now + // and return. This will make other blocks reachable, which will allow new + // values to be discovered and existing ones to be moved in the lattice. + markEdgeExecutable(BB, TI->getSuccessor(0)); + return true; + } + + return false; +} + + +namespace { + //===--------------------------------------------------------------------===// + // + /// SCCP Class - This class uses the SCCPSolver to implement a per-function + /// Sparse Conditional Constant Propagator. + /// + struct VISIBILITY_HIDDEN SCCP : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + SCCP() : FunctionPass((intptr_t)&ID) {} + + // runOnFunction - Run the Sparse Conditional Constant Propagation + // algorithm, and return true if the function was modified. + // + bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + } + }; + + char SCCP::ID = 0; + RegisterPass X("sccp", "Sparse Conditional Constant Propagation"); +} // end anonymous namespace + + +// createSCCPPass - This is the public interface to this file... +FunctionPass *llvm::createSCCPPass() { + return new SCCP(); +} + + +// runOnFunction() - Run the Sparse Conditional Constant Propagation algorithm, +// and return true if the function was modified. +// +bool SCCP::runOnFunction(Function &F) { + DOUT << "SCCP on function '" << F.getName() << "'\n"; + SCCPSolver Solver; + + // Mark the first block of the function as being executable. + Solver.MarkBlockExecutable(F.begin()); + + // Mark all arguments to the function as being overdefined. + for (Function::arg_iterator AI = F.arg_begin(), E = F.arg_end(); AI != E;++AI) + Solver.markOverdefined(AI); + + // Solve for constants. + bool ResolvedUndefs = true; + while (ResolvedUndefs) { + Solver.Solve(); + DOUT << "RESOLVING UNDEFs\n"; + ResolvedUndefs = Solver.ResolvedUndefsIn(F); + } + + bool MadeChanges = false; + + // If we decided that there are basic blocks that are dead in this function, + // delete their contents now. Note that we cannot actually delete the blocks, + // as we cannot modify the CFG of the function. + // + SmallSet &ExecutableBBs = Solver.getExecutableBlocks(); + SmallVector Insts; + std::map &Values = Solver.getValueMapping(); + + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (!ExecutableBBs.count(BB)) { + DOUT << " BasicBlock Dead:" << *BB; + ++NumDeadBlocks; + + // Delete the instructions backwards, as it has a reduced likelihood of + // having to update as many def-use and use-def chains. + for (BasicBlock::iterator I = BB->begin(), E = BB->getTerminator(); + I != E; ++I) + Insts.push_back(I); + while (!Insts.empty()) { + Instruction *I = Insts.back(); + Insts.pop_back(); + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + BB->getInstList().erase(I); + MadeChanges = true; + ++NumInstRemoved; + } + } else { + // Iterate over all of the instructions in a function, replacing them with + // constants if we have found them to be of constant values. + // + for (BasicBlock::iterator BI = BB->begin(), E = BB->end(); BI != E; ) { + Instruction *Inst = BI++; + if (Inst->getType() != Type::VoidTy) { + LatticeVal &IV = Values[Inst]; + if ((IV.isConstant() || IV.isUndefined()) && + !isa(Inst)) { + Constant *Const = IV.isConstant() + ? IV.getConstant() : UndefValue::get(Inst->getType()); + DOUT << " Constant: " << *Const << " = " << *Inst; + + // Replaces all of the uses of a variable with uses of the constant. + Inst->replaceAllUsesWith(Const); + + // Delete the instruction. + BB->getInstList().erase(Inst); + + // Hey, we just changed something! + MadeChanges = true; + ++NumInstRemoved; + } + } + } + } + + return MadeChanges; +} + +namespace { + //===--------------------------------------------------------------------===// + // + /// IPSCCP Class - This class implements interprocedural Sparse Conditional + /// Constant Propagation. + /// + struct VISIBILITY_HIDDEN IPSCCP : public ModulePass { + static char ID; + IPSCCP() : ModulePass((intptr_t)&ID) {} + bool runOnModule(Module &M); + }; + + char IPSCCP::ID = 0; + RegisterPass + Y("ipsccp", "Interprocedural Sparse Conditional Constant Propagation"); +} // end anonymous namespace + +// createIPSCCPPass - This is the public interface to this file... +ModulePass *llvm::createIPSCCPPass() { + return new IPSCCP(); +} + + +static bool AddressIsTaken(GlobalValue *GV) { + // Delete any dead constantexpr klingons. + GV->removeDeadConstantUsers(); + + for (Value::use_iterator UI = GV->use_begin(), E = GV->use_end(); + UI != E; ++UI) + if (StoreInst *SI = dyn_cast(*UI)) { + if (SI->getOperand(0) == GV || SI->isVolatile()) + return true; // Storing addr of GV. + } else if (isa(*UI) || isa(*UI)) { + // Make sure we are calling the function, not passing the address. + CallSite CS = CallSite::get(cast(*UI)); + for (CallSite::arg_iterator AI = CS.arg_begin(), + E = CS.arg_end(); AI != E; ++AI) + if (*AI == GV) + return true; + } else if (LoadInst *LI = dyn_cast(*UI)) { + if (LI->isVolatile()) + return true; + } else { + return true; + } + return false; +} + +bool IPSCCP::runOnModule(Module &M) { + SCCPSolver Solver; + + // Loop over all functions, marking arguments to those with their addresses + // taken or that are external as overdefined. + // + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) + if (!F->hasInternalLinkage() || AddressIsTaken(F)) { + if (!F->isDeclaration()) + Solver.MarkBlockExecutable(F->begin()); + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); + AI != E; ++AI) + Solver.markOverdefined(AI); + } else { + Solver.AddTrackedFunction(F); + } + + // Loop over global variables. We inform the solver about any internal global + // variables that do not have their 'addresses taken'. If they don't have + // their addresses taken, we can propagate constants through them. + for (Module::global_iterator G = M.global_begin(), E = M.global_end(); + G != E; ++G) + if (!G->isConstant() && G->hasInternalLinkage() && !AddressIsTaken(G)) + Solver.TrackValueOfGlobalVariable(G); + + // Solve for constants. + bool ResolvedUndefs = true; + while (ResolvedUndefs) { + Solver.Solve(); + + DOUT << "RESOLVING UNDEFS\n"; + ResolvedUndefs = false; + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) + ResolvedUndefs |= Solver.ResolvedUndefsIn(*F); + } + + bool MadeChanges = false; + + // Iterate over all of the instructions in the module, replacing them with + // constants if we have found them to be of constant values. + // + SmallSet &ExecutableBBs = Solver.getExecutableBlocks(); + SmallVector Insts; + SmallVector BlocksToErase; + std::map &Values = Solver.getValueMapping(); + + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) { + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); + AI != E; ++AI) + if (!AI->use_empty()) { + LatticeVal &IV = Values[AI]; + if (IV.isConstant() || IV.isUndefined()) { + Constant *CST = IV.isConstant() ? + IV.getConstant() : UndefValue::get(AI->getType()); + DOUT << "*** Arg " << *AI << " = " << *CST <<"\n"; + + // Replaces all of the uses of a variable with uses of the + // constant. + AI->replaceAllUsesWith(CST); + ++IPNumArgsElimed; + } + } + + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + if (!ExecutableBBs.count(BB)) { + DOUT << " BasicBlock Dead:" << *BB; + ++IPNumDeadBlocks; + + // Delete the instructions backwards, as it has a reduced likelihood of + // having to update as many def-use and use-def chains. + TerminatorInst *TI = BB->getTerminator(); + for (BasicBlock::iterator I = BB->begin(), E = TI; I != E; ++I) + Insts.push_back(I); + + while (!Insts.empty()) { + Instruction *I = Insts.back(); + Insts.pop_back(); + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + BB->getInstList().erase(I); + MadeChanges = true; + ++IPNumInstRemoved; + } + + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + BasicBlock *Succ = TI->getSuccessor(i); + if (Succ->begin() != Succ->end() && isa(Succ->begin())) + TI->getSuccessor(i)->removePredecessor(BB); + } + if (!TI->use_empty()) + TI->replaceAllUsesWith(UndefValue::get(TI->getType())); + BB->getInstList().erase(TI); + + if (&*BB != &F->front()) + BlocksToErase.push_back(BB); + else + new UnreachableInst(BB); + + } else { + for (BasicBlock::iterator BI = BB->begin(), E = BB->end(); BI != E; ) { + Instruction *Inst = BI++; + if (Inst->getType() != Type::VoidTy) { + LatticeVal &IV = Values[Inst]; + if (IV.isConstant() || IV.isUndefined() && + !isa(Inst)) { + Constant *Const = IV.isConstant() + ? IV.getConstant() : UndefValue::get(Inst->getType()); + DOUT << " Constant: " << *Const << " = " << *Inst; + + // Replaces all of the uses of a variable with uses of the + // constant. + Inst->replaceAllUsesWith(Const); + + // Delete the instruction. + if (!isa(Inst) && !isa(Inst)) + BB->getInstList().erase(Inst); + + // Hey, we just changed something! + MadeChanges = true; + ++IPNumInstRemoved; + } + } + } + } + + // Now that all instructions in the function are constant folded, erase dead + // blocks, because we can now use ConstantFoldTerminator to get rid of + // in-edges. + for (unsigned i = 0, e = BlocksToErase.size(); i != e; ++i) { + // If there are any PHI nodes in this successor, drop entries for BB now. + BasicBlock *DeadBB = BlocksToErase[i]; + while (!DeadBB->use_empty()) { + Instruction *I = cast(DeadBB->use_back()); + bool Folded = ConstantFoldTerminator(I->getParent()); + if (!Folded) { + // The constant folder may not have been able to fold the terminator + // if this is a branch or switch on undef. Fold it manually as a + // branch to the first successor. + if (BranchInst *BI = dyn_cast(I)) { + assert(BI->isConditional() && isa(BI->getCondition()) && + "Branch should be foldable!"); + } else if (SwitchInst *SI = dyn_cast(I)) { + assert(isa(SI->getCondition()) && "Switch should fold"); + } else { + assert(0 && "Didn't fold away reference to block!"); + } + + // Make this an uncond branch to the first successor. + TerminatorInst *TI = I->getParent()->getTerminator(); + new BranchInst(TI->getSuccessor(0), TI); + + // Remove entries in successor phi nodes to remove edges. + for (unsigned i = 1, e = TI->getNumSuccessors(); i != e; ++i) + TI->getSuccessor(i)->removePredecessor(TI->getParent()); + + // Remove the old terminator. + TI->eraseFromParent(); + } + } + + // Finally, delete the basic block. + F->getBasicBlockList().erase(DeadBB); + } + BlocksToErase.clear(); + } + + // If we inferred constant or undef return values for a function, we replaced + // all call uses with the inferred value. This means we don't need to bother + // actually returning anything from the function. Replace all return + // instructions with return undef. + const DenseMap &RV =Solver.getTrackedFunctionRetVals(); + for (DenseMap::const_iterator I = RV.begin(), + E = RV.end(); I != E; ++I) + if (!I->second.isOverdefined() && + I->first->getReturnType() != Type::VoidTy) { + Function *F = I->first; + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + if (ReturnInst *RI = dyn_cast(BB->getTerminator())) + if (!isa(RI->getOperand(0))) + RI->setOperand(0, UndefValue::get(F->getReturnType())); + } + + // If we infered constant or undef values for globals variables, we can delete + // the global and any stores that remain to it. + const DenseMap &TG = Solver.getTrackedGlobals(); + for (DenseMap::const_iterator I = TG.begin(), + E = TG.end(); I != E; ++I) { + GlobalVariable *GV = I->first; + assert(!I->second.isOverdefined() && + "Overdefined values should have been taken out of the map!"); + DOUT << "Found that GV '" << GV->getName()<< "' is constant!\n"; + while (!GV->use_empty()) { + StoreInst *SI = cast(GV->use_back()); + SI->eraseFromParent(); + } + M.getGlobalList().erase(GV); + ++IPNumGlobalConst; + } + + return MadeChanges; +} diff --git a/lib/Transforms/Scalar/ScalarReplAggregates.cpp b/lib/Transforms/Scalar/ScalarReplAggregates.cpp new file mode 100644 index 0000000..e303468 --- /dev/null +++ b/lib/Transforms/Scalar/ScalarReplAggregates.cpp @@ -0,0 +1,1335 @@ +//===- ScalarReplAggregates.cpp - Scalar Replacement of Aggregates --------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This transformation implements the well known scalar replacement of +// aggregates transformation. This xform breaks up alloca instructions of +// aggregate type (structure or array) into individual alloca instructions for +// each member (if possible). Then, if possible, it transforms the individual +// alloca instructions into nice clean scalar SSA form. +// +// This combines a simple SRoA algorithm with the Mem2Reg algorithm because +// often interact, especially for C++ programs. As such, iterating between +// SRoA, then Mem2Reg until we run out of things to promote works well. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "scalarrepl" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/GlobalVariable.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/Compiler.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +using namespace llvm; + +STATISTIC(NumReplaced, "Number of allocas broken up"); +STATISTIC(NumPromoted, "Number of allocas promoted"); +STATISTIC(NumConverted, "Number of aggregates converted to scalar"); +STATISTIC(NumGlobals, "Number of allocas copied from constant global"); + +namespace { + struct VISIBILITY_HIDDEN SROA : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + SROA(signed T = -1) : FunctionPass((intptr_t)&ID) { + if (T == -1) + SRThreshold = 128; + else + SRThreshold = T; + } + + bool runOnFunction(Function &F); + + bool performScalarRepl(Function &F); + bool performPromotion(Function &F); + + // getAnalysisUsage - This pass does not require any passes, but we know it + // will not alter the CFG, so say so. + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.setPreservesCFG(); + } + + private: + /// AllocaInfo - When analyzing uses of an alloca instruction, this captures + /// information about the uses. All these fields are initialized to false + /// and set to true when something is learned. + struct AllocaInfo { + /// isUnsafe - This is set to true if the alloca cannot be SROA'd. + bool isUnsafe : 1; + + /// needsCanon - This is set to true if there is some use of the alloca + /// that requires canonicalization. + bool needsCanon : 1; + + /// isMemCpySrc - This is true if this aggregate is memcpy'd from. + bool isMemCpySrc : 1; + + /// isMemCpyDst - This is true if this aggregate is memcpy'd into. + bool isMemCpyDst : 1; + + AllocaInfo() + : isUnsafe(false), needsCanon(false), + isMemCpySrc(false), isMemCpyDst(false) {} + }; + + unsigned SRThreshold; + + void MarkUnsafe(AllocaInfo &I) { I.isUnsafe = true; } + + int isSafeAllocaToScalarRepl(AllocationInst *AI); + + void isSafeUseOfAllocation(Instruction *User, AllocationInst *AI, + AllocaInfo &Info); + void isSafeElementUse(Value *Ptr, bool isFirstElt, AllocationInst *AI, + AllocaInfo &Info); + void isSafeMemIntrinsicOnAllocation(MemIntrinsic *MI, AllocationInst *AI, + unsigned OpNo, AllocaInfo &Info); + void isSafeUseOfBitCastedAllocation(BitCastInst *User, AllocationInst *AI, + AllocaInfo &Info); + + void DoScalarReplacement(AllocationInst *AI, + std::vector &WorkList); + void CanonicalizeAllocaUsers(AllocationInst *AI); + AllocaInst *AddNewAlloca(Function &F, const Type *Ty, AllocationInst *Base); + + void RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI, + SmallVector &NewElts); + + const Type *CanConvertToScalar(Value *V, bool &IsNotTrivial); + void ConvertToScalar(AllocationInst *AI, const Type *Ty); + void ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, unsigned Offset); + static Instruction *isOnlyCopiedFromConstantGlobal(AllocationInst *AI); + }; + + char SROA::ID = 0; + RegisterPass X("scalarrepl", "Scalar Replacement of Aggregates"); +} + +// Public interface to the ScalarReplAggregates pass +FunctionPass *llvm::createScalarReplAggregatesPass(signed int Threshold) { + return new SROA(Threshold); +} + + +bool SROA::runOnFunction(Function &F) { + bool Changed = performPromotion(F); + while (1) { + bool LocalChange = performScalarRepl(F); + if (!LocalChange) break; // No need to repromote if no scalarrepl + Changed = true; + LocalChange = performPromotion(F); + if (!LocalChange) break; // No need to re-scalarrepl if no promotion + } + + return Changed; +} + + +bool SROA::performPromotion(Function &F) { + std::vector Allocas; + DominatorTree &DT = getAnalysis(); + DominanceFrontier &DF = getAnalysis(); + + BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function + + bool Changed = false; + + while (1) { + Allocas.clear(); + + // Find allocas that are safe to promote, by looking at all instructions in + // the entry node + for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I) + if (AllocaInst *AI = dyn_cast(I)) // Is it an alloca? + if (isAllocaPromotable(AI)) + Allocas.push_back(AI); + + if (Allocas.empty()) break; + + PromoteMemToReg(Allocas, DT, DF); + NumPromoted += Allocas.size(); + Changed = true; + } + + return Changed; +} + +// performScalarRepl - This algorithm is a simple worklist driven algorithm, +// which runs on all of the malloc/alloca instructions in the function, removing +// them if they are only used by getelementptr instructions. +// +bool SROA::performScalarRepl(Function &F) { + std::vector WorkList; + + // Scan the entry basic block, adding any alloca's and mallocs to the worklist + BasicBlock &BB = F.getEntryBlock(); + for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) + if (AllocationInst *A = dyn_cast(I)) + WorkList.push_back(A); + + const TargetData &TD = getAnalysis(); + + // Process the worklist + bool Changed = false; + while (!WorkList.empty()) { + AllocationInst *AI = WorkList.back(); + WorkList.pop_back(); + + // Handle dead allocas trivially. These can be formed by SROA'ing arrays + // with unused elements. + if (AI->use_empty()) { + AI->eraseFromParent(); + continue; + } + + // If we can turn this aggregate value (potentially with casts) into a + // simple scalar value that can be mem2reg'd into a register value. + bool IsNotTrivial = false; + if (const Type *ActualType = CanConvertToScalar(AI, IsNotTrivial)) + if (IsNotTrivial && ActualType != Type::VoidTy) { + ConvertToScalar(AI, ActualType); + Changed = true; + continue; + } + + // Check to see if we can perform the core SROA transformation. We cannot + // transform the allocation instruction if it is an array allocation + // (allocations OF arrays are ok though), and an allocation of a scalar + // value cannot be decomposed at all. + if (!AI->isArrayAllocation() && + (isa(AI->getAllocatedType()) || + isa(AI->getAllocatedType())) && + AI->getAllocatedType()->isSized() && + TD.getTypeSize(AI->getAllocatedType()) < SRThreshold) { + // Check that all of the users of the allocation are capable of being + // transformed. + switch (isSafeAllocaToScalarRepl(AI)) { + default: assert(0 && "Unexpected value!"); + case 0: // Not safe to scalar replace. + break; + case 1: // Safe, but requires cleanup/canonicalizations first + CanonicalizeAllocaUsers(AI); + // FALL THROUGH. + case 3: // Safe to scalar replace. + DoScalarReplacement(AI, WorkList); + Changed = true; + continue; + } + } + + // Check to see if this allocation is only modified by a memcpy/memmove from + // a constant global. If this is the case, we can change all users to use + // the constant global instead. This is commonly produced by the CFE by + // constructs like "void foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A' + // is only subsequently read. + if (Instruction *TheCopy = isOnlyCopiedFromConstantGlobal(AI)) { + DOUT << "Found alloca equal to global: " << *AI; + DOUT << " memcpy = " << *TheCopy; + Constant *TheSrc = cast(TheCopy->getOperand(2)); + AI->replaceAllUsesWith(ConstantExpr::getBitCast(TheSrc, AI->getType())); + TheCopy->eraseFromParent(); // Don't mutate the global. + AI->eraseFromParent(); + ++NumGlobals; + Changed = true; + continue; + } + + // Otherwise, couldn't process this. + } + + return Changed; +} + +/// DoScalarReplacement - This alloca satisfied the isSafeAllocaToScalarRepl +/// predicate, do SROA now. +void SROA::DoScalarReplacement(AllocationInst *AI, + std::vector &WorkList) { + DOUT << "Found inst to SROA: " << *AI; + SmallVector ElementAllocas; + if (const StructType *ST = dyn_cast(AI->getAllocatedType())) { + ElementAllocas.reserve(ST->getNumContainedTypes()); + for (unsigned i = 0, e = ST->getNumContainedTypes(); i != e; ++i) { + AllocaInst *NA = new AllocaInst(ST->getContainedType(i), 0, + AI->getAlignment(), + AI->getName() + "." + utostr(i), AI); + ElementAllocas.push_back(NA); + WorkList.push_back(NA); // Add to worklist for recursive processing + } + } else { + const ArrayType *AT = cast(AI->getAllocatedType()); + ElementAllocas.reserve(AT->getNumElements()); + const Type *ElTy = AT->getElementType(); + for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) { + AllocaInst *NA = new AllocaInst(ElTy, 0, AI->getAlignment(), + AI->getName() + "." + utostr(i), AI); + ElementAllocas.push_back(NA); + WorkList.push_back(NA); // Add to worklist for recursive processing + } + } + + // Now that we have created the alloca instructions that we want to use, + // expand the getelementptr instructions to use them. + // + while (!AI->use_empty()) { + Instruction *User = cast(AI->use_back()); + if (BitCastInst *BCInst = dyn_cast(User)) { + RewriteBitCastUserOfAlloca(BCInst, AI, ElementAllocas); + BCInst->eraseFromParent(); + continue; + } + + GetElementPtrInst *GEPI = cast(User); + // We now know that the GEP is of the form: GEP , 0, + unsigned Idx = + (unsigned)cast(GEPI->getOperand(2))->getZExtValue(); + + assert(Idx < ElementAllocas.size() && "Index out of range?"); + AllocaInst *AllocaToUse = ElementAllocas[Idx]; + + Value *RepValue; + if (GEPI->getNumOperands() == 3) { + // Do not insert a new getelementptr instruction with zero indices, only + // to have it optimized out later. + RepValue = AllocaToUse; + } else { + // We are indexing deeply into the structure, so we still need a + // getelement ptr instruction to finish the indexing. This may be + // expanded itself once the worklist is rerun. + // + SmallVector NewArgs; + NewArgs.push_back(Constant::getNullValue(Type::Int32Ty)); + NewArgs.append(GEPI->op_begin()+3, GEPI->op_end()); + RepValue = new GetElementPtrInst(AllocaToUse, &NewArgs[0], + NewArgs.size(), "", GEPI); + RepValue->takeName(GEPI); + } + + // If this GEP is to the start of the aggregate, check for memcpys. + if (Idx == 0) { + bool IsStartOfAggregateGEP = true; + for (unsigned i = 3, e = GEPI->getNumOperands(); i != e; ++i) { + if (!isa(GEPI->getOperand(i))) { + IsStartOfAggregateGEP = false; + break; + } + if (!cast(GEPI->getOperand(i))->isZero()) { + IsStartOfAggregateGEP = false; + break; + } + } + + if (IsStartOfAggregateGEP) + RewriteBitCastUserOfAlloca(GEPI, AI, ElementAllocas); + } + + + // Move all of the users over to the new GEP. + GEPI->replaceAllUsesWith(RepValue); + // Delete the old GEP + GEPI->eraseFromParent(); + } + + // Finally, delete the Alloca instruction + AI->eraseFromParent(); + NumReplaced++; +} + + +/// isSafeElementUse - Check to see if this use is an allowed use for a +/// getelementptr instruction of an array aggregate allocation. isFirstElt +/// indicates whether Ptr is known to the start of the aggregate. +/// +void SROA::isSafeElementUse(Value *Ptr, bool isFirstElt, AllocationInst *AI, + AllocaInfo &Info) { + for (Value::use_iterator I = Ptr->use_begin(), E = Ptr->use_end(); + I != E; ++I) { + Instruction *User = cast(*I); + switch (User->getOpcode()) { + case Instruction::Load: break; + case Instruction::Store: + // Store is ok if storing INTO the pointer, not storing the pointer + if (User->getOperand(0) == Ptr) return MarkUnsafe(Info); + break; + case Instruction::GetElementPtr: { + GetElementPtrInst *GEP = cast(User); + bool AreAllZeroIndices = isFirstElt; + if (GEP->getNumOperands() > 1) { + if (!isa(GEP->getOperand(1)) || + !cast(GEP->getOperand(1))->isZero()) + // Using pointer arithmetic to navigate the array. + return MarkUnsafe(Info); + + if (AreAllZeroIndices) { + for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i) { + if (!isa(GEP->getOperand(i)) || + !cast(GEP->getOperand(i))->isZero()) { + AreAllZeroIndices = false; + break; + } + } + } + } + isSafeElementUse(GEP, AreAllZeroIndices, AI, Info); + if (Info.isUnsafe) return; + break; + } + case Instruction::BitCast: + if (isFirstElt) { + isSafeUseOfBitCastedAllocation(cast(User), AI, Info); + if (Info.isUnsafe) return; + break; + } + DOUT << " Transformation preventing inst: " << *User; + return MarkUnsafe(Info); + case Instruction::Call: + if (MemIntrinsic *MI = dyn_cast(User)) { + if (isFirstElt) { + isSafeMemIntrinsicOnAllocation(MI, AI, I.getOperandNo(), Info); + if (Info.isUnsafe) return; + break; + } + } + DOUT << " Transformation preventing inst: " << *User; + return MarkUnsafe(Info); + default: + DOUT << " Transformation preventing inst: " << *User; + return MarkUnsafe(Info); + } + } + return; // All users look ok :) +} + +/// AllUsersAreLoads - Return true if all users of this value are loads. +static bool AllUsersAreLoads(Value *Ptr) { + for (Value::use_iterator I = Ptr->use_begin(), E = Ptr->use_end(); + I != E; ++I) + if (cast(*I)->getOpcode() != Instruction::Load) + return false; + return true; +} + +/// isSafeUseOfAllocation - Check to see if this user is an allowed use for an +/// aggregate allocation. +/// +void SROA::isSafeUseOfAllocation(Instruction *User, AllocationInst *AI, + AllocaInfo &Info) { + if (BitCastInst *C = dyn_cast(User)) + return isSafeUseOfBitCastedAllocation(C, AI, Info); + + GetElementPtrInst *GEPI = dyn_cast(User); + if (GEPI == 0) + return MarkUnsafe(Info); + + gep_type_iterator I = gep_type_begin(GEPI), E = gep_type_end(GEPI); + + // The GEP is not safe to transform if not of the form "GEP , 0, ". + if (I == E || + I.getOperand() != Constant::getNullValue(I.getOperand()->getType())) { + return MarkUnsafe(Info); + } + + ++I; + if (I == E) return MarkUnsafe(Info); // ran out of GEP indices?? + + bool IsAllZeroIndices = true; + + // If this is a use of an array allocation, do a bit more checking for sanity. + if (const ArrayType *AT = dyn_cast(*I)) { + uint64_t NumElements = AT->getNumElements(); + + if (ConstantInt *Idx = dyn_cast(I.getOperand())) { + IsAllZeroIndices &= Idx->isZero(); + + // Check to make sure that index falls within the array. If not, + // something funny is going on, so we won't do the optimization. + // + if (Idx->getZExtValue() >= NumElements) + return MarkUnsafe(Info); + + // We cannot scalar repl this level of the array unless any array + // sub-indices are in-range constants. In particular, consider: + // A[0][i]. We cannot know that the user isn't doing invalid things like + // allowing i to index an out-of-range subscript that accesses A[1]. + // + // Scalar replacing *just* the outer index of the array is probably not + // going to be a win anyway, so just give up. + for (++I; I != E && (isa(*I) || isa(*I)); ++I) { + uint64_t NumElements; + if (const ArrayType *SubArrayTy = dyn_cast(*I)) + NumElements = SubArrayTy->getNumElements(); + else + NumElements = cast(*I)->getNumElements(); + + ConstantInt *IdxVal = dyn_cast(I.getOperand()); + if (!IdxVal) return MarkUnsafe(Info); + if (IdxVal->getZExtValue() >= NumElements) + return MarkUnsafe(Info); + IsAllZeroIndices &= IdxVal->isZero(); + } + + } else { + IsAllZeroIndices = 0; + + // If this is an array index and the index is not constant, we cannot + // promote... that is unless the array has exactly one or two elements in + // it, in which case we CAN promote it, but we have to canonicalize this + // out if this is the only problem. + if ((NumElements == 1 || NumElements == 2) && + AllUsersAreLoads(GEPI)) { + Info.needsCanon = true; + return; // Canonicalization required! + } + return MarkUnsafe(Info); + } + } + + // If there are any non-simple uses of this getelementptr, make sure to reject + // them. + return isSafeElementUse(GEPI, IsAllZeroIndices, AI, Info); +} + +/// isSafeMemIntrinsicOnAllocation - Return true if the specified memory +/// intrinsic can be promoted by SROA. At this point, we know that the operand +/// of the memintrinsic is a pointer to the beginning of the allocation. +void SROA::isSafeMemIntrinsicOnAllocation(MemIntrinsic *MI, AllocationInst *AI, + unsigned OpNo, AllocaInfo &Info) { + // If not constant length, give up. + ConstantInt *Length = dyn_cast(MI->getLength()); + if (!Length) return MarkUnsafe(Info); + + // If not the whole aggregate, give up. + const TargetData &TD = getAnalysis(); + if (Length->getZExtValue() != TD.getTypeSize(AI->getType()->getElementType())) + return MarkUnsafe(Info); + + // We only know about memcpy/memset/memmove. + if (!isa(MI) && !isa(MI) && !isa(MI)) + return MarkUnsafe(Info); + + // Otherwise, we can transform it. Determine whether this is a memcpy/set + // into or out of the aggregate. + if (OpNo == 1) + Info.isMemCpyDst = true; + else { + assert(OpNo == 2); + Info.isMemCpySrc = true; + } +} + +/// isSafeUseOfBitCastedAllocation - Return true if all users of this bitcast +/// are +void SROA::isSafeUseOfBitCastedAllocation(BitCastInst *BC, AllocationInst *AI, + AllocaInfo &Info) { + for (Value::use_iterator UI = BC->use_begin(), E = BC->use_end(); + UI != E; ++UI) { + if (BitCastInst *BCU = dyn_cast(UI)) { + isSafeUseOfBitCastedAllocation(BCU, AI, Info); + } else if (MemIntrinsic *MI = dyn_cast(UI)) { + isSafeMemIntrinsicOnAllocation(MI, AI, UI.getOperandNo(), Info); + } else { + return MarkUnsafe(Info); + } + if (Info.isUnsafe) return; + } +} + +/// RewriteBitCastUserOfAlloca - BCInst (transitively) bitcasts AI, or indexes +/// to its first element. Transform users of the cast to use the new values +/// instead. +void SROA::RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI, + SmallVector &NewElts) { + Constant *Zero = Constant::getNullValue(Type::Int32Ty); + const TargetData &TD = getAnalysis(); + + Value::use_iterator UI = BCInst->use_begin(), UE = BCInst->use_end(); + while (UI != UE) { + if (BitCastInst *BCU = dyn_cast(*UI)) { + RewriteBitCastUserOfAlloca(BCU, AI, NewElts); + ++UI; + BCU->eraseFromParent(); + continue; + } + + // Otherwise, must be memcpy/memmove/memset of the entire aggregate. Split + // into one per element. + MemIntrinsic *MI = dyn_cast(*UI); + + // If it's not a mem intrinsic, it must be some other user of a gep of the + // first pointer. Just leave these alone. + if (!MI) { + ++UI; + continue; + } + + // If this is a memcpy/memmove, construct the other pointer as the + // appropriate type. + Value *OtherPtr = 0; + if (MemCpyInst *MCI = dyn_cast(MI)) { + if (BCInst == MCI->getRawDest()) + OtherPtr = MCI->getRawSource(); + else { + assert(BCInst == MCI->getRawSource()); + OtherPtr = MCI->getRawDest(); + } + } else if (MemMoveInst *MMI = dyn_cast(MI)) { + if (BCInst == MMI->getRawDest()) + OtherPtr = MMI->getRawSource(); + else { + assert(BCInst == MMI->getRawSource()); + OtherPtr = MMI->getRawDest(); + } + } + + // If there is an other pointer, we want to convert it to the same pointer + // type as AI has, so we can GEP through it. + if (OtherPtr) { + // It is likely that OtherPtr is a bitcast, if so, remove it. + if (BitCastInst *BC = dyn_cast(OtherPtr)) + OtherPtr = BC->getOperand(0); + if (ConstantExpr *BCE = dyn_cast(OtherPtr)) + if (BCE->getOpcode() == Instruction::BitCast) + OtherPtr = BCE->getOperand(0); + + // If the pointer is not the right type, insert a bitcast to the right + // type. + if (OtherPtr->getType() != AI->getType()) + OtherPtr = new BitCastInst(OtherPtr, AI->getType(), OtherPtr->getName(), + MI); + } + + // Process each element of the aggregate. + Value *TheFn = MI->getOperand(0); + const Type *BytePtrTy = MI->getRawDest()->getType(); + bool SROADest = MI->getRawDest() == BCInst; + + for (unsigned i = 0, e = NewElts.size(); i != e; ++i) { + // If this is a memcpy/memmove, emit a GEP of the other element address. + Value *OtherElt = 0; + if (OtherPtr) { + OtherElt = new GetElementPtrInst(OtherPtr, Zero, + ConstantInt::get(Type::Int32Ty, i), + OtherPtr->getNameStr()+"."+utostr(i), + MI); + } + + Value *EltPtr = NewElts[i]; + const Type *EltTy =cast(EltPtr->getType())->getElementType(); + + // If we got down to a scalar, insert a load or store as appropriate. + if (EltTy->isFirstClassType()) { + if (isa(MI) || isa(MI)) { + Value *Elt = new LoadInst(SROADest ? OtherElt : EltPtr, "tmp", + MI); + new StoreInst(Elt, SROADest ? EltPtr : OtherElt, MI); + continue; + } else { + assert(isa(MI)); + + // If the stored element is zero (common case), just store a null + // constant. + Constant *StoreVal; + if (ConstantInt *CI = dyn_cast(MI->getOperand(2))) { + if (CI->isZero()) { + StoreVal = Constant::getNullValue(EltTy); // 0.0, null, 0, <0,0> + } else { + // If EltTy is a vector type, get the element type. + const Type *ValTy = EltTy; + if (const VectorType *VTy = dyn_cast(ValTy)) + ValTy = VTy->getElementType(); + + // Construct an integer with the right value. + unsigned EltSize = TD.getTypeSize(ValTy); + APInt OneVal(EltSize*8, CI->getZExtValue()); + APInt TotalVal(OneVal); + // Set each byte. + for (unsigned i = 0; i != EltSize-1; ++i) { + TotalVal = TotalVal.shl(8); + TotalVal |= OneVal; + } + + // Convert the integer value to the appropriate type. + StoreVal = ConstantInt::get(TotalVal); + if (isa(ValTy)) + StoreVal = ConstantExpr::getIntToPtr(StoreVal, ValTy); + else if (ValTy->isFloatingPoint()) + StoreVal = ConstantExpr::getBitCast(StoreVal, ValTy); + assert(StoreVal->getType() == ValTy && "Type mismatch!"); + + // If the requested value was a vector constant, create it. + if (EltTy != ValTy) { + unsigned NumElts = cast(ValTy)->getNumElements(); + SmallVector Elts(NumElts, StoreVal); + StoreVal = ConstantVector::get(&Elts[0], NumElts); + } + } + new StoreInst(StoreVal, EltPtr, MI); + continue; + } + // Otherwise, if we're storing a byte variable, use a memset call for + // this element. + } + } + + // Cast the element pointer to BytePtrTy. + if (EltPtr->getType() != BytePtrTy) + EltPtr = new BitCastInst(EltPtr, BytePtrTy, EltPtr->getNameStr(), MI); + + // Cast the other pointer (if we have one) to BytePtrTy. + if (OtherElt && OtherElt->getType() != BytePtrTy) + OtherElt = new BitCastInst(OtherElt, BytePtrTy,OtherElt->getNameStr(), + MI); + + unsigned EltSize = TD.getTypeSize(EltTy); + + // Finally, insert the meminst for this element. + if (isa(MI) || isa(MI)) { + Value *Ops[] = { + SROADest ? EltPtr : OtherElt, // Dest ptr + SROADest ? OtherElt : EltPtr, // Src ptr + ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size + Zero // Align + }; + new CallInst(TheFn, Ops, 4, "", MI); + } else { + assert(isa(MI)); + Value *Ops[] = { + EltPtr, MI->getOperand(2), // Dest, Value, + ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size + Zero // Align + }; + new CallInst(TheFn, Ops, 4, "", MI); + } + } + + // Finally, MI is now dead, as we've modified its actions to occur on all of + // the elements of the aggregate. + ++UI; + MI->eraseFromParent(); + } +} + +/// HasStructPadding - Return true if the specified type has any structure +/// padding, false otherwise. +static bool HasStructPadding(const Type *Ty, const TargetData &TD) { + if (const StructType *STy = dyn_cast(Ty)) { + const StructLayout *SL = TD.getStructLayout(STy); + unsigned PrevFieldBitOffset = 0; + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + unsigned FieldBitOffset = SL->getElementOffset(i)*8; + + // Padding in sub-elements? + if (HasStructPadding(STy->getElementType(i), TD)) + return true; + + // Check to see if there is any padding between this element and the + // previous one. + if (i) { + unsigned PrevFieldEnd = + PrevFieldBitOffset+TD.getTypeSizeInBits(STy->getElementType(i-1)); + if (PrevFieldEnd < FieldBitOffset) + return true; + } + + PrevFieldBitOffset = FieldBitOffset; + } + + // Check for tail padding. + if (unsigned EltCount = STy->getNumElements()) { + unsigned PrevFieldEnd = PrevFieldBitOffset + + TD.getTypeSizeInBits(STy->getElementType(EltCount-1)); + if (PrevFieldEnd < SL->getSizeInBytes()*8) + return true; + } + + } else if (const ArrayType *ATy = dyn_cast(Ty)) { + return HasStructPadding(ATy->getElementType(), TD); + } + return false; +} + +/// isSafeStructAllocaToScalarRepl - Check to see if the specified allocation of +/// an aggregate can be broken down into elements. Return 0 if not, 3 if safe, +/// or 1 if safe after canonicalization has been performed. +/// +int SROA::isSafeAllocaToScalarRepl(AllocationInst *AI) { + // Loop over the use list of the alloca. We can only transform it if all of + // the users are safe to transform. + AllocaInfo Info; + + for (Value::use_iterator I = AI->use_begin(), E = AI->use_end(); + I != E; ++I) { + isSafeUseOfAllocation(cast(*I), AI, Info); + if (Info.isUnsafe) { + DOUT << "Cannot transform: " << *AI << " due to user: " << **I; + return 0; + } + } + + // Okay, we know all the users are promotable. If the aggregate is a memcpy + // source and destination, we have to be careful. In particular, the memcpy + // could be moving around elements that live in structure padding of the LLVM + // types, but may actually be used. In these cases, we refuse to promote the + // struct. + if (Info.isMemCpySrc && Info.isMemCpyDst && + HasStructPadding(AI->getType()->getElementType(), + getAnalysis())) + return 0; + + // If we require cleanup, return 1, otherwise return 3. + return Info.needsCanon ? 1 : 3; +} + +/// CanonicalizeAllocaUsers - If SROA reported that it can promote the specified +/// allocation, but only if cleaned up, perform the cleanups required. +void SROA::CanonicalizeAllocaUsers(AllocationInst *AI) { + // At this point, we know that the end result will be SROA'd and promoted, so + // we can insert ugly code if required so long as sroa+mem2reg will clean it + // up. + for (Value::use_iterator UI = AI->use_begin(), E = AI->use_end(); + UI != E; ) { + GetElementPtrInst *GEPI = dyn_cast(*UI++); + if (!GEPI) continue; + gep_type_iterator I = gep_type_begin(GEPI); + ++I; + + if (const ArrayType *AT = dyn_cast(*I)) { + uint64_t NumElements = AT->getNumElements(); + + if (!isa(I.getOperand())) { + if (NumElements == 1) { + GEPI->setOperand(2, Constant::getNullValue(Type::Int32Ty)); + } else { + assert(NumElements == 2 && "Unhandled case!"); + // All users of the GEP must be loads. At each use of the GEP, insert + // two loads of the appropriate indexed GEP and select between them. + Value *IsOne = new ICmpInst(ICmpInst::ICMP_NE, I.getOperand(), + Constant::getNullValue(I.getOperand()->getType()), + "isone", GEPI); + // Insert the new GEP instructions, which are properly indexed. + SmallVector Indices(GEPI->op_begin()+1, GEPI->op_end()); + Indices[1] = Constant::getNullValue(Type::Int32Ty); + Value *ZeroIdx = new GetElementPtrInst(GEPI->getOperand(0), + &Indices[0], Indices.size(), + GEPI->getName()+".0", GEPI); + Indices[1] = ConstantInt::get(Type::Int32Ty, 1); + Value *OneIdx = new GetElementPtrInst(GEPI->getOperand(0), + &Indices[0], Indices.size(), + GEPI->getName()+".1", GEPI); + // Replace all loads of the variable index GEP with loads from both + // indexes and a select. + while (!GEPI->use_empty()) { + LoadInst *LI = cast(GEPI->use_back()); + Value *Zero = new LoadInst(ZeroIdx, LI->getName()+".0", LI); + Value *One = new LoadInst(OneIdx , LI->getName()+".1", LI); + Value *R = new SelectInst(IsOne, One, Zero, LI->getName(), LI); + LI->replaceAllUsesWith(R); + LI->eraseFromParent(); + } + GEPI->eraseFromParent(); + } + } + } + } +} + +/// MergeInType - Add the 'In' type to the accumulated type so far. If the +/// types are incompatible, return true, otherwise update Accum and return +/// false. +/// +/// There are three cases we handle here: +/// 1) An effectively-integer union, where the pieces are stored into as +/// smaller integers (common with byte swap and other idioms). +/// 2) A union of vector types of the same size and potentially its elements. +/// Here we turn element accesses into insert/extract element operations. +/// 3) A union of scalar types, such as int/float or int/pointer. Here we +/// merge together into integers, allowing the xform to work with #1 as +/// well. +static bool MergeInType(const Type *In, const Type *&Accum, + const TargetData &TD) { + // If this is our first type, just use it. + const VectorType *PTy; + if (Accum == Type::VoidTy || In == Accum) { + Accum = In; + } else if (In == Type::VoidTy) { + // Noop. + } else if (In->isInteger() && Accum->isInteger()) { // integer union. + // Otherwise pick whichever type is larger. + if (cast(In)->getBitWidth() > + cast(Accum)->getBitWidth()) + Accum = In; + } else if (isa(In) && isa(Accum)) { + // Pointer unions just stay as one of the pointers. + } else if (isa(In) || isa(Accum)) { + if ((PTy = dyn_cast(Accum)) && + PTy->getElementType() == In) { + // Accum is a vector, and we are accessing an element: ok. + } else if ((PTy = dyn_cast(In)) && + PTy->getElementType() == Accum) { + // In is a vector, and accum is an element: ok, remember In. + Accum = In; + } else if ((PTy = dyn_cast(In)) && isa(Accum) && + PTy->getBitWidth() == cast(Accum)->getBitWidth()) { + // Two vectors of the same size: keep Accum. + } else { + // Cannot insert an short into a <4 x int> or handle + // <2 x int> -> <4 x int> + return true; + } + } else { + // Pointer/FP/Integer unions merge together as integers. + switch (Accum->getTypeID()) { + case Type::PointerTyID: Accum = TD.getIntPtrType(); break; + case Type::FloatTyID: Accum = Type::Int32Ty; break; + case Type::DoubleTyID: Accum = Type::Int64Ty; break; + default: + assert(Accum->isInteger() && "Unknown FP type!"); + break; + } + + switch (In->getTypeID()) { + case Type::PointerTyID: In = TD.getIntPtrType(); break; + case Type::FloatTyID: In = Type::Int32Ty; break; + case Type::DoubleTyID: In = Type::Int64Ty; break; + default: + assert(In->isInteger() && "Unknown FP type!"); + break; + } + return MergeInType(In, Accum, TD); + } + return false; +} + +/// getUIntAtLeastAsBitAs - Return an unsigned integer type that is at least +/// as big as the specified type. If there is no suitable type, this returns +/// null. +const Type *getUIntAtLeastAsBitAs(unsigned NumBits) { + if (NumBits > 64) return 0; + if (NumBits > 32) return Type::Int64Ty; + if (NumBits > 16) return Type::Int32Ty; + if (NumBits > 8) return Type::Int16Ty; + return Type::Int8Ty; +} + +/// CanConvertToScalar - V is a pointer. If we can convert the pointee to a +/// single scalar integer type, return that type. Further, if the use is not +/// a completely trivial use that mem2reg could promote, set IsNotTrivial. If +/// there are no uses of this pointer, return Type::VoidTy to differentiate from +/// failure. +/// +const Type *SROA::CanConvertToScalar(Value *V, bool &IsNotTrivial) { + const Type *UsedType = Type::VoidTy; // No uses, no forced type. + const TargetData &TD = getAnalysis(); + const PointerType *PTy = cast(V->getType()); + + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI!=E; ++UI) { + Instruction *User = cast(*UI); + + if (LoadInst *LI = dyn_cast(User)) { + if (MergeInType(LI->getType(), UsedType, TD)) + return 0; + + } else if (StoreInst *SI = dyn_cast(User)) { + // Storing the pointer, not into the value? + if (SI->getOperand(0) == V) return 0; + + // NOTE: We could handle storing of FP imms into integers here! + + if (MergeInType(SI->getOperand(0)->getType(), UsedType, TD)) + return 0; + } else if (BitCastInst *CI = dyn_cast(User)) { + IsNotTrivial = true; + const Type *SubTy = CanConvertToScalar(CI, IsNotTrivial); + if (!SubTy || MergeInType(SubTy, UsedType, TD)) return 0; + } else if (GetElementPtrInst *GEP = dyn_cast(User)) { + // Check to see if this is stepping over an element: GEP Ptr, int C + if (GEP->getNumOperands() == 2 && isa(GEP->getOperand(1))) { + unsigned Idx = cast(GEP->getOperand(1))->getZExtValue(); + unsigned ElSize = TD.getTypeSize(PTy->getElementType()); + unsigned BitOffset = Idx*ElSize*8; + if (BitOffset > 64 || !isPowerOf2_32(ElSize)) return 0; + + IsNotTrivial = true; + const Type *SubElt = CanConvertToScalar(GEP, IsNotTrivial); + if (SubElt == 0) return 0; + if (SubElt != Type::VoidTy && SubElt->isInteger()) { + const Type *NewTy = + getUIntAtLeastAsBitAs(TD.getTypeSize(SubElt)*8+BitOffset); + if (NewTy == 0 || MergeInType(NewTy, UsedType, TD)) return 0; + continue; + } + } else if (GEP->getNumOperands() == 3 && + isa(GEP->getOperand(1)) && + isa(GEP->getOperand(2)) && + cast(GEP->getOperand(1))->isZero()) { + // We are stepping into an element, e.g. a structure or an array: + // GEP Ptr, int 0, uint C + const Type *AggTy = PTy->getElementType(); + unsigned Idx = cast(GEP->getOperand(2))->getZExtValue(); + + if (const ArrayType *ATy = dyn_cast(AggTy)) { + if (Idx >= ATy->getNumElements()) return 0; // Out of range. + } else if (const VectorType *VectorTy = dyn_cast(AggTy)) { + // Getting an element of the vector. + if (Idx >= VectorTy->getNumElements()) return 0; // Out of range. + + // Merge in the vector type. + if (MergeInType(VectorTy, UsedType, TD)) return 0; + + const Type *SubTy = CanConvertToScalar(GEP, IsNotTrivial); + if (SubTy == 0) return 0; + + if (SubTy != Type::VoidTy && MergeInType(SubTy, UsedType, TD)) + return 0; + + // We'll need to change this to an insert/extract element operation. + IsNotTrivial = true; + continue; // Everything looks ok + + } else if (isa(AggTy)) { + // Structs are always ok. + } else { + return 0; + } + const Type *NTy = getUIntAtLeastAsBitAs(TD.getTypeSize(AggTy)*8); + if (NTy == 0 || MergeInType(NTy, UsedType, TD)) return 0; + const Type *SubTy = CanConvertToScalar(GEP, IsNotTrivial); + if (SubTy == 0) return 0; + if (SubTy != Type::VoidTy && MergeInType(SubTy, UsedType, TD)) + return 0; + continue; // Everything looks ok + } + return 0; + } else { + // Cannot handle this! + return 0; + } + } + + return UsedType; +} + +/// ConvertToScalar - The specified alloca passes the CanConvertToScalar +/// predicate and is non-trivial. Convert it to something that can be trivially +/// promoted into a register by mem2reg. +void SROA::ConvertToScalar(AllocationInst *AI, const Type *ActualTy) { + DOUT << "CONVERT TO SCALAR: " << *AI << " TYPE = " + << *ActualTy << "\n"; + ++NumConverted; + + BasicBlock *EntryBlock = AI->getParent(); + assert(EntryBlock == &EntryBlock->getParent()->getEntryBlock() && + "Not in the entry block!"); + EntryBlock->getInstList().remove(AI); // Take the alloca out of the program. + + // Create and insert the alloca. + AllocaInst *NewAI = new AllocaInst(ActualTy, 0, AI->getName(), + EntryBlock->begin()); + ConvertUsesToScalar(AI, NewAI, 0); + delete AI; +} + + +/// ConvertUsesToScalar - Convert all of the users of Ptr to use the new alloca +/// directly. This happens when we are converting an "integer union" to a +/// single integer scalar, or when we are converting a "vector union" to a +/// vector with insert/extractelement instructions. +/// +/// Offset is an offset from the original alloca, in bits that need to be +/// shifted to the right. By the end of this, there should be no uses of Ptr. +void SROA::ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, unsigned Offset) { + const TargetData &TD = getAnalysis(); + while (!Ptr->use_empty()) { + Instruction *User = cast(Ptr->use_back()); + + if (LoadInst *LI = dyn_cast(User)) { + // The load is a bit extract from NewAI shifted right by Offset bits. + Value *NV = new LoadInst(NewAI, LI->getName(), LI); + if (NV->getType() == LI->getType()) { + // We win, no conversion needed. + } else if (const VectorType *PTy = dyn_cast(NV->getType())) { + // If the result alloca is a vector type, this is either an element + // access or a bitcast to another vector type. + if (isa(LI->getType())) { + NV = new BitCastInst(NV, LI->getType(), LI->getName(), LI); + } else { + // Must be an element access. + unsigned Elt = Offset/(TD.getTypeSize(PTy->getElementType())*8); + NV = new ExtractElementInst( + NV, ConstantInt::get(Type::Int32Ty, Elt), "tmp", LI); + } + } else if (isa(NV->getType())) { + assert(isa(LI->getType())); + // Must be ptr->ptr cast. Anything else would result in NV being + // an integer. + NV = new BitCastInst(NV, LI->getType(), LI->getName(), LI); + } else { + const IntegerType *NTy = cast(NV->getType()); + unsigned LIBitWidth = TD.getTypeSizeInBits(LI->getType()); + + // If this is a big-endian system and the load is narrower than the + // full alloca type, we need to do a shift to get the right bits. + int ShAmt = 0; + if (TD.isBigEndian()) { + ShAmt = NTy->getBitWidth()-LIBitWidth-Offset; + } else { + ShAmt = Offset; + } + + // Note: we support negative bitwidths (with shl) which are not defined. + // We do this to support (f.e.) loads off the end of a structure where + // only some bits are used. + if (ShAmt > 0 && (unsigned)ShAmt < NTy->getBitWidth()) + NV = BinaryOperator::createLShr(NV, + ConstantInt::get(NV->getType(),ShAmt), + LI->getName(), LI); + else if (ShAmt < 0 && (unsigned)-ShAmt < NTy->getBitWidth()) + NV = BinaryOperator::createShl(NV, + ConstantInt::get(NV->getType(),-ShAmt), + LI->getName(), LI); + + // Finally, unconditionally truncate the integer to the right width. + if (LIBitWidth < NTy->getBitWidth()) + NV = new TruncInst(NV, IntegerType::get(LIBitWidth), + LI->getName(), LI); + + // If the result is an integer, this is a trunc or bitcast. + if (isa(LI->getType())) { + assert(NV->getType() == LI->getType() && "Truncate wasn't enough?"); + } else if (LI->getType()->isFloatingPoint()) { + // Just do a bitcast, we know the sizes match up. + NV = new BitCastInst(NV, LI->getType(), LI->getName(), LI); + } else { + // Otherwise must be a pointer. + NV = new IntToPtrInst(NV, LI->getType(), LI->getName(), LI); + } + } + LI->replaceAllUsesWith(NV); + LI->eraseFromParent(); + } else if (StoreInst *SI = dyn_cast(User)) { + assert(SI->getOperand(0) != Ptr && "Consistency error!"); + + // Convert the stored type to the actual type, shift it left to insert + // then 'or' into place. + Value *SV = SI->getOperand(0); + const Type *AllocaType = NewAI->getType()->getElementType(); + if (SV->getType() == AllocaType) { + // All is well. + } else if (const VectorType *PTy = dyn_cast(AllocaType)) { + Value *Old = new LoadInst(NewAI, NewAI->getName()+".in", SI); + + // If the result alloca is a vector type, this is either an element + // access or a bitcast to another vector type. + if (isa(SV->getType())) { + SV = new BitCastInst(SV, AllocaType, SV->getName(), SI); + } else { + // Must be an element insertion. + unsigned Elt = Offset/(TD.getTypeSize(PTy->getElementType())*8); + SV = new InsertElementInst(Old, SV, + ConstantInt::get(Type::Int32Ty, Elt), + "tmp", SI); + } + } else if (isa(AllocaType)) { + // If the alloca type is a pointer, then all the elements must be + // pointers. + if (SV->getType() != AllocaType) + SV = new BitCastInst(SV, AllocaType, SV->getName(), SI); + } else { + Value *Old = new LoadInst(NewAI, NewAI->getName()+".in", SI); + + // If SV is a float, convert it to the appropriate integer type. + // If it is a pointer, do the same, and also handle ptr->ptr casts + // here. + unsigned SrcWidth = TD.getTypeSizeInBits(SV->getType()); + unsigned DestWidth = AllocaType->getPrimitiveSizeInBits(); + if (SV->getType()->isFloatingPoint()) + SV = new BitCastInst(SV, IntegerType::get(SrcWidth), + SV->getName(), SI); + else if (isa(SV->getType())) + SV = new PtrToIntInst(SV, TD.getIntPtrType(), SV->getName(), SI); + + // Always zero extend the value if needed. + if (SV->getType() != AllocaType) + SV = new ZExtInst(SV, AllocaType, SV->getName(), SI); + + // If this is a big-endian system and the store is narrower than the + // full alloca type, we need to do a shift to get the right bits. + int ShAmt = 0; + if (TD.isBigEndian()) { + ShAmt = DestWidth-SrcWidth-Offset; + } else { + ShAmt = Offset; + } + + // Note: we support negative bitwidths (with shr) which are not defined. + // We do this to support (f.e.) stores off the end of a structure where + // only some bits in the structure are set. + APInt Mask(APInt::getLowBitsSet(DestWidth, SrcWidth)); + if (ShAmt > 0 && (unsigned)ShAmt < DestWidth) { + SV = BinaryOperator::createShl(SV, + ConstantInt::get(SV->getType(), ShAmt), + SV->getName(), SI); + Mask <<= ShAmt; + } else if (ShAmt < 0 && (unsigned)-ShAmt < DestWidth) { + SV = BinaryOperator::createLShr(SV, + ConstantInt::get(SV->getType(),-ShAmt), + SV->getName(), SI); + Mask = Mask.lshr(ShAmt); + } + + // Mask out the bits we are about to insert from the old value, and or + // in the new bits. + if (SrcWidth != DestWidth) { + assert(DestWidth > SrcWidth); + Old = BinaryOperator::createAnd(Old, ConstantInt::get(~Mask), + Old->getName()+".mask", SI); + SV = BinaryOperator::createOr(Old, SV, SV->getName()+".ins", SI); + } + } + new StoreInst(SV, NewAI, SI); + SI->eraseFromParent(); + + } else if (BitCastInst *CI = dyn_cast(User)) { + ConvertUsesToScalar(CI, NewAI, Offset); + CI->eraseFromParent(); + } else if (GetElementPtrInst *GEP = dyn_cast(User)) { + const PointerType *AggPtrTy = + cast(GEP->getOperand(0)->getType()); + const TargetData &TD = getAnalysis(); + unsigned AggSizeInBits = TD.getTypeSize(AggPtrTy->getElementType())*8; + + // Check to see if this is stepping over an element: GEP Ptr, int C + unsigned NewOffset = Offset; + if (GEP->getNumOperands() == 2) { + unsigned Idx = cast(GEP->getOperand(1))->getZExtValue(); + unsigned BitOffset = Idx*AggSizeInBits; + + NewOffset += BitOffset; + } else if (GEP->getNumOperands() == 3) { + // We know that operand #2 is zero. + unsigned Idx = cast(GEP->getOperand(2))->getZExtValue(); + const Type *AggTy = AggPtrTy->getElementType(); + if (const SequentialType *SeqTy = dyn_cast(AggTy)) { + unsigned ElSizeBits = TD.getTypeSize(SeqTy->getElementType())*8; + + NewOffset += ElSizeBits*Idx; + } else if (const StructType *STy = dyn_cast(AggTy)) { + unsigned EltBitOffset = + TD.getStructLayout(STy)->getElementOffset(Idx)*8; + + NewOffset += EltBitOffset; + } else { + assert(0 && "Unsupported operation!"); + abort(); + } + } else { + assert(0 && "Unsupported operation!"); + abort(); + } + ConvertUsesToScalar(GEP, NewAI, NewOffset); + GEP->eraseFromParent(); + } else { + assert(0 && "Unsupported operation!"); + abort(); + } + } +} + + +/// PointsToConstantGlobal - Return true if V (possibly indirectly) points to +/// some part of a constant global variable. This intentionally only accepts +/// constant expressions because we don't can't rewrite arbitrary instructions. +static bool PointsToConstantGlobal(Value *V) { + if (GlobalVariable *GV = dyn_cast(V)) + return GV->isConstant(); + if (ConstantExpr *CE = dyn_cast(V)) + if (CE->getOpcode() == Instruction::BitCast || + CE->getOpcode() == Instruction::GetElementPtr) + return PointsToConstantGlobal(CE->getOperand(0)); + return false; +} + +/// isOnlyCopiedFromConstantGlobal - Recursively walk the uses of a (derived) +/// pointer to an alloca. Ignore any reads of the pointer, return false if we +/// see any stores or other unknown uses. If we see pointer arithmetic, keep +/// track of whether it moves the pointer (with isOffset) but otherwise traverse +/// the uses. If we see a memcpy/memmove that targets an unoffseted pointer to +/// the alloca, and if the source pointer is a pointer to a constant global, we +/// can optimize this. +static bool isOnlyCopiedFromConstantGlobal(Value *V, Instruction *&TheCopy, + bool isOffset) { + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI!=E; ++UI) { + if (isa(*UI)) { + // Ignore loads, they are always ok. + continue; + } + if (BitCastInst *BCI = dyn_cast(*UI)) { + // If uses of the bitcast are ok, we are ok. + if (!isOnlyCopiedFromConstantGlobal(BCI, TheCopy, isOffset)) + return false; + continue; + } + if (GetElementPtrInst *GEP = dyn_cast(*UI)) { + // If the GEP has all zero indices, it doesn't offset the pointer. If it + // doesn't, it does. + if (!isOnlyCopiedFromConstantGlobal(GEP, TheCopy, + isOffset || !GEP->hasAllZeroIndices())) + return false; + continue; + } + + // If this is isn't our memcpy/memmove, reject it as something we can't + // handle. + if (!isa(*UI) && !isa(*UI)) + return false; + + // If we already have seen a copy, reject the second one. + if (TheCopy) return false; + + // If the pointer has been offset from the start of the alloca, we can't + // safely handle this. + if (isOffset) return false; + + // If the memintrinsic isn't using the alloca as the dest, reject it. + if (UI.getOperandNo() != 1) return false; + + MemIntrinsic *MI = cast(*UI); + + // If the source of the memcpy/move is not a constant global, reject it. + if (!PointsToConstantGlobal(MI->getOperand(2))) + return false; + + // Otherwise, the transform is safe. Remember the copy instruction. + TheCopy = MI; + } + return true; +} + +/// isOnlyCopiedFromConstantGlobal - Return true if the specified alloca is only +/// modified by a copy from a constant global. If we can prove this, we can +/// replace any uses of the alloca with uses of the global directly. +Instruction *SROA::isOnlyCopiedFromConstantGlobal(AllocationInst *AI) { + Instruction *TheCopy = 0; + if (::isOnlyCopiedFromConstantGlobal(AI, TheCopy, false)) + return TheCopy; + return 0; +} diff --git a/lib/Transforms/Scalar/SimplifyCFG.cpp b/lib/Transforms/Scalar/SimplifyCFG.cpp new file mode 100644 index 0000000..6b47ef7 --- /dev/null +++ b/lib/Transforms/Scalar/SimplifyCFG.cpp @@ -0,0 +1,145 @@ +//===- SimplifyCFG.cpp - CFG Simplification Pass --------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements dead code elimination and basic block merging. +// +// Specifically, this: +// * removes basic blocks with no predecessors +// * merges a basic block into its predecessor if there is only one and the +// predecessor only has one successor. +// * Eliminates PHI nodes for basic blocks with a single predecessor +// * Eliminates a basic block that only contains an unconditional branch +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "simplifycfg" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Pass.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +using namespace llvm; + +STATISTIC(NumSimpl, "Number of blocks simplified"); + +namespace { + struct VISIBILITY_HIDDEN CFGSimplifyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + CFGSimplifyPass() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F); + }; + char CFGSimplifyPass::ID = 0; + RegisterPass X("simplifycfg", "Simplify the CFG"); +} + +// Public interface to the CFGSimplification pass +FunctionPass *llvm::createCFGSimplificationPass() { + return new CFGSimplifyPass(); +} + +static bool MarkAliveBlocks(BasicBlock *BB, + SmallPtrSet &Reachable) { + + std::vector Worklist; + Worklist.push_back(BB); + bool Changed = false; + while (!Worklist.empty()) { + BB = Worklist.back(); + Worklist.pop_back(); + + if (!Reachable.insert(BB)) + continue; + + // Do a quick scan of the basic block, turning any obviously unreachable + // instructions into LLVM unreachable insts. The instruction combining pass + // canonnicalizes unreachable insts into stores to null or undef. + for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E; ++BBI) + if (StoreInst *SI = dyn_cast(BBI)) + if (isa(SI->getOperand(1)) || + isa(SI->getOperand(1))) { + // Loop over all of the successors, removing BB's entry from any PHI + // nodes. + for (succ_iterator I = succ_begin(BB), SE = succ_end(BB); I != SE;++I) + (*I)->removePredecessor(BB); + + new UnreachableInst(SI); + + // All instructions after this are dead. + while (BBI != E) { + if (!BBI->use_empty()) + BBI->replaceAllUsesWith(UndefValue::get(BBI->getType())); + BB->getInstList().erase(BBI++); + } + break; + } + + + Changed |= ConstantFoldTerminator(BB); + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) + Worklist.push_back(*SI); + } + return Changed; +} + + +// It is possible that we may require multiple passes over the code to fully +// simplify the CFG. +// +bool CFGSimplifyPass::runOnFunction(Function &F) { + SmallPtrSet Reachable; + bool Changed = MarkAliveBlocks(F.begin(), Reachable); + + // If there are unreachable blocks in the CFG... + if (Reachable.size() != F.size()) { + assert(Reachable.size() < F.size()); + NumSimpl += F.size()-Reachable.size(); + + // Loop over all of the basic blocks that are not reachable, dropping all of + // their internal references... + for (Function::iterator BB = ++F.begin(), E = F.end(); BB != E; ++BB) + if (!Reachable.count(BB)) { + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI!=SE; ++SI) + if (Reachable.count(*SI)) + (*SI)->removePredecessor(BB); + BB->dropAllReferences(); + } + + for (Function::iterator I = ++F.begin(); I != F.end();) + if (!Reachable.count(I)) + I = F.getBasicBlockList().erase(I); + else + ++I; + + Changed = true; + } + + bool LocalChange = true; + while (LocalChange) { + LocalChange = false; + + // Loop over all of the basic blocks (except the first one) and remove them + // if they are unneeded... + // + for (Function::iterator BBIt = ++F.begin(); BBIt != F.end(); ) { + if (SimplifyCFG(BBIt++)) { + LocalChange = true; + ++NumSimpl; + } + } + Changed |= LocalChange; + } + + return Changed; +} diff --git a/lib/Transforms/Scalar/TailDuplication.cpp b/lib/Transforms/Scalar/TailDuplication.cpp new file mode 100644 index 0000000..22d8157 --- /dev/null +++ b/lib/Transforms/Scalar/TailDuplication.cpp @@ -0,0 +1,364 @@ +//===- TailDuplication.cpp - Simplify CFG through tail duplication --------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs a limited form of tail duplication, intended to simplify +// CFGs by removing some unconditional branches. This pass is necessary to +// straighten out loops created by the C front-end, but also is capable of +// making other code nicer. After this pass is run, the CFG simplify pass +// should be run to clean up the mess. +// +// This pass could be enhanced in the future to use profile information to be +// more aggressive. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "tailduplicate" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constant.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/Type.h" +#include "llvm/Support/CFG.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" +using namespace llvm; + +STATISTIC(NumEliminated, "Number of unconditional branches eliminated"); + +namespace { + cl::opt + Threshold("taildup-threshold", cl::desc("Max block size to tail duplicate"), + cl::init(6), cl::Hidden); + class VISIBILITY_HIDDEN TailDup : public FunctionPass { + bool runOnFunction(Function &F); + public: + static char ID; // Pass identification, replacement for typeid + TailDup() : FunctionPass((intptr_t)&ID) {} + + private: + inline bool shouldEliminateUnconditionalBranch(TerminatorInst *TI); + inline void eliminateUnconditionalBranch(BranchInst *BI); + }; + char TailDup::ID = 0; + RegisterPass X("tailduplicate", "Tail Duplication"); +} + +// Public interface to the Tail Duplication pass +FunctionPass *llvm::createTailDuplicationPass() { return new TailDup(); } + +/// runOnFunction - Top level algorithm - Loop over each unconditional branch in +/// the function, eliminating it if it looks attractive enough. +/// +bool TailDup::runOnFunction(Function &F) { + bool Changed = false; + for (Function::iterator I = F.begin(), E = F.end(); I != E; ) + if (shouldEliminateUnconditionalBranch(I->getTerminator())) { + eliminateUnconditionalBranch(cast(I->getTerminator())); + Changed = true; + } else { + ++I; + } + return Changed; +} + +/// shouldEliminateUnconditionalBranch - Return true if this branch looks +/// attractive to eliminate. We eliminate the branch if the destination basic +/// block has <= 5 instructions in it, not counting PHI nodes. In practice, +/// since one of these is a terminator instruction, this means that we will add +/// up to 4 instructions to the new block. +/// +/// We don't count PHI nodes in the count since they will be removed when the +/// contents of the block are copied over. +/// +bool TailDup::shouldEliminateUnconditionalBranch(TerminatorInst *TI) { + BranchInst *BI = dyn_cast(TI); + if (!BI || !BI->isUnconditional()) return false; // Not an uncond branch! + + BasicBlock *Dest = BI->getSuccessor(0); + if (Dest == BI->getParent()) return false; // Do not loop infinitely! + + // Do not inline a block if we will just get another branch to the same block! + TerminatorInst *DTI = Dest->getTerminator(); + if (BranchInst *DBI = dyn_cast(DTI)) + if (DBI->isUnconditional() && DBI->getSuccessor(0) == Dest) + return false; // Do not loop infinitely! + + // FIXME: DemoteRegToStack cannot yet demote invoke instructions to the stack, + // because doing so would require breaking critical edges. This should be + // fixed eventually. + if (!DTI->use_empty()) + return false; + + // Do not bother working on dead blocks... + pred_iterator PI = pred_begin(Dest), PE = pred_end(Dest); + if (PI == PE && Dest != Dest->getParent()->begin()) + return false; // It's just a dead block, ignore it... + + // Also, do not bother with blocks with only a single predecessor: simplify + // CFG will fold these two blocks together! + ++PI; + if (PI == PE) return false; // Exactly one predecessor! + + BasicBlock::iterator I = Dest->begin(); + while (isa(*I)) ++I; + + for (unsigned Size = 0; I != Dest->end(); ++I) { + if (Size == Threshold) return false; // The block is too large. + // Only count instructions that are not debugger intrinsics. + if (!isa(I)) ++Size; + } + + // Do not tail duplicate a block that has thousands of successors into a block + // with a single successor if the block has many other predecessors. This can + // cause an N^2 explosion in CFG edges (and PHI node entries), as seen in + // cases that have a large number of indirect gotos. + unsigned NumSuccs = DTI->getNumSuccessors(); + if (NumSuccs > 8) { + unsigned TooMany = 128; + if (NumSuccs >= TooMany) return false; + TooMany = TooMany/NumSuccs; + for (; PI != PE; ++PI) + if (TooMany-- == 0) return false; + } + + // Finally, if this unconditional branch is a fall-through, be careful about + // tail duplicating it. In particular, we don't want to taildup it if the + // original block will still be there after taildup is completed: doing so + // would eliminate the fall-through, requiring unconditional branches. + Function::iterator DestI = Dest; + if (&*--DestI == BI->getParent()) { + // The uncond branch is a fall-through. Tail duplication of the block is + // will eliminate the fall-through-ness and end up cloning the terminator + // at the end of the Dest block. Since the original Dest block will + // continue to exist, this means that one or the other will not be able to + // fall through. One typical example that this helps with is code like: + // if (a) + // foo(); + // if (b) + // foo(); + // Cloning the 'if b' block into the end of the first foo block is messy. + + // The messy case is when the fall-through block falls through to other + // blocks. This is what we would be preventing if we cloned the block. + DestI = Dest; + if (++DestI != Dest->getParent()->end()) { + BasicBlock *DestSucc = DestI; + // If any of Dest's successors are fall-throughs, don't do this xform. + for (succ_iterator SI = succ_begin(Dest), SE = succ_end(Dest); + SI != SE; ++SI) + if (*SI == DestSucc) + return false; + } + } + + return true; +} + +/// FindObviousSharedDomOf - We know there is a branch from SrcBlock to +/// DestBlock, and that SrcBlock is not the only predecessor of DstBlock. If we +/// can find a predecessor of SrcBlock that is a dominator of both SrcBlock and +/// DstBlock, return it. +static BasicBlock *FindObviousSharedDomOf(BasicBlock *SrcBlock, + BasicBlock *DstBlock) { + // SrcBlock must have a single predecessor. + pred_iterator PI = pred_begin(SrcBlock), PE = pred_end(SrcBlock); + if (PI == PE || ++PI != PE) return 0; + + BasicBlock *SrcPred = *pred_begin(SrcBlock); + + // Look at the predecessors of DstBlock. One of them will be SrcBlock. If + // there is only one other pred, get it, otherwise we can't handle it. + PI = pred_begin(DstBlock); PE = pred_end(DstBlock); + BasicBlock *DstOtherPred = 0; + if (*PI == SrcBlock) { + if (++PI == PE) return 0; + DstOtherPred = *PI; + if (++PI != PE) return 0; + } else { + DstOtherPred = *PI; + if (++PI == PE || *PI != SrcBlock || ++PI != PE) return 0; + } + + // We can handle two situations here: "if then" and "if then else" blocks. An + // 'if then' situation is just where DstOtherPred == SrcPred. + if (DstOtherPred == SrcPred) + return SrcPred; + + // Check to see if we have an "if then else" situation, which means that + // DstOtherPred will have a single predecessor and it will be SrcPred. + PI = pred_begin(DstOtherPred); PE = pred_end(DstOtherPred); + if (PI != PE && *PI == SrcPred) { + if (++PI != PE) return 0; // Not a single pred. + return SrcPred; // Otherwise, it's an "if then" situation. Return the if. + } + + // Otherwise, this is something we can't handle. + return 0; +} + + +/// eliminateUnconditionalBranch - Clone the instructions from the destination +/// block into the source block, eliminating the specified unconditional branch. +/// If the destination block defines values used by successors of the dest +/// block, we may need to insert PHI nodes. +/// +void TailDup::eliminateUnconditionalBranch(BranchInst *Branch) { + BasicBlock *SourceBlock = Branch->getParent(); + BasicBlock *DestBlock = Branch->getSuccessor(0); + assert(SourceBlock != DestBlock && "Our predicate is broken!"); + + DOUT << "TailDuplication[" << SourceBlock->getParent()->getName() + << "]: Eliminating branch: " << *Branch; + + // See if we can avoid duplicating code by moving it up to a dominator of both + // blocks. + if (BasicBlock *DomBlock = FindObviousSharedDomOf(SourceBlock, DestBlock)) { + DOUT << "Found shared dominator: " << DomBlock->getName() << "\n"; + + // If there are non-phi instructions in DestBlock that have no operands + // defined in DestBlock, and if the instruction has no side effects, we can + // move the instruction to DomBlock instead of duplicating it. + BasicBlock::iterator BBI = DestBlock->begin(); + while (isa(BBI)) ++BBI; + while (!isa(BBI)) { + Instruction *I = BBI++; + + bool CanHoist = !I->isTrapping() && !I->mayWriteToMemory(); + if (CanHoist) { + for (unsigned op = 0, e = I->getNumOperands(); op != e; ++op) + if (Instruction *OpI = dyn_cast(I->getOperand(op))) + if (OpI->getParent() == DestBlock || + (isa(OpI) && OpI->getParent() == DomBlock)) { + CanHoist = false; + break; + } + if (CanHoist) { + // Remove from DestBlock, move right before the term in DomBlock. + DestBlock->getInstList().remove(I); + DomBlock->getInstList().insert(DomBlock->getTerminator(), I); + DOUT << "Hoisted: " << *I; + } + } + } + } + + // Tail duplication can not update SSA properties correctly if the values + // defined in the duplicated tail are used outside of the tail itself. For + // this reason, we spill all values that are used outside of the tail to the + // stack. + for (BasicBlock::iterator I = DestBlock->begin(); I != DestBlock->end(); ++I) + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); UI != E; + ++UI) { + bool ShouldDemote = false; + if (cast(*UI)->getParent() != DestBlock) { + // We must allow our successors to use tail values in their PHI nodes + // (if the incoming value corresponds to the tail block). + if (PHINode *PN = dyn_cast(*UI)) { + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == I && + PN->getIncomingBlock(i) != DestBlock) { + ShouldDemote = true; + break; + } + + } else { + ShouldDemote = true; + } + } else if (PHINode *PN = dyn_cast(cast(*UI))) { + // If the user of this instruction is a PHI node in the current block, + // which has an entry from another block using the value, spill it. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == I && + PN->getIncomingBlock(i) != DestBlock) { + ShouldDemote = true; + break; + } + } + + if (ShouldDemote) { + // We found a use outside of the tail. Create a new stack slot to + // break this inter-block usage pattern. + DemoteRegToStack(*I); + break; + } + } + + // We are going to have to map operands from the original block B to the new + // copy of the block B'. If there are PHI nodes in the DestBlock, these PHI + // nodes also define part of this mapping. Loop over these PHI nodes, adding + // them to our mapping. + // + std::map ValueMapping; + + BasicBlock::iterator BI = DestBlock->begin(); + bool HadPHINodes = isa(BI); + for (; PHINode *PN = dyn_cast(BI); ++BI) + ValueMapping[PN] = PN->getIncomingValueForBlock(SourceBlock); + + // Clone the non-phi instructions of the dest block into the source block, + // keeping track of the mapping... + // + for (; BI != DestBlock->end(); ++BI) { + Instruction *New = BI->clone(); + New->setName(BI->getName()); + SourceBlock->getInstList().push_back(New); + ValueMapping[BI] = New; + } + + // Now that we have built the mapping information and cloned all of the + // instructions (giving us a new terminator, among other things), walk the new + // instructions, rewriting references of old instructions to use new + // instructions. + // + BI = Branch; ++BI; // Get an iterator to the first new instruction + for (; BI != SourceBlock->end(); ++BI) + for (unsigned i = 0, e = BI->getNumOperands(); i != e; ++i) + if (Value *Remapped = ValueMapping[BI->getOperand(i)]) + BI->setOperand(i, Remapped); + + // Next we check to see if any of the successors of DestBlock had PHI nodes. + // If so, we need to add entries to the PHI nodes for SourceBlock now. + for (succ_iterator SI = succ_begin(DestBlock), SE = succ_end(DestBlock); + SI != SE; ++SI) { + BasicBlock *Succ = *SI; + for (BasicBlock::iterator PNI = Succ->begin(); isa(PNI); ++PNI) { + PHINode *PN = cast(PNI); + // Ok, we have a PHI node. Figure out what the incoming value was for the + // DestBlock. + Value *IV = PN->getIncomingValueForBlock(DestBlock); + + // Remap the value if necessary... + if (Value *MappedIV = ValueMapping[IV]) + IV = MappedIV; + PN->addIncoming(IV, SourceBlock); + } + } + + // Next, remove the old branch instruction, and any PHI node entries that we + // had. + BI = Branch; ++BI; // Get an iterator to the first new instruction + DestBlock->removePredecessor(SourceBlock); // Remove entries in PHI nodes... + SourceBlock->getInstList().erase(Branch); // Destroy the uncond branch... + + // Final step: now that we have finished everything up, walk the cloned + // instructions one last time, constant propagating and DCE'ing them, because + // they may not be needed anymore. + // + if (HadPHINodes) + while (BI != SourceBlock->end()) + if (!dceInstruction(BI) && !doConstantPropagation(BI)) + ++BI; + + ++NumEliminated; // We just killed a branch! +} diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp new file mode 100644 index 0000000..497b81f --- /dev/null +++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -0,0 +1,462 @@ +//===- TailRecursionElimination.cpp - Eliminate Tail Calls ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file transforms calls of the current function (self recursion) followed +// by a return instruction with a branch to the entry of the function, creating +// a loop. This pass also implements the following extensions to the basic +// algorithm: +// +// 1. Trivial instructions between the call and return do not prevent the +// transformation from taking place, though currently the analysis cannot +// support moving any really useful instructions (only dead ones). +// 2. This pass transforms functions that are prevented from being tail +// recursive by an associative expression to use an accumulator variable, +// thus compiling the typical naive factorial or 'fib' implementation into +// efficient code. +// 3. TRE is performed if the function returns void, if the return +// returns the result returned by the call, or if the function returns a +// run-time constant on all exits from the function. It is possible, though +// unlikely, that the return returns something else (like constant 0), and +// can still be TRE'd. It can be TRE'd if ALL OTHER return instructions in +// the function return the exact same value. +// 4. If it can prove that callees do not access theier caller stack frame, +// they are marked as eligible for tail call elimination (by the code +// generator). +// +// There are several improvements that could be made: +// +// 1. If the function has any alloca instructions, these instructions will be +// moved out of the entry block of the function, causing them to be +// evaluated each time through the tail recursion. Safely keeping allocas +// in the entry block requires analysis to proves that the tail-called +// function does not read or write the stack object. +// 2. Tail recursion is only performed if the call immediately preceeds the +// return instruction. It's possible that there could be a jump between +// the call and the return. +// 3. There can be intervening operations between the call and the return that +// prevent the TRE from occurring. For example, there could be GEP's and +// stores to memory that will not be read or written by the call. This +// requires some substantial analysis (such as with DSA) to prove safe to +// move ahead of the call, but doing so could allow many more TREs to be +// performed, for example in TreeAdd/TreeAlloc from the treeadd benchmark. +// 4. The algorithm we use to detect if callees access their caller stack +// frames is very primitive. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "tailcallelim" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Support/CFG.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumEliminated, "Number of tail calls removed"); +STATISTIC(NumAccumAdded, "Number of accumulators introduced"); + +namespace { + struct VISIBILITY_HIDDEN TailCallElim : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + TailCallElim() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F); + + private: + bool ProcessReturningBlock(ReturnInst *RI, BasicBlock *&OldEntry, + bool &TailCallsAreMarkedTail, + std::vector &ArgumentPHIs, + bool CannotTailCallElimCallsMarkedTail); + bool CanMoveAboveCall(Instruction *I, CallInst *CI); + Value *CanTransformAccumulatorRecursion(Instruction *I, CallInst *CI); + }; + char TailCallElim::ID = 0; + RegisterPass X("tailcallelim", "Tail Call Elimination"); +} + +// Public interface to the TailCallElimination pass +FunctionPass *llvm::createTailCallEliminationPass() { + return new TailCallElim(); +} + + +/// AllocaMightEscapeToCalls - Return true if this alloca may be accessed by +/// callees of this function. We only do very simple analysis right now, this +/// could be expanded in the future to use mod/ref information for particular +/// call sites if desired. +static bool AllocaMightEscapeToCalls(AllocaInst *AI) { + // FIXME: do simple 'address taken' analysis. + return true; +} + +/// FunctionContainsAllocas - Scan the specified basic block for alloca +/// instructions. If it contains any that might be accessed by calls, return +/// true. +static bool CheckForEscapingAllocas(BasicBlock *BB, + bool &CannotTCETailMarkedCall) { + bool RetVal = false; + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (AllocaInst *AI = dyn_cast(I)) { + RetVal |= AllocaMightEscapeToCalls(AI); + + // If this alloca is in the body of the function, or if it is a variable + // sized allocation, we cannot tail call eliminate calls marked 'tail' + // with this mechanism. + if (BB != &BB->getParent()->getEntryBlock() || + !isa(AI->getArraySize())) + CannotTCETailMarkedCall = true; + } + return RetVal; +} + +bool TailCallElim::runOnFunction(Function &F) { + // If this function is a varargs function, we won't be able to PHI the args + // right, so don't even try to convert it... + if (F.getFunctionType()->isVarArg()) return false; + + BasicBlock *OldEntry = 0; + bool TailCallsAreMarkedTail = false; + std::vector ArgumentPHIs; + bool MadeChange = false; + + bool FunctionContainsEscapingAllocas = false; + + // CannotTCETailMarkedCall - If true, we cannot perform TCE on tail calls + // marked with the 'tail' attribute, because doing so would cause the stack + // size to increase (real TCE would deallocate variable sized allocas, TCE + // doesn't). + bool CannotTCETailMarkedCall = false; + + // Loop over the function, looking for any returning blocks, and keeping track + // of whether this function has any non-trivially used allocas. + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { + if (FunctionContainsEscapingAllocas && CannotTCETailMarkedCall) + break; + + FunctionContainsEscapingAllocas |= + CheckForEscapingAllocas(BB, CannotTCETailMarkedCall); + } + + /// FIXME: The code generator produces really bad code when an 'escaping + /// alloca' is changed from being a static alloca to being a dynamic alloca. + /// Until this is resolved, disable this transformation if that would ever + /// happen. This bug is PR962. + if (FunctionContainsEscapingAllocas) + return false; + + + // Second pass, change any tail calls to loops. + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (ReturnInst *Ret = dyn_cast(BB->getTerminator())) + MadeChange |= ProcessReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, + ArgumentPHIs,CannotTCETailMarkedCall); + + // If we eliminated any tail recursions, it's possible that we inserted some + // silly PHI nodes which just merge an initial value (the incoming operand) + // with themselves. Check to see if we did and clean up our mess if so. This + // occurs when a function passes an argument straight through to its tail + // call. + if (!ArgumentPHIs.empty()) { + for (unsigned i = 0, e = ArgumentPHIs.size(); i != e; ++i) { + PHINode *PN = ArgumentPHIs[i]; + + // If the PHI Node is a dynamic constant, replace it with the value it is. + if (Value *PNV = PN->hasConstantValue()) { + PN->replaceAllUsesWith(PNV); + PN->eraseFromParent(); + } + } + } + + // Finally, if this function contains no non-escaping allocas, mark all calls + // in the function as eligible for tail calls (there is no stack memory for + // them to access). + if (!FunctionContainsEscapingAllocas) + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (CallInst *CI = dyn_cast(I)) + CI->setTailCall(); + + return MadeChange; +} + + +/// CanMoveAboveCall - Return true if it is safe to move the specified +/// instruction from after the call to before the call, assuming that all +/// instructions between the call and this instruction are movable. +/// +bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) { + // FIXME: We can move load/store/call/free instructions above the call if the + // call does not mod/ref the memory location being processed. + if (I->mayWriteToMemory() || isa(I)) + return false; + + // Otherwise, if this is a side-effect free instruction, check to make sure + // that it does not use the return value of the call. If it doesn't use the + // return value of the call, it must only use things that are defined before + // the call, or movable instructions between the call and the instruction + // itself. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (I->getOperand(i) == CI) + return false; + return true; +} + +// isDynamicConstant - Return true if the specified value is the same when the +// return would exit as it was when the initial iteration of the recursive +// function was executed. +// +// We currently handle static constants and arguments that are not modified as +// part of the recursion. +// +static bool isDynamicConstant(Value *V, CallInst *CI) { + if (isa(V)) return true; // Static constants are always dyn consts + + // Check to see if this is an immutable argument, if so, the value + // will be available to initialize the accumulator. + if (Argument *Arg = dyn_cast(V)) { + // Figure out which argument number this is... + unsigned ArgNo = 0; + Function *F = CI->getParent()->getParent(); + for (Function::arg_iterator AI = F->arg_begin(); &*AI != Arg; ++AI) + ++ArgNo; + + // If we are passing this argument into call as the corresponding + // argument operand, then the argument is dynamically constant. + // Otherwise, we cannot transform this function safely. + if (CI->getOperand(ArgNo+1) == Arg) + return true; + } + // Not a constant or immutable argument, we can't safely transform. + return false; +} + +// getCommonReturnValue - Check to see if the function containing the specified +// return instruction and tail call consistently returns the same +// runtime-constant value at all exit points. If so, return the returned value. +// +static Value *getCommonReturnValue(ReturnInst *TheRI, CallInst *CI) { + Function *F = TheRI->getParent()->getParent(); + Value *ReturnedValue = 0; + + for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI) + if (ReturnInst *RI = dyn_cast(BBI->getTerminator())) + if (RI != TheRI) { + Value *RetOp = RI->getOperand(0); + + // We can only perform this transformation if the value returned is + // evaluatable at the start of the initial invocation of the function, + // instead of at the end of the evaluation. + // + if (!isDynamicConstant(RetOp, CI)) + return 0; + + if (ReturnedValue && RetOp != ReturnedValue) + return 0; // Cannot transform if differing values are returned. + ReturnedValue = RetOp; + } + return ReturnedValue; +} + +/// CanTransformAccumulatorRecursion - If the specified instruction can be +/// transformed using accumulator recursion elimination, return the constant +/// which is the start of the accumulator value. Otherwise return null. +/// +Value *TailCallElim::CanTransformAccumulatorRecursion(Instruction *I, + CallInst *CI) { + if (!I->isAssociative()) return 0; + assert(I->getNumOperands() == 2 && + "Associative operations should have 2 args!"); + + // Exactly one operand should be the result of the call instruction... + if (I->getOperand(0) == CI && I->getOperand(1) == CI || + I->getOperand(0) != CI && I->getOperand(1) != CI) + return 0; + + // The only user of this instruction we allow is a single return instruction. + if (!I->hasOneUse() || !isa(I->use_back())) + return 0; + + // Ok, now we have to check all of the other return instructions in this + // function. If they return non-constants or differing values, then we cannot + // transform the function safely. + return getCommonReturnValue(cast(I->use_back()), CI); +} + +bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry, + bool &TailCallsAreMarkedTail, + std::vector &ArgumentPHIs, + bool CannotTailCallElimCallsMarkedTail) { + BasicBlock *BB = Ret->getParent(); + Function *F = BB->getParent(); + + if (&BB->front() == Ret) // Make sure there is something before the ret... + return false; + + // Scan backwards from the return, checking to see if there is a tail call in + // this block. If so, set CI to it. + CallInst *CI; + BasicBlock::iterator BBI = Ret; + while (1) { + CI = dyn_cast(BBI); + if (CI && CI->getCalledFunction() == F) + break; + + if (BBI == BB->begin()) + return false; // Didn't find a potential tail call. + --BBI; + } + + // If this call is marked as a tail call, and if there are dynamic allocas in + // the function, we cannot perform this optimization. + if (CI->isTailCall() && CannotTailCallElimCallsMarkedTail) + return false; + + // If we are introducing accumulator recursion to eliminate associative + // operations after the call instruction, this variable contains the initial + // value for the accumulator. If this value is set, we actually perform + // accumulator recursion elimination instead of simple tail recursion + // elimination. + Value *AccumulatorRecursionEliminationInitVal = 0; + Instruction *AccumulatorRecursionInstr = 0; + + // Ok, we found a potential tail call. We can currently only transform the + // tail call if all of the instructions between the call and the return are + // movable to above the call itself, leaving the call next to the return. + // Check that this is the case now. + for (BBI = CI, ++BBI; &*BBI != Ret; ++BBI) + if (!CanMoveAboveCall(BBI, CI)) { + // If we can't move the instruction above the call, it might be because it + // is an associative operation that could be tranformed using accumulator + // recursion elimination. Check to see if this is the case, and if so, + // remember the initial accumulator value for later. + if ((AccumulatorRecursionEliminationInitVal = + CanTransformAccumulatorRecursion(BBI, CI))) { + // Yes, this is accumulator recursion. Remember which instruction + // accumulates. + AccumulatorRecursionInstr = BBI; + } else { + return false; // Otherwise, we cannot eliminate the tail recursion! + } + } + + // We can only transform call/return pairs that either ignore the return value + // of the call and return void, ignore the value of the call and return a + // constant, return the value returned by the tail call, or that are being + // accumulator recursion variable eliminated. + if (Ret->getNumOperands() != 0 && Ret->getReturnValue() != CI && + !isa(Ret->getReturnValue()) && + AccumulatorRecursionEliminationInitVal == 0 && + !getCommonReturnValue(Ret, CI)) + return false; + + // OK! We can transform this tail call. If this is the first one found, + // create the new entry block, allowing us to branch back to the old entry. + if (OldEntry == 0) { + OldEntry = &F->getEntryBlock(); + BasicBlock *NewEntry = new BasicBlock("", F, OldEntry); + NewEntry->takeName(OldEntry); + OldEntry->setName("tailrecurse"); + new BranchInst(OldEntry, NewEntry); + + // If this tail call is marked 'tail' and if there are any allocas in the + // entry block, move them up to the new entry block. + TailCallsAreMarkedTail = CI->isTailCall(); + if (TailCallsAreMarkedTail) + // Move all fixed sized allocas from OldEntry to NewEntry. + for (BasicBlock::iterator OEBI = OldEntry->begin(), E = OldEntry->end(), + NEBI = NewEntry->begin(); OEBI != E; ) + if (AllocaInst *AI = dyn_cast(OEBI++)) + if (isa(AI->getArraySize())) + AI->moveBefore(NEBI); + + // Now that we have created a new block, which jumps to the entry + // block, insert a PHI node for each argument of the function. + // For now, we initialize each PHI to only have the real arguments + // which are passed in. + Instruction *InsertPos = OldEntry->begin(); + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); + I != E; ++I) { + PHINode *PN = new PHINode(I->getType(), I->getName()+".tr", InsertPos); + I->replaceAllUsesWith(PN); // Everyone use the PHI node now! + PN->addIncoming(I, NewEntry); + ArgumentPHIs.push_back(PN); + } + } + + // If this function has self recursive calls in the tail position where some + // are marked tail and some are not, only transform one flavor or another. We + // have to choose whether we move allocas in the entry block to the new entry + // block or not, so we can't make a good choice for both. NOTE: We could do + // slightly better here in the case that the function has no entry block + // allocas. + if (TailCallsAreMarkedTail && !CI->isTailCall()) + return false; + + // Ok, now that we know we have a pseudo-entry block WITH all of the + // required PHI nodes, add entries into the PHI node for the actual + // parameters passed into the tail-recursive call. + for (unsigned i = 0, e = CI->getNumOperands()-1; i != e; ++i) + ArgumentPHIs[i]->addIncoming(CI->getOperand(i+1), BB); + + // If we are introducing an accumulator variable to eliminate the recursion, + // do so now. Note that we _know_ that no subsequent tail recursion + // eliminations will happen on this function because of the way the + // accumulator recursion predicate is set up. + // + if (AccumulatorRecursionEliminationInitVal) { + Instruction *AccRecInstr = AccumulatorRecursionInstr; + // Start by inserting a new PHI node for the accumulator. + PHINode *AccPN = new PHINode(AccRecInstr->getType(), "accumulator.tr", + OldEntry->begin()); + + // Loop over all of the predecessors of the tail recursion block. For the + // real entry into the function we seed the PHI with the initial value, + // computed earlier. For any other existing branches to this block (due to + // other tail recursions eliminated) the accumulator is not modified. + // Because we haven't added the branch in the current block to OldEntry yet, + // it will not show up as a predecessor. + for (pred_iterator PI = pred_begin(OldEntry), PE = pred_end(OldEntry); + PI != PE; ++PI) { + if (*PI == &F->getEntryBlock()) + AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, *PI); + else + AccPN->addIncoming(AccPN, *PI); + } + + // Add an incoming argument for the current block, which is computed by our + // associative accumulator instruction. + AccPN->addIncoming(AccRecInstr, BB); + + // Next, rewrite the accumulator recursion instruction so that it does not + // use the result of the call anymore, instead, use the PHI node we just + // inserted. + AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); + + // Finally, rewrite any return instructions in the program to return the PHI + // node instead of the "initval" that they do currently. This loop will + // actually rewrite the return value we are destroying, but that's ok. + for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI) + if (ReturnInst *RI = dyn_cast(BBI->getTerminator())) + RI->setOperand(0, AccPN); + ++NumAccumAdded; + } + + // Now that all of the PHI nodes are in place, remove the call and + // ret instructions, replacing them with an unconditional branch. + new BranchInst(OldEntry, Ret); + BB->getInstList().erase(Ret); // Remove return. + BB->getInstList().erase(CI); // Remove call. + ++NumEliminated; + return true; +} diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp new file mode 100644 index 0000000..520cfeb --- /dev/null +++ b/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -0,0 +1,175 @@ +//===-- BasicBlockUtils.cpp - BasicBlock Utilities -------------------------==// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This family of functions perform manipulations on basic blocks, and +// instructions contained within basic blocks. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Constant.h" +#include "llvm/Type.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/Dominators.h" +#include +using namespace llvm; + +/// ReplaceInstWithValue - Replace all uses of an instruction (specified by BI) +/// with a value, then remove and delete the original instruction. +/// +void llvm::ReplaceInstWithValue(BasicBlock::InstListType &BIL, + BasicBlock::iterator &BI, Value *V) { + Instruction &I = *BI; + // Replaces all of the uses of the instruction with uses of the value + I.replaceAllUsesWith(V); + + // Make sure to propagate a name if there is one already. + if (I.hasName() && !V->hasName()) + V->takeName(&I); + + // Delete the unnecessary instruction now... + BI = BIL.erase(BI); +} + + +/// ReplaceInstWithInst - Replace the instruction specified by BI with the +/// instruction specified by I. The original instruction is deleted and BI is +/// updated to point to the new instruction. +/// +void llvm::ReplaceInstWithInst(BasicBlock::InstListType &BIL, + BasicBlock::iterator &BI, Instruction *I) { + assert(I->getParent() == 0 && + "ReplaceInstWithInst: Instruction already inserted into basic block!"); + + // Insert the new instruction into the basic block... + BasicBlock::iterator New = BIL.insert(BI, I); + + // Replace all uses of the old instruction, and delete it. + ReplaceInstWithValue(BIL, BI, I); + + // Move BI back to point to the newly inserted instruction + BI = New; +} + +/// ReplaceInstWithInst - Replace the instruction specified by From with the +/// instruction specified by To. +/// +void llvm::ReplaceInstWithInst(Instruction *From, Instruction *To) { + BasicBlock::iterator BI(From); + ReplaceInstWithInst(From->getParent()->getInstList(), BI, To); +} + +/// RemoveSuccessor - Change the specified terminator instruction such that its +/// successor SuccNum no longer exists. Because this reduces the outgoing +/// degree of the current basic block, the actual terminator instruction itself +/// may have to be changed. In the case where the last successor of the block +/// is deleted, a return instruction is inserted in its place which can cause a +/// surprising change in program behavior if it is not expected. +/// +void llvm::RemoveSuccessor(TerminatorInst *TI, unsigned SuccNum) { + assert(SuccNum < TI->getNumSuccessors() && + "Trying to remove a nonexistant successor!"); + + // If our old successor block contains any PHI nodes, remove the entry in the + // PHI nodes that comes from this branch... + // + BasicBlock *BB = TI->getParent(); + TI->getSuccessor(SuccNum)->removePredecessor(BB); + + TerminatorInst *NewTI = 0; + switch (TI->getOpcode()) { + case Instruction::Br: + // If this is a conditional branch... convert to unconditional branch. + if (TI->getNumSuccessors() == 2) { + cast(TI)->setUnconditionalDest(TI->getSuccessor(1-SuccNum)); + } else { // Otherwise convert to a return instruction... + Value *RetVal = 0; + + // Create a value to return... if the function doesn't return null... + if (BB->getParent()->getReturnType() != Type::VoidTy) + RetVal = Constant::getNullValue(BB->getParent()->getReturnType()); + + // Create the return... + NewTI = new ReturnInst(RetVal); + } + break; + + case Instruction::Invoke: // Should convert to call + case Instruction::Switch: // Should remove entry + default: + case Instruction::Ret: // Cannot happen, has no successors! + assert(0 && "Unhandled terminator instruction type in RemoveSuccessor!"); + abort(); + } + + if (NewTI) // If it's a different instruction, replace. + ReplaceInstWithInst(TI, NewTI); +} + +/// SplitEdge - Split the edge connecting specified block. Pass P must +/// not be NULL. +BasicBlock *llvm::SplitEdge(BasicBlock *BB, BasicBlock *Succ, Pass *P) { + TerminatorInst *LatchTerm = BB->getTerminator(); + unsigned SuccNum = 0; + for (unsigned i = 0, e = LatchTerm->getNumSuccessors(); ; ++i) { + assert(i != e && "Didn't find edge?"); + if (LatchTerm->getSuccessor(i) == Succ) { + SuccNum = i; + break; + } + } + + // If this is a critical edge, let SplitCriticalEdge do it. + if (SplitCriticalEdge(BB->getTerminator(), SuccNum, P)) + return LatchTerm->getSuccessor(SuccNum); + + // If the edge isn't critical, then BB has a single successor or Succ has a + // single pred. Split the block. + BasicBlock::iterator SplitPoint; + if (BasicBlock *SP = Succ->getSinglePredecessor()) { + // If the successor only has a single pred, split the top of the successor + // block. + assert(SP == BB && "CFG broken"); + return SplitBlock(Succ, Succ->begin(), P); + } else { + // Otherwise, if BB has a single successor, split it at the bottom of the + // block. + assert(BB->getTerminator()->getNumSuccessors() == 1 && + "Should have a single succ!"); + return SplitBlock(BB, BB->getTerminator(), P); + } +} + +/// SplitBlock - Split the specified block at the specified instruction - every +/// thing before SplitPt stays in Old and everything starting with SplitPt moves +/// to a new block. The two blocks are joined by an unconditional branch and +/// the loop info is updated. +/// +BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt, Pass *P) { + + LoopInfo &LI = P->getAnalysis(); + BasicBlock::iterator SplitIt = SplitPt; + while (isa(SplitIt)) + ++SplitIt; + BasicBlock *New = Old->splitBasicBlock(SplitIt, Old->getName()+".split"); + + // The new block lives in whichever loop the old one did. + if (Loop *L = LI.getLoopFor(Old)) + L->addBasicBlockToLoop(New, LI); + + if (DominatorTree *DT = P->getAnalysisToUpdate()) + DT->addNewBlock(New, Old); + + if (DominanceFrontier *DF = P->getAnalysisToUpdate()) + DF->splitBlock(Old); + + return New; +} diff --git a/lib/Transforms/Utils/BreakCriticalEdges.cpp b/lib/Transforms/Utils/BreakCriticalEdges.cpp new file mode 100644 index 0000000..af9a114 --- /dev/null +++ b/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -0,0 +1,269 @@ +//===- BreakCriticalEdges.cpp - Critical Edge Elimination Pass ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// BreakCriticalEdges pass - Break all of the critical edges in the CFG by +// inserting a dummy basic block. This pass may be "required" by passes that +// cannot deal with critical edges. For this usage, the structure type is +// forward declared. This pass obviously invalidates the CFG, but can update +// forward dominator (set, immediate dominators, tree, and frontier) +// information. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "break-crit-edges" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Type.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +using namespace llvm; + +STATISTIC(NumBroken, "Number of blocks inserted"); + +namespace { + struct VISIBILITY_HIDDEN BreakCriticalEdges : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + BreakCriticalEdges() : FunctionPass((intptr_t)&ID) {} + + virtual bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addPreserved(); + AU.addPreserved(); + AU.addPreserved(); + + // No loop canonicalization guarantees are broken by this pass. + AU.addPreservedID(LoopSimplifyID); + } + }; + + char BreakCriticalEdges::ID = 0; + RegisterPass X("break-crit-edges", + "Break critical edges in CFG"); +} + +// Publically exposed interface to pass... +const PassInfo *llvm::BreakCriticalEdgesID = X.getPassInfo(); +FunctionPass *llvm::createBreakCriticalEdgesPass() { + return new BreakCriticalEdges(); +} + +// runOnFunction - Loop over all of the edges in the CFG, breaking critical +// edges as they are found. +// +bool BreakCriticalEdges::runOnFunction(Function &F) { + bool Changed = false; + for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) { + TerminatorInst *TI = I->getTerminator(); + if (TI->getNumSuccessors() > 1) + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + if (SplitCriticalEdge(TI, i, this)) { + ++NumBroken; + Changed = true; + } + } + + return Changed; +} + +//===----------------------------------------------------------------------===// +// Implementation of the external critical edge manipulation functions +//===----------------------------------------------------------------------===// + +// isCriticalEdge - Return true if the specified edge is a critical edge. +// Critical edges are edges from a block with multiple successors to a block +// with multiple predecessors. +// +bool llvm::isCriticalEdge(const TerminatorInst *TI, unsigned SuccNum, + bool AllowIdenticalEdges) { + assert(SuccNum < TI->getNumSuccessors() && "Illegal edge specification!"); + if (TI->getNumSuccessors() == 1) return false; + + const BasicBlock *Dest = TI->getSuccessor(SuccNum); + pred_const_iterator I = pred_begin(Dest), E = pred_end(Dest); + + // If there is more than one predecessor, this is a critical edge... + assert(I != E && "No preds, but we have an edge to the block?"); + const BasicBlock *FirstPred = *I; + ++I; // Skip one edge due to the incoming arc from TI. + if (!AllowIdenticalEdges) + return I != E; + + // If AllowIdenticalEdges is true, then we allow this edge to be considered + // non-critical iff all preds come from TI's block. + for (; I != E; ++I) + if (*I != FirstPred) return true; + return false; +} + +// SplitCriticalEdge - If this edge is a critical edge, insert a new node to +// split the critical edge. This will update DominatorTree, and DominatorFrontier +// information if it is available, thus calling this pass will not invalidate +// any of them. This returns true if the edge was split, false otherwise. +// This ensures that all edges to that dest go to one block instead of each +// going to a different block. +// +bool llvm::SplitCriticalEdge(TerminatorInst *TI, unsigned SuccNum, Pass *P, + bool MergeIdenticalEdges) { + if (!isCriticalEdge(TI, SuccNum, MergeIdenticalEdges)) return false; + BasicBlock *TIBB = TI->getParent(); + BasicBlock *DestBB = TI->getSuccessor(SuccNum); + + // Create a new basic block, linking it into the CFG. + BasicBlock *NewBB = new BasicBlock(TIBB->getName() + "." + + DestBB->getName() + "_crit_edge"); + // Create our unconditional branch... + new BranchInst(DestBB, NewBB); + + // Branch to the new block, breaking the edge. + TI->setSuccessor(SuccNum, NewBB); + + // Insert the block into the function... right after the block TI lives in. + Function &F = *TIBB->getParent(); + Function::iterator FBBI = TIBB; + F.getBasicBlockList().insert(++FBBI, NewBB); + + // If there are any PHI nodes in DestBB, we need to update them so that they + // merge incoming values from NewBB instead of from TIBB. + // + for (BasicBlock::iterator I = DestBB->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + // We no longer enter through TIBB, now we come in through NewBB. Revector + // exactly one entry in the PHI node that used to come from TIBB to come + // from NewBB. + int BBIdx = PN->getBasicBlockIndex(TIBB); + PN->setIncomingBlock(BBIdx, NewBB); + } + + // If there are any other edges from TIBB to DestBB, update those to go + // through the split block, making those edges non-critical as well (and + // reducing the number of phi entries in the DestBB if relevant). + if (MergeIdenticalEdges) { + for (unsigned i = SuccNum+1, e = TI->getNumSuccessors(); i != e; ++i) { + if (TI->getSuccessor(i) != DestBB) continue; + + // Remove an entry for TIBB from DestBB phi nodes. + DestBB->removePredecessor(TIBB); + + // We found another edge to DestBB, go to NewBB instead. + TI->setSuccessor(i, NewBB); + } + } + + + + // If we don't have a pass object, we can't update anything... + if (P == 0) return true; + + // Now update analysis information. Since the only predecessor of NewBB is + // the TIBB, TIBB clearly dominates NewBB. TIBB usually doesn't dominate + // anything, as there are other successors of DestBB. However, if all other + // predecessors of DestBB are already dominated by DestBB (e.g. DestBB is a + // loop header) then NewBB dominates DestBB. + SmallVector OtherPreds; + + for (pred_iterator I = pred_begin(DestBB), E = pred_end(DestBB); I != E; ++I) + if (*I != NewBB) + OtherPreds.push_back(*I); + + bool NewBBDominatesDestBB = true; + + // Should we update DominatorTree information? + if (DominatorTree *DT = P->getAnalysisToUpdate()) { + DomTreeNode *TINode = DT->getNode(TIBB); + + // The new block is not the immediate dominator for any other nodes, but + // TINode is the immediate dominator for the new node. + // + if (TINode) { // Don't break unreachable code! + DomTreeNode *NewBBNode = DT->addNewBlock(NewBB, TIBB); + DomTreeNode *DestBBNode = 0; + + // If NewBBDominatesDestBB hasn't been computed yet, do so with DT. + if (!OtherPreds.empty()) { + DestBBNode = DT->getNode(DestBB); + while (!OtherPreds.empty() && NewBBDominatesDestBB) { + if (DomTreeNode *OPNode = DT->getNode(OtherPreds.back())) + NewBBDominatesDestBB = DT->dominates(DestBBNode, OPNode); + OtherPreds.pop_back(); + } + OtherPreds.clear(); + } + + // If NewBBDominatesDestBB, then NewBB dominates DestBB, otherwise it + // doesn't dominate anything. + if (NewBBDominatesDestBB) { + if (!DestBBNode) DestBBNode = DT->getNode(DestBB); + DT->changeImmediateDominator(DestBBNode, NewBBNode); + } + } + } + + // Should we update DominanceFrontier information? + if (DominanceFrontier *DF = P->getAnalysisToUpdate()) { + // If NewBBDominatesDestBB hasn't been computed yet, do so with DF. + if (!OtherPreds.empty()) { + // FIXME: IMPLEMENT THIS! + assert(0 && "Requiring domfrontiers but not idom/domtree/domset." + " not implemented yet!"); + } + + // Since the new block is dominated by its only predecessor TIBB, + // it cannot be in any block's dominance frontier. If NewBB dominates + // DestBB, its dominance frontier is the same as DestBB's, otherwise it is + // just {DestBB}. + DominanceFrontier::DomSetType NewDFSet; + if (NewBBDominatesDestBB) { + DominanceFrontier::iterator I = DF->find(DestBB); + if (I != DF->end()) + DF->addBasicBlock(NewBB, I->second); + else + DF->addBasicBlock(NewBB, DominanceFrontier::DomSetType()); + } else { + DominanceFrontier::DomSetType NewDFSet; + NewDFSet.insert(DestBB); + DF->addBasicBlock(NewBB, NewDFSet); + } + } + + // Update LoopInfo if it is around. + if (LoopInfo *LI = P->getAnalysisToUpdate()) { + // If one or the other blocks were not in a loop, the new block is not + // either, and thus LI doesn't need to be updated. + if (Loop *TIL = LI->getLoopFor(TIBB)) + if (Loop *DestLoop = LI->getLoopFor(DestBB)) { + if (TIL == DestLoop) { + // Both in the same loop, the NewBB joins loop. + DestLoop->addBasicBlockToLoop(NewBB, *LI); + } else if (TIL->contains(DestLoop->getHeader())) { + // Edge from an outer loop to an inner loop. Add to the outer loop. + TIL->addBasicBlockToLoop(NewBB, *LI); + } else if (DestLoop->contains(TIL->getHeader())) { + // Edge from an inner loop to an outer loop. Add to the outer loop. + DestLoop->addBasicBlockToLoop(NewBB, *LI); + } else { + // Edge from two loops with no containment relation. Because these + // are natural loops, we know that the destination block must be the + // header of its loop (adding a branch into a loop elsewhere would + // create an irreducible loop). + assert(DestLoop->getHeader() == DestBB && + "Should not create irreducible loops!"); + if (Loop *P = DestLoop->getParentLoop()) + P->addBasicBlockToLoop(NewBB, *LI); + } + } + } + return true; +} diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp new file mode 100644 index 0000000..cff58ab --- /dev/null +++ b/lib/Transforms/Utils/CloneFunction.cpp @@ -0,0 +1,485 @@ +//===- CloneFunction.cpp - Clone a function into another function ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the CloneFunctionInto interface, which is used as the +// low-level function cloner. This is used by the CloneFunction and function +// inliner to do the dirty work of copying the body of a function around. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Function.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "ValueMapper.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/ADT/SmallVector.h" +#include +using namespace llvm; + +// CloneBasicBlock - See comments in Cloning.h +BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, + DenseMap &ValueMap, + const char *NameSuffix, Function *F, + ClonedCodeInfo *CodeInfo) { + BasicBlock *NewBB = new BasicBlock("", F); + if (BB->hasName()) NewBB->setName(BB->getName()+NameSuffix); + + bool hasCalls = false, hasDynamicAllocas = false, hasStaticAllocas = false; + + // Loop over all instructions, and copy them over. + for (BasicBlock::const_iterator II = BB->begin(), IE = BB->end(); + II != IE; ++II) { + Instruction *NewInst = II->clone(); + if (II->hasName()) + NewInst->setName(II->getName()+NameSuffix); + NewBB->getInstList().push_back(NewInst); + ValueMap[II] = NewInst; // Add instruction map to value. + + hasCalls |= isa(II); + if (const AllocaInst *AI = dyn_cast(II)) { + if (isa(AI->getArraySize())) + hasStaticAllocas = true; + else + hasDynamicAllocas = true; + } + } + + if (CodeInfo) { + CodeInfo->ContainsCalls |= hasCalls; + CodeInfo->ContainsUnwinds |= isa(BB->getTerminator()); + CodeInfo->ContainsDynamicAllocas |= hasDynamicAllocas; + CodeInfo->ContainsDynamicAllocas |= hasStaticAllocas && + BB != &BB->getParent()->getEntryBlock(); + } + return NewBB; +} + +// Clone OldFunc into NewFunc, transforming the old arguments into references to +// ArgMap values. +// +void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, + DenseMap &ValueMap, + std::vector &Returns, + const char *NameSuffix, ClonedCodeInfo *CodeInfo) { + assert(NameSuffix && "NameSuffix cannot be null!"); + +#ifndef NDEBUG + for (Function::const_arg_iterator I = OldFunc->arg_begin(), + E = OldFunc->arg_end(); I != E; ++I) + assert(ValueMap.count(I) && "No mapping from source argument specified!"); +#endif + + // Loop over all of the basic blocks in the function, cloning them as + // appropriate. Note that we save BE this way in order to handle cloning of + // recursive functions into themselves. + // + for (Function::const_iterator BI = OldFunc->begin(), BE = OldFunc->end(); + BI != BE; ++BI) { + const BasicBlock &BB = *BI; + + // Create a new basic block and copy instructions into it! + BasicBlock *CBB = CloneBasicBlock(&BB, ValueMap, NameSuffix, NewFunc, + CodeInfo); + ValueMap[&BB] = CBB; // Add basic block mapping. + + if (ReturnInst *RI = dyn_cast(CBB->getTerminator())) + Returns.push_back(RI); + } + + // Loop over all of the instructions in the function, fixing up operand + // references as we go. This uses ValueMap to do all the hard work. + // + for (Function::iterator BB = cast(ValueMap[OldFunc->begin()]), + BE = NewFunc->end(); BB != BE; ++BB) + // Loop over all instructions, fixing each one as we find it... + for (BasicBlock::iterator II = BB->begin(); II != BB->end(); ++II) + RemapInstruction(II, ValueMap); +} + +/// CloneFunction - Return a copy of the specified function, but without +/// embedding the function into another module. Also, any references specified +/// in the ValueMap are changed to refer to their mapped value instead of the +/// original one. If any of the arguments to the function are in the ValueMap, +/// the arguments are deleted from the resultant function. The ValueMap is +/// updated to include mappings from all of the instructions and basicblocks in +/// the function from their old to new values. +/// +Function *llvm::CloneFunction(const Function *F, + DenseMap &ValueMap, + ClonedCodeInfo *CodeInfo) { + std::vector ArgTypes; + + // The user might be deleting arguments to the function by specifying them in + // the ValueMap. If so, we need to not add the arguments to the arg ty vector + // + for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end(); + I != E; ++I) + if (ValueMap.count(I) == 0) // Haven't mapped the argument to anything yet? + ArgTypes.push_back(I->getType()); + + // Create a new function type... + FunctionType *FTy = FunctionType::get(F->getFunctionType()->getReturnType(), + ArgTypes, F->getFunctionType()->isVarArg()); + + // Create the new function... + Function *NewF = new Function(FTy, F->getLinkage(), F->getName()); + + // Loop over the arguments, copying the names of the mapped arguments over... + Function::arg_iterator DestI = NewF->arg_begin(); + for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end(); + I != E; ++I) + if (ValueMap.count(I) == 0) { // Is this argument preserved? + DestI->setName(I->getName()); // Copy the name over... + ValueMap[I] = DestI++; // Add mapping to ValueMap + } + + std::vector Returns; // Ignore returns cloned... + CloneFunctionInto(NewF, F, ValueMap, Returns, "", CodeInfo); + return NewF; +} + + + +namespace { + /// PruningFunctionCloner - This class is a private class used to implement + /// the CloneAndPruneFunctionInto method. + struct VISIBILITY_HIDDEN PruningFunctionCloner { + Function *NewFunc; + const Function *OldFunc; + DenseMap &ValueMap; + std::vector &Returns; + const char *NameSuffix; + ClonedCodeInfo *CodeInfo; + const TargetData *TD; + + public: + PruningFunctionCloner(Function *newFunc, const Function *oldFunc, + DenseMap &valueMap, + std::vector &returns, + const char *nameSuffix, + ClonedCodeInfo *codeInfo, + const TargetData *td) + : NewFunc(newFunc), OldFunc(oldFunc), ValueMap(valueMap), Returns(returns), + NameSuffix(nameSuffix), CodeInfo(codeInfo), TD(td) { + } + + /// CloneBlock - The specified block is found to be reachable, clone it and + /// anything that it can reach. + void CloneBlock(const BasicBlock *BB, + std::vector &ToClone); + + public: + /// ConstantFoldMappedInstruction - Constant fold the specified instruction, + /// mapping its operands through ValueMap if they are available. + Constant *ConstantFoldMappedInstruction(const Instruction *I); + }; +} + +/// CloneBlock - The specified block is found to be reachable, clone it and +/// anything that it can reach. +void PruningFunctionCloner::CloneBlock(const BasicBlock *BB, + std::vector &ToClone){ + Value *&BBEntry = ValueMap[BB]; + + // Have we already cloned this block? + if (BBEntry) return; + + // Nope, clone it now. + BasicBlock *NewBB; + BBEntry = NewBB = new BasicBlock(); + if (BB->hasName()) NewBB->setName(BB->getName()+NameSuffix); + + bool hasCalls = false, hasDynamicAllocas = false, hasStaticAllocas = false; + + // Loop over all instructions, and copy them over, DCE'ing as we go. This + // loop doesn't include the terminator. + for (BasicBlock::const_iterator II = BB->begin(), IE = --BB->end(); + II != IE; ++II) { + // If this instruction constant folds, don't bother cloning the instruction, + // instead, just add the constant to the value map. + if (Constant *C = ConstantFoldMappedInstruction(II)) { + ValueMap[II] = C; + continue; + } + + Instruction *NewInst = II->clone(); + if (II->hasName()) + NewInst->setName(II->getName()+NameSuffix); + NewBB->getInstList().push_back(NewInst); + ValueMap[II] = NewInst; // Add instruction map to value. + + hasCalls |= isa(II); + if (const AllocaInst *AI = dyn_cast(II)) { + if (isa(AI->getArraySize())) + hasStaticAllocas = true; + else + hasDynamicAllocas = true; + } + } + + // Finally, clone over the terminator. + const TerminatorInst *OldTI = BB->getTerminator(); + bool TerminatorDone = false; + if (const BranchInst *BI = dyn_cast(OldTI)) { + if (BI->isConditional()) { + // If the condition was a known constant in the callee... + ConstantInt *Cond = dyn_cast(BI->getCondition()); + // Or is a known constant in the caller... + if (Cond == 0) + Cond = dyn_cast_or_null(ValueMap[BI->getCondition()]); + + // Constant fold to uncond branch! + if (Cond) { + BasicBlock *Dest = BI->getSuccessor(!Cond->getZExtValue()); + ValueMap[OldTI] = new BranchInst(Dest, NewBB); + ToClone.push_back(Dest); + TerminatorDone = true; + } + } + } else if (const SwitchInst *SI = dyn_cast(OldTI)) { + // If switching on a value known constant in the caller. + ConstantInt *Cond = dyn_cast(SI->getCondition()); + if (Cond == 0) // Or known constant after constant prop in the callee... + Cond = dyn_cast_or_null(ValueMap[SI->getCondition()]); + if (Cond) { // Constant fold to uncond branch! + BasicBlock *Dest = SI->getSuccessor(SI->findCaseValue(Cond)); + ValueMap[OldTI] = new BranchInst(Dest, NewBB); + ToClone.push_back(Dest); + TerminatorDone = true; + } + } + + if (!TerminatorDone) { + Instruction *NewInst = OldTI->clone(); + if (OldTI->hasName()) + NewInst->setName(OldTI->getName()+NameSuffix); + NewBB->getInstList().push_back(NewInst); + ValueMap[OldTI] = NewInst; // Add instruction map to value. + + // Recursively clone any reachable successor blocks. + const TerminatorInst *TI = BB->getTerminator(); + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + ToClone.push_back(TI->getSuccessor(i)); + } + + if (CodeInfo) { + CodeInfo->ContainsCalls |= hasCalls; + CodeInfo->ContainsUnwinds |= isa(OldTI); + CodeInfo->ContainsDynamicAllocas |= hasDynamicAllocas; + CodeInfo->ContainsDynamicAllocas |= hasStaticAllocas && + BB != &BB->getParent()->front(); + } + + if (ReturnInst *RI = dyn_cast(NewBB->getTerminator())) + Returns.push_back(RI); +} + +/// ConstantFoldMappedInstruction - Constant fold the specified instruction, +/// mapping its operands through ValueMap if they are available. +Constant *PruningFunctionCloner:: +ConstantFoldMappedInstruction(const Instruction *I) { + SmallVector Ops; + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Constant *Op = dyn_cast_or_null(MapValue(I->getOperand(i), + ValueMap))) + Ops.push_back(Op); + else + return 0; // All operands not constant! + + return ConstantFoldInstOperands(I, &Ops[0], Ops.size(), TD); +} + +/// CloneAndPruneFunctionInto - This works exactly like CloneFunctionInto, +/// except that it does some simple constant prop and DCE on the fly. The +/// effect of this is to copy significantly less code in cases where (for +/// example) a function call with constant arguments is inlined, and those +/// constant arguments cause a significant amount of code in the callee to be +/// dead. Since this doesn't produce an exactly copy of the input, it can't be +/// used for things like CloneFunction or CloneModule. +void llvm::CloneAndPruneFunctionInto(Function *NewFunc, const Function *OldFunc, + DenseMap &ValueMap, + std::vector &Returns, + const char *NameSuffix, + ClonedCodeInfo *CodeInfo, + const TargetData *TD) { + assert(NameSuffix && "NameSuffix cannot be null!"); + +#ifndef NDEBUG + for (Function::const_arg_iterator II = OldFunc->arg_begin(), + E = OldFunc->arg_end(); II != E; ++II) + assert(ValueMap.count(II) && "No mapping from source argument specified!"); +#endif + + PruningFunctionCloner PFC(NewFunc, OldFunc, ValueMap, Returns, + NameSuffix, CodeInfo, TD); + + // Clone the entry block, and anything recursively reachable from it. + std::vector CloneWorklist; + CloneWorklist.push_back(&OldFunc->getEntryBlock()); + while (!CloneWorklist.empty()) { + const BasicBlock *BB = CloneWorklist.back(); + CloneWorklist.pop_back(); + PFC.CloneBlock(BB, CloneWorklist); + } + + // Loop over all of the basic blocks in the old function. If the block was + // reachable, we have cloned it and the old block is now in the value map: + // insert it into the new function in the right order. If not, ignore it. + // + // Defer PHI resolution until rest of function is resolved. + std::vector PHIToResolve; + for (Function::const_iterator BI = OldFunc->begin(), BE = OldFunc->end(); + BI != BE; ++BI) { + BasicBlock *NewBB = cast_or_null(ValueMap[BI]); + if (NewBB == 0) continue; // Dead block. + + // Add the new block to the new function. + NewFunc->getBasicBlockList().push_back(NewBB); + + // Loop over all of the instructions in the block, fixing up operand + // references as we go. This uses ValueMap to do all the hard work. + // + BasicBlock::iterator I = NewBB->begin(); + + // Handle PHI nodes specially, as we have to remove references to dead + // blocks. + if (PHINode *PN = dyn_cast(I)) { + // Skip over all PHI nodes, remembering them for later. + BasicBlock::const_iterator OldI = BI->begin(); + for (; (PN = dyn_cast(I)); ++I, ++OldI) + PHIToResolve.push_back(cast(OldI)); + } + + // Otherwise, remap the rest of the instructions normally. + for (; I != NewBB->end(); ++I) + RemapInstruction(I, ValueMap); + } + + // Defer PHI resolution until rest of function is resolved, PHI resolution + // requires the CFG to be up-to-date. + for (unsigned phino = 0, e = PHIToResolve.size(); phino != e; ) { + const PHINode *OPN = PHIToResolve[phino]; + unsigned NumPreds = OPN->getNumIncomingValues(); + const BasicBlock *OldBB = OPN->getParent(); + BasicBlock *NewBB = cast(ValueMap[OldBB]); + + // Map operands for blocks that are live and remove operands for blocks + // that are dead. + for (; phino != PHIToResolve.size() && + PHIToResolve[phino]->getParent() == OldBB; ++phino) { + OPN = PHIToResolve[phino]; + PHINode *PN = cast(ValueMap[OPN]); + for (unsigned pred = 0, e = NumPreds; pred != e; ++pred) { + if (BasicBlock *MappedBlock = + cast_or_null(ValueMap[PN->getIncomingBlock(pred)])) { + Value *InVal = MapValue(PN->getIncomingValue(pred), ValueMap); + assert(InVal && "Unknown input value?"); + PN->setIncomingValue(pred, InVal); + PN->setIncomingBlock(pred, MappedBlock); + } else { + PN->removeIncomingValue(pred, false); + --pred, --e; // Revisit the next entry. + } + } + } + + // The loop above has removed PHI entries for those blocks that are dead + // and has updated others. However, if a block is live (i.e. copied over) + // but its terminator has been changed to not go to this block, then our + // phi nodes will have invalid entries. Update the PHI nodes in this + // case. + PHINode *PN = cast(NewBB->begin()); + NumPreds = std::distance(pred_begin(NewBB), pred_end(NewBB)); + if (NumPreds != PN->getNumIncomingValues()) { + assert(NumPreds < PN->getNumIncomingValues()); + // Count how many times each predecessor comes to this block. + std::map PredCount; + for (pred_iterator PI = pred_begin(NewBB), E = pred_end(NewBB); + PI != E; ++PI) + --PredCount[*PI]; + + // Figure out how many entries to remove from each PHI. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + ++PredCount[PN->getIncomingBlock(i)]; + + // At this point, the excess predecessor entries are positive in the + // map. Loop over all of the PHIs and remove excess predecessor + // entries. + BasicBlock::iterator I = NewBB->begin(); + for (; (PN = dyn_cast(I)); ++I) { + for (std::map::iterator PCI =PredCount.begin(), + E = PredCount.end(); PCI != E; ++PCI) { + BasicBlock *Pred = PCI->first; + for (unsigned NumToRemove = PCI->second; NumToRemove; --NumToRemove) + PN->removeIncomingValue(Pred, false); + } + } + } + + // If the loops above have made these phi nodes have 0 or 1 operand, + // replace them with undef or the input value. We must do this for + // correctness, because 0-operand phis are not valid. + PN = cast(NewBB->begin()); + if (PN->getNumIncomingValues() == 0) { + BasicBlock::iterator I = NewBB->begin(); + BasicBlock::const_iterator OldI = OldBB->begin(); + while ((PN = dyn_cast(I++))) { + Value *NV = UndefValue::get(PN->getType()); + PN->replaceAllUsesWith(NV); + assert(ValueMap[OldI] == PN && "ValueMap mismatch"); + ValueMap[OldI] = NV; + PN->eraseFromParent(); + ++OldI; + } + } + // NOTE: We cannot eliminate single entry phi nodes here, because of + // ValueMap. Single entry phi nodes can have multiple ValueMap entries + // pointing at them. Thus, deleting one would require scanning the ValueMap + // to update any entries in it that would require that. This would be + // really slow. + } + + // Now that the inlined function body has been fully constructed, go through + // and zap unconditional fall-through branches. This happen all the time when + // specializing code: code specialization turns conditional branches into + // uncond branches, and this code folds them. + Function::iterator I = cast(ValueMap[&OldFunc->getEntryBlock()]); + while (I != NewFunc->end()) { + BranchInst *BI = dyn_cast(I->getTerminator()); + if (!BI || BI->isConditional()) { ++I; continue; } + + // Note that we can't eliminate uncond branches if the destination has + // single-entry PHI nodes. Eliminating the single-entry phi nodes would + // require scanning the ValueMap to update any entries that point to the phi + // node. + BasicBlock *Dest = BI->getSuccessor(0); + if (!Dest->getSinglePredecessor() || isa(Dest->begin())) { + ++I; continue; + } + + // We know all single-entry PHI nodes in the inlined function have been + // removed, so we just need to splice the blocks. + BI->eraseFromParent(); + + // Move all the instructions in the succ to the pred. + I->getInstList().splice(I->end(), Dest->getInstList()); + + // Make all PHI nodes that referred to Dest now refer to I as their source. + Dest->replaceAllUsesWith(I); + + // Remove the dest block. + Dest->eraseFromParent(); + + // Do not increment I, iteratively merge all things this block branches to. + } +} diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp new file mode 100644 index 0000000..d64d58f --- /dev/null +++ b/lib/Transforms/Utils/CloneModule.cpp @@ -0,0 +1,124 @@ +//===- CloneModule.cpp - Clone an entire module ---------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the CloneModule interface which makes a copy of an +// entire module. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Module.h" +#include "llvm/DerivedTypes.h" +#include "llvm/TypeSymbolTable.h" +#include "llvm/Constant.h" +#include "ValueMapper.h" +using namespace llvm; + +/// CloneModule - Return an exact copy of the specified module. This is not as +/// easy as it might seem because we have to worry about making copies of global +/// variables and functions, and making their (initializers and references, +/// respectively) refer to the right globals. +/// +Module *llvm::CloneModule(const Module *M) { + // Create the value map that maps things from the old module over to the new + // module. + DenseMap ValueMap; + return CloneModule(M, ValueMap); +} + +Module *llvm::CloneModule(const Module *M, + DenseMap &ValueMap) { + // First off, we need to create the new module... + Module *New = new Module(M->getModuleIdentifier()); + New->setDataLayout(M->getDataLayout()); + New->setTargetTriple(M->getTargetTriple()); + New->setModuleInlineAsm(M->getModuleInlineAsm()); + + // Copy all of the type symbol table entries over. + const TypeSymbolTable &TST = M->getTypeSymbolTable(); + for (TypeSymbolTable::const_iterator TI = TST.begin(), TE = TST.end(); + TI != TE; ++TI) + New->addTypeName(TI->first, TI->second); + + // Copy all of the dependent libraries over. + for (Module::lib_iterator I = M->lib_begin(), E = M->lib_end(); I != E; ++I) + New->addLibrary(*I); + + // Loop over all of the global variables, making corresponding globals in the + // new module. Here we add them to the ValueMap and to the new Module. We + // don't worry about attributes or initializers, they will come later. + // + for (Module::const_global_iterator I = M->global_begin(), E = M->global_end(); + I != E; ++I) + ValueMap[I] = new GlobalVariable(I->getType()->getElementType(), false, + GlobalValue::ExternalLinkage, 0, + I->getName(), New); + + // Loop over the functions in the module, making external functions as before + for (Module::const_iterator I = M->begin(), E = M->end(); I != E; ++I) { + Function *NF = + new Function(cast(I->getType()->getElementType()), + GlobalValue::ExternalLinkage, I->getName(), New); + NF->setCallingConv(I->getCallingConv()); + ValueMap[I]= NF; + } + + // Loop over the aliases in the module + for (Module::const_alias_iterator I = M->alias_begin(), E = M->alias_end(); + I != E; ++I) + ValueMap[I] = new GlobalAlias(I->getType(), GlobalAlias::ExternalLinkage, + I->getName(), NULL, New); + + // Now that all of the things that global variable initializer can refer to + // have been created, loop through and copy the global variable referrers + // over... We also set the attributes on the global now. + // + for (Module::const_global_iterator I = M->global_begin(), E = M->global_end(); + I != E; ++I) { + GlobalVariable *GV = cast(ValueMap[I]); + if (I->hasInitializer()) + GV->setInitializer(cast(MapValue(I->getInitializer(), + ValueMap))); + GV->setLinkage(I->getLinkage()); + GV->setThreadLocal(I->isThreadLocal()); + GV->setConstant(I->isConstant()); + } + + // Similarly, copy over function bodies now... + // + for (Module::const_iterator I = M->begin(), E = M->end(); I != E; ++I) { + Function *F = cast(ValueMap[I]); + if (!I->isDeclaration()) { + Function::arg_iterator DestI = F->arg_begin(); + for (Function::const_arg_iterator J = I->arg_begin(); J != I->arg_end(); + ++J) { + DestI->setName(J->getName()); + ValueMap[J] = DestI++; + } + + std::vector Returns; // Ignore returns cloned... + CloneFunctionInto(F, I, ValueMap, Returns); + } + + F->setLinkage(I->getLinkage()); + } + + // And aliases + for (Module::const_alias_iterator I = M->alias_begin(), E = M->alias_end(); + I != E; ++I) { + GlobalAlias *GA = cast(ValueMap[I]); + GA->setLinkage(I->getLinkage()); + if (const Constant* C = I->getAliasee()) + GA->setAliasee(cast(MapValue(C, ValueMap))); + } + + return New; +} + +// vim: sw=2 diff --git a/lib/Transforms/Utils/CloneTrace.cpp b/lib/Transforms/Utils/CloneTrace.cpp new file mode 100644 index 0000000..97e57b2 --- /dev/null +++ b/lib/Transforms/Utils/CloneTrace.cpp @@ -0,0 +1,120 @@ +//===- CloneTrace.cpp - Clone a trace -------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the CloneTrace interface, which is used when writing +// runtime optimizations. It takes a vector of basic blocks clones the basic +// blocks, removes internal phi nodes, adds it to the same function as the +// original (although there is no jump to it) and returns the new vector of +// basic blocks. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Trace.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Instructions.h" +#include "llvm/Function.h" +#include "ValueMapper.h" +using namespace llvm; + +//Clones the trace (a vector of basic blocks) +std::vector +llvm::CloneTrace(const std::vector &origTrace) { + std::vector clonedTrace; + DenseMap ValueMap; + + //First, loop over all the Basic Blocks in the trace and copy + //them using CloneBasicBlock. Also fix the phi nodes during + //this loop. To fix the phi nodes, we delete incoming branches + //that are not in the trace. + for(std::vector::const_iterator T = origTrace.begin(), + End = origTrace.end(); T != End; ++T) { + + //Clone Basic Block + BasicBlock *clonedBlock = + CloneBasicBlock(*T, ValueMap, ".tr", (*T)->getParent()); + + //Add it to our new trace + clonedTrace.push_back(clonedBlock); + + //Add this new mapping to our Value Map + ValueMap[*T] = clonedBlock; + + //Loop over the phi instructions and delete operands + //that are from blocks not in the trace + //only do this if we are NOT the first block + if(T != origTrace.begin()) { + for (BasicBlock::iterator I = clonedBlock->begin(); + isa(I); ++I) { + PHINode *PN = cast(I); + //get incoming value for the previous BB + Value *V = PN->getIncomingValueForBlock(*(T-1)); + assert(V && "No incoming value from a BasicBlock in our trace!"); + + //remap our phi node to point to incoming value + ValueMap[*&I] = V; + + //remove phi node + clonedBlock->getInstList().erase(PN); + } + } + } + + //Second loop to do the remapping + for(std::vector::const_iterator BB = clonedTrace.begin(), + BE = clonedTrace.end(); BB != BE; ++BB) { + for(BasicBlock::iterator I = (*BB)->begin(); I != (*BB)->end(); ++I) { + + //Loop over all the operands of the instruction + for(unsigned op=0, E = I->getNumOperands(); op != E; ++op) { + const Value *Op = I->getOperand(op); + + //Get it out of the value map + Value *V = ValueMap[Op]; + + //If not in the value map, then its outside our trace so ignore + if(V != 0) + I->setOperand(op,V); + } + } + } + + //return new vector of basic blocks + return clonedTrace; +} + +/// CloneTraceInto - Clone T into NewFunc. Original<->clone mapping is +/// saved in ValueMap. +/// +void llvm::CloneTraceInto(Function *NewFunc, Trace &T, + DenseMap &ValueMap, + const char *NameSuffix) { + assert(NameSuffix && "NameSuffix cannot be null!"); + + // Loop over all of the basic blocks in the trace, cloning them as + // appropriate. + // + for (Trace::const_iterator BI = T.begin(), BE = T.end(); BI != BE; ++BI) { + const BasicBlock *BB = *BI; + + // Create a new basic block and copy instructions into it! + BasicBlock *CBB = CloneBasicBlock(BB, ValueMap, NameSuffix, NewFunc); + ValueMap[BB] = CBB; // Add basic block mapping. + } + + // Loop over all of the instructions in the new function, fixing up operand + // references as we go. This uses ValueMap to do all the hard work. + // + for (Function::iterator BB = + cast(ValueMap[T.getEntryBasicBlock()]), + BE = NewFunc->end(); BB != BE; ++BB) + // Loop over all instructions, fixing each one as we find it... + for (BasicBlock::iterator II = BB->begin(); II != BB->end(); ++II) + RemapInstruction(II, ValueMap); +} + diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp new file mode 100644 index 0000000..aaf9986 --- /dev/null +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -0,0 +1,737 @@ +//===- CodeExtractor.cpp - Pull code region into a new function -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the interface to tear out a code region, such as an +// individual loop or a parallel section, into a new function, replacing it with +// a call to the new function. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/FunctionUtils.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Intrinsics.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/Verifier.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/StringExtras.h" +#include +#include +using namespace llvm; + +// Provide a command-line option to aggregate function arguments into a struct +// for functions produced by the code extrator. This is useful when converting +// extracted functions to pthread-based code, as only one argument (void*) can +// be passed in to pthread_create(). +static cl::opt +AggregateArgsOpt("aggregate-extracted-args", cl::Hidden, + cl::desc("Aggregate arguments to code-extracted functions")); + +namespace { + class VISIBILITY_HIDDEN CodeExtractor { + typedef std::vector Values; + std::set BlocksToExtract; + DominatorTree* DT; + bool AggregateArgs; + unsigned NumExitBlocks; + const Type *RetTy; + public: + CodeExtractor(DominatorTree* dt = 0, bool AggArgs = false) + : DT(dt), AggregateArgs(AggArgs||AggregateArgsOpt), NumExitBlocks(~0U) {} + + Function *ExtractCodeRegion(const std::vector &code); + + bool isEligible(const std::vector &code); + + private: + /// definedInRegion - Return true if the specified value is defined in the + /// extracted region. + bool definedInRegion(Value *V) const { + if (Instruction *I = dyn_cast(V)) + if (BlocksToExtract.count(I->getParent())) + return true; + return false; + } + + /// definedInCaller - Return true if the specified value is defined in the + /// function being code extracted, but not in the region being extracted. + /// These values must be passed in as live-ins to the function. + bool definedInCaller(Value *V) const { + if (isa(V)) return true; + if (Instruction *I = dyn_cast(V)) + if (!BlocksToExtract.count(I->getParent())) + return true; + return false; + } + + void severSplitPHINodes(BasicBlock *&Header); + void splitReturnBlocks(); + void findInputsOutputs(Values &inputs, Values &outputs); + + Function *constructFunction(const Values &inputs, + const Values &outputs, + BasicBlock *header, + BasicBlock *newRootNode, BasicBlock *newHeader, + Function *oldFunction, Module *M); + + void moveCodeToFunction(Function *newFunction); + + void emitCallAndSwitchStatement(Function *newFunction, + BasicBlock *newHeader, + Values &inputs, + Values &outputs); + + }; +} + +/// severSplitPHINodes - If a PHI node has multiple inputs from outside of the +/// region, we need to split the entry block of the region so that the PHI node +/// is easier to deal with. +void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { + bool HasPredsFromRegion = false; + unsigned NumPredsOutsideRegion = 0; + + if (Header != &Header->getParent()->getEntryBlock()) { + PHINode *PN = dyn_cast(Header->begin()); + if (!PN) return; // No PHI nodes. + + // If the header node contains any PHI nodes, check to see if there is more + // than one entry from outside the region. If so, we need to sever the + // header block into two. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (BlocksToExtract.count(PN->getIncomingBlock(i))) + HasPredsFromRegion = true; + else + ++NumPredsOutsideRegion; + + // If there is one (or fewer) predecessor from outside the region, we don't + // need to do anything special. + if (NumPredsOutsideRegion <= 1) return; + } + + // Otherwise, we need to split the header block into two pieces: one + // containing PHI nodes merging values from outside of the region, and a + // second that contains all of the code for the block and merges back any + // incoming values from inside of the region. + BasicBlock::iterator AfterPHIs = Header->begin(); + while (isa(AfterPHIs)) ++AfterPHIs; + BasicBlock *NewBB = Header->splitBasicBlock(AfterPHIs, + Header->getName()+".ce"); + + // We only want to code extract the second block now, and it becomes the new + // header of the region. + BasicBlock *OldPred = Header; + BlocksToExtract.erase(OldPred); + BlocksToExtract.insert(NewBB); + Header = NewBB; + + // Okay, update dominator sets. The blocks that dominate the new one are the + // blocks that dominate TIBB plus the new block itself. + if (DT) + DT->splitBlock(NewBB); + + // Okay, now we need to adjust the PHI nodes and any branches from within the + // region to go to the new header block instead of the old header block. + if (HasPredsFromRegion) { + PHINode *PN = cast(OldPred->begin()); + // Loop over all of the predecessors of OldPred that are in the region, + // changing them to branch to NewBB instead. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (BlocksToExtract.count(PN->getIncomingBlock(i))) { + TerminatorInst *TI = PN->getIncomingBlock(i)->getTerminator(); + TI->replaceUsesOfWith(OldPred, NewBB); + } + + // Okay, everthing within the region is now branching to the right block, we + // just have to update the PHI nodes now, inserting PHI nodes into NewBB. + for (AfterPHIs = OldPred->begin(); isa(AfterPHIs); ++AfterPHIs) { + PHINode *PN = cast(AfterPHIs); + // Create a new PHI node in the new region, which has an incoming value + // from OldPred of PN. + PHINode *NewPN = new PHINode(PN->getType(), PN->getName()+".ce", + NewBB->begin()); + NewPN->addIncoming(PN, OldPred); + + // Loop over all of the incoming value in PN, moving them to NewPN if they + // are from the extracted region. + for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) { + if (BlocksToExtract.count(PN->getIncomingBlock(i))) { + NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i)); + PN->removeIncomingValue(i); + --i; + } + } + } + } +} + +void CodeExtractor::splitReturnBlocks() { + for (std::set::iterator I = BlocksToExtract.begin(), + E = BlocksToExtract.end(); I != E; ++I) + if (ReturnInst *RI = dyn_cast((*I)->getTerminator())) + (*I)->splitBasicBlock(RI, (*I)->getName()+".ret"); +} + +// findInputsOutputs - Find inputs to, outputs from the code region. +// +void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs) { + std::set ExitBlocks; + for (std::set::const_iterator ci = BlocksToExtract.begin(), + ce = BlocksToExtract.end(); ci != ce; ++ci) { + BasicBlock *BB = *ci; + + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + // If a used value is defined outside the region, it's an input. If an + // instruction is used outside the region, it's an output. + for (User::op_iterator O = I->op_begin(), E = I->op_end(); O != E; ++O) + if (definedInCaller(*O)) + inputs.push_back(*O); + + // Consider uses of this instruction (outputs). + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) + if (!definedInRegion(*UI)) { + outputs.push_back(I); + break; + } + } // for: insts + + // Keep track of the exit blocks from the region. + TerminatorInst *TI = BB->getTerminator(); + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + if (!BlocksToExtract.count(TI->getSuccessor(i))) + ExitBlocks.insert(TI->getSuccessor(i)); + } // for: basic blocks + + NumExitBlocks = ExitBlocks.size(); + + // Eliminate duplicates. + std::sort(inputs.begin(), inputs.end()); + inputs.erase(std::unique(inputs.begin(), inputs.end()), inputs.end()); + std::sort(outputs.begin(), outputs.end()); + outputs.erase(std::unique(outputs.begin(), outputs.end()), outputs.end()); +} + +/// constructFunction - make a function based on inputs and outputs, as follows: +/// f(in0, ..., inN, out0, ..., outN) +/// +Function *CodeExtractor::constructFunction(const Values &inputs, + const Values &outputs, + BasicBlock *header, + BasicBlock *newRootNode, + BasicBlock *newHeader, + Function *oldFunction, + Module *M) { + DOUT << "inputs: " << inputs.size() << "\n"; + DOUT << "outputs: " << outputs.size() << "\n"; + + // This function returns unsigned, outputs will go back by reference. + switch (NumExitBlocks) { + case 0: + case 1: RetTy = Type::VoidTy; break; + case 2: RetTy = Type::Int1Ty; break; + default: RetTy = Type::Int16Ty; break; + } + + std::vector paramTy; + + // Add the types of the input values to the function's argument list + for (Values::const_iterator i = inputs.begin(), + e = inputs.end(); i != e; ++i) { + const Value *value = *i; + DOUT << "value used in func: " << *value << "\n"; + paramTy.push_back(value->getType()); + } + + // Add the types of the output values to the function's argument list. + for (Values::const_iterator I = outputs.begin(), E = outputs.end(); + I != E; ++I) { + DOUT << "instr used in func: " << **I << "\n"; + if (AggregateArgs) + paramTy.push_back((*I)->getType()); + else + paramTy.push_back(PointerType::get((*I)->getType())); + } + + DOUT << "Function type: " << *RetTy << " f("; + for (std::vector::iterator i = paramTy.begin(), + e = paramTy.end(); i != e; ++i) + DOUT << **i << ", "; + DOUT << ")\n"; + + if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { + PointerType *StructPtr = PointerType::get(StructType::get(paramTy)); + paramTy.clear(); + paramTy.push_back(StructPtr); + } + const FunctionType *funcType = FunctionType::get(RetTy, paramTy, false); + + // Create the new function + Function *newFunction = new Function(funcType, + GlobalValue::InternalLinkage, + oldFunction->getName() + "_" + + header->getName(), M); + newFunction->getBasicBlockList().push_back(newRootNode); + + // Create an iterator to name all of the arguments we inserted. + Function::arg_iterator AI = newFunction->arg_begin(); + + // Rewrite all users of the inputs in the extracted region to use the + // arguments (or appropriate addressing into struct) instead. + for (unsigned i = 0, e = inputs.size(); i != e; ++i) { + Value *RewriteVal; + if (AggregateArgs) { + Value *Idx0 = Constant::getNullValue(Type::Int32Ty); + Value *Idx1 = ConstantInt::get(Type::Int32Ty, i); + std::string GEPname = "gep_" + inputs[i]->getName(); + TerminatorInst *TI = newFunction->begin()->getTerminator(); + GetElementPtrInst *GEP = new GetElementPtrInst(AI, Idx0, Idx1, + GEPname, TI); + RewriteVal = new LoadInst(GEP, "load" + GEPname, TI); + } else + RewriteVal = AI++; + + std::vector Users(inputs[i]->use_begin(), inputs[i]->use_end()); + for (std::vector::iterator use = Users.begin(), useE = Users.end(); + use != useE; ++use) + if (Instruction* inst = dyn_cast(*use)) + if (BlocksToExtract.count(inst->getParent())) + inst->replaceUsesOfWith(inputs[i], RewriteVal); + } + + // Set names for input and output arguments. + if (!AggregateArgs) { + AI = newFunction->arg_begin(); + for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI) + AI->setName(inputs[i]->getName()); + for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI) + AI->setName(outputs[i]->getName()+".out"); + } + + // Rewrite branches to basic blocks outside of the loop to new dummy blocks + // within the new function. This must be done before we lose track of which + // blocks were originally in the code region. + std::vector Users(header->use_begin(), header->use_end()); + for (unsigned i = 0, e = Users.size(); i != e; ++i) + // The BasicBlock which contains the branch is not in the region + // modify the branch target to a new block + if (TerminatorInst *TI = dyn_cast(Users[i])) + if (!BlocksToExtract.count(TI->getParent()) && + TI->getParent()->getParent() == oldFunction) + TI->replaceUsesOfWith(header, newHeader); + + return newFunction; +} + +/// emitCallAndSwitchStatement - This method sets up the caller side by adding +/// the call instruction, splitting any PHI nodes in the header block as +/// necessary. +void CodeExtractor:: +emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, + Values &inputs, Values &outputs) { + // Emit a call to the new function, passing in: *pointer to struct (if + // aggregating parameters), or plan inputs and allocated memory for outputs + std::vector params, StructValues, ReloadOutputs; + + // Add inputs as params, or to be filled into the struct + for (Values::iterator i = inputs.begin(), e = inputs.end(); i != e; ++i) + if (AggregateArgs) + StructValues.push_back(*i); + else + params.push_back(*i); + + // Create allocas for the outputs + for (Values::iterator i = outputs.begin(), e = outputs.end(); i != e; ++i) { + if (AggregateArgs) { + StructValues.push_back(*i); + } else { + AllocaInst *alloca = + new AllocaInst((*i)->getType(), 0, (*i)->getName()+".loc", + codeReplacer->getParent()->begin()->begin()); + ReloadOutputs.push_back(alloca); + params.push_back(alloca); + } + } + + AllocaInst *Struct = 0; + if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { + std::vector ArgTypes; + for (Values::iterator v = StructValues.begin(), + ve = StructValues.end(); v != ve; ++v) + ArgTypes.push_back((*v)->getType()); + + // Allocate a struct at the beginning of this function + Type *StructArgTy = StructType::get(ArgTypes); + Struct = + new AllocaInst(StructArgTy, 0, "structArg", + codeReplacer->getParent()->begin()->begin()); + params.push_back(Struct); + + for (unsigned i = 0, e = inputs.size(); i != e; ++i) { + Value *Idx0 = Constant::getNullValue(Type::Int32Ty); + Value *Idx1 = ConstantInt::get(Type::Int32Ty, i); + GetElementPtrInst *GEP = + new GetElementPtrInst(Struct, Idx0, Idx1, + "gep_" + StructValues[i]->getName()); + codeReplacer->getInstList().push_back(GEP); + StoreInst *SI = new StoreInst(StructValues[i], GEP); + codeReplacer->getInstList().push_back(SI); + } + } + + // Emit the call to the function + CallInst *call = new CallInst(newFunction, ¶ms[0], params.size(), + NumExitBlocks > 1 ? "targetBlock" : ""); + codeReplacer->getInstList().push_back(call); + + Function::arg_iterator OutputArgBegin = newFunction->arg_begin(); + unsigned FirstOut = inputs.size(); + if (!AggregateArgs) + std::advance(OutputArgBegin, inputs.size()); + + // Reload the outputs passed in by reference + for (unsigned i = 0, e = outputs.size(); i != e; ++i) { + Value *Output = 0; + if (AggregateArgs) { + Value *Idx0 = Constant::getNullValue(Type::Int32Ty); + Value *Idx1 = ConstantInt::get(Type::Int32Ty, FirstOut + i); + GetElementPtrInst *GEP + = new GetElementPtrInst(Struct, Idx0, Idx1, + "gep_reload_" + outputs[i]->getName()); + codeReplacer->getInstList().push_back(GEP); + Output = GEP; + } else { + Output = ReloadOutputs[i]; + } + LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload"); + codeReplacer->getInstList().push_back(load); + std::vector Users(outputs[i]->use_begin(), outputs[i]->use_end()); + for (unsigned u = 0, e = Users.size(); u != e; ++u) { + Instruction *inst = cast(Users[u]); + if (!BlocksToExtract.count(inst->getParent())) + inst->replaceUsesOfWith(outputs[i], load); + } + } + + // Now we can emit a switch statement using the call as a value. + SwitchInst *TheSwitch = + new SwitchInst(ConstantInt::getNullValue(Type::Int16Ty), + codeReplacer, 0, codeReplacer); + + // Since there may be multiple exits from the original region, make the new + // function return an unsigned, switch on that number. This loop iterates + // over all of the blocks in the extracted region, updating any terminator + // instructions in the to-be-extracted region that branch to blocks that are + // not in the region to be extracted. + std::map ExitBlockMap; + + unsigned switchVal = 0; + for (std::set::const_iterator i = BlocksToExtract.begin(), + e = BlocksToExtract.end(); i != e; ++i) { + TerminatorInst *TI = (*i)->getTerminator(); + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + if (!BlocksToExtract.count(TI->getSuccessor(i))) { + BasicBlock *OldTarget = TI->getSuccessor(i); + // add a new basic block which returns the appropriate value + BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; + if (!NewTarget) { + // If we don't already have an exit stub for this non-extracted + // destination, create one now! + NewTarget = new BasicBlock(OldTarget->getName() + ".exitStub", + newFunction); + unsigned SuccNum = switchVal++; + + Value *brVal = 0; + switch (NumExitBlocks) { + case 0: + case 1: break; // No value needed. + case 2: // Conditional branch, return a bool + brVal = ConstantInt::get(Type::Int1Ty, !SuccNum); + break; + default: + brVal = ConstantInt::get(Type::Int16Ty, SuccNum); + break; + } + + ReturnInst *NTRet = new ReturnInst(brVal, NewTarget); + + // Update the switch instruction. + TheSwitch->addCase(ConstantInt::get(Type::Int16Ty, SuccNum), + OldTarget); + + // Restore values just before we exit + Function::arg_iterator OAI = OutputArgBegin; + for (unsigned out = 0, e = outputs.size(); out != e; ++out) { + // For an invoke, the normal destination is the only one that is + // dominated by the result of the invocation + BasicBlock *DefBlock = cast(outputs[out])->getParent(); + + bool DominatesDef = true; + + if (InvokeInst *Invoke = dyn_cast(outputs[out])) { + DefBlock = Invoke->getNormalDest(); + + // Make sure we are looking at the original successor block, not + // at a newly inserted exit block, which won't be in the dominator + // info. + for (std::map::iterator I = + ExitBlockMap.begin(), E = ExitBlockMap.end(); I != E; ++I) + if (DefBlock == I->second) { + DefBlock = I->first; + break; + } + + // In the extract block case, if the block we are extracting ends + // with an invoke instruction, make sure that we don't emit a + // store of the invoke value for the unwind block. + if (!DT && DefBlock != OldTarget) + DominatesDef = false; + } + + if (DT) + DominatesDef = DT->dominates(DefBlock, OldTarget); + + if (DominatesDef) { + if (AggregateArgs) { + Value *Idx0 = Constant::getNullValue(Type::Int32Ty); + Value *Idx1 = ConstantInt::get(Type::Int32Ty,FirstOut+out); + GetElementPtrInst *GEP = + new GetElementPtrInst(OAI, Idx0, Idx1, + "gep_" + outputs[out]->getName(), + NTRet); + new StoreInst(outputs[out], GEP, NTRet); + } else { + new StoreInst(outputs[out], OAI, NTRet); + } + } + // Advance output iterator even if we don't emit a store + if (!AggregateArgs) ++OAI; + } + } + + // rewrite the original branch instruction with this new target + TI->setSuccessor(i, NewTarget); + } + } + + // Now that we've done the deed, simplify the switch instruction. + const Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); + switch (NumExitBlocks) { + case 0: + // There are no successors (the block containing the switch itself), which + // means that previously this was the last part of the function, and hence + // this should be rewritten as a `ret' + + // Check if the function should return a value + if (OldFnRetTy == Type::VoidTy) { + new ReturnInst(0, TheSwitch); // Return void + } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { + // return what we have + new ReturnInst(TheSwitch->getCondition(), TheSwitch); + } else { + // Otherwise we must have code extracted an unwind or something, just + // return whatever we want. + new ReturnInst(Constant::getNullValue(OldFnRetTy), TheSwitch); + } + + TheSwitch->getParent()->getInstList().erase(TheSwitch); + break; + case 1: + // Only a single destination, change the switch into an unconditional + // branch. + new BranchInst(TheSwitch->getSuccessor(1), TheSwitch); + TheSwitch->getParent()->getInstList().erase(TheSwitch); + break; + case 2: + new BranchInst(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), + call, TheSwitch); + TheSwitch->getParent()->getInstList().erase(TheSwitch); + break; + default: + // Otherwise, make the default destination of the switch instruction be one + // of the other successors. + TheSwitch->setOperand(0, call); + TheSwitch->setSuccessor(0, TheSwitch->getSuccessor(NumExitBlocks)); + TheSwitch->removeCase(NumExitBlocks); // Remove redundant case + break; + } +} + +void CodeExtractor::moveCodeToFunction(Function *newFunction) { + Function *oldFunc = (*BlocksToExtract.begin())->getParent(); + Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); + Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); + + for (std::set::const_iterator i = BlocksToExtract.begin(), + e = BlocksToExtract.end(); i != e; ++i) { + // Delete the basic block from the old function, and the list of blocks + oldBlocks.remove(*i); + + // Insert this basic block into the new function + newBlocks.push_back(*i); + } +} + +/// ExtractRegion - Removes a loop from a function, replaces it with a call to +/// new function. Returns pointer to the new function. +/// +/// algorithm: +/// +/// find inputs and outputs for the region +/// +/// for inputs: add to function as args, map input instr* to arg# +/// for outputs: add allocas for scalars, +/// add to func as args, map output instr* to arg# +/// +/// rewrite func to use argument #s instead of instr* +/// +/// for each scalar output in the function: at every exit, store intermediate +/// computed result back into memory. +/// +Function *CodeExtractor:: +ExtractCodeRegion(const std::vector &code) { + if (!isEligible(code)) + return 0; + + // 1) Find inputs, outputs + // 2) Construct new function + // * Add allocas for defs, pass as args by reference + // * Pass in uses as args + // 3) Move code region, add call instr to func + // + BlocksToExtract.insert(code.begin(), code.end()); + + Values inputs, outputs; + + // Assumption: this is a single-entry code region, and the header is the first + // block in the region. + BasicBlock *header = code[0]; + + for (unsigned i = 1, e = code.size(); i != e; ++i) + for (pred_iterator PI = pred_begin(code[i]), E = pred_end(code[i]); + PI != E; ++PI) + assert(BlocksToExtract.count(*PI) && + "No blocks in this region may have entries from outside the region" + " except for the first block!"); + + // If we have to split PHI nodes or the entry block, do so now. + severSplitPHINodes(header); + + // If we have any return instructions in the region, split those blocks so + // that the return is not in the region. + splitReturnBlocks(); + + Function *oldFunction = header->getParent(); + + // This takes place of the original loop + BasicBlock *codeReplacer = new BasicBlock("codeRepl", oldFunction, header); + + // The new function needs a root node because other nodes can branch to the + // head of the region, but the entry node of a function cannot have preds. + BasicBlock *newFuncRoot = new BasicBlock("newFuncRoot"); + newFuncRoot->getInstList().push_back(new BranchInst(header)); + + // Find inputs to, outputs from the code region. + findInputsOutputs(inputs, outputs); + + // Construct new function based on inputs/outputs & add allocas for all defs. + Function *newFunction = constructFunction(inputs, outputs, header, + newFuncRoot, + codeReplacer, oldFunction, + oldFunction->getParent()); + + emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); + + moveCodeToFunction(newFunction); + + // Loop over all of the PHI nodes in the header block, and change any + // references to the old incoming edge to be the new incoming edge. + for (BasicBlock::iterator I = header->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (!BlocksToExtract.count(PN->getIncomingBlock(i))) + PN->setIncomingBlock(i, newFuncRoot); + } + + // Look at all successors of the codeReplacer block. If any of these blocks + // had PHI nodes in them, we need to update the "from" block to be the code + // replacer, not the original block in the extracted region. + std::vector Succs(succ_begin(codeReplacer), + succ_end(codeReplacer)); + for (unsigned i = 0, e = Succs.size(); i != e; ++i) + for (BasicBlock::iterator I = Succs[i]->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + std::set ProcessedPreds; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (BlocksToExtract.count(PN->getIncomingBlock(i))) + if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second) + PN->setIncomingBlock(i, codeReplacer); + else { + // There were multiple entries in the PHI for this block, now there + // is only one, so remove the duplicated entries. + PN->removeIncomingValue(i, false); + --i; --e; + } + } + + //cerr << "NEW FUNCTION: " << *newFunction; + // verifyFunction(*newFunction); + + // cerr << "OLD FUNCTION: " << *oldFunction; + // verifyFunction(*oldFunction); + + DEBUG(if (verifyFunction(*newFunction)) abort()); + return newFunction; +} + +bool CodeExtractor::isEligible(const std::vector &code) { + // Deny code region if it contains allocas or vastarts. + for (std::vector::const_iterator BB = code.begin(), e=code.end(); + BB != e; ++BB) + for (BasicBlock::const_iterator I = (*BB)->begin(), Ie = (*BB)->end(); + I != Ie; ++I) + if (isa(*I)) + return false; + else if (const CallInst *CI = dyn_cast(I)) + if (const Function *F = CI->getCalledFunction()) + if (F->getIntrinsicID() == Intrinsic::vastart) + return false; + return true; +} + + +/// ExtractCodeRegion - slurp a sequence of basic blocks into a brand new +/// function +/// +Function* llvm::ExtractCodeRegion(DominatorTree &DT, + const std::vector &code, + bool AggregateArgs) { + return CodeExtractor(&DT, AggregateArgs).ExtractCodeRegion(code); +} + +/// ExtractBasicBlock - slurp a natural loop into a brand new function +/// +Function* llvm::ExtractLoop(DominatorTree &DT, Loop *L, bool AggregateArgs) { + return CodeExtractor(&DT, AggregateArgs).ExtractCodeRegion(L->getBlocks()); +} + +/// ExtractBasicBlock - slurp a basic block into a brand new function +/// +Function* llvm::ExtractBasicBlock(BasicBlock *BB, bool AggregateArgs) { + std::vector Blocks; + Blocks.push_back(BB); + return CodeExtractor(0, AggregateArgs).ExtractCodeRegion(Blocks); +} diff --git a/lib/Transforms/Utils/DemoteRegToStack.cpp b/lib/Transforms/Utils/DemoteRegToStack.cpp new file mode 100644 index 0000000..df332b2 --- /dev/null +++ b/lib/Transforms/Utils/DemoteRegToStack.cpp @@ -0,0 +1,133 @@ +//===- DemoteRegToStack.cpp - Move a virtual register to the stack --------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file provide the function DemoteRegToStack(). This function takes a +// virtual register computed by an Instruction and replaces it with a slot in +// the stack frame, allocated via alloca. It returns the pointer to the +// AllocaInst inserted. After this function is called on an instruction, we are +// guaranteed that the only user of the instruction is a store that is +// immediately after it. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Type.h" +#include +using namespace llvm; + +/// DemoteRegToStack - This function takes a virtual register computed by an +/// Instruction and replaces it with a slot in the stack frame, allocated via +/// alloca. This allows the CFG to be changed around without fear of +/// invalidating the SSA information for the value. It returns the pointer to +/// the alloca inserted to create a stack slot for I. +/// +AllocaInst* llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads) { + if (I.use_empty()) return 0; // nothing to do! + + // Create a stack slot to hold the value. + Function *F = I.getParent()->getParent(); + AllocaInst *Slot = new AllocaInst(I.getType(), 0, I.getName(), + F->getEntryBlock().begin()); + + // Change all of the users of the instruction to read from the stack slot + // instead. + while (!I.use_empty()) { + Instruction *U = cast(I.use_back()); + if (PHINode *PN = dyn_cast(U)) { + // If this is a PHI node, we can't insert a load of the value before the + // use. Instead, insert the load in the predecessor block corresponding + // to the incoming value. + // + // Note that if there are multiple edges from a basic block to this PHI + // node that we cannot multiple loads. The problem is that the resultant + // PHI node will have multiple values (from each load) coming in from the + // same block, which is illegal SSA form. For this reason, we keep track + // and reuse loads we insert. + std::map Loads; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == &I) { + Value *&V = Loads[PN->getIncomingBlock(i)]; + if (V == 0) { + // Insert the load into the predecessor block + V = new LoadInst(Slot, I.getName()+".reload", VolatileLoads, + PN->getIncomingBlock(i)->getTerminator()); + } + PN->setIncomingValue(i, V); + } + + } else { + // If this is a normal instruction, just insert a load. + Value *V = new LoadInst(Slot, I.getName()+".reload", VolatileLoads, U); + U->replaceUsesOfWith(&I, V); + } + } + + + // Insert stores of the computed value into the stack slot. We have to be + // careful is I is an invoke instruction though, because we can't insert the + // store AFTER the terminator instruction. + BasicBlock::iterator InsertPt; + if (!isa(I)) { + InsertPt = &I; + ++InsertPt; + } else { + // We cannot demote invoke instructions to the stack if their normal edge + // is critical. + InvokeInst &II = cast(I); + assert(II.getNormalDest()->getSinglePredecessor() && + "Cannot demote invoke with a critical successor!"); + InsertPt = II.getNormalDest()->begin(); + } + + for (; isa(InsertPt); ++InsertPt) + /* empty */; // Don't insert before any PHI nodes. + new StoreInst(&I, Slot, InsertPt); + + return Slot; +} + + +/// DemotePHIToStack - This function takes a virtual register computed by a phi +/// node and replaces it with a slot in the stack frame, allocated via alloca. +/// The phi node is deleted and it returns the pointer to the alloca inserted. +AllocaInst* llvm::DemotePHIToStack(PHINode *P) { + if (P->use_empty()) { + P->eraseFromParent(); + return 0; + } + + // Create a stack slot to hold the value. + Function *F = P->getParent()->getParent(); + AllocaInst *Slot = new AllocaInst(P->getType(), 0, P->getName(), + F->getEntryBlock().begin()); + + // Iterate over each operand, insert store in each predecessor. + for (unsigned i = 0, e = P->getNumIncomingValues(); i < e; ++i) { + if (InvokeInst *II = dyn_cast(P->getIncomingValue(i))) { + assert(II->getParent() != P->getIncomingBlock(i) && + "Invoke edge not supported yet"); + } + new StoreInst(P->getIncomingValue(i), Slot, + P->getIncomingBlock(i)->getTerminator()); + } + + // Insert load in place of the phi and replace all uses. + BasicBlock::iterator InsertPt; + for (InsertPt = P->getParent()->getInstList().begin(); + isa(InsertPt); ++InsertPt); + Value *V = new LoadInst(Slot, P->getName()+".reload", P); + P->replaceAllUsesWith(V); + + // Delete phi. + P->eraseFromParent(); + + return Slot; +} diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp new file mode 100644 index 0000000..9735a2f --- /dev/null +++ b/lib/Transforms/Utils/InlineFunction.cpp @@ -0,0 +1,496 @@ +//===- InlineFunction.cpp - Code to perform function inlining -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements inlining of a function into a call site, resolving +// parameters and the return value as appropriate. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Module.h" +#include "llvm/Instructions.h" +#include "llvm/Intrinsics.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CallSite.h" +using namespace llvm; + +bool llvm::InlineFunction(CallInst *CI, CallGraph *CG, const TargetData *TD) { + return InlineFunction(CallSite(CI), CG, TD); +} +bool llvm::InlineFunction(InvokeInst *II, CallGraph *CG, const TargetData *TD) { + return InlineFunction(CallSite(II), CG, TD); +} + +/// HandleInlinedInvoke - If we inlined an invoke site, we need to convert calls +/// in the body of the inlined function into invokes and turn unwind +/// instructions into branches to the invoke unwind dest. +/// +/// II is the invoke instruction begin inlined. FirstNewBlock is the first +/// block of the inlined code (the last block is the end of the function), +/// and InlineCodeInfo is information about the code that got inlined. +static void HandleInlinedInvoke(InvokeInst *II, BasicBlock *FirstNewBlock, + ClonedCodeInfo &InlinedCodeInfo) { + BasicBlock *InvokeDest = II->getUnwindDest(); + std::vector InvokeDestPHIValues; + + // If there are PHI nodes in the unwind destination block, we need to + // keep track of which values came into them from this invoke, then remove + // the entry for this block. + BasicBlock *InvokeBlock = II->getParent(); + for (BasicBlock::iterator I = InvokeDest->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + // Save the value to use for this edge. + InvokeDestPHIValues.push_back(PN->getIncomingValueForBlock(InvokeBlock)); + } + + Function *Caller = FirstNewBlock->getParent(); + + // The inlined code is currently at the end of the function, scan from the + // start of the inlined code to its end, checking for stuff we need to + // rewrite. + if (InlinedCodeInfo.ContainsCalls || InlinedCodeInfo.ContainsUnwinds) { + for (Function::iterator BB = FirstNewBlock, E = Caller->end(); + BB != E; ++BB) { + if (InlinedCodeInfo.ContainsCalls) { + for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E; ){ + Instruction *I = BBI++; + + // We only need to check for function calls: inlined invoke + // instructions require no special handling. + if (!isa(I)) continue; + CallInst *CI = cast(I); + + // If this is an intrinsic function call or an inline asm, don't + // convert it to an invoke. + if ((CI->getCalledFunction() && + CI->getCalledFunction()->getIntrinsicID()) || + isa(CI->getCalledValue())) + continue; + + // Convert this function call into an invoke instruction. + // First, split the basic block. + BasicBlock *Split = BB->splitBasicBlock(CI, CI->getName()+".noexc"); + + // Next, create the new invoke instruction, inserting it at the end + // of the old basic block. + SmallVector InvokeArgs(CI->op_begin()+1, CI->op_end()); + InvokeInst *II = + new InvokeInst(CI->getCalledValue(), Split, InvokeDest, + &InvokeArgs[0], InvokeArgs.size(), + CI->getName(), BB->getTerminator()); + II->setCallingConv(CI->getCallingConv()); + + // Make sure that anything using the call now uses the invoke! + CI->replaceAllUsesWith(II); + + // Delete the unconditional branch inserted by splitBasicBlock + BB->getInstList().pop_back(); + Split->getInstList().pop_front(); // Delete the original call + + // Update any PHI nodes in the exceptional block to indicate that + // there is now a new entry in them. + unsigned i = 0; + for (BasicBlock::iterator I = InvokeDest->begin(); + isa(I); ++I, ++i) { + PHINode *PN = cast(I); + PN->addIncoming(InvokeDestPHIValues[i], BB); + } + + // This basic block is now complete, start scanning the next one. + break; + } + } + + if (UnwindInst *UI = dyn_cast(BB->getTerminator())) { + // An UnwindInst requires special handling when it gets inlined into an + // invoke site. Once this happens, we know that the unwind would cause + // a control transfer to the invoke exception destination, so we can + // transform it into a direct branch to the exception destination. + new BranchInst(InvokeDest, UI); + + // Delete the unwind instruction! + UI->getParent()->getInstList().pop_back(); + + // Update any PHI nodes in the exceptional block to indicate that + // there is now a new entry in them. + unsigned i = 0; + for (BasicBlock::iterator I = InvokeDest->begin(); + isa(I); ++I, ++i) { + PHINode *PN = cast(I); + PN->addIncoming(InvokeDestPHIValues[i], BB); + } + } + } + } + + // Now that everything is happy, we have one final detail. The PHI nodes in + // the exception destination block still have entries due to the original + // invoke instruction. Eliminate these entries (which might even delete the + // PHI node) now. + InvokeDest->removePredecessor(II->getParent()); +} + +/// UpdateCallGraphAfterInlining - Once we have cloned code over from a callee +/// into the caller, update the specified callgraph to reflect the changes we +/// made. Note that it's possible that not all code was copied over, so only +/// some edges of the callgraph will be remain. +static void UpdateCallGraphAfterInlining(const Function *Caller, + const Function *Callee, + Function::iterator FirstNewBlock, + DenseMap &ValueMap, + CallGraph &CG) { + // Update the call graph by deleting the edge from Callee to Caller + CallGraphNode *CalleeNode = CG[Callee]; + CallGraphNode *CallerNode = CG[Caller]; + CallerNode->removeCallEdgeTo(CalleeNode); + + // Since we inlined some uninlined call sites in the callee into the caller, + // add edges from the caller to all of the callees of the callee. + for (CallGraphNode::iterator I = CalleeNode->begin(), + E = CalleeNode->end(); I != E; ++I) { + const Instruction *OrigCall = I->first.getInstruction(); + + DenseMap::iterator VMI = ValueMap.find(OrigCall); + // Only copy the edge if the call was inlined! + if (VMI != ValueMap.end() && VMI->second) { + // If the call was inlined, but then constant folded, there is no edge to + // add. Check for this case. + if (Instruction *NewCall = dyn_cast(VMI->second)) + CallerNode->addCalledFunction(CallSite::get(NewCall), I->second); + } + } +} + + +// InlineFunction - This function inlines the called function into the basic +// block of the caller. This returns false if it is not possible to inline this +// call. The program is still in a well defined state if this occurs though. +// +// Note that this only does one level of inlining. For example, if the +// instruction 'call B' is inlined, and 'B' calls 'C', then the call to 'C' now +// exists in the instruction stream. Similiarly this will inline a recursive +// function by one level. +// +bool llvm::InlineFunction(CallSite CS, CallGraph *CG, const TargetData *TD) { + Instruction *TheCall = CS.getInstruction(); + assert(TheCall->getParent() && TheCall->getParent()->getParent() && + "Instruction not in function!"); + + const Function *CalledFunc = CS.getCalledFunction(); + if (CalledFunc == 0 || // Can't inline external function or indirect + CalledFunc->isDeclaration() || // call, or call to a vararg function! + CalledFunc->getFunctionType()->isVarArg()) return false; + + + // If the call to the callee is a non-tail call, we must clear the 'tail' + // flags on any calls that we inline. + bool MustClearTailCallFlags = + isa(TheCall) && !cast(TheCall)->isTailCall(); + + BasicBlock *OrigBB = TheCall->getParent(); + Function *Caller = OrigBB->getParent(); + + // Get an iterator to the last basic block in the function, which will have + // the new function inlined after it. + // + Function::iterator LastBlock = &Caller->back(); + + // Make sure to capture all of the return instructions from the cloned + // function. + std::vector Returns; + ClonedCodeInfo InlinedFunctionInfo; + Function::iterator FirstNewBlock; + + { // Scope to destroy ValueMap after cloning. + DenseMap ValueMap; + + // Calculate the vector of arguments to pass into the function cloner, which + // matches up the formal to the actual argument values. + assert(std::distance(CalledFunc->arg_begin(), CalledFunc->arg_end()) == + std::distance(CS.arg_begin(), CS.arg_end()) && + "No varargs calls can be inlined!"); + CallSite::arg_iterator AI = CS.arg_begin(); + for (Function::const_arg_iterator I = CalledFunc->arg_begin(), + E = CalledFunc->arg_end(); I != E; ++I, ++AI) + ValueMap[I] = *AI; + + // We want the inliner to prune the code as it copies. We would LOVE to + // have no dead or constant instructions leftover after inlining occurs + // (which can happen, e.g., because an argument was constant), but we'll be + // happy with whatever the cloner can do. + CloneAndPruneFunctionInto(Caller, CalledFunc, ValueMap, Returns, ".i", + &InlinedFunctionInfo, TD); + + // Remember the first block that is newly cloned over. + FirstNewBlock = LastBlock; ++FirstNewBlock; + + // Update the callgraph if requested. + if (CG) + UpdateCallGraphAfterInlining(Caller, CalledFunc, FirstNewBlock, ValueMap, + *CG); + } + + // If there are any alloca instructions in the block that used to be the entry + // block for the callee, move them to the entry block of the caller. First + // calculate which instruction they should be inserted before. We insert the + // instructions at the end of the current alloca list. + // + { + BasicBlock::iterator InsertPoint = Caller->begin()->begin(); + for (BasicBlock::iterator I = FirstNewBlock->begin(), + E = FirstNewBlock->end(); I != E; ) + if (AllocaInst *AI = dyn_cast(I++)) { + // If the alloca is now dead, remove it. This often occurs due to code + // specialization. + if (AI->use_empty()) { + AI->eraseFromParent(); + continue; + } + + if (isa(AI->getArraySize())) { + // Scan for the block of allocas that we can move over, and move them + // all at once. + while (isa(I) && + isa(cast(I)->getArraySize())) + ++I; + + // Transfer all of the allocas over in a block. Using splice means + // that the instructions aren't removed from the symbol table, then + // reinserted. + Caller->getEntryBlock().getInstList().splice( + InsertPoint, + FirstNewBlock->getInstList(), + AI, I); + } + } + } + + // If the inlined code contained dynamic alloca instructions, wrap the inlined + // code with llvm.stacksave/llvm.stackrestore intrinsics. + if (InlinedFunctionInfo.ContainsDynamicAllocas) { + Module *M = Caller->getParent(); + const Type *BytePtr = PointerType::get(Type::Int8Ty); + // Get the two intrinsics we care about. + Constant *StackSave, *StackRestore; + StackSave = M->getOrInsertFunction("llvm.stacksave", BytePtr, NULL); + StackRestore = M->getOrInsertFunction("llvm.stackrestore", Type::VoidTy, + BytePtr, NULL); + + // If we are preserving the callgraph, add edges to the stacksave/restore + // functions for the calls we insert. + CallGraphNode *StackSaveCGN = 0, *StackRestoreCGN = 0, *CallerNode = 0; + if (CG) { + // We know that StackSave/StackRestore are Function*'s, because they are + // intrinsics which must have the right types. + StackSaveCGN = CG->getOrInsertFunction(cast(StackSave)); + StackRestoreCGN = CG->getOrInsertFunction(cast(StackRestore)); + CallerNode = (*CG)[Caller]; + } + + // Insert the llvm.stacksave. + CallInst *SavedPtr = new CallInst(StackSave, "savedstack", + FirstNewBlock->begin()); + if (CG) CallerNode->addCalledFunction(SavedPtr, StackSaveCGN); + + // Insert a call to llvm.stackrestore before any return instructions in the + // inlined function. + for (unsigned i = 0, e = Returns.size(); i != e; ++i) { + CallInst *CI = new CallInst(StackRestore, SavedPtr, "", Returns[i]); + if (CG) CallerNode->addCalledFunction(CI, StackRestoreCGN); + } + + // Count the number of StackRestore calls we insert. + unsigned NumStackRestores = Returns.size(); + + // If we are inlining an invoke instruction, insert restores before each + // unwind. These unwinds will be rewritten into branches later. + if (InlinedFunctionInfo.ContainsUnwinds && isa(TheCall)) { + for (Function::iterator BB = FirstNewBlock, E = Caller->end(); + BB != E; ++BB) + if (UnwindInst *UI = dyn_cast(BB->getTerminator())) { + new CallInst(StackRestore, SavedPtr, "", UI); + ++NumStackRestores; + } + } + } + + // If we are inlining tail call instruction through a call site that isn't + // marked 'tail', we must remove the tail marker for any calls in the inlined + // code. + if (MustClearTailCallFlags && InlinedFunctionInfo.ContainsCalls) { + for (Function::iterator BB = FirstNewBlock, E = Caller->end(); + BB != E; ++BB) + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (CallInst *CI = dyn_cast(I)) + CI->setTailCall(false); + } + + // If we are inlining for an invoke instruction, we must make sure to rewrite + // any inlined 'unwind' instructions into branches to the invoke exception + // destination, and call instructions into invoke instructions. + if (InvokeInst *II = dyn_cast(TheCall)) + HandleInlinedInvoke(II, FirstNewBlock, InlinedFunctionInfo); + + // If we cloned in _exactly one_ basic block, and if that block ends in a + // return instruction, we splice the body of the inlined callee directly into + // the calling basic block. + if (Returns.size() == 1 && std::distance(FirstNewBlock, Caller->end()) == 1) { + // Move all of the instructions right before the call. + OrigBB->getInstList().splice(TheCall, FirstNewBlock->getInstList(), + FirstNewBlock->begin(), FirstNewBlock->end()); + // Remove the cloned basic block. + Caller->getBasicBlockList().pop_back(); + + // If the call site was an invoke instruction, add a branch to the normal + // destination. + if (InvokeInst *II = dyn_cast(TheCall)) + new BranchInst(II->getNormalDest(), TheCall); + + // If the return instruction returned a value, replace uses of the call with + // uses of the returned value. + if (!TheCall->use_empty()) + TheCall->replaceAllUsesWith(Returns[0]->getReturnValue()); + + // Since we are now done with the Call/Invoke, we can delete it. + TheCall->getParent()->getInstList().erase(TheCall); + + // Since we are now done with the return instruction, delete it also. + Returns[0]->getParent()->getInstList().erase(Returns[0]); + + // We are now done with the inlining. + return true; + } + + // Otherwise, we have the normal case, of more than one block to inline or + // multiple return sites. + + // We want to clone the entire callee function into the hole between the + // "starter" and "ender" blocks. How we accomplish this depends on whether + // this is an invoke instruction or a call instruction. + BasicBlock *AfterCallBB; + if (InvokeInst *II = dyn_cast(TheCall)) { + + // Add an unconditional branch to make this look like the CallInst case... + BranchInst *NewBr = new BranchInst(II->getNormalDest(), TheCall); + + // Split the basic block. This guarantees that no PHI nodes will have to be + // updated due to new incoming edges, and make the invoke case more + // symmetric to the call case. + AfterCallBB = OrigBB->splitBasicBlock(NewBr, + CalledFunc->getName()+".exit"); + + } else { // It's a call + // If this is a call instruction, we need to split the basic block that + // the call lives in. + // + AfterCallBB = OrigBB->splitBasicBlock(TheCall, + CalledFunc->getName()+".exit"); + } + + // Change the branch that used to go to AfterCallBB to branch to the first + // basic block of the inlined function. + // + TerminatorInst *Br = OrigBB->getTerminator(); + assert(Br && Br->getOpcode() == Instruction::Br && + "splitBasicBlock broken!"); + Br->setOperand(0, FirstNewBlock); + + + // Now that the function is correct, make it a little bit nicer. In + // particular, move the basic blocks inserted from the end of the function + // into the space made by splitting the source basic block. + // + Caller->getBasicBlockList().splice(AfterCallBB, Caller->getBasicBlockList(), + FirstNewBlock, Caller->end()); + + // Handle all of the return instructions that we just cloned in, and eliminate + // any users of the original call/invoke instruction. + if (Returns.size() > 1) { + // The PHI node should go at the front of the new basic block to merge all + // possible incoming values. + // + PHINode *PHI = 0; + if (!TheCall->use_empty()) { + PHI = new PHINode(CalledFunc->getReturnType(), + TheCall->getName(), AfterCallBB->begin()); + + // Anything that used the result of the function call should now use the + // PHI node as their operand. + // + TheCall->replaceAllUsesWith(PHI); + } + + // Loop over all of the return instructions, turning them into unconditional + // branches to the merge point now, and adding entries to the PHI node as + // appropriate. + for (unsigned i = 0, e = Returns.size(); i != e; ++i) { + ReturnInst *RI = Returns[i]; + + if (PHI) { + assert(RI->getReturnValue() && "Ret should have value!"); + assert(RI->getReturnValue()->getType() == PHI->getType() && + "Ret value not consistent in function!"); + PHI->addIncoming(RI->getReturnValue(), RI->getParent()); + } + + // Add a branch to the merge point where the PHI node lives if it exists. + new BranchInst(AfterCallBB, RI); + + // Delete the return instruction now + RI->getParent()->getInstList().erase(RI); + } + + } else if (!Returns.empty()) { + // Otherwise, if there is exactly one return value, just replace anything + // using the return value of the call with the computed value. + if (!TheCall->use_empty()) + TheCall->replaceAllUsesWith(Returns[0]->getReturnValue()); + + // Splice the code from the return block into the block that it will return + // to, which contains the code that was after the call. + BasicBlock *ReturnBB = Returns[0]->getParent(); + AfterCallBB->getInstList().splice(AfterCallBB->begin(), + ReturnBB->getInstList()); + + // Update PHI nodes that use the ReturnBB to use the AfterCallBB. + ReturnBB->replaceAllUsesWith(AfterCallBB); + + // Delete the return instruction now and empty ReturnBB now. + Returns[0]->eraseFromParent(); + ReturnBB->eraseFromParent(); + } else if (!TheCall->use_empty()) { + // No returns, but something is using the return value of the call. Just + // nuke the result. + TheCall->replaceAllUsesWith(UndefValue::get(TheCall->getType())); + } + + // Since we are now done with the Call/Invoke, we can delete it. + TheCall->eraseFromParent(); + + // We should always be able to fold the entry block of the function into the + // single predecessor of the block... + assert(cast(Br)->isUnconditional() && "splitBasicBlock broken!"); + BasicBlock *CalleeEntry = cast(Br)->getSuccessor(0); + + // Splice the code entry block into calling block, right before the + // unconditional branch. + OrigBB->getInstList().splice(Br, CalleeEntry->getInstList()); + CalleeEntry->replaceAllUsesWith(OrigBB); // Update PHI nodes + + // Remove the unconditional branch. + OrigBB->getInstList().erase(Br); + + // Now we can remove the CalleeEntry block, which is now empty. + Caller->getBasicBlockList().erase(CalleeEntry); + + return true; +} diff --git a/lib/Transforms/Utils/LCSSA.cpp b/lib/Transforms/Utils/LCSSA.cpp new file mode 100644 index 0000000..220241d --- /dev/null +++ b/lib/Transforms/Utils/LCSSA.cpp @@ -0,0 +1,269 @@ +//===-- LCSSA.cpp - Convert loops into loop-closed SSA form ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by Owen Anderson and is distributed under the +// University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass transforms loops by placing phi nodes at the end of the loops for +// all values that are live across the loop boundary. For example, it turns +// the left into the right code: +// +// for (...) for (...) +// if (c) if (c) +// X1 = ... X1 = ... +// else else +// X2 = ... X2 = ... +// X3 = phi(X1, X2) X3 = phi(X1, X2) +// ... = X3 + 4 X4 = phi(X3) +// ... = X4 + 4 +// +// This is still valid LLVM; the extra phi nodes are purely redundant, and will +// be trivially eliminated by InstCombine. The major benefit of this +// transformation is that it makes many other loop optimizations, such as +// LoopUnswitching, simpler. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "lcssa" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/Pass.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include +#include +using namespace llvm; + +STATISTIC(NumLCSSA, "Number of live out of a loop variables"); + +namespace { + struct VISIBILITY_HIDDEN LCSSA : public LoopPass { + static char ID; // Pass identification, replacement for typeid + LCSSA() : LoopPass((intptr_t)&ID) {} + + // Cached analysis information for the current function. + LoopInfo *LI; + DominatorTree *DT; + std::vector LoopBlocks; + + virtual bool runOnLoop(Loop *L, LPPassManager &LPM); + + void ProcessInstruction(Instruction* Instr, + const std::vector& exitBlocks); + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG. It maintains both of these, + /// as well as the CFG. It also requires dominator information. + /// + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequiredID(LoopSimplifyID); + AU.addPreservedID(LoopSimplifyID); + AU.addRequired(); + AU.addPreserved(); + AU.addRequired(); + AU.addPreserved(); + } + private: + void getLoopValuesUsedOutsideLoop(Loop *L, + SetVector &AffectedValues); + + Value *GetValueForBlock(DomTreeNode *BB, Instruction *OrigInst, + std::map &Phis); + + /// inLoop - returns true if the given block is within the current loop + const bool inLoop(BasicBlock* B) { + return std::binary_search(LoopBlocks.begin(), LoopBlocks.end(), B); + } + }; + + char LCSSA::ID = 0; + RegisterPass X("lcssa", "Loop-Closed SSA Form Pass"); +} + +LoopPass *llvm::createLCSSAPass() { return new LCSSA(); } +const PassInfo *llvm::LCSSAID = X.getPassInfo(); + +/// runOnFunction - Process all loops in the function, inner-most out. +bool LCSSA::runOnLoop(Loop *L, LPPassManager &LPM) { + + LI = &LPM.getAnalysis(); + DT = &getAnalysis(); + + // Speed up queries by creating a sorted list of blocks + LoopBlocks.clear(); + LoopBlocks.insert(LoopBlocks.end(), L->block_begin(), L->block_end()); + std::sort(LoopBlocks.begin(), LoopBlocks.end()); + + SetVector AffectedValues; + getLoopValuesUsedOutsideLoop(L, AffectedValues); + + // If no values are affected, we can save a lot of work, since we know that + // nothing will be changed. + if (AffectedValues.empty()) + return false; + + std::vector exitBlocks; + L->getExitBlocks(exitBlocks); + + + // Iterate over all affected values for this loop and insert Phi nodes + // for them in the appropriate exit blocks + + for (SetVector::iterator I = AffectedValues.begin(), + E = AffectedValues.end(); I != E; ++I) + ProcessInstruction(*I, exitBlocks); + + assert(L->isLCSSAForm()); + + return true; +} + +/// processInstruction - Given a live-out instruction, insert LCSSA Phi nodes, +/// eliminate all out-of-loop uses. +void LCSSA::ProcessInstruction(Instruction *Instr, + const std::vector& exitBlocks) { + ++NumLCSSA; // We are applying the transformation + + // Keep track of the blocks that have the value available already. + std::map Phis; + + DomTreeNode *InstrNode = DT->getNode(Instr->getParent()); + + // Insert the LCSSA phi's into the exit blocks (dominated by the value), and + // add them to the Phi's map. + for (std::vector::const_iterator BBI = exitBlocks.begin(), + BBE = exitBlocks.end(); BBI != BBE; ++BBI) { + BasicBlock *BB = *BBI; + DomTreeNode *ExitBBNode = DT->getNode(BB); + Value *&Phi = Phis[ExitBBNode]; + if (!Phi && DT->dominates(InstrNode, ExitBBNode)) { + PHINode *PN = new PHINode(Instr->getType(), Instr->getName()+".lcssa", + BB->begin()); + PN->reserveOperandSpace(std::distance(pred_begin(BB), pred_end(BB))); + + // Remember that this phi makes the value alive in this block. + Phi = PN; + + // Add inputs from inside the loop for this PHI. + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) + PN->addIncoming(Instr, *PI); + } + } + + + // Record all uses of Instr outside the loop. We need to rewrite these. The + // LCSSA phis won't be included because they use the value in the loop. + for (Value::use_iterator UI = Instr->use_begin(), E = Instr->use_end(); + UI != E;) { + BasicBlock *UserBB = cast(*UI)->getParent(); + if (PHINode *P = dyn_cast(*UI)) { + unsigned OperandNo = UI.getOperandNo(); + UserBB = P->getIncomingBlock(OperandNo/2); + } + + // If the user is in the loop, don't rewrite it! + if (UserBB == Instr->getParent() || inLoop(UserBB)) { + ++UI; + continue; + } + + // Otherwise, patch up uses of the value with the appropriate LCSSA Phi, + // inserting PHI nodes into join points where needed. + Value *Val = GetValueForBlock(DT->getNode(UserBB), Instr, Phis); + + // Preincrement the iterator to avoid invalidating it when we change the + // value. + Use &U = UI.getUse(); + ++UI; + U.set(Val); + } +} + +/// getLoopValuesUsedOutsideLoop - Return any values defined in the loop that +/// are used by instructions outside of it. +void LCSSA::getLoopValuesUsedOutsideLoop(Loop *L, + SetVector &AffectedValues) { + // FIXME: For large loops, we may be able to avoid a lot of use-scanning + // by using dominance information. In particular, if a block does not + // dominate any of the loop exits, then none of the values defined in the + // block could be used outside the loop. + for (Loop::block_iterator BB = L->block_begin(), E = L->block_end(); + BB != E; ++BB) { + for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); I != E; ++I) + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); UI != E; + ++UI) { + BasicBlock *UserBB = cast(*UI)->getParent(); + if (PHINode* p = dyn_cast(*UI)) { + unsigned OperandNo = UI.getOperandNo(); + UserBB = p->getIncomingBlock(OperandNo/2); + } + + if (*BB != UserBB && !inLoop(UserBB)) { + AffectedValues.insert(I); + break; + } + } + } +} + +/// GetValueForBlock - Get the value to use within the specified basic block. +/// available values are in Phis. +Value *LCSSA::GetValueForBlock(DomTreeNode *BB, Instruction *OrigInst, + std::map &Phis) { + // If there is no dominator info for this BB, it is unreachable. + if (BB == 0) + return UndefValue::get(OrigInst->getType()); + + // If we have already computed this value, return the previously computed val. + Value *&V = Phis[BB]; + if (V) return V; + + DomTreeNode *IDom = BB->getIDom(); + + // If the block has no dominator, bail + if (!IDom) + return V = UndefValue::get(OrigInst->getType()); + + // Otherwise, there are two cases: we either have to insert a PHI node or we + // don't. We need to insert a PHI node if this block is not dominated by one + // of the exit nodes from the loop (the loop could have multiple exits, and + // though the value defined *inside* the loop dominated all its uses, each + // exit by itself may not dominate all the uses). + // + // The simplest way to check for this condition is by checking to see if the + // idom is in the loop. If so, we *know* that none of the exit blocks + // dominate this block. Note that we *know* that the block defining the + // original instruction is in the idom chain, because if it weren't, then the + // original value didn't dominate this use. + if (!inLoop(IDom->getBlock())) { + // Idom is not in the loop, we must still be "below" the exit block and must + // be fully dominated by the value live in the idom. + return V = GetValueForBlock(IDom, OrigInst, Phis); + } + + BasicBlock *BBN = BB->getBlock(); + + // Otherwise, the idom is the loop, so we need to insert a PHI node. Do so + // now, then get values to fill in the incoming values for the PHI. + PHINode *PN = new PHINode(OrigInst->getType(), OrigInst->getName()+".lcssa", + BBN->begin()); + PN->reserveOperandSpace(std::distance(pred_begin(BBN), pred_end(BBN))); + V = PN; + + // Fill in the incoming values for the block. + for (pred_iterator PI = pred_begin(BBN), E = pred_end(BBN); PI != E; ++PI) + PN->addIncoming(GetValueForBlock(DT->getNode(*PI), OrigInst, Phis), *PI); + return PN; +} + diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp new file mode 100644 index 0000000..5e2d237 --- /dev/null +++ b/lib/Transforms/Utils/Local.cpp @@ -0,0 +1,200 @@ +//===-- Local.cpp - Functions to perform local transformations ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This family of functions perform various local transformations to the +// program. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Intrinsics.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Support/MathExtras.h" +#include +using namespace llvm; + +//===----------------------------------------------------------------------===// +// Local constant propagation... +// + +/// doConstantPropagation - If an instruction references constants, try to fold +/// them together... +/// +bool llvm::doConstantPropagation(BasicBlock::iterator &II, + const TargetData *TD) { + if (Constant *C = ConstantFoldInstruction(II, TD)) { + // Replaces all of the uses of a variable with uses of the constant. + II->replaceAllUsesWith(C); + + // Remove the instruction from the basic block... + II = II->getParent()->getInstList().erase(II); + return true; + } + + return false; +} + +// ConstantFoldTerminator - If a terminator instruction is predicated on a +// constant value, convert it into an unconditional branch to the constant +// destination. +// +bool llvm::ConstantFoldTerminator(BasicBlock *BB) { + TerminatorInst *T = BB->getTerminator(); + + // Branch - See if we are conditional jumping on constant + if (BranchInst *BI = dyn_cast(T)) { + if (BI->isUnconditional()) return false; // Can't optimize uncond branch + BasicBlock *Dest1 = cast(BI->getOperand(0)); + BasicBlock *Dest2 = cast(BI->getOperand(1)); + + if (ConstantInt *Cond = dyn_cast(BI->getCondition())) { + // Are we branching on constant? + // YES. Change to unconditional branch... + BasicBlock *Destination = Cond->getZExtValue() ? Dest1 : Dest2; + BasicBlock *OldDest = Cond->getZExtValue() ? Dest2 : Dest1; + + //cerr << "Function: " << T->getParent()->getParent() + // << "\nRemoving branch from " << T->getParent() + // << "\n\nTo: " << OldDest << endl; + + // Let the basic block know that we are letting go of it. Based on this, + // it will adjust it's PHI nodes. + assert(BI->getParent() && "Terminator not inserted in block!"); + OldDest->removePredecessor(BI->getParent()); + + // Set the unconditional destination, and change the insn to be an + // unconditional branch. + BI->setUnconditionalDest(Destination); + return true; + } else if (Dest2 == Dest1) { // Conditional branch to same location? + // This branch matches something like this: + // br bool %cond, label %Dest, label %Dest + // and changes it into: br label %Dest + + // Let the basic block know that we are letting go of one copy of it. + assert(BI->getParent() && "Terminator not inserted in block!"); + Dest1->removePredecessor(BI->getParent()); + + // Change a conditional branch to unconditional. + BI->setUnconditionalDest(Dest1); + return true; + } + } else if (SwitchInst *SI = dyn_cast(T)) { + // If we are switching on a constant, we can convert the switch into a + // single branch instruction! + ConstantInt *CI = dyn_cast(SI->getCondition()); + BasicBlock *TheOnlyDest = SI->getSuccessor(0); // The default dest + BasicBlock *DefaultDest = TheOnlyDest; + assert(TheOnlyDest == SI->getDefaultDest() && + "Default destination is not successor #0?"); + + // Figure out which case it goes to... + for (unsigned i = 1, e = SI->getNumSuccessors(); i != e; ++i) { + // Found case matching a constant operand? + if (SI->getSuccessorValue(i) == CI) { + TheOnlyDest = SI->getSuccessor(i); + break; + } + + // Check to see if this branch is going to the same place as the default + // dest. If so, eliminate it as an explicit compare. + if (SI->getSuccessor(i) == DefaultDest) { + // Remove this entry... + DefaultDest->removePredecessor(SI->getParent()); + SI->removeCase(i); + --i; --e; // Don't skip an entry... + continue; + } + + // Otherwise, check to see if the switch only branches to one destination. + // We do this by reseting "TheOnlyDest" to null when we find two non-equal + // destinations. + if (SI->getSuccessor(i) != TheOnlyDest) TheOnlyDest = 0; + } + + if (CI && !TheOnlyDest) { + // Branching on a constant, but not any of the cases, go to the default + // successor. + TheOnlyDest = SI->getDefaultDest(); + } + + // If we found a single destination that we can fold the switch into, do so + // now. + if (TheOnlyDest) { + // Insert the new branch.. + new BranchInst(TheOnlyDest, SI); + BasicBlock *BB = SI->getParent(); + + // Remove entries from PHI nodes which we no longer branch to... + for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { + // Found case matching a constant operand? + BasicBlock *Succ = SI->getSuccessor(i); + if (Succ == TheOnlyDest) + TheOnlyDest = 0; // Don't modify the first branch to TheOnlyDest + else + Succ->removePredecessor(BB); + } + + // Delete the old switch... + BB->getInstList().erase(SI); + return true; + } else if (SI->getNumSuccessors() == 2) { + // Otherwise, we can fold this switch into a conditional branch + // instruction if it has only one non-default destination. + Value *Cond = new ICmpInst(ICmpInst::ICMP_EQ, SI->getCondition(), + SI->getSuccessorValue(1), "cond", SI); + // Insert the new branch... + new BranchInst(SI->getSuccessor(1), SI->getSuccessor(0), Cond, SI); + + // Delete the old switch... + SI->getParent()->getInstList().erase(SI); + return true; + } + } + return false; +} + + +//===----------------------------------------------------------------------===// +// Local dead code elimination... +// + +bool llvm::isInstructionTriviallyDead(Instruction *I) { + if (!I->use_empty() || isa(I)) return false; + + if (!I->mayWriteToMemory()) return true; + + if (CallInst *CI = dyn_cast(I)) + if (Function *F = CI->getCalledFunction()) { + unsigned IntrinsicID = F->getIntrinsicID(); +#define GET_SIDE_EFFECT_INFO +#include "llvm/Intrinsics.gen" +#undef GET_SIDE_EFFECT_INFO + } + return false; +} + +// dceInstruction - Inspect the instruction at *BBI and figure out if it's +// [trivially] dead. If so, remove the instruction and update the iterator +// to point to the instruction that immediately succeeded the original +// instruction. +// +bool llvm::dceInstruction(BasicBlock::iterator &BBI) { + // Look for un"used" definitions... + if (isInstructionTriviallyDead(BBI)) { + BBI = BBI->getParent()->getInstList().erase(BBI); // Bye bye + return true; + } + return false; +} diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp new file mode 100644 index 0000000..0a5de2b --- /dev/null +++ b/lib/Transforms/Utils/LoopSimplify.cpp @@ -0,0 +1,692 @@ +//===- LoopSimplify.cpp - Loop Canonicalization Pass ----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs several transformations to transform natural loops into a +// simpler form, which makes subsequent analyses and transformations simpler and +// more effective. +// +// Loop pre-header insertion guarantees that there is a single, non-critical +// entry edge from outside of the loop to the loop header. This simplifies a +// number of analyses and transformations, such as LICM. +// +// Loop exit-block insertion guarantees that all exit blocks from the loop +// (blocks which are outside of the loop that have predecessors inside of the +// loop) only have predecessors from inside of the loop (and are thus dominated +// by the loop header). This simplifies transformations such as store-sinking +// that are built into LICM. +// +// This pass also guarantees that loops will have exactly one backedge. +// +// Note that the simplifycfg pass will clean up blocks which are split out but +// end up being unnecessary, so usage of this pass should not pessimize +// generated code. +// +// This pass obviously modifies the CFG, but updates loop information and +// dominator information. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loopsimplify" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constant.h" +#include "llvm/Instructions.h" +#include "llvm/Function.h" +#include "llvm/Type.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/DepthFirstIterator.h" +using namespace llvm; + +STATISTIC(NumInserted, "Number of pre-header or exit blocks inserted"); +STATISTIC(NumNested , "Number of nested loops split out"); + +namespace { + struct VISIBILITY_HIDDEN LoopSimplify : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + LoopSimplify() : FunctionPass((intptr_t)&ID) {} + + // AA - If we have an alias analysis object to update, this is it, otherwise + // this is null. + AliasAnalysis *AA; + LoopInfo *LI; + DominatorTree *DT; + virtual bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + // We need loop information to identify the loops... + AU.addRequired(); + AU.addRequired(); + + AU.addPreserved(); + AU.addPreserved(); + AU.addPreserved(); + AU.addPreservedID(BreakCriticalEdgesID); // No critical edges added. + } + private: + bool ProcessLoop(Loop *L); + BasicBlock *SplitBlockPredecessors(BasicBlock *BB, const char *Suffix, + const std::vector &Preds); + BasicBlock *RewriteLoopExitBlock(Loop *L, BasicBlock *Exit); + void InsertPreheaderForLoop(Loop *L); + Loop *SeparateNestedLoop(Loop *L); + void InsertUniqueBackedgeBlock(Loop *L); + void PlaceSplitBlockCarefully(BasicBlock *NewBB, + std::vector &SplitPreds, + Loop *L); + }; + + char LoopSimplify::ID = 0; + RegisterPass + X("loopsimplify", "Canonicalize natural loops", true); +} + +// Publically exposed interface to pass... +const PassInfo *llvm::LoopSimplifyID = X.getPassInfo(); +FunctionPass *llvm::createLoopSimplifyPass() { return new LoopSimplify(); } + +/// runOnFunction - Run down all loops in the CFG (recursively, but we could do +/// it in any convenient order) inserting preheaders... +/// +bool LoopSimplify::runOnFunction(Function &F) { + bool Changed = false; + LI = &getAnalysis(); + AA = getAnalysisToUpdate(); + DT = &getAnalysis(); + + // Check to see that no blocks (other than the header) in loops have + // predecessors that are not in loops. This is not valid for natural loops, + // but can occur if the blocks are unreachable. Since they are unreachable we + // can just shamelessly destroy their terminators to make them not branch into + // the loop! + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { + // This case can only occur for unreachable blocks. Blocks that are + // unreachable can't be in loops, so filter those blocks out. + if (LI->getLoopFor(BB)) continue; + + bool BlockUnreachable = false; + TerminatorInst *TI = BB->getTerminator(); + + // Check to see if any successors of this block are non-loop-header loops + // that are not the header. + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + // If this successor is not in a loop, BB is clearly ok. + Loop *L = LI->getLoopFor(TI->getSuccessor(i)); + if (!L) continue; + + // If the succ is the loop header, and if L is a top-level loop, then this + // is an entrance into a loop through the header, which is also ok. + if (L->getHeader() == TI->getSuccessor(i) && L->getParentLoop() == 0) + continue; + + // Otherwise, this is an entrance into a loop from some place invalid. + // Either the loop structure is invalid and this is not a natural loop (in + // which case the compiler is buggy somewhere else) or BB is unreachable. + BlockUnreachable = true; + break; + } + + // If this block is ok, check the next one. + if (!BlockUnreachable) continue; + + // Otherwise, this block is dead. To clean up the CFG and to allow later + // loop transformations to ignore this case, we delete the edges into the + // loop by replacing the terminator. + + // Remove PHI entries from the successors. + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + TI->getSuccessor(i)->removePredecessor(BB); + + // Add a new unreachable instruction. + new UnreachableInst(TI); + + // Delete the dead terminator. + if (AA) AA->deleteValue(&BB->back()); + BB->getInstList().pop_back(); + Changed |= true; + } + + for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) + Changed |= ProcessLoop(*I); + + return Changed; +} + +/// ProcessLoop - Walk the loop structure in depth first order, ensuring that +/// all loops have preheaders. +/// +bool LoopSimplify::ProcessLoop(Loop *L) { + bool Changed = false; +ReprocessLoop: + + // Canonicalize inner loops before outer loops. Inner loop canonicalization + // can provide work for the outer loop to canonicalize. + for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) + Changed |= ProcessLoop(*I); + + assert(L->getBlocks()[0] == L->getHeader() && + "Header isn't first block in loop?"); + + // Does the loop already have a preheader? If so, don't insert one. + if (L->getLoopPreheader() == 0) { + InsertPreheaderForLoop(L); + NumInserted++; + Changed = true; + } + + // Next, check to make sure that all exit nodes of the loop only have + // predecessors that are inside of the loop. This check guarantees that the + // loop preheader/header will dominate the exit blocks. If the exit block has + // predecessors from outside of the loop, split the edge now. + std::vector ExitBlocks; + L->getExitBlocks(ExitBlocks); + + SetVector ExitBlockSet(ExitBlocks.begin(), ExitBlocks.end()); + for (SetVector::iterator I = ExitBlockSet.begin(), + E = ExitBlockSet.end(); I != E; ++I) { + BasicBlock *ExitBlock = *I; + for (pred_iterator PI = pred_begin(ExitBlock), PE = pred_end(ExitBlock); + PI != PE; ++PI) + // Must be exactly this loop: no subloops, parent loops, or non-loop preds + // allowed. + if (!L->contains(*PI)) { + RewriteLoopExitBlock(L, ExitBlock); + NumInserted++; + Changed = true; + break; + } + } + + // If the header has more than two predecessors at this point (from the + // preheader and from multiple backedges), we must adjust the loop. + unsigned NumBackedges = L->getNumBackEdges(); + if (NumBackedges != 1) { + // If this is really a nested loop, rip it out into a child loop. Don't do + // this for loops with a giant number of backedges, just factor them into a + // common backedge instead. + if (NumBackedges < 8) { + if (Loop *NL = SeparateNestedLoop(L)) { + ++NumNested; + // This is a big restructuring change, reprocess the whole loop. + ProcessLoop(NL); + Changed = true; + // GCC doesn't tail recursion eliminate this. + goto ReprocessLoop; + } + } + + // If we either couldn't, or didn't want to, identify nesting of the loops, + // insert a new block that all backedges target, then make it jump to the + // loop header. + InsertUniqueBackedgeBlock(L); + NumInserted++; + Changed = true; + } + + // Scan over the PHI nodes in the loop header. Since they now have only two + // incoming values (the loop is canonicalized), we may have simplified the PHI + // down to 'X = phi [X, Y]', which should be replaced with 'Y'. + PHINode *PN; + for (BasicBlock::iterator I = L->getHeader()->begin(); + (PN = dyn_cast(I++)); ) + if (Value *V = PN->hasConstantValue()) { + PN->replaceAllUsesWith(V); + PN->eraseFromParent(); + } + + return Changed; +} + +/// SplitBlockPredecessors - Split the specified block into two blocks. We want +/// to move the predecessors specified in the Preds list to point to the new +/// block, leaving the remaining predecessors pointing to BB. This method +/// updates the SSA PHINode's, but no other analyses. +/// +BasicBlock *LoopSimplify::SplitBlockPredecessors(BasicBlock *BB, + const char *Suffix, + const std::vector &Preds) { + + // Create new basic block, insert right before the original block... + BasicBlock *NewBB = new BasicBlock(BB->getName()+Suffix, BB->getParent(), BB); + + // The preheader first gets an unconditional branch to the loop header... + BranchInst *BI = new BranchInst(BB, NewBB); + + // For every PHI node in the block, insert a PHI node into NewBB where the + // incoming values from the out of loop edges are moved to NewBB. We have two + // possible cases here. If the loop is dead, we just insert dummy entries + // into the PHI nodes for the new edge. If the loop is not dead, we move the + // incoming edges in BB into new PHI nodes in NewBB. + // + if (!Preds.empty()) { // Is the loop not obviously dead? + // Check to see if the values being merged into the new block need PHI + // nodes. If so, insert them. + for (BasicBlock::iterator I = BB->begin(); isa(I); ) { + PHINode *PN = cast(I); + ++I; + + // Check to see if all of the values coming in are the same. If so, we + // don't need to create a new PHI node. + Value *InVal = PN->getIncomingValueForBlock(Preds[0]); + for (unsigned i = 1, e = Preds.size(); i != e; ++i) + if (InVal != PN->getIncomingValueForBlock(Preds[i])) { + InVal = 0; + break; + } + + // If the values coming into the block are not the same, we need a PHI. + if (InVal == 0) { + // Create the new PHI node, insert it into NewBB at the end of the block + PHINode *NewPHI = new PHINode(PN->getType(), PN->getName()+".ph", BI); + if (AA) AA->copyValue(PN, NewPHI); + + // Move all of the edges from blocks outside the loop to the new PHI + for (unsigned i = 0, e = Preds.size(); i != e; ++i) { + Value *V = PN->removeIncomingValue(Preds[i], false); + NewPHI->addIncoming(V, Preds[i]); + } + InVal = NewPHI; + } else { + // Remove all of the edges coming into the PHI nodes from outside of the + // block. + for (unsigned i = 0, e = Preds.size(); i != e; ++i) + PN->removeIncomingValue(Preds[i], false); + } + + // Add an incoming value to the PHI node in the loop for the preheader + // edge. + PN->addIncoming(InVal, NewBB); + + // Can we eliminate this phi node now? + if (Value *V = PN->hasConstantValue(true)) { + Instruction *I = dyn_cast(V); + // If I is in NewBB, the Dominator call will fail, because NewBB isn't + // registered in DominatorTree yet. Handle this case explicitly. + if (!I || (I->getParent() != NewBB && + getAnalysis().dominates(I, PN))) { + PN->replaceAllUsesWith(V); + if (AA) AA->deleteValue(PN); + BB->getInstList().erase(PN); + } + } + } + + // Now that the PHI nodes are updated, actually move the edges from + // Preds to point to NewBB instead of BB. + // + for (unsigned i = 0, e = Preds.size(); i != e; ++i) { + TerminatorInst *TI = Preds[i]->getTerminator(); + for (unsigned s = 0, e = TI->getNumSuccessors(); s != e; ++s) + if (TI->getSuccessor(s) == BB) + TI->setSuccessor(s, NewBB); + } + + } else { // Otherwise the loop is dead... + for (BasicBlock::iterator I = BB->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + // Insert dummy values as the incoming value... + PN->addIncoming(Constant::getNullValue(PN->getType()), NewBB); + } + } + + return NewBB; +} + +/// InsertPreheaderForLoop - Once we discover that a loop doesn't have a +/// preheader, this method is called to insert one. This method has two phases: +/// preheader insertion and analysis updating. +/// +void LoopSimplify::InsertPreheaderForLoop(Loop *L) { + BasicBlock *Header = L->getHeader(); + + // Compute the set of predecessors of the loop that are not in the loop. + std::vector OutsideBlocks; + for (pred_iterator PI = pred_begin(Header), PE = pred_end(Header); + PI != PE; ++PI) + if (!L->contains(*PI)) // Coming in from outside the loop? + OutsideBlocks.push_back(*PI); // Keep track of it... + + // Split out the loop pre-header. + BasicBlock *NewBB = + SplitBlockPredecessors(Header, ".preheader", OutsideBlocks); + + + //===--------------------------------------------------------------------===// + // Update analysis results now that we have performed the transformation + // + + // We know that we have loop information to update... update it now. + if (Loop *Parent = L->getParentLoop()) + Parent->addBasicBlockToLoop(NewBB, *LI); + + DT->splitBlock(NewBB); + if (DominanceFrontier *DF = getAnalysisToUpdate()) + DF->splitBlock(NewBB); + + // Make sure that NewBB is put someplace intelligent, which doesn't mess up + // code layout too horribly. + PlaceSplitBlockCarefully(NewBB, OutsideBlocks, L); +} + +/// RewriteLoopExitBlock - Ensure that the loop preheader dominates all exit +/// blocks. This method is used to split exit blocks that have predecessors +/// outside of the loop. +BasicBlock *LoopSimplify::RewriteLoopExitBlock(Loop *L, BasicBlock *Exit) { + std::vector LoopBlocks; + for (pred_iterator I = pred_begin(Exit), E = pred_end(Exit); I != E; ++I) + if (L->contains(*I)) + LoopBlocks.push_back(*I); + + assert(!LoopBlocks.empty() && "No edges coming in from outside the loop?"); + BasicBlock *NewBB = SplitBlockPredecessors(Exit, ".loopexit", LoopBlocks); + + // Update Loop Information - we know that the new block will be in whichever + // loop the Exit block is in. Note that it may not be in that immediate loop, + // if the successor is some other loop header. In that case, we continue + // walking up the loop tree to find a loop that contains both the successor + // block and the predecessor block. + Loop *SuccLoop = LI->getLoopFor(Exit); + while (SuccLoop && !SuccLoop->contains(L->getHeader())) + SuccLoop = SuccLoop->getParentLoop(); + if (SuccLoop) + SuccLoop->addBasicBlockToLoop(NewBB, *LI); + + // Update Dominator Information + DT->splitBlock(NewBB); + if (DominanceFrontier *DF = getAnalysisToUpdate()) + DF->splitBlock(NewBB); + + return NewBB; +} + +/// AddBlockAndPredsToSet - Add the specified block, and all of its +/// predecessors, to the specified set, if it's not already in there. Stop +/// predecessor traversal when we reach StopBlock. +static void AddBlockAndPredsToSet(BasicBlock *InputBB, BasicBlock *StopBlock, + std::set &Blocks) { + std::vector WorkList; + WorkList.push_back(InputBB); + do { + BasicBlock *BB = WorkList.back(); WorkList.pop_back(); + if (Blocks.insert(BB).second && BB != StopBlock) + // If BB is not already processed and it is not a stop block then + // insert its predecessor in the work list + for (pred_iterator I = pred_begin(BB), E = pred_end(BB); I != E; ++I) { + BasicBlock *WBB = *I; + WorkList.push_back(WBB); + } + } while(!WorkList.empty()); +} + +/// FindPHIToPartitionLoops - The first part of loop-nestification is to find a +/// PHI node that tells us how to partition the loops. +static PHINode *FindPHIToPartitionLoops(Loop *L, DominatorTree *DT, + AliasAnalysis *AA) { + for (BasicBlock::iterator I = L->getHeader()->begin(); isa(I); ) { + PHINode *PN = cast(I); + ++I; + if (Value *V = PN->hasConstantValue()) + if (!isa(V) || DT->dominates(cast(V), PN)) { + // This is a degenerate PHI already, don't modify it! + PN->replaceAllUsesWith(V); + if (AA) AA->deleteValue(PN); + PN->eraseFromParent(); + continue; + } + + // Scan this PHI node looking for a use of the PHI node by itself. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == PN && + L->contains(PN->getIncomingBlock(i))) + // We found something tasty to remove. + return PN; + } + return 0; +} + +// PlaceSplitBlockCarefully - If the block isn't already, move the new block to +// right after some 'outside block' block. This prevents the preheader from +// being placed inside the loop body, e.g. when the loop hasn't been rotated. +void LoopSimplify::PlaceSplitBlockCarefully(BasicBlock *NewBB, + std::vector&SplitPreds, + Loop *L) { + // Check to see if NewBB is already well placed. + Function::iterator BBI = NewBB; --BBI; + for (unsigned i = 0, e = SplitPreds.size(); i != e; ++i) { + if (&*BBI == SplitPreds[i]) + return; + } + + // If it isn't already after an outside block, move it after one. This is + // always good as it makes the uncond branch from the outside block into a + // fall-through. + + // Figure out *which* outside block to put this after. Prefer an outside + // block that neighbors a BB actually in the loop. + BasicBlock *FoundBB = 0; + for (unsigned i = 0, e = SplitPreds.size(); i != e; ++i) { + Function::iterator BBI = SplitPreds[i]; + if (++BBI != NewBB->getParent()->end() && + L->contains(BBI)) { + FoundBB = SplitPreds[i]; + break; + } + } + + // If our heuristic for a *good* bb to place this after doesn't find + // anything, just pick something. It's likely better than leaving it within + // the loop. + if (!FoundBB) + FoundBB = SplitPreds[0]; + NewBB->moveAfter(FoundBB); +} + + +/// SeparateNestedLoop - If this loop has multiple backedges, try to pull one of +/// them out into a nested loop. This is important for code that looks like +/// this: +/// +/// Loop: +/// ... +/// br cond, Loop, Next +/// ... +/// br cond2, Loop, Out +/// +/// To identify this common case, we look at the PHI nodes in the header of the +/// loop. PHI nodes with unchanging values on one backedge correspond to values +/// that change in the "outer" loop, but not in the "inner" loop. +/// +/// If we are able to separate out a loop, return the new outer loop that was +/// created. +/// +Loop *LoopSimplify::SeparateNestedLoop(Loop *L) { + PHINode *PN = FindPHIToPartitionLoops(L, DT, AA); + if (PN == 0) return 0; // No known way to partition. + + // Pull out all predecessors that have varying values in the loop. This + // handles the case when a PHI node has multiple instances of itself as + // arguments. + std::vector OuterLoopPreds; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) != PN || + !L->contains(PN->getIncomingBlock(i))) + OuterLoopPreds.push_back(PN->getIncomingBlock(i)); + + BasicBlock *Header = L->getHeader(); + BasicBlock *NewBB = SplitBlockPredecessors(Header, ".outer", OuterLoopPreds); + + // Update dominator information + DT->splitBlock(NewBB); + if (DominanceFrontier *DF = getAnalysisToUpdate()) + DF->splitBlock(NewBB); + + // Make sure that NewBB is put someplace intelligent, which doesn't mess up + // code layout too horribly. + PlaceSplitBlockCarefully(NewBB, OuterLoopPreds, L); + + // Create the new outer loop. + Loop *NewOuter = new Loop(); + + // Change the parent loop to use the outer loop as its child now. + if (Loop *Parent = L->getParentLoop()) + Parent->replaceChildLoopWith(L, NewOuter); + else + LI->changeTopLevelLoop(L, NewOuter); + + // This block is going to be our new header block: add it to this loop and all + // parent loops. + NewOuter->addBasicBlockToLoop(NewBB, *LI); + + // L is now a subloop of our outer loop. + NewOuter->addChildLoop(L); + + for (unsigned i = 0, e = L->getBlocks().size(); i != e; ++i) + NewOuter->addBlockEntry(L->getBlocks()[i]); + + // Determine which blocks should stay in L and which should be moved out to + // the Outer loop now. + std::set BlocksInL; + for (pred_iterator PI = pred_begin(Header), E = pred_end(Header); PI!=E; ++PI) + if (DT->dominates(Header, *PI)) + AddBlockAndPredsToSet(*PI, Header, BlocksInL); + + + // Scan all of the loop children of L, moving them to OuterLoop if they are + // not part of the inner loop. + const std::vector &SubLoops = L->getSubLoops(); + for (size_t I = 0; I != SubLoops.size(); ) + if (BlocksInL.count(SubLoops[I]->getHeader())) + ++I; // Loop remains in L + else + NewOuter->addChildLoop(L->removeChildLoop(SubLoops.begin() + I)); + + // Now that we know which blocks are in L and which need to be moved to + // OuterLoop, move any blocks that need it. + for (unsigned i = 0; i != L->getBlocks().size(); ++i) { + BasicBlock *BB = L->getBlocks()[i]; + if (!BlocksInL.count(BB)) { + // Move this block to the parent, updating the exit blocks sets + L->removeBlockFromLoop(BB); + if ((*LI)[BB] == L) + LI->changeLoopFor(BB, NewOuter); + --i; + } + } + + return NewOuter; +} + + + +/// InsertUniqueBackedgeBlock - This method is called when the specified loop +/// has more than one backedge in it. If this occurs, revector all of these +/// backedges to target a new basic block and have that block branch to the loop +/// header. This ensures that loops have exactly one backedge. +/// +void LoopSimplify::InsertUniqueBackedgeBlock(Loop *L) { + assert(L->getNumBackEdges() > 1 && "Must have > 1 backedge!"); + + // Get information about the loop + BasicBlock *Preheader = L->getLoopPreheader(); + BasicBlock *Header = L->getHeader(); + Function *F = Header->getParent(); + + // Figure out which basic blocks contain back-edges to the loop header. + std::vector BackedgeBlocks; + for (pred_iterator I = pred_begin(Header), E = pred_end(Header); I != E; ++I) + if (*I != Preheader) BackedgeBlocks.push_back(*I); + + // Create and insert the new backedge block... + BasicBlock *BEBlock = new BasicBlock(Header->getName()+".backedge", F); + BranchInst *BETerminator = new BranchInst(Header, BEBlock); + + // Move the new backedge block to right after the last backedge block. + Function::iterator InsertPos = BackedgeBlocks.back(); ++InsertPos; + F->getBasicBlockList().splice(InsertPos, F->getBasicBlockList(), BEBlock); + + // Now that the block has been inserted into the function, create PHI nodes in + // the backedge block which correspond to any PHI nodes in the header block. + for (BasicBlock::iterator I = Header->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + PHINode *NewPN = new PHINode(PN->getType(), PN->getName()+".be", + BETerminator); + NewPN->reserveOperandSpace(BackedgeBlocks.size()); + if (AA) AA->copyValue(PN, NewPN); + + // Loop over the PHI node, moving all entries except the one for the + // preheader over to the new PHI node. + unsigned PreheaderIdx = ~0U; + bool HasUniqueIncomingValue = true; + Value *UniqueValue = 0; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + BasicBlock *IBB = PN->getIncomingBlock(i); + Value *IV = PN->getIncomingValue(i); + if (IBB == Preheader) { + PreheaderIdx = i; + } else { + NewPN->addIncoming(IV, IBB); + if (HasUniqueIncomingValue) { + if (UniqueValue == 0) + UniqueValue = IV; + else if (UniqueValue != IV) + HasUniqueIncomingValue = false; + } + } + } + + // Delete all of the incoming values from the old PN except the preheader's + assert(PreheaderIdx != ~0U && "PHI has no preheader entry??"); + if (PreheaderIdx != 0) { + PN->setIncomingValue(0, PN->getIncomingValue(PreheaderIdx)); + PN->setIncomingBlock(0, PN->getIncomingBlock(PreheaderIdx)); + } + // Nuke all entries except the zero'th. + for (unsigned i = 0, e = PN->getNumIncomingValues()-1; i != e; ++i) + PN->removeIncomingValue(e-i, false); + + // Finally, add the newly constructed PHI node as the entry for the BEBlock. + PN->addIncoming(NewPN, BEBlock); + + // As an optimization, if all incoming values in the new PhiNode (which is a + // subset of the incoming values of the old PHI node) have the same value, + // eliminate the PHI Node. + if (HasUniqueIncomingValue) { + NewPN->replaceAllUsesWith(UniqueValue); + if (AA) AA->deleteValue(NewPN); + BEBlock->getInstList().erase(NewPN); + } + } + + // Now that all of the PHI nodes have been inserted and adjusted, modify the + // backedge blocks to just to the BEBlock instead of the header. + for (unsigned i = 0, e = BackedgeBlocks.size(); i != e; ++i) { + TerminatorInst *TI = BackedgeBlocks[i]->getTerminator(); + for (unsigned Op = 0, e = TI->getNumSuccessors(); Op != e; ++Op) + if (TI->getSuccessor(Op) == Header) + TI->setSuccessor(Op, BEBlock); + } + + //===--- Update all analyses which we must preserve now -----------------===// + + // Update Loop Information - we know that this block is now in the current + // loop and all parent loops. + L->addBasicBlockToLoop(BEBlock, *LI); + + // Update dominator information + DT->splitBlock(BEBlock); + if (DominanceFrontier *DF = getAnalysisToUpdate()) + DF->splitBlock(BEBlock); +} + + diff --git a/lib/Transforms/Utils/LowerAllocations.cpp b/lib/Transforms/Utils/LowerAllocations.cpp new file mode 100644 index 0000000..7ce2479 --- /dev/null +++ b/lib/Transforms/Utils/LowerAllocations.cpp @@ -0,0 +1,176 @@ +//===- LowerAllocations.cpp - Reduce malloc & free insts to calls ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// The LowerAllocations transformation is a target-dependent tranformation +// because it depends on the size of data types and alignment constraints. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "lowerallocs" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" +#include "llvm/Module.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Constants.h" +#include "llvm/Pass.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumLowered, "Number of allocations lowered"); + +namespace { + /// LowerAllocations - Turn malloc and free instructions into %malloc and + /// %free calls. + /// + class VISIBILITY_HIDDEN LowerAllocations : public BasicBlockPass { + Constant *MallocFunc; // Functions in the module we are processing + Constant *FreeFunc; // Initialized by doInitialization + bool LowerMallocArgToInteger; + public: + static char ID; // Pass ID, replacement for typeid + LowerAllocations(bool LowerToInt = false) + : BasicBlockPass((intptr_t)&ID), MallocFunc(0), FreeFunc(0), + LowerMallocArgToInteger(LowerToInt) {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.setPreservesCFG(); + + // This is a cluster of orthogonal Transforms: + AU.addPreserved(); + AU.addPreservedID(PromoteMemoryToRegisterID); + AU.addPreservedID(LowerSelectID); + AU.addPreservedID(LowerSwitchID); + AU.addPreservedID(LowerInvokePassID); + } + + /// doPassInitialization - For the lower allocations pass, this ensures that + /// a module contains a declaration for a malloc and a free function. + /// + bool doInitialization(Module &M); + + virtual bool doInitialization(Function &F) { + return BasicBlockPass::doInitialization(F); + } + + /// runOnBasicBlock - This method does the actual work of converting + /// instructions over, assuming that the pass has already been initialized. + /// + bool runOnBasicBlock(BasicBlock &BB); + }; + + char LowerAllocations::ID = 0; + RegisterPass + X("lowerallocs", "Lower allocations from instructions to calls"); +} + +// Publically exposed interface to pass... +const PassInfo *llvm::LowerAllocationsID = X.getPassInfo(); +// createLowerAllocationsPass - Interface to this file... +Pass *llvm::createLowerAllocationsPass(bool LowerMallocArgToInteger) { + return new LowerAllocations(LowerMallocArgToInteger); +} + + +// doInitialization - For the lower allocations pass, this ensures that a +// module contains a declaration for a malloc and a free function. +// +// This function is always successful. +// +bool LowerAllocations::doInitialization(Module &M) { + const Type *BPTy = PointerType::get(Type::Int8Ty); + // Prototype malloc as "char* malloc(...)", because we don't know in + // doInitialization whether size_t is int or long. + FunctionType *FT = FunctionType::get(BPTy, std::vector(), true); + MallocFunc = M.getOrInsertFunction("malloc", FT); + FreeFunc = M.getOrInsertFunction("free" , Type::VoidTy, BPTy, (Type *)0); + return true; +} + +// runOnBasicBlock - This method does the actual work of converting +// instructions over, assuming that the pass has already been initialized. +// +bool LowerAllocations::runOnBasicBlock(BasicBlock &BB) { + bool Changed = false; + assert(MallocFunc && FreeFunc && "Pass not initialized!"); + + BasicBlock::InstListType &BBIL = BB.getInstList(); + + const TargetData &TD = getAnalysis(); + const Type *IntPtrTy = TD.getIntPtrType(); + + // Loop over all of the instructions, looking for malloc or free instructions + for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) { + if (MallocInst *MI = dyn_cast(I)) { + const Type *AllocTy = MI->getType()->getElementType(); + + // malloc(type) becomes sbyte *malloc(size) + Value *MallocArg; + if (LowerMallocArgToInteger) + MallocArg = ConstantInt::get(Type::Int64Ty, TD.getTypeSize(AllocTy)); + else + MallocArg = ConstantExpr::getSizeOf(AllocTy); + MallocArg = ConstantExpr::getTruncOrBitCast(cast(MallocArg), + IntPtrTy); + + if (MI->isArrayAllocation()) { + if (isa(MallocArg) && + cast(MallocArg)->isOne()) { + MallocArg = MI->getOperand(0); // Operand * 1 = Operand + } else if (Constant *CO = dyn_cast(MI->getOperand(0))) { + CO = ConstantExpr::getIntegerCast(CO, IntPtrTy, false /*ZExt*/); + MallocArg = ConstantExpr::getMul(CO, cast(MallocArg)); + } else { + Value *Scale = MI->getOperand(0); + if (Scale->getType() != IntPtrTy) + Scale = CastInst::createIntegerCast(Scale, IntPtrTy, false /*ZExt*/, + "", I); + + // Multiply it by the array size if necessary... + MallocArg = BinaryOperator::create(Instruction::Mul, Scale, + MallocArg, "", I); + } + } + + // Create the call to Malloc. + CallInst *MCall = new CallInst(MallocFunc, MallocArg, "", I); + MCall->setTailCall(); + + // Create a cast instruction to convert to the right type... + Value *MCast; + if (MCall->getType() != Type::VoidTy) + MCast = new BitCastInst(MCall, MI->getType(), "", I); + else + MCast = Constant::getNullValue(MI->getType()); + + // Replace all uses of the old malloc inst with the cast inst + MI->replaceAllUsesWith(MCast); + I = --BBIL.erase(I); // remove and delete the malloc instr... + Changed = true; + ++NumLowered; + } else if (FreeInst *FI = dyn_cast(I)) { + Value *PtrCast = new BitCastInst(FI->getOperand(0), + PointerType::get(Type::Int8Ty), "", I); + + // Insert a call to the free function... + (new CallInst(FreeFunc, PtrCast, "", I))->setTailCall(); + + // Delete the old free instruction + I = --BBIL.erase(I); + Changed = true; + ++NumLowered; + } + } + + return Changed; +} + diff --git a/lib/Transforms/Utils/LowerInvoke.cpp b/lib/Transforms/Utils/LowerInvoke.cpp new file mode 100644 index 0000000..d72c018 --- /dev/null +++ b/lib/Transforms/Utils/LowerInvoke.cpp @@ -0,0 +1,585 @@ +//===- LowerInvoke.cpp - Eliminate Invoke & Unwind instructions -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This transformation is designed for use by code generators which do not yet +// support stack unwinding. This pass supports two models of exception handling +// lowering, the 'cheap' support and the 'expensive' support. +// +// 'Cheap' exception handling support gives the program the ability to execute +// any program which does not "throw an exception", by turning 'invoke' +// instructions into calls and by turning 'unwind' instructions into calls to +// abort(). If the program does dynamically use the unwind instruction, the +// program will print a message then abort. +// +// 'Expensive' exception handling support gives the full exception handling +// support to the program at the cost of making the 'invoke' instruction +// really expensive. It basically inserts setjmp/longjmp calls to emulate the +// exception handling as necessary. +// +// Because the 'expensive' support slows down programs a lot, and EH is only +// used for a subset of the programs, it must be specifically enabled by an +// option. +// +// Note that after this pass runs the CFG is not entirely accurate (exceptional +// control flow edges are not correct anymore) so only very simple things should +// be done after the lowerinvoke pass has run (like generation of native code). +// This should not be used as a general purpose "my LLVM-to-LLVM pass doesn't +// support the invoke instruction yet" lowering pass. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "lowerinvoke" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Target/TargetLowering.h" +#include +#include +using namespace llvm; + +STATISTIC(NumInvokes, "Number of invokes replaced"); +STATISTIC(NumUnwinds, "Number of unwinds replaced"); +STATISTIC(NumSpilled, "Number of registers live across unwind edges"); + +static cl::opt ExpensiveEHSupport("enable-correct-eh-support", + cl::desc("Make the -lowerinvoke pass insert expensive, but correct, EH code")); + +namespace { + class VISIBILITY_HIDDEN LowerInvoke : public FunctionPass { + // Used for both models. + Constant *WriteFn; + Constant *AbortFn; + Value *AbortMessage; + unsigned AbortMessageLength; + + // Used for expensive EH support. + const Type *JBLinkTy; + GlobalVariable *JBListHead; + Constant *SetJmpFn, *LongJmpFn; + + // We peek in TLI to grab the target's jmp_buf size and alignment + const TargetLowering *TLI; + + public: + static char ID; // Pass identification, replacement for typeid + LowerInvoke(const TargetLowering *tli = NULL) : FunctionPass((intptr_t)&ID), + TLI(tli) { } + bool doInitialization(Module &M); + bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + // This is a cluster of orthogonal Transforms + AU.addPreservedID(PromoteMemoryToRegisterID); + AU.addPreservedID(LowerSelectID); + AU.addPreservedID(LowerSwitchID); + AU.addPreservedID(LowerAllocationsID); + } + + private: + void createAbortMessage(Module *M); + void writeAbortMessage(Instruction *IB); + bool insertCheapEHSupport(Function &F); + void splitLiveRangesLiveAcrossInvokes(std::vector &Invokes); + void rewriteExpensiveInvoke(InvokeInst *II, unsigned InvokeNo, + AllocaInst *InvokeNum, SwitchInst *CatchSwitch); + bool insertExpensiveEHSupport(Function &F); + }; + + char LowerInvoke::ID = 0; + RegisterPass + X("lowerinvoke", "Lower invoke and unwind, for unwindless code generators"); +} + +const PassInfo *llvm::LowerInvokePassID = X.getPassInfo(); + +// Public Interface To the LowerInvoke pass. +FunctionPass *llvm::createLowerInvokePass(const TargetLowering *TLI) { + return new LowerInvoke(TLI); +} + +// doInitialization - Make sure that there is a prototype for abort in the +// current module. +bool LowerInvoke::doInitialization(Module &M) { + const Type *VoidPtrTy = PointerType::get(Type::Int8Ty); + AbortMessage = 0; + if (ExpensiveEHSupport) { + // Insert a type for the linked list of jump buffers. + unsigned JBSize = TLI ? TLI->getJumpBufSize() : 0; + JBSize = JBSize ? JBSize : 200; + const Type *JmpBufTy = ArrayType::get(VoidPtrTy, JBSize); + + { // The type is recursive, so use a type holder. + std::vector Elements; + Elements.push_back(JmpBufTy); + OpaqueType *OT = OpaqueType::get(); + Elements.push_back(PointerType::get(OT)); + PATypeHolder JBLType(StructType::get(Elements)); + OT->refineAbstractTypeTo(JBLType.get()); // Complete the cycle. + JBLinkTy = JBLType.get(); + M.addTypeName("llvm.sjljeh.jmpbufty", JBLinkTy); + } + + const Type *PtrJBList = PointerType::get(JBLinkTy); + + // Now that we've done that, insert the jmpbuf list head global, unless it + // already exists. + if (!(JBListHead = M.getGlobalVariable("llvm.sjljeh.jblist", PtrJBList))) { + JBListHead = new GlobalVariable(PtrJBList, false, + GlobalValue::LinkOnceLinkage, + Constant::getNullValue(PtrJBList), + "llvm.sjljeh.jblist", &M); + } + SetJmpFn = M.getOrInsertFunction("llvm.setjmp", Type::Int32Ty, + PointerType::get(JmpBufTy), (Type *)0); + LongJmpFn = M.getOrInsertFunction("llvm.longjmp", Type::VoidTy, + PointerType::get(JmpBufTy), + Type::Int32Ty, (Type *)0); + } + + // We need the 'write' and 'abort' functions for both models. + AbortFn = M.getOrInsertFunction("abort", Type::VoidTy, (Type *)0); +#if 0 // "write" is Unix-specific.. code is going away soon anyway. + WriteFn = M.getOrInsertFunction("write", Type::VoidTy, Type::Int32Ty, + VoidPtrTy, Type::Int32Ty, (Type *)0); +#else + WriteFn = 0; +#endif + return true; +} + +void LowerInvoke::createAbortMessage(Module *M) { + if (ExpensiveEHSupport) { + // The abort message for expensive EH support tells the user that the + // program 'unwound' without an 'invoke' instruction. + Constant *Msg = + ConstantArray::get("ERROR: Exception thrown, but not caught!\n"); + AbortMessageLength = Msg->getNumOperands()-1; // don't include \0 + + GlobalVariable *MsgGV = new GlobalVariable(Msg->getType(), true, + GlobalValue::InternalLinkage, + Msg, "abortmsg", M); + std::vector GEPIdx(2, Constant::getNullValue(Type::Int32Ty)); + AbortMessage = ConstantExpr::getGetElementPtr(MsgGV, &GEPIdx[0], 2); + } else { + // The abort message for cheap EH support tells the user that EH is not + // enabled. + Constant *Msg = + ConstantArray::get("Exception handler needed, but not enabled. Recompile" + " program with -enable-correct-eh-support.\n"); + AbortMessageLength = Msg->getNumOperands()-1; // don't include \0 + + GlobalVariable *MsgGV = new GlobalVariable(Msg->getType(), true, + GlobalValue::InternalLinkage, + Msg, "abortmsg", M); + std::vector GEPIdx(2, Constant::getNullValue(Type::Int32Ty)); + AbortMessage = ConstantExpr::getGetElementPtr(MsgGV, &GEPIdx[0], 2); + } +} + + +void LowerInvoke::writeAbortMessage(Instruction *IB) { +#if 0 + if (AbortMessage == 0) + createAbortMessage(IB->getParent()->getParent()->getParent()); + + // These are the arguments we WANT... + Value* Args[3]; + Args[0] = ConstantInt::get(Type::Int32Ty, 2); + Args[1] = AbortMessage; + Args[2] = ConstantInt::get(Type::Int32Ty, AbortMessageLength); + (new CallInst(WriteFn, Args, 3, "", IB))->setTailCall(); +#endif +} + +bool LowerInvoke::insertCheapEHSupport(Function &F) { + bool Changed = false; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (InvokeInst *II = dyn_cast(BB->getTerminator())) { + std::vector CallArgs(II->op_begin()+3, II->op_end()); + // Insert a normal call instruction... + CallInst *NewCall = new CallInst(II->getCalledValue(), + &CallArgs[0], CallArgs.size(), "", II); + NewCall->takeName(II); + NewCall->setCallingConv(II->getCallingConv()); + II->replaceAllUsesWith(NewCall); + + // Insert an unconditional branch to the normal destination. + new BranchInst(II->getNormalDest(), II); + + // Remove any PHI node entries from the exception destination. + II->getUnwindDest()->removePredecessor(BB); + + // Remove the invoke instruction now. + BB->getInstList().erase(II); + + ++NumInvokes; Changed = true; + } else if (UnwindInst *UI = dyn_cast(BB->getTerminator())) { + // Insert a new call to write(2, AbortMessage, AbortMessageLength); + writeAbortMessage(UI); + + // Insert a call to abort() + (new CallInst(AbortFn, "", UI))->setTailCall(); + + // Insert a return instruction. This really should be a "barrier", as it + // is unreachable. + new ReturnInst(F.getReturnType() == Type::VoidTy ? 0 : + Constant::getNullValue(F.getReturnType()), UI); + + // Remove the unwind instruction now. + BB->getInstList().erase(UI); + + ++NumUnwinds; Changed = true; + } + return Changed; +} + +/// rewriteExpensiveInvoke - Insert code and hack the function to replace the +/// specified invoke instruction with a call. +void LowerInvoke::rewriteExpensiveInvoke(InvokeInst *II, unsigned InvokeNo, + AllocaInst *InvokeNum, + SwitchInst *CatchSwitch) { + ConstantInt *InvokeNoC = ConstantInt::get(Type::Int32Ty, InvokeNo); + + // Insert a store of the invoke num before the invoke and store zero into the + // location afterward. + new StoreInst(InvokeNoC, InvokeNum, true, II); // volatile + + BasicBlock::iterator NI = II->getNormalDest()->begin(); + while (isa(NI)) ++NI; + // nonvolatile. + new StoreInst(Constant::getNullValue(Type::Int32Ty), InvokeNum, false, NI); + + // Add a switch case to our unwind block. + CatchSwitch->addCase(InvokeNoC, II->getUnwindDest()); + + // Insert a normal call instruction. + std::vector CallArgs(II->op_begin()+3, II->op_end()); + CallInst *NewCall = new CallInst(II->getCalledValue(), + &CallArgs[0], CallArgs.size(), "", + II); + NewCall->takeName(II); + NewCall->setCallingConv(II->getCallingConv()); + II->replaceAllUsesWith(NewCall); + + // Replace the invoke with an uncond branch. + new BranchInst(II->getNormalDest(), NewCall->getParent()); + II->eraseFromParent(); +} + +/// MarkBlocksLiveIn - Insert BB and all of its predescessors into LiveBBs until +/// we reach blocks we've already seen. +static void MarkBlocksLiveIn(BasicBlock *BB, std::set &LiveBBs) { + if (!LiveBBs.insert(BB).second) return; // already been here. + + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) + MarkBlocksLiveIn(*PI, LiveBBs); +} + +// First thing we need to do is scan the whole function for values that are +// live across unwind edges. Each value that is live across an unwind edge +// we spill into a stack location, guaranteeing that there is nothing live +// across the unwind edge. This process also splits all critical edges +// coming out of invoke's. +void LowerInvoke:: +splitLiveRangesLiveAcrossInvokes(std::vector &Invokes) { + // First step, split all critical edges from invoke instructions. + for (unsigned i = 0, e = Invokes.size(); i != e; ++i) { + InvokeInst *II = Invokes[i]; + SplitCriticalEdge(II, 0, this); + SplitCriticalEdge(II, 1, this); + assert(!isa(II->getNormalDest()) && + !isa(II->getUnwindDest()) && + "critical edge splitting left single entry phi nodes?"); + } + + Function *F = Invokes.back()->getParent()->getParent(); + + // To avoid having to handle incoming arguments specially, we lower each arg + // to a copy instruction in the entry block. This ensures that the argument + // value itself cannot be live across the entry block. + BasicBlock::iterator AfterAllocaInsertPt = F->begin()->begin(); + while (isa(AfterAllocaInsertPt) && + isa(cast(AfterAllocaInsertPt)->getArraySize())) + ++AfterAllocaInsertPt; + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); + AI != E; ++AI) { + // This is always a no-op cast because we're casting AI to AI->getType() so + // src and destination types are identical. BitCast is the only possibility. + CastInst *NC = new BitCastInst( + AI, AI->getType(), AI->getName()+".tmp", AfterAllocaInsertPt); + AI->replaceAllUsesWith(NC); + // Normally its is forbidden to replace a CastInst's operand because it + // could cause the opcode to reflect an illegal conversion. However, we're + // replacing it here with the same value it was constructed with to simply + // make NC its user. + NC->setOperand(0, AI); + } + + // Finally, scan the code looking for instructions with bad live ranges. + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E; ++II) { + // Ignore obvious cases we don't have to handle. In particular, most + // instructions either have no uses or only have a single use inside the + // current block. Ignore them quickly. + Instruction *Inst = II; + if (Inst->use_empty()) continue; + if (Inst->hasOneUse() && + cast(Inst->use_back())->getParent() == BB && + !isa(Inst->use_back())) continue; + + // If this is an alloca in the entry block, it's not a real register + // value. + if (AllocaInst *AI = dyn_cast(Inst)) + if (isa(AI->getArraySize()) && BB == F->begin()) + continue; + + // Avoid iterator invalidation by copying users to a temporary vector. + std::vector Users; + for (Value::use_iterator UI = Inst->use_begin(), E = Inst->use_end(); + UI != E; ++UI) { + Instruction *User = cast(*UI); + if (User->getParent() != BB || isa(User)) + Users.push_back(User); + } + + // Scan all of the uses and see if the live range is live across an unwind + // edge. If we find a use live across an invoke edge, create an alloca + // and spill the value. + std::set InvokesWithStoreInserted; + + // Find all of the blocks that this value is live in. + std::set LiveBBs; + LiveBBs.insert(Inst->getParent()); + while (!Users.empty()) { + Instruction *U = Users.back(); + Users.pop_back(); + + if (!isa(U)) { + MarkBlocksLiveIn(U->getParent(), LiveBBs); + } else { + // Uses for a PHI node occur in their predecessor block. + PHINode *PN = cast(U); + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == Inst) + MarkBlocksLiveIn(PN->getIncomingBlock(i), LiveBBs); + } + } + + // Now that we know all of the blocks that this thing is live in, see if + // it includes any of the unwind locations. + bool NeedsSpill = false; + for (unsigned i = 0, e = Invokes.size(); i != e; ++i) { + BasicBlock *UnwindBlock = Invokes[i]->getUnwindDest(); + if (UnwindBlock != BB && LiveBBs.count(UnwindBlock)) { + NeedsSpill = true; + } + } + + // If we decided we need a spill, do it. + if (NeedsSpill) { + ++NumSpilled; + DemoteRegToStack(*Inst, true); + } + } +} + +bool LowerInvoke::insertExpensiveEHSupport(Function &F) { + std::vector Returns; + std::vector Unwinds; + std::vector Invokes; + + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (ReturnInst *RI = dyn_cast(BB->getTerminator())) { + // Remember all return instructions in case we insert an invoke into this + // function. + Returns.push_back(RI); + } else if (InvokeInst *II = dyn_cast(BB->getTerminator())) { + Invokes.push_back(II); + } else if (UnwindInst *UI = dyn_cast(BB->getTerminator())) { + Unwinds.push_back(UI); + } + + if (Unwinds.empty() && Invokes.empty()) return false; + + NumInvokes += Invokes.size(); + NumUnwinds += Unwinds.size(); + + // TODO: This is not an optimal way to do this. In particular, this always + // inserts setjmp calls into the entries of functions with invoke instructions + // even though there are possibly paths through the function that do not + // execute any invokes. In particular, for functions with early exits, e.g. + // the 'addMove' method in hexxagon, it would be nice to not have to do the + // setjmp stuff on the early exit path. This requires a bit of dataflow, but + // would not be too hard to do. + + // If we have an invoke instruction, insert a setjmp that dominates all + // invokes. After the setjmp, use a cond branch that goes to the original + // code path on zero, and to a designated 'catch' block of nonzero. + Value *OldJmpBufPtr = 0; + if (!Invokes.empty()) { + // First thing we need to do is scan the whole function for values that are + // live across unwind edges. Each value that is live across an unwind edge + // we spill into a stack location, guaranteeing that there is nothing live + // across the unwind edge. This process also splits all critical edges + // coming out of invoke's. + splitLiveRangesLiveAcrossInvokes(Invokes); + + BasicBlock *EntryBB = F.begin(); + + // Create an alloca for the incoming jump buffer ptr and the new jump buffer + // that needs to be restored on all exits from the function. This is an + // alloca because the value needs to be live across invokes. + unsigned Align = TLI ? TLI->getJumpBufAlignment() : 0; + AllocaInst *JmpBuf = + new AllocaInst(JBLinkTy, 0, Align, "jblink", F.begin()->begin()); + + std::vector Idx; + Idx.push_back(Constant::getNullValue(Type::Int32Ty)); + Idx.push_back(ConstantInt::get(Type::Int32Ty, 1)); + OldJmpBufPtr = new GetElementPtrInst(JmpBuf, &Idx[0], 2, "OldBuf", + EntryBB->getTerminator()); + + // Copy the JBListHead to the alloca. + Value *OldBuf = new LoadInst(JBListHead, "oldjmpbufptr", true, + EntryBB->getTerminator()); + new StoreInst(OldBuf, OldJmpBufPtr, true, EntryBB->getTerminator()); + + // Add the new jumpbuf to the list. + new StoreInst(JmpBuf, JBListHead, true, EntryBB->getTerminator()); + + // Create the catch block. The catch block is basically a big switch + // statement that goes to all of the invoke catch blocks. + BasicBlock *CatchBB = new BasicBlock("setjmp.catch", &F); + + // Create an alloca which keeps track of which invoke is currently + // executing. For normal calls it contains zero. + AllocaInst *InvokeNum = new AllocaInst(Type::Int32Ty, 0, "invokenum", + EntryBB->begin()); + new StoreInst(ConstantInt::get(Type::Int32Ty, 0), InvokeNum, true, + EntryBB->getTerminator()); + + // Insert a load in the Catch block, and a switch on its value. By default, + // we go to a block that just does an unwind (which is the correct action + // for a standard call). + BasicBlock *UnwindBB = new BasicBlock("unwindbb", &F); + Unwinds.push_back(new UnwindInst(UnwindBB)); + + Value *CatchLoad = new LoadInst(InvokeNum, "invoke.num", true, CatchBB); + SwitchInst *CatchSwitch = + new SwitchInst(CatchLoad, UnwindBB, Invokes.size(), CatchBB); + + // Now that things are set up, insert the setjmp call itself. + + // Split the entry block to insert the conditional branch for the setjmp. + BasicBlock *ContBlock = EntryBB->splitBasicBlock(EntryBB->getTerminator(), + "setjmp.cont"); + + Idx[1] = ConstantInt::get(Type::Int32Ty, 0); + Value *JmpBufPtr = new GetElementPtrInst(JmpBuf, &Idx[0], Idx.size(), + "TheJmpBuf", + EntryBB->getTerminator()); + Value *SJRet = new CallInst(SetJmpFn, JmpBufPtr, "sjret", + EntryBB->getTerminator()); + + // Compare the return value to zero. + Value *IsNormal = new ICmpInst(ICmpInst::ICMP_EQ, SJRet, + Constant::getNullValue(SJRet->getType()), + "notunwind", EntryBB->getTerminator()); + // Nuke the uncond branch. + EntryBB->getTerminator()->eraseFromParent(); + + // Put in a new condbranch in its place. + new BranchInst(ContBlock, CatchBB, IsNormal, EntryBB); + + // At this point, we are all set up, rewrite each invoke instruction. + for (unsigned i = 0, e = Invokes.size(); i != e; ++i) + rewriteExpensiveInvoke(Invokes[i], i+1, InvokeNum, CatchSwitch); + } + + // We know that there is at least one unwind. + + // Create three new blocks, the block to load the jmpbuf ptr and compare + // against null, the block to do the longjmp, and the error block for if it + // is null. Add them at the end of the function because they are not hot. + BasicBlock *UnwindHandler = new BasicBlock("dounwind", &F); + BasicBlock *UnwindBlock = new BasicBlock("unwind", &F); + BasicBlock *TermBlock = new BasicBlock("unwinderror", &F); + + // If this function contains an invoke, restore the old jumpbuf ptr. + Value *BufPtr; + if (OldJmpBufPtr) { + // Before the return, insert a copy from the saved value to the new value. + BufPtr = new LoadInst(OldJmpBufPtr, "oldjmpbufptr", UnwindHandler); + new StoreInst(BufPtr, JBListHead, UnwindHandler); + } else { + BufPtr = new LoadInst(JBListHead, "ehlist", UnwindHandler); + } + + // Load the JBList, if it's null, then there was no catch! + Value *NotNull = new ICmpInst(ICmpInst::ICMP_NE, BufPtr, + Constant::getNullValue(BufPtr->getType()), + "notnull", UnwindHandler); + new BranchInst(UnwindBlock, TermBlock, NotNull, UnwindHandler); + + // Create the block to do the longjmp. + // Get a pointer to the jmpbuf and longjmp. + std::vector Idx; + Idx.push_back(Constant::getNullValue(Type::Int32Ty)); + Idx.push_back(ConstantInt::get(Type::Int32Ty, 0)); + Idx[0] = new GetElementPtrInst(BufPtr, &Idx[0], 2, "JmpBuf", UnwindBlock); + Idx[1] = ConstantInt::get(Type::Int32Ty, 1); + new CallInst(LongJmpFn, &Idx[0], Idx.size(), "", UnwindBlock); + new UnreachableInst(UnwindBlock); + + // Set up the term block ("throw without a catch"). + new UnreachableInst(TermBlock); + + // Insert a new call to write(2, AbortMessage, AbortMessageLength); + writeAbortMessage(TermBlock->getTerminator()); + + // Insert a call to abort() + (new CallInst(AbortFn, "", + TermBlock->getTerminator()))->setTailCall(); + + + // Replace all unwinds with a branch to the unwind handler. + for (unsigned i = 0, e = Unwinds.size(); i != e; ++i) { + new BranchInst(UnwindHandler, Unwinds[i]); + Unwinds[i]->eraseFromParent(); + } + + // Finally, for any returns from this function, if this function contains an + // invoke, restore the old jmpbuf pointer to its input value. + if (OldJmpBufPtr) { + for (unsigned i = 0, e = Returns.size(); i != e; ++i) { + ReturnInst *R = Returns[i]; + + // Before the return, insert a copy from the saved value to the new value. + Value *OldBuf = new LoadInst(OldJmpBufPtr, "oldjmpbufptr", true, R); + new StoreInst(OldBuf, JBListHead, true, R); + } + } + + return true; +} + +bool LowerInvoke::runOnFunction(Function &F) { + if (ExpensiveEHSupport) + return insertExpensiveEHSupport(F); + else + return insertCheapEHSupport(F); +} diff --git a/lib/Transforms/Utils/LowerSelect.cpp b/lib/Transforms/Utils/LowerSelect.cpp new file mode 100644 index 0000000..1882695 --- /dev/null +++ b/lib/Transforms/Utils/LowerSelect.cpp @@ -0,0 +1,105 @@ +//===- LowerSelect.cpp - Transform select insts to branches ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass lowers select instructions into conditional branches for targets +// that do not have conditional moves or that have not implemented the select +// instruction yet. +// +// Note that this pass could be improved. In particular it turns every select +// instruction into a new conditional branch, even though some common cases have +// select instructions on the same predicate next to each other. It would be +// better to use the same branch for the whole group of selects. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Type.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +namespace { + /// LowerSelect - Turn select instructions into conditional branches. + /// + class VISIBILITY_HIDDEN LowerSelect : public FunctionPass { + bool OnlyFP; // Only lower FP select instructions? + public: + static char ID; // Pass identification, replacement for typeid + LowerSelect(bool onlyfp = false) : FunctionPass((intptr_t)&ID), + OnlyFP(onlyfp) {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + // This certainly destroys the CFG. + // This is a cluster of orthogonal Transforms: + AU.addPreserved(); + AU.addPreservedID(PromoteMemoryToRegisterID); + AU.addPreservedID(LowerSwitchID); + AU.addPreservedID(LowerInvokePassID); + AU.addPreservedID(LowerAllocationsID); + } + + bool runOnFunction(Function &F); + }; + + char LowerSelect::ID = 0; + RegisterPass + X("lowerselect", "Lower select instructions to branches"); +} + +// Publically exposed interface to pass... +const PassInfo *llvm::LowerSelectID = X.getPassInfo(); +//===----------------------------------------------------------------------===// +// This pass converts SelectInst instructions into conditional branch and PHI +// instructions. If the OnlyFP flag is set to true, then only floating point +// select instructions are lowered. +// +FunctionPass *llvm::createLowerSelectPass(bool OnlyFP) { + return new LowerSelect(OnlyFP); +} + + +bool LowerSelect::runOnFunction(Function &F) { + bool Changed = false; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + if (SelectInst *SI = dyn_cast(I)) + if (!OnlyFP || SI->getType()->isFloatingPoint()) { + // Split this basic block in half right before the select instruction. + BasicBlock *NewCont = + BB->splitBasicBlock(I, BB->getName()+".selectcont"); + + // Make the true block, and make it branch to the continue block. + BasicBlock *NewTrue = new BasicBlock(BB->getName()+".selecttrue", + BB->getParent(), NewCont); + new BranchInst(NewCont, NewTrue); + + // Make the unconditional branch in the incoming block be a + // conditional branch on the select predicate. + BB->getInstList().erase(BB->getTerminator()); + new BranchInst(NewTrue, NewCont, SI->getCondition(), BB); + + // Create a new PHI node in the cont block with the entries we need. + PHINode *PN = new PHINode(SI->getType(), "", NewCont->begin()); + PN->takeName(SI); + PN->addIncoming(SI->getTrueValue(), NewTrue); + PN->addIncoming(SI->getFalseValue(), BB); + + // Use the PHI instead of the select. + SI->replaceAllUsesWith(PN); + NewCont->getInstList().erase(SI); + + Changed = true; + break; // This block is done with. + } + } + return Changed; +} diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp new file mode 100644 index 0000000..633633d --- /dev/null +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -0,0 +1,324 @@ +//===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// The LowerSwitch transformation rewrites switch statements with a sequence of +// branches, which allows targets to get away with not implementing the switch +// statement until it is convenient. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" +#include "llvm/Constants.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Compiler.h" +#include +using namespace llvm; + +namespace { + /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch + /// instructions. Note that this cannot be a BasicBlock pass because it + /// modifies the CFG! + class VISIBILITY_HIDDEN LowerSwitch : public FunctionPass { + public: + static char ID; // Pass identification, replacement for typeid + LowerSwitch() : FunctionPass((intptr_t) &ID) {} + + virtual bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + // This is a cluster of orthogonal Transforms + AU.addPreserved(); + AU.addPreservedID(PromoteMemoryToRegisterID); + AU.addPreservedID(LowerSelectID); + AU.addPreservedID(LowerInvokePassID); + AU.addPreservedID(LowerAllocationsID); + } + + struct CaseRange { + Constant* Low; + Constant* High; + BasicBlock* BB; + + CaseRange() : Low(0), High(0), BB(0) { } + CaseRange(Constant* low, Constant* high, BasicBlock* bb) : + Low(low), High(high), BB(bb) { } + }; + + typedef std::vector CaseVector; + typedef std::vector::iterator CaseItr; + private: + void processSwitchInst(SwitchInst *SI); + + BasicBlock* switchConvert(CaseItr Begin, CaseItr End, Value* Val, + BasicBlock* OrigBlock, BasicBlock* Default); + BasicBlock* newLeafBlock(CaseRange& Leaf, Value* Val, + BasicBlock* OrigBlock, BasicBlock* Default); + unsigned Clusterify(CaseVector& Cases, SwitchInst *SI); + }; + + /// The comparison function for sorting the switch case values in the vector. + /// WARNING: Case ranges should be disjoint! + struct CaseCmp { + bool operator () (const LowerSwitch::CaseRange& C1, + const LowerSwitch::CaseRange& C2) { + + const ConstantInt* CI1 = cast(C1.Low); + const ConstantInt* CI2 = cast(C2.High); + return CI1->getValue().slt(CI2->getValue()); + } + }; + + char LowerSwitch::ID = 0; + RegisterPass + X("lowerswitch", "Lower SwitchInst's to branches"); +} + +// Publically exposed interface to pass... +const PassInfo *llvm::LowerSwitchID = X.getPassInfo(); +// createLowerSwitchPass - Interface to this file... +FunctionPass *llvm::createLowerSwitchPass() { + return new LowerSwitch(); +} + +bool LowerSwitch::runOnFunction(Function &F) { + bool Changed = false; + + for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { + BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks + + if (SwitchInst *SI = dyn_cast(Cur->getTerminator())) { + Changed = true; + processSwitchInst(SI); + } + } + + return Changed; +} + +// operator<< - Used for debugging purposes. +// +static std::ostream& operator<<(std::ostream &O, + const LowerSwitch::CaseVector &C) { + O << "["; + + for (LowerSwitch::CaseVector::const_iterator B = C.begin(), + E = C.end(); B != E; ) { + O << *B->Low << " -" << *B->High; + if (++B != E) O << ", "; + } + + return O << "]"; +} + +static OStream& operator<<(OStream &O, const LowerSwitch::CaseVector &C) { + if (O.stream()) *O.stream() << C; + return O; +} + +// switchConvert - Convert the switch statement into a binary lookup of +// the case values. The function recursively builds this tree. +// +BasicBlock* LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, + Value* Val, BasicBlock* OrigBlock, + BasicBlock* Default) +{ + unsigned Size = End - Begin; + + if (Size == 1) + return newLeafBlock(*Begin, Val, OrigBlock, Default); + + unsigned Mid = Size / 2; + std::vector LHS(Begin, Begin + Mid); + DOUT << "LHS: " << LHS << "\n"; + std::vector RHS(Begin + Mid, End); + DOUT << "RHS: " << RHS << "\n"; + + CaseRange& Pivot = *(Begin + Mid); + DEBUG( DOUT << "Pivot ==> " + << cast(Pivot.Low)->getValue().toStringSigned(10) + << " -" + << cast(Pivot.High)->getValue().toStringSigned(10) + << "\n"); + + BasicBlock* LBranch = switchConvert(LHS.begin(), LHS.end(), Val, + OrigBlock, Default); + BasicBlock* RBranch = switchConvert(RHS.begin(), RHS.end(), Val, + OrigBlock, Default); + + // Create a new node that checks if the value is < pivot. Go to the + // left branch if it is and right branch if not. + Function* F = OrigBlock->getParent(); + BasicBlock* NewNode = new BasicBlock("NodeBlock"); + Function::iterator FI = OrigBlock; + F->getBasicBlockList().insert(++FI, NewNode); + + ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, Val, Pivot.Low, "Pivot"); + NewNode->getInstList().push_back(Comp); + new BranchInst(LBranch, RBranch, Comp, NewNode); + return NewNode; +} + +// newLeafBlock - Create a new leaf block for the binary lookup tree. It +// checks if the switch's value == the case's value. If not, then it +// jumps to the default branch. At this point in the tree, the value +// can't be another valid case value, so the jump to the "default" branch +// is warranted. +// +BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, + BasicBlock* OrigBlock, + BasicBlock* Default) +{ + Function* F = OrigBlock->getParent(); + BasicBlock* NewLeaf = new BasicBlock("LeafBlock"); + Function::iterator FI = OrigBlock; + F->getBasicBlockList().insert(++FI, NewLeaf); + + // Emit comparison + ICmpInst* Comp = NULL; + if (Leaf.Low == Leaf.High) { + // Make the seteq instruction... + Comp = new ICmpInst(ICmpInst::ICMP_EQ, Val, Leaf.Low, + "SwitchLeaf", NewLeaf); + } else { + // Make range comparison + if (cast(Leaf.Low)->isMinValue(true /*isSigned*/)) { + // Val >= Min && Val <= Hi --> Val <= Hi + Comp = new ICmpInst(ICmpInst::ICMP_SLE, Val, Leaf.High, + "SwitchLeaf", NewLeaf); + } else if (cast(Leaf.Low)->isZero()) { + // Val >= 0 && Val <= Hi --> Val <=u Hi + Comp = new ICmpInst(ICmpInst::ICMP_ULE, Val, Leaf.High, + "SwitchLeaf", NewLeaf); + } else { + // Emit V-Lo <=u Hi-Lo + Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); + Instruction* Add = BinaryOperator::createAdd(Val, NegLo, + Val->getName()+".off", + NewLeaf); + Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); + Comp = new ICmpInst(ICmpInst::ICMP_ULE, Add, UpperBound, + "SwitchLeaf", NewLeaf); + } + } + + // Make the conditional branch... + BasicBlock* Succ = Leaf.BB; + new BranchInst(Succ, Default, Comp, NewLeaf); + + // If there were any PHI nodes in this successor, rewrite one entry + // from OrigBlock to come from NewLeaf. + for (BasicBlock::iterator I = Succ->begin(); isa(I); ++I) { + PHINode* PN = cast(I); + // Remove all but one incoming entries from the cluster + uint64_t Range = cast(Leaf.High)->getSExtValue() - + cast(Leaf.Low)->getSExtValue(); + for (uint64_t j = 0; j < Range; ++j) { + PN->removeIncomingValue(OrigBlock); + } + + int BlockIdx = PN->getBasicBlockIndex(OrigBlock); + assert(BlockIdx != -1 && "Switch didn't go to this successor??"); + PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); + } + + return NewLeaf; +} + +// Clusterify - Transform simple list of Cases into list of CaseRange's +unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { + unsigned numCmps = 0; + + // Start with "simple" cases + for (unsigned i = 1; i < SI->getNumSuccessors(); ++i) + Cases.push_back(CaseRange(SI->getSuccessorValue(i), + SI->getSuccessorValue(i), + SI->getSuccessor(i))); + sort(Cases.begin(), Cases.end(), CaseCmp()); + + // Merge case into clusters + if (Cases.size()>=2) + for (CaseItr I=Cases.begin(), J=++(Cases.begin()), E=Cases.end(); J!=E; ) { + int64_t nextValue = cast(J->Low)->getSExtValue(); + int64_t currentValue = cast(I->High)->getSExtValue(); + BasicBlock* nextBB = J->BB; + BasicBlock* currentBB = I->BB; + + // If the two neighboring cases go to the same destination, merge them + // into a single case. + if ((nextValue-currentValue==1) && (currentBB == nextBB)) { + I->High = J->High; + J = Cases.erase(J); + } else { + I = J++; + } + } + + for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { + if (I->Low != I->High) + // A range counts double, since it requires two compares. + ++numCmps; + } + + return numCmps; +} + +// processSwitchInst - Replace the specified switch instruction with a sequence +// of chained if-then insts in a balanced binary search. +// +void LowerSwitch::processSwitchInst(SwitchInst *SI) { + BasicBlock *CurBlock = SI->getParent(); + BasicBlock *OrigBlock = CurBlock; + Function *F = CurBlock->getParent(); + Value *Val = SI->getOperand(0); // The value we are switching on... + BasicBlock* Default = SI->getDefaultDest(); + + // If there is only the default destination, don't bother with the code below. + if (SI->getNumOperands() == 2) { + new BranchInst(SI->getDefaultDest(), CurBlock); + CurBlock->getInstList().erase(SI); + return; + } + + // Create a new, empty default block so that the new hierarchy of + // if-then statements go to this and the PHI nodes are happy. + BasicBlock* NewDefault = new BasicBlock("NewDefault"); + F->getBasicBlockList().insert(Default, NewDefault); + + new BranchInst(Default, NewDefault); + + // If there is an entry in any PHI nodes for the default edge, make sure + // to update them as well. + for (BasicBlock::iterator I = Default->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + int BlockIdx = PN->getBasicBlockIndex(OrigBlock); + assert(BlockIdx != -1 && "Switch didn't go to this successor??"); + PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); + } + + // Prepare cases vector. + CaseVector Cases; + unsigned numCmps = Clusterify(Cases, SI); + + DOUT << "Clusterify finished. Total clusters: " << Cases.size() + << ". Total compares: " << numCmps << "\n"; + DOUT << "Cases: " << Cases << "\n"; + + BasicBlock* SwitchBlock = switchConvert(Cases.begin(), Cases.end(), Val, + OrigBlock, NewDefault); + + // Branch to our shiny new if-then stuff... + new BranchInst(SwitchBlock, OrigBlock); + + // We are now done with the switch instruction, delete it. + CurBlock->getInstList().erase(SI); +} diff --git a/lib/Transforms/Utils/Makefile b/lib/Transforms/Utils/Makefile new file mode 100644 index 0000000..26fc426 --- /dev/null +++ b/lib/Transforms/Utils/Makefile @@ -0,0 +1,15 @@ +##===- lib/Transforms/Utils/Makefile -----------------------*- Makefile -*-===## +# +# The LLVM Compiler Infrastructure +# +# This file was developed by the LLVM research group and is distributed under +# the University of Illinois Open Source License. See LICENSE.TXT for details. +# +##===----------------------------------------------------------------------===## + +LEVEL = ../../.. +LIBRARYNAME = LLVMTransformUtils +BUILD_ARCHIVE = 1 + +include $(LEVEL)/Makefile.common + diff --git a/lib/Transforms/Utils/Mem2Reg.cpp b/lib/Transforms/Utils/Mem2Reg.cpp new file mode 100644 index 0000000..d67b3de --- /dev/null +++ b/lib/Transforms/Utils/Mem2Reg.cpp @@ -0,0 +1,93 @@ +//===- Mem2Reg.cpp - The -mem2reg pass, a wrapper around the Utils lib ----===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass is a simple pass wrapper around the PromoteMemToReg function call +// exposed by the Utils library. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "mem2reg" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Instructions.h" +#include "llvm/Function.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumPromoted, "Number of alloca's promoted"); + +namespace { + struct VISIBILITY_HIDDEN PromotePass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + PromotePass() : FunctionPass((intptr_t)&ID) {} + + // runOnFunction - To run this pass, first we calculate the alloca + // instructions that are safe for promotion, then we promote each one. + // + virtual bool runOnFunction(Function &F); + + // getAnalysisUsage - We need dominance frontiers + // + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.addRequired(); + AU.setPreservesCFG(); + // This is a cluster of orthogonal Transforms + AU.addPreserved(); + AU.addPreservedID(LowerSelectID); + AU.addPreservedID(LowerSwitchID); + AU.addPreservedID(LowerInvokePassID); + AU.addPreservedID(LowerAllocationsID); + } + }; + + char PromotePass::ID = 0; + RegisterPass X("mem2reg", "Promote Memory to Register"); +} // end of anonymous namespace + +bool PromotePass::runOnFunction(Function &F) { + std::vector Allocas; + + BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function + + bool Changed = false; + + DominatorTree &DT = getAnalysis(); + DominanceFrontier &DF = getAnalysis(); + + while (1) { + Allocas.clear(); + + // Find allocas that are safe to promote, by looking at all instructions in + // the entry node + for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I) + if (AllocaInst *AI = dyn_cast(I)) // Is it an alloca? + if (isAllocaPromotable(AI)) + Allocas.push_back(AI); + + if (Allocas.empty()) break; + + PromoteMemToReg(Allocas, DT, DF); + NumPromoted += Allocas.size(); + Changed = true; + } + + return Changed; +} + +// Publically exposed interface to pass... +const PassInfo *llvm::PromoteMemoryToRegisterID = X.getPassInfo(); +// createPromoteMemoryToRegister - Provide an entry point to create this pass. +// +FunctionPass *llvm::createPromoteMemoryToRegisterPass() { + return new PromotePass(); +} diff --git a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp new file mode 100644 index 0000000..259a5a2 --- /dev/null +++ b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -0,0 +1,835 @@ +//===- PromoteMemoryToRegister.cpp - Convert allocas to registers ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file promote memory references to be register references. It promotes +// alloca instructions which only have loads and stores as uses. An alloca is +// transformed by using dominator frontiers to place PHI nodes, then traversing +// the function in depth-first order to rewrite loads and stores as appropriate. +// This is just the standard SSA construction algorithm to construct "pruned" +// SSA form. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include +using namespace llvm; + +// Provide DenseMapKeyInfo for all pointers. +namespace llvm { +template<> +struct DenseMapKeyInfo > { + static inline std::pair getEmptyKey() { + return std::make_pair((BasicBlock*)-1, ~0U); + } + static inline std::pair getTombstoneKey() { + return std::make_pair((BasicBlock*)-2, 0U); + } + static unsigned getHashValue(const std::pair &Val) { + return DenseMapKeyInfo::getHashValue(Val.first) + Val.second*2; + } + static bool isPod() { return true; } +}; +} + +/// isAllocaPromotable - Return true if this alloca is legal for promotion. +/// This is true if there are only loads and stores to the alloca. +/// +bool llvm::isAllocaPromotable(const AllocaInst *AI) { + // FIXME: If the memory unit is of pointer or integer type, we can permit + // assignments to subsections of the memory unit. + + // Only allow direct loads and stores... + for (Value::use_const_iterator UI = AI->use_begin(), UE = AI->use_end(); + UI != UE; ++UI) // Loop over all of the uses of the alloca + if (isa(*UI)) { + // noop + } else if (const StoreInst *SI = dyn_cast(*UI)) { + if (SI->getOperand(0) == AI) + return false; // Don't allow a store OF the AI, only INTO the AI. + } else { + return false; // Not a load or store. + } + + return true; +} + +namespace { + + // Data package used by RenamePass() + class VISIBILITY_HIDDEN RenamePassData { + public: + RenamePassData(BasicBlock *B, BasicBlock *P, + const std::vector &V) : BB(B), Pred(P), Values(V) {} + BasicBlock *BB; + BasicBlock *Pred; + std::vector Values; + }; + + struct VISIBILITY_HIDDEN PromoteMem2Reg { + /// Allocas - The alloca instructions being promoted. + /// + std::vector Allocas; + SmallVector &RetryList; + DominatorTree &DT; + DominanceFrontier &DF; + + /// AST - An AliasSetTracker object to update. If null, don't update it. + /// + AliasSetTracker *AST; + + /// AllocaLookup - Reverse mapping of Allocas. + /// + std::map AllocaLookup; + + /// NewPhiNodes - The PhiNodes we're adding. + /// + DenseMap, PHINode*> NewPhiNodes; + + /// PhiToAllocaMap - For each PHI node, keep track of which entry in Allocas + /// it corresponds to. + DenseMap PhiToAllocaMap; + + /// PointerAllocaValues - If we are updating an AliasSetTracker, then for + /// each alloca that is of pointer type, we keep track of what to copyValue + /// to the inserted PHI nodes here. + /// + std::vector PointerAllocaValues; + + /// Visited - The set of basic blocks the renamer has already visited. + /// + SmallPtrSet Visited; + + /// BBNumbers - Contains a stable numbering of basic blocks to avoid + /// non-determinstic behavior. + DenseMap BBNumbers; + + /// RenamePassWorkList - Worklist used by RenamePass() + std::vector RenamePassWorkList; + + public: + PromoteMem2Reg(const std::vector &A, + SmallVector &Retry, DominatorTree &dt, + DominanceFrontier &df, AliasSetTracker *ast) + : Allocas(A), RetryList(Retry), DT(dt), DF(df), AST(ast) {} + + void run(); + + /// properlyDominates - Return true if I1 properly dominates I2. + /// + bool properlyDominates(Instruction *I1, Instruction *I2) const { + if (InvokeInst *II = dyn_cast(I1)) + I1 = II->getNormalDest()->begin(); + return DT.properlyDominates(I1->getParent(), I2->getParent()); + } + + /// dominates - Return true if BB1 dominates BB2 using the DominatorTree. + /// + bool dominates(BasicBlock *BB1, BasicBlock *BB2) const { + return DT.dominates(BB1, BB2); + } + + private: + void MarkDominatingPHILive(BasicBlock *BB, unsigned AllocaNum, + SmallPtrSet &DeadPHINodes); + bool PromoteLocallyUsedAlloca(BasicBlock *BB, AllocaInst *AI); + void PromoteLocallyUsedAllocas(BasicBlock *BB, + const std::vector &AIs); + + void RenamePass(BasicBlock *BB, BasicBlock *Pred, + std::vector &IncVals); + bool QueuePhiNode(BasicBlock *BB, unsigned AllocaIdx, unsigned &Version, + SmallPtrSet &InsertedPHINodes); + }; + +} // end of anonymous namespace + +void PromoteMem2Reg::run() { + Function &F = *DF.getRoot()->getParent(); + + // LocallyUsedAllocas - Keep track of all of the alloca instructions which are + // only used in a single basic block. These instructions can be efficiently + // promoted by performing a single linear scan over that one block. Since + // individual basic blocks are sometimes large, we group together all allocas + // that are live in a single basic block by the basic block they are live in. + std::map > LocallyUsedAllocas; + + if (AST) PointerAllocaValues.resize(Allocas.size()); + + for (unsigned AllocaNum = 0; AllocaNum != Allocas.size(); ++AllocaNum) { + AllocaInst *AI = Allocas[AllocaNum]; + + assert(isAllocaPromotable(AI) && + "Cannot promote non-promotable alloca!"); + assert(AI->getParent()->getParent() == &F && + "All allocas should be in the same function, which is same as DF!"); + + if (AI->use_empty()) { + // If there are no uses of the alloca, just delete it now. + if (AST) AST->deleteValue(AI); + AI->eraseFromParent(); + + // Remove the alloca from the Allocas list, since it has been processed + Allocas[AllocaNum] = Allocas.back(); + Allocas.pop_back(); + --AllocaNum; + continue; + } + + // Calculate the set of read and write-locations for each alloca. This is + // analogous to finding the 'uses' and 'definitions' of each variable. + std::vector DefiningBlocks; + std::vector UsingBlocks; + + StoreInst *OnlyStore = 0; + BasicBlock *OnlyBlock = 0; + bool OnlyUsedInOneBlock = true; + + // As we scan the uses of the alloca instruction, keep track of stores, and + // decide whether all of the loads and stores to the alloca are within the + // same basic block. + Value *AllocaPointerVal = 0; + for (Value::use_iterator U =AI->use_begin(), E = AI->use_end(); U != E;++U){ + Instruction *User = cast(*U); + if (StoreInst *SI = dyn_cast(User)) { + // Remember the basic blocks which define new values for the alloca + DefiningBlocks.push_back(SI->getParent()); + AllocaPointerVal = SI->getOperand(0); + OnlyStore = SI; + } else { + LoadInst *LI = cast(User); + // Otherwise it must be a load instruction, keep track of variable reads + UsingBlocks.push_back(LI->getParent()); + AllocaPointerVal = LI; + } + + if (OnlyUsedInOneBlock) { + if (OnlyBlock == 0) + OnlyBlock = User->getParent(); + else if (OnlyBlock != User->getParent()) + OnlyUsedInOneBlock = false; + } + } + + // If the alloca is only read and written in one basic block, just perform a + // linear sweep over the block to eliminate it. + if (OnlyUsedInOneBlock) { + LocallyUsedAllocas[OnlyBlock].push_back(AI); + + // Remove the alloca from the Allocas list, since it will be processed. + Allocas[AllocaNum] = Allocas.back(); + Allocas.pop_back(); + --AllocaNum; + continue; + } + + // If there is only a single store to this value, replace any loads of + // it that are directly dominated by the definition with the value stored. + if (DefiningBlocks.size() == 1) { + // Be aware of loads before the store. + std::set ProcessedBlocks; + for (unsigned i = 0, e = UsingBlocks.size(); i != e; ++i) + // If the store dominates the block and if we haven't processed it yet, + // do so now. + if (dominates(OnlyStore->getParent(), UsingBlocks[i])) + if (ProcessedBlocks.insert(UsingBlocks[i]).second) { + BasicBlock *UseBlock = UsingBlocks[i]; + + // If the use and store are in the same block, do a quick scan to + // verify that there are no uses before the store. + if (UseBlock == OnlyStore->getParent()) { + BasicBlock::iterator I = UseBlock->begin(); + for (; &*I != OnlyStore; ++I) { // scan block for store. + if (isa(I) && I->getOperand(0) == AI) + break; + } + if (&*I != OnlyStore) break; // Do not handle this case. + } + + // Otherwise, if this is a different block or if all uses happen + // after the store, do a simple linear scan to replace loads with + // the stored value. + for (BasicBlock::iterator I = UseBlock->begin(),E = UseBlock->end(); + I != E; ) { + if (LoadInst *LI = dyn_cast(I++)) { + if (LI->getOperand(0) == AI) { + LI->replaceAllUsesWith(OnlyStore->getOperand(0)); + if (AST && isa(LI->getType())) + AST->deleteValue(LI); + LI->eraseFromParent(); + } + } + } + + // Finally, remove this block from the UsingBlock set. + UsingBlocks[i] = UsingBlocks.back(); + --i; --e; + } + + // Finally, after the scan, check to see if the store is all that is left. + if (UsingBlocks.empty()) { + // The alloca has been processed, move on. + Allocas[AllocaNum] = Allocas.back(); + Allocas.pop_back(); + --AllocaNum; + continue; + } + } + + + if (AST) + PointerAllocaValues[AllocaNum] = AllocaPointerVal; + + // If we haven't computed a numbering for the BB's in the function, do so + // now. + if (BBNumbers.empty()) { + unsigned ID = 0; + for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) + BBNumbers[I] = ID++; + } + + // Compute the locations where PhiNodes need to be inserted. Look at the + // dominance frontier of EACH basic-block we have a write in. + // + unsigned CurrentVersion = 0; + SmallPtrSet InsertedPHINodes; + std::vector > DFBlocks; + while (!DefiningBlocks.empty()) { + BasicBlock *BB = DefiningBlocks.back(); + DefiningBlocks.pop_back(); + + // Look up the DF for this write, add it to PhiNodes + DominanceFrontier::const_iterator it = DF.find(BB); + if (it != DF.end()) { + const DominanceFrontier::DomSetType &S = it->second; + + // In theory we don't need the indirection through the DFBlocks vector. + // In practice, the order of calling QueuePhiNode would depend on the + // (unspecified) ordering of basic blocks in the dominance frontier, + // which would give PHI nodes non-determinstic subscripts. Fix this by + // processing blocks in order of the occurance in the function. + for (DominanceFrontier::DomSetType::const_iterator P = S.begin(), + PE = S.end(); P != PE; ++P) + DFBlocks.push_back(std::make_pair(BBNumbers[*P], *P)); + + // Sort by which the block ordering in the function. + std::sort(DFBlocks.begin(), DFBlocks.end()); + + for (unsigned i = 0, e = DFBlocks.size(); i != e; ++i) { + BasicBlock *BB = DFBlocks[i].second; + if (QueuePhiNode(BB, AllocaNum, CurrentVersion, InsertedPHINodes)) + DefiningBlocks.push_back(BB); + } + DFBlocks.clear(); + } + } + + // Now that we have inserted PHI nodes along the Iterated Dominance Frontier + // of the writes to the variable, scan through the reads of the variable, + // marking PHI nodes which are actually necessary as alive (by removing them + // from the InsertedPHINodes set). This is not perfect: there may PHI + // marked alive because of loads which are dominated by stores, but there + // will be no unmarked PHI nodes which are actually used. + // + for (unsigned i = 0, e = UsingBlocks.size(); i != e; ++i) + MarkDominatingPHILive(UsingBlocks[i], AllocaNum, InsertedPHINodes); + UsingBlocks.clear(); + + // If there are any PHI nodes which are now known to be dead, remove them! + for (SmallPtrSet::iterator I = InsertedPHINodes.begin(), + E = InsertedPHINodes.end(); I != E; ++I) { + PHINode *PN = *I; + bool Erased=NewPhiNodes.erase(std::make_pair(PN->getParent(), AllocaNum)); + Erased=Erased; + assert(Erased && "PHI already removed?"); + + if (AST && isa(PN->getType())) + AST->deleteValue(PN); + PN->eraseFromParent(); + PhiToAllocaMap.erase(PN); + } + + // Keep the reverse mapping of the 'Allocas' array. + AllocaLookup[Allocas[AllocaNum]] = AllocaNum; + } + + // Process all allocas which are only used in a single basic block. + for (std::map >::iterator I = + LocallyUsedAllocas.begin(), E = LocallyUsedAllocas.end(); I != E; ++I){ + const std::vector &LocAllocas = I->second; + assert(!LocAllocas.empty() && "empty alloca list??"); + + // It's common for there to only be one alloca in the list. Handle it + // efficiently. + if (LocAllocas.size() == 1) { + // If we can do the quick promotion pass, do so now. + if (PromoteLocallyUsedAlloca(I->first, LocAllocas[0])) + RetryList.push_back(LocAllocas[0]); // Failed, retry later. + } else { + // Locally promote anything possible. Note that if this is unable to + // promote a particular alloca, it puts the alloca onto the Allocas vector + // for global processing. + PromoteLocallyUsedAllocas(I->first, LocAllocas); + } + } + + if (Allocas.empty()) + return; // All of the allocas must have been trivial! + + // Set the incoming values for the basic block to be null values for all of + // the alloca's. We do this in case there is a load of a value that has not + // been stored yet. In this case, it will get this null value. + // + std::vector Values(Allocas.size()); + for (unsigned i = 0, e = Allocas.size(); i != e; ++i) + Values[i] = UndefValue::get(Allocas[i]->getAllocatedType()); + + // Walks all basic blocks in the function performing the SSA rename algorithm + // and inserting the phi nodes we marked as necessary + // + RenamePassWorkList.clear(); + RenamePassWorkList.push_back(RenamePassData(F.begin(), 0, Values)); + while(!RenamePassWorkList.empty()) { + RenamePassData RPD = RenamePassWorkList.back(); + RenamePassWorkList.pop_back(); + // RenamePass may add new worklist entries. + RenamePass(RPD.BB, RPD.Pred, RPD.Values); + } + + // The renamer uses the Visited set to avoid infinite loops. Clear it now. + Visited.clear(); + + // Remove the allocas themselves from the function. + for (unsigned i = 0, e = Allocas.size(); i != e; ++i) { + Instruction *A = Allocas[i]; + + // If there are any uses of the alloca instructions left, they must be in + // sections of dead code that were not processed on the dominance frontier. + // Just delete the users now. + // + if (!A->use_empty()) + A->replaceAllUsesWith(UndefValue::get(A->getType())); + if (AST) AST->deleteValue(A); + A->eraseFromParent(); + } + + + // Loop over all of the PHI nodes and see if there are any that we can get + // rid of because they merge all of the same incoming values. This can + // happen due to undef values coming into the PHI nodes. This process is + // iterative, because eliminating one PHI node can cause others to be removed. + bool EliminatedAPHI = true; + while (EliminatedAPHI) { + EliminatedAPHI = false; + + for (DenseMap, PHINode*>::iterator I = + NewPhiNodes.begin(), E = NewPhiNodes.end(); I != E;) { + PHINode *PN = I->second; + + // If this PHI node merges one value and/or undefs, get the value. + if (Value *V = PN->hasConstantValue(true)) { + if (!isa(V) || + properlyDominates(cast(V), PN)) { + if (AST && isa(PN->getType())) + AST->deleteValue(PN); + PN->replaceAllUsesWith(V); + PN->eraseFromParent(); + NewPhiNodes.erase(I++); + EliminatedAPHI = true; + continue; + } + } + ++I; + } + } + + // At this point, the renamer has added entries to PHI nodes for all reachable + // code. Unfortunately, there may be unreachable blocks which the renamer + // hasn't traversed. If this is the case, the PHI nodes may not + // have incoming values for all predecessors. Loop over all PHI nodes we have + // created, inserting undef values if they are missing any incoming values. + // + for (DenseMap, PHINode*>::iterator I = + NewPhiNodes.begin(), E = NewPhiNodes.end(); I != E; ++I) { + // We want to do this once per basic block. As such, only process a block + // when we find the PHI that is the first entry in the block. + PHINode *SomePHI = I->second; + BasicBlock *BB = SomePHI->getParent(); + if (&BB->front() != SomePHI) + continue; + + // Count the number of preds for BB. + SmallVector Preds(pred_begin(BB), pred_end(BB)); + + // Only do work here if there the PHI nodes are missing incoming values. We + // know that all PHI nodes that were inserted in a block will have the same + // number of incoming values, so we can just check any of them. + if (SomePHI->getNumIncomingValues() == Preds.size()) + continue; + + // Ok, now we know that all of the PHI nodes are missing entries for some + // basic blocks. Start by sorting the incoming predecessors for efficient + // access. + std::sort(Preds.begin(), Preds.end()); + + // Now we loop through all BB's which have entries in SomePHI and remove + // them from the Preds list. + for (unsigned i = 0, e = SomePHI->getNumIncomingValues(); i != e; ++i) { + // Do a log(n) search of the Preds list for the entry we want. + SmallVector::iterator EntIt = + std::lower_bound(Preds.begin(), Preds.end(), + SomePHI->getIncomingBlock(i)); + assert(EntIt != Preds.end() && *EntIt == SomePHI->getIncomingBlock(i)&& + "PHI node has entry for a block which is not a predecessor!"); + + // Remove the entry + Preds.erase(EntIt); + } + + // At this point, the blocks left in the preds list must have dummy + // entries inserted into every PHI nodes for the block. Update all the phi + // nodes in this block that we are inserting (there could be phis before + // mem2reg runs). + unsigned NumBadPreds = SomePHI->getNumIncomingValues(); + BasicBlock::iterator BBI = BB->begin(); + while ((SomePHI = dyn_cast(BBI++)) && + SomePHI->getNumIncomingValues() == NumBadPreds) { + Value *UndefVal = UndefValue::get(SomePHI->getType()); + for (unsigned pred = 0, e = Preds.size(); pred != e; ++pred) + SomePHI->addIncoming(UndefVal, Preds[pred]); + } + } + + NewPhiNodes.clear(); +} + +// MarkDominatingPHILive - Mem2Reg wants to construct "pruned" SSA form, not +// "minimal" SSA form. To do this, it inserts all of the PHI nodes on the IDF +// as usual (inserting the PHI nodes in the DeadPHINodes set), then processes +// each read of the variable. For each block that reads the variable, this +// function is called, which removes used PHI nodes from the DeadPHINodes set. +// After all of the reads have been processed, any PHI nodes left in the +// DeadPHINodes set are removed. +// +void PromoteMem2Reg::MarkDominatingPHILive(BasicBlock *BB, unsigned AllocaNum, + SmallPtrSet &DeadPHINodes) { + // Scan the immediate dominators of this block looking for a block which has a + // PHI node for Alloca num. If we find it, mark the PHI node as being alive! + DomTreeNode *IDomNode = DT.getNode(BB); + for (DomTreeNode *IDom = IDomNode; IDom; IDom = IDom->getIDom()) { + BasicBlock *DomBB = IDom->getBlock(); + DenseMap, PHINode*>::iterator + I = NewPhiNodes.find(std::make_pair(DomBB, AllocaNum)); + if (I != NewPhiNodes.end()) { + // Ok, we found an inserted PHI node which dominates this value. + PHINode *DominatingPHI = I->second; + + // Find out if we previously thought it was dead. If so, mark it as being + // live by removing it from the DeadPHINodes set. + if (DeadPHINodes.erase(DominatingPHI)) { + // Now that we have marked the PHI node alive, also mark any PHI nodes + // which it might use as being alive as well. + for (pred_iterator PI = pred_begin(DomBB), PE = pred_end(DomBB); + PI != PE; ++PI) + MarkDominatingPHILive(*PI, AllocaNum, DeadPHINodes); + } + } + } +} + +/// PromoteLocallyUsedAlloca - Many allocas are only used within a single basic +/// block. If this is the case, avoid traversing the CFG and inserting a lot of +/// potentially useless PHI nodes by just performing a single linear pass over +/// the basic block using the Alloca. +/// +/// If we cannot promote this alloca (because it is read before it is written), +/// return true. This is necessary in cases where, due to control flow, the +/// alloca is potentially undefined on some control flow paths. e.g. code like +/// this is potentially correct: +/// +/// for (...) { if (c) { A = undef; undef = B; } } +/// +/// ... so long as A is not used before undef is set. +/// +bool PromoteMem2Reg::PromoteLocallyUsedAlloca(BasicBlock *BB, AllocaInst *AI) { + assert(!AI->use_empty() && "There are no uses of the alloca!"); + + // Handle degenerate cases quickly. + if (AI->hasOneUse()) { + Instruction *U = cast(AI->use_back()); + if (LoadInst *LI = dyn_cast(U)) { + // Must be a load of uninitialized value. + LI->replaceAllUsesWith(UndefValue::get(AI->getAllocatedType())); + if (AST && isa(LI->getType())) + AST->deleteValue(LI); + } else { + // Otherwise it must be a store which is never read. + assert(isa(U)); + } + BB->getInstList().erase(U); + } else { + // Uses of the uninitialized memory location shall get undef. + Value *CurVal = 0; + + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) { + Instruction *Inst = I++; + if (LoadInst *LI = dyn_cast(Inst)) { + if (LI->getOperand(0) == AI) { + if (!CurVal) return true; // Could not locally promote! + + // Loads just returns the "current value"... + LI->replaceAllUsesWith(CurVal); + if (AST && isa(LI->getType())) + AST->deleteValue(LI); + BB->getInstList().erase(LI); + } + } else if (StoreInst *SI = dyn_cast(Inst)) { + if (SI->getOperand(1) == AI) { + // Store updates the "current value"... + CurVal = SI->getOperand(0); + BB->getInstList().erase(SI); + } + } + } + } + + // After traversing the basic block, there should be no more uses of the + // alloca, remove it now. + assert(AI->use_empty() && "Uses of alloca from more than one BB??"); + if (AST) AST->deleteValue(AI); + AI->getParent()->getInstList().erase(AI); + return false; +} + +/// PromoteLocallyUsedAllocas - This method is just like +/// PromoteLocallyUsedAlloca, except that it processes multiple alloca +/// instructions in parallel. This is important in cases where we have large +/// basic blocks, as we don't want to rescan the entire basic block for each +/// alloca which is locally used in it (which might be a lot). +void PromoteMem2Reg:: +PromoteLocallyUsedAllocas(BasicBlock *BB, const std::vector &AIs) { + std::map CurValues; + for (unsigned i = 0, e = AIs.size(); i != e; ++i) + CurValues[AIs[i]] = 0; // Insert with null value + + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) { + Instruction *Inst = I++; + if (LoadInst *LI = dyn_cast(Inst)) { + // Is this a load of an alloca we are tracking? + if (AllocaInst *AI = dyn_cast(LI->getOperand(0))) { + std::map::iterator AIt = CurValues.find(AI); + if (AIt != CurValues.end()) { + // If loading an uninitialized value, allow the inter-block case to + // handle it. Due to control flow, this might actually be ok. + if (AIt->second == 0) { // Use of locally uninitialized value?? + RetryList.push_back(AI); // Retry elsewhere. + CurValues.erase(AIt); // Stop tracking this here. + if (CurValues.empty()) return; + } else { + // Loads just returns the "current value"... + LI->replaceAllUsesWith(AIt->second); + if (AST && isa(LI->getType())) + AST->deleteValue(LI); + BB->getInstList().erase(LI); + } + } + } + } else if (StoreInst *SI = dyn_cast(Inst)) { + if (AllocaInst *AI = dyn_cast(SI->getOperand(1))) { + std::map::iterator AIt = CurValues.find(AI); + if (AIt != CurValues.end()) { + // Store updates the "current value"... + AIt->second = SI->getOperand(0); + BB->getInstList().erase(SI); + } + } + } + } +} + + + +// QueuePhiNode - queues a phi-node to be added to a basic-block for a specific +// Alloca returns true if there wasn't already a phi-node for that variable +// +bool PromoteMem2Reg::QueuePhiNode(BasicBlock *BB, unsigned AllocaNo, + unsigned &Version, + SmallPtrSet &InsertedPHINodes) { + // Look up the basic-block in question. + PHINode *&PN = NewPhiNodes[std::make_pair(BB, AllocaNo)]; + + // If the BB already has a phi node added for the i'th alloca then we're done! + if (PN) return false; + + // Create a PhiNode using the dereferenced type... and add the phi-node to the + // BasicBlock. + PN = new PHINode(Allocas[AllocaNo]->getAllocatedType(), + Allocas[AllocaNo]->getName() + "." + + utostr(Version++), BB->begin()); + PhiToAllocaMap[PN] = AllocaNo; + + InsertedPHINodes.insert(PN); + + if (AST && isa(PN->getType())) + AST->copyValue(PointerAllocaValues[AllocaNo], PN); + + return true; +} + + +// RenamePass - Recursively traverse the CFG of the function, renaming loads and +// stores to the allocas which we are promoting. IncomingVals indicates what +// value each Alloca contains on exit from the predecessor block Pred. +// +void PromoteMem2Reg::RenamePass(BasicBlock *BB, BasicBlock *Pred, + std::vector &IncomingVals) { + // If we are inserting any phi nodes into this BB, they will already be in the + // block. + if (PHINode *APN = dyn_cast(BB->begin())) { + // Pred may have multiple edges to BB. If so, we want to add N incoming + // values to each PHI we are inserting on the first time we see the edge. + // Check to see if APN already has incoming values from Pred. This also + // prevents us from modifying PHI nodes that are not currently being + // inserted. + bool HasPredEntries = false; + for (unsigned i = 0, e = APN->getNumIncomingValues(); i != e; ++i) { + if (APN->getIncomingBlock(i) == Pred) { + HasPredEntries = true; + break; + } + } + + // If we have PHI nodes to update, compute the number of edges from Pred to + // BB. + if (!HasPredEntries) { + TerminatorInst *PredTerm = Pred->getTerminator(); + unsigned NumEdges = 0; + for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) { + if (PredTerm->getSuccessor(i) == BB) + ++NumEdges; + } + assert(NumEdges && "Must be at least one edge from Pred to BB!"); + + // Add entries for all the phis. + BasicBlock::iterator PNI = BB->begin(); + do { + unsigned AllocaNo = PhiToAllocaMap[APN]; + + // Add N incoming values to the PHI node. + for (unsigned i = 0; i != NumEdges; ++i) + APN->addIncoming(IncomingVals[AllocaNo], Pred); + + // The currently active variable for this block is now the PHI. + IncomingVals[AllocaNo] = APN; + + // Get the next phi node. + ++PNI; + APN = dyn_cast(PNI); + if (APN == 0) break; + + // Verify it doesn't already have entries for Pred. If it does, it is + // not being inserted by this mem2reg invocation. + HasPredEntries = false; + for (unsigned i = 0, e = APN->getNumIncomingValues(); i != e; ++i) { + if (APN->getIncomingBlock(i) == Pred) { + HasPredEntries = true; + break; + } + } + } while (!HasPredEntries); + } + } + + // Don't revisit blocks. + if (!Visited.insert(BB)) return; + + for (BasicBlock::iterator II = BB->begin(); !isa(II); ) { + Instruction *I = II++; // get the instruction, increment iterator + + if (LoadInst *LI = dyn_cast(I)) { + if (AllocaInst *Src = dyn_cast(LI->getPointerOperand())) { + std::map::iterator AI = AllocaLookup.find(Src); + if (AI != AllocaLookup.end()) { + Value *V = IncomingVals[AI->second]; + + // walk the use list of this load and replace all uses with r + LI->replaceAllUsesWith(V); + if (AST && isa(LI->getType())) + AST->deleteValue(LI); + BB->getInstList().erase(LI); + } + } + } else if (StoreInst *SI = dyn_cast(I)) { + // Delete this instruction and mark the name as the current holder of the + // value + if (AllocaInst *Dest = dyn_cast(SI->getPointerOperand())) { + std::map::iterator ai = AllocaLookup.find(Dest); + if (ai != AllocaLookup.end()) { + // what value were we writing? + IncomingVals[ai->second] = SI->getOperand(0); + BB->getInstList().erase(SI); + } + } + } + } + + // Recurse to our successors. + TerminatorInst *TI = BB->getTerminator(); + for (unsigned i = 0; i != TI->getNumSuccessors(); i++) + RenamePassWorkList.push_back(RenamePassData(TI->getSuccessor(i), BB, IncomingVals)); +} + +/// PromoteMemToReg - Promote the specified list of alloca instructions into +/// scalar registers, inserting PHI nodes as appropriate. This function makes +/// use of DominanceFrontier information. This function does not modify the CFG +/// of the function at all. All allocas must be from the same function. +/// +/// If AST is specified, the specified tracker is updated to reflect changes +/// made to the IR. +/// +void llvm::PromoteMemToReg(const std::vector &Allocas, + DominatorTree &DT, DominanceFrontier &DF, + AliasSetTracker *AST) { + // If there is nothing to do, bail out... + if (Allocas.empty()) return; + + SmallVector RetryList; + PromoteMem2Reg(Allocas, RetryList, DT, DF, AST).run(); + + // PromoteMem2Reg may not have been able to promote all of the allocas in one + // pass, run it again if needed. + std::vector NewAllocas; + while (!RetryList.empty()) { + // If we need to retry some allocas, this is due to there being no store + // before a read in a local block. To counteract this, insert a store of + // undef into the alloca right after the alloca itself. + for (unsigned i = 0, e = RetryList.size(); i != e; ++i) { + BasicBlock::iterator BBI = RetryList[i]; + + new StoreInst(UndefValue::get(RetryList[i]->getAllocatedType()), + RetryList[i], ++BBI); + } + + NewAllocas.assign(RetryList.begin(), RetryList.end()); + RetryList.clear(); + PromoteMem2Reg(NewAllocas, RetryList, DT, DF, AST).run(); + NewAllocas.clear(); + } +} diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp new file mode 100644 index 0000000..6c34d02 --- /dev/null +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -0,0 +1,1905 @@ +//===- SimplifyCFG.cpp - Code to perform CFG simplification ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Peephole optimize the CFG. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "simplifycfg" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/Type.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Debug.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include +#include +#include +#include +using namespace llvm; + +/// SafeToMergeTerminators - Return true if it is safe to merge these two +/// terminator instructions together. +/// +static bool SafeToMergeTerminators(TerminatorInst *SI1, TerminatorInst *SI2) { + if (SI1 == SI2) return false; // Can't merge with self! + + // It is not safe to merge these two switch instructions if they have a common + // successor, and if that successor has a PHI node, and if *that* PHI node has + // conflicting incoming values from the two switch blocks. + BasicBlock *SI1BB = SI1->getParent(); + BasicBlock *SI2BB = SI2->getParent(); + SmallPtrSet SI1Succs(succ_begin(SI1BB), succ_end(SI1BB)); + + for (succ_iterator I = succ_begin(SI2BB), E = succ_end(SI2BB); I != E; ++I) + if (SI1Succs.count(*I)) + for (BasicBlock::iterator BBI = (*I)->begin(); + isa(BBI); ++BBI) { + PHINode *PN = cast(BBI); + if (PN->getIncomingValueForBlock(SI1BB) != + PN->getIncomingValueForBlock(SI2BB)) + return false; + } + + return true; +} + +/// AddPredecessorToBlock - Update PHI nodes in Succ to indicate that there will +/// now be entries in it from the 'NewPred' block. The values that will be +/// flowing into the PHI nodes will be the same as those coming in from +/// ExistPred, an existing predecessor of Succ. +static void AddPredecessorToBlock(BasicBlock *Succ, BasicBlock *NewPred, + BasicBlock *ExistPred) { + assert(std::find(succ_begin(ExistPred), succ_end(ExistPred), Succ) != + succ_end(ExistPred) && "ExistPred is not a predecessor of Succ!"); + if (!isa(Succ->begin())) return; // Quick exit if nothing to do + + for (BasicBlock::iterator I = Succ->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + Value *V = PN->getIncomingValueForBlock(ExistPred); + PN->addIncoming(V, NewPred); + } +} + +// CanPropagatePredecessorsForPHIs - Return true if we can fold BB, an +// almost-empty BB ending in an unconditional branch to Succ, into succ. +// +// Assumption: Succ is the single successor for BB. +// +static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) { + assert(*succ_begin(BB) == Succ && "Succ is not successor of BB!"); + + // Check to see if one of the predecessors of BB is already a predecessor of + // Succ. If so, we cannot do the transformation if there are any PHI nodes + // with incompatible values coming in from the two edges! + // + if (isa(Succ->front())) { + SmallPtrSet BBPreds(pred_begin(BB), pred_end(BB)); + for (pred_iterator PI = pred_begin(Succ), PE = pred_end(Succ); + PI != PE; ++PI) + if (BBPreds.count(*PI)) { + // Loop over all of the PHI nodes checking to see if there are + // incompatible values coming in. + for (BasicBlock::iterator I = Succ->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + // Loop up the entries in the PHI node for BB and for *PI if the + // values coming in are non-equal, we cannot merge these two blocks + // (instead we should insert a conditional move or something, then + // merge the blocks). + if (PN->getIncomingValueForBlock(BB) != + PN->getIncomingValueForBlock(*PI)) + return false; // Values are not equal... + } + } + } + + // Finally, if BB has PHI nodes that are used by things other than the PHIs in + // Succ and Succ has predecessors that are not Succ and not Pred, we cannot + // fold these blocks, as we don't know whether BB dominates Succ or not to + // update the PHI nodes correctly. + if (!isa(BB->begin()) || Succ->getSinglePredecessor()) return true; + + // If the predecessors of Succ are only BB and Succ itself, handle it. + bool IsSafe = true; + for (pred_iterator PI = pred_begin(Succ), E = pred_end(Succ); PI != E; ++PI) + if (*PI != Succ && *PI != BB) { + IsSafe = false; + break; + } + if (IsSafe) return true; + + // If the PHI nodes in BB are only used by instructions in Succ, we are ok if + // BB and Succ have no common predecessors. + for (BasicBlock::iterator I = BB->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + for (Value::use_iterator UI = PN->use_begin(), E = PN->use_end(); UI != E; + ++UI) + if (cast(*UI)->getParent() != Succ) + return false; + } + + // Scan the predecessor sets of BB and Succ, making sure there are no common + // predecessors. Common predecessors would cause us to build a phi node with + // differing incoming values, which is not legal. + SmallPtrSet BBPreds(pred_begin(BB), pred_end(BB)); + for (pred_iterator PI = pred_begin(Succ), E = pred_end(Succ); PI != E; ++PI) + if (BBPreds.count(*PI)) + return false; + + return true; +} + +/// TryToSimplifyUncondBranchFromEmptyBlock - BB contains an unconditional +/// branch to Succ, and contains no instructions other than PHI nodes and the +/// branch. If possible, eliminate BB. +static bool TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, + BasicBlock *Succ) { + // If our successor has PHI nodes, then we need to update them to include + // entries for BB's predecessors, not for BB itself. Be careful though, + // if this transformation fails (returns true) then we cannot do this + // transformation! + // + if (!CanPropagatePredecessorsForPHIs(BB, Succ)) return false; + + DOUT << "Killing Trivial BB: \n" << *BB; + + if (isa(Succ->begin())) { + // If there is more than one pred of succ, and there are PHI nodes in + // the successor, then we need to add incoming edges for the PHI nodes + // + const std::vector BBPreds(pred_begin(BB), pred_end(BB)); + + // Loop over all of the PHI nodes in the successor of BB. + for (BasicBlock::iterator I = Succ->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + Value *OldVal = PN->removeIncomingValue(BB, false); + assert(OldVal && "No entry in PHI for Pred BB!"); + + // If this incoming value is one of the PHI nodes in BB, the new entries + // in the PHI node are the entries from the old PHI. + if (isa(OldVal) && cast(OldVal)->getParent() == BB) { + PHINode *OldValPN = cast(OldVal); + for (unsigned i = 0, e = OldValPN->getNumIncomingValues(); i != e; ++i) + PN->addIncoming(OldValPN->getIncomingValue(i), + OldValPN->getIncomingBlock(i)); + } else { + for (std::vector::const_iterator PredI = BBPreds.begin(), + End = BBPreds.end(); PredI != End; ++PredI) { + // Add an incoming value for each of the new incoming values... + PN->addIncoming(OldVal, *PredI); + } + } + } + } + + if (isa(&BB->front())) { + std::vector + OldSuccPreds(pred_begin(Succ), pred_end(Succ)); + + // Move all PHI nodes in BB to Succ if they are alive, otherwise + // delete them. + while (PHINode *PN = dyn_cast(&BB->front())) + if (PN->use_empty()) { + // Just remove the dead phi. This happens if Succ's PHIs were the only + // users of the PHI nodes. + PN->eraseFromParent(); + } else { + // The instruction is alive, so this means that Succ must have + // *ONLY* had BB as a predecessor, and the PHI node is still valid + // now. Simply move it into Succ, because we know that BB + // strictly dominated Succ. + Succ->getInstList().splice(Succ->begin(), + BB->getInstList(), BB->begin()); + + // We need to add new entries for the PHI node to account for + // predecessors of Succ that the PHI node does not take into + // account. At this point, since we know that BB dominated succ, + // this means that we should any newly added incoming edges should + // use the PHI node as the value for these edges, because they are + // loop back edges. + for (unsigned i = 0, e = OldSuccPreds.size(); i != e; ++i) + if (OldSuccPreds[i] != BB) + PN->addIncoming(PN, OldSuccPreds[i]); + } + } + + // Everything that jumped to BB now goes to Succ. + BB->replaceAllUsesWith(Succ); + if (!Succ->hasName()) Succ->takeName(BB); + BB->eraseFromParent(); // Delete the old basic block. + return true; +} + +/// GetIfCondition - Given a basic block (BB) with two predecessors (and +/// presumably PHI nodes in it), check to see if the merge at this block is due +/// to an "if condition". If so, return the boolean condition that determines +/// which entry into BB will be taken. Also, return by references the block +/// that will be entered from if the condition is true, and the block that will +/// be entered if the condition is false. +/// +/// +static Value *GetIfCondition(BasicBlock *BB, + BasicBlock *&IfTrue, BasicBlock *&IfFalse) { + assert(std::distance(pred_begin(BB), pred_end(BB)) == 2 && + "Function can only handle blocks with 2 predecessors!"); + BasicBlock *Pred1 = *pred_begin(BB); + BasicBlock *Pred2 = *++pred_begin(BB); + + // We can only handle branches. Other control flow will be lowered to + // branches if possible anyway. + if (!isa(Pred1->getTerminator()) || + !isa(Pred2->getTerminator())) + return 0; + BranchInst *Pred1Br = cast(Pred1->getTerminator()); + BranchInst *Pred2Br = cast(Pred2->getTerminator()); + + // Eliminate code duplication by ensuring that Pred1Br is conditional if + // either are. + if (Pred2Br->isConditional()) { + // If both branches are conditional, we don't have an "if statement". In + // reality, we could transform this case, but since the condition will be + // required anyway, we stand no chance of eliminating it, so the xform is + // probably not profitable. + if (Pred1Br->isConditional()) + return 0; + + std::swap(Pred1, Pred2); + std::swap(Pred1Br, Pred2Br); + } + + if (Pred1Br->isConditional()) { + // If we found a conditional branch predecessor, make sure that it branches + // to BB and Pred2Br. If it doesn't, this isn't an "if statement". + if (Pred1Br->getSuccessor(0) == BB && + Pred1Br->getSuccessor(1) == Pred2) { + IfTrue = Pred1; + IfFalse = Pred2; + } else if (Pred1Br->getSuccessor(0) == Pred2 && + Pred1Br->getSuccessor(1) == BB) { + IfTrue = Pred2; + IfFalse = Pred1; + } else { + // We know that one arm of the conditional goes to BB, so the other must + // go somewhere unrelated, and this must not be an "if statement". + return 0; + } + + // The only thing we have to watch out for here is to make sure that Pred2 + // doesn't have incoming edges from other blocks. If it does, the condition + // doesn't dominate BB. + if (++pred_begin(Pred2) != pred_end(Pred2)) + return 0; + + return Pred1Br->getCondition(); + } + + // Ok, if we got here, both predecessors end with an unconditional branch to + // BB. Don't panic! If both blocks only have a single (identical) + // predecessor, and THAT is a conditional branch, then we're all ok! + if (pred_begin(Pred1) == pred_end(Pred1) || + ++pred_begin(Pred1) != pred_end(Pred1) || + pred_begin(Pred2) == pred_end(Pred2) || + ++pred_begin(Pred2) != pred_end(Pred2) || + *pred_begin(Pred1) != *pred_begin(Pred2)) + return 0; + + // Otherwise, if this is a conditional branch, then we can use it! + BasicBlock *CommonPred = *pred_begin(Pred1); + if (BranchInst *BI = dyn_cast(CommonPred->getTerminator())) { + assert(BI->isConditional() && "Two successors but not conditional?"); + if (BI->getSuccessor(0) == Pred1) { + IfTrue = Pred1; + IfFalse = Pred2; + } else { + IfTrue = Pred2; + IfFalse = Pred1; + } + return BI->getCondition(); + } + return 0; +} + + +// If we have a merge point of an "if condition" as accepted above, return true +// if the specified value dominates the block. We don't handle the true +// generality of domination here, just a special case which works well enough +// for us. +// +// If AggressiveInsts is non-null, and if V does not dominate BB, we check to +// see if V (which must be an instruction) is cheap to compute and is +// non-trapping. If both are true, the instruction is inserted into the set and +// true is returned. +static bool DominatesMergePoint(Value *V, BasicBlock *BB, + std::set *AggressiveInsts) { + Instruction *I = dyn_cast(V); + if (!I) { + // Non-instructions all dominate instructions, but not all constantexprs + // can be executed unconditionally. + if (ConstantExpr *C = dyn_cast(V)) + if (C->canTrap()) + return false; + return true; + } + BasicBlock *PBB = I->getParent(); + + // We don't want to allow weird loops that might have the "if condition" in + // the bottom of this block. + if (PBB == BB) return false; + + // If this instruction is defined in a block that contains an unconditional + // branch to BB, then it must be in the 'conditional' part of the "if + // statement". + if (BranchInst *BI = dyn_cast(PBB->getTerminator())) + if (BI->isUnconditional() && BI->getSuccessor(0) == BB) { + if (!AggressiveInsts) return false; + // Okay, it looks like the instruction IS in the "condition". Check to + // see if its a cheap instruction to unconditionally compute, and if it + // only uses stuff defined outside of the condition. If so, hoist it out. + switch (I->getOpcode()) { + default: return false; // Cannot hoist this out safely. + case Instruction::Load: + // We can hoist loads that are non-volatile and obviously cannot trap. + if (cast(I)->isVolatile()) + return false; + if (!isa(I->getOperand(0)) && + !isa(I->getOperand(0))) + return false; + + // Finally, we have to check to make sure there are no instructions + // before the load in its basic block, as we are going to hoist the loop + // out to its predecessor. + if (PBB->begin() != BasicBlock::iterator(I)) + return false; + break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::ICmp: + case Instruction::FCmp: + break; // These are all cheap and non-trapping instructions. + } + + // Okay, we can only really hoist these out if their operands are not + // defined in the conditional region. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (!DominatesMergePoint(I->getOperand(i), BB, 0)) + return false; + // Okay, it's safe to do this! Remember this instruction. + AggressiveInsts->insert(I); + } + + return true; +} + +// GatherConstantSetEQs - Given a potentially 'or'd together collection of +// icmp_eq instructions that compare a value against a constant, return the +// value being compared, and stick the constant into the Values vector. +static Value *GatherConstantSetEQs(Value *V, std::vector &Values){ + if (Instruction *Inst = dyn_cast(V)) + if (Inst->getOpcode() == Instruction::ICmp && + cast(Inst)->getPredicate() == ICmpInst::ICMP_EQ) { + if (ConstantInt *C = dyn_cast(Inst->getOperand(1))) { + Values.push_back(C); + return Inst->getOperand(0); + } else if (ConstantInt *C = dyn_cast(Inst->getOperand(0))) { + Values.push_back(C); + return Inst->getOperand(1); + } + } else if (Inst->getOpcode() == Instruction::Or) { + if (Value *LHS = GatherConstantSetEQs(Inst->getOperand(0), Values)) + if (Value *RHS = GatherConstantSetEQs(Inst->getOperand(1), Values)) + if (LHS == RHS) + return LHS; + } + return 0; +} + +// GatherConstantSetNEs - Given a potentially 'and'd together collection of +// setne instructions that compare a value against a constant, return the value +// being compared, and stick the constant into the Values vector. +static Value *GatherConstantSetNEs(Value *V, std::vector &Values){ + if (Instruction *Inst = dyn_cast(V)) + if (Inst->getOpcode() == Instruction::ICmp && + cast(Inst)->getPredicate() == ICmpInst::ICMP_NE) { + if (ConstantInt *C = dyn_cast(Inst->getOperand(1))) { + Values.push_back(C); + return Inst->getOperand(0); + } else if (ConstantInt *C = dyn_cast(Inst->getOperand(0))) { + Values.push_back(C); + return Inst->getOperand(1); + } + } else if (Inst->getOpcode() == Instruction::And) { + if (Value *LHS = GatherConstantSetNEs(Inst->getOperand(0), Values)) + if (Value *RHS = GatherConstantSetNEs(Inst->getOperand(1), Values)) + if (LHS == RHS) + return LHS; + } + return 0; +} + + + +/// GatherValueComparisons - If the specified Cond is an 'and' or 'or' of a +/// bunch of comparisons of one value against constants, return the value and +/// the constants being compared. +static bool GatherValueComparisons(Instruction *Cond, Value *&CompVal, + std::vector &Values) { + if (Cond->getOpcode() == Instruction::Or) { + CompVal = GatherConstantSetEQs(Cond, Values); + + // Return true to indicate that the condition is true if the CompVal is + // equal to one of the constants. + return true; + } else if (Cond->getOpcode() == Instruction::And) { + CompVal = GatherConstantSetNEs(Cond, Values); + + // Return false to indicate that the condition is false if the CompVal is + // equal to one of the constants. + return false; + } + return false; +} + +/// ErasePossiblyDeadInstructionTree - If the specified instruction is dead and +/// has no side effects, nuke it. If it uses any instructions that become dead +/// because the instruction is now gone, nuke them too. +static void ErasePossiblyDeadInstructionTree(Instruction *I) { + if (!isInstructionTriviallyDead(I)) return; + + std::vector InstrsToInspect; + InstrsToInspect.push_back(I); + + while (!InstrsToInspect.empty()) { + I = InstrsToInspect.back(); + InstrsToInspect.pop_back(); + + if (!isInstructionTriviallyDead(I)) continue; + + // If I is in the work list multiple times, remove previous instances. + for (unsigned i = 0, e = InstrsToInspect.size(); i != e; ++i) + if (InstrsToInspect[i] == I) { + InstrsToInspect.erase(InstrsToInspect.begin()+i); + --i, --e; + } + + // Add operands of dead instruction to worklist. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Instruction *OpI = dyn_cast(I->getOperand(i))) + InstrsToInspect.push_back(OpI); + + // Remove dead instruction. + I->eraseFromParent(); + } +} + +// isValueEqualityComparison - Return true if the specified terminator checks to +// see if a value is equal to constant integer value. +static Value *isValueEqualityComparison(TerminatorInst *TI) { + if (SwitchInst *SI = dyn_cast(TI)) { + // Do not permit merging of large switch instructions into their + // predecessors unless there is only one predecessor. + if (SI->getNumSuccessors() * std::distance(pred_begin(SI->getParent()), + pred_end(SI->getParent())) > 128) + return 0; + + return SI->getCondition(); + } + if (BranchInst *BI = dyn_cast(TI)) + if (BI->isConditional() && BI->getCondition()->hasOneUse()) + if (ICmpInst *ICI = dyn_cast(BI->getCondition())) + if ((ICI->getPredicate() == ICmpInst::ICMP_EQ || + ICI->getPredicate() == ICmpInst::ICMP_NE) && + isa(ICI->getOperand(1))) + return ICI->getOperand(0); + return 0; +} + +// Given a value comparison instruction, decode all of the 'cases' that it +// represents and return the 'default' block. +static BasicBlock * +GetValueEqualityComparisonCases(TerminatorInst *TI, + std::vector > &Cases) { + if (SwitchInst *SI = dyn_cast(TI)) { + Cases.reserve(SI->getNumCases()); + for (unsigned i = 1, e = SI->getNumCases(); i != e; ++i) + Cases.push_back(std::make_pair(SI->getCaseValue(i), SI->getSuccessor(i))); + return SI->getDefaultDest(); + } + + BranchInst *BI = cast(TI); + ICmpInst *ICI = cast(BI->getCondition()); + Cases.push_back(std::make_pair(cast(ICI->getOperand(1)), + BI->getSuccessor(ICI->getPredicate() == + ICmpInst::ICMP_NE))); + return BI->getSuccessor(ICI->getPredicate() == ICmpInst::ICMP_EQ); +} + + +// EliminateBlockCases - Given a vector of bb/value pairs, remove any entries +// in the list that match the specified block. +static void EliminateBlockCases(BasicBlock *BB, + std::vector > &Cases) { + for (unsigned i = 0, e = Cases.size(); i != e; ++i) + if (Cases[i].second == BB) { + Cases.erase(Cases.begin()+i); + --i; --e; + } +} + +// ValuesOverlap - Return true if there are any keys in C1 that exist in C2 as +// well. +static bool +ValuesOverlap(std::vector > &C1, + std::vector > &C2) { + std::vector > *V1 = &C1, *V2 = &C2; + + // Make V1 be smaller than V2. + if (V1->size() > V2->size()) + std::swap(V1, V2); + + if (V1->size() == 0) return false; + if (V1->size() == 1) { + // Just scan V2. + ConstantInt *TheVal = (*V1)[0].first; + for (unsigned i = 0, e = V2->size(); i != e; ++i) + if (TheVal == (*V2)[i].first) + return true; + } + + // Otherwise, just sort both lists and compare element by element. + std::sort(V1->begin(), V1->end()); + std::sort(V2->begin(), V2->end()); + unsigned i1 = 0, i2 = 0, e1 = V1->size(), e2 = V2->size(); + while (i1 != e1 && i2 != e2) { + if ((*V1)[i1].first == (*V2)[i2].first) + return true; + if ((*V1)[i1].first < (*V2)[i2].first) + ++i1; + else + ++i2; + } + return false; +} + +// SimplifyEqualityComparisonWithOnlyPredecessor - If TI is known to be a +// terminator instruction and its block is known to only have a single +// predecessor block, check to see if that predecessor is also a value +// comparison with the same value, and if that comparison determines the outcome +// of this comparison. If so, simplify TI. This does a very limited form of +// jump threading. +static bool SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, + BasicBlock *Pred) { + Value *PredVal = isValueEqualityComparison(Pred->getTerminator()); + if (!PredVal) return false; // Not a value comparison in predecessor. + + Value *ThisVal = isValueEqualityComparison(TI); + assert(ThisVal && "This isn't a value comparison!!"); + if (ThisVal != PredVal) return false; // Different predicates. + + // Find out information about when control will move from Pred to TI's block. + std::vector > PredCases; + BasicBlock *PredDef = GetValueEqualityComparisonCases(Pred->getTerminator(), + PredCases); + EliminateBlockCases(PredDef, PredCases); // Remove default from cases. + + // Find information about how control leaves this block. + std::vector > ThisCases; + BasicBlock *ThisDef = GetValueEqualityComparisonCases(TI, ThisCases); + EliminateBlockCases(ThisDef, ThisCases); // Remove default from cases. + + // If TI's block is the default block from Pred's comparison, potentially + // simplify TI based on this knowledge. + if (PredDef == TI->getParent()) { + // If we are here, we know that the value is none of those cases listed in + // PredCases. If there are any cases in ThisCases that are in PredCases, we + // can simplify TI. + if (ValuesOverlap(PredCases, ThisCases)) { + if (BranchInst *BTI = dyn_cast(TI)) { + // Okay, one of the successors of this condbr is dead. Convert it to a + // uncond br. + assert(ThisCases.size() == 1 && "Branch can only have one case!"); + Value *Cond = BTI->getCondition(); + // Insert the new branch. + Instruction *NI = new BranchInst(ThisDef, TI); + + // Remove PHI node entries for the dead edge. + ThisCases[0].second->removePredecessor(TI->getParent()); + + DOUT << "Threading pred instr: " << *Pred->getTerminator() + << "Through successor TI: " << *TI << "Leaving: " << *NI << "\n"; + + TI->eraseFromParent(); // Nuke the old one. + // If condition is now dead, nuke it. + if (Instruction *CondI = dyn_cast(Cond)) + ErasePossiblyDeadInstructionTree(CondI); + return true; + + } else { + SwitchInst *SI = cast(TI); + // Okay, TI has cases that are statically dead, prune them away. + SmallPtrSet DeadCases; + for (unsigned i = 0, e = PredCases.size(); i != e; ++i) + DeadCases.insert(PredCases[i].first); + + DOUT << "Threading pred instr: " << *Pred->getTerminator() + << "Through successor TI: " << *TI; + + for (unsigned i = SI->getNumCases()-1; i != 0; --i) + if (DeadCases.count(SI->getCaseValue(i))) { + SI->getSuccessor(i)->removePredecessor(TI->getParent()); + SI->removeCase(i); + } + + DOUT << "Leaving: " << *TI << "\n"; + return true; + } + } + + } else { + // Otherwise, TI's block must correspond to some matched value. Find out + // which value (or set of values) this is. + ConstantInt *TIV = 0; + BasicBlock *TIBB = TI->getParent(); + for (unsigned i = 0, e = PredCases.size(); i != e; ++i) + if (PredCases[i].second == TIBB) + if (TIV == 0) + TIV = PredCases[i].first; + else + return false; // Cannot handle multiple values coming to this block. + assert(TIV && "No edge from pred to succ?"); + + // Okay, we found the one constant that our value can be if we get into TI's + // BB. Find out which successor will unconditionally be branched to. + BasicBlock *TheRealDest = 0; + for (unsigned i = 0, e = ThisCases.size(); i != e; ++i) + if (ThisCases[i].first == TIV) { + TheRealDest = ThisCases[i].second; + break; + } + + // If not handled by any explicit cases, it is handled by the default case. + if (TheRealDest == 0) TheRealDest = ThisDef; + + // Remove PHI node entries for dead edges. + BasicBlock *CheckEdge = TheRealDest; + for (succ_iterator SI = succ_begin(TIBB), e = succ_end(TIBB); SI != e; ++SI) + if (*SI != CheckEdge) + (*SI)->removePredecessor(TIBB); + else + CheckEdge = 0; + + // Insert the new branch. + Instruction *NI = new BranchInst(TheRealDest, TI); + + DOUT << "Threading pred instr: " << *Pred->getTerminator() + << "Through successor TI: " << *TI << "Leaving: " << *NI << "\n"; + Instruction *Cond = 0; + if (BranchInst *BI = dyn_cast(TI)) + Cond = dyn_cast(BI->getCondition()); + TI->eraseFromParent(); // Nuke the old one. + + if (Cond) ErasePossiblyDeadInstructionTree(Cond); + return true; + } + return false; +} + +// FoldValueComparisonIntoPredecessors - The specified terminator is a value +// equality comparison instruction (either a switch or a branch on "X == c"). +// See if any of the predecessors of the terminator block are value comparisons +// on the same value. If so, and if safe to do so, fold them together. +static bool FoldValueComparisonIntoPredecessors(TerminatorInst *TI) { + BasicBlock *BB = TI->getParent(); + Value *CV = isValueEqualityComparison(TI); // CondVal + assert(CV && "Not a comparison?"); + bool Changed = false; + + std::vector Preds(pred_begin(BB), pred_end(BB)); + while (!Preds.empty()) { + BasicBlock *Pred = Preds.back(); + Preds.pop_back(); + + // See if the predecessor is a comparison with the same value. + TerminatorInst *PTI = Pred->getTerminator(); + Value *PCV = isValueEqualityComparison(PTI); // PredCondVal + + if (PCV == CV && SafeToMergeTerminators(TI, PTI)) { + // Figure out which 'cases' to copy from SI to PSI. + std::vector > BBCases; + BasicBlock *BBDefault = GetValueEqualityComparisonCases(TI, BBCases); + + std::vector > PredCases; + BasicBlock *PredDefault = GetValueEqualityComparisonCases(PTI, PredCases); + + // Based on whether the default edge from PTI goes to BB or not, fill in + // PredCases and PredDefault with the new switch cases we would like to + // build. + std::vector NewSuccessors; + + if (PredDefault == BB) { + // If this is the default destination from PTI, only the edges in TI + // that don't occur in PTI, or that branch to BB will be activated. + std::set PTIHandled; + for (unsigned i = 0, e = PredCases.size(); i != e; ++i) + if (PredCases[i].second != BB) + PTIHandled.insert(PredCases[i].first); + else { + // The default destination is BB, we don't need explicit targets. + std::swap(PredCases[i], PredCases.back()); + PredCases.pop_back(); + --i; --e; + } + + // Reconstruct the new switch statement we will be building. + if (PredDefault != BBDefault) { + PredDefault->removePredecessor(Pred); + PredDefault = BBDefault; + NewSuccessors.push_back(BBDefault); + } + for (unsigned i = 0, e = BBCases.size(); i != e; ++i) + if (!PTIHandled.count(BBCases[i].first) && + BBCases[i].second != BBDefault) { + PredCases.push_back(BBCases[i]); + NewSuccessors.push_back(BBCases[i].second); + } + + } else { + // If this is not the default destination from PSI, only the edges + // in SI that occur in PSI with a destination of BB will be + // activated. + std::set PTIHandled; + for (unsigned i = 0, e = PredCases.size(); i != e; ++i) + if (PredCases[i].second == BB) { + PTIHandled.insert(PredCases[i].first); + std::swap(PredCases[i], PredCases.back()); + PredCases.pop_back(); + --i; --e; + } + + // Okay, now we know which constants were sent to BB from the + // predecessor. Figure out where they will all go now. + for (unsigned i = 0, e = BBCases.size(); i != e; ++i) + if (PTIHandled.count(BBCases[i].first)) { + // If this is one we are capable of getting... + PredCases.push_back(BBCases[i]); + NewSuccessors.push_back(BBCases[i].second); + PTIHandled.erase(BBCases[i].first);// This constant is taken care of + } + + // If there are any constants vectored to BB that TI doesn't handle, + // they must go to the default destination of TI. + for (std::set::iterator I = PTIHandled.begin(), + E = PTIHandled.end(); I != E; ++I) { + PredCases.push_back(std::make_pair(*I, BBDefault)); + NewSuccessors.push_back(BBDefault); + } + } + + // Okay, at this point, we know which new successor Pred will get. Make + // sure we update the number of entries in the PHI nodes for these + // successors. + for (unsigned i = 0, e = NewSuccessors.size(); i != e; ++i) + AddPredecessorToBlock(NewSuccessors[i], Pred, BB); + + // Now that the successors are updated, create the new Switch instruction. + SwitchInst *NewSI = new SwitchInst(CV, PredDefault, PredCases.size(),PTI); + for (unsigned i = 0, e = PredCases.size(); i != e; ++i) + NewSI->addCase(PredCases[i].first, PredCases[i].second); + + Instruction *DeadCond = 0; + if (BranchInst *BI = dyn_cast(PTI)) + // If PTI is a branch, remember the condition. + DeadCond = dyn_cast(BI->getCondition()); + Pred->getInstList().erase(PTI); + + // If the condition is dead now, remove the instruction tree. + if (DeadCond) ErasePossiblyDeadInstructionTree(DeadCond); + + // Okay, last check. If BB is still a successor of PSI, then we must + // have an infinite loop case. If so, add an infinitely looping block + // to handle the case to preserve the behavior of the code. + BasicBlock *InfLoopBlock = 0; + for (unsigned i = 0, e = NewSI->getNumSuccessors(); i != e; ++i) + if (NewSI->getSuccessor(i) == BB) { + if (InfLoopBlock == 0) { + // Insert it at the end of the loop, because it's either code, + // or it won't matter if it's hot. :) + InfLoopBlock = new BasicBlock("infloop", BB->getParent()); + new BranchInst(InfLoopBlock, InfLoopBlock); + } + NewSI->setSuccessor(i, InfLoopBlock); + } + + Changed = true; + } + } + return Changed; +} + +/// HoistThenElseCodeToIf - Given a conditional branch that goes to BB1 and +/// BB2, hoist any common code in the two blocks up into the branch block. The +/// caller of this function guarantees that BI's block dominates BB1 and BB2. +static bool HoistThenElseCodeToIf(BranchInst *BI) { + // This does very trivial matching, with limited scanning, to find identical + // instructions in the two blocks. In particular, we don't want to get into + // O(M*N) situations here where M and N are the sizes of BB1 and BB2. As + // such, we currently just scan for obviously identical instructions in an + // identical order. + BasicBlock *BB1 = BI->getSuccessor(0); // The true destination. + BasicBlock *BB2 = BI->getSuccessor(1); // The false destination + + Instruction *I1 = BB1->begin(), *I2 = BB2->begin(); + if (I1->getOpcode() != I2->getOpcode() || isa(I1) || + isa(I1) || !I1->isIdenticalTo(I2)) + return false; + + // If we get here, we can hoist at least one instruction. + BasicBlock *BIParent = BI->getParent(); + + do { + // If we are hoisting the terminator instruction, don't move one (making a + // broken BB), instead clone it, and remove BI. + if (isa(I1)) + goto HoistTerminator; + + // For a normal instruction, we just move one to right before the branch, + // then replace all uses of the other with the first. Finally, we remove + // the now redundant second instruction. + BIParent->getInstList().splice(BI, BB1->getInstList(), I1); + if (!I2->use_empty()) + I2->replaceAllUsesWith(I1); + BB2->getInstList().erase(I2); + + I1 = BB1->begin(); + I2 = BB2->begin(); + } while (I1->getOpcode() == I2->getOpcode() && I1->isIdenticalTo(I2)); + + return true; + +HoistTerminator: + // Okay, it is safe to hoist the terminator. + Instruction *NT = I1->clone(); + BIParent->getInstList().insert(BI, NT); + if (NT->getType() != Type::VoidTy) { + I1->replaceAllUsesWith(NT); + I2->replaceAllUsesWith(NT); + NT->takeName(I1); + } + + // Hoisting one of the terminators from our successor is a great thing. + // Unfortunately, the successors of the if/else blocks may have PHI nodes in + // them. If they do, all PHI entries for BB1/BB2 must agree for all PHI + // nodes, so we insert select instruction to compute the final result. + std::map, SelectInst*> InsertedSelects; + for (succ_iterator SI = succ_begin(BB1), E = succ_end(BB1); SI != E; ++SI) { + PHINode *PN; + for (BasicBlock::iterator BBI = SI->begin(); + (PN = dyn_cast(BBI)); ++BBI) { + Value *BB1V = PN->getIncomingValueForBlock(BB1); + Value *BB2V = PN->getIncomingValueForBlock(BB2); + if (BB1V != BB2V) { + // These values do not agree. Insert a select instruction before NT + // that determines the right value. + SelectInst *&SI = InsertedSelects[std::make_pair(BB1V, BB2V)]; + if (SI == 0) + SI = new SelectInst(BI->getCondition(), BB1V, BB2V, + BB1V->getName()+"."+BB2V->getName(), NT); + // Make the PHI node use the select for all incoming values for BB1/BB2 + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingBlock(i) == BB1 || PN->getIncomingBlock(i) == BB2) + PN->setIncomingValue(i, SI); + } + } + } + + // Update any PHI nodes in our new successors. + for (succ_iterator SI = succ_begin(BB1), E = succ_end(BB1); SI != E; ++SI) + AddPredecessorToBlock(*SI, BIParent, BB1); + + BI->eraseFromParent(); + return true; +} + +/// BlockIsSimpleEnoughToThreadThrough - Return true if we can thread a branch +/// across this block. +static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { + BranchInst *BI = cast(BB->getTerminator()); + unsigned Size = 0; + + // If this basic block contains anything other than a PHI (which controls the + // branch) and branch itself, bail out. FIXME: improve this in the future. + for (BasicBlock::iterator BBI = BB->begin(); &*BBI != BI; ++BBI, ++Size) { + if (Size > 10) return false; // Don't clone large BB's. + + // We can only support instructions that are do not define values that are + // live outside of the current basic block. + for (Value::use_iterator UI = BBI->use_begin(), E = BBI->use_end(); + UI != E; ++UI) { + Instruction *U = cast(*UI); + if (U->getParent() != BB || isa(U)) return false; + } + + // Looks ok, continue checking. + } + + return true; +} + +/// FoldCondBranchOnPHI - If we have a conditional branch on a PHI node value +/// that is defined in the same block as the branch and if any PHI entries are +/// constants, thread edges corresponding to that entry to be branches to their +/// ultimate destination. +static bool FoldCondBranchOnPHI(BranchInst *BI) { + BasicBlock *BB = BI->getParent(); + PHINode *PN = dyn_cast(BI->getCondition()); + // NOTE: we currently cannot transform this case if the PHI node is used + // outside of the block. + if (!PN || PN->getParent() != BB || !PN->hasOneUse()) + return false; + + // Degenerate case of a single entry PHI. + if (PN->getNumIncomingValues() == 1) { + if (PN->getIncomingValue(0) != PN) + PN->replaceAllUsesWith(PN->getIncomingValue(0)); + else + PN->replaceAllUsesWith(UndefValue::get(PN->getType())); + PN->eraseFromParent(); + return true; + } + + // Now we know that this block has multiple preds and two succs. + if (!BlockIsSimpleEnoughToThreadThrough(BB)) return false; + + // Okay, this is a simple enough basic block. See if any phi values are + // constants. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + ConstantInt *CB; + if ((CB = dyn_cast(PN->getIncomingValue(i))) && + CB->getType() == Type::Int1Ty) { + // Okay, we now know that all edges from PredBB should be revectored to + // branch to RealDest. + BasicBlock *PredBB = PN->getIncomingBlock(i); + BasicBlock *RealDest = BI->getSuccessor(!CB->getZExtValue()); + + if (RealDest == BB) continue; // Skip self loops. + + // The dest block might have PHI nodes, other predecessors and other + // difficult cases. Instead of being smart about this, just insert a new + // block that jumps to the destination block, effectively splitting + // the edge we are about to create. + BasicBlock *EdgeBB = new BasicBlock(RealDest->getName()+".critedge", + RealDest->getParent(), RealDest); + new BranchInst(RealDest, EdgeBB); + PHINode *PN; + for (BasicBlock::iterator BBI = RealDest->begin(); + (PN = dyn_cast(BBI)); ++BBI) { + Value *V = PN->getIncomingValueForBlock(BB); + PN->addIncoming(V, EdgeBB); + } + + // BB may have instructions that are being threaded over. Clone these + // instructions into EdgeBB. We know that there will be no uses of the + // cloned instructions outside of EdgeBB. + BasicBlock::iterator InsertPt = EdgeBB->begin(); + std::map TranslateMap; // Track translated values. + for (BasicBlock::iterator BBI = BB->begin(); &*BBI != BI; ++BBI) { + if (PHINode *PN = dyn_cast(BBI)) { + TranslateMap[PN] = PN->getIncomingValueForBlock(PredBB); + } else { + // Clone the instruction. + Instruction *N = BBI->clone(); + if (BBI->hasName()) N->setName(BBI->getName()+".c"); + + // Update operands due to translation. + for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) { + std::map::iterator PI = + TranslateMap.find(N->getOperand(i)); + if (PI != TranslateMap.end()) + N->setOperand(i, PI->second); + } + + // Check for trivial simplification. + if (Constant *C = ConstantFoldInstruction(N)) { + TranslateMap[BBI] = C; + delete N; // Constant folded away, don't need actual inst + } else { + // Insert the new instruction into its new home. + EdgeBB->getInstList().insert(InsertPt, N); + if (!BBI->use_empty()) + TranslateMap[BBI] = N; + } + } + } + + // Loop over all of the edges from PredBB to BB, changing them to branch + // to EdgeBB instead. + TerminatorInst *PredBBTI = PredBB->getTerminator(); + for (unsigned i = 0, e = PredBBTI->getNumSuccessors(); i != e; ++i) + if (PredBBTI->getSuccessor(i) == BB) { + BB->removePredecessor(PredBB); + PredBBTI->setSuccessor(i, EdgeBB); + } + + // Recurse, simplifying any other constants. + return FoldCondBranchOnPHI(BI) | true; + } + } + + return false; +} + +/// FoldTwoEntryPHINode - Given a BB that starts with the specified two-entry +/// PHI node, see if we can eliminate it. +static bool FoldTwoEntryPHINode(PHINode *PN) { + // Ok, this is a two entry PHI node. Check to see if this is a simple "if + // statement", which has a very simple dominance structure. Basically, we + // are trying to find the condition that is being branched on, which + // subsequently causes this merge to happen. We really want control + // dependence information for this check, but simplifycfg can't keep it up + // to date, and this catches most of the cases we care about anyway. + // + BasicBlock *BB = PN->getParent(); + BasicBlock *IfTrue, *IfFalse; + Value *IfCond = GetIfCondition(BB, IfTrue, IfFalse); + if (!IfCond) return false; + + // Okay, we found that we can merge this two-entry phi node into a select. + // Doing so would require us to fold *all* two entry phi nodes in this block. + // At some point this becomes non-profitable (particularly if the target + // doesn't support cmov's). Only do this transformation if there are two or + // fewer PHI nodes in this block. + unsigned NumPhis = 0; + for (BasicBlock::iterator I = BB->begin(); isa(I); ++NumPhis, ++I) + if (NumPhis > 2) + return false; + + DOUT << "FOUND IF CONDITION! " << *IfCond << " T: " + << IfTrue->getName() << " F: " << IfFalse->getName() << "\n"; + + // Loop over the PHI's seeing if we can promote them all to select + // instructions. While we are at it, keep track of the instructions + // that need to be moved to the dominating block. + std::set AggressiveInsts; + + BasicBlock::iterator AfterPHIIt = BB->begin(); + while (isa(AfterPHIIt)) { + PHINode *PN = cast(AfterPHIIt++); + if (PN->getIncomingValue(0) == PN->getIncomingValue(1)) { + if (PN->getIncomingValue(0) != PN) + PN->replaceAllUsesWith(PN->getIncomingValue(0)); + else + PN->replaceAllUsesWith(UndefValue::get(PN->getType())); + } else if (!DominatesMergePoint(PN->getIncomingValue(0), BB, + &AggressiveInsts) || + !DominatesMergePoint(PN->getIncomingValue(1), BB, + &AggressiveInsts)) { + return false; + } + } + + // If we all PHI nodes are promotable, check to make sure that all + // instructions in the predecessor blocks can be promoted as well. If + // not, we won't be able to get rid of the control flow, so it's not + // worth promoting to select instructions. + BasicBlock *DomBlock = 0, *IfBlock1 = 0, *IfBlock2 = 0; + PN = cast(BB->begin()); + BasicBlock *Pred = PN->getIncomingBlock(0); + if (cast(Pred->getTerminator())->isUnconditional()) { + IfBlock1 = Pred; + DomBlock = *pred_begin(Pred); + for (BasicBlock::iterator I = Pred->begin(); + !isa(I); ++I) + if (!AggressiveInsts.count(I)) { + // This is not an aggressive instruction that we can promote. + // Because of this, we won't be able to get rid of the control + // flow, so the xform is not worth it. + return false; + } + } + + Pred = PN->getIncomingBlock(1); + if (cast(Pred->getTerminator())->isUnconditional()) { + IfBlock2 = Pred; + DomBlock = *pred_begin(Pred); + for (BasicBlock::iterator I = Pred->begin(); + !isa(I); ++I) + if (!AggressiveInsts.count(I)) { + // This is not an aggressive instruction that we can promote. + // Because of this, we won't be able to get rid of the control + // flow, so the xform is not worth it. + return false; + } + } + + // If we can still promote the PHI nodes after this gauntlet of tests, + // do all of the PHI's now. + + // Move all 'aggressive' instructions, which are defined in the + // conditional parts of the if's up to the dominating block. + if (IfBlock1) { + DomBlock->getInstList().splice(DomBlock->getTerminator(), + IfBlock1->getInstList(), + IfBlock1->begin(), + IfBlock1->getTerminator()); + } + if (IfBlock2) { + DomBlock->getInstList().splice(DomBlock->getTerminator(), + IfBlock2->getInstList(), + IfBlock2->begin(), + IfBlock2->getTerminator()); + } + + while (PHINode *PN = dyn_cast(BB->begin())) { + // Change the PHI node into a select instruction. + Value *TrueVal = + PN->getIncomingValue(PN->getIncomingBlock(0) == IfFalse); + Value *FalseVal = + PN->getIncomingValue(PN->getIncomingBlock(0) == IfTrue); + + Value *NV = new SelectInst(IfCond, TrueVal, FalseVal, "", AfterPHIIt); + PN->replaceAllUsesWith(NV); + NV->takeName(PN); + + BB->getInstList().erase(PN); + } + return true; +} + +namespace { + /// ConstantIntOrdering - This class implements a stable ordering of constant + /// integers that does not depend on their address. This is important for + /// applications that sort ConstantInt's to ensure uniqueness. + struct ConstantIntOrdering { + bool operator()(const ConstantInt *LHS, const ConstantInt *RHS) const { + return LHS->getValue().ult(RHS->getValue()); + } + }; +} + +// SimplifyCFG - This function is used to do simplification of a CFG. For +// example, it adjusts branches to branches to eliminate the extra hop, it +// eliminates unreachable basic blocks, and does other "peephole" optimization +// of the CFG. It returns true if a modification was made. +// +// WARNING: The entry node of a function may not be simplified. +// +bool llvm::SimplifyCFG(BasicBlock *BB) { + bool Changed = false; + Function *M = BB->getParent(); + + assert(BB && BB->getParent() && "Block not embedded in function!"); + assert(BB->getTerminator() && "Degenerate basic block encountered!"); + assert(&BB->getParent()->getEntryBlock() != BB && + "Can't Simplify entry block!"); + + // Remove basic blocks that have no predecessors... which are unreachable. + if (pred_begin(BB) == pred_end(BB) || + *pred_begin(BB) == BB && ++pred_begin(BB) == pred_end(BB)) { + DOUT << "Removing BB: \n" << *BB; + + // Loop through all of our successors and make sure they know that one + // of their predecessors is going away. + for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI) + SI->removePredecessor(BB); + + while (!BB->empty()) { + Instruction &I = BB->back(); + // If this instruction is used, replace uses with an arbitrary + // value. Because control flow can't get here, we don't care + // what we replace the value with. Note that since this block is + // unreachable, and all values contained within it must dominate their + // uses, that all uses will eventually be removed. + if (!I.use_empty()) + // Make all users of this instruction use undef instead + I.replaceAllUsesWith(UndefValue::get(I.getType())); + + // Remove the instruction from the basic block + BB->getInstList().pop_back(); + } + M->getBasicBlockList().erase(BB); + return true; + } + + // Check to see if we can constant propagate this terminator instruction + // away... + Changed |= ConstantFoldTerminator(BB); + + // If this is a returning block with only PHI nodes in it, fold the return + // instruction into any unconditional branch predecessors. + // + // If any predecessor is a conditional branch that just selects among + // different return values, fold the replace the branch/return with a select + // and return. + if (ReturnInst *RI = dyn_cast(BB->getTerminator())) { + BasicBlock::iterator BBI = BB->getTerminator(); + if (BBI == BB->begin() || isa(--BBI)) { + // Find predecessors that end with branches. + std::vector UncondBranchPreds; + std::vector CondBranchPreds; + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { + TerminatorInst *PTI = (*PI)->getTerminator(); + if (BranchInst *BI = dyn_cast(PTI)) + if (BI->isUnconditional()) + UncondBranchPreds.push_back(*PI); + else + CondBranchPreds.push_back(BI); + } + + // If we found some, do the transformation! + if (!UncondBranchPreds.empty()) { + while (!UncondBranchPreds.empty()) { + BasicBlock *Pred = UncondBranchPreds.back(); + DOUT << "FOLDING: " << *BB + << "INTO UNCOND BRANCH PRED: " << *Pred; + UncondBranchPreds.pop_back(); + Instruction *UncondBranch = Pred->getTerminator(); + // Clone the return and add it to the end of the predecessor. + Instruction *NewRet = RI->clone(); + Pred->getInstList().push_back(NewRet); + + // If the return instruction returns a value, and if the value was a + // PHI node in "BB", propagate the right value into the return. + if (NewRet->getNumOperands() == 1) + if (PHINode *PN = dyn_cast(NewRet->getOperand(0))) + if (PN->getParent() == BB) + NewRet->setOperand(0, PN->getIncomingValueForBlock(Pred)); + // Update any PHI nodes in the returning block to realize that we no + // longer branch to them. + BB->removePredecessor(Pred); + Pred->getInstList().erase(UncondBranch); + } + + // If we eliminated all predecessors of the block, delete the block now. + if (pred_begin(BB) == pred_end(BB)) + // We know there are no successors, so just nuke the block. + M->getBasicBlockList().erase(BB); + + return true; + } + + // Check out all of the conditional branches going to this return + // instruction. If any of them just select between returns, change the + // branch itself into a select/return pair. + while (!CondBranchPreds.empty()) { + BranchInst *BI = CondBranchPreds.back(); + CondBranchPreds.pop_back(); + BasicBlock *TrueSucc = BI->getSuccessor(0); + BasicBlock *FalseSucc = BI->getSuccessor(1); + BasicBlock *OtherSucc = TrueSucc == BB ? FalseSucc : TrueSucc; + + // Check to see if the non-BB successor is also a return block. + if (isa(OtherSucc->getTerminator())) { + // Check to see if there are only PHI instructions in this block. + BasicBlock::iterator OSI = OtherSucc->getTerminator(); + if (OSI == OtherSucc->begin() || isa(--OSI)) { + // Okay, we found a branch that is going to two return nodes. If + // there is no return value for this function, just change the + // branch into a return. + if (RI->getNumOperands() == 0) { + TrueSucc->removePredecessor(BI->getParent()); + FalseSucc->removePredecessor(BI->getParent()); + new ReturnInst(0, BI); + BI->getParent()->getInstList().erase(BI); + return true; + } + + // Otherwise, figure out what the true and false return values are + // so we can insert a new select instruction. + Value *TrueValue = TrueSucc->getTerminator()->getOperand(0); + Value *FalseValue = FalseSucc->getTerminator()->getOperand(0); + + // Unwrap any PHI nodes in the return blocks. + if (PHINode *TVPN = dyn_cast(TrueValue)) + if (TVPN->getParent() == TrueSucc) + TrueValue = TVPN->getIncomingValueForBlock(BI->getParent()); + if (PHINode *FVPN = dyn_cast(FalseValue)) + if (FVPN->getParent() == FalseSucc) + FalseValue = FVPN->getIncomingValueForBlock(BI->getParent()); + + // In order for this transformation to be safe, we must be able to + // unconditionally execute both operands to the return. This is + // normally the case, but we could have a potentially-trapping + // constant expression that prevents this transformation from being + // safe. + if ((!isa(TrueValue) || + !cast(TrueValue)->canTrap()) && + (!isa(TrueValue) || + !cast(TrueValue)->canTrap())) { + TrueSucc->removePredecessor(BI->getParent()); + FalseSucc->removePredecessor(BI->getParent()); + + // Insert a new select instruction. + Value *NewRetVal; + Value *BrCond = BI->getCondition(); + if (TrueValue != FalseValue) + NewRetVal = new SelectInst(BrCond, TrueValue, + FalseValue, "retval", BI); + else + NewRetVal = TrueValue; + + DOUT << "\nCHANGING BRANCH TO TWO RETURNS INTO SELECT:" + << "\n " << *BI << "Select = " << *NewRetVal + << "TRUEBLOCK: " << *TrueSucc << "FALSEBLOCK: "<< *FalseSucc; + + new ReturnInst(NewRetVal, BI); + BI->eraseFromParent(); + if (Instruction *BrCondI = dyn_cast(BrCond)) + if (isInstructionTriviallyDead(BrCondI)) + BrCondI->eraseFromParent(); + return true; + } + } + } + } + } + } else if (isa(BB->begin())) { + // Check to see if the first instruction in this block is just an unwind. + // If so, replace any invoke instructions which use this as an exception + // destination with call instructions, and any unconditional branch + // predecessor with an unwind. + // + std::vector Preds(pred_begin(BB), pred_end(BB)); + while (!Preds.empty()) { + BasicBlock *Pred = Preds.back(); + if (BranchInst *BI = dyn_cast(Pred->getTerminator())) { + if (BI->isUnconditional()) { + Pred->getInstList().pop_back(); // nuke uncond branch + new UnwindInst(Pred); // Use unwind. + Changed = true; + } + } else if (InvokeInst *II = dyn_cast(Pred->getTerminator())) + if (II->getUnwindDest() == BB) { + // Insert a new branch instruction before the invoke, because this + // is now a fall through... + BranchInst *BI = new BranchInst(II->getNormalDest(), II); + Pred->getInstList().remove(II); // Take out of symbol table + + // Insert the call now... + SmallVector Args(II->op_begin()+3, II->op_end()); + CallInst *CI = new CallInst(II->getCalledValue(), + &Args[0], Args.size(), II->getName(), BI); + CI->setCallingConv(II->getCallingConv()); + // If the invoke produced a value, the Call now does instead + II->replaceAllUsesWith(CI); + delete II; + Changed = true; + } + + Preds.pop_back(); + } + + // If this block is now dead, remove it. + if (pred_begin(BB) == pred_end(BB)) { + // We know there are no successors, so just nuke the block. + M->getBasicBlockList().erase(BB); + return true; + } + + } else if (SwitchInst *SI = dyn_cast(BB->getTerminator())) { + if (isValueEqualityComparison(SI)) { + // If we only have one predecessor, and if it is a branch on this value, + // see if that predecessor totally determines the outcome of this switch. + if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) + if (SimplifyEqualityComparisonWithOnlyPredecessor(SI, OnlyPred)) + return SimplifyCFG(BB) || 1; + + // If the block only contains the switch, see if we can fold the block + // away into any preds. + if (SI == &BB->front()) + if (FoldValueComparisonIntoPredecessors(SI)) + return SimplifyCFG(BB) || 1; + } + } else if (BranchInst *BI = dyn_cast(BB->getTerminator())) { + if (BI->isUnconditional()) { + BasicBlock::iterator BBI = BB->begin(); // Skip over phi nodes... + while (isa(*BBI)) ++BBI; + + BasicBlock *Succ = BI->getSuccessor(0); + if (BBI->isTerminator() && // Terminator is the only non-phi instruction! + Succ != BB) // Don't hurt infinite loops! + if (TryToSimplifyUncondBranchFromEmptyBlock(BB, Succ)) + return 1; + + } else { // Conditional branch + if (isValueEqualityComparison(BI)) { + // If we only have one predecessor, and if it is a branch on this value, + // see if that predecessor totally determines the outcome of this + // switch. + if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) + if (SimplifyEqualityComparisonWithOnlyPredecessor(BI, OnlyPred)) + return SimplifyCFG(BB) || 1; + + // This block must be empty, except for the setcond inst, if it exists. + BasicBlock::iterator I = BB->begin(); + if (&*I == BI || + (&*I == cast(BI->getCondition()) && + &*++I == BI)) + if (FoldValueComparisonIntoPredecessors(BI)) + return SimplifyCFG(BB) | true; + } + + // If this is a branch on a phi node in the current block, thread control + // through this block if any PHI node entries are constants. + if (PHINode *PN = dyn_cast(BI->getCondition())) + if (PN->getParent() == BI->getParent()) + if (FoldCondBranchOnPHI(BI)) + return SimplifyCFG(BB) | true; + + // If this basic block is ONLY a setcc and a branch, and if a predecessor + // branches to us and one of our successors, fold the setcc into the + // predecessor and use logical operations to pick the right destination. + BasicBlock *TrueDest = BI->getSuccessor(0); + BasicBlock *FalseDest = BI->getSuccessor(1); + if (Instruction *Cond = dyn_cast(BI->getCondition())) { + BasicBlock::iterator CondIt = Cond; + if ((isa(Cond) || isa(Cond)) && + Cond->getParent() == BB && &BB->front() == Cond && + &*++CondIt == BI && Cond->hasOneUse() && + TrueDest != BB && FalseDest != BB) + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI!=E; ++PI) + if (BranchInst *PBI = dyn_cast((*PI)->getTerminator())) + if (PBI->isConditional() && SafeToMergeTerminators(BI, PBI)) { + BasicBlock *PredBlock = *PI; + if (PBI->getSuccessor(0) == FalseDest || + PBI->getSuccessor(1) == TrueDest) { + // Invert the predecessors condition test (xor it with true), + // which allows us to write this code once. + Value *NewCond = + BinaryOperator::createNot(PBI->getCondition(), + PBI->getCondition()->getName()+".not", PBI); + PBI->setCondition(NewCond); + BasicBlock *OldTrue = PBI->getSuccessor(0); + BasicBlock *OldFalse = PBI->getSuccessor(1); + PBI->setSuccessor(0, OldFalse); + PBI->setSuccessor(1, OldTrue); + } + + if ((PBI->getSuccessor(0) == TrueDest && FalseDest != BB) || + (PBI->getSuccessor(1) == FalseDest && TrueDest != BB)) { + // Clone Cond into the predecessor basic block, and or/and the + // two conditions together. + Instruction *New = Cond->clone(); + PredBlock->getInstList().insert(PBI, New); + New->takeName(Cond); + Cond->setName(New->getName()+".old"); + Instruction::BinaryOps Opcode = + PBI->getSuccessor(0) == TrueDest ? + Instruction::Or : Instruction::And; + Value *NewCond = + BinaryOperator::create(Opcode, PBI->getCondition(), + New, "bothcond", PBI); + PBI->setCondition(NewCond); + if (PBI->getSuccessor(0) == BB) { + AddPredecessorToBlock(TrueDest, PredBlock, BB); + PBI->setSuccessor(0, TrueDest); + } + if (PBI->getSuccessor(1) == BB) { + AddPredecessorToBlock(FalseDest, PredBlock, BB); + PBI->setSuccessor(1, FalseDest); + } + return SimplifyCFG(BB) | 1; + } + } + } + + // Scan predessor blocks for conditional branches. + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) + if (BranchInst *PBI = dyn_cast((*PI)->getTerminator())) + if (PBI != BI && PBI->isConditional()) { + + // If this block ends with a branch instruction, and if there is a + // predecessor that ends on a branch of the same condition, make + // this conditional branch redundant. + if (PBI->getCondition() == BI->getCondition() && + PBI->getSuccessor(0) != PBI->getSuccessor(1)) { + // Okay, the outcome of this conditional branch is statically + // knowable. If this block had a single pred, handle specially. + if (BB->getSinglePredecessor()) { + // Turn this into a branch on constant. + bool CondIsTrue = PBI->getSuccessor(0) == BB; + BI->setCondition(ConstantInt::get(Type::Int1Ty, CondIsTrue)); + return SimplifyCFG(BB); // Nuke the branch on constant. + } + + // Otherwise, if there are multiple predecessors, insert a PHI + // that merges in the constant and simplify the block result. + if (BlockIsSimpleEnoughToThreadThrough(BB)) { + PHINode *NewPN = new PHINode(Type::Int1Ty, + BI->getCondition()->getName()+".pr", + BB->begin()); + for (PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) + if ((PBI = dyn_cast((*PI)->getTerminator())) && + PBI != BI && PBI->isConditional() && + PBI->getCondition() == BI->getCondition() && + PBI->getSuccessor(0) != PBI->getSuccessor(1)) { + bool CondIsTrue = PBI->getSuccessor(0) == BB; + NewPN->addIncoming(ConstantInt::get(Type::Int1Ty, + CondIsTrue), *PI); + } else { + NewPN->addIncoming(BI->getCondition(), *PI); + } + + BI->setCondition(NewPN); + // This will thread the branch. + return SimplifyCFG(BB) | true; + } + } + + // If this is a conditional branch in an empty block, and if any + // predecessors is a conditional branch to one of our destinations, + // fold the conditions into logical ops and one cond br. + if (&BB->front() == BI) { + int PBIOp, BIOp; + if (PBI->getSuccessor(0) == BI->getSuccessor(0)) { + PBIOp = BIOp = 0; + } else if (PBI->getSuccessor(0) == BI->getSuccessor(1)) { + PBIOp = 0; BIOp = 1; + } else if (PBI->getSuccessor(1) == BI->getSuccessor(0)) { + PBIOp = 1; BIOp = 0; + } else if (PBI->getSuccessor(1) == BI->getSuccessor(1)) { + PBIOp = BIOp = 1; + } else { + PBIOp = BIOp = -1; + } + + // Check to make sure that the other destination of this branch + // isn't BB itself. If so, this is an infinite loop that will + // keep getting unwound. + if (PBIOp != -1 && PBI->getSuccessor(PBIOp) == BB) + PBIOp = BIOp = -1; + + // Do not perform this transformation if it would require + // insertion of a large number of select instructions. For targets + // without predication/cmovs, this is a big pessimization. + if (PBIOp != -1) { + BasicBlock *CommonDest = PBI->getSuccessor(PBIOp); + + unsigned NumPhis = 0; + for (BasicBlock::iterator II = CommonDest->begin(); + isa(II); ++II, ++NumPhis) { + if (NumPhis > 2) { + // Disable this xform. + PBIOp = -1; + break; + } + } + } + + // Finally, if everything is ok, fold the branches to logical ops. + if (PBIOp != -1) { + BasicBlock *CommonDest = PBI->getSuccessor(PBIOp); + BasicBlock *OtherDest = BI->getSuccessor(BIOp ^ 1); + + // If OtherDest *is* BB, then this is a basic block with just + // a conditional branch in it, where one edge (OtherDesg) goes + // back to the block. We know that the program doesn't get + // stuck in the infinite loop, so the condition must be such + // that OtherDest isn't branched through. Forward to CommonDest, + // and avoid an infinite loop at optimizer time. + if (OtherDest == BB) + OtherDest = CommonDest; + + DOUT << "FOLDING BRs:" << *PBI->getParent() + << "AND: " << *BI->getParent(); + + // BI may have other predecessors. Because of this, we leave + // it alone, but modify PBI. + + // Make sure we get to CommonDest on True&True directions. + Value *PBICond = PBI->getCondition(); + if (PBIOp) + PBICond = BinaryOperator::createNot(PBICond, + PBICond->getName()+".not", + PBI); + Value *BICond = BI->getCondition(); + if (BIOp) + BICond = BinaryOperator::createNot(BICond, + BICond->getName()+".not", + PBI); + // Merge the conditions. + Value *Cond = + BinaryOperator::createOr(PBICond, BICond, "brmerge", PBI); + + // Modify PBI to branch on the new condition to the new dests. + PBI->setCondition(Cond); + PBI->setSuccessor(0, CommonDest); + PBI->setSuccessor(1, OtherDest); + + // OtherDest may have phi nodes. If so, add an entry from PBI's + // block that are identical to the entries for BI's block. + PHINode *PN; + for (BasicBlock::iterator II = OtherDest->begin(); + (PN = dyn_cast(II)); ++II) { + Value *V = PN->getIncomingValueForBlock(BB); + PN->addIncoming(V, PBI->getParent()); + } + + // We know that the CommonDest already had an edge from PBI to + // it. If it has PHIs though, the PHIs may have different + // entries for BB and PBI's BB. If so, insert a select to make + // them agree. + for (BasicBlock::iterator II = CommonDest->begin(); + (PN = dyn_cast(II)); ++II) { + Value * BIV = PN->getIncomingValueForBlock(BB); + unsigned PBBIdx = PN->getBasicBlockIndex(PBI->getParent()); + Value *PBIV = PN->getIncomingValue(PBBIdx); + if (BIV != PBIV) { + // Insert a select in PBI to pick the right value. + Value *NV = new SelectInst(PBICond, PBIV, BIV, + PBIV->getName()+".mux", PBI); + PN->setIncomingValue(PBBIdx, NV); + } + } + + DOUT << "INTO: " << *PBI->getParent(); + + // This basic block is probably dead. We know it has at least + // one fewer predecessor. + return SimplifyCFG(BB) | true; + } + } + } + } + } else if (isa(BB->getTerminator())) { + // If there are any instructions immediately before the unreachable that can + // be removed, do so. + Instruction *Unreachable = BB->getTerminator(); + while (Unreachable != BB->begin()) { + BasicBlock::iterator BBI = Unreachable; + --BBI; + if (isa(BBI)) break; + // Delete this instruction + BB->getInstList().erase(BBI); + Changed = true; + } + + // If the unreachable instruction is the first in the block, take a gander + // at all of the predecessors of this instruction, and simplify them. + if (&BB->front() == Unreachable) { + std::vector Preds(pred_begin(BB), pred_end(BB)); + for (unsigned i = 0, e = Preds.size(); i != e; ++i) { + TerminatorInst *TI = Preds[i]->getTerminator(); + + if (BranchInst *BI = dyn_cast(TI)) { + if (BI->isUnconditional()) { + if (BI->getSuccessor(0) == BB) { + new UnreachableInst(TI); + TI->eraseFromParent(); + Changed = true; + } + } else { + if (BI->getSuccessor(0) == BB) { + new BranchInst(BI->getSuccessor(1), BI); + BI->eraseFromParent(); + } else if (BI->getSuccessor(1) == BB) { + new BranchInst(BI->getSuccessor(0), BI); + BI->eraseFromParent(); + Changed = true; + } + } + } else if (SwitchInst *SI = dyn_cast(TI)) { + for (unsigned i = 1, e = SI->getNumCases(); i != e; ++i) + if (SI->getSuccessor(i) == BB) { + BB->removePredecessor(SI->getParent()); + SI->removeCase(i); + --i; --e; + Changed = true; + } + // If the default value is unreachable, figure out the most popular + // destination and make it the default. + if (SI->getSuccessor(0) == BB) { + std::map Popularity; + for (unsigned i = 1, e = SI->getNumCases(); i != e; ++i) + Popularity[SI->getSuccessor(i)]++; + + // Find the most popular block. + unsigned MaxPop = 0; + BasicBlock *MaxBlock = 0; + for (std::map::iterator + I = Popularity.begin(), E = Popularity.end(); I != E; ++I) { + if (I->second > MaxPop) { + MaxPop = I->second; + MaxBlock = I->first; + } + } + if (MaxBlock) { + // Make this the new default, allowing us to delete any explicit + // edges to it. + SI->setSuccessor(0, MaxBlock); + Changed = true; + + // If MaxBlock has phinodes in it, remove MaxPop-1 entries from + // it. + if (isa(MaxBlock->begin())) + for (unsigned i = 0; i != MaxPop-1; ++i) + MaxBlock->removePredecessor(SI->getParent()); + + for (unsigned i = 1, e = SI->getNumCases(); i != e; ++i) + if (SI->getSuccessor(i) == MaxBlock) { + SI->removeCase(i); + --i; --e; + } + } + } + } else if (InvokeInst *II = dyn_cast(TI)) { + if (II->getUnwindDest() == BB) { + // Convert the invoke to a call instruction. This would be a good + // place to note that the call does not throw though. + BranchInst *BI = new BranchInst(II->getNormalDest(), II); + II->removeFromParent(); // Take out of symbol table + + // Insert the call now... + SmallVector Args(II->op_begin()+3, II->op_end()); + CallInst *CI = new CallInst(II->getCalledValue(), + &Args[0], Args.size(), + II->getName(), BI); + CI->setCallingConv(II->getCallingConv()); + // If the invoke produced a value, the Call does now instead. + II->replaceAllUsesWith(CI); + delete II; + Changed = true; + } + } + } + + // If this block is now dead, remove it. + if (pred_begin(BB) == pred_end(BB)) { + // We know there are no successors, so just nuke the block. + M->getBasicBlockList().erase(BB); + return true; + } + } + } + + // Merge basic blocks into their predecessor if there is only one distinct + // pred, and if there is only one distinct successor of the predecessor, and + // if there are no PHI nodes. + // + pred_iterator PI(pred_begin(BB)), PE(pred_end(BB)); + BasicBlock *OnlyPred = *PI++; + for (; PI != PE; ++PI) // Search all predecessors, see if they are all same + if (*PI != OnlyPred) { + OnlyPred = 0; // There are multiple different predecessors... + break; + } + + BasicBlock *OnlySucc = 0; + if (OnlyPred && OnlyPred != BB && // Don't break self loops + OnlyPred->getTerminator()->getOpcode() != Instruction::Invoke) { + // Check to see if there is only one distinct successor... + succ_iterator SI(succ_begin(OnlyPred)), SE(succ_end(OnlyPred)); + OnlySucc = BB; + for (; SI != SE; ++SI) + if (*SI != OnlySucc) { + OnlySucc = 0; // There are multiple distinct successors! + break; + } + } + + if (OnlySucc) { + DOUT << "Merging: " << *BB << "into: " << *OnlyPred; + + // Resolve any PHI nodes at the start of the block. They are all + // guaranteed to have exactly one entry if they exist, unless there are + // multiple duplicate (but guaranteed to be equal) entries for the + // incoming edges. This occurs when there are multiple edges from + // OnlyPred to OnlySucc. + // + while (PHINode *PN = dyn_cast(&BB->front())) { + PN->replaceAllUsesWith(PN->getIncomingValue(0)); + BB->getInstList().pop_front(); // Delete the phi node. + } + + // Delete the unconditional branch from the predecessor. + OnlyPred->getInstList().pop_back(); + + // Move all definitions in the successor to the predecessor. + OnlyPred->getInstList().splice(OnlyPred->end(), BB->getInstList()); + + // Make all PHI nodes that referred to BB now refer to Pred as their + // source. + BB->replaceAllUsesWith(OnlyPred); + + // Inherit predecessors name if it exists. + if (!OnlyPred->hasName()) + OnlyPred->takeName(BB); + + // Erase basic block from the function. + M->getBasicBlockList().erase(BB); + + return true; + } + + // Otherwise, if this block only has a single predecessor, and if that block + // is a conditional branch, see if we can hoist any code from this block up + // into our predecessor. + if (OnlyPred) + if (BranchInst *BI = dyn_cast(OnlyPred->getTerminator())) + if (BI->isConditional()) { + // Get the other block. + BasicBlock *OtherBB = BI->getSuccessor(BI->getSuccessor(0) == BB); + PI = pred_begin(OtherBB); + ++PI; + if (PI == pred_end(OtherBB)) { + // We have a conditional branch to two blocks that are only reachable + // from the condbr. We know that the condbr dominates the two blocks, + // so see if there is any identical code in the "then" and "else" + // blocks. If so, we can hoist it up to the branching block. + Changed |= HoistThenElseCodeToIf(BI); + } + } + + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) + if (BranchInst *BI = dyn_cast((*PI)->getTerminator())) + // Change br (X == 0 | X == 1), T, F into a switch instruction. + if (BI->isConditional() && isa(BI->getCondition())) { + Instruction *Cond = cast(BI->getCondition()); + // If this is a bunch of seteq's or'd together, or if it's a bunch of + // 'setne's and'ed together, collect them. + Value *CompVal = 0; + std::vector Values; + bool TrueWhenEqual = GatherValueComparisons(Cond, CompVal, Values); + if (CompVal && CompVal->getType()->isInteger()) { + // There might be duplicate constants in the list, which the switch + // instruction can't handle, remove them now. + std::sort(Values.begin(), Values.end(), ConstantIntOrdering()); + Values.erase(std::unique(Values.begin(), Values.end()), Values.end()); + + // Figure out which block is which destination. + BasicBlock *DefaultBB = BI->getSuccessor(1); + BasicBlock *EdgeBB = BI->getSuccessor(0); + if (!TrueWhenEqual) std::swap(DefaultBB, EdgeBB); + + // Create the new switch instruction now. + SwitchInst *New = new SwitchInst(CompVal, DefaultBB,Values.size(),BI); + + // Add all of the 'cases' to the switch instruction. + for (unsigned i = 0, e = Values.size(); i != e; ++i) + New->addCase(Values[i], EdgeBB); + + // We added edges from PI to the EdgeBB. As such, if there were any + // PHI nodes in EdgeBB, they need entries to be added corresponding to + // the number of edges added. + for (BasicBlock::iterator BBI = EdgeBB->begin(); + isa(BBI); ++BBI) { + PHINode *PN = cast(BBI); + Value *InVal = PN->getIncomingValueForBlock(*PI); + for (unsigned i = 0, e = Values.size()-1; i != e; ++i) + PN->addIncoming(InVal, *PI); + } + + // Erase the old branch instruction. + (*PI)->getInstList().erase(BI); + + // Erase the potentially condition tree that was used to computed the + // branch condition. + ErasePossiblyDeadInstructionTree(Cond); + return true; + } + } + + // If there is a trivial two-entry PHI node in this basic block, and we can + // eliminate it, do so now. + if (PHINode *PN = dyn_cast(BB->begin())) + if (PN->getNumIncomingValues() == 2) + Changed |= FoldTwoEntryPHINode(PN); + + return Changed; +} diff --git a/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp b/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp new file mode 100644 index 0000000..b545ad3 --- /dev/null +++ b/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp @@ -0,0 +1,138 @@ +//===- UnifyFunctionExitNodes.cpp - Make all functions have a single exit -===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass is used to ensure that functions have at most one return +// instruction in them. Additionally, it keeps track of which node is the new +// exit node of the CFG. If there are no exit nodes in the CFG, the getExitNode +// method will return a null pointer. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/BasicBlock.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Type.h" +using namespace llvm; + +char UnifyFunctionExitNodes::ID = 0; +static RegisterPass +X("mergereturn", "Unify function exit nodes"); + +int UnifyFunctionExitNodes::stub; + +Pass *llvm::createUnifyFunctionExitNodesPass() { + return new UnifyFunctionExitNodes(); +} + +void UnifyFunctionExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{ + // We preserve the non-critical-edgeness property + AU.addPreservedID(BreakCriticalEdgesID); + // This is a cluster of orthogonal Transforms + AU.addPreservedID(PromoteMemoryToRegisterID); + AU.addPreservedID(LowerSelectID); + AU.addPreservedID(LowerSwitchID); +} + +// UnifyAllExitNodes - Unify all exit nodes of the CFG by creating a new +// BasicBlock, and converting all returns to unconditional branches to this +// new basic block. The singular exit node is returned. +// +// If there are no return stmts in the Function, a null pointer is returned. +// +bool UnifyFunctionExitNodes::runOnFunction(Function &F) { + // Loop over all of the blocks in a function, tracking all of the blocks that + // return. + // + std::vector ReturningBlocks; + std::vector UnwindingBlocks; + std::vector UnreachableBlocks; + for(Function::iterator I = F.begin(), E = F.end(); I != E; ++I) + if (isa(I->getTerminator())) + ReturningBlocks.push_back(I); + else if (isa(I->getTerminator())) + UnwindingBlocks.push_back(I); + else if (isa(I->getTerminator())) + UnreachableBlocks.push_back(I); + + // Handle unwinding blocks first. + if (UnwindingBlocks.empty()) { + UnwindBlock = 0; + } else if (UnwindingBlocks.size() == 1) { + UnwindBlock = UnwindingBlocks.front(); + } else { + UnwindBlock = new BasicBlock("UnifiedUnwindBlock", &F); + new UnwindInst(UnwindBlock); + + for (std::vector::iterator I = UnwindingBlocks.begin(), + E = UnwindingBlocks.end(); I != E; ++I) { + BasicBlock *BB = *I; + BB->getInstList().pop_back(); // Remove the unwind insn + new BranchInst(UnwindBlock, BB); + } + } + + // Then unreachable blocks. + if (UnreachableBlocks.empty()) { + UnreachableBlock = 0; + } else if (UnreachableBlocks.size() == 1) { + UnreachableBlock = UnreachableBlocks.front(); + } else { + UnreachableBlock = new BasicBlock("UnifiedUnreachableBlock", &F); + new UnreachableInst(UnreachableBlock); + + for (std::vector::iterator I = UnreachableBlocks.begin(), + E = UnreachableBlocks.end(); I != E; ++I) { + BasicBlock *BB = *I; + BB->getInstList().pop_back(); // Remove the unreachable inst. + new BranchInst(UnreachableBlock, BB); + } + } + + // Now handle return blocks. + if (ReturningBlocks.empty()) { + ReturnBlock = 0; + return false; // No blocks return + } else if (ReturningBlocks.size() == 1) { + ReturnBlock = ReturningBlocks.front(); // Already has a single return block + return false; + } + + // Otherwise, we need to insert a new basic block into the function, add a PHI + // node (if the function returns a value), and convert all of the return + // instructions into unconditional branches. + // + BasicBlock *NewRetBlock = new BasicBlock("UnifiedReturnBlock", &F); + + PHINode *PN = 0; + if (F.getReturnType() != Type::VoidTy) { + // If the function doesn't return void... add a PHI node to the block... + PN = new PHINode(F.getReturnType(), "UnifiedRetVal"); + NewRetBlock->getInstList().push_back(PN); + } + new ReturnInst(PN, NewRetBlock); + + // Loop over all of the blocks, replacing the return instruction with an + // unconditional branch. + // + for (std::vector::iterator I = ReturningBlocks.begin(), + E = ReturningBlocks.end(); I != E; ++I) { + BasicBlock *BB = *I; + + // Add an incoming element to the PHI node for every return instruction that + // is merging into this new block... + if (PN) PN->addIncoming(BB->getTerminator()->getOperand(0), BB); + + BB->getInstList().pop_back(); // Remove the return insn + new BranchInst(NewRetBlock, BB); + } + ReturnBlock = NewRetBlock; + return true; +} diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp new file mode 100644 index 0000000..0b8c5c2 --- /dev/null +++ b/lib/Transforms/Utils/ValueMapper.cpp @@ -0,0 +1,118 @@ +//===- ValueMapper.cpp - Interface shared by lib/Transforms/Utils ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the MapValue function, which is shared by various parts of +// the lib/Transforms/Utils library. +// +//===----------------------------------------------------------------------===// + +#include "ValueMapper.h" +#include "llvm/Constants.h" +#include "llvm/GlobalValue.h" +#include "llvm/Instruction.h" +using namespace llvm; + +Value *llvm::MapValue(const Value *V, ValueMapTy &VM) { + Value *&VMSlot = VM[V]; + if (VMSlot) return VMSlot; // Does it exist in the map yet? + + // NOTE: VMSlot can be invalidated by any reference to VM, which can grow the + // DenseMap. This includes any recursive calls to MapValue. + + // Global values do not need to be seeded into the ValueMap if they are using + // the identity mapping. + if (isa(V) || isa(V)) + return VMSlot = const_cast(V); + + if (Constant *C = const_cast(dyn_cast(V))) { + if (isa(C) || isa(C) || + isa(C) || isa(C) || + isa(C)) + return VMSlot = C; // Primitive constants map directly + else if (ConstantArray *CA = dyn_cast(C)) { + for (unsigned i = 0, e = CA->getNumOperands(); i != e; ++i) { + Value *MV = MapValue(CA->getOperand(i), VM); + if (MV != CA->getOperand(i)) { + // This array must contain a reference to a global, make a new array + // and return it. + // + std::vector Values; + Values.reserve(CA->getNumOperands()); + for (unsigned j = 0; j != i; ++j) + Values.push_back(CA->getOperand(j)); + Values.push_back(cast(MV)); + for (++i; i != e; ++i) + Values.push_back(cast(MapValue(CA->getOperand(i), VM))); + return VM[V] = ConstantArray::get(CA->getType(), Values); + } + } + return VM[V] = C; + + } else if (ConstantStruct *CS = dyn_cast(C)) { + for (unsigned i = 0, e = CS->getNumOperands(); i != e; ++i) { + Value *MV = MapValue(CS->getOperand(i), VM); + if (MV != CS->getOperand(i)) { + // This struct must contain a reference to a global, make a new struct + // and return it. + // + std::vector Values; + Values.reserve(CS->getNumOperands()); + for (unsigned j = 0; j != i; ++j) + Values.push_back(CS->getOperand(j)); + Values.push_back(cast(MV)); + for (++i; i != e; ++i) + Values.push_back(cast(MapValue(CS->getOperand(i), VM))); + return VM[V] = ConstantStruct::get(CS->getType(), Values); + } + } + return VM[V] = C; + + } else if (ConstantExpr *CE = dyn_cast(C)) { + std::vector Ops; + for (unsigned i = 0, e = CE->getNumOperands(); i != e; ++i) + Ops.push_back(cast(MapValue(CE->getOperand(i), VM))); + return VM[V] = CE->getWithOperands(Ops); + } else if (ConstantVector *CP = dyn_cast(C)) { + for (unsigned i = 0, e = CP->getNumOperands(); i != e; ++i) { + Value *MV = MapValue(CP->getOperand(i), VM); + if (MV != CP->getOperand(i)) { + // This vector value must contain a reference to a global, make a new + // vector constant and return it. + // + std::vector Values; + Values.reserve(CP->getNumOperands()); + for (unsigned j = 0; j != i; ++j) + Values.push_back(CP->getOperand(j)); + Values.push_back(cast(MV)); + for (++i; i != e; ++i) + Values.push_back(cast(MapValue(CP->getOperand(i), VM))); + return VM[V] = ConstantVector::get(Values); + } + } + return VM[V] = C; + + } else { + assert(0 && "Unknown type of constant!"); + } + } + + return 0; +} + +/// RemapInstruction - Convert the instruction operands from referencing the +/// current values into those specified by ValueMap. +/// +void llvm::RemapInstruction(Instruction *I, ValueMapTy &ValueMap) { + for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) { + const Value *Op = I->getOperand(op); + Value *V = MapValue(Op, ValueMap); + assert(V && "Referenced value not in value map!"); + I->setOperand(op, V); + } +} diff --git a/lib/Transforms/Utils/ValueMapper.h b/lib/Transforms/Utils/ValueMapper.h new file mode 100644 index 0000000..51319db --- /dev/null +++ b/lib/Transforms/Utils/ValueMapper.h @@ -0,0 +1,29 @@ +//===- ValueMapper.h - Interface shared by lib/Transforms/Utils -*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file was developed by the LLVM research group and is distributed under +// the University of Illinois Open Source License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the MapValue interface which is used by various parts of +// the Transforms/Utils library to implement cloning and linking facilities. +// +//===----------------------------------------------------------------------===// + +#ifndef VALUEMAPPER_H +#define VALUEMAPPER_H + +#include "llvm/ADT/DenseMap.h" + +namespace llvm { + class Value; + class Instruction; + typedef DenseMap ValueMapTy; + + Value *MapValue(const Value *V, ValueMapTy &VM); + void RemapInstruction(Instruction *I, ValueMapTy &VM); +} // End llvm namespace + +#endif -- cgit v1.1