diff options
-rw-r--r-- | lib/Transforms/Utils/CodeExtractor.cpp | 97 |
1 files changed, 52 insertions, 45 deletions
diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index e008fd7..59f9876 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -19,6 +19,7 @@ #include "llvm/Instructions.h" #include "llvm/Module.h" #include "llvm/Pass.h" +#include "llvm/Analysis/Dominators.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/Verifier.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -38,13 +39,16 @@ namespace { return I; } - struct CodeExtractor { + class CodeExtractor { typedef std::vector<Value*> Values; typedef std::vector<std::pair<unsigned, unsigned> > PhiValChangesTy; typedef std::map<PHINode*, PhiValChangesTy> PhiVal2ArgTy; PhiVal2ArgTy PhiVal2Arg; std::set<BasicBlock*> BlocksToExtract; + DominatorSet *DS; public: + CodeExtractor(DominatorSet *ds = 0) : DS(ds) {} + Function *ExtractCodeRegion(const std::vector<BasicBlock*> &code); private: @@ -191,8 +195,10 @@ void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs, // Consider uses of this instruction (outputs) for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); UI != E; ++UI) - if (!BlocksToExtract.count(cast<Instruction>(*UI)->getParent())) - outputs.push_back(*UI); + if (!BlocksToExtract.count(cast<Instruction>(*UI)->getParent())) { + outputs.push_back(I); + break; + } } // for: insts } // for: basic blocks } @@ -257,18 +263,11 @@ Function *CodeExtractor::constructFunction(const Values &inputs, paramTy.push_back(value->getType()); } - // Add the types of the output values to the function's argument list, but - // make them pointer types for scalars - for (Values::const_iterator i = outputs.begin(), - e = outputs.end(); i != e; ++i) { - const Value *value = *i; - DEBUG(std::cerr << "instr used in func: " << value << "\n"); - const Type *valueType = value->getType(); - // Convert scalar types into a pointer of that type - if (valueType->isPrimitiveType()) { - valueType = PointerType::get(valueType); - } - paramTy.push_back(valueType); + // Add the types of the output values to the function's argument list. + for (Values::const_iterator I = outputs.begin(), E = outputs.end(); + I != E; ++I) { + DEBUG(std::cerr << "instr used in func: " << *I << "\n"); + paramTy.push_back(PointerType::get((*I)->getType())); } DEBUG(std::cerr << "Function type: " << retTy << " f("); @@ -285,15 +284,26 @@ Function *CodeExtractor::constructFunction(const Values &inputs, oldFunction->getName() + "_code", M); newFunction->getBasicBlockList().push_back(newRootNode); - for (unsigned i = 0, e = inputs.size(); i != e; ++i) { + // Create an iterator to name all of the arguments we inserted. + Function::aiterator AI = newFunction->abegin(); + + // Rewrite all users of the inputs in the extracted region to use the + // arguments instead. + for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI) { + AI->setName(inputs[i]->getName()); std::vector<User*> Users(inputs[i]->use_begin(), inputs[i]->use_end()); for (std::vector<User*>::iterator use = Users.begin(), useE = Users.end(); use != useE; ++use) if (Instruction* inst = dyn_cast<Instruction>(*use)) if (BlocksToExtract.count(inst->getParent())) - inst->replaceUsesOfWith(inputs[i], getFunctionArg(newFunction, i)); + inst->replaceUsesOfWith(inputs[i], AI); } + // Set names for all of the output arguments. + for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI) + AI->setName(outputs[i]->getName()+".out"); + + // Rewrite branches to basic blocks outside of the loop to new dummy blocks // within the new function. This must be done before we lose track of which // blocks were originally in the code region. @@ -332,34 +342,26 @@ void CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, Values &inputs, - Values &outputs) -{ + Values &outputs) { // Emit a call to the new function, passing allocated memory for outputs and // just plain inputs for non-scalars std::vector<Value*> params(inputs); - for (Values::const_iterator i = outputs.begin(), e = outputs.end(); i != e; - ++i) { - Value *Output = *i; + for (unsigned i = 0, e = outputs.size(); i != e; ++i) { + Value *Output = outputs[i]; // Create allocas for scalar outputs - if (Output->getType()->isPrimitiveType()) { - AllocaInst *alloca = - new AllocaInst((*i)->getType(), 0, Output->getName()+".loc", - codeReplacer->getParent()->begin()->begin()); - params.push_back(alloca); - - LoadInst *load = new LoadInst(alloca, Output->getName()+".reload"); - codeReplacer->getInstList().push_back(load); - std::vector<User*> Users((*i)->use_begin(), (*i)->use_end()); - for (std::vector<User*>::iterator use = Users.begin(), useE =Users.end(); - use != useE; ++use) { - if (Instruction* inst = dyn_cast<Instruction>(*use)) { - if (!BlocksToExtract.count(inst->getParent())) - inst->replaceUsesOfWith(*i, load); - } - } - } else { - params.push_back(*i); + AllocaInst *alloca = + new AllocaInst(outputs[i]->getType(), 0, Output->getName()+".loc", + codeReplacer->getParent()->begin()->begin()); + params.push_back(alloca); + + LoadInst *load = new LoadInst(alloca, Output->getName()+".reload"); + codeReplacer->getInstList().push_back(load); + std::vector<User*> Users(outputs[i]->use_begin(), outputs[i]->use_end()); + for (unsigned u = 0, e = Users.size(); u != e; ++u) { + Instruction *inst = cast<Instruction>(Users[u]); + if (!BlocksToExtract.count(inst->getParent())) + inst->replaceUsesOfWith(outputs[i], load); } } @@ -400,7 +402,11 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, // Restore values just before we exit // FIXME: Use a GetElementPtr to bunch the outputs in a struct for (unsigned out = 0, e = outputs.size(); out != e; ++out) - new StoreInst(outputs[out], getFunctionArg(newFunction, out),NTRet); + if (!DS || + DS->dominates(cast<Instruction>(outputs[out])->getParent(), + TI->getParent())) + new StoreInst(outputs[out], getFunctionArg(newFunction, out), + NTRet); } // rewrite the original branch instruction with this new target @@ -502,14 +508,15 @@ Function *CodeExtractor::ExtractCodeRegion(const std::vector<BasicBlock*> &code) /// ExtractCodeRegion - slurp a sequence of basic blocks into a brand new /// function /// -Function* llvm::ExtractCodeRegion(const std::vector<BasicBlock*> &code) { - return CodeExtractor().ExtractCodeRegion(code); +Function* llvm::ExtractCodeRegion(DominatorSet &DS, + const std::vector<BasicBlock*> &code) { + return CodeExtractor(&DS).ExtractCodeRegion(code); } /// ExtractBasicBlock - slurp a natural loop into a brand new function /// -Function* llvm::ExtractLoop(Loop *L) { - return CodeExtractor().ExtractCodeRegion(L->getBlocks()); +Function* llvm::ExtractLoop(DominatorSet &DS, Loop *L) { + return CodeExtractor(&DS).ExtractCodeRegion(L->getBlocks()); } /// ExtractBasicBlock - slurp a basic block into a brand new function |