diff options
Diffstat (limited to 'lib/Transforms/Scalar/RewriteStatepointsForGC.cpp')
-rw-r--r-- | lib/Transforms/Scalar/RewriteStatepointsForGC.cpp | 314 |
1 files changed, 179 insertions, 135 deletions
diff --git a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index ca9ab54..f5d21ff 100644 --- a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -548,9 +548,6 @@ public: } PhiState(Value *b) : status(Base), base(b) {} PhiState() : status(Unknown), base(nullptr) {} - PhiState(const PhiState &other) : status(other.status), base(other.base) { - assert(status != Base || base); - } Status getStatus() const { return status; } Value *getBase() const { return base; } @@ -684,12 +681,19 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache, states[def] = PhiState(); // Recursively fill in all phis & selects reachable from the initial one // for which we don't already know a definite base value for - // PERF: Yes, this is as horribly inefficient as it looks. + // TODO: This should be rewritten with a worklist bool done = false; while (!done) { done = true; + // Since we're adding elements to 'states' as we run, we can't keep + // iterators into the set. + SmallVector<Value*, 16> Keys; + Keys.reserve(states.size()); for (auto Pair : states) { - Value *v = Pair.first; + Value *V = Pair.first; + Keys.push_back(V); + } + for (Value *v : Keys) { assert(!isKnownBaseResult(v) && "why did it get added?"); if (PHINode *phi = dyn_cast<PHINode>(v)) { assert(phi->getNumIncomingValues() > 0 && @@ -730,10 +734,12 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache, // have reached conflict state. The current version seems too conservative. bool progress = true; - size_t oldSize = 0; while (progress) { - oldSize = states.size(); +#ifndef NDEBUG + size_t oldSize = states.size(); +#endif progress = false; + // We're only changing keys in this loop, thus safe to keep iterators for (auto Pair : states) { MeetPhiStates calculateMeet(states); Value *v = Pair.first; @@ -768,46 +774,58 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache, } // Insert Phis for all conflicts + // We want to keep naming deterministic in the loop that follows, so + // sort the keys before iteration. This is useful in allowing us to + // write stable tests. Note that there is no invalidation issue here. + SmallVector<Value*, 16> Keys; + Keys.reserve(states.size()); for (auto Pair : states) { - Instruction *v = cast<Instruction>(Pair.first); - PhiState state = Pair.second; + Value *V = Pair.first; + Keys.push_back(V); + } + std::sort(Keys.begin(), Keys.end(), order_by_name); + // TODO: adjust naming patterns to avoid this order of iteration dependency + for (Value *V : Keys) { + Instruction *v = cast<Instruction>(V); + PhiState state = states[V]; assert(!isKnownBaseResult(v) && "why did it get added?"); assert(!state.isUnknown() && "Optimistic algorithm didn't complete!"); - if (state.isConflict()) { - if (isa<PHINode>(v)) { - int num_preds = - std::distance(pred_begin(v->getParent()), pred_end(v->getParent())); - assert(num_preds > 0 && "how did we reach here"); - PHINode *phi = PHINode::Create(v->getType(), num_preds, "base_phi", v); - NewInsertedDefs.insert(phi); - // Add metadata marking this as a base value - auto *const_1 = ConstantInt::get( - Type::getInt32Ty( - v->getParent()->getParent()->getParent()->getContext()), - 1); - auto MDConst = ConstantAsMetadata::get(const_1); - MDNode *md = MDNode::get( - v->getParent()->getParent()->getParent()->getContext(), MDConst); - phi->setMetadata("is_base_value", md); - states[v] = PhiState(PhiState::Conflict, phi); - } else if (SelectInst *sel = dyn_cast<SelectInst>(v)) { - // The undef will be replaced later - UndefValue *undef = UndefValue::get(sel->getType()); - SelectInst *basesel = SelectInst::Create(sel->getCondition(), undef, - undef, "base_select", sel); - NewInsertedDefs.insert(basesel); - // Add metadata marking this as a base value - auto *const_1 = ConstantInt::get( - Type::getInt32Ty( - v->getParent()->getParent()->getParent()->getContext()), - 1); - auto MDConst = ConstantAsMetadata::get(const_1); - MDNode *md = MDNode::get( - v->getParent()->getParent()->getParent()->getContext(), MDConst); - basesel->setMetadata("is_base_value", md); - states[v] = PhiState(PhiState::Conflict, basesel); - } else - llvm_unreachable("unknown conflict type"); + if (!state.isConflict()) + continue; + + if (isa<PHINode>(v)) { + int num_preds = + std::distance(pred_begin(v->getParent()), pred_end(v->getParent())); + assert(num_preds > 0 && "how did we reach here"); + PHINode *phi = PHINode::Create(v->getType(), num_preds, "base_phi", v); + NewInsertedDefs.insert(phi); + // Add metadata marking this as a base value + auto *const_1 = ConstantInt::get( + Type::getInt32Ty( + v->getParent()->getParent()->getParent()->getContext()), + 1); + auto MDConst = ConstantAsMetadata::get(const_1); + MDNode *md = MDNode::get( + v->getParent()->getParent()->getParent()->getContext(), MDConst); + phi->setMetadata("is_base_value", md); + states[v] = PhiState(PhiState::Conflict, phi); + } else { + SelectInst *sel = cast<SelectInst>(v); + // The undef will be replaced later + UndefValue *undef = UndefValue::get(sel->getType()); + SelectInst *basesel = SelectInst::Create(sel->getCondition(), undef, + undef, "base_select", sel); + NewInsertedDefs.insert(basesel); + // Add metadata marking this as a base value + auto *const_1 = ConstantInt::get( + Type::getInt32Ty( + v->getParent()->getParent()->getParent()->getContext()), + 1); + auto MDConst = ConstantAsMetadata::get(const_1); + MDNode *md = MDNode::get( + v->getParent()->getParent()->getParent()->getContext(), MDConst); + basesel->setMetadata("is_base_value", md); + states[v] = PhiState(PhiState::Conflict, basesel); } } @@ -818,97 +836,98 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache, assert(!isKnownBaseResult(v) && "why did it get added?"); assert(!state.isUnknown() && "Optimistic algorithm didn't complete!"); - if (state.isConflict()) { - if (PHINode *basephi = dyn_cast<PHINode>(state.getBase())) { - PHINode *phi = cast<PHINode>(v); - unsigned NumPHIValues = phi->getNumIncomingValues(); - for (unsigned i = 0; i < NumPHIValues; i++) { - Value *InVal = phi->getIncomingValue(i); - BasicBlock *InBB = phi->getIncomingBlock(i); - - // If we've already seen InBB, add the same incoming value - // we added for it earlier. The IR verifier requires phi - // nodes with multiple entries from the same basic block - // to have the same incoming value for each of those - // entries. If we don't do this check here and basephi - // has a different type than base, we'll end up adding two - // bitcasts (and hence two distinct values) as incoming - // values for the same basic block. - - int blockIndex = basephi->getBasicBlockIndex(InBB); - if (blockIndex != -1) { - Value *oldBase = basephi->getIncomingValue(blockIndex); - basephi->addIncoming(oldBase, InBB); + if (!state.isConflict()) + continue; + + if (PHINode *basephi = dyn_cast<PHINode>(state.getBase())) { + PHINode *phi = cast<PHINode>(v); + unsigned NumPHIValues = phi->getNumIncomingValues(); + for (unsigned i = 0; i < NumPHIValues; i++) { + Value *InVal = phi->getIncomingValue(i); + BasicBlock *InBB = phi->getIncomingBlock(i); + + // If we've already seen InBB, add the same incoming value + // we added for it earlier. The IR verifier requires phi + // nodes with multiple entries from the same basic block + // to have the same incoming value for each of those + // entries. If we don't do this check here and basephi + // has a different type than base, we'll end up adding two + // bitcasts (and hence two distinct values) as incoming + // values for the same basic block. + + int blockIndex = basephi->getBasicBlockIndex(InBB); + if (blockIndex != -1) { + Value *oldBase = basephi->getIncomingValue(blockIndex); + basephi->addIncoming(oldBase, InBB); #ifndef NDEBUG - Value *base = findBaseOrBDV(InVal, cache); - if (!isKnownBaseResult(base)) { - // Either conflict or base. - assert(states.count(base)); - base = states[base].getBase(); - assert(base != nullptr && "unknown PhiState!"); - assert(NewInsertedDefs.count(base) && - "should have already added this in a prev. iteration!"); - } - - // In essense this assert states: the only way two - // values incoming from the same basic block may be - // different is by being different bitcasts of the same - // value. A cleanup that remains TODO is changing - // findBaseOrBDV to return an llvm::Value of the correct - // type (and still remain pure). This will remove the - // need to add bitcasts. - assert(base->stripPointerCasts() == oldBase->stripPointerCasts() && - "sanity -- findBaseOrBDV should be pure!"); -#endif - continue; - } - - // Find either the defining value for the PHI or the normal base for - // a non-phi node Value *base = findBaseOrBDV(InVal, cache); if (!isKnownBaseResult(base)) { // Either conflict or base. assert(states.count(base)); base = states[base].getBase(); assert(base != nullptr && "unknown PhiState!"); + assert(NewInsertedDefs.count(base) && + "should have already added this in a prev. iteration!"); } - assert(base && "can't be null"); - // Must use original input BB since base may not be Instruction - // The cast is needed since base traversal may strip away bitcasts - if (base->getType() != basephi->getType()) { - base = new BitCastInst(base, basephi->getType(), "cast", - InBB->getTerminator()); - NewInsertedDefs.insert(base); - } - basephi->addIncoming(base, InBB); + + // In essense this assert states: the only way two + // values incoming from the same basic block may be + // different is by being different bitcasts of the same + // value. A cleanup that remains TODO is changing + // findBaseOrBDV to return an llvm::Value of the correct + // type (and still remain pure). This will remove the + // need to add bitcasts. + assert(base->stripPointerCasts() == oldBase->stripPointerCasts() && + "sanity -- findBaseOrBDV should be pure!"); +#endif + continue; } - assert(basephi->getNumIncomingValues() == NumPHIValues); - } else if (SelectInst *basesel = dyn_cast<SelectInst>(state.getBase())) { - SelectInst *sel = cast<SelectInst>(v); - // Operand 1 & 2 are true, false path respectively. TODO: refactor to - // something more safe and less hacky. - for (int i = 1; i <= 2; i++) { - Value *InVal = sel->getOperand(i); - // Find either the defining value for the PHI or the normal base for - // a non-phi node - Value *base = findBaseOrBDV(InVal, cache); - if (!isKnownBaseResult(base)) { - // Either conflict or base. - assert(states.count(base)); - base = states[base].getBase(); - assert(base != nullptr && "unknown PhiState!"); - } - assert(base && "can't be null"); - // Must use original input BB since base may not be Instruction - // The cast is needed since base traversal may strip away bitcasts - if (base->getType() != basesel->getType()) { - base = new BitCastInst(base, basesel->getType(), "cast", basesel); - NewInsertedDefs.insert(base); - } - basesel->setOperand(i, base); + + // Find either the defining value for the PHI or the normal base for + // a non-phi node + Value *base = findBaseOrBDV(InVal, cache); + if (!isKnownBaseResult(base)) { + // Either conflict or base. + assert(states.count(base)); + base = states[base].getBase(); + assert(base != nullptr && "unknown PhiState!"); } - } else - llvm_unreachable("unexpected conflict type"); + assert(base && "can't be null"); + // Must use original input BB since base may not be Instruction + // The cast is needed since base traversal may strip away bitcasts + if (base->getType() != basephi->getType()) { + base = new BitCastInst(base, basephi->getType(), "cast", + InBB->getTerminator()); + NewInsertedDefs.insert(base); + } + basephi->addIncoming(base, InBB); + } + assert(basephi->getNumIncomingValues() == NumPHIValues); + } else { + SelectInst *basesel = cast<SelectInst>(state.getBase()); + SelectInst *sel = cast<SelectInst>(v); + // Operand 1 & 2 are true, false path respectively. TODO: refactor to + // something more safe and less hacky. + for (int i = 1; i <= 2; i++) { + Value *InVal = sel->getOperand(i); + // Find either the defining value for the PHI or the normal base for + // a non-phi node + Value *base = findBaseOrBDV(InVal, cache); + if (!isKnownBaseResult(base)) { + // Either conflict or base. + assert(states.count(base)); + base = states[base].getBase(); + assert(base != nullptr && "unknown PhiState!"); + } + assert(base && "can't be null"); + // Must use original input BB since base may not be Instruction + // The cast is needed since base traversal may strip away bitcasts + if (base->getType() != basesel->getType()) { + base = new BitCastInst(base, basesel->getType(), "cast", basesel); + NewInsertedDefs.insert(base); + } + basesel->setOperand(i, base); + } } } @@ -964,7 +983,13 @@ static void findBasePointers(const StatepointLiveSetTy &live, DenseMap<llvm::Value *, llvm::Value *> &PointerToBase, DominatorTree *DT, DefiningValueMapTy &DVCache, DenseSet<llvm::Value *> &NewInsertedDefs) { - for (Value *ptr : live) { + // For the naming of values inserted to be deterministic - which makes for + // much cleaner and more stable tests - we need to assign an order to the + // live values. DenseSets do not provide a deterministic order across runs. + SmallVector<Value*, 64> Temp; + Temp.insert(Temp.end(), live.begin(), live.end()); + std::sort(Temp.begin(), Temp.end(), order_by_name); + for (Value *ptr : Temp) { Value *base = findBasePointer(ptr, DVCache, NewInsertedDefs); assert(base && "failed to find base pointer"); PointerToBase[ptr] = base; @@ -993,10 +1018,19 @@ static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache, findBasePointers(result.liveset, PointerToBase, &DT, DVCache, NewInsertedDefs); if (PrintBasePointers) { + // Note: Need to print these in a stable order since this is checked in + // some tests. errs() << "Base Pairs (w/o Relocation):\n"; + SmallVector<Value*, 64> Temp; + Temp.reserve(PointerToBase.size()); for (auto Pair : PointerToBase) { - errs() << " derived %" << Pair.first->getName() << " base %" - << Pair.second->getName() << "\n"; + Temp.push_back(Pair.first); + } + std::sort(Temp.begin(), Temp.end(), order_by_name); + for (Value *Ptr : Temp) { + Value *Base = PointerToBase[Ptr]; + errs() << " derived %" << Ptr->getName() << " base %" + << Base->getName() << "\n"; } } @@ -1131,11 +1165,11 @@ static AttributeSet legalizeCallAttributes(AttributeSet AS) { /// statepointToken - statepoint instruction to which relocates should be /// bound. /// Builder - Llvm IR builder to be used to construct new calls. -void CreateGCRelocates(ArrayRef<llvm::Value *> liveVariables, - const int liveStart, - ArrayRef<llvm::Value *> basePtrs, - Instruction *statepointToken, IRBuilder<> Builder) { - +static void CreateGCRelocates(ArrayRef<llvm::Value *> liveVariables, + const int liveStart, + ArrayRef<llvm::Value *> basePtrs, + Instruction *statepointToken, + IRBuilder<> Builder) { SmallVector<Instruction *, 64> NewDefs; NewDefs.reserve(liveVariables.size()); @@ -1559,8 +1593,18 @@ static void relocationViaAlloca( // store must be inserted after load, otherwise store will be in alloca's // use list and an extra load will be inserted before it StoreInst *store = new StoreInst(def, alloca); - if (isa<Instruction>(def)) { - store->insertAfter(cast<Instruction>(def)); + if (Instruction *inst = dyn_cast<Instruction>(def)) { + if (InvokeInst *invoke = dyn_cast<InvokeInst>(inst)) { + // InvokeInst is a TerminatorInst so the store need to be inserted + // into its normal destination block. + BasicBlock *normalDest = invoke->getNormalDest(); + store->insertBefore(normalDest->getFirstNonPHI()); + } else { + assert(!inst->isTerminator() && + "The only TerminatorInst that can produce a value is " + "InvokeInst which is handled above."); + store->insertAfter(inst); + } } else { assert((isa<Argument>(def) || isa<GlobalVariable>(def) || (isa<Constant>(def) && cast<Constant>(def)->isNullValue())) && |