diff options
Diffstat (limited to 'lib/Transforms/IPO/MergeFunctions.cpp')
-rw-r--r-- | lib/Transforms/IPO/MergeFunctions.cpp | 350 |
1 files changed, 251 insertions, 99 deletions
diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index cac824b..9cfbcc8 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -45,10 +45,11 @@ #define DEBUG_TYPE "mergefunc" #include "llvm/Transforms/IPO.h" -#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/FoldingSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Constants.h" #include "llvm/InlineAsm.h" #include "llvm/Instructions.h" @@ -59,47 +60,127 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/IRBuilder.h" +#include "llvm/Support/ValueHandle.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetData.h" -#include <map> #include <vector> using namespace llvm; STATISTIC(NumFunctionsMerged, "Number of functions merged"); +STATISTIC(NumThunksWritten, "Number of thunks generated"); +STATISTIC(NumDoubleWeak, "Number of new functions created"); + +/// ProfileFunction - Creates a hash-code for the function which is the same +/// for any two functions that will compare equal, without looking at the +/// instructions inside the function. +static unsigned ProfileFunction(const Function *F) { + const FunctionType *FTy = F->getFunctionType(); + + FoldingSetNodeID ID; + ID.AddInteger(F->size()); + ID.AddInteger(F->getCallingConv()); + ID.AddBoolean(F->hasGC()); + ID.AddBoolean(FTy->isVarArg()); + ID.AddInteger(FTy->getReturnType()->getTypeID()); + for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) + ID.AddInteger(FTy->getParamType(i)->getTypeID()); + return ID.ComputeHash(); +} namespace { - /// MergeFunctions finds functions which will generate identical machine code, - /// by considering all pointer types to be equivalent. Once identified, - /// MergeFunctions will fold them by replacing a call to one to a call to a - /// bitcast of the other. - /// - class MergeFunctions : public ModulePass { - public: - static char ID; - MergeFunctions() : ModulePass(ID) {} - - bool runOnModule(Module &M); - - private: - /// PairwiseCompareAndMerge - Given a list of functions, compare each pair - /// and merge the pairs of equivalent functions. - bool PairwiseCompareAndMerge(std::vector<Function *> &FnVec); - - /// MergeTwoFunctions - Merge two equivalent functions. Upon completion, - /// FnVec[j] should never be visited again. - void MergeTwoFunctions(std::vector<Function *> &FnVec, - unsigned i, unsigned j) const; - - /// WriteThunk - Replace G with a simple tail call to bitcast(F). Also - /// replace direct uses of G with bitcast(F). - void WriteThunk(Function *F, Function *G) const; - - TargetData *TD; + +class ComparableFunction { +public: + static const ComparableFunction EmptyKey; + static const ComparableFunction TombstoneKey; + + ComparableFunction(Function *Func, TargetData *TD) + : Func(Func), Hash(ProfileFunction(Func)), TD(TD) {} + + Function *getFunc() const { return Func; } + unsigned getHash() const { return Hash; } + TargetData *getTD() const { return TD; } + + // Drops AssertingVH reference to the function. Outside of debug mode, this + // does nothing. + void release() { + assert(Func && + "Attempted to release function twice, or release empty/tombstone!"); + Func = NULL; + } + +private: + explicit ComparableFunction(unsigned Hash) + : Func(NULL), Hash(Hash), TD(NULL) {} + + AssertingVH<Function> Func; + unsigned Hash; + TargetData *TD; +}; + +const ComparableFunction ComparableFunction::EmptyKey = ComparableFunction(0); +const ComparableFunction ComparableFunction::TombstoneKey = + ComparableFunction(1); + +} + +namespace llvm { + template <> + struct DenseMapInfo<ComparableFunction> { + static ComparableFunction getEmptyKey() { + return ComparableFunction::EmptyKey; + } + static ComparableFunction getTombstoneKey() { + return ComparableFunction::TombstoneKey; + } + static unsigned getHashValue(const ComparableFunction &CF) { + return CF.getHash(); + } + static bool isEqual(const ComparableFunction &LHS, + const ComparableFunction &RHS); }; } +namespace { + +/// MergeFunctions finds functions which will generate identical machine code, +/// by considering all pointer types to be equivalent. Once identified, +/// MergeFunctions will fold them by replacing a call to one to a call to a +/// bitcast of the other. +/// +class MergeFunctions : public ModulePass { +public: + static char ID; + MergeFunctions() : ModulePass(ID) { + initializeMergeFunctionsPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M); + +private: + typedef DenseSet<ComparableFunction> FnSetType; + + + /// Insert a ComparableFunction into the FnSet, or merge it away if it's + /// equal to one that's already present. + bool Insert(FnSetType &FnSet, ComparableFunction &NewF); + + /// MergeTwoFunctions - Merge two equivalent functions. Upon completion, G + /// may be deleted, or may be converted into a thunk. In either case, it + /// should never be visited again. + void MergeTwoFunctions(Function *F, Function *G) const; + + /// WriteThunk - Replace G with a simple tail call to bitcast(F). Also + /// replace direct uses of G with bitcast(F). Deletes G. + void WriteThunk(Function *F, Function *G) const; + + TargetData *TD; +}; + +} // end anonymous namespace + char MergeFunctions::ID = 0; -INITIALIZE_PASS(MergeFunctions, "mergefunc", "Merge Functions", false, false); +INITIALIZE_PASS(MergeFunctions, "mergefunc", "Merge Functions", false, false) ModulePass *llvm::createMergeFunctionsPass() { return new MergeFunctions(); @@ -112,7 +193,8 @@ namespace { /// side of claiming that two functions are different). class FunctionComparator { public: - FunctionComparator(TargetData *TD, Function *F1, Function *F2) + FunctionComparator(const TargetData *TD, const Function *F1, + const Function *F2) : F1(F1), F2(F2), TD(TD), IDMap1Count(0), IDMap2Count(0) {} /// Compare - test whether the two functions have equivalent behaviour. @@ -144,9 +226,9 @@ private: bool isEquivalentType(const Type *Ty1, const Type *Ty2) const; // The two functions undergoing comparison. - Function *F1, *F2; + const Function *F1, *F2; - TargetData *TD; + const TargetData *TD; typedef DenseMap<const Value *, unsigned long> IDMap; IDMap Map1, Map2; @@ -154,22 +236,6 @@ private: }; } -/// Compute a hash guaranteed to be equal for two equivalent functions, but -/// very likely to be different for different functions. -static unsigned long ProfileFunction(const Function *F) { - const FunctionType *FTy = F->getFunctionType(); - - FoldingSetNodeID ID; - ID.AddInteger(F->size()); - ID.AddInteger(F->getCallingConv()); - ID.AddBoolean(F->hasGC()); - ID.AddBoolean(FTy->isVarArg()); - ID.AddInteger(FTy->getReturnType()->getTypeID()); - for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) - ID.AddInteger(FTy->getParamType(i)->getTypeID()); - return ID.ComputeHash(); -} - /// isEquivalentType - any two pointers in the same address space are /// equivalent. Otherwise, standard type equivalence rules apply. bool FunctionComparator::isEquivalentType(const Type *Ty1, @@ -449,14 +515,14 @@ bool FunctionComparator::Compare() { return false; assert(F1->arg_size() == F2->arg_size() && - "Identical functions have a different number of args."); + "Identically typed functions have different numbers of args!"); // Visit the arguments so that they get enumerated in the order they're // passed in. for (Function::const_arg_iterator f1i = F1->arg_begin(), f2i = F2->arg_begin(), f1e = F1->arg_end(); f1i != f1e; ++f1i, ++f2i) { if (!Enumerate(f1i, f2i)) - llvm_unreachable("Arguments repeat"); + llvm_unreachable("Arguments repeat!"); } // We do a CFG-ordered walk since the actual ordering of the blocks in the @@ -493,7 +559,7 @@ bool FunctionComparator::Compare() { } /// WriteThunk - Replace G with a simple tail call to bitcast(F). Also replace -/// direct uses of G with bitcast(F). +/// direct uses of G with bitcast(F). Deletes G. void MergeFunctions::WriteThunk(Function *F, Function *G) const { if (!G->mayBeOverridden()) { // Redirect direct callers of G to F. @@ -508,7 +574,7 @@ void MergeFunctions::WriteThunk(Function *F, Function *G) const { } } - // If G was internal then we may have replaced all uses if G with F. If so, + // If G was internal then we may have replaced all uses of G with F. If so, // stop here and delete G. There's no need for a thunk. if (G->hasLocalLinkage() && G->use_empty()) { G->eraseFromParent(); @@ -542,22 +608,16 @@ void MergeFunctions::WriteThunk(Function *F, Function *G) const { NewG->takeName(G); G->replaceAllUsesWith(NewG); G->eraseFromParent(); + + DEBUG(dbgs() << "WriteThunk: " << NewG->getName() << '\n'); + ++NumThunksWritten; } /// MergeTwoFunctions - Merge two equivalent functions. Upon completion, -/// FnVec[j] is deleted but not removed from the vector. -void MergeFunctions::MergeTwoFunctions(std::vector<Function *> &FnVec, - unsigned i, unsigned j) const { - Function *F = FnVec[i]; - Function *G = FnVec[j]; - - if (F->isWeakForLinker() && !G->isWeakForLinker()) { - std::swap(FnVec[i], FnVec[j]); - std::swap(F, G); - } - - if (F->isWeakForLinker()) { - assert(G->isWeakForLinker()); +/// Function G is deleted. +void MergeFunctions::MergeTwoFunctions(Function *F, Function *G) const { + if (F->mayBeOverridden()) { + assert(G->mayBeOverridden()); // Make them both thunks to the same internal function. Function *H = Function::Create(F->getFunctionType(), F->getLinkage(), "", @@ -573,6 +633,8 @@ void MergeFunctions::MergeTwoFunctions(std::vector<Function *> &FnVec, F->setAlignment(MaxAlignment); F->setLinkage(GlobalValue::InternalLinkage); + + ++NumDoubleWeak; } else { WriteThunk(F, G); } @@ -580,54 +642,144 @@ void MergeFunctions::MergeTwoFunctions(std::vector<Function *> &FnVec, ++NumFunctionsMerged; } -/// PairwiseCompareAndMerge - Given a list of functions, compare each pair and -/// merge the pairs of equivalent functions. -bool MergeFunctions::PairwiseCompareAndMerge(std::vector<Function *> &FnVec) { - bool Changed = false; - for (int i = 0, e = FnVec.size(); i != e; ++i) { - for (int j = i + 1; j != e; ++j) { - bool isEqual = FunctionComparator(TD, FnVec[i], FnVec[j]).Compare(); - - DEBUG(dbgs() << " " << FnVec[i]->getName() - << (isEqual ? " == " : " != ") << FnVec[j]->getName() << "\n"); - - if (isEqual) { - MergeTwoFunctions(FnVec, i, j); - Changed = true; - FnVec.erase(FnVec.begin() + j); - --j, --e; - } - } - } - return Changed; +// Insert - Insert a ComparableFunction into the FnSet, or merge it away if +// equal to one that's already inserted. +bool MergeFunctions::Insert(FnSetType &FnSet, ComparableFunction &NewF) { + std::pair<FnSetType::iterator, bool> Result = FnSet.insert(NewF); + if (Result.second) + return false; + + const ComparableFunction &OldF = *Result.first; + + // Never thunk a strong function to a weak function. + assert(!OldF.getFunc()->mayBeOverridden() || + NewF.getFunc()->mayBeOverridden()); + + DEBUG(dbgs() << " " << OldF.getFunc()->getName() << " == " + << NewF.getFunc()->getName() << '\n'); + + Function *DeleteF = NewF.getFunc(); + NewF.release(); + MergeTwoFunctions(OldF.getFunc(), DeleteF); + return true; } -bool MergeFunctions::runOnModule(Module &M) { - bool Changed = false; +// IsThunk - This method determines whether or not a given Function is a thunk\// like the ones emitted by this pass and therefore not subject to further +// merging. +static bool IsThunk(const Function *F) { + // The safe direction to fail is to return true. In that case, the function + // will be removed from merging analysis. If we failed to including functions + // then we may try to merge unmergable thing (ie., identical weak functions) + // which will push us into an infinite loop. + + assert(!F->isDeclaration() && "Expected a function definition."); + + const BasicBlock *BB = &F->front(); + // A thunk is: + // bitcast-inst* + // optional-reg tail call @thunkee(args...*) + // ret void|optional-reg + // where the args are in the same order as the arguments. + + // Put this at the top since it triggers most often. + const ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator()); + if (!RI) return false; + + // Verify that the sequence of bitcast-inst's are all casts of arguments and + // that there aren't any extras (ie. no repeated casts). + int LastArgNo = -1; + BasicBlock::const_iterator I = BB->begin(); + while (const BitCastInst *BCI = dyn_cast<BitCastInst>(I)) { + const Argument *A = dyn_cast<Argument>(BCI->getOperand(0)); + if (!A) return false; + if ((int)A->getArgNo() <= LastArgNo) return false; + LastArgNo = A->getArgNo(); + ++I; + } - std::map<unsigned long, std::vector<Function *> > FnMap; + // Verify that we have a direct tail call and that the calling conventions + // and number of arguments match. + const CallInst *CI = dyn_cast<CallInst>(I++); + if (!CI || !CI->isTailCall() || !CI->getCalledFunction() || + CI->getCallingConv() != CI->getCalledFunction()->getCallingConv() || + CI->getNumArgOperands() != F->arg_size()) + return false; - for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) { - if (F->isDeclaration() || F->hasAvailableExternallyLinkage()) - continue; + // Verify that the call instruction has the same arguments as this function + // and that they're all either the incoming argument or a cast of the right + // argument. + for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { + const Value *V = CI->getArgOperand(i); + const Argument *A = dyn_cast<Argument>(V); + if (!A) { + const BitCastInst *BCI = dyn_cast<BitCastInst>(V); + if (!BCI) return false; + A = cast<Argument>(BCI->getOperand(0)); + } + if (A->getArgNo() != i) return false; + } - FnMap[ProfileFunction(F)].push_back(F); + // Verify that the terminator is a ret void (if we're void) or a ret of the + // call's return, or a ret of a bitcast of the call's return. + if (const BitCastInst *BCI = dyn_cast<BitCastInst>(I)) { + ++I; + if (BCI->getOperand(0) != CI) return false; } + if (RI != I) return false; + if (RI->getNumOperands() == 0) + return CI->getType()->isVoidTy(); + return RI->getReturnValue() == CI; +} +bool MergeFunctions::runOnModule(Module &M) { + bool Changed = false; TD = getAnalysisIfAvailable<TargetData>(); bool LocalChanged; do { + DEBUG(dbgs() << "size of module: " << M.size() << '\n'); LocalChanged = false; - DEBUG(dbgs() << "size: " << FnMap.size() << "\n"); - for (std::map<unsigned long, std::vector<Function *> >::iterator - I = FnMap.begin(), E = FnMap.end(); I != E; ++I) { - std::vector<Function *> &FnVec = I->second; - DEBUG(dbgs() << "hash (" << I->first << "): " << FnVec.size() << "\n"); - LocalChanged |= PairwiseCompareAndMerge(FnVec); + FnSetType FnSet; + + // Insert only strong functions and merge them. Strong function merging + // always deletes one of them. + for (Module::iterator I = M.begin(), E = M.end(); I != E;) { + Function *F = I++; + if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage() && + !F->mayBeOverridden() && !IsThunk(F)) { + ComparableFunction CF = ComparableFunction(F, TD); + LocalChanged |= Insert(FnSet, CF); + } } + + // Insert only weak functions and merge them. By doing these second we + // create thunks to the strong function when possible. When two weak + // functions are identical, we create a new strong function with two weak + // weak thunks to it which are identical but not mergable. + for (Module::iterator I = M.begin(), E = M.end(); I != E;) { + Function *F = I++; + if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage() && + F->mayBeOverridden() && !IsThunk(F)) { + ComparableFunction CF = ComparableFunction(F, TD); + LocalChanged |= Insert(FnSet, CF); + } + } + DEBUG(dbgs() << "size of FnSet: " << FnSet.size() << '\n'); Changed |= LocalChanged; } while (LocalChanged); return Changed; } + +bool DenseMapInfo<ComparableFunction>::isEqual(const ComparableFunction &LHS, + const ComparableFunction &RHS) { + if (LHS.getFunc() == RHS.getFunc() && + LHS.getHash() == RHS.getHash()) + return true; + if (!LHS.getFunc() || !RHS.getFunc()) + return false; + assert(LHS.getTD() == RHS.getTD() && + "Comparing functions for different targets"); + return FunctionComparator(LHS.getTD(), + LHS.getFunc(), RHS.getFunc()).Compare(); +} |