diff options
Diffstat (limited to 'lib/Analysis')
48 files changed, 5491 insertions, 2365 deletions
diff --git a/lib/Analysis/AliasAnalysis.cpp b/lib/Analysis/AliasAnalysis.cpp index 5cde979..5171a45 100644 --- a/lib/Analysis/AliasAnalysis.cpp +++ b/lib/Analysis/AliasAnalysis.cpp @@ -196,17 +196,21 @@ AliasAnalysis::getModRefInfo(ImmutableCallSite CS1, ImmutableCallSite CS2) { if (!Arg->getType()->isPointerTy()) continue; ModRefResult ArgMask; - Location CS1Loc = - getArgLocation(CS1, (unsigned) std::distance(CS1.arg_begin(), I), - ArgMask); - if ((getModRefInfo(CS2, CS1Loc) & ArgMask) != NoModRef) { - R = Mask; + Location CS1Loc = getArgLocation( + CS1, (unsigned)std::distance(CS1.arg_begin(), I), ArgMask); + // ArgMask indicates what CS1 might do to CS1Loc; if CS1 might Mod + // CS1Loc, then we care about either a Mod or a Ref by CS2. If CS1 + // might Ref, then we care only about a Mod by CS2. + ModRefResult ArgR = getModRefInfo(CS2, CS1Loc); + if (((ArgMask & Mod) != NoModRef && (ArgR & ModRef) != NoModRef) || + ((ArgMask & Ref) != NoModRef && (ArgR & Mod) != NoModRef)) + R = ModRefResult((R | ArgMask) & Mask); + + if (R == Mask) break; - } } } - if (R == NoModRef) - return R; + return R; } // If this is the end of the chain, don't forward. @@ -247,61 +251,73 @@ AliasAnalysis::getModRefBehavior(const Function *F) { //===----------------------------------------------------------------------===// AliasAnalysis::Location AliasAnalysis::getLocation(const LoadInst *LI) { + AAMDNodes AATags; + LI->getAAMetadata(AATags); + return Location(LI->getPointerOperand(), - getTypeStoreSize(LI->getType()), - LI->getMetadata(LLVMContext::MD_tbaa)); + getTypeStoreSize(LI->getType()), AATags); } AliasAnalysis::Location AliasAnalysis::getLocation(const StoreInst *SI) { + AAMDNodes AATags; + SI->getAAMetadata(AATags); + return Location(SI->getPointerOperand(), - getTypeStoreSize(SI->getValueOperand()->getType()), - SI->getMetadata(LLVMContext::MD_tbaa)); + getTypeStoreSize(SI->getValueOperand()->getType()), AATags); } AliasAnalysis::Location AliasAnalysis::getLocation(const VAArgInst *VI) { - return Location(VI->getPointerOperand(), - UnknownSize, - VI->getMetadata(LLVMContext::MD_tbaa)); + AAMDNodes AATags; + VI->getAAMetadata(AATags); + + return Location(VI->getPointerOperand(), UnknownSize, AATags); } AliasAnalysis::Location AliasAnalysis::getLocation(const AtomicCmpXchgInst *CXI) { + AAMDNodes AATags; + CXI->getAAMetadata(AATags); + return Location(CXI->getPointerOperand(), getTypeStoreSize(CXI->getCompareOperand()->getType()), - CXI->getMetadata(LLVMContext::MD_tbaa)); + AATags); } AliasAnalysis::Location AliasAnalysis::getLocation(const AtomicRMWInst *RMWI) { + AAMDNodes AATags; + RMWI->getAAMetadata(AATags); + return Location(RMWI->getPointerOperand(), - getTypeStoreSize(RMWI->getValOperand()->getType()), - RMWI->getMetadata(LLVMContext::MD_tbaa)); + getTypeStoreSize(RMWI->getValOperand()->getType()), AATags); } -AliasAnalysis::Location +AliasAnalysis::Location AliasAnalysis::getLocationForSource(const MemTransferInst *MTI) { uint64_t Size = UnknownSize; if (ConstantInt *C = dyn_cast<ConstantInt>(MTI->getLength())) Size = C->getValue().getZExtValue(); - // memcpy/memmove can have TBAA tags. For memcpy, they apply + // memcpy/memmove can have AA tags. For memcpy, they apply // to both the source and the destination. - MDNode *TBAATag = MTI->getMetadata(LLVMContext::MD_tbaa); + AAMDNodes AATags; + MTI->getAAMetadata(AATags); - return Location(MTI->getRawSource(), Size, TBAATag); + return Location(MTI->getRawSource(), Size, AATags); } -AliasAnalysis::Location +AliasAnalysis::Location AliasAnalysis::getLocationForDest(const MemIntrinsic *MTI) { uint64_t Size = UnknownSize; if (ConstantInt *C = dyn_cast<ConstantInt>(MTI->getLength())) Size = C->getValue().getZExtValue(); - // memcpy/memmove can have TBAA tags. For memcpy, they apply + // memcpy/memmove can have AA tags. For memcpy, they apply // to both the source and the destination. - MDNode *TBAATag = MTI->getMetadata(LLVMContext::MD_tbaa); - - return Location(MTI->getRawDest(), Size, TBAATag); + AAMDNodes AATags; + MTI->getAAMetadata(AATags); + + return Location(MTI->getRawDest(), Size, AATags); } @@ -383,53 +399,6 @@ AliasAnalysis::getModRefInfo(const AtomicRMWInst *RMW, const Location &Loc) { return ModRef; } -namespace { - /// Only find pointer captures which happen before the given instruction. Uses - /// the dominator tree to determine whether one instruction is before another. - /// Only support the case where the Value is defined in the same basic block - /// as the given instruction and the use. - struct CapturesBefore : public CaptureTracker { - CapturesBefore(const Instruction *I, DominatorTree *DT) - : BeforeHere(I), DT(DT), Captured(false) {} - - void tooManyUses() override { Captured = true; } - - bool shouldExplore(const Use *U) override { - Instruction *I = cast<Instruction>(U->getUser()); - BasicBlock *BB = I->getParent(); - // We explore this usage only if the usage can reach "BeforeHere". - // If use is not reachable from entry, there is no need to explore. - if (BeforeHere != I && !DT->isReachableFromEntry(BB)) - return false; - // If the value is defined in the same basic block as use and BeforeHere, - // there is no need to explore the use if BeforeHere dominates use. - // Check whether there is a path from I to BeforeHere. - if (BeforeHere != I && DT->dominates(BeforeHere, I) && - !isPotentiallyReachable(I, BeforeHere, DT)) - return false; - return true; - } - - bool captured(const Use *U) override { - Instruction *I = cast<Instruction>(U->getUser()); - BasicBlock *BB = I->getParent(); - // Same logic as in shouldExplore. - if (BeforeHere != I && !DT->isReachableFromEntry(BB)) - return false; - if (BeforeHere != I && DT->dominates(BeforeHere, I) && - !isPotentiallyReachable(I, BeforeHere, DT)) - return false; - Captured = true; - return true; - } - - const Instruction *BeforeHere; - DominatorTree *DT; - - bool Captured; - }; -} - // FIXME: this is really just shoring-up a deficiency in alias analysis. // BasicAA isn't willing to spend linear time determining whether an alloca // was captured before or after this particular call, while we are. However, @@ -449,9 +418,9 @@ AliasAnalysis::callCapturesBefore(const Instruction *I, if (!CS.getInstruction() || CS.getInstruction() == Object) return AliasAnalysis::ModRef; - CapturesBefore CB(I, DT); - llvm::PointerMayBeCaptured(Object, &CB); - if (CB.Captured) + if (llvm::PointerMayBeCapturedBefore(Object, /* ReturnCaptures */ true, + /* StoreCaptures */ true, I, DT, + /* include Object */ true)) return AliasAnalysis::ModRef; unsigned ArgNo = 0; @@ -470,7 +439,7 @@ AliasAnalysis::callCapturesBefore(const Instruction *I, // assume that the call could touch the pointer, even though it doesn't // escape. if (isNoAlias(AliasAnalysis::Location(*CI), - AliasAnalysis::Location(Object))) + AliasAnalysis::Location(Object))) continue; if (CS.doesNotAccessMemory(ArgNo)) continue; @@ -577,3 +546,13 @@ bool llvm::isIdentifiedObject(const Value *V) { return A->hasNoAliasAttr() || A->hasByValAttr(); return false; } + +/// isIdentifiedFunctionLocal - Return true if V is umabigously identified +/// at the function-level. Different IdentifiedFunctionLocals can't alias. +/// Further, an IdentifiedFunctionLocal can not alias with any function +/// arguments other than itself, which is not necessarily true for +/// IdentifiedObjects. +bool llvm::isIdentifiedFunctionLocal(const Value *V) +{ + return isa<AllocaInst>(V) || isNoAliasCall(V) || isNoAliasArgument(V); +} diff --git a/lib/Analysis/AliasAnalysisEvaluator.cpp b/lib/Analysis/AliasAnalysisEvaluator.cpp index d9fa5a5..fe4bd4c 100644 --- a/lib/Analysis/AliasAnalysisEvaluator.cpp +++ b/lib/Analysis/AliasAnalysisEvaluator.cpp @@ -43,7 +43,7 @@ static cl::opt<bool> PrintMod("print-mod", cl::ReallyHidden); static cl::opt<bool> PrintRef("print-ref", cl::ReallyHidden); static cl::opt<bool> PrintModRef("print-modref", cl::ReallyHidden); -static cl::opt<bool> EvalTBAA("evaluate-tbaa", cl::ReallyHidden); +static cl::opt<bool> EvalAAMD("evaluate-aa-metadata", cl::ReallyHidden); namespace { class AAEval : public FunctionPass { @@ -153,9 +153,9 @@ bool AAEval::runOnFunction(Function &F) { for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) { if (I->getType()->isPointerTy()) // Add all pointer instructions. Pointers.insert(&*I); - if (EvalTBAA && isa<LoadInst>(&*I)) + if (EvalAAMD && isa<LoadInst>(&*I)) Loads.insert(&*I); - if (EvalTBAA && isa<StoreInst>(&*I)) + if (EvalAAMD && isa<StoreInst>(&*I)) Stores.insert(&*I); Instruction &Inst = *I; if (CallSite CS = cast<Value>(&Inst)) { @@ -213,7 +213,7 @@ bool AAEval::runOnFunction(Function &F) { } } - if (EvalTBAA) { + if (EvalAAMD) { // iterate over all pairs of load, store for (SetVector<Value *>::iterator I1 = Loads.begin(), E = Loads.end(); I1 != E; ++I1) { diff --git a/lib/Analysis/AliasSetTracker.cpp b/lib/Analysis/AliasSetTracker.cpp index a45fe23..45442b0 100644 --- a/lib/Analysis/AliasSetTracker.cpp +++ b/lib/Analysis/AliasSetTracker.cpp @@ -47,18 +47,21 @@ void AliasSet::mergeSetIn(AliasSet &AS, AliasSetTracker &AST) { // If the pointers are not a must-alias pair, this set becomes a may alias. if (AA.alias(AliasAnalysis::Location(L->getValue(), L->getSize(), - L->getTBAAInfo()), + L->getAAInfo()), AliasAnalysis::Location(R->getValue(), R->getSize(), - R->getTBAAInfo())) + R->getAAInfo())) != AliasAnalysis::MustAlias) AliasTy = MayAlias; } + bool ASHadUnknownInsts = !AS.UnknownInsts.empty(); if (UnknownInsts.empty()) { // Merge call sites... - if (!AS.UnknownInsts.empty()) + if (ASHadUnknownInsts) { std::swap(UnknownInsts, AS.UnknownInsts); - } else if (!AS.UnknownInsts.empty()) { + addRef(); + } + } else if (ASHadUnknownInsts) { UnknownInsts.insert(UnknownInsts.end(), AS.UnknownInsts.begin(), AS.UnknownInsts.end()); AS.UnknownInsts.clear(); } @@ -76,6 +79,8 @@ void AliasSet::mergeSetIn(AliasSet &AS, AliasSetTracker &AST) { AS.PtrListEnd = &AS.PtrList; assert(*AS.PtrListEnd == nullptr && "End of list is not null?"); } + if (ASHadUnknownInsts) + AS.dropRef(AST); } void AliasSetTracker::removeAliasSet(AliasSet *AS) { @@ -92,7 +97,7 @@ void AliasSet::removeFromTracker(AliasSetTracker &AST) { } void AliasSet::addPointer(AliasSetTracker &AST, PointerRec &Entry, - uint64_t Size, const MDNode *TBAAInfo, + uint64_t Size, const AAMDNodes &AAInfo, bool KnownMustAlias) { assert(!Entry.hasAliasSet() && "Entry already in set!"); @@ -102,17 +107,17 @@ void AliasSet::addPointer(AliasSetTracker &AST, PointerRec &Entry, AliasAnalysis &AA = AST.getAliasAnalysis(); AliasAnalysis::AliasResult Result = AA.alias(AliasAnalysis::Location(P->getValue(), P->getSize(), - P->getTBAAInfo()), - AliasAnalysis::Location(Entry.getValue(), Size, TBAAInfo)); + P->getAAInfo()), + AliasAnalysis::Location(Entry.getValue(), Size, AAInfo)); if (Result != AliasAnalysis::MustAlias) AliasTy = MayAlias; else // First entry of must alias must have maximum size! - P->updateSizeAndTBAAInfo(Size, TBAAInfo); + P->updateSizeAndAAInfo(Size, AAInfo); assert(Result != AliasAnalysis::NoAlias && "Cannot be part of must set!"); } Entry.setAliasSet(this); - Entry.updateSizeAndTBAAInfo(Size, TBAAInfo); + Entry.updateSizeAndAAInfo(Size, AAInfo); // Add it to the end of the list... assert(*PtrListEnd == nullptr && "End of list is not null?"); @@ -123,6 +128,8 @@ void AliasSet::addPointer(AliasSetTracker &AST, PointerRec &Entry, } void AliasSet::addUnknownInst(Instruction *I, AliasAnalysis &AA) { + if (UnknownInsts.empty()) + addRef(); UnknownInsts.push_back(I); if (!I->mayWriteToMemory()) { @@ -140,7 +147,7 @@ void AliasSet::addUnknownInst(Instruction *I, AliasAnalysis &AA) { /// alias one of the members in the set. /// bool AliasSet::aliasesPointer(const Value *Ptr, uint64_t Size, - const MDNode *TBAAInfo, + const AAMDNodes &AAInfo, AliasAnalysis &AA) const { if (AliasTy == MustAlias) { assert(UnknownInsts.empty() && "Illegal must alias set!"); @@ -151,23 +158,23 @@ bool AliasSet::aliasesPointer(const Value *Ptr, uint64_t Size, assert(SomePtr && "Empty must-alias set??"); return AA.alias(AliasAnalysis::Location(SomePtr->getValue(), SomePtr->getSize(), - SomePtr->getTBAAInfo()), - AliasAnalysis::Location(Ptr, Size, TBAAInfo)); + SomePtr->getAAInfo()), + AliasAnalysis::Location(Ptr, Size, AAInfo)); } // If this is a may-alias set, we have to check all of the pointers in the set // to be sure it doesn't alias the set... for (iterator I = begin(), E = end(); I != E; ++I) - if (AA.alias(AliasAnalysis::Location(Ptr, Size, TBAAInfo), + if (AA.alias(AliasAnalysis::Location(Ptr, Size, AAInfo), AliasAnalysis::Location(I.getPointer(), I.getSize(), - I.getTBAAInfo()))) + I.getAAInfo()))) return true; // Check the unknown instructions... if (!UnknownInsts.empty()) { for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) if (AA.getModRefInfo(UnknownInsts[i], - AliasAnalysis::Location(Ptr, Size, TBAAInfo)) != + AliasAnalysis::Location(Ptr, Size, AAInfo)) != AliasAnalysis::NoModRef) return true; } @@ -190,7 +197,7 @@ bool AliasSet::aliasesUnknownInst(Instruction *Inst, AliasAnalysis &AA) const { for (iterator I = begin(), E = end(); I != E; ++I) if (AA.getModRefInfo(Inst, AliasAnalysis::Location(I.getPointer(), I.getSize(), - I.getTBAAInfo())) != + I.getAAInfo())) != AliasAnalysis::NoModRef) return true; @@ -216,15 +223,16 @@ void AliasSetTracker::clear() { /// AliasSet *AliasSetTracker::findAliasSetForPointer(const Value *Ptr, uint64_t Size, - const MDNode *TBAAInfo) { + const AAMDNodes &AAInfo) { AliasSet *FoundSet = nullptr; - for (iterator I = begin(), E = end(); I != E; ++I) { - if (I->Forward || !I->aliasesPointer(Ptr, Size, TBAAInfo, AA)) continue; + for (iterator I = begin(), E = end(); I != E;) { + iterator Cur = I++; + if (Cur->Forward || !Cur->aliasesPointer(Ptr, Size, AAInfo, AA)) continue; if (!FoundSet) { // If this is the first alias set ptr can go into. - FoundSet = I; // Remember it. + FoundSet = Cur; // Remember it. } else { // Otherwise, we must merge the sets. - FoundSet->mergeSetIn(*I, *this); // Merge in contents. + FoundSet->mergeSetIn(*Cur, *this); // Merge in contents. } } @@ -235,25 +243,30 @@ AliasSet *AliasSetTracker::findAliasSetForPointer(const Value *Ptr, /// this alias set, false otherwise. This does not modify the AST object or /// alias sets. bool AliasSetTracker::containsPointer(Value *Ptr, uint64_t Size, - const MDNode *TBAAInfo) const { + const AAMDNodes &AAInfo) const { for (const_iterator I = begin(), E = end(); I != E; ++I) - if (!I->Forward && I->aliasesPointer(Ptr, Size, TBAAInfo, AA)) + if (!I->Forward && I->aliasesPointer(Ptr, Size, AAInfo, AA)) return true; return false; } - +bool AliasSetTracker::containsUnknown(Instruction *Inst) const { + for (const_iterator I = begin(), E = end(); I != E; ++I) + if (!I->Forward && I->aliasesUnknownInst(Inst, AA)) + return true; + return false; +} AliasSet *AliasSetTracker::findAliasSetForUnknownInst(Instruction *Inst) { AliasSet *FoundSet = nullptr; - for (iterator I = begin(), E = end(); I != E; ++I) { - if (I->Forward || !I->aliasesUnknownInst(Inst, AA)) + for (iterator I = begin(), E = end(); I != E;) { + iterator Cur = I++; + if (Cur->Forward || !Cur->aliasesUnknownInst(Inst, AA)) continue; - if (!FoundSet) // If this is the first alias set ptr can go into. - FoundSet = I; // Remember it. - else if (!I->Forward) // Otherwise, we must merge the sets. - FoundSet->mergeSetIn(*I, *this); // Merge in contents. + FoundSet = Cur; // Remember it. + else if (!Cur->Forward) // Otherwise, we must merge the sets. + FoundSet->mergeSetIn(*Cur, *this); // Merge in contents. } return FoundSet; } @@ -264,67 +277,75 @@ AliasSet *AliasSetTracker::findAliasSetForUnknownInst(Instruction *Inst) { /// getAliasSetForPointer - Return the alias set that the specified pointer /// lives in. AliasSet &AliasSetTracker::getAliasSetForPointer(Value *Pointer, uint64_t Size, - const MDNode *TBAAInfo, + const AAMDNodes &AAInfo, bool *New) { AliasSet::PointerRec &Entry = getEntryFor(Pointer); // Check to see if the pointer is already known. if (Entry.hasAliasSet()) { - Entry.updateSizeAndTBAAInfo(Size, TBAAInfo); + Entry.updateSizeAndAAInfo(Size, AAInfo); // Return the set! return *Entry.getAliasSet(*this)->getForwardedTarget(*this); } - if (AliasSet *AS = findAliasSetForPointer(Pointer, Size, TBAAInfo)) { + if (AliasSet *AS = findAliasSetForPointer(Pointer, Size, AAInfo)) { // Add it to the alias set it aliases. - AS->addPointer(*this, Entry, Size, TBAAInfo); + AS->addPointer(*this, Entry, Size, AAInfo); return *AS; } if (New) *New = true; // Otherwise create a new alias set to hold the loaded pointer. AliasSets.push_back(new AliasSet()); - AliasSets.back().addPointer(*this, Entry, Size, TBAAInfo); + AliasSets.back().addPointer(*this, Entry, Size, AAInfo); return AliasSets.back(); } -bool AliasSetTracker::add(Value *Ptr, uint64_t Size, const MDNode *TBAAInfo) { +bool AliasSetTracker::add(Value *Ptr, uint64_t Size, const AAMDNodes &AAInfo) { bool NewPtr; - addPointer(Ptr, Size, TBAAInfo, AliasSet::NoModRef, NewPtr); + addPointer(Ptr, Size, AAInfo, AliasSet::NoModRef, NewPtr); return NewPtr; } bool AliasSetTracker::add(LoadInst *LI) { if (LI->getOrdering() > Monotonic) return addUnknown(LI); + + AAMDNodes AAInfo; + LI->getAAMetadata(AAInfo); + AliasSet::AccessType ATy = AliasSet::Refs; bool NewPtr; AliasSet &AS = addPointer(LI->getOperand(0), AA.getTypeStoreSize(LI->getType()), - LI->getMetadata(LLVMContext::MD_tbaa), - ATy, NewPtr); + AAInfo, ATy, NewPtr); if (LI->isVolatile()) AS.setVolatile(); return NewPtr; } bool AliasSetTracker::add(StoreInst *SI) { if (SI->getOrdering() > Monotonic) return addUnknown(SI); + + AAMDNodes AAInfo; + SI->getAAMetadata(AAInfo); + AliasSet::AccessType ATy = AliasSet::Mods; bool NewPtr; Value *Val = SI->getOperand(0); AliasSet &AS = addPointer(SI->getOperand(1), AA.getTypeStoreSize(Val->getType()), - SI->getMetadata(LLVMContext::MD_tbaa), - ATy, NewPtr); + AAInfo, ATy, NewPtr); if (SI->isVolatile()) AS.setVolatile(); return NewPtr; } bool AliasSetTracker::add(VAArgInst *VAAI) { + AAMDNodes AAInfo; + VAAI->getAAMetadata(AAInfo); + bool NewPtr; addPointer(VAAI->getOperand(0), AliasAnalysis::UnknownSize, - VAAI->getMetadata(LLVMContext::MD_tbaa), - AliasSet::ModRef, NewPtr); + AAInfo, AliasSet::ModRef, NewPtr); return NewPtr; } @@ -382,7 +403,7 @@ void AliasSetTracker::add(const AliasSetTracker &AST) { bool X; for (AliasSet::iterator ASI = AS.begin(), E = AS.end(); ASI != E; ++ASI) { AliasSet &NewAS = addPointer(ASI.getPointer(), ASI.getSize(), - ASI.getTBAAInfo(), + ASI.getAAInfo(), (AliasSet::AccessType)AS.AccessTy, X); if (AS.isVolatile()) NewAS.setVolatile(); } @@ -393,6 +414,8 @@ void AliasSetTracker::add(const AliasSetTracker &AST) { /// tracker. void AliasSetTracker::remove(AliasSet &AS) { // Drop all call sites. + if (!AS.UnknownInsts.empty()) + AS.dropRef(*this); AS.UnknownInsts.clear(); // Clear the alias set. @@ -419,8 +442,8 @@ void AliasSetTracker::remove(AliasSet &AS) { } bool -AliasSetTracker::remove(Value *Ptr, uint64_t Size, const MDNode *TBAAInfo) { - AliasSet *AS = findAliasSetForPointer(Ptr, Size, TBAAInfo); +AliasSetTracker::remove(Value *Ptr, uint64_t Size, const AAMDNodes &AAInfo) { + AliasSet *AS = findAliasSetForPointer(Ptr, Size, AAInfo); if (!AS) return false; remove(*AS); return true; @@ -428,8 +451,11 @@ AliasSetTracker::remove(Value *Ptr, uint64_t Size, const MDNode *TBAAInfo) { bool AliasSetTracker::remove(LoadInst *LI) { uint64_t Size = AA.getTypeStoreSize(LI->getType()); - const MDNode *TBAAInfo = LI->getMetadata(LLVMContext::MD_tbaa); - AliasSet *AS = findAliasSetForPointer(LI->getOperand(0), Size, TBAAInfo); + + AAMDNodes AAInfo; + LI->getAAMetadata(AAInfo); + + AliasSet *AS = findAliasSetForPointer(LI->getOperand(0), Size, AAInfo); if (!AS) return false; remove(*AS); return true; @@ -437,17 +463,22 @@ bool AliasSetTracker::remove(LoadInst *LI) { bool AliasSetTracker::remove(StoreInst *SI) { uint64_t Size = AA.getTypeStoreSize(SI->getOperand(0)->getType()); - const MDNode *TBAAInfo = SI->getMetadata(LLVMContext::MD_tbaa); - AliasSet *AS = findAliasSetForPointer(SI->getOperand(1), Size, TBAAInfo); + + AAMDNodes AAInfo; + SI->getAAMetadata(AAInfo); + + AliasSet *AS = findAliasSetForPointer(SI->getOperand(1), Size, AAInfo); if (!AS) return false; remove(*AS); return true; } bool AliasSetTracker::remove(VAArgInst *VAAI) { + AAMDNodes AAInfo; + VAAI->getAAMetadata(AAInfo); + AliasSet *AS = findAliasSetForPointer(VAAI->getOperand(0), - AliasAnalysis::UnknownSize, - VAAI->getMetadata(LLVMContext::MD_tbaa)); + AliasAnalysis::UnknownSize, AAInfo); if (!AS) return false; remove(*AS); return true; @@ -489,10 +520,10 @@ void AliasSetTracker::deleteValue(Value *PtrVal) { if (Instruction *Inst = dyn_cast<Instruction>(PtrVal)) { if (Inst->mayReadOrWriteMemory()) { // Scan all the alias sets to see if this call site is contained. - for (iterator I = begin(), E = end(); I != E; ++I) { - if (I->Forward) continue; - - I->removeUnknownInst(Inst); + for (iterator I = begin(), E = end(); I != E;) { + iterator Cur = I++; + if (!Cur->Forward) + Cur->removeUnknownInst(*this, Inst); } } } @@ -536,7 +567,7 @@ void AliasSetTracker::copyValue(Value *From, Value *To) { I = PointerMap.find_as(From); AliasSet *AS = I->second->getAliasSet(*this); AS->addPointer(*this, Entry, I->second->getSize(), - I->second->getTBAAInfo(), + I->second->getAAInfo(), true); } diff --git a/lib/Analysis/Analysis.cpp b/lib/Analysis/Analysis.cpp index ade940a..f64bf0e 100644 --- a/lib/Analysis/Analysis.cpp +++ b/lib/Analysis/Analysis.cpp @@ -34,6 +34,7 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeCFGPrinterPass(Registry); initializeCFGOnlyViewerPass(Registry); initializeCFGOnlyPrinterPass(Registry); + initializeCFLAliasAnalysisPass(Registry); initializeDependenceAnalysisPass(Registry); initializeDelinearizationPass(Registry); initializeDominanceFrontierPass(Registry); @@ -57,7 +58,7 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeMemoryDependenceAnalysisPass(Registry); initializeModuleDebugInfoPrinterPass(Registry); initializePostDominatorTreePass(Registry); - initializeRegionInfoPass(Registry); + initializeRegionInfoPassPass(Registry); initializeRegionViewerPass(Registry); initializeRegionPrinterPass(Registry); initializeRegionOnlyViewerPass(Registry); @@ -66,6 +67,7 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeScalarEvolutionAliasAnalysisPass(Registry); initializeTargetTransformInfoAnalysisGroup(Registry); initializeTypeBasedAliasAnalysisPass(Registry); + initializeScopedNoAliasAAPass(Registry); } void LLVMInitializeAnalysis(LLVMPassRegistryRef R) { diff --git a/lib/Analysis/Android.mk b/lib/Analysis/Android.mk index 4e435a1..8770fa7 100644 --- a/lib/Analysis/Android.mk +++ b/lib/Analysis/Android.mk @@ -7,12 +7,15 @@ analysis_SRC_FILES := \ AliasDebugger.cpp \ AliasSetTracker.cpp \ Analysis.cpp \ + AssumptionTracker.cpp \ BasicAliasAnalysis.cpp \ BlockFrequencyInfo.cpp \ BlockFrequencyInfoImpl.cpp \ BranchProbabilityInfo.cpp \ CFG.cpp \ CFGPrinter.cpp \ + CFLAliasAnalysis.cpp \ + CGSCCPassManager.cpp \ CaptureTracking.cpp \ CodeMetrics.cpp \ ConstantFolding.cpp \ @@ -21,7 +24,7 @@ analysis_SRC_FILES := \ DependenceAnalysis.cpp \ DomPrinter.cpp \ DominanceFrontier.cpp \ - CGSCCPassManager.cpp \ + FunctionTargetTransformInfo.cpp \ IVUsers.cpp \ InstCount.cpp \ InstructionSimplify.cpp \ @@ -51,6 +54,7 @@ analysis_SRC_FILES := \ ScalarEvolutionAliasAnalysis.cpp \ ScalarEvolutionExpander.cpp \ ScalarEvolutionNormalization.cpp \ + ScopedNoAliasAA.cpp \ SparsePropagation.cpp \ TargetTransformInfo.cpp \ Trace.cpp \ diff --git a/lib/Analysis/AssumptionTracker.cpp b/lib/Analysis/AssumptionTracker.cpp new file mode 100644 index 0000000..775ce1d --- /dev/null +++ b/lib/Analysis/AssumptionTracker.cpp @@ -0,0 +1,110 @@ +//===- AssumptionTracker.cpp - Track @llvm.assume -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains a pass that keeps track of @llvm.assume intrinsics in +// the functions of a module. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AssumptionTracker.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +using namespace llvm; +using namespace llvm::PatternMatch; + +void AssumptionTracker::FunctionCallbackVH::deleted() { + AT->forgetCachedAssumptions(cast<Function>(getValPtr())); + // 'this' now dangles! +} + +void AssumptionTracker::forgetCachedAssumptions(Function *F) { + auto I = CachedAssumeCalls.find_as(F); + if (I != CachedAssumeCalls.end()) + CachedAssumeCalls.erase(I); +} + +void AssumptionTracker::CallCallbackVH::deleted() { + assert(F && "delete callback called on dummy handle"); + FunctionCallsMap::iterator I = AT->CachedAssumeCalls.find_as(F); + assert(I != AT->CachedAssumeCalls.end() && + "Function cleared from the map without removing the values?"); + + I->second->erase(*this); + // 'this' now dangles! +} + +AssumptionTracker::FunctionCallsMap::iterator +AssumptionTracker::scanFunction(Function *F) { + auto IP = CachedAssumeCalls.insert(std::make_pair( + FunctionCallbackVH(F, this), llvm::make_unique<CallHandleSet>())); + assert(IP.second && "Scanning function already in the map?"); + + FunctionCallsMap::iterator I = IP.first; + + // Go through all instructions in all blocks, add all calls to @llvm.assume + // to our cache. + for (BasicBlock &B : *F) + for (Instruction &II : B) + if (match(&II, m_Intrinsic<Intrinsic::assume>())) + I->second->insert(CallCallbackVH(&II, this)); + + return I; +} + +void AssumptionTracker::verifyAnalysis() const { +#ifndef NDEBUG + for (const auto &I : CachedAssumeCalls) { + for (const BasicBlock &B : cast<Function>(*I.first)) + for (const Instruction &II : B) { + if (match(&II, m_Intrinsic<Intrinsic::assume>())) { + assert(I.second->find_as(&II) != I.second->end() && + "Assumption in scanned function not in cache"); + } + } + } +#endif +} + +void AssumptionTracker::registerAssumption(CallInst *CI) { + assert(match(CI, m_Intrinsic<Intrinsic::assume>()) && + "Registered call does not call @llvm.assume"); + assert(CI->getParent() && + "Cannot register @llvm.assume call not in a basic block"); + + Function *F = CI->getParent()->getParent(); + assert(F && "Cannot register @llvm.assume call not in a function"); + + FunctionCallsMap::iterator I = CachedAssumeCalls.find_as(F); + if (I == CachedAssumeCalls.end()) { + // If this function has not already been scanned, then don't do anything + // here. This intrinsic will be found, if it still exists, if the list of + // assumptions in this function is requested at some later point. This + // maintains the following invariant: if a function is present in the + // cache, then its list of assumption intrinsic calls is complete. + return; + } + + I->second->insert(CallCallbackVH(CI, this)); +} + +AssumptionTracker::AssumptionTracker() : ImmutablePass(ID) { + initializeAssumptionTrackerPass(*PassRegistry::getPassRegistry()); +} + +AssumptionTracker::~AssumptionTracker() {} + +INITIALIZE_PASS(AssumptionTracker, "assumption-tracker", "Assumption Tracker", + false, true) +char AssumptionTracker::ID = 0; + diff --git a/lib/Analysis/BasicAliasAnalysis.cpp b/lib/Analysis/BasicAliasAnalysis.cpp index c50dd4a..9aba0d3 100644 --- a/lib/Analysis/BasicAliasAnalysis.cpp +++ b/lib/Analysis/BasicAliasAnalysis.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -156,17 +157,6 @@ static bool isObjectSize(const Value *V, uint64_t Size, return ObjectSize != AliasAnalysis::UnknownSize && ObjectSize == Size; } -/// isIdentifiedFunctionLocal - Return true if V is umabigously identified -/// at the function-level. Different IdentifiedFunctionLocals can't alias. -/// Further, an IdentifiedFunctionLocal can not alias with any function -/// arguments other than itself, which is not necessarily true for -/// IdentifiedObjects. -static bool isIdentifiedFunctionLocal(const Value *V) -{ - return isa<AllocaInst>(V) || isNoAliasCall(V) || isNoAliasArgument(V); -} - - //===----------------------------------------------------------------------===// // GetElementPtr Instruction Decomposition and Analysis //===----------------------------------------------------------------------===// @@ -205,7 +195,9 @@ namespace { /// represented in the result. static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, ExtensionKind &Extension, - const DataLayout &DL, unsigned Depth) { + const DataLayout &DL, unsigned Depth, + AssumptionTracker *AT, + DominatorTree *DT) { assert(V->getType()->isIntegerTy() && "Not an integer value"); // Limit our recursion depth. @@ -215,6 +207,14 @@ static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, return V; } + if (ConstantInt *Const = dyn_cast<ConstantInt>(V)) { + // if it's a constant, just convert it to an offset + // and remove the variable. + Offset += Const->getValue(); + assert(Scale == 0 && "Constant values don't have a scale"); + return V; + } + if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(V)) { if (ConstantInt *RHSC = dyn_cast<ConstantInt>(BOp->getOperand(1))) { switch (BOp->getOpcode()) { @@ -222,23 +222,24 @@ static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, case Instruction::Or: // X|C == X+C if all the bits in C are unset in X. Otherwise we can't // analyze it. - if (!MaskedValueIsZero(BOp->getOperand(0), RHSC->getValue(), &DL)) + if (!MaskedValueIsZero(BOp->getOperand(0), RHSC->getValue(), &DL, 0, + AT, BOp, DT)) break; // FALL THROUGH. case Instruction::Add: V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, Extension, - DL, Depth+1); + DL, Depth+1, AT, DT); Offset += RHSC->getValue(); return V; case Instruction::Mul: V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, Extension, - DL, Depth+1); + DL, Depth+1, AT, DT); Offset *= RHSC->getValue(); Scale *= RHSC->getValue(); return V; case Instruction::Shl: V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, Extension, - DL, Depth+1); + DL, Depth+1, AT, DT); Offset <<= RHSC->getValue().getLimitedValue(); Scale <<= RHSC->getValue().getLimitedValue(); return V; @@ -259,9 +260,12 @@ static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, Extension = isa<SExtInst>(V) ? EK_SignExt : EK_ZeroExt; Value *Result = GetLinearExpression(CastOp, Scale, Offset, Extension, - DL, Depth+1); + DL, Depth+1, AT, DT); Scale = Scale.zext(OldWidth); - Offset = Offset.zext(OldWidth); + + // We have to sign-extend even if Extension == EK_ZeroExt as we can't + // decompose a sign extension (i.e. zext(x - 1) != zext(x) - zext(-1)). + Offset = Offset.sext(OldWidth); return Result; } @@ -289,7 +293,8 @@ static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, static const Value * DecomposeGEPExpression(const Value *V, int64_t &BaseOffs, SmallVectorImpl<VariableGEPIndex> &VarIndices, - bool &MaxLookupReached, const DataLayout *DL) { + bool &MaxLookupReached, const DataLayout *DL, + AssumptionTracker *AT, DominatorTree *DT) { // Limit recursion depth to limit compile time in crazy cases. unsigned MaxLookup = MaxLookupSearchDepth; MaxLookupReached = false; @@ -309,7 +314,8 @@ DecomposeGEPExpression(const Value *V, int64_t &BaseOffs, return V; } - if (Op->getOpcode() == Instruction::BitCast) { + if (Op->getOpcode() == Instruction::BitCast || + Op->getOpcode() == Instruction::AddrSpaceCast) { V = Op->getOperand(0); continue; } @@ -319,7 +325,10 @@ DecomposeGEPExpression(const Value *V, int64_t &BaseOffs, // If it's not a GEP, hand it off to SimplifyInstruction to see if it // can come up with something. This matches what GetUnderlyingObject does. if (const Instruction *I = dyn_cast<Instruction>(V)) - // TODO: Get a DominatorTree and use it here. + // TODO: Get a DominatorTree and AssumptionTracker and use them here + // (these are both now available in this function, but this should be + // updated when GetUnderlyingObject is updated). TLI should be + // provided also. if (const Value *Simplified = SimplifyInstruction(const_cast<Instruction *>(I), DL)) { V = Simplified; @@ -378,7 +387,7 @@ DecomposeGEPExpression(const Value *V, int64_t &BaseOffs, // Use GetLinearExpression to decompose the index into a C1*V+C2 form. APInt IndexScale(Width, 0), IndexOffset(Width, 0); Index = GetLinearExpression(Index, IndexScale, IndexOffset, Extension, - *DL, 0); + *DL, 0, AT, DT); // The GEP index scale ("Scale") scales C1*V+C2, yielding (C1*V+C2)*Scale. // This gives us an aggregate computation of (C1*Scale)*V + C2*Scale. @@ -459,6 +468,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AliasAnalysis>(); + AU.addRequired<AssumptionTracker>(); AU.addRequired<TargetLibraryInfo>(); } @@ -466,8 +476,8 @@ namespace { assert(AliasCache.empty() && "AliasCache must be cleared after use!"); assert(notDifferentParent(LocA.Ptr, LocB.Ptr) && "BasicAliasAnalysis doesn't support interprocedural queries."); - AliasResult Alias = aliasCheck(LocA.Ptr, LocA.Size, LocA.TBAATag, - LocB.Ptr, LocB.Size, LocB.TBAATag); + AliasResult Alias = aliasCheck(LocA.Ptr, LocA.Size, LocA.AATags, + LocB.Ptr, LocB.Size, LocB.AATags); // AliasCache rarely has more than 1 or 2 elements, always use // shrink_and_clear so it quickly returns to the inline capacity of the // SmallDenseMap if it ever grows larger. @@ -481,10 +491,7 @@ namespace { const Location &Loc) override; ModRefResult getModRefInfo(ImmutableCallSite CS1, - ImmutableCallSite CS2) override { - // The AliasAnalysis base class has some smarts, lets use them. - return AliasAnalysis::getModRefInfo(CS1, CS2); - } + ImmutableCallSite CS2) override; /// pointsToConstantMemory - Chase pointers until we find a (constant /// global) or not. @@ -554,28 +561,28 @@ namespace { // aliasGEP - Provide a bunch of ad-hoc rules to disambiguate a GEP // instruction against another. AliasResult aliasGEP(const GEPOperator *V1, uint64_t V1Size, - const MDNode *V1TBAAInfo, + const AAMDNodes &V1AAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo, + const AAMDNodes &V2AAInfo, const Value *UnderlyingV1, const Value *UnderlyingV2); // aliasPHI - Provide a bunch of ad-hoc rules to disambiguate a PHI // instruction against another. AliasResult aliasPHI(const PHINode *PN, uint64_t PNSize, - const MDNode *PNTBAAInfo, + const AAMDNodes &PNAAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo); + const AAMDNodes &V2AAInfo); /// aliasSelect - Disambiguate a Select instruction against another value. AliasResult aliasSelect(const SelectInst *SI, uint64_t SISize, - const MDNode *SITBAAInfo, + const AAMDNodes &SIAAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo); + const AAMDNodes &V2AAInfo); AliasResult aliasCheck(const Value *V1, uint64_t V1Size, - const MDNode *V1TBAATag, + AAMDNodes V1AATag, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAATag); + AAMDNodes V2AATag); }; } // End of anonymous namespace @@ -584,6 +591,7 @@ char BasicAliasAnalysis::ID = 0; INITIALIZE_AG_PASS_BEGIN(BasicAliasAnalysis, AliasAnalysis, "basicaa", "Basic Alias Analysis (stateless AA impl)", false, true, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_AG_PASS_END(BasicAliasAnalysis, AliasAnalysis, "basicaa", "Basic Alias Analysis (stateless AA impl)", @@ -606,7 +614,7 @@ BasicAliasAnalysis::pointsToConstantMemory(const Location &Loc, bool OrLocal) { Worklist.push_back(Loc.Ptr); do { const Value *V = GetUnderlyingObject(Worklist.pop_back_val(), DL); - if (!Visited.insert(V)) { + if (!Visited.insert(V).second) { Visited.clear(); return AliasAnalysis::pointsToConstantMemory(Loc, OrLocal); } @@ -798,6 +806,14 @@ BasicAliasAnalysis::getArgLocation(ImmutableCallSite CS, unsigned ArgIdx, return Loc; } +static bool isAssumeIntrinsic(ImmutableCallSite CS) { + const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction()); + if (II && II->getIntrinsicID() == Intrinsic::assume) + return true; + + return false; +} + /// getModRefInfo - Check to see if the specified callsite can clobber the /// specified memory object. Since we only look at local properties of this /// function, we really can't say much about this query. We do, however, use @@ -850,10 +866,29 @@ BasicAliasAnalysis::getModRefInfo(ImmutableCallSite CS, return NoModRef; } + // While the assume intrinsic is marked as arbitrarily writing so that + // proper control dependencies will be maintained, it never aliases any + // particular memory location. + if (isAssumeIntrinsic(CS)) + return NoModRef; + // The AliasAnalysis base class has some smarts, lets use them. return AliasAnalysis::getModRefInfo(CS, Loc); } +AliasAnalysis::ModRefResult +BasicAliasAnalysis::getModRefInfo(ImmutableCallSite CS1, + ImmutableCallSite CS2) { + // While the assume intrinsic is marked as arbitrarily writing so that + // proper control dependencies will be maintained, it never aliases any + // particular memory location. + if (isAssumeIntrinsic(CS1) || isAssumeIntrinsic(CS2)) + return NoModRef; + + // The AliasAnalysis base class has some smarts, lets use them. + return AliasAnalysis::getModRefInfo(CS1, CS2); +} + /// aliasGEP - Provide a bunch of ad-hoc rules to disambiguate a GEP instruction /// against another pointer. We know that V1 is a GEP, but we don't know /// anything about V2. UnderlyingV1 is GetUnderlyingObject(GEP1, DL), @@ -861,30 +896,35 @@ BasicAliasAnalysis::getModRefInfo(ImmutableCallSite CS, /// AliasAnalysis::AliasResult BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, - const MDNode *V1TBAAInfo, + const AAMDNodes &V1AAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo, + const AAMDNodes &V2AAInfo, const Value *UnderlyingV1, const Value *UnderlyingV2) { int64_t GEP1BaseOffset; bool GEP1MaxLookupReached; SmallVector<VariableGEPIndex, 4> GEP1VariableIndices; + AssumptionTracker *AT = &getAnalysis<AssumptionTracker>(); + DominatorTreeWrapperPass *DTWP = + getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; + // If we have two gep instructions with must-alias or not-alias'ing base // pointers, figure out if the indexes to the GEP tell us anything about the // derived pointer. if (const GEPOperator *GEP2 = dyn_cast<GEPOperator>(V2)) { // Do the base pointers alias? - AliasResult BaseAlias = aliasCheck(UnderlyingV1, UnknownSize, nullptr, - UnderlyingV2, UnknownSize, nullptr); + AliasResult BaseAlias = aliasCheck(UnderlyingV1, UnknownSize, AAMDNodes(), + UnderlyingV2, UnknownSize, AAMDNodes()); // Check for geps of non-aliasing underlying pointers where the offsets are // identical. if ((BaseAlias == MayAlias) && V1Size == V2Size) { // Do the base pointers alias assuming type and size. AliasResult PreciseBaseAlias = aliasCheck(UnderlyingV1, V1Size, - V1TBAAInfo, UnderlyingV2, - V2Size, V2TBAAInfo); + V1AAInfo, UnderlyingV2, + V2Size, V2AAInfo); if (PreciseBaseAlias == NoAlias) { // See if the computed offset from the common pointer tells us about the // relation of the resulting pointer. @@ -893,10 +933,10 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, SmallVector<VariableGEPIndex, 4> GEP2VariableIndices; const Value *GEP2BasePtr = DecomposeGEPExpression(GEP2, GEP2BaseOffset, GEP2VariableIndices, - GEP2MaxLookupReached, DL); + GEP2MaxLookupReached, DL, AT, DT); const Value *GEP1BasePtr = DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL); + GEP1MaxLookupReached, DL, AT, DT); // DecomposeGEPExpression and GetUnderlyingObject should return the // same result except when DecomposeGEPExpression has no DataLayout. if (GEP1BasePtr != UnderlyingV1 || GEP2BasePtr != UnderlyingV2) { @@ -925,14 +965,14 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, // about the relation of the resulting pointer. const Value *GEP1BasePtr = DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL); + GEP1MaxLookupReached, DL, AT, DT); int64_t GEP2BaseOffset; bool GEP2MaxLookupReached; SmallVector<VariableGEPIndex, 4> GEP2VariableIndices; const Value *GEP2BasePtr = DecomposeGEPExpression(GEP2, GEP2BaseOffset, GEP2VariableIndices, - GEP2MaxLookupReached, DL); + GEP2MaxLookupReached, DL, AT, DT); // DecomposeGEPExpression and GetUnderlyingObject should return the // same result except when DecomposeGEPExpression has no DataLayout. @@ -959,8 +999,8 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, if (V1Size == UnknownSize && V2Size == UnknownSize) return MayAlias; - AliasResult R = aliasCheck(UnderlyingV1, UnknownSize, nullptr, - V2, V2Size, V2TBAAInfo); + AliasResult R = aliasCheck(UnderlyingV1, UnknownSize, AAMDNodes(), + V2, V2Size, V2AAInfo); if (R != MustAlias) // If V2 may alias GEP base pointer, conservatively returns MayAlias. // If V2 is known not to alias GEP base pointer, then the two values @@ -971,7 +1011,7 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, const Value *GEP1BasePtr = DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL); + GEP1MaxLookupReached, DL, AT, DT); // DecomposeGEPExpression and GetUnderlyingObject should return the // same result except when DecomposeGEPExpression has no DataLayout. @@ -1022,12 +1062,45 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, } } - // Try to distinguish something like &A[i][1] against &A[42][0]. - // Grab the least significant bit set in any of the scales. if (!GEP1VariableIndices.empty()) { uint64_t Modulo = 0; - for (unsigned i = 0, e = GEP1VariableIndices.size(); i != e; ++i) - Modulo |= (uint64_t)GEP1VariableIndices[i].Scale; + bool AllPositive = true; + for (unsigned i = 0, e = GEP1VariableIndices.size(); i != e; ++i) { + + // Try to distinguish something like &A[i][1] against &A[42][0]. + // Grab the least significant bit set in any of the scales. We + // don't need std::abs here (even if the scale's negative) as we'll + // be ^'ing Modulo with itself later. + Modulo |= (uint64_t) GEP1VariableIndices[i].Scale; + + if (AllPositive) { + // If the Value could change between cycles, then any reasoning about + // the Value this cycle may not hold in the next cycle. We'll just + // give up if we can't determine conditions that hold for every cycle: + const Value *V = GEP1VariableIndices[i].V; + + bool SignKnownZero, SignKnownOne; + ComputeSignBit( + const_cast<Value *>(V), + SignKnownZero, SignKnownOne, + DL, 0, AT, nullptr, DT); + + // Zero-extension widens the variable, and so forces the sign + // bit to zero. + bool IsZExt = GEP1VariableIndices[i].Extension == EK_ZeroExt; + SignKnownZero |= IsZExt; + SignKnownOne &= !IsZExt; + + // If the variable begins with a zero then we know it's + // positive, regardless of whether the value is signed or + // unsigned. + int64_t Scale = GEP1VariableIndices[i].Scale; + AllPositive = + (SignKnownZero && Scale >= 0) || + (SignKnownOne && Scale < 0); + } + } + Modulo = Modulo ^ (Modulo & (Modulo - 1)); // We can compute the difference between the two addresses @@ -1037,6 +1110,12 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, if (V1Size != UnknownSize && V2Size != UnknownSize && ModOffset >= V2Size && V1Size <= Modulo - ModOffset) return NoAlias; + + // If we know all the variables are positive, then GEP1 >= GEP1BasePtr. + // If GEP1BasePtr > V2 (GEP1BaseOffset > 0) then we know the pointers + // don't alias if V2Size can fit in the gap between V2 and GEP1BasePtr. + if (AllPositive && GEP1BaseOffset > 0 && V2Size <= (uint64_t) GEP1BaseOffset) + return NoAlias; } // Statically, we can see that the base objects are the same, but the @@ -1066,33 +1145,33 @@ MergeAliasResults(AliasAnalysis::AliasResult A, AliasAnalysis::AliasResult B) { /// instruction against another. AliasAnalysis::AliasResult BasicAliasAnalysis::aliasSelect(const SelectInst *SI, uint64_t SISize, - const MDNode *SITBAAInfo, + const AAMDNodes &SIAAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo) { + const AAMDNodes &V2AAInfo) { // If the values are Selects with the same condition, we can do a more precise // check: just check for aliases between the values on corresponding arms. if (const SelectInst *SI2 = dyn_cast<SelectInst>(V2)) if (SI->getCondition() == SI2->getCondition()) { AliasResult Alias = - aliasCheck(SI->getTrueValue(), SISize, SITBAAInfo, - SI2->getTrueValue(), V2Size, V2TBAAInfo); + aliasCheck(SI->getTrueValue(), SISize, SIAAInfo, + SI2->getTrueValue(), V2Size, V2AAInfo); if (Alias == MayAlias) return MayAlias; AliasResult ThisAlias = - aliasCheck(SI->getFalseValue(), SISize, SITBAAInfo, - SI2->getFalseValue(), V2Size, V2TBAAInfo); + aliasCheck(SI->getFalseValue(), SISize, SIAAInfo, + SI2->getFalseValue(), V2Size, V2AAInfo); return MergeAliasResults(ThisAlias, Alias); } // If both arms of the Select node NoAlias or MustAlias V2, then returns // NoAlias / MustAlias. Otherwise, returns MayAlias. AliasResult Alias = - aliasCheck(V2, V2Size, V2TBAAInfo, SI->getTrueValue(), SISize, SITBAAInfo); + aliasCheck(V2, V2Size, V2AAInfo, SI->getTrueValue(), SISize, SIAAInfo); if (Alias == MayAlias) return MayAlias; AliasResult ThisAlias = - aliasCheck(V2, V2Size, V2TBAAInfo, SI->getFalseValue(), SISize, SITBAAInfo); + aliasCheck(V2, V2Size, V2AAInfo, SI->getFalseValue(), SISize, SIAAInfo); return MergeAliasResults(ThisAlias, Alias); } @@ -1100,9 +1179,9 @@ BasicAliasAnalysis::aliasSelect(const SelectInst *SI, uint64_t SISize, // against another. AliasAnalysis::AliasResult BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, - const MDNode *PNTBAAInfo, + const AAMDNodes &PNAAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo) { + const AAMDNodes &V2AAInfo) { // Track phi nodes we have visited. We use this information when we determine // value equivalence. VisitedPhiBBs.insert(PN->getParent()); @@ -1112,8 +1191,8 @@ BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, // on corresponding edges. if (const PHINode *PN2 = dyn_cast<PHINode>(V2)) if (PN2->getParent() == PN->getParent()) { - LocPair Locs(Location(PN, PNSize, PNTBAAInfo), - Location(V2, V2Size, V2TBAAInfo)); + LocPair Locs(Location(PN, PNSize, PNAAInfo), + Location(V2, V2Size, V2AAInfo)); if (PN > V2) std::swap(Locs.first, Locs.second); // Analyse the PHIs' inputs under the assumption that the PHIs are @@ -1131,9 +1210,9 @@ BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { AliasResult ThisAlias = - aliasCheck(PN->getIncomingValue(i), PNSize, PNTBAAInfo, + aliasCheck(PN->getIncomingValue(i), PNSize, PNAAInfo, PN2->getIncomingValueForBlock(PN->getIncomingBlock(i)), - V2Size, V2TBAAInfo); + V2Size, V2AAInfo); Alias = MergeAliasResults(ThisAlias, Alias); if (Alias == MayAlias) break; @@ -1156,12 +1235,12 @@ BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, // sides are PHI nodes. In which case, this is O(m x n) time where 'm' // and 'n' are the number of PHI sources. return MayAlias; - if (UniqueSrc.insert(PV1)) + if (UniqueSrc.insert(PV1).second) V1Srcs.push_back(PV1); } - AliasResult Alias = aliasCheck(V2, V2Size, V2TBAAInfo, - V1Srcs[0], PNSize, PNTBAAInfo); + AliasResult Alias = aliasCheck(V2, V2Size, V2AAInfo, + V1Srcs[0], PNSize, PNAAInfo); // Early exit if the check of the first PHI source against V2 is MayAlias. // Other results are not possible. if (Alias == MayAlias) @@ -1172,8 +1251,8 @@ BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, for (unsigned i = 1, e = V1Srcs.size(); i != e; ++i) { Value *V = V1Srcs[i]; - AliasResult ThisAlias = aliasCheck(V2, V2Size, V2TBAAInfo, - V, PNSize, PNTBAAInfo); + AliasResult ThisAlias = aliasCheck(V2, V2Size, V2AAInfo, + V, PNSize, PNAAInfo); Alias = MergeAliasResults(ThisAlias, Alias); if (Alias == MayAlias) break; @@ -1187,9 +1266,9 @@ BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, // AliasAnalysis::AliasResult BasicAliasAnalysis::aliasCheck(const Value *V1, uint64_t V1Size, - const MDNode *V1TBAAInfo, + AAMDNodes V1AAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo) { + AAMDNodes V2AAInfo) { // If either of the memory references is empty, it doesn't matter what the // pointer values are. if (V1Size == 0 || V2Size == 0) @@ -1269,8 +1348,8 @@ BasicAliasAnalysis::aliasCheck(const Value *V1, uint64_t V1Size, // Check the cache before climbing up use-def chains. This also terminates // otherwise infinitely recursive queries. - LocPair Locs(Location(V1, V1Size, V1TBAAInfo), - Location(V2, V2Size, V2TBAAInfo)); + LocPair Locs(Location(V1, V1Size, V1AAInfo), + Location(V2, V2Size, V2AAInfo)); if (V1 > V2) std::swap(Locs.first, Locs.second); std::pair<AliasCacheTy::iterator, bool> Pair = @@ -1284,32 +1363,32 @@ BasicAliasAnalysis::aliasCheck(const Value *V1, uint64_t V1Size, std::swap(V1, V2); std::swap(V1Size, V2Size); std::swap(O1, O2); - std::swap(V1TBAAInfo, V2TBAAInfo); + std::swap(V1AAInfo, V2AAInfo); } if (const GEPOperator *GV1 = dyn_cast<GEPOperator>(V1)) { - AliasResult Result = aliasGEP(GV1, V1Size, V1TBAAInfo, V2, V2Size, V2TBAAInfo, O1, O2); + AliasResult Result = aliasGEP(GV1, V1Size, V1AAInfo, V2, V2Size, V2AAInfo, O1, O2); if (Result != MayAlias) return AliasCache[Locs] = Result; } if (isa<PHINode>(V2) && !isa<PHINode>(V1)) { std::swap(V1, V2); std::swap(V1Size, V2Size); - std::swap(V1TBAAInfo, V2TBAAInfo); + std::swap(V1AAInfo, V2AAInfo); } if (const PHINode *PN = dyn_cast<PHINode>(V1)) { - AliasResult Result = aliasPHI(PN, V1Size, V1TBAAInfo, - V2, V2Size, V2TBAAInfo); + AliasResult Result = aliasPHI(PN, V1Size, V1AAInfo, + V2, V2Size, V2AAInfo); if (Result != MayAlias) return AliasCache[Locs] = Result; } if (isa<SelectInst>(V2) && !isa<SelectInst>(V1)) { std::swap(V1, V2); std::swap(V1Size, V2Size); - std::swap(V1TBAAInfo, V2TBAAInfo); + std::swap(V1AAInfo, V2AAInfo); } if (const SelectInst *S1 = dyn_cast<SelectInst>(V1)) { - AliasResult Result = aliasSelect(S1, V1Size, V1TBAAInfo, - V2, V2Size, V2TBAAInfo); + AliasResult Result = aliasSelect(S1, V1Size, V1AAInfo, + V2, V2Size, V2AAInfo); if (Result != MayAlias) return AliasCache[Locs] = Result; } @@ -1322,8 +1401,8 @@ BasicAliasAnalysis::aliasCheck(const Value *V1, uint64_t V1Size, return AliasCache[Locs] = PartialAlias; AliasResult Result = - AliasAnalysis::alias(Location(V1, V1Size, V1TBAAInfo), - Location(V2, V2Size, V2TBAAInfo)); + AliasAnalysis::alias(Location(V1, V1Size, V1AAInfo), + Location(V2, V2Size, V2AAInfo)); return AliasCache[Locs] = Result; } @@ -1348,10 +1427,8 @@ bool BasicAliasAnalysis::isValueEqualInPotentialCycles(const Value *V, // Make sure that the visited phis cannot reach the Value. This ensures that // the Values cannot come from different iterations of a potential cycle the // phi nodes could be involved in. - for (SmallPtrSet<const BasicBlock *, 8>::iterator PI = VisitedPhiBBs.begin(), - PE = VisitedPhiBBs.end(); - PI != PE; ++PI) - if (isPotentiallyReachable((*PI)->begin(), Inst, DT, LI)) + for (auto *P : VisitedPhiBBs) + if (isPotentiallyReachable(P->begin(), Inst, DT, LI)) return false; return true; diff --git a/lib/Analysis/BlockFrequencyInfoImpl.cpp b/lib/Analysis/BlockFrequencyInfoImpl.cpp index 4fd2c11..06b8acd 100644 --- a/lib/Analysis/BlockFrequencyInfoImpl.cpp +++ b/lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -14,18 +14,12 @@ #include "llvm/Analysis/BlockFrequencyInfoImpl.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/Support/raw_ostream.h" -#include <deque> using namespace llvm; using namespace llvm::bfi_detail; #define DEBUG_TYPE "block-freq" -//===----------------------------------------------------------------------===// -// -// BlockMass implementation. -// -//===----------------------------------------------------------------------===// ScaledNumber<uint64_t> BlockMass::toScaled() const { if (isFull()) return ScaledNumber<uint64_t>(1, 0); @@ -46,11 +40,6 @@ raw_ostream &BlockMass::print(raw_ostream &OS) const { return OS; } -//===----------------------------------------------------------------------===// -// -// BlockFrequencyInfoImpl implementation. -// -//===----------------------------------------------------------------------===// namespace { typedef BlockFrequencyInfoImplBase::BlockNode BlockNode; @@ -87,7 +76,8 @@ struct DitheringDistributer { BlockMass takeMass(uint32_t Weight); }; -} + +} // end namespace DitheringDistributer::DitheringDistributer(Distribution &Dist, const BlockMass &Mass) { @@ -121,11 +111,7 @@ void Distribution::add(const BlockNode &Node, uint64_t Amount, Total = NewTotal; // Save the weight. - Weight W; - W.TargetNode = Node; - W.Amount = Amount; - W.Type = Type; - Weights.push_back(W); + Weights.push_back(Weight(Type, Node, Amount)); } static void combineWeight(Weight &W, const Weight &OtherW) { @@ -615,7 +601,8 @@ static void findIrreducibleHeaders( break; } } - assert(Headers.size() >= 2 && "Should be irreducible"); + assert(Headers.size() >= 2 && + "Expected irreducible CFG; -loop-info is likely invalid"); if (Headers.size() == InSCC.size()) { // Every block is a header. std::sort(Headers.begin(), Headers.end()); diff --git a/lib/Analysis/CFG.cpp b/lib/Analysis/CFG.cpp index 8ef5302..25e7bc0 100644 --- a/lib/Analysis/CFG.cpp +++ b/lib/Analysis/CFG.cpp @@ -45,7 +45,7 @@ void llvm::FindFunctionBackedges(const Function &F, bool FoundNew = false; while (I != succ_end(ParentBB)) { BB = *I++; - if (Visited.insert(BB)) { + if (Visited.insert(BB).second) { FoundNew = true; break; } @@ -141,7 +141,7 @@ static bool isPotentiallyReachableInner(SmallVectorImpl<BasicBlock *> &Worklist, SmallSet<const BasicBlock*, 64> Visited; do { BasicBlock *BB = Worklist.pop_back_val(); - if (!Visited.insert(BB)) + if (!Visited.insert(BB).second) continue; if (BB == StopBB) return true; diff --git a/lib/Analysis/CFGPrinter.cpp b/lib/Analysis/CFGPrinter.cpp index c2c19d6..89787f82 100644 --- a/lib/Analysis/CFGPrinter.cpp +++ b/lib/Analysis/CFGPrinter.cpp @@ -79,11 +79,11 @@ namespace { bool runOnFunction(Function &F) override { std::string Filename = "cfg." + F.getName().str() + ".dot"; errs() << "Writing '" << Filename << "'..."; - - std::string ErrorInfo; - raw_fd_ostream File(Filename.c_str(), ErrorInfo, sys::fs::F_Text); - if (ErrorInfo.empty()) + std::error_code EC; + raw_fd_ostream File(Filename, EC, sys::fs::F_Text); + + if (!EC) WriteGraph(File, (const Function*)&F); else errs() << " error opening file for writing!"; @@ -114,10 +114,10 @@ namespace { std::string Filename = "cfg." + F.getName().str() + ".dot"; errs() << "Writing '" << Filename << "'..."; - std::string ErrorInfo; - raw_fd_ostream File(Filename.c_str(), ErrorInfo, sys::fs::F_Text); - - if (ErrorInfo.empty()) + std::error_code EC; + raw_fd_ostream File(Filename, EC, sys::fs::F_Text); + + if (!EC) WriteGraph(File, (const Function*)&F, true); else errs() << " error opening file for writing!"; diff --git a/lib/Analysis/CFLAliasAnalysis.cpp b/lib/Analysis/CFLAliasAnalysis.cpp new file mode 100644 index 0000000..5f1b3d3 --- /dev/null +++ b/lib/Analysis/CFLAliasAnalysis.cpp @@ -0,0 +1,1013 @@ +//===- CFLAliasAnalysis.cpp - CFL-Based Alias Analysis Implementation ------==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a CFL-based context-insensitive alias analysis +// algorithm. It does not depend on types. The algorithm is a mixture of the one +// described in "Demand-driven alias analysis for C" by Xin Zheng and Radu +// Rugina, and "Fast algorithms for Dyck-CFL-reachability with applications to +// Alias Analysis" by Zhang Q, Lyu M R, Yuan H, and Su Z. -- to summarize the +// papers, we build a graph of the uses of a variable, where each node is a +// memory location, and each edge is an action that happened on that memory +// location. The "actions" can be one of Dereference, Reference, Assign, or +// Assign. +// +// Two variables are considered as aliasing iff you can reach one value's node +// from the other value's node and the language formed by concatenating all of +// the edge labels (actions) conforms to a context-free grammar. +// +// Because this algorithm requires a graph search on each query, we execute the +// algorithm outlined in "Fast algorithms..." (mentioned above) +// in order to transform the graph into sets of variables that may alias in +// ~nlogn time (n = number of variables.), which makes queries take constant +// time. +//===----------------------------------------------------------------------===// + +#include "StratifiedSets.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/None.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/ErrorHandling.h" +#include <algorithm> +#include <cassert> +#include <forward_list> +#include <tuple> + +using namespace llvm; + +// Try to go from a Value* to a Function*. Never returns nullptr. +static Optional<Function *> parentFunctionOfValue(Value *); + +// Returns possible functions called by the Inst* into the given +// SmallVectorImpl. Returns true if targets found, false otherwise. +// This is templated because InvokeInst/CallInst give us the same +// set of functions that we care about, and I don't like repeating +// myself. +template <typename Inst> +static bool getPossibleTargets(Inst *, SmallVectorImpl<Function *> &); + +// Some instructions need to have their users tracked. Instructions like +// `add` require you to get the users of the Instruction* itself, other +// instructions like `store` require you to get the users of the first +// operand. This function gets the "proper" value to track for each +// type of instruction we support. +static Optional<Value *> getTargetValue(Instruction *); + +// There are certain instructions (i.e. FenceInst, etc.) that we ignore. +// This notes that we should ignore those. +static bool hasUsefulEdges(Instruction *); + +const StratifiedIndex StratifiedLink::SetSentinel = + std::numeric_limits<StratifiedIndex>::max(); + +namespace { +// StratifiedInfo Attribute things. +typedef unsigned StratifiedAttr; +LLVM_CONSTEXPR unsigned MaxStratifiedAttrIndex = NumStratifiedAttrs; +LLVM_CONSTEXPR unsigned AttrAllIndex = 0; +LLVM_CONSTEXPR unsigned AttrGlobalIndex = 1; +LLVM_CONSTEXPR unsigned AttrFirstArgIndex = 2; +LLVM_CONSTEXPR unsigned AttrLastArgIndex = MaxStratifiedAttrIndex; +LLVM_CONSTEXPR unsigned AttrMaxNumArgs = AttrLastArgIndex - AttrFirstArgIndex; + +LLVM_CONSTEXPR StratifiedAttr AttrNone = 0; +LLVM_CONSTEXPR StratifiedAttr AttrAll = ~AttrNone; + +// \brief StratifiedSets call for knowledge of "direction", so this is how we +// represent that locally. +enum class Level { Same, Above, Below }; + +// \brief Edges can be one of four "weights" -- each weight must have an inverse +// weight (Assign has Assign; Reference has Dereference). +enum class EdgeType { + // The weight assigned when assigning from or to a value. For example, in: + // %b = getelementptr %a, 0 + // ...The relationships are %b assign %a, and %a assign %b. This used to be + // two edges, but having a distinction bought us nothing. + Assign, + + // The edge used when we have an edge going from some handle to a Value. + // Examples of this include: + // %b = load %a (%b Dereference %a) + // %b = extractelement %a, 0 (%a Dereference %b) + Dereference, + + // The edge used when our edge goes from a value to a handle that may have + // contained it at some point. Examples: + // %b = load %a (%a Reference %b) + // %b = extractelement %a, 0 (%b Reference %a) + Reference +}; + +// \brief Encodes the notion of a "use" +struct Edge { + // \brief Which value the edge is coming from + Value *From; + + // \brief Which value the edge is pointing to + Value *To; + + // \brief Edge weight + EdgeType Weight; + + // \brief Whether we aliased any external values along the way that may be + // invisible to the analysis (i.e. landingpad for exceptions, calls for + // interprocedural analysis, etc.) + StratifiedAttrs AdditionalAttrs; + + Edge(Value *From, Value *To, EdgeType W, StratifiedAttrs A) + : From(From), To(To), Weight(W), AdditionalAttrs(A) {} +}; + +// \brief Information we have about a function and would like to keep around +struct FunctionInfo { + StratifiedSets<Value *> Sets; + // Lots of functions have < 4 returns. Adjust as necessary. + SmallVector<Value *, 4> ReturnedValues; + + FunctionInfo(StratifiedSets<Value *> &&S, + SmallVector<Value *, 4> &&RV) + : Sets(std::move(S)), ReturnedValues(std::move(RV)) {} +}; + +struct CFLAliasAnalysis; + +struct FunctionHandle : public CallbackVH { + FunctionHandle(Function *Fn, CFLAliasAnalysis *CFLAA) + : CallbackVH(Fn), CFLAA(CFLAA) { + assert(Fn != nullptr); + assert(CFLAA != nullptr); + } + + virtual ~FunctionHandle() {} + + void deleted() override { removeSelfFromCache(); } + void allUsesReplacedWith(Value *) override { removeSelfFromCache(); } + +private: + CFLAliasAnalysis *CFLAA; + + void removeSelfFromCache(); +}; + +struct CFLAliasAnalysis : public ImmutablePass, public AliasAnalysis { +private: + /// \brief Cached mapping of Functions to their StratifiedSets. + /// If a function's sets are currently being built, it is marked + /// in the cache as an Optional without a value. This way, if we + /// have any kind of recursion, it is discernable from a function + /// that simply has empty sets. + DenseMap<Function *, Optional<FunctionInfo>> Cache; + std::forward_list<FunctionHandle> Handles; + +public: + static char ID; + + CFLAliasAnalysis() : ImmutablePass(ID) { + initializeCFLAliasAnalysisPass(*PassRegistry::getPassRegistry()); + } + + virtual ~CFLAliasAnalysis() {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AliasAnalysis::getAnalysisUsage(AU); + } + + void *getAdjustedAnalysisPointer(const void *ID) override { + if (ID == &AliasAnalysis::ID) + return (AliasAnalysis *)this; + return this; + } + + /// \brief Inserts the given Function into the cache. + void scan(Function *Fn); + + void evict(Function *Fn) { Cache.erase(Fn); } + + /// \brief Ensures that the given function is available in the cache. + /// Returns the appropriate entry from the cache. + const Optional<FunctionInfo> &ensureCached(Function *Fn) { + auto Iter = Cache.find(Fn); + if (Iter == Cache.end()) { + scan(Fn); + Iter = Cache.find(Fn); + assert(Iter != Cache.end()); + assert(Iter->second.hasValue()); + } + return Iter->second; + } + + AliasResult query(const Location &LocA, const Location &LocB); + + AliasResult alias(const Location &LocA, const Location &LocB) override { + if (LocA.Ptr == LocB.Ptr) { + if (LocA.Size == LocB.Size) { + return MustAlias; + } else { + return PartialAlias; + } + } + + // Comparisons between global variables and other constants should be + // handled by BasicAA. + if (isa<Constant>(LocA.Ptr) && isa<Constant>(LocB.Ptr)) { + return MayAlias; + } + + return query(LocA, LocB); + } + + void initializePass() override { InitializeAliasAnalysis(this); } +}; + +void FunctionHandle::removeSelfFromCache() { + assert(CFLAA != nullptr); + auto *Val = getValPtr(); + CFLAA->evict(cast<Function>(Val)); + setValPtr(nullptr); +} + +// \brief Gets the edges our graph should have, based on an Instruction* +class GetEdgesVisitor : public InstVisitor<GetEdgesVisitor, void> { + CFLAliasAnalysis &AA; + SmallVectorImpl<Edge> &Output; + +public: + GetEdgesVisitor(CFLAliasAnalysis &AA, SmallVectorImpl<Edge> &Output) + : AA(AA), Output(Output) {} + + void visitInstruction(Instruction &) { + llvm_unreachable("Unsupported instruction encountered"); + } + + void visitCastInst(CastInst &Inst) { + Output.push_back(Edge(&Inst, Inst.getOperand(0), EdgeType::Assign, + AttrNone)); + } + + void visitBinaryOperator(BinaryOperator &Inst) { + auto *Op1 = Inst.getOperand(0); + auto *Op2 = Inst.getOperand(1); + Output.push_back(Edge(&Inst, Op1, EdgeType::Assign, AttrNone)); + Output.push_back(Edge(&Inst, Op2, EdgeType::Assign, AttrNone)); + } + + void visitAtomicCmpXchgInst(AtomicCmpXchgInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getNewValOperand(); + Output.push_back(Edge(Ptr, Val, EdgeType::Dereference, AttrNone)); + } + + void visitAtomicRMWInst(AtomicRMWInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getValOperand(); + Output.push_back(Edge(Ptr, Val, EdgeType::Dereference, AttrNone)); + } + + void visitPHINode(PHINode &Inst) { + for (unsigned I = 0, E = Inst.getNumIncomingValues(); I != E; ++I) { + Value *Val = Inst.getIncomingValue(I); + Output.push_back(Edge(&Inst, Val, EdgeType::Assign, AttrNone)); + } + } + + void visitGetElementPtrInst(GetElementPtrInst &Inst) { + auto *Op = Inst.getPointerOperand(); + Output.push_back(Edge(&Inst, Op, EdgeType::Assign, AttrNone)); + for (auto I = Inst.idx_begin(), E = Inst.idx_end(); I != E; ++I) + Output.push_back(Edge(&Inst, *I, EdgeType::Assign, AttrNone)); + } + + void visitSelectInst(SelectInst &Inst) { + auto *Condition = Inst.getCondition(); + Output.push_back(Edge(&Inst, Condition, EdgeType::Assign, AttrNone)); + auto *TrueVal = Inst.getTrueValue(); + Output.push_back(Edge(&Inst, TrueVal, EdgeType::Assign, AttrNone)); + auto *FalseVal = Inst.getFalseValue(); + Output.push_back(Edge(&Inst, FalseVal, EdgeType::Assign, AttrNone)); + } + + void visitAllocaInst(AllocaInst &) {} + + void visitLoadInst(LoadInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = &Inst; + Output.push_back(Edge(Val, Ptr, EdgeType::Reference, AttrNone)); + } + + void visitStoreInst(StoreInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getValueOperand(); + Output.push_back(Edge(Ptr, Val, EdgeType::Dereference, AttrNone)); + } + + void visitVAArgInst(VAArgInst &Inst) { + // We can't fully model va_arg here. For *Ptr = Inst.getOperand(0), it does + // two things: + // 1. Loads a value from *((T*)*Ptr). + // 2. Increments (stores to) *Ptr by some target-specific amount. + // For now, we'll handle this like a landingpad instruction (by placing the + // result in its own group, and having that group alias externals). + auto *Val = &Inst; + Output.push_back(Edge(Val, Val, EdgeType::Assign, AttrAll)); + } + + static bool isFunctionExternal(Function *Fn) { + return Fn->isDeclaration() || !Fn->hasLocalLinkage(); + } + + // Gets whether the sets at Index1 above, below, or equal to the sets at + // Index2. Returns None if they are not in the same set chain. + static Optional<Level> getIndexRelation(const StratifiedSets<Value *> &Sets, + StratifiedIndex Index1, + StratifiedIndex Index2) { + if (Index1 == Index2) + return Level::Same; + + const auto *Current = &Sets.getLink(Index1); + while (Current->hasBelow()) { + if (Current->Below == Index2) + return Level::Below; + Current = &Sets.getLink(Current->Below); + } + + Current = &Sets.getLink(Index1); + while (Current->hasAbove()) { + if (Current->Above == Index2) + return Level::Above; + Current = &Sets.getLink(Current->Above); + } + + return NoneType(); + } + + bool + tryInterproceduralAnalysis(const SmallVectorImpl<Function *> &Fns, + Value *FuncValue, + const iterator_range<User::op_iterator> &Args) { + const unsigned ExpectedMaxArgs = 8; + const unsigned MaxSupportedArgs = 50; + assert(Fns.size() > 0); + + // I put this here to give us an upper bound on time taken by IPA. Is it + // really (realistically) needed? Keep in mind that we do have an n^2 algo. + if (std::distance(Args.begin(), Args.end()) > (int) MaxSupportedArgs) + return false; + + // Exit early if we'll fail anyway + for (auto *Fn : Fns) { + if (isFunctionExternal(Fn) || Fn->isVarArg()) + return false; + auto &MaybeInfo = AA.ensureCached(Fn); + if (!MaybeInfo.hasValue()) + return false; + } + + SmallVector<Value *, ExpectedMaxArgs> Arguments(Args.begin(), Args.end()); + SmallVector<StratifiedInfo, ExpectedMaxArgs> Parameters; + for (auto *Fn : Fns) { + auto &Info = *AA.ensureCached(Fn); + auto &Sets = Info.Sets; + auto &RetVals = Info.ReturnedValues; + + Parameters.clear(); + for (auto &Param : Fn->args()) { + auto MaybeInfo = Sets.find(&Param); + // Did a new parameter somehow get added to the function/slip by? + if (!MaybeInfo.hasValue()) + return false; + Parameters.push_back(*MaybeInfo); + } + + // Adding an edge from argument -> return value for each parameter that + // may alias the return value + for (unsigned I = 0, E = Parameters.size(); I != E; ++I) { + auto &ParamInfo = Parameters[I]; + auto &ArgVal = Arguments[I]; + bool AddEdge = false; + StratifiedAttrs Externals; + for (unsigned X = 0, XE = RetVals.size(); X != XE; ++X) { + auto MaybeInfo = Sets.find(RetVals[X]); + if (!MaybeInfo.hasValue()) + return false; + + auto &RetInfo = *MaybeInfo; + auto RetAttrs = Sets.getLink(RetInfo.Index).Attrs; + auto ParamAttrs = Sets.getLink(ParamInfo.Index).Attrs; + auto MaybeRelation = + getIndexRelation(Sets, ParamInfo.Index, RetInfo.Index); + if (MaybeRelation.hasValue()) { + AddEdge = true; + Externals |= RetAttrs | ParamAttrs; + } + } + if (AddEdge) + Output.push_back(Edge(FuncValue, ArgVal, EdgeType::Assign, + StratifiedAttrs().flip())); + } + + if (Parameters.size() != Arguments.size()) + return false; + + // Adding edges between arguments for arguments that may end up aliasing + // each other. This is necessary for functions such as + // void foo(int** a, int** b) { *a = *b; } + // (Technically, the proper sets for this would be those below + // Arguments[I] and Arguments[X], but our algorithm will produce + // extremely similar, and equally correct, results either way) + for (unsigned I = 0, E = Arguments.size(); I != E; ++I) { + auto &MainVal = Arguments[I]; + auto &MainInfo = Parameters[I]; + auto &MainAttrs = Sets.getLink(MainInfo.Index).Attrs; + for (unsigned X = I + 1; X != E; ++X) { + auto &SubInfo = Parameters[X]; + auto &SubVal = Arguments[X]; + auto &SubAttrs = Sets.getLink(SubInfo.Index).Attrs; + auto MaybeRelation = + getIndexRelation(Sets, MainInfo.Index, SubInfo.Index); + + if (!MaybeRelation.hasValue()) + continue; + + auto NewAttrs = SubAttrs | MainAttrs; + Output.push_back(Edge(MainVal, SubVal, EdgeType::Assign, NewAttrs)); + } + } + } + return true; + } + + template <typename InstT> void visitCallLikeInst(InstT &Inst) { + SmallVector<Function *, 4> Targets; + if (getPossibleTargets(&Inst, Targets)) { + if (tryInterproceduralAnalysis(Targets, &Inst, Inst.arg_operands())) + return; + // Cleanup from interprocedural analysis + Output.clear(); + } + + for (Value *V : Inst.arg_operands()) + Output.push_back(Edge(&Inst, V, EdgeType::Assign, AttrAll)); + } + + void visitCallInst(CallInst &Inst) { visitCallLikeInst(Inst); } + + void visitInvokeInst(InvokeInst &Inst) { visitCallLikeInst(Inst); } + + // Because vectors/aggregates are immutable and unaddressable, + // there's nothing we can do to coax a value out of them, other + // than calling Extract{Element,Value}. We can effectively treat + // them as pointers to arbitrary memory locations we can store in + // and load from. + void visitExtractElementInst(ExtractElementInst &Inst) { + auto *Ptr = Inst.getVectorOperand(); + auto *Val = &Inst; + Output.push_back(Edge(Val, Ptr, EdgeType::Reference, AttrNone)); + } + + void visitInsertElementInst(InsertElementInst &Inst) { + auto *Vec = Inst.getOperand(0); + auto *Val = Inst.getOperand(1); + Output.push_back(Edge(&Inst, Vec, EdgeType::Assign, AttrNone)); + Output.push_back(Edge(&Inst, Val, EdgeType::Dereference, AttrNone)); + } + + void visitLandingPadInst(LandingPadInst &Inst) { + // Exceptions come from "nowhere", from our analysis' perspective. + // So we place the instruction its own group, noting that said group may + // alias externals + Output.push_back(Edge(&Inst, &Inst, EdgeType::Assign, AttrAll)); + } + + void visitInsertValueInst(InsertValueInst &Inst) { + auto *Agg = Inst.getOperand(0); + auto *Val = Inst.getOperand(1); + Output.push_back(Edge(&Inst, Agg, EdgeType::Assign, AttrNone)); + Output.push_back(Edge(&Inst, Val, EdgeType::Dereference, AttrNone)); + } + + void visitExtractValueInst(ExtractValueInst &Inst) { + auto *Ptr = Inst.getAggregateOperand(); + Output.push_back(Edge(&Inst, Ptr, EdgeType::Reference, AttrNone)); + } + + void visitShuffleVectorInst(ShuffleVectorInst &Inst) { + auto *From1 = Inst.getOperand(0); + auto *From2 = Inst.getOperand(1); + Output.push_back(Edge(&Inst, From1, EdgeType::Assign, AttrNone)); + Output.push_back(Edge(&Inst, From2, EdgeType::Assign, AttrNone)); + } +}; + +// For a given instruction, we need to know which Value* to get the +// users of in order to build our graph. In some cases (i.e. add), +// we simply need the Instruction*. In other cases (i.e. store), +// finding the users of the Instruction* is useless; we need to find +// the users of the first operand. This handles determining which +// value to follow for us. +// +// Note: we *need* to keep this in sync with GetEdgesVisitor. Add +// something to GetEdgesVisitor, add it here -- remove something from +// GetEdgesVisitor, remove it here. +class GetTargetValueVisitor + : public InstVisitor<GetTargetValueVisitor, Value *> { +public: + Value *visitInstruction(Instruction &Inst) { return &Inst; } + + Value *visitStoreInst(StoreInst &Inst) { return Inst.getPointerOperand(); } + + Value *visitAtomicCmpXchgInst(AtomicCmpXchgInst &Inst) { + return Inst.getPointerOperand(); + } + + Value *visitAtomicRMWInst(AtomicRMWInst &Inst) { + return Inst.getPointerOperand(); + } + + Value *visitInsertElementInst(InsertElementInst &Inst) { + return Inst.getOperand(0); + } + + Value *visitInsertValueInst(InsertValueInst &Inst) { + return Inst.getAggregateOperand(); + } +}; + +// Set building requires a weighted bidirectional graph. +template <typename EdgeTypeT> class WeightedBidirectionalGraph { +public: + typedef std::size_t Node; + +private: + const static Node StartNode = Node(0); + + struct Edge { + EdgeTypeT Weight; + Node Other; + + Edge(const EdgeTypeT &W, const Node &N) + : Weight(W), Other(N) {} + + bool operator==(const Edge &E) const { + return Weight == E.Weight && Other == E.Other; + } + + bool operator!=(const Edge &E) const { return !operator==(E); } + }; + + struct NodeImpl { + std::vector<Edge> Edges; + }; + + std::vector<NodeImpl> NodeImpls; + + bool inbounds(Node NodeIndex) const { return NodeIndex < NodeImpls.size(); } + + const NodeImpl &getNode(Node N) const { return NodeImpls[N]; } + NodeImpl &getNode(Node N) { return NodeImpls[N]; } + +public: + // ----- Various Edge iterators for the graph ----- // + + // \brief Iterator for edges. Because this graph is bidirected, we don't + // allow modificaiton of the edges using this iterator. Additionally, the + // iterator becomes invalid if you add edges to or from the node you're + // getting the edges of. + struct EdgeIterator : public std::iterator<std::forward_iterator_tag, + std::tuple<EdgeTypeT, Node *>> { + EdgeIterator(const typename std::vector<Edge>::const_iterator &Iter) + : Current(Iter) {} + + EdgeIterator(NodeImpl &Impl) : Current(Impl.begin()) {} + + EdgeIterator &operator++() { + ++Current; + return *this; + } + + EdgeIterator operator++(int) { + EdgeIterator Copy(Current); + operator++(); + return Copy; + } + + std::tuple<EdgeTypeT, Node> &operator*() { + Store = std::make_tuple(Current->Weight, Current->Other); + return Store; + } + + bool operator==(const EdgeIterator &Other) const { + return Current == Other.Current; + } + + bool operator!=(const EdgeIterator &Other) const { + return !operator==(Other); + } + + private: + typename std::vector<Edge>::const_iterator Current; + std::tuple<EdgeTypeT, Node> Store; + }; + + // Wrapper for EdgeIterator with begin()/end() calls. + struct EdgeIterable { + EdgeIterable(const std::vector<Edge> &Edges) + : BeginIter(Edges.begin()), EndIter(Edges.end()) {} + + EdgeIterator begin() { return EdgeIterator(BeginIter); } + + EdgeIterator end() { return EdgeIterator(EndIter); } + + private: + typename std::vector<Edge>::const_iterator BeginIter; + typename std::vector<Edge>::const_iterator EndIter; + }; + + // ----- Actual graph-related things ----- // + + WeightedBidirectionalGraph() {} + + WeightedBidirectionalGraph(WeightedBidirectionalGraph<EdgeTypeT> &&Other) + : NodeImpls(std::move(Other.NodeImpls)) {} + + WeightedBidirectionalGraph<EdgeTypeT> & + operator=(WeightedBidirectionalGraph<EdgeTypeT> &&Other) { + NodeImpls = std::move(Other.NodeImpls); + return *this; + } + + Node addNode() { + auto Index = NodeImpls.size(); + auto NewNode = Node(Index); + NodeImpls.push_back(NodeImpl()); + return NewNode; + } + + void addEdge(Node From, Node To, const EdgeTypeT &Weight, + const EdgeTypeT &ReverseWeight) { + assert(inbounds(From)); + assert(inbounds(To)); + auto &FromNode = getNode(From); + auto &ToNode = getNode(To); + FromNode.Edges.push_back(Edge(Weight, To)); + ToNode.Edges.push_back(Edge(ReverseWeight, From)); + } + + EdgeIterable edgesFor(const Node &N) const { + const auto &Node = getNode(N); + return EdgeIterable(Node.Edges); + } + + bool empty() const { return NodeImpls.empty(); } + std::size_t size() const { return NodeImpls.size(); } + + // \brief Gets an arbitrary node in the graph as a starting point for + // traversal. + Node getEntryNode() { + assert(inbounds(StartNode)); + return StartNode; + } +}; + +typedef WeightedBidirectionalGraph<std::pair<EdgeType, StratifiedAttrs>> GraphT; +typedef DenseMap<Value *, GraphT::Node> NodeMapT; +} + +// -- Setting up/registering CFLAA pass -- // +char CFLAliasAnalysis::ID = 0; + +INITIALIZE_AG_PASS(CFLAliasAnalysis, AliasAnalysis, "cfl-aa", + "CFL-Based AA implementation", false, true, false) + +ImmutablePass *llvm::createCFLAliasAnalysisPass() { + return new CFLAliasAnalysis(); +} + +//===----------------------------------------------------------------------===// +// Function declarations that require types defined in the namespace above +//===----------------------------------------------------------------------===// + +// Given an argument number, returns the appropriate Attr index to set. +static StratifiedAttr argNumberToAttrIndex(StratifiedAttr); + +// Given a Value, potentially return which AttrIndex it maps to. +static Optional<StratifiedAttr> valueToAttrIndex(Value *Val); + +// Gets the inverse of a given EdgeType. +static EdgeType flipWeight(EdgeType); + +// Gets edges of the given Instruction*, writing them to the SmallVector*. +static void argsToEdges(CFLAliasAnalysis &, Instruction *, + SmallVectorImpl<Edge> &); + +// Gets the "Level" that one should travel in StratifiedSets +// given an EdgeType. +static Level directionOfEdgeType(EdgeType); + +// Builds the graph needed for constructing the StratifiedSets for the +// given function +static void buildGraphFrom(CFLAliasAnalysis &, Function *, + SmallVectorImpl<Value *> &, NodeMapT &, GraphT &); + +// Builds the graph + StratifiedSets for a function. +static FunctionInfo buildSetsFrom(CFLAliasAnalysis &, Function *); + +static Optional<Function *> parentFunctionOfValue(Value *Val) { + if (auto *Inst = dyn_cast<Instruction>(Val)) { + auto *Bb = Inst->getParent(); + return Bb->getParent(); + } + + if (auto *Arg = dyn_cast<Argument>(Val)) + return Arg->getParent(); + return NoneType(); +} + +template <typename Inst> +static bool getPossibleTargets(Inst *Call, + SmallVectorImpl<Function *> &Output) { + if (auto *Fn = Call->getCalledFunction()) { + Output.push_back(Fn); + return true; + } + + // TODO: If the call is indirect, we might be able to enumerate all potential + // targets of the call and return them, rather than just failing. + return false; +} + +static Optional<Value *> getTargetValue(Instruction *Inst) { + GetTargetValueVisitor V; + return V.visit(Inst); +} + +static bool hasUsefulEdges(Instruction *Inst) { + bool IsNonInvokeTerminator = + isa<TerminatorInst>(Inst) && !isa<InvokeInst>(Inst); + return !isa<CmpInst>(Inst) && !isa<FenceInst>(Inst) && !IsNonInvokeTerminator; +} + +static Optional<StratifiedAttr> valueToAttrIndex(Value *Val) { + if (isa<GlobalValue>(Val)) + return AttrGlobalIndex; + + if (auto *Arg = dyn_cast<Argument>(Val)) + if (!Arg->hasNoAliasAttr()) + return argNumberToAttrIndex(Arg->getArgNo()); + return NoneType(); +} + +static StratifiedAttr argNumberToAttrIndex(unsigned ArgNum) { + if (ArgNum > AttrMaxNumArgs) + return AttrAllIndex; + return ArgNum + AttrFirstArgIndex; +} + +static EdgeType flipWeight(EdgeType Initial) { + switch (Initial) { + case EdgeType::Assign: + return EdgeType::Assign; + case EdgeType::Dereference: + return EdgeType::Reference; + case EdgeType::Reference: + return EdgeType::Dereference; + } + llvm_unreachable("Incomplete coverage of EdgeType enum"); +} + +static void argsToEdges(CFLAliasAnalysis &Analysis, Instruction *Inst, + SmallVectorImpl<Edge> &Output) { + GetEdgesVisitor v(Analysis, Output); + v.visit(Inst); +} + +static Level directionOfEdgeType(EdgeType Weight) { + switch (Weight) { + case EdgeType::Reference: + return Level::Above; + case EdgeType::Dereference: + return Level::Below; + case EdgeType::Assign: + return Level::Same; + } + llvm_unreachable("Incomplete switch coverage"); +} + +// Aside: We may remove graph construction entirely, because it doesn't really +// buy us much that we don't already have. I'd like to add interprocedural +// analysis prior to this however, in case that somehow requires the graph +// produced by this for efficient execution +static void buildGraphFrom(CFLAliasAnalysis &Analysis, Function *Fn, + SmallVectorImpl<Value *> &ReturnedValues, + NodeMapT &Map, GraphT &Graph) { + const auto findOrInsertNode = [&Map, &Graph](Value *Val) { + auto Pair = Map.insert(std::make_pair(Val, GraphT::Node())); + auto &Iter = Pair.first; + if (Pair.second) { + auto NewNode = Graph.addNode(); + Iter->second = NewNode; + } + return Iter->second; + }; + + SmallVector<Edge, 8> Edges; + for (auto &Bb : Fn->getBasicBlockList()) { + for (auto &Inst : Bb.getInstList()) { + // We don't want the edges of most "return" instructions, but we *do* want + // to know what can be returned. + if (auto *Ret = dyn_cast<ReturnInst>(&Inst)) + ReturnedValues.push_back(Ret); + + if (!hasUsefulEdges(&Inst)) + continue; + + Edges.clear(); + argsToEdges(Analysis, &Inst, Edges); + + // In the case of an unused alloca (or similar), edges may be empty. Note + // that it exists so we can potentially answer NoAlias. + if (Edges.empty()) { + auto MaybeVal = getTargetValue(&Inst); + assert(MaybeVal.hasValue()); + auto *Target = *MaybeVal; + findOrInsertNode(Target); + continue; + } + + for (const Edge &E : Edges) { + auto To = findOrInsertNode(E.To); + auto From = findOrInsertNode(E.From); + auto FlippedWeight = flipWeight(E.Weight); + auto Attrs = E.AdditionalAttrs; + Graph.addEdge(From, To, std::make_pair(E.Weight, Attrs), + std::make_pair(FlippedWeight, Attrs)); + } + } + } +} + +static FunctionInfo buildSetsFrom(CFLAliasAnalysis &Analysis, Function *Fn) { + NodeMapT Map; + GraphT Graph; + SmallVector<Value *, 4> ReturnedValues; + + buildGraphFrom(Analysis, Fn, ReturnedValues, Map, Graph); + + DenseMap<GraphT::Node, Value *> NodeValueMap; + NodeValueMap.resize(Map.size()); + for (const auto &Pair : Map) + NodeValueMap.insert(std::make_pair(Pair.second, Pair.first)); + + const auto findValueOrDie = [&NodeValueMap](GraphT::Node Node) { + auto ValIter = NodeValueMap.find(Node); + assert(ValIter != NodeValueMap.end()); + return ValIter->second; + }; + + StratifiedSetsBuilder<Value *> Builder; + + SmallVector<GraphT::Node, 16> Worklist; + for (auto &Pair : Map) { + Worklist.clear(); + + auto *Value = Pair.first; + Builder.add(Value); + auto InitialNode = Pair.second; + Worklist.push_back(InitialNode); + while (!Worklist.empty()) { + auto Node = Worklist.pop_back_val(); + auto *CurValue = findValueOrDie(Node); + if (isa<Constant>(CurValue) && !isa<GlobalValue>(CurValue)) + continue; + + for (const auto &EdgeTuple : Graph.edgesFor(Node)) { + auto Weight = std::get<0>(EdgeTuple); + auto Label = Weight.first; + auto &OtherNode = std::get<1>(EdgeTuple); + auto *OtherValue = findValueOrDie(OtherNode); + + if (isa<Constant>(OtherValue) && !isa<GlobalValue>(OtherValue)) + continue; + + bool Added; + switch (directionOfEdgeType(Label)) { + case Level::Above: + Added = Builder.addAbove(CurValue, OtherValue); + break; + case Level::Below: + Added = Builder.addBelow(CurValue, OtherValue); + break; + case Level::Same: + Added = Builder.addWith(CurValue, OtherValue); + break; + } + + if (Added) { + auto Aliasing = Weight.second; + if (auto MaybeCurIndex = valueToAttrIndex(CurValue)) + Aliasing.set(*MaybeCurIndex); + if (auto MaybeOtherIndex = valueToAttrIndex(OtherValue)) + Aliasing.set(*MaybeOtherIndex); + Builder.noteAttributes(CurValue, Aliasing); + Builder.noteAttributes(OtherValue, Aliasing); + Worklist.push_back(OtherNode); + } + } + } + } + + // There are times when we end up with parameters not in our graph (i.e. if + // it's only used as the condition of a branch). Other bits of code depend on + // things that were present during construction being present in the graph. + // So, we add all present arguments here. + for (auto &Arg : Fn->args()) { + Builder.add(&Arg); + } + + return FunctionInfo(Builder.build(), std::move(ReturnedValues)); +} + +void CFLAliasAnalysis::scan(Function *Fn) { + auto InsertPair = Cache.insert(std::make_pair(Fn, Optional<FunctionInfo>())); + (void)InsertPair; + assert(InsertPair.second && + "Trying to scan a function that has already been cached"); + + FunctionInfo Info(buildSetsFrom(*this, Fn)); + Cache[Fn] = std::move(Info); + Handles.push_front(FunctionHandle(Fn, this)); +} + +AliasAnalysis::AliasResult +CFLAliasAnalysis::query(const AliasAnalysis::Location &LocA, + const AliasAnalysis::Location &LocB) { + auto *ValA = const_cast<Value *>(LocA.Ptr); + auto *ValB = const_cast<Value *>(LocB.Ptr); + + Function *Fn = nullptr; + auto MaybeFnA = parentFunctionOfValue(ValA); + auto MaybeFnB = parentFunctionOfValue(ValB); + if (!MaybeFnA.hasValue() && !MaybeFnB.hasValue()) { + llvm_unreachable("Don't know how to extract the parent function " + "from values A or B"); + } + + if (MaybeFnA.hasValue()) { + Fn = *MaybeFnA; + assert((!MaybeFnB.hasValue() || *MaybeFnB == *MaybeFnA) && + "Interprocedural queries not supported"); + } else { + Fn = *MaybeFnB; + } + + assert(Fn != nullptr); + auto &MaybeInfo = ensureCached(Fn); + assert(MaybeInfo.hasValue()); + + auto &Sets = MaybeInfo->Sets; + auto MaybeA = Sets.find(ValA); + if (!MaybeA.hasValue()) + return AliasAnalysis::MayAlias; + + auto MaybeB = Sets.find(ValB); + if (!MaybeB.hasValue()) + return AliasAnalysis::MayAlias; + + auto SetA = *MaybeA; + auto SetB = *MaybeB; + + if (SetA.Index == SetB.Index) + return AliasAnalysis::PartialAlias; + + auto AttrsA = Sets.getLink(SetA.Index).Attrs; + auto AttrsB = Sets.getLink(SetB.Index).Attrs; + // Stratified set attributes are used as markets to signify whether a member + // of a StratifiedSet (or a member of a set above the current set) has + // interacted with either arguments or globals. "Interacted with" meaning + // its value may be different depending on the value of an argument or + // global. The thought behind this is that, because arguments and globals + // may alias each other, if AttrsA and AttrsB have touched args/globals, + // we must conservatively say that they alias. However, if at least one of + // the sets has no values that could legally be altered by changing the value + // of an argument or global, then we don't have to be as conservative. + if (AttrsA.any() && AttrsB.any()) + return AliasAnalysis::MayAlias; + + return AliasAnalysis::NoAlias; +} diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index d1632fd..4e9664f 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -5,12 +5,14 @@ add_llvm_library(LLVMAnalysis AliasDebugger.cpp AliasSetTracker.cpp Analysis.cpp + AssumptionTracker.cpp BasicAliasAnalysis.cpp BlockFrequencyInfo.cpp BlockFrequencyInfoImpl.cpp BranchProbabilityInfo.cpp CFG.cpp CFGPrinter.cpp + CFLAliasAnalysis.cpp CGSCCPassManager.cpp CaptureTracking.cpp CostModel.cpp @@ -20,6 +22,7 @@ add_llvm_library(LLVMAnalysis DependenceAnalysis.cpp DomPrinter.cpp DominanceFrontier.cpp + FunctionTargetTransformInfo.cpp IVUsers.cpp InstCount.cpp InstructionSimplify.cpp @@ -53,6 +56,7 @@ add_llvm_library(LLVMAnalysis TargetTransformInfo.cpp Trace.cpp TypeBasedAliasAnalysis.cpp + ScopedNoAliasAA.cpp ValueTracking.cpp ) diff --git a/lib/Analysis/CaptureTracking.cpp b/lib/Analysis/CaptureTracking.cpp index 3708e60..a271729 100644 --- a/lib/Analysis/CaptureTracking.cpp +++ b/lib/Analysis/CaptureTracking.cpp @@ -20,8 +20,10 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/CFG.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" using namespace llvm; @@ -49,6 +51,65 @@ namespace { bool Captured; }; + + /// Only find pointer captures which happen before the given instruction. Uses + /// the dominator tree to determine whether one instruction is before another. + /// Only support the case where the Value is defined in the same basic block + /// as the given instruction and the use. + struct CapturesBefore : public CaptureTracker { + CapturesBefore(bool ReturnCaptures, const Instruction *I, DominatorTree *DT, + bool IncludeI) + : BeforeHere(I), DT(DT), ReturnCaptures(ReturnCaptures), + IncludeI(IncludeI), Captured(false) {} + + void tooManyUses() override { Captured = true; } + + bool shouldExplore(const Use *U) override { + Instruction *I = cast<Instruction>(U->getUser()); + if (BeforeHere == I && !IncludeI) + return false; + + BasicBlock *BB = I->getParent(); + // We explore this usage only if the usage can reach "BeforeHere". + // If use is not reachable from entry, there is no need to explore. + if (BeforeHere != I && !DT->isReachableFromEntry(BB)) + return false; + // If the value is defined in the same basic block as use and BeforeHere, + // there is no need to explore the use if BeforeHere dominates use. + // Check whether there is a path from I to BeforeHere. + if (BeforeHere != I && DT->dominates(BeforeHere, I) && + !isPotentiallyReachable(I, BeforeHere, DT)) + return false; + return true; + } + + bool captured(const Use *U) override { + if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures) + return false; + + Instruction *I = cast<Instruction>(U->getUser()); + if (BeforeHere == I && !IncludeI) + return false; + + BasicBlock *BB = I->getParent(); + // Same logic as in shouldExplore. + if (BeforeHere != I && !DT->isReachableFromEntry(BB)) + return false; + if (BeforeHere != I && DT->dominates(BeforeHere, I) && + !isPotentiallyReachable(I, BeforeHere, DT)) + return false; + Captured = true; + return true; + } + + const Instruction *BeforeHere; + DominatorTree *DT; + + bool ReturnCaptures; + bool IncludeI; + + bool Captured; + }; } /// PointerMayBeCaptured - Return true if this pointer value may be captured @@ -74,6 +135,32 @@ bool llvm::PointerMayBeCaptured(const Value *V, return SCT.Captured; } +/// PointerMayBeCapturedBefore - Return true if this pointer value may be +/// captured by the enclosing function (which is required to exist). If a +/// DominatorTree is provided, only captures which happen before the given +/// instruction are considered. This routine can be expensive, so consider +/// caching the results. The boolean ReturnCaptures specifies whether +/// returning the value (or part of it) from the function counts as capturing +/// it or not. The boolean StoreCaptures specified whether storing the value +/// (or part of it) into memory anywhere automatically counts as capturing it +/// or not. +bool llvm::PointerMayBeCapturedBefore(const Value *V, bool ReturnCaptures, + bool StoreCaptures, const Instruction *I, + DominatorTree *DT, bool IncludeI) { + assert(!isa<GlobalValue>(V) && + "It doesn't make sense to ask whether a global is captured."); + + if (!DT) + return PointerMayBeCaptured(V, ReturnCaptures, StoreCaptures); + + // TODO: See comment in PointerMayBeCaptured regarding what could be done + // with StoreCaptures. + + CapturesBefore CB(ReturnCaptures, I, DT, IncludeI); + PointerMayBeCaptured(V, &CB); + return CB.Captured; +} + /// TODO: Write a new FunctionPass AliasAnalysis so that it can keep /// a cache. Then we can move the code from BasicAliasAnalysis into /// that path, and remove this threshold. @@ -152,7 +239,7 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { if (Count++ >= Threshold) return Tracker->tooManyUses(); - if (Visited.insert(&UU)) + if (Visited.insert(&UU).second) if (Tracker->shouldExplore(&UU)) Worklist.push_back(&UU); } diff --git a/lib/Analysis/CodeMetrics.cpp b/lib/Analysis/CodeMetrics.cpp index 4c8a093..f29e4a2 100644 --- a/lib/Analysis/CodeMetrics.cpp +++ b/lib/Analysis/CodeMetrics.cpp @@ -11,23 +11,101 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "code-metrics" using namespace llvm; +static void completeEphemeralValues(SmallVector<const Value *, 16> &WorkSet, + SmallPtrSetImpl<const Value*> &EphValues) { + SmallPtrSet<const Value *, 32> Visited; + + // Make sure that all of the items in WorkSet are in our EphValues set. + EphValues.insert(WorkSet.begin(), WorkSet.end()); + + // Note: We don't speculate PHIs here, so we'll miss instruction chains kept + // alive only by ephemeral values. + + while (!WorkSet.empty()) { + const Value *V = WorkSet.front(); + WorkSet.erase(WorkSet.begin()); + + if (!Visited.insert(V).second) + continue; + + // If all uses of this value are ephemeral, then so is this value. + bool FoundNEUse = false; + for (const User *I : V->users()) + if (!EphValues.count(I)) { + FoundNEUse = true; + break; + } + + if (FoundNEUse) + continue; + + EphValues.insert(V); + DEBUG(dbgs() << "Ephemeral Value: " << *V << "\n"); + + if (const User *U = dyn_cast<User>(V)) + for (const Value *J : U->operands()) { + if (isSafeToSpeculativelyExecute(J)) + WorkSet.push_back(J); + } + } +} + +// Find all ephemeral values. +void CodeMetrics::collectEphemeralValues(const Loop *L, AssumptionTracker *AT, + SmallPtrSetImpl<const Value*> &EphValues) { + SmallVector<const Value *, 16> WorkSet; + + for (auto &I : AT->assumptions(L->getHeader()->getParent())) { + // Filter out call sites outside of the loop so we don't to a function's + // worth of work for each of its loops (and, in the common case, ephemeral + // values in the loop are likely due to @llvm.assume calls in the loop). + if (!L->contains(I->getParent())) + continue; + + WorkSet.push_back(I); + } + + completeEphemeralValues(WorkSet, EphValues); +} + +void CodeMetrics::collectEphemeralValues(const Function *F, AssumptionTracker *AT, + SmallPtrSetImpl<const Value*> &EphValues) { + SmallVector<const Value *, 16> WorkSet; + + for (auto &I : AT->assumptions(const_cast<Function*>(F))) + WorkSet.push_back(I); + + completeEphemeralValues(WorkSet, EphValues); +} + /// analyzeBasicBlock - Fill in the current structure with information gleaned /// from the specified block. void CodeMetrics::analyzeBasicBlock(const BasicBlock *BB, - const TargetTransformInfo &TTI) { + const TargetTransformInfo &TTI, + SmallPtrSetImpl<const Value*> &EphValues) { ++NumBlocks; unsigned NumInstsBeforeThisBB = NumInsts; for (BasicBlock::const_iterator II = BB->begin(), E = BB->end(); II != E; ++II) { + // Skip ephemeral values. + if (EphValues.count(II)) + continue; + // Special handling for calls. if (isa<CallInst>(II) || isa<InvokeInst>(II)) { ImmutableCallSite CS(cast<Instruction>(II)); diff --git a/lib/Analysis/ConstantFolding.cpp b/lib/Analysis/ConstantFolding.cpp index eb3e2c6..fd8f2ae 100644 --- a/lib/Analysis/ConstantFolding.cpp +++ b/lib/Analysis/ConstantFolding.cpp @@ -47,15 +47,16 @@ using namespace llvm; // Constant Folding internal helper functions //===----------------------------------------------------------------------===// -/// FoldBitCast - Constant fold bitcast, symbolically evaluating it with -/// DataLayout. This always returns a non-null constant, but it may be a +/// Constant fold bitcast, symbolically evaluating it with DataLayout. +/// This always returns a non-null constant, but it may be a /// ConstantExpr if unfoldable. static Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &TD) { // Catch the obvious splat cases. if (C->isNullValue() && !DestTy->isX86_MMXTy()) return Constant::getNullValue(DestTy); - if (C->isAllOnesValue() && !DestTy->isX86_MMXTy()) + if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() && + !DestTy->isPtrOrPtrVectorTy()) // Don't get ones for ptr types! return Constant::getAllOnesValue(DestTy); // Handle a vector->integer cast. @@ -197,7 +198,7 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy, // Handle: bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>) unsigned Ratio = NumDstElt/NumSrcElt; - unsigned DstBitSize = DstEltTy->getPrimitiveSizeInBits(); + unsigned DstBitSize = TD.getTypeSizeInBits(DstEltTy); // Loop over each source value, expanding into multiple results. for (unsigned i = 0; i != NumSrcElt; ++i) { @@ -213,6 +214,15 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy, ConstantInt::get(Src->getType(), ShiftAmt)); ShiftAmt += isLittleEndian ? DstBitSize : -DstBitSize; + // Truncate the element to an integer with the same pointer size and + // convert the element back to a pointer using a inttoptr. + if (DstEltTy->isPointerTy()) { + IntegerType *DstIntTy = Type::getIntNTy(C->getContext(), DstBitSize); + Constant *CE = ConstantExpr::getTrunc(Elt, DstIntTy); + Result.push_back(ConstantExpr::getIntToPtr(CE, DstEltTy)); + continue; + } + // Truncate and remember this piece. Result.push_back(ConstantExpr::getTrunc(Elt, DstEltTy)); } @@ -222,9 +232,8 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy, } -/// IsConstantOffsetFromGlobal - If this constant is actually a constant offset -/// from a global, return the global and the constant. Because of -/// constantexprs, this function is recursive. +/// If this constant is a constant offset from a global, return the global and +/// the constant. Because of constantexprs, this function is recursive. static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, APInt &Offset, const DataLayout &TD) { // Trivial case, constant is the global. @@ -240,7 +249,8 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, // Look through ptr->int and ptr->ptr casts. if (CE->getOpcode() == Instruction::PtrToInt || - CE->getOpcode() == Instruction::BitCast) + CE->getOpcode() == Instruction::BitCast || + CE->getOpcode() == Instruction::AddrSpaceCast) return IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, TD); // i32* getelementptr ([5 x i32]* @a, i32 0, i32 5) @@ -263,10 +273,10 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, return true; } -/// ReadDataFromGlobal - Recursive helper to read bits out of global. C is the -/// constant being copied out of. ByteOffset is an offset into C. CurPtr is the -/// pointer to copy results into and BytesLeft is the number of bytes left in -/// the CurPtr buffer. TD is the target data. +/// Recursive helper to read bits out of global. C is the constant being copied +/// out of. ByteOffset is an offset into C. CurPtr is the pointer to copy +/// results into and BytesLeft is the number of bytes left in +/// the CurPtr buffer. TD is the target data. static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, unsigned char *CurPtr, unsigned BytesLeft, const DataLayout &TD) { @@ -517,9 +527,8 @@ static Constant *ConstantFoldLoadThroughBitcast(ConstantExpr *CE, return nullptr; } -/// ConstantFoldLoadFromConstPtr - Return the value that a load from C would -/// produce if it is constant and determinable. If this is not determinable, -/// return null. +/// Return the value that a load from C would produce if it is constant and +/// determinable. If this is not determinable, return null. Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, const DataLayout *TD) { // First, try the easy cases: @@ -609,7 +618,7 @@ static Constant *ConstantFoldLoadInst(const LoadInst *LI, const DataLayout *TD){ return nullptr; } -/// SymbolicallyEvaluateBinop - One of Op0/Op1 is a constant expression. +/// One of Op0/Op1 is a constant expression. /// Attempt to symbolically evaluate the result of a binary operator merging /// these together. If target data info is available, it is provided as DL, /// otherwise DL is null. @@ -666,9 +675,8 @@ static Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0, return nullptr; } -/// CastGEPIndices - If array indices are not pointer-sized integers, -/// explicitly cast them so that they aren't implicitly casted by the -/// getelementptr. +/// If array indices are not pointer-sized integers, explicitly cast them so +/// that they aren't implicitly casted by the getelementptr. static Constant *CastGEPIndices(ArrayRef<Constant *> Ops, Type *ResultTy, const DataLayout *TD, const TargetLibraryInfo *TLI) { @@ -723,8 +731,7 @@ static Constant* StripPtrCastKeepAS(Constant* Ptr) { return Ptr; } -/// SymbolicallyEvaluateGEP - If we can symbolically evaluate the specified GEP -/// constant expression, do so. +/// If we can symbolically evaluate the GEP constant expression, do so. static Constant *SymbolicallyEvaluateGEP(ArrayRef<Constant *> Ops, Type *ResultTy, const DataLayout *TD, const TargetLibraryInfo *TLI) { @@ -886,7 +893,7 @@ static Constant *SymbolicallyEvaluateGEP(ArrayRef<Constant *> Ops, // Constant Folding public APIs //===----------------------------------------------------------------------===// -/// ConstantFoldInstruction - Try to constant fold the specified instruction. +/// Try to constant fold the specified instruction. /// If successful, the constant result is returned, if not, null is returned. /// Note that this fails if not all of the operands are constant. Otherwise, /// this function can only fail when attempting to fold instructions like loads @@ -966,7 +973,7 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, static Constant * ConstantFoldConstantExpressionImpl(const ConstantExpr *CE, const DataLayout *TD, const TargetLibraryInfo *TLI, - SmallPtrSet<ConstantExpr *, 4> &FoldedOps) { + SmallPtrSetImpl<ConstantExpr *> &FoldedOps) { SmallVector<Constant *, 8> Ops; for (User::const_op_iterator i = CE->op_begin(), e = CE->op_end(); i != e; ++i) { @@ -974,7 +981,7 @@ ConstantFoldConstantExpressionImpl(const ConstantExpr *CE, const DataLayout *TD, // Recursively fold the ConstantExpr's operands. If we have already folded // a ConstantExpr, we don't have to process it again. if (ConstantExpr *NewCE = dyn_cast<ConstantExpr>(NewC)) { - if (FoldedOps.insert(NewCE)) + if (FoldedOps.insert(NewCE).second) NewC = ConstantFoldConstantExpressionImpl(NewCE, TD, TLI, FoldedOps); } Ops.push_back(NewC); @@ -986,7 +993,7 @@ ConstantFoldConstantExpressionImpl(const ConstantExpr *CE, const DataLayout *TD, return ConstantFoldInstOperands(CE->getOpcode(), CE->getType(), Ops, TD, TLI); } -/// ConstantFoldConstantExpression - Attempt to fold the constant expression +/// Attempt to fold the constant expression /// using the specified DataLayout. If successful, the constant result is /// result is returned, if not, null is returned. Constant *llvm::ConstantFoldConstantExpression(const ConstantExpr *CE, @@ -996,7 +1003,7 @@ Constant *llvm::ConstantFoldConstantExpression(const ConstantExpr *CE, return ConstantFoldConstantExpressionImpl(CE, TD, TLI, FoldedOps); } -/// ConstantFoldInstOperands - Attempt to constant fold an instruction with the +/// Attempt to constant fold an instruction with the /// specified opcode and operands. If successful, the constant result is /// returned, if not, null is returned. Note that this function can fail when /// attempting to fold instructions like loads and stores, which have no @@ -1101,10 +1108,9 @@ Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, Type *DestTy, } } -/// ConstantFoldCompareInstOperands - Attempt to constant fold a compare +/// Attempt to constant fold a compare /// instruction (icmp/fcmp) with the specified operands. If it fails, it /// returns a constant expression of the specified operands. -/// Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, Constant *Ops0, Constant *Ops1, const DataLayout *TD, @@ -1191,9 +1197,9 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, } -/// ConstantFoldLoadThroughGEPConstantExpr - Given a constant and a -/// getelementptr constantexpr, return the constant value being addressed by the -/// constant expression, or null if something is funny and we can't decide. +/// Given a constant and a getelementptr constantexpr, return the constant value +/// being addressed by the constant expression, or null if something is funny +/// and we can't decide. Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C, ConstantExpr *CE) { if (!CE->getOperand(1)->isNullValue()) @@ -1209,10 +1215,9 @@ Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C, return C; } -/// ConstantFoldLoadThroughGEPIndices - Given a constant and getelementptr -/// indices (with an *implied* zero pointer index that is not in the list), -/// return the constant value being addressed by a virtual load, or null if -/// something is funny and we can't decide. +/// Given a constant and getelementptr indices (with an *implied* zero pointer +/// index that is not in the list), return the constant value being addressed by +/// a virtual load, or null if something is funny and we can't decide. Constant *llvm::ConstantFoldLoadThroughGEPIndices(Constant *C, ArrayRef<Constant*> Indices) { // Loop over all of the operands, tracking down which value we are @@ -1230,11 +1235,12 @@ Constant *llvm::ConstantFoldLoadThroughGEPIndices(Constant *C, // Constant Folding for Calls // -/// canConstantFoldCallTo - Return true if its even possible to fold a call to -/// the specified function. +/// Return true if it's even possible to fold a call to the specified function. bool llvm::canConstantFoldCallTo(const Function *F) { switch (F->getIntrinsicID()) { case Intrinsic::fabs: + case Intrinsic::minnum: + case Intrinsic::maxnum: case Intrinsic::log: case Intrinsic::log2: case Intrinsic::log10: @@ -1320,7 +1326,7 @@ static Constant *GetConstantFoldFPValue(double V, Type *Ty) { } namespace { -/// llvm_fenv_clearexcept - Clear the floating-point exception state. +/// Clear the floating-point exception state. static inline void llvm_fenv_clearexcept() { #if defined(HAVE_FENV_H) && HAVE_DECL_FE_ALL_EXCEPT feclearexcept(FE_ALL_EXCEPT); @@ -1328,7 +1334,7 @@ static inline void llvm_fenv_clearexcept() { errno = 0; } -/// llvm_fenv_testexcept - Test if a floating-point exception was raised. +/// Test if a floating-point exception was raised. static inline bool llvm_fenv_testexcept() { int errno_val = errno; if (errno_val == ERANGE || errno_val == EDOM) @@ -1365,14 +1371,13 @@ static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), return GetConstantFoldFPValue(V, Ty); } -/// ConstantFoldConvertToInt - Attempt to an SSE floating point to integer -/// conversion of a constant floating point. If roundTowardZero is false, the -/// default IEEE rounding is used (toward nearest, ties to even). This matches -/// the behavior of the non-truncating SSE instructions in the default rounding -/// mode. The desired integer type Ty is used to select how many bits are -/// available for the result. Returns null if the conversion cannot be -/// performed, otherwise returns the Constant value resulting from the -/// conversion. +/// Attempt to fold an SSE floating point to integer conversion of a constant +/// floating point. If roundTowardZero is false, the default IEEE rounding is +/// used (toward nearest, ties to even). This matches the behavior of the +/// non-truncating SSE instructions in the default rounding mode. The desired +/// integer type Ty is used to select how many bits are available for the +/// result. Returns null if the conversion cannot be performed, otherwise +/// returns the Constant value resulting from the conversion. static Constant *ConstantFoldConvertToInt(const APFloat &Val, bool roundTowardZero, Type *Ty) { // All of these conversion intrinsics form an integer of at most 64bits. @@ -1519,8 +1524,14 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy())) { if (V >= -0.0) return ConstantFoldFP(sqrt, V, Ty); - else // Undefined - return Constant::getNullValue(Ty); + else { + // Unlike the sqrt definitions in C/C++, POSIX, and IEEE-754 - which + // all guarantee or favor returning NaN - the square root of a + // negative number is not defined for the LLVM sqrt intrinsic. + // This is because the intrinsic should only be emitted in place of + // libm's sqrt function when using "no-nans-fp-math". + return UndefValue::get(Ty); + } } break; case 's': @@ -1626,6 +1637,19 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, V1.copySign(V2); return ConstantFP::get(Ty->getContext(), V1); } + + if (IntrinsicID == Intrinsic::minnum) { + const APFloat &C1 = Op1->getValueAPF(); + const APFloat &C2 = Op2->getValueAPF(); + return ConstantFP::get(Ty->getContext(), minnum(C1, C2)); + } + + if (IntrinsicID == Intrinsic::maxnum) { + const APFloat &C1 = Op1->getValueAPF(); + const APFloat &C2 = Op2->getValueAPF(); + return ConstantFP::get(Ty->getContext(), maxnum(C1, C2)); + } + if (!TLI) return nullptr; if (Name == "pow" && TLI->has(LibFunc::pow)) @@ -1761,7 +1785,7 @@ static Constant *ConstantFoldVectorCall(StringRef Name, unsigned IntrinsicID, return ConstantVector::get(Result); } -/// ConstantFoldCall - Attempt to constant fold a call to the specified function +/// Attempt to constant fold a call to the specified function /// with the specified arguments, returning null if unsuccessful. Constant * llvm::ConstantFoldCall(Function *F, ArrayRef<Constant *> Operands, diff --git a/lib/Analysis/DependenceAnalysis.cpp b/lib/Analysis/DependenceAnalysis.cpp index d0784f1..092df5c 100644 --- a/lib/Analysis/DependenceAnalysis.cpp +++ b/lib/Analysis/DependenceAnalysis.cpp @@ -163,16 +163,15 @@ void dumpExampleDependence(raw_ostream &OS, Function *F, DstI != DstE; ++DstI) { if (isa<StoreInst>(*DstI) || isa<LoadInst>(*DstI)) { OS << "da analyze - "; - if (Dependence *D = DA->depends(&*SrcI, &*DstI, true)) { + if (auto D = DA->depends(&*SrcI, &*DstI, true)) { D->dump(OS); for (unsigned Level = 1; Level <= D->getLevels(); Level++) { if (D->isSplitable(Level)) { OS << "da analyze - split level = " << Level; - OS << ", iteration = " << *DA->getSplitIteration(D, Level); + OS << ", iteration = " << *DA->getSplitIteration(*D, Level); OS << "!\n"; } } - delete D; } else OS << "none!\n"; @@ -782,6 +781,25 @@ void DependenceAnalysis::collectCommonLoops(const SCEV *Expression, } } +void DependenceAnalysis::unifySubscriptType(Subscript *Pair) { + const SCEV *Src = Pair->Src; + const SCEV *Dst = Pair->Dst; + IntegerType *SrcTy = dyn_cast<IntegerType>(Src->getType()); + IntegerType *DstTy = dyn_cast<IntegerType>(Dst->getType()); + if (SrcTy == nullptr || DstTy == nullptr) { + assert(SrcTy == DstTy && "This function only unify integer types and " + "expect Src and Dst share the same type " + "otherwise."); + return; + } + if (SrcTy->getBitWidth() > DstTy->getBitWidth()) { + // Sign-extend Dst to typeof(Src) if typeof(Src) is wider than typeof(Dst). + Pair->Dst = SE->getSignExtendExpr(Dst, SrcTy); + } else if (SrcTy->getBitWidth() < DstTy->getBitWidth()) { + // Sign-extend Src to typeof(Dst) if typeof(Dst) is wider than typeof(Src). + Pair->Src = SE->getSignExtendExpr(Src, DstTy); + } +} // removeMatchingExtensions - Examines a subscript pair. // If the source and destination are identically sign (or zero) @@ -794,9 +812,11 @@ void DependenceAnalysis::removeMatchingExtensions(Subscript *Pair) { (isa<SCEVSignExtendExpr>(Src) && isa<SCEVSignExtendExpr>(Dst))) { const SCEVCastExpr *SrcCast = cast<SCEVCastExpr>(Src); const SCEVCastExpr *DstCast = cast<SCEVCastExpr>(Dst); - if (SrcCast->getType() == DstCast->getType()) { - Pair->Src = SrcCast->getOperand(); - Pair->Dst = DstCast->getOperand(); + const SCEV *SrcCastOp = SrcCast->getOperand(); + const SCEV *DstCastOp = DstCast->getOperand(); + if (SrcCastOp->getType() == DstCastOp->getType()) { + Pair->Src = SrcCastOp; + Pair->Dst = DstCastOp; } } } @@ -2957,15 +2977,11 @@ const SCEV *DependenceAnalysis::addToCoefficient(const SCEV *Expr, AddRec->getNoWrapFlags()); } if (SE->isLoopInvariant(AddRec, TargetLoop)) - return SE->getAddRecExpr(AddRec, - Value, - TargetLoop, - SCEV::FlagAnyWrap); - return SE->getAddRecExpr(addToCoefficient(AddRec->getStart(), - TargetLoop, Value), - AddRec->getStepRecurrence(*SE), - AddRec->getLoop(), - AddRec->getNoWrapFlags()); + return SE->getAddRecExpr(AddRec, Value, TargetLoop, SCEV::FlagAnyWrap); + return SE->getAddRecExpr( + addToCoefficient(AddRec->getStart(), TargetLoop, Value), + AddRec->getStepRecurrence(*SE), AddRec->getLoop(), + AddRec->getNoWrapFlags()); } @@ -3183,7 +3199,7 @@ void DependenceAnalysis::updateDirection(Dependence::DVEntry &Level, bool DependenceAnalysis::tryDelinearize(const SCEV *SrcSCEV, const SCEV *DstSCEV, SmallVectorImpl<Subscript> &Pair, - const SCEV *ElementSize) const { + const SCEV *ElementSize) { const SCEVUnknown *SrcBase = dyn_cast<SCEVUnknown>(SE->getPointerBase(SrcSCEV)); const SCEVUnknown *DstBase = @@ -3238,6 +3254,7 @@ bool DependenceAnalysis::tryDelinearize(const SCEV *SrcSCEV, for (int i = 0; i < size; ++i) { Pair[i].Src = SrcSubscripts[i]; Pair[i].Dst = DstSubscripts[i]; + unifySubscriptType(&Pair[i]); // FIXME: we should record the bounds SrcSizes[i] and DstSizes[i] that the // delinearization has found, and add these constraints to the dependence @@ -3277,9 +3294,9 @@ static void dumpSmallBitVector(SmallBitVector &BV) { // // Care is required to keep the routine below, getSplitIteration(), // up to date with respect to this routine. -Dependence *DependenceAnalysis::depends(Instruction *Src, - Instruction *Dst, - bool PossiblyLoopIndependent) { +std::unique_ptr<Dependence> +DependenceAnalysis::depends(Instruction *Src, Instruction *Dst, + bool PossiblyLoopIndependent) { if (Src == Dst) PossiblyLoopIndependent = false; @@ -3291,7 +3308,7 @@ Dependence *DependenceAnalysis::depends(Instruction *Src, if (!isLoadOrStore(Src) || !isLoadOrStore(Dst)) { // can only analyze simple loads and stores, i.e., no calls, invokes, etc. DEBUG(dbgs() << "can only handle simple loads and stores\n"); - return new Dependence(Src, Dst); + return make_unique<Dependence>(Src, Dst); } Value *SrcPtr = getPointerOperand(Src); @@ -3302,7 +3319,7 @@ Dependence *DependenceAnalysis::depends(Instruction *Src, case AliasAnalysis::PartialAlias: // cannot analyse objects if we don't understand their aliasing. DEBUG(dbgs() << "can't analyze may or partial alias\n"); - return new Dependence(Src, Dst); + return make_unique<Dependence>(Src, Dst); case AliasAnalysis::NoAlias: // If the objects noalias, they are distinct, accesses are independent. DEBUG(dbgs() << "no alias\n"); @@ -3346,6 +3363,7 @@ Dependence *DependenceAnalysis::depends(Instruction *Src, ++SrcIdx, ++DstIdx, ++P) { Pair[P].Src = SE->getSCEV(*SrcIdx); Pair[P].Dst = SE->getSCEV(*DstIdx); + unifySubscriptType(&Pair[P]); } } else { @@ -3675,9 +3693,9 @@ Dependence *DependenceAnalysis::depends(Instruction *Src, return nullptr; } - FullDependence *Final = new FullDependence(Result); + auto Final = make_unique<FullDependence>(Result); Result.DV = nullptr; - return Final; + return std::move(Final); } @@ -3729,13 +3747,12 @@ Dependence *DependenceAnalysis::depends(Instruction *Src, // // breaks the dependence and allows us to vectorize/parallelize // both loops. -const SCEV *DependenceAnalysis::getSplitIteration(const Dependence *Dep, +const SCEV *DependenceAnalysis::getSplitIteration(const Dependence &Dep, unsigned SplitLevel) { - assert(Dep && "expected a pointer to a Dependence"); - assert(Dep->isSplitable(SplitLevel) && + assert(Dep.isSplitable(SplitLevel) && "Dep should be splitable at SplitLevel"); - Instruction *Src = Dep->getSrc(); - Instruction *Dst = Dep->getDst(); + Instruction *Src = Dep.getSrc(); + Instruction *Dst = Dep.getDst(); assert(Src->mayReadFromMemory() || Src->mayWriteToMemory()); assert(Dst->mayReadFromMemory() || Dst->mayWriteToMemory()); assert(isLoadOrStore(Src)); diff --git a/lib/Analysis/DominanceFrontier.cpp b/lib/Analysis/DominanceFrontier.cpp index 74594f8..7ba91bc 100644 --- a/lib/Analysis/DominanceFrontier.cpp +++ b/lib/Analysis/DominanceFrontier.cpp @@ -8,133 +8,50 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/DominanceFrontier.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/Analysis/DominanceFrontierImpl.h" + using namespace llvm; +namespace llvm { +template class DominanceFrontierBase<BasicBlock>; +template class ForwardDominanceFrontierBase<BasicBlock>; +} + char DominanceFrontier::ID = 0; + INITIALIZE_PASS_BEGIN(DominanceFrontier, "domfrontier", "Dominance Frontier Construction", true, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(DominanceFrontier, "domfrontier", "Dominance Frontier Construction", true, true) -namespace { - class DFCalculateWorkObject { - public: - DFCalculateWorkObject(BasicBlock *B, BasicBlock *P, - const DomTreeNode *N, - const DomTreeNode *PN) - : currentBB(B), parentBB(P), Node(N), parentNode(PN) {} - BasicBlock *currentBB; - BasicBlock *parentBB; - const DomTreeNode *Node; - const DomTreeNode *parentNode; - }; +DominanceFrontier::DominanceFrontier() + : FunctionPass(ID), + Base() { + initializeDominanceFrontierPass(*PassRegistry::getPassRegistry()); } -void DominanceFrontier::anchor() { } - -const DominanceFrontier::DomSetType & -DominanceFrontier::calculate(const DominatorTree &DT, - const DomTreeNode *Node) { - BasicBlock *BB = Node->getBlock(); - DomSetType *Result = nullptr; - - std::vector<DFCalculateWorkObject> workList; - SmallPtrSet<BasicBlock *, 32> visited; - - workList.push_back(DFCalculateWorkObject(BB, nullptr, Node, nullptr)); - do { - DFCalculateWorkObject *currentW = &workList.back(); - assert (currentW && "Missing work object."); - - BasicBlock *currentBB = currentW->currentBB; - BasicBlock *parentBB = currentW->parentBB; - const DomTreeNode *currentNode = currentW->Node; - const DomTreeNode *parentNode = currentW->parentNode; - assert (currentBB && "Invalid work object. Missing current Basic Block"); - assert (currentNode && "Invalid work object. Missing current Node"); - DomSetType &S = Frontiers[currentBB]; - - // Visit each block only once. - if (visited.count(currentBB) == 0) { - visited.insert(currentBB); - - // Loop over CFG successors to calculate DFlocal[currentNode] - for (succ_iterator SI = succ_begin(currentBB), SE = succ_end(currentBB); - SI != SE; ++SI) { - // Does Node immediately dominate this successor? - if (DT[*SI]->getIDom() != currentNode) - S.insert(*SI); - } - } - - // At this point, S is DFlocal. Now we union in DFup's of our children... - // Loop through and visit the nodes that Node immediately dominates (Node's - // children in the IDomTree) - bool visitChild = false; - for (DomTreeNode::const_iterator NI = currentNode->begin(), - NE = currentNode->end(); NI != NE; ++NI) { - DomTreeNode *IDominee = *NI; - BasicBlock *childBB = IDominee->getBlock(); - if (visited.count(childBB) == 0) { - workList.push_back(DFCalculateWorkObject(childBB, currentBB, - IDominee, currentNode)); - visitChild = true; - } - } - - // If all children are visited or there is any child then pop this block - // from the workList. - if (!visitChild) { - - if (!parentBB) { - Result = &S; - break; - } - - DomSetType::const_iterator CDFI = S.begin(), CDFE = S.end(); - DomSetType &parentSet = Frontiers[parentBB]; - for (; CDFI != CDFE; ++CDFI) { - if (!DT.properlyDominates(parentNode, DT[*CDFI])) - parentSet.insert(*CDFI); - } - workList.pop_back(); - } +void DominanceFrontier::releaseMemory() { + Base.releaseMemory(); +} - } while (!workList.empty()); +bool DominanceFrontier::runOnFunction(Function &) { + releaseMemory(); + Base.analyze(getAnalysis<DominatorTreeWrapperPass>().getDomTree()); + return false; +} - return *Result; +void DominanceFrontier::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<DominatorTreeWrapperPass>(); } -void DominanceFrontierBase::print(raw_ostream &OS, const Module* ) const { - for (const_iterator I = begin(), E = end(); I != E; ++I) { - OS << " DomFrontier for BB "; - if (I->first) - I->first->printAsOperand(OS, false); - else - OS << " <<exit node>>"; - OS << " is:\t"; - - const std::set<BasicBlock*> &BBs = I->second; - - for (std::set<BasicBlock*>::const_iterator I = BBs.begin(), E = BBs.end(); - I != E; ++I) { - OS << ' '; - if (*I) - (*I)->printAsOperand(OS, false); - else - OS << "<<exit node>>"; - } - OS << "\n"; - } +void DominanceFrontier::print(raw_ostream &OS, const Module *) const { + Base.print(OS); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void DominanceFrontierBase::dump() const { +void DominanceFrontier::dump() const { print(dbgs()); } #endif - diff --git a/lib/Analysis/FunctionTargetTransformInfo.cpp b/lib/Analysis/FunctionTargetTransformInfo.cpp new file mode 100644 index 0000000..a686bec --- /dev/null +++ b/lib/Analysis/FunctionTargetTransformInfo.cpp @@ -0,0 +1,50 @@ +//===- llvm/Analysis/FunctionTargetTransformInfo.h --------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass wraps a TargetTransformInfo in a FunctionPass so that it can +// forward along the current Function so that we can make target specific +// decisions based on the particular subtarget specified for each Function. +// +//===----------------------------------------------------------------------===// + +#include "llvm/InitializePasses.h" +#include "llvm/Analysis/FunctionTargetTransformInfo.h" + +using namespace llvm; + +#define DEBUG_TYPE "function-tti" +static const char ftti_name[] = "Function TargetTransformInfo"; +INITIALIZE_PASS_BEGIN(FunctionTargetTransformInfo, "function_tti", ftti_name, false, true) +INITIALIZE_AG_DEPENDENCY(TargetTransformInfo) +INITIALIZE_PASS_END(FunctionTargetTransformInfo, "function_tti", ftti_name, false, true) +char FunctionTargetTransformInfo::ID = 0; + +namespace llvm { +FunctionPass *createFunctionTargetTransformInfoPass() { + return new FunctionTargetTransformInfo(); +} +} + +FunctionTargetTransformInfo::FunctionTargetTransformInfo() + : FunctionPass(ID), Fn(nullptr), TTI(nullptr) { + initializeFunctionTargetTransformInfoPass(*PassRegistry::getPassRegistry()); +} + +void FunctionTargetTransformInfo::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<TargetTransformInfo>(); +} + +void FunctionTargetTransformInfo::releaseMemory() {} + +bool FunctionTargetTransformInfo::runOnFunction(Function &F) { + Fn = &F; + TTI = &getAnalysis<TargetTransformInfo>(); + return false; +} diff --git a/lib/Analysis/IPA/CallGraph.cpp b/lib/Analysis/IPA/CallGraph.cpp index caec253..67cf7f8 100644 --- a/lib/Analysis/IPA/CallGraph.cpp +++ b/lib/Analysis/IPA/CallGraph.cpp @@ -267,7 +267,7 @@ INITIALIZE_PASS(CallGraphWrapperPass, "basiccg", "CallGraph Construction", char CallGraphWrapperPass::ID = 0; -void CallGraphWrapperPass::releaseMemory() { G.reset(nullptr); } +void CallGraphWrapperPass::releaseMemory() { G.reset(); } void CallGraphWrapperPass::print(raw_ostream &OS, const Module *) const { if (!G) { @@ -282,6 +282,3 @@ void CallGraphWrapperPass::print(raw_ostream &OS, const Module *) const { #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void CallGraphWrapperPass::dump() const { print(dbgs(), nullptr); } #endif - -// Enuse that users of CallGraph.h also link with this file -DEFINING_FILE_FOR(CallGraph) diff --git a/lib/Analysis/IPA/CallGraphSCCPass.cpp b/lib/Analysis/IPA/CallGraphSCCPass.cpp index c27edbf..665aa7f 100644 --- a/lib/Analysis/IPA/CallGraphSCCPass.cpp +++ b/lib/Analysis/IPA/CallGraphSCCPass.cpp @@ -243,7 +243,14 @@ bool CGPassManager::RefreshCallGraph(CallGraphSCC &CurSCC, assert(!CallSites.count(I->first) && "Call site occurs in node multiple times"); - CallSites.insert(std::make_pair(I->first, I->second)); + + CallSite CS(I->first); + if (CS) { + Function *Callee = CS.getCalledFunction(); + // Ignore intrinsics because they're not really function calls. + if (!Callee || !(Callee->isIntrinsic())) + CallSites.insert(std::make_pair(I->first, I->second)); + } ++I; } diff --git a/lib/Analysis/IPA/InlineCost.cpp b/lib/Analysis/IPA/InlineCost.cpp index 8807529..85db278 100644 --- a/lib/Analysis/IPA/InlineCost.cpp +++ b/lib/Analysis/IPA/InlineCost.cpp @@ -17,7 +17,9 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CallSite.h" @@ -49,6 +51,9 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { /// The TargetTransformInfo available for this compilation. const TargetTransformInfo &TTI; + /// The cache of @llvm.assume intrinsics. + AssumptionTracker *AT; + // The called function. Function &F; @@ -104,7 +109,7 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { ConstantInt *stripAndComputeInBoundsConstantOffsets(Value *&V); // Custom analysis routines. - bool analyzeBlock(BasicBlock *BB); + bool analyzeBlock(BasicBlock *BB, SmallPtrSetImpl<const Value *> &EphValues); // Disable several entry points to the visitor so we don't accidentally use // them by declaring but not defining them here. @@ -141,8 +146,8 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { public: CallAnalyzer(const DataLayout *DL, const TargetTransformInfo &TTI, - Function &Callee, int Threshold) - : DL(DL), TTI(TTI), F(Callee), Threshold(Threshold), Cost(0), + AssumptionTracker *AT, Function &Callee, int Threshold) + : DL(DL), TTI(TTI), AT(AT), F(Callee), Threshold(Threshold), Cost(0), IsCallerRecursive(false), IsRecursiveCall(false), ExposesReturnsTwice(false), HasDynamicAlloca(false), ContainsNoDuplicateCall(false), HasReturn(false), HasIndirectBr(false), @@ -778,7 +783,7 @@ bool CallAnalyzer::visitCallSite(CallSite CS) { // during devirtualization and so we want to give it a hefty bonus for // inlining, but cap that bonus in the event that inlining wouldn't pan // out. Pretend to inline the function, with a custom threshold. - CallAnalyzer CA(DL, TTI, *F, InlineConstants::IndirectCallThreshold); + CallAnalyzer CA(DL, TTI, AT, *F, InlineConstants::IndirectCallThreshold); if (CA.analyzeCall(CS)) { // We were able to inline the indirect call! Subtract the cost from the // bonus we want to apply, but don't go below zero. @@ -881,7 +886,8 @@ bool CallAnalyzer::visitInstruction(Instruction &I) { /// aborts early if the threshold has been exceeded or an impossible to inline /// construct has been detected. It returns false if inlining is no longer /// viable, and true if inlining remains viable. -bool CallAnalyzer::analyzeBlock(BasicBlock *BB) { +bool CallAnalyzer::analyzeBlock(BasicBlock *BB, + SmallPtrSetImpl<const Value *> &EphValues) { for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { // FIXME: Currently, the number of instructions in a function regardless of // our ability to simplify them during inline to constants or dead code, @@ -893,6 +899,10 @@ bool CallAnalyzer::analyzeBlock(BasicBlock *BB) { if (isa<DbgInfoIntrinsic>(I)) continue; + // Skip ephemeral values. + if (EphValues.count(I)) + continue; + ++NumInstructions; if (isa<ExtractElementInst>(I) || I->getType()->isVectorTy()) ++NumVectorInstructions; @@ -967,7 +977,7 @@ ConstantInt *CallAnalyzer::stripAndComputeInBoundsConstantOffsets(Value *&V) { break; } assert(V->getType()->isPointerTy() && "Unexpected operand type!"); - } while (Visited.insert(V)); + } while (Visited.insert(V).second); Type *IntPtrTy = DL->getIntPtrType(V->getContext()); return cast<ConstantInt>(ConstantInt::get(IntPtrTy, Offset)); @@ -1096,6 +1106,12 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { NumConstantOffsetPtrArgs = ConstantOffsetPtrs.size(); NumAllocaArgs = SROAArgValues.size(); + // FIXME: If a caller has multiple calls to a callee, we end up recomputing + // the ephemeral values multiple times (and they're completely determined by + // the callee, so this is purely duplicate work). + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(&F, AT, EphValues); + // The worklist of live basic blocks in the callee *after* inlining. We avoid // adding basic blocks of the callee which can be proven to be dead for this // particular call site in order to get more accurate cost estimates. This @@ -1129,7 +1145,7 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // Analyze the cost of this block. If we blow through the threshold, this // returns false, and we can bail on out. - if (!analyzeBlock(BB)) { + if (!analyzeBlock(BB, EphValues)) { if (IsRecursiveCall || ExposesReturnsTwice || HasDynamicAlloca || HasIndirectBr) return false; @@ -1217,6 +1233,7 @@ void CallAnalyzer::dump() { INITIALIZE_PASS_BEGIN(InlineCostAnalysis, "inline-cost", "Inline Cost Analysis", true, true) INITIALIZE_AG_DEPENDENCY(TargetTransformInfo) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_END(InlineCostAnalysis, "inline-cost", "Inline Cost Analysis", true, true) @@ -1228,12 +1245,14 @@ InlineCostAnalysis::~InlineCostAnalysis() {} void InlineCostAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); + AU.addRequired<AssumptionTracker>(); AU.addRequired<TargetTransformInfo>(); CallGraphSCCPass::getAnalysisUsage(AU); } bool InlineCostAnalysis::runOnSCC(CallGraphSCC &SCC) { TTI = &getAnalysis<TargetTransformInfo>(); + AT = &getAnalysis<AssumptionTracker>(); return false; } @@ -1290,7 +1309,7 @@ InlineCost InlineCostAnalysis::getInlineCost(CallSite CS, Function *Callee, DEBUG(llvm::dbgs() << " Analyzing call of " << Callee->getName() << "...\n"); - CallAnalyzer CA(Callee->getDataLayout(), *TTI, *Callee, Threshold); + CallAnalyzer CA(Callee->getDataLayout(), *TTI, AT, *Callee, Threshold); bool ShouldInline = CA.analyzeCall(CS); DEBUG(CA.dump()); diff --git a/lib/Analysis/IVUsers.cpp b/lib/Analysis/IVUsers.cpp index 24655aa..6b5f370 100644 --- a/lib/Analysis/IVUsers.cpp +++ b/lib/Analysis/IVUsers.cpp @@ -84,7 +84,7 @@ static bool isInteresting(const SCEV *S, const Instruction *I, const Loop *L, /// form. static bool isSimplifiedLoopNest(BasicBlock *BB, const DominatorTree *DT, const LoopInfo *LI, - SmallPtrSet<Loop*,16> &SimpleLoopNests) { + SmallPtrSetImpl<Loop*> &SimpleLoopNests) { Loop *NearestLoop = nullptr; for (DomTreeNode *Rung = DT->getNode(BB); Rung; Rung = Rung->getIDom()) { @@ -112,10 +112,10 @@ static bool isSimplifiedLoopNest(BasicBlock *BB, const DominatorTree *DT, /// reducible SCEV, recursively add its users to the IVUsesByStride set and /// return true. Otherwise, return false. bool IVUsers::AddUsersImpl(Instruction *I, - SmallPtrSet<Loop*,16> &SimpleLoopNests) { + SmallPtrSetImpl<Loop*> &SimpleLoopNests) { // Add this IV user to the Processed set before returning false to ensure that // all IV users are members of the set. See IVUsers::isIVUserOrOperand. - if (!Processed.insert(I)) + if (!Processed.insert(I).second) return true; // Instruction already handled. if (!SE->isSCEVable(I->getType())) @@ -145,7 +145,7 @@ bool IVUsers::AddUsersImpl(Instruction *I, SmallPtrSet<Instruction *, 4> UniqueUsers; for (Use &U : I->uses()) { Instruction *User = cast<Instruction>(U.getUser()); - if (!UniqueUsers.insert(User)) + if (!UniqueUsers.insert(User).second) continue; // Do not infinitely recurse on PHI nodes. diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index bd42af1..f151a3a 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -41,14 +41,20 @@ enum { RecursionLimit = 3 }; STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumReassoc, "Number of reassociations"); +namespace { struct Query { const DataLayout *DL; const TargetLibraryInfo *TLI; const DominatorTree *DT; + AssumptionTracker *AT; + const Instruction *CxtI; Query(const DataLayout *DL, const TargetLibraryInfo *tli, - const DominatorTree *dt) : DL(DL), TLI(tli), DT(dt) {} + const DominatorTree *dt, AssumptionTracker *at = nullptr, + const Instruction *cxti = nullptr) + : DL(DL), TLI(tli), DT(dt), AT(at), CxtI(cxti) {} }; +} // end anonymous namespace static Value *SimplifyAndInst(Value *, Value *, const Query &, unsigned); static Value *SimplifyBinOp(unsigned, Value *, Value *, const Query &, @@ -575,9 +581,10 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, Query (DL, TLI, DT), - RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, + Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } /// \brief Compute the base pointer and cumulative constant offsets for V. @@ -624,7 +631,7 @@ static Constant *stripAndComputeConstantOffsets(const DataLayout *DL, } assert(V->getType()->getScalarType()->isPointerTy() && "Unexpected operand type!"); - } while (Visited.insert(V)); + } while (Visited.insert(V).second); Constant *OffsetIntPtr = ConstantInt::get(IntPtrTy, Offset); if (V->getType()->isVectorTy()) @@ -676,6 +683,18 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (Op0 == Op1) return Constant::getNullValue(Op0->getType()); + // X - (0 - Y) -> X if the second sub is NUW. + // If Y != 0, 0 - Y is a poison value. + // If Y == 0, 0 - Y simplifies to 0. + if (BinaryOperator::isNeg(Op1)) { + if (const auto *BO = dyn_cast<BinaryOperator>(Op1)) { + assert(BO->getOpcode() == Instruction::Sub && + "Expected a subtraction operator!"); + if (BO->hasNoUnsignedWrap()) + return Op0; + } + } + // (X + Y) - Z -> X + (Y - Z) or Y + (X - Z) if everything simplifies. // For example, (X + Y) - Y -> X; (Y + X) - Y -> X Value *X = nullptr, *Y = nullptr, *Z = Op1; @@ -769,9 +788,10 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Query (DL, TLI, DT), - RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, + Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } /// Given operands for an FAdd, see if we can fold the result. If not, this @@ -947,28 +967,37 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFAddInst(Op0, Op1, FMF, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFAddInst(Op0, Op1, FMF, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFSubInst(Op0, Op1, FMF, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFSubInst(Op0, Op1, FMF, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFMulInst(Op0, Op1, FMF, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFMulInst(Op0, Op1, FMF, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyMulInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyMulInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyDiv - Given operands for an SDiv or UDiv, see if we can @@ -1028,6 +1057,16 @@ static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, (!isSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1))))) return Constant::getNullValue(Op0->getType()); + // (X /u C1) /u C2 -> 0 if C1 * C2 overflow + ConstantInt *C1, *C2; + if (!isSigned && match(Op0, m_UDiv(m_Value(X), m_ConstantInt(C1))) && + match(Op1, m_ConstantInt(C2))) { + bool Overflow; + C1->getValue().umul_ov(C2->getValue(), Overflow); + if (Overflow) + return Constant::getNullValue(Op0->getType()); + } + // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) @@ -1055,8 +1094,11 @@ static Value *SimplifySDivInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySDivInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifySDivInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyUDivInst - Given operands for a UDiv, see if we can @@ -1071,8 +1113,11 @@ static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyUDivInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyUDivInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } static Value *SimplifyFDivInst(Value *Op0, Value *Op1, const Query &Q, @@ -1090,8 +1135,11 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFDivInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFDivInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyRem - Given operands for an SRem or URem, see if we can @@ -1133,6 +1181,13 @@ static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (Op0 == Op1) return Constant::getNullValue(Op0->getType()); + // (X % Y) % Y -> X % Y + if ((Opcode == Instruction::SRem && + match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || + (Opcode == Instruction::URem && + match(Op0, m_URem(m_Value(), m_Specific(Op1))))) + return Op0; + // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) @@ -1160,8 +1215,11 @@ static Value *SimplifySRemInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySRemInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifySRemInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyURemInst - Given operands for a URem, see if we can @@ -1176,8 +1234,11 @@ static Value *SimplifyURemInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyURemInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyURemInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } static Value *SimplifyFRemInst(Value *Op0, Value *Op1, const Query &, @@ -1195,8 +1256,11 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, const Query &, Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFRemInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFRemInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// isUndefShift - Returns true if a shift by \c Amount always yields undef. @@ -1264,6 +1328,32 @@ static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, return nullptr; } +/// \brief Given operands for an Shl, LShr or AShr, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1, + bool isExact, const Query &Q, + unsigned MaxRecurse) { + if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse)) + return V; + + // X >> X -> 0 + if (Op0 == Op1) + return Constant::getNullValue(Op0->getType()); + + // The low bit cannot be shifted out of an exact shift if it is set. + if (isExact) { + unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); + APInt Op0KnownZero(BitWidth, 0); + APInt Op0KnownOne(BitWidth, 0); + computeKnownBits(Op0, Op0KnownZero, Op0KnownOne, Q.DL, /*Depth=*/0, Q.AT, Q.CxtI, + Q.DT); + if (Op0KnownOne[0]) + return Op0; + } + + return nullptr; +} + /// SimplifyShlInst - Given operands for an Shl, see if we can /// fold the result. If not, this returns null. static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, @@ -1284,8 +1374,9 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -1293,12 +1384,9 @@ Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, /// fold the result. If not, this returns null. static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, const Query &Q, unsigned MaxRecurse) { - if (Value *V = SimplifyShift(Instruction::LShr, Op0, Op1, Q, MaxRecurse)) - return V; - - // X >> X -> 0 - if (Op0 == Op1) - return Constant::getNullValue(Op0->getType()); + if (Value *V = SimplifyRightShift(Instruction::LShr, Op0, Op1, isExact, Q, + MaxRecurse)) + return V; // undef >>l X -> 0 if (match(Op0, m_Undef())) @@ -1306,8 +1394,7 @@ static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, // (X << A) >> A -> X Value *X; - if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1))) && - cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap()) + if (match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1)))) return X; return nullptr; @@ -1316,8 +1403,10 @@ static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyLShrInst(Op0, Op1, isExact, Query (DL, TLI, DT), + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyLShrInst(Op0, Op1, isExact, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -1325,13 +1414,10 @@ Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, /// fold the result. If not, this returns null. static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, const Query &Q, unsigned MaxRecurse) { - if (Value *V = SimplifyShift(Instruction::AShr, Op0, Op1, Q, MaxRecurse)) + if (Value *V = SimplifyRightShift(Instruction::AShr, Op0, Op1, isExact, Q, + MaxRecurse)) return V; - // X >> X -> 0 - if (Op0 == Op1) - return Constant::getNullValue(Op0->getType()); - // all ones >>a X -> all ones if (match(Op0, m_AllOnes())) return Op0; @@ -1342,21 +1428,75 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, // (X << A) >> A -> X Value *X; - if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1))) && - cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap()) + if (match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1)))) return X; + // Arithmetic shifting an all-sign-bit value is a no-op. + unsigned NumSignBits = ComputeNumSignBits(Op0, Q.DL, 0, Q.AT, Q.CxtI, Q.DT); + if (NumSignBits == Op0->getType()->getScalarSizeInBits()) + return Op0; + return nullptr; } Value *llvm::SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyAShrInst(Op0, Op1, isExact, Query (DL, TLI, DT), + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyAShrInst(Op0, Op1, isExact, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } +// Simplify (and (icmp ...) (icmp ...)) to true when we can tell that the range +// of possible values cannot be satisfied. +static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { + ICmpInst::Predicate Pred0, Pred1; + ConstantInt *CI1, *CI2; + Value *V; + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_ConstantInt(CI1)), + m_ConstantInt(CI2)))) + return nullptr; + + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Specific(CI1)))) + return nullptr; + + Type *ITy = Op0->getType(); + + auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + bool isNSW = AddInst->hasNoSignedWrap(); + bool isNUW = AddInst->hasNoUnsignedWrap(); + + const APInt &CI1V = CI1->getValue(); + const APInt &CI2V = CI2->getValue(); + const APInt Delta = CI2V - CI1V; + if (CI1V.isStrictlyPositive()) { + if (Delta == 2) { + if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_SGT) + return getFalse(ITy); + if (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT && isNSW) + return getFalse(ITy); + } + if (Delta == 1) { + if (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_SGT) + return getFalse(ITy); + if (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGT && isNSW) + return getFalse(ITy); + } + } + if (CI1V.getBoolValue() && isNUW) { + if (Delta == 2) + if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT) + return getFalse(ITy); + if (Delta == 1) + if (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGT) + return getFalse(ITy); + } + + return nullptr; +} + /// SimplifyAndInst - Given operands for an And, see if we can /// fold the result. If not, this returns null. static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, @@ -1407,12 +1547,21 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, // A & (-A) = A if A is a power of two or zero. if (match(Op0, m_Neg(m_Specific(Op1))) || match(Op1, m_Neg(m_Specific(Op0)))) { - if (isKnownToBeAPowerOfTwo(Op0, /*OrZero*/true)) + if (isKnownToBeAPowerOfTwo(Op0, /*OrZero*/true, 0, Q.AT, Q.CxtI, Q.DT)) return Op0; - if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true)) + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true, 0, Q.AT, Q.CxtI, Q.DT)) return Op1; } + if (auto *ICILHS = dyn_cast<ICmpInst>(Op0)) { + if (auto *ICIRHS = dyn_cast<ICmpInst>(Op1)) { + if (Value *V = SimplifyAndOfICmps(ICILHS, ICIRHS)) + return V; + if (Value *V = SimplifyAndOfICmps(ICIRHS, ICILHS)) + return V; + } + } + // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q, MaxRecurse)) @@ -1447,8 +1596,58 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyAndInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyAndInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); +} + +// Simplify (or (icmp ...) (icmp ...)) to true when we can tell that the union +// contains all possible values. +static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { + ICmpInst::Predicate Pred0, Pred1; + ConstantInt *CI1, *CI2; + Value *V; + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_ConstantInt(CI1)), + m_ConstantInt(CI2)))) + return nullptr; + + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Specific(CI1)))) + return nullptr; + + Type *ITy = Op0->getType(); + + auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + bool isNSW = AddInst->hasNoSignedWrap(); + bool isNUW = AddInst->hasNoUnsignedWrap(); + + const APInt &CI1V = CI1->getValue(); + const APInt &CI2V = CI2->getValue(); + const APInt Delta = CI2V - CI1V; + if (CI1V.isStrictlyPositive()) { + if (Delta == 2) { + if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE) + return getTrue(ITy); + if (Pred0 == ICmpInst::ICMP_SGE && Pred1 == ICmpInst::ICMP_SLE && isNSW) + return getTrue(ITy); + } + if (Delta == 1) { + if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_SLE) + return getTrue(ITy); + if (Pred0 == ICmpInst::ICMP_SGT && Pred1 == ICmpInst::ICMP_SLE && isNSW) + return getTrue(ITy); + } + } + if (CI1V.getBoolValue() && isNUW) { + if (Delta == 2) + if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE) + return getTrue(ITy); + if (Delta == 1) + if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_ULE) + return getTrue(ITy); + } + + return nullptr; } /// SimplifyOrInst - Given operands for an Or, see if we can @@ -1508,6 +1707,15 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, (A == Op0 || B == Op0)) return Constant::getAllOnesValue(Op0->getType()); + if (auto *ICILHS = dyn_cast<ICmpInst>(Op0)) { + if (auto *ICIRHS = dyn_cast<ICmpInst>(Op1)) { + if (Value *V = SimplifyOrOfICmps(ICILHS, ICIRHS)) + return V; + if (Value *V = SimplifyOrOfICmps(ICIRHS, ICILHS)) + return V; + } + } + // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, MaxRecurse)) @@ -1540,18 +1748,22 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, if ((C2->getValue() & (C2->getValue() + 1)) == 0 && // C2 == 0+1+ match(A, m_Add(m_Value(V1), m_Value(V2)))) { // Add commutes, try both ways. - if (V1 == B && MaskedValueIsZero(V2, C2->getValue())) + if (V1 == B && MaskedValueIsZero(V2, C2->getValue(), Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return A; - if (V2 == B && MaskedValueIsZero(V1, C2->getValue())) + if (V2 == B && MaskedValueIsZero(V1, C2->getValue(), Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return A; } // Or commutes, try both ways. if ((C1->getValue() & (C1->getValue() + 1)) == 0 && match(B, m_Add(m_Value(V1), m_Value(V2)))) { // Add commutes, try both ways. - if (V1 == A && MaskedValueIsZero(V2, C1->getValue())) + if (V1 == A && MaskedValueIsZero(V2, C1->getValue(), Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return B; - if (V2 == A && MaskedValueIsZero(V1, C1->getValue())) + if (V2 == A && MaskedValueIsZero(V1, C1->getValue(), Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return B; } } @@ -1568,8 +1780,10 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyOrInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyOrInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyXorInst - Given operands for a Xor, see if we can @@ -1623,8 +1837,10 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyXorInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyXorInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } static Type *GetCompareTy(Value *Op) { @@ -1878,40 +2094,46 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(ITy); case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULE: - if (isKnownNonZero(LHS, Q.DL)) + if (isKnownNonZero(LHS, Q.DL, 0, Q.AT, Q.CxtI, Q.DT)) return getFalse(ITy); break; case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGT: - if (isKnownNonZero(LHS, Q.DL)) + if (isKnownNonZero(LHS, Q.DL, 0, Q.AT, Q.CxtI, Q.DT)) return getTrue(ITy); break; case ICmpInst::ICMP_SLT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (LHSKnownNegative) return getTrue(ITy); if (LHSKnownNonNegative) return getFalse(ITy); break; case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (LHSKnownNegative) return getTrue(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL)) + if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return getFalse(ITy); break; case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (LHSKnownNegative) return getFalse(ITy); if (LHSKnownNonNegative) return getTrue(ITy); break; case ICmpInst::ICMP_SGT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (LHSKnownNegative) return getFalse(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL)) + if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return getTrue(ITy); break; } @@ -1958,13 +2180,39 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Lower = (-Upper) + 1; } } else if (match(LHS, m_SDiv(m_Value(), m_ConstantInt(CI2)))) { - // 'sdiv x, CI2' produces [INT_MIN / CI2, INT_MAX / CI2]. APInt IntMin = APInt::getSignedMinValue(Width); APInt IntMax = APInt::getSignedMaxValue(Width); - APInt Val = CI2->getValue().abs(); - if (!Val.isMinValue()) { + APInt Val = CI2->getValue(); + if (Val.isAllOnesValue()) { + // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] + // where CI2 != -1 and CI2 != 0 and CI2 != 1 + Lower = IntMin + 1; + Upper = IntMax + 1; + } else if (Val.countLeadingZeros() < Width - 1) { + // 'sdiv x, CI2' produces [INT_MIN / CI2, INT_MAX / CI2] + // where CI2 != -1 and CI2 != 0 and CI2 != 1 Lower = IntMin.sdiv(Val); - Upper = IntMax.sdiv(Val) + 1; + Upper = IntMax.sdiv(Val); + if (Lower.sgt(Upper)) + std::swap(Lower, Upper); + Upper = Upper + 1; + assert(Upper != Lower && "Upper part of range has wrapped!"); + } + } else if (match(LHS, m_NUWShl(m_ConstantInt(CI2), m_Value()))) { + // 'shl nuw CI2, x' produces [CI2, CI2 << CLZ(CI2)] + Lower = CI2->getValue(); + Upper = Lower.shl(Lower.countLeadingZeros()) + 1; + } else if (match(LHS, m_NSWShl(m_ConstantInt(CI2), m_Value()))) { + if (CI2->isNegative()) { + // 'shl nsw CI2, x' produces [CI2 << CLO(CI2)-1, CI2] + unsigned ShiftAmount = CI2->getValue().countLeadingOnes() - 1; + Lower = CI2->getValue().shl(ShiftAmount); + Upper = CI2->getValue() + 1; + } else { + // 'shl nsw CI2, x' produces [CI2, CI2 << CLZ(CI2)-1] + unsigned ShiftAmount = CI2->getValue().countLeadingZeros() - 1; + Lower = CI2->getValue(); + Upper = CI2->getValue().shl(ShiftAmount) + 1; } } else if (match(LHS, m_LShr(m_Value(), m_ConstantInt(CI2)))) { // 'lshr x, CI2' produces [0, UINT_MAX >> CI2]. @@ -2174,25 +2422,6 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } - // If a bit is known to be zero for A and known to be one for B, - // then A and B cannot be equal. - if (ICmpInst::isEquality(Pred)) { - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - uint32_t BitWidth = CI->getBitWidth(); - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); - APInt RHSKnownZero(BitWidth, 0); - APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); - if (((LHSKnownOne & RHSKnownZero) != 0) || - ((LHSKnownZero & RHSKnownOne) != 0)) - return (Pred == ICmpInst::ICMP_EQ) - ? ConstantInt::getFalse(CI->getContext()) - : ConstantInt::getTrue(CI->getContext()); - } - } - // Special logic for binary operators. BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS); BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS); @@ -2286,7 +2515,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2296,7 +2526,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getFalse(ITy); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2315,7 +2546,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2325,7 +2557,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(ITy); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2345,6 +2578,41 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(ITy); } + // handle: + // CI2 << X == CI + // CI2 << X != CI + // + // where CI2 is a power of 2 and CI isn't + if (auto *CI = dyn_cast<ConstantInt>(RHS)) { + const APInt *CI2Val, *CIVal = &CI->getValue(); + if (LBO && match(LBO, m_Shl(m_APInt(CI2Val), m_Value())) && + CI2Val->isPowerOf2()) { + if (!CIVal->isPowerOf2()) { + // CI2 << X can equal zero in some circumstances, + // this simplification is unsafe if CI is zero. + // + // We know it is safe if: + // - The shift is nsw, we can't shift out the one bit. + // - The shift is nuw, we can't shift out the one bit. + // - CI2 is one + // - CI isn't zero + if (LBO->hasNoSignedWrap() || LBO->hasNoUnsignedWrap() || + *CI2Val == 1 || !CI->isZero()) { + if (Pred == ICmpInst::ICMP_EQ) + return ConstantInt::getFalse(RHS->getContext()); + if (Pred == ICmpInst::ICMP_NE) + return ConstantInt::getTrue(RHS->getContext()); + } + } + if (CIVal->isSignBit() && *CI2Val == 1) { + if (Pred == ICmpInst::ICMP_UGT) + return ConstantInt::getFalse(RHS->getContext()); + if (Pred == ICmpInst::ICMP_ULE) + return ConstantInt::getTrue(RHS->getContext()); + } + } + } + if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() && LBO->getOperand(1) == RBO->getOperand(1)) { switch (LBO->getOpcode()) { @@ -2592,6 +2860,23 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } + // If a bit is known to be zero for A and known to be one for B, + // then A and B cannot be equal. + if (ICmpInst::isEquality(Pred)) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + uint32_t BitWidth = CI->getBitWidth(); + APInt LHSKnownZero(BitWidth, 0); + APInt LHSKnownOne(BitWidth, 0); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, Q.DL, /*Depth=*/0, Q.AT, + Q.CxtI, Q.DT); + const APInt &RHSVal = CI->getValue(); + if (((LHSKnownZero & RHSVal) != 0) || ((LHSKnownOne & ~RHSVal) != 0)) + return Pred == ICmpInst::ICMP_EQ + ? ConstantInt::getFalse(CI->getContext()) + : ConstantInt::getTrue(CI->getContext()); + } + } + // If the comparison is with the result of a select instruction, check whether // comparing with either branch of the select always yields the same value. if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) @@ -2610,8 +2895,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Value *llvm::SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyICmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT), + const DominatorTree *DT, + AssumptionTracker *AT, + Instruction *CxtI) { + return ::SimplifyICmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -2707,8 +2994,10 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFCmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT), + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFCmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -2746,9 +3035,11 @@ static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, Value *llvm::SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySelectInst(Cond, TrueVal, FalseVal, Query (DL, TLI, DT), - RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifySelectInst(Cond, TrueVal, FalseVal, + Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } /// SimplifyGEPInst - Given operands for an GetElementPtrInst, see if we can @@ -2756,29 +3047,72 @@ Value *llvm::SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, static Value *SimplifyGEPInst(ArrayRef<Value *> Ops, const Query &Q, unsigned) { // The type of the GEP pointer operand. PointerType *PtrTy = cast<PointerType>(Ops[0]->getType()->getScalarType()); + unsigned AS = PtrTy->getAddressSpace(); // getelementptr P -> P. if (Ops.size() == 1) return Ops[0]; - if (isa<UndefValue>(Ops[0])) { - // Compute the (pointer) type returned by the GEP instruction. - Type *LastType = GetElementPtrInst::getIndexedType(PtrTy, Ops.slice(1)); - Type *GEPTy = PointerType::get(LastType, PtrTy->getAddressSpace()); - if (VectorType *VT = dyn_cast<VectorType>(Ops[0]->getType())) - GEPTy = VectorType::get(GEPTy, VT->getNumElements()); + // Compute the (pointer) type returned by the GEP instruction. + Type *LastType = GetElementPtrInst::getIndexedType(PtrTy, Ops.slice(1)); + Type *GEPTy = PointerType::get(LastType, AS); + if (VectorType *VT = dyn_cast<VectorType>(Ops[0]->getType())) + GEPTy = VectorType::get(GEPTy, VT->getNumElements()); + + if (isa<UndefValue>(Ops[0])) return UndefValue::get(GEPTy); - } if (Ops.size() == 2) { // getelementptr P, 0 -> P. if (match(Ops[1], m_Zero())) return Ops[0]; - // getelementptr P, N -> P if P points to a type of zero size. - if (Q.DL) { - Type *Ty = PtrTy->getElementType(); - if (Ty->isSized() && Q.DL->getTypeAllocSize(Ty) == 0) + + Type *Ty = PtrTy->getElementType(); + if (Q.DL && Ty->isSized()) { + Value *P; + uint64_t C; + uint64_t TyAllocSize = Q.DL->getTypeAllocSize(Ty); + // getelementptr P, N -> P if P points to a type of zero size. + if (TyAllocSize == 0) return Ops[0]; + + // The following transforms are only safe if the ptrtoint cast + // doesn't truncate the pointers. + if (Ops[1]->getType()->getScalarSizeInBits() == + Q.DL->getPointerSizeInBits(AS)) { + auto PtrToIntOrZero = [GEPTy](Value *P) -> Value * { + if (match(P, m_Zero())) + return Constant::getNullValue(GEPTy); + Value *Temp; + if (match(P, m_PtrToInt(m_Value(Temp)))) + if (Temp->getType() == GEPTy) + return Temp; + return nullptr; + }; + + // getelementptr V, (sub P, V) -> P if P points to a type of size 1. + if (TyAllocSize == 1 && + match(Ops[1], m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))))) + if (Value *R = PtrToIntOrZero(P)) + return R; + + // getelementptr V, (ashr (sub P, V), C) -> Q + // if P points to a type of size 1 << C. + if (match(Ops[1], + m_AShr(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), + m_ConstantInt(C))) && + TyAllocSize == 1ULL << C) + if (Value *R = PtrToIntOrZero(P)) + return R; + + // getelementptr V, (sdiv (sub P, V), C) -> Q + // if P points to a type of size C. + if (match(Ops[1], + m_SDiv(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), + m_SpecificInt(TyAllocSize)))) + if (Value *R = PtrToIntOrZero(P)) + return R; + } } } @@ -2792,8 +3126,9 @@ static Value *SimplifyGEPInst(ArrayRef<Value *> Ops, const Query &Q, unsigned) { Value *llvm::SimplifyGEPInst(ArrayRef<Value *> Ops, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyGEPInst(Ops, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyGEPInst(Ops, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } /// SimplifyInsertValueInst - Given operands for an InsertValueInst, see if we @@ -2829,8 +3164,11 @@ Value *llvm::SimplifyInsertValueInst(Value *Agg, Value *Val, ArrayRef<unsigned> Idxs, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyInsertValueInst(Agg, Val, Idxs, Query (DL, TLI, DT), + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyInsertValueInst(Agg, Val, Idxs, + Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -2877,8 +3215,11 @@ static Value *SimplifyTruncInst(Value *Op, Type *Ty, const Query &Q, unsigned) { Value *llvm::SimplifyTruncInst(Value *Op, Type *Ty, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyTruncInst(Op, Ty, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyTruncInst(Op, Ty, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } //=== Helper functions for higher up the class hierarchy. @@ -2950,8 +3291,10 @@ static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyBinOp(Opcode, LHS, RHS, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyBinOp(Opcode, LHS, RHS, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyCmpInst - Given operands for a CmpInst, see if we can @@ -2965,8 +3308,9 @@ static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, Value *llvm::SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyCmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyCmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -3041,23 +3385,26 @@ static Value *SimplifyCall(Value *V, IterTy ArgBegin, IterTy ArgEnd, Value *llvm::SimplifyCall(Value *V, User::op_iterator ArgBegin, User::op_iterator ArgEnd, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyCall(V, ArgBegin, ArgEnd, Query(DL, TLI, DT), + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyCall(V, ArgBegin, ArgEnd, Query(DL, TLI, DT, AT, CxtI), RecursionLimit); } Value *llvm::SimplifyCall(Value *V, ArrayRef<Value *> Args, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyCall(V, Args.begin(), Args.end(), Query(DL, TLI, DT), - RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyCall(V, Args.begin(), Args.end(), + Query(DL, TLI, DT, AT, CxtI), RecursionLimit); } /// SimplifyInstruction - See if we can compute a simplified version of this /// instruction. If not, this returns null. Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { + const DominatorTree *DT, + AssumptionTracker *AT) { Value *Result; switch (I->getOpcode()) { @@ -3066,109 +3413,122 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout *DL, break; case Instruction::FAdd: Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT); + I->getFastMathFlags(), DL, TLI, DT, AT, I); break; case Instruction::Add: Result = SimplifyAddInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), cast<BinaryOperator>(I)->hasNoUnsignedWrap(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; case Instruction::FSub: Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT); + I->getFastMathFlags(), DL, TLI, DT, AT, I); break; case Instruction::Sub: Result = SimplifySubInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), cast<BinaryOperator>(I)->hasNoUnsignedWrap(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; case Instruction::FMul: Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT); + I->getFastMathFlags(), DL, TLI, DT, AT, I); break; case Instruction::Mul: - Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::SDiv: - Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::UDiv: - Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::FDiv: - Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::SRem: - Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::URem: - Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::FRem: - Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::Shl: Result = SimplifyShlInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), cast<BinaryOperator>(I)->hasNoUnsignedWrap(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; case Instruction::LShr: Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->isExact(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; case Instruction::AShr: Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->isExact(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; case Instruction::And: - Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::Or: - Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, + AT, I); break; case Instruction::Xor: - Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::ICmp: Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), - I->getOperand(0), I->getOperand(1), DL, TLI, DT); + I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::FCmp: Result = SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), - I->getOperand(0), I->getOperand(1), DL, TLI, DT); + I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::Select: Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1), - I->getOperand(2), DL, TLI, DT); + I->getOperand(2), DL, TLI, DT, AT, I); break; case Instruction::GetElementPtr: { SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end()); - Result = SimplifyGEPInst(Ops, DL, TLI, DT); + Result = SimplifyGEPInst(Ops, DL, TLI, DT, AT, I); break; } case Instruction::InsertValue: { InsertValueInst *IV = cast<InsertValueInst>(I); Result = SimplifyInsertValueInst(IV->getAggregateOperand(), IV->getInsertedValueOperand(), - IV->getIndices(), DL, TLI, DT); + IV->getIndices(), DL, TLI, DT, AT, I); break; } case Instruction::PHI: - Result = SimplifyPHINode(cast<PHINode>(I), Query (DL, TLI, DT)); + Result = SimplifyPHINode(cast<PHINode>(I), Query (DL, TLI, DT, AT, I)); break; case Instruction::Call: { CallSite CS(cast<CallInst>(I)); Result = SimplifyCall(CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; } case Instruction::Trunc: - Result = SimplifyTruncInst(I->getOperand(0), I->getType(), DL, TLI, DT); + Result = SimplifyTruncInst(I->getOperand(0), I->getType(), DL, TLI, DT, + AT, I); break; } @@ -3192,7 +3552,8 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout *DL, static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { + const DominatorTree *DT, + AssumptionTracker *AT) { bool Simplified = false; SmallSetVector<Instruction *, 8> Worklist; @@ -3219,7 +3580,7 @@ static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, I = Worklist[Idx]; // See if this instruction simplifies. - SimpleV = SimplifyInstruction(I, DL, TLI, DT); + SimpleV = SimplifyInstruction(I, DL, TLI, DT, AT); if (!SimpleV) continue; @@ -3245,15 +3606,17 @@ static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, bool llvm::recursivelySimplifyInstruction(Instruction *I, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return replaceAndRecursivelySimplifyImpl(I, nullptr, DL, TLI, DT); + const DominatorTree *DT, + AssumptionTracker *AT) { + return replaceAndRecursivelySimplifyImpl(I, nullptr, DL, TLI, DT, AT); } bool llvm::replaceAndRecursivelySimplify(Instruction *I, Value *SimpleV, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { + const DominatorTree *DT, + AssumptionTracker *AT) { assert(I != SimpleV && "replaceAndRecursivelySimplify(X,X) is not valid!"); assert(SimpleV && "Must provide a simplified value."); - return replaceAndRecursivelySimplifyImpl(I, SimpleV, DL, TLI, DT); + return replaceAndRecursivelySimplifyImpl(I, SimpleV, DL, TLI, DT, AT); } diff --git a/lib/Analysis/JumpInstrTableInfo.cpp b/lib/Analysis/JumpInstrTableInfo.cpp index b5b4265..7aae2a5 100644 --- a/lib/Analysis/JumpInstrTableInfo.cpp +++ b/lib/Analysis/JumpInstrTableInfo.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/Passes.h" #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" +#include "llvm/Support/MathExtras.h" using namespace llvm; @@ -28,7 +29,21 @@ ImmutablePass *llvm::createJumpInstrTableInfoPass() { return new JumpInstrTableInfo(); } -JumpInstrTableInfo::JumpInstrTableInfo() : ImmutablePass(ID), Tables() { +ModulePass *llvm::createJumpInstrTableInfoPass(unsigned Bound) { + // This cast is always safe, since Bound is always in a subset of uint64_t. + uint64_t B = static_cast<uint64_t>(Bound); + return new JumpInstrTableInfo(B); +} + +JumpInstrTableInfo::JumpInstrTableInfo(uint64_t ByteAlign) + : ImmutablePass(ID), Tables(), ByteAlignment(ByteAlign) { + if (!llvm::isPowerOf2_64(ByteAlign)) { + // Note that we don't explicitly handle overflow here, since we handle the 0 + // case explicitly when a caller actually tries to create jumptable entries, + // and this is the return value on overflow. + ByteAlignment = llvm::NextPowerOf2(ByteAlign); + } + initializeJumpInstrTableInfoPass(*PassRegistry::getPassRegistry()); } diff --git a/lib/Analysis/LazyCallGraph.cpp b/lib/Analysis/LazyCallGraph.cpp index e073616..767da4e 100644 --- a/lib/Analysis/LazyCallGraph.cpp +++ b/lib/Analysis/LazyCallGraph.cpp @@ -48,7 +48,7 @@ static void findCallees( } for (Value *Op : C->operand_values()) - if (Visited.insert(cast<Constant>(Op))) + if (Visited.insert(cast<Constant>(Op)).second) Worklist.push_back(cast<Constant>(Op)); } } @@ -66,7 +66,7 @@ LazyCallGraph::Node::Node(LazyCallGraph &G, Function &F) for (Instruction &I : BB) for (Value *Op : I.operand_values()) if (Constant *C = dyn_cast<Constant>(Op)) - if (Visited.insert(C)) + if (Visited.insert(C).second) Worklist.push_back(C); // We've collected all the constant (and thus potentially function or @@ -113,7 +113,7 @@ LazyCallGraph::LazyCallGraph(Module &M) : NextDFSNumber(0) { SmallPtrSet<Constant *, 16> Visited; for (GlobalVariable &GV : M.globals()) if (GV.hasInitializer()) - if (Visited.insert(GV.getInitializer())) + if (Visited.insert(GV.getInitializer()).second) Worklist.push_back(GV.getInitializer()); DEBUG(dbgs() << " Adding functions referenced by global initializers to the " @@ -688,7 +688,7 @@ static void printNodes(raw_ostream &OS, LazyCallGraph::Node &N, SmallPtrSetImpl<LazyCallGraph::Node *> &Printed) { // Recurse depth first through the nodes. for (LazyCallGraph::Node &ChildN : N) - if (Printed.insert(&ChildN)) + if (Printed.insert(&ChildN).second) printNodes(OS, ChildN, Printed); OS << " Call edges in function: " << N.getFunction().getName() << "\n"; @@ -717,7 +717,7 @@ PreservedAnalyses LazyCallGraphPrinterPass::run(Module *M, SmallPtrSet<LazyCallGraph::Node *, 16> Printed; for (LazyCallGraph::Node &N : G) - if (Printed.insert(&N)) + if (Printed.insert(&N).second) printNodes(OS, N, Printed); for (LazyCallGraph::SCC &SCC : G.postorder_sccs()) diff --git a/lib/Analysis/LazyValueInfo.cpp b/lib/Analysis/LazyValueInfo.cpp index 9f919f7..c712c9f 100644 --- a/lib/Analysis/LazyValueInfo.cpp +++ b/lib/Analysis/LazyValueInfo.cpp @@ -15,12 +15,14 @@ #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" @@ -38,6 +40,7 @@ using namespace PatternMatch; char LazyValueInfo::ID = 0; INITIALIZE_PASS_BEGIN(LazyValueInfo, "lazy-value-info", "Lazy Value Information Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_END(LazyValueInfo, "lazy-value-info", "Lazy Value Information Analysis", false, true) @@ -338,6 +341,13 @@ namespace { /// during a query. It basically emulates the callstack of the naive /// recursive value lookup process. std::stack<std::pair<BasicBlock*, Value*> > BlockValueStack; + + /// A pointer to the cache of @llvm.assume calls. + AssumptionTracker *AT; + /// An optional DL pointer. + const DataLayout *DL; + /// An optional DT pointer. + DominatorTree *DT; friend struct LVIValueHandle; @@ -364,7 +374,8 @@ namespace { LVILatticeVal getBlockValue(Value *Val, BasicBlock *BB); bool getEdgeValue(Value *V, BasicBlock *F, BasicBlock *T, - LVILatticeVal &Result); + LVILatticeVal &Result, + Instruction *CxtI = nullptr); bool hasBlockValue(Value *Val, BasicBlock *BB); // These methods process one work item and may add more. A false value @@ -377,6 +388,8 @@ namespace { PHINode *PN, BasicBlock *BB); bool solveBlockValueConstantRange(LVILatticeVal &BBLV, Instruction *BBI, BasicBlock *BB); + void mergeAssumeBlockValueConstantRange(Value *Val, LVILatticeVal &BBLV, + Instruction *BBI); void solve(); @@ -387,11 +400,18 @@ namespace { public: /// getValueInBlock - This is the query interface to determine the lattice /// value for the specified Value* at the end of the specified block. - LVILatticeVal getValueInBlock(Value *V, BasicBlock *BB); + LVILatticeVal getValueInBlock(Value *V, BasicBlock *BB, + Instruction *CxtI = nullptr); + + /// getValueAt - This is the query interface to determine the lattice + /// value for the specified Value* at the specified instruction (generally + /// from an assume intrinsic). + LVILatticeVal getValueAt(Value *V, Instruction *CxtI); /// getValueOnEdge - This is the query interface to determine the lattice /// value for the specified Value* that is true on the specified edge. - LVILatticeVal getValueOnEdge(Value *V, BasicBlock *FromBB,BasicBlock *ToBB); + LVILatticeVal getValueOnEdge(Value *V, BasicBlock *FromBB,BasicBlock *ToBB, + Instruction *CxtI = nullptr); /// threadEdge - This is the update interface to inform the cache that an /// edge from PredBB to OldSucc has been threaded to be from PredBB to @@ -408,6 +428,10 @@ namespace { ValueCache.clear(); OverDefinedCache.clear(); } + + LazyValueInfoCache(AssumptionTracker *AT, + const DataLayout *DL = nullptr, + DominatorTree *DT = nullptr) : AT(AT), DL(DL), DT(DT) {} }; } // end anonymous namespace @@ -500,7 +524,6 @@ bool LazyValueInfoCache::solveBlockValue(Value *Val, BasicBlock *BB) { // cache needs updating, i.e. if we have solve a new value or not. OverDefinedCacheUpdater ODCacheUpdater(Val, BB, BBLV, this); - // If we've already computed this block's value, return it. if (!BBLV.isUndefined()) { DEBUG(dbgs() << " reuse BB '" << BB->getName() << "' val=" << BBLV <<'\n'); @@ -669,7 +692,10 @@ bool LazyValueInfoCache::solveBlockValuePHINode(LVILatticeVal &BBLV, BasicBlock *PhiBB = PN->getIncomingBlock(i); Value *PhiVal = PN->getIncomingValue(i); LVILatticeVal EdgeResult; - EdgesMissing |= !getEdgeValue(PhiVal, PhiBB, BB, EdgeResult); + // Note that we can provide PN as the context value to getEdgeValue, even + // though the results will be cached, because PN is the value being used as + // the cache key in the caller. + EdgesMissing |= !getEdgeValue(PhiVal, PhiBB, BB, EdgeResult, PN); if (EdgesMissing) continue; @@ -694,6 +720,36 @@ bool LazyValueInfoCache::solveBlockValuePHINode(LVILatticeVal &BBLV, return true; } +static bool getValueFromFromCondition(Value *Val, ICmpInst *ICI, + LVILatticeVal &Result, + bool isTrueDest = true); + +// If we can determine a constant range for the value Val at the context +// provided by the instruction BBI, then merge it into BBLV. If we did find a +// constant range, return true. +void LazyValueInfoCache::mergeAssumeBlockValueConstantRange( + Value *Val, LVILatticeVal &BBLV, Instruction *BBI) { + BBI = BBI ? BBI : dyn_cast<Instruction>(Val); + if (!BBI) + return; + + for (auto &I : AT->assumptions(BBI->getParent()->getParent())) { + if (!isValidAssumeForContext(I, BBI, DL, DT)) + continue; + + Value *C = I->getArgOperand(0); + if (ICmpInst *ICI = dyn_cast<ICmpInst>(C)) { + LVILatticeVal Result; + if (getValueFromFromCondition(Val, ICI, Result)) { + if (BBLV.isOverdefined()) + BBLV = Result; + else + BBLV.mergeIn(Result); + } + } + } +} + bool LazyValueInfoCache::solveBlockValueConstantRange(LVILatticeVal &BBLV, Instruction *BBI, BasicBlock *BB) { @@ -704,6 +760,7 @@ bool LazyValueInfoCache::solveBlockValueConstantRange(LVILatticeVal &BBLV, } LVILatticeVal LHSVal = getBlockValue(BBI->getOperand(0), BB); + mergeAssumeBlockValueConstantRange(BBI->getOperand(0), LHSVal, BBI); if (!LHSVal.isConstantRange()) { BBLV.markOverdefined(); return true; @@ -775,6 +832,47 @@ bool LazyValueInfoCache::solveBlockValueConstantRange(LVILatticeVal &BBLV, return true; } +bool getValueFromFromCondition(Value *Val, ICmpInst *ICI, + LVILatticeVal &Result, bool isTrueDest) { + if (ICI && isa<Constant>(ICI->getOperand(1))) { + if (ICI->isEquality() && ICI->getOperand(0) == Val) { + // We know that V has the RHS constant if this is a true SETEQ or + // false SETNE. + if (isTrueDest == (ICI->getPredicate() == ICmpInst::ICMP_EQ)) + Result = LVILatticeVal::get(cast<Constant>(ICI->getOperand(1))); + else + Result = LVILatticeVal::getNot(cast<Constant>(ICI->getOperand(1))); + return true; + } + + // Recognize the range checking idiom that InstCombine produces. + // (X-C1) u< C2 --> [C1, C1+C2) + ConstantInt *NegOffset = nullptr; + if (ICI->getPredicate() == ICmpInst::ICMP_ULT) + match(ICI->getOperand(0), m_Add(m_Specific(Val), + m_ConstantInt(NegOffset))); + + ConstantInt *CI = dyn_cast<ConstantInt>(ICI->getOperand(1)); + if (CI && (ICI->getOperand(0) == Val || NegOffset)) { + // Calculate the range of values that would satisfy the comparison. + ConstantRange CmpRange(CI->getValue()); + ConstantRange TrueValues = + ConstantRange::makeICmpRegion(ICI->getPredicate(), CmpRange); + + if (NegOffset) // Apply the offset from above. + TrueValues = TrueValues.subtract(NegOffset->getValue()); + + // If we're interested in the false dest, invert the condition. + if (!isTrueDest) TrueValues = TrueValues.inverse(); + + Result = LVILatticeVal::getRange(TrueValues); + return true; + } + } + + return false; +} + /// \brief Compute the value of Val on the edge BBFrom -> BBTo. Returns false if /// Val is not constrained on the edge. static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, @@ -801,41 +899,8 @@ static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, // If the condition of the branch is an equality comparison, we may be // able to infer the value. ICmpInst *ICI = dyn_cast<ICmpInst>(BI->getCondition()); - if (ICI && isa<Constant>(ICI->getOperand(1))) { - if (ICI->isEquality() && ICI->getOperand(0) == Val) { - // We know that V has the RHS constant if this is a true SETEQ or - // false SETNE. - if (isTrueDest == (ICI->getPredicate() == ICmpInst::ICMP_EQ)) - Result = LVILatticeVal::get(cast<Constant>(ICI->getOperand(1))); - else - Result = LVILatticeVal::getNot(cast<Constant>(ICI->getOperand(1))); - return true; - } - - // Recognize the range checking idiom that InstCombine produces. - // (X-C1) u< C2 --> [C1, C1+C2) - ConstantInt *NegOffset = nullptr; - if (ICI->getPredicate() == ICmpInst::ICMP_ULT) - match(ICI->getOperand(0), m_Add(m_Specific(Val), - m_ConstantInt(NegOffset))); - - ConstantInt *CI = dyn_cast<ConstantInt>(ICI->getOperand(1)); - if (CI && (ICI->getOperand(0) == Val || NegOffset)) { - // Calculate the range of values that would satisfy the comparison. - ConstantRange CmpRange(CI->getValue()); - ConstantRange TrueValues = - ConstantRange::makeICmpRegion(ICI->getPredicate(), CmpRange); - - if (NegOffset) // Apply the offset from above. - TrueValues = TrueValues.subtract(NegOffset->getValue()); - - // If we're interested in the false dest, invert the condition. - if (!isTrueDest) TrueValues = TrueValues.inverse(); - - Result = LVILatticeVal::getRange(TrueValues); - return true; - } - } + if (getValueFromFromCondition(Val, ICI, Result, isTrueDest)) + return true; } } @@ -869,7 +934,8 @@ static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, /// \brief Compute the value of Val on the edge BBFrom -> BBTo, or the value at /// the basic block if the edge does not constraint Val. bool LazyValueInfoCache::getEdgeValue(Value *Val, BasicBlock *BBFrom, - BasicBlock *BBTo, LVILatticeVal &Result) { + BasicBlock *BBTo, LVILatticeVal &Result, + Instruction *CxtI) { // If already a constant, there is nothing to compute. if (Constant *VC = dyn_cast<Constant>(Val)) { Result = LVILatticeVal::get(VC); @@ -891,6 +957,10 @@ bool LazyValueInfoCache::getEdgeValue(Value *Val, BasicBlock *BBFrom, // Try to intersect ranges of the BB and the constraint on the edge. LVILatticeVal InBlock = getBlockValue(Val, BBFrom); + mergeAssumeBlockValueConstantRange(Val, InBlock, BBFrom->getTerminator()); + // See note on the use of the CxtI with mergeAssumeBlockValueConstantRange, + // and caching, below. + mergeAssumeBlockValueConstantRange(Val, InBlock, CxtI); if (!InBlock.isConstantRange()) return true; @@ -907,30 +977,54 @@ bool LazyValueInfoCache::getEdgeValue(Value *Val, BasicBlock *BBFrom, // if we couldn't compute the value on the edge, use the value from the BB Result = getBlockValue(Val, BBFrom); + mergeAssumeBlockValueConstantRange(Val, Result, BBFrom->getTerminator()); + // We can use the context instruction (generically the ultimate instruction + // the calling pass is trying to simplify) here, even though the result of + // this function is generally cached when called from the solve* functions + // (and that cached result might be used with queries using a different + // context instruction), because when this function is called from the solve* + // functions, the context instruction is not provided. When called from + // LazyValueInfoCache::getValueOnEdge, the context instruction is provided, + // but then the result is not cached. + mergeAssumeBlockValueConstantRange(Val, Result, CxtI); return true; } -LVILatticeVal LazyValueInfoCache::getValueInBlock(Value *V, BasicBlock *BB) { +LVILatticeVal LazyValueInfoCache::getValueInBlock(Value *V, BasicBlock *BB, + Instruction *CxtI) { DEBUG(dbgs() << "LVI Getting block end value " << *V << " at '" << BB->getName() << "'\n"); BlockValueStack.push(std::make_pair(BB, V)); solve(); LVILatticeVal Result = getBlockValue(V, BB); + mergeAssumeBlockValueConstantRange(V, Result, CxtI); + + DEBUG(dbgs() << " Result = " << Result << "\n"); + return Result; +} + +LVILatticeVal LazyValueInfoCache::getValueAt(Value *V, Instruction *CxtI) { + DEBUG(dbgs() << "LVI Getting value " << *V << " at '" + << CxtI->getName() << "'\n"); + + LVILatticeVal Result; + mergeAssumeBlockValueConstantRange(V, Result, CxtI); DEBUG(dbgs() << " Result = " << Result << "\n"); return Result; } LVILatticeVal LazyValueInfoCache:: -getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB) { +getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB, + Instruction *CxtI) { DEBUG(dbgs() << "LVI Getting edge value " << *V << " from '" << FromBB->getName() << "' to '" << ToBB->getName() << "'\n"); LVILatticeVal Result; - if (!getEdgeValue(V, FromBB, ToBB, Result)) { + if (!getEdgeValue(V, FromBB, ToBB, Result, CxtI)) { solve(); - bool WasFastQuery = getEdgeValue(V, FromBB, ToBB, Result); + bool WasFastQuery = getEdgeValue(V, FromBB, ToBB, Result, CxtI); (void)WasFastQuery; assert(WasFastQuery && "More work to do after problem solved?"); } @@ -1004,39 +1098,51 @@ void LazyValueInfoCache::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc, //===----------------------------------------------------------------------===// /// getCache - This lazily constructs the LazyValueInfoCache. -static LazyValueInfoCache &getCache(void *&PImpl) { +static LazyValueInfoCache &getCache(void *&PImpl, + AssumptionTracker *AT, + const DataLayout *DL = nullptr, + DominatorTree *DT = nullptr) { if (!PImpl) - PImpl = new LazyValueInfoCache(); + PImpl = new LazyValueInfoCache(AT, DL, DT); return *static_cast<LazyValueInfoCache*>(PImpl); } bool LazyValueInfo::runOnFunction(Function &F) { - if (PImpl) - getCache(PImpl).clear(); + AT = &getAnalysis<AssumptionTracker>(); + + DominatorTreeWrapperPass *DTWP = + getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DT = DTWP ? &DTWP->getDomTree() : nullptr; DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; TLI = &getAnalysis<TargetLibraryInfo>(); + if (PImpl) + getCache(PImpl, AT, DL, DT).clear(); + // Fully lazy. return false; } void LazyValueInfo::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); + AU.addRequired<AssumptionTracker>(); AU.addRequired<TargetLibraryInfo>(); } void LazyValueInfo::releaseMemory() { // If the cache was allocated, free it. if (PImpl) { - delete &getCache(PImpl); + delete &getCache(PImpl, AT); PImpl = nullptr; } } -Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB) { - LVILatticeVal Result = getCache(PImpl).getValueInBlock(V, BB); +Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB, + Instruction *CxtI) { + LVILatticeVal Result = + getCache(PImpl, AT, DL, DT).getValueInBlock(V, BB, CxtI); if (Result.isConstant()) return Result.getConstant(); @@ -1051,8 +1157,10 @@ Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB) { /// getConstantOnEdge - Determine whether the specified value is known to be a /// constant on the specified edge. Return null if not. Constant *LazyValueInfo::getConstantOnEdge(Value *V, BasicBlock *FromBB, - BasicBlock *ToBB) { - LVILatticeVal Result = getCache(PImpl).getValueOnEdge(V, FromBB, ToBB); + BasicBlock *ToBB, + Instruction *CxtI) { + LVILatticeVal Result = + getCache(PImpl, AT, DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI); if (Result.isConstant()) return Result.getConstant(); @@ -1064,51 +1172,47 @@ Constant *LazyValueInfo::getConstantOnEdge(Value *V, BasicBlock *FromBB, return nullptr; } -/// getPredicateOnEdge - Determine whether the specified value comparison -/// with a constant is known to be true or false on the specified CFG edge. -/// Pred is a CmpInst predicate. -LazyValueInfo::Tristate -LazyValueInfo::getPredicateOnEdge(unsigned Pred, Value *V, Constant *C, - BasicBlock *FromBB, BasicBlock *ToBB) { - LVILatticeVal Result = getCache(PImpl).getValueOnEdge(V, FromBB, ToBB); - +static LazyValueInfo::Tristate +getPredicateResult(unsigned Pred, Constant *C, LVILatticeVal &Result, + const DataLayout *DL, TargetLibraryInfo *TLI) { + // If we know the value is a constant, evaluate the conditional. Constant *Res = nullptr; if (Result.isConstant()) { Res = ConstantFoldCompareInstOperands(Pred, Result.getConstant(), C, DL, TLI); if (ConstantInt *ResCI = dyn_cast<ConstantInt>(Res)) - return ResCI->isZero() ? False : True; - return Unknown; + return ResCI->isZero() ? LazyValueInfo::False : LazyValueInfo::True; + return LazyValueInfo::Unknown; } if (Result.isConstantRange()) { ConstantInt *CI = dyn_cast<ConstantInt>(C); - if (!CI) return Unknown; + if (!CI) return LazyValueInfo::Unknown; ConstantRange CR = Result.getConstantRange(); if (Pred == ICmpInst::ICMP_EQ) { if (!CR.contains(CI->getValue())) - return False; + return LazyValueInfo::False; if (CR.isSingleElement() && CR.contains(CI->getValue())) - return True; + return LazyValueInfo::True; } else if (Pred == ICmpInst::ICMP_NE) { if (!CR.contains(CI->getValue())) - return True; + return LazyValueInfo::True; if (CR.isSingleElement() && CR.contains(CI->getValue())) - return False; + return LazyValueInfo::False; } // Handle more complex predicates. ConstantRange TrueValues = ICmpInst::makeConstantRange((ICmpInst::Predicate)Pred, CI->getValue()); if (TrueValues.contains(CR)) - return True; + return LazyValueInfo::True; if (TrueValues.inverse().contains(CR)) - return False; - return Unknown; + return LazyValueInfo::False; + return LazyValueInfo::Unknown; } if (Result.isNotConstant()) { @@ -1120,26 +1224,48 @@ LazyValueInfo::getPredicateOnEdge(unsigned Pred, Value *V, Constant *C, Result.getNotConstant(), C, DL, TLI); if (Res->isNullValue()) - return False; + return LazyValueInfo::False; } else if (Pred == ICmpInst::ICMP_NE) { // !C1 != C -> true iff C1 == C. Res = ConstantFoldCompareInstOperands(ICmpInst::ICMP_NE, Result.getNotConstant(), C, DL, TLI); if (Res->isNullValue()) - return True; + return LazyValueInfo::True; } - return Unknown; + return LazyValueInfo::Unknown; } - return Unknown; + return LazyValueInfo::Unknown; +} + +/// getPredicateOnEdge - Determine whether the specified value comparison +/// with a constant is known to be true or false on the specified CFG edge. +/// Pred is a CmpInst predicate. +LazyValueInfo::Tristate +LazyValueInfo::getPredicateOnEdge(unsigned Pred, Value *V, Constant *C, + BasicBlock *FromBB, BasicBlock *ToBB, + Instruction *CxtI) { + LVILatticeVal Result = + getCache(PImpl, AT, DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI); + + return getPredicateResult(Pred, C, Result, DL, TLI); +} + +LazyValueInfo::Tristate +LazyValueInfo::getPredicateAt(unsigned Pred, Value *V, Constant *C, + Instruction *CxtI) { + LVILatticeVal Result = + getCache(PImpl, AT, DL, DT).getValueAt(V, CxtI); + + return getPredicateResult(Pred, C, Result, DL, TLI); } void LazyValueInfo::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc, BasicBlock *NewSucc) { - if (PImpl) getCache(PImpl).threadEdge(PredBB, OldSucc, NewSucc); + if (PImpl) getCache(PImpl, AT, DL, DT).threadEdge(PredBB, OldSucc, NewSucc); } void LazyValueInfo::eraseBlock(BasicBlock *BB) { - if (PImpl) getCache(PImpl).eraseBlock(BB); + if (PImpl) getCache(PImpl, AT, DL, DT).eraseBlock(BB); } diff --git a/lib/Analysis/LibCallSemantics.cpp b/lib/Analysis/LibCallSemantics.cpp index 7d4e254..23639e7 100644 --- a/lib/Analysis/LibCallSemantics.cpp +++ b/lib/Analysis/LibCallSemantics.cpp @@ -18,7 +18,7 @@ #include "llvm/IR/Function.h" using namespace llvm; -/// getMap - This impl pointer in ~LibCallInfo is actually a StringMap. This +/// This impl pointer in ~LibCallInfo is actually a StringMap. This /// helper does the cast. static StringMap<const LibCallFunctionInfo*> *getMap(void *Ptr) { return static_cast<StringMap<const LibCallFunctionInfo*> *>(Ptr); @@ -38,7 +38,7 @@ const LibCallLocationInfo &LibCallInfo::getLocationInfo(unsigned LocID) const { } -/// getFunctionInfo - Return the LibCallFunctionInfo object corresponding to +/// Return the LibCallFunctionInfo object corresponding to /// the specified function if we have it. If not, return null. const LibCallFunctionInfo * LibCallInfo::getFunctionInfo(const Function *F) const { diff --git a/lib/Analysis/Lint.cpp b/lib/Analysis/Lint.cpp index b14f329..8ee9b8a 100644 --- a/lib/Analysis/Lint.cpp +++ b/lib/Analysis/Lint.cpp @@ -37,6 +37,7 @@ #include "llvm/Analysis/Lint.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" @@ -96,11 +97,12 @@ namespace { Value *findValue(Value *V, bool OffsetOk) const; Value *findValueImpl(Value *V, bool OffsetOk, - SmallPtrSet<Value *, 4> &Visited) const; + SmallPtrSetImpl<Value *> &Visited) const; public: Module *Mod; AliasAnalysis *AA; + AssumptionTracker *AT; DominatorTree *DT; const DataLayout *DL; TargetLibraryInfo *TLI; @@ -118,6 +120,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); AU.addRequired<AliasAnalysis>(); + AU.addRequired<AssumptionTracker>(); AU.addRequired<TargetLibraryInfo>(); AU.addRequired<DominatorTreeWrapperPass>(); } @@ -151,6 +154,7 @@ namespace { char Lint::ID = 0; INITIALIZE_PASS_BEGIN(Lint, "lint", "Statically lint-checks LLVM IR", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_AG_DEPENDENCY(AliasAnalysis) @@ -175,6 +179,7 @@ INITIALIZE_PASS_END(Lint, "lint", "Statically lint-checks LLVM IR", bool Lint::runOnFunction(Function &F) { Mod = F.getParent(); AA = &getAnalysis<AliasAnalysis>(); + AT = &getAnalysis<AssumptionTracker>(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; @@ -504,7 +509,8 @@ void Lint::visitShl(BinaryOperator &I) { "Undefined result: Shift count out of range", &I); } -static bool isZero(Value *V, const DataLayout *DL) { +static bool isZero(Value *V, const DataLayout *DL, DominatorTree *DT, + AssumptionTracker *AT) { // Assume undef could be zero. if (isa<UndefValue>(V)) return true; @@ -513,7 +519,8 @@ static bool isZero(Value *V, const DataLayout *DL) { if (!VecTy) { unsigned BitWidth = V->getType()->getIntegerBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, DL); + computeKnownBits(V, KnownZero, KnownOne, DL, + 0, AT, dyn_cast<Instruction>(V), DT); return KnownZero.isAllOnesValue(); } @@ -543,22 +550,22 @@ static bool isZero(Value *V, const DataLayout *DL) { } void Lint::visitSDiv(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AT), "Undefined behavior: Division by zero", &I); } void Lint::visitUDiv(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AT), "Undefined behavior: Division by zero", &I); } void Lint::visitSRem(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AT), "Undefined behavior: Division by zero", &I); } void Lint::visitURem(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AT), "Undefined behavior: Division by zero", &I); } @@ -622,9 +629,9 @@ Value *Lint::findValue(Value *V, bool OffsetOk) const { /// findValueImpl - Implementation helper for findValue. Value *Lint::findValueImpl(Value *V, bool OffsetOk, - SmallPtrSet<Value *, 4> &Visited) const { + SmallPtrSetImpl<Value *> &Visited) const { // Detect self-referential values. - if (!Visited.insert(V)) + if (!Visited.insert(V).second) return UndefValue::get(V->getType()); // TODO: Look through sext or zext cast, when the result is known to @@ -638,7 +645,8 @@ Value *Lint::findValueImpl(Value *V, bool OffsetOk, BasicBlock *BB = L->getParent(); SmallPtrSet<BasicBlock *, 4> VisitedBlocks; for (;;) { - if (!VisitedBlocks.insert(BB)) break; + if (!VisitedBlocks.insert(BB).second) + break; if (Value *U = FindAvailableLoadedValue(L->getPointerOperand(), BB, BBI, 6, AA)) return findValueImpl(U, OffsetOk, Visited); @@ -678,7 +686,7 @@ Value *Lint::findValueImpl(Value *V, bool OffsetOk, // As a last resort, try SimplifyInstruction or constant folding. if (Instruction *Inst = dyn_cast<Instruction>(V)) { - if (Value *W = SimplifyInstruction(Inst, DL, TLI, DT)) + if (Value *W = SimplifyInstruction(Inst, DL, TLI, DT, AT)) return findValueImpl(W, OffsetOk, Visited); } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { if (Value *W = ConstantFoldConstantExpression(CE, DL, TLI)) diff --git a/lib/Analysis/Loads.cpp b/lib/Analysis/Loads.cpp index 005d309..bb0d60e 100644 --- a/lib/Analysis/Loads.cpp +++ b/lib/Analysis/Loads.cpp @@ -22,25 +22,29 @@ #include "llvm/IR/Operator.h" using namespace llvm; -/// AreEquivalentAddressValues - Test if A and B will obviously have the same -/// value. This includes recognizing that %t0 and %t1 will have the same +/// \brief Test if A and B will obviously have the same value. +/// +/// This includes recognizing that %t0 and %t1 will have the same /// value in code like this: +/// \code /// %t0 = getelementptr \@a, 0, 3 /// store i32 0, i32* %t0 /// %t1 = getelementptr \@a, 0, 3 /// %t2 = load i32* %t1 +/// \endcode /// static bool AreEquivalentAddressValues(const Value *A, const Value *B) { // Test if the values are trivially equivalent. - if (A == B) return true; + if (A == B) + return true; // Test if the values come from identical arithmetic instructions. // Use isIdenticalToWhenDefined instead of isIdenticalTo because // this function is only used when one address use dominates the // other, which means that they'll always either have the same // value or one of them will have an undefined value. - if (isa<BinaryOperator>(A) || isa<CastInst>(A) || - isa<PHINode>(A) || isa<GetElementPtrInst>(A)) + if (isa<BinaryOperator>(A) || isa<CastInst>(A) || isa<PHINode>(A) || + isa<GetElementPtrInst>(A)) if (const Instruction *BI = dyn_cast<Instruction>(B)) if (cast<Instruction>(A)->isIdenticalToWhenDefined(BI)) return true; @@ -49,15 +53,19 @@ static bool AreEquivalentAddressValues(const Value *A, const Value *B) { return false; } -/// isSafeToLoadUnconditionally - Return true if we know that executing a load -/// from this value cannot trap. If it is not obviously safe to load from the -/// specified pointer, we do a quick local scan of the basic block containing -/// ScanFrom, to determine if the address is already accessed. +/// \brief Check if executing a load of this pointer value cannot trap. +/// +/// If it is not obviously safe to load from the specified pointer, we do +/// a quick local scan of the basic block containing \c ScanFrom, to determine +/// if the address is already accessed. +/// +/// This uses the pointee type to determine how many bytes need to be safe to +/// load from the pointer. bool llvm::isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom, - unsigned Align, const DataLayout *TD) { + unsigned Align, const DataLayout *DL) { int64_t ByteOffset = 0; Value *Base = V; - Base = GetPointerBaseWithConstantOffset(V, ByteOffset, TD); + Base = GetPointerBaseWithConstantOffset(V, ByteOffset, DL); if (ByteOffset < 0) // out of bounds return false; @@ -69,26 +77,29 @@ bool llvm::isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom, BaseType = AI->getAllocatedType(); BaseAlign = AI->getAlignment(); } else if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Base)) { - // Global variables are safe to load from but their size cannot be - // guaranteed if they are overridden. + // Global variables are not necessarily safe to load from if they are + // overridden. Their size may change or they may be weak and require a test + // to determine if they were in fact provided. if (!GV->mayBeOverridden()) { BaseType = GV->getType()->getElementType(); BaseAlign = GV->getAlignment(); } } - if (BaseType && BaseType->isSized()) { - if (TD && BaseAlign == 0) - BaseAlign = TD->getPrefTypeAlignment(BaseType); + PointerType *AddrTy = cast<PointerType>(V->getType()); + uint64_t LoadSize = DL ? DL->getTypeStoreSize(AddrTy->getElementType()) : 0; - if (Align <= BaseAlign) { - if (!TD) - return true; // Loading directly from an alloca or global is OK. + // If we found a base allocated type from either an alloca or global variable, + // try to see if we are definitively within the allocated region. We need to + // know the size of the base type and the loaded type to do anything in this + // case, so only try this when we have the DataLayout available. + if (BaseType && BaseType->isSized() && DL) { + if (BaseAlign == 0) + BaseAlign = DL->getPrefTypeAlignment(BaseType); + if (Align <= BaseAlign) { // Check if the load is within the bounds of the underlying object. - PointerType *AddrTy = cast<PointerType>(V->getType()); - uint64_t LoadSize = TD->getTypeStoreSize(AddrTy->getElementType()); - if (ByteOffset + LoadSize <= TD->getTypeAllocSize(BaseType) && + if (ByteOffset + LoadSize <= DL->getTypeAllocSize(BaseType) && (Align == 0 || (ByteOffset % Align) == 0)) return true; } @@ -101,6 +112,10 @@ bool llvm::isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom, // the load entirely). BasicBlock::iterator BBI = ScanFrom, E = ScanFrom->getParent()->begin(); + // We can at least always strip pointer casts even though we can't use the + // base here. + V = V->stripPointerCasts(); + while (BBI != E) { --BBI; @@ -110,46 +125,62 @@ bool llvm::isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom, !isa<DbgInfoIntrinsic>(BBI)) return false; - if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) { - if (AreEquivalentAddressValues(LI->getOperand(0), V)) return true; - } else if (StoreInst *SI = dyn_cast<StoreInst>(BBI)) { - if (AreEquivalentAddressValues(SI->getOperand(1), V)) return true; - } + Value *AccessedPtr; + if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) + AccessedPtr = LI->getPointerOperand(); + else if (StoreInst *SI = dyn_cast<StoreInst>(BBI)) + AccessedPtr = SI->getPointerOperand(); + else + continue; + + // Handle trivial cases even w/o DataLayout or other work. + if (AccessedPtr == V) + return true; + + if (!DL) + continue; + + auto *AccessedTy = cast<PointerType>(AccessedPtr->getType()); + if (AreEquivalentAddressValues(AccessedPtr->stripPointerCasts(), V) && + LoadSize <= DL->getTypeStoreSize(AccessedTy->getElementType())) + return true; } return false; } -/// FindAvailableLoadedValue - Scan the ScanBB block backwards (starting at the -/// instruction before ScanFrom) checking to see if we have the value at the +/// \brief Scan the ScanBB block backwards to see if we have the value at the /// memory address *Ptr locally available within a small number of instructions. -/// If the value is available, return it. /// -/// If not, return the iterator for the last validated instruction that the -/// value would be live through. If we scanned the entire block and didn't find -/// something that invalidates *Ptr or provides it, ScanFrom would be left at -/// begin() and this returns null. ScanFrom could also be left +/// The scan starts from \c ScanFrom. \c MaxInstsToScan specifies the maximum +/// instructions to scan in the block. If it is set to \c 0, it will scan the whole +/// block. +/// +/// If the value is available, this function returns it. If not, it returns the +/// iterator for the last validated instruction that the value would be live +/// through. If we scanned the entire block and didn't find something that +/// invalidates \c *Ptr or provides it, \c ScanFrom is left at the last +/// instruction processed and this returns null. /// -/// MaxInstsToScan specifies the maximum instructions to scan in the block. If -/// it is set to 0, it will scan the whole block. You can also optionally -/// specify an alias analysis implementation, which makes this more precise. +/// You can also optionally specify an alias analysis implementation, which +/// makes this more precise. /// -/// If TBAATag is non-null and a load or store is found, the TBAA tag from the -/// load or store is recorded there. If there is no TBAA tag or if no access -/// is found, it is left unmodified. +/// If \c AATags is non-null and a load or store is found, the AA tags from the +/// load or store are recorded there. If there are no AA tags or if no access is +/// found, it is left unmodified. Value *llvm::FindAvailableLoadedValue(Value *Ptr, BasicBlock *ScanBB, BasicBlock::iterator &ScanFrom, unsigned MaxInstsToScan, - AliasAnalysis *AA, - MDNode **TBAATag) { - if (MaxInstsToScan == 0) MaxInstsToScan = ~0U; + AliasAnalysis *AA, AAMDNodes *AATags) { + if (MaxInstsToScan == 0) + MaxInstsToScan = ~0U; + + Type *AccessTy = cast<PointerType>(Ptr->getType())->getElementType(); // If we're using alias analysis to disambiguate get the size of *Ptr. - uint64_t AccessSize = 0; - if (AA) { - Type *AccessTy = cast<PointerType>(Ptr->getType())->getElementType(); - AccessSize = AA->getTypeStoreSize(AccessTy); - } - + uint64_t AccessSize = AA ? AA->getTypeStoreSize(AccessTy) : 0; + + Value *StrippedPtr = Ptr->stripPointerCasts(); + while (ScanFrom != ScanBB->begin()) { // We must ignore debug info directives when counting (otherwise they // would affect codegen). @@ -159,62 +190,71 @@ Value *llvm::FindAvailableLoadedValue(Value *Ptr, BasicBlock *ScanBB, // Restore ScanFrom to expected value in case next test succeeds ScanFrom++; - + // Don't scan huge blocks. - if (MaxInstsToScan-- == 0) return nullptr; - + if (MaxInstsToScan-- == 0) + return nullptr; + --ScanFrom; // If this is a load of Ptr, the loaded value is available. // (This is true even if the load is volatile or atomic, although // those cases are unlikely.) if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) - if (AreEquivalentAddressValues(LI->getOperand(0), Ptr)) { - if (TBAATag) *TBAATag = LI->getMetadata(LLVMContext::MD_tbaa); + if (AreEquivalentAddressValues( + LI->getPointerOperand()->stripPointerCasts(), StrippedPtr) && + CastInst::isBitCastable(LI->getType(), AccessTy)) { + if (AATags) + LI->getAAMetadata(*AATags); return LI; } - + if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + Value *StorePtr = SI->getPointerOperand()->stripPointerCasts(); // If this is a store through Ptr, the value is available! // (This is true even if the store is volatile or atomic, although // those cases are unlikely.) - if (AreEquivalentAddressValues(SI->getOperand(1), Ptr)) { - if (TBAATag) *TBAATag = SI->getMetadata(LLVMContext::MD_tbaa); + if (AreEquivalentAddressValues(StorePtr, StrippedPtr) && + CastInst::isBitCastable(SI->getValueOperand()->getType(), AccessTy)) { + if (AATags) + SI->getAAMetadata(*AATags); return SI->getOperand(0); } - - // If Ptr is an alloca and this is a store to a different alloca, ignore - // the store. This is a trivial form of alias analysis that is important - // for reg2mem'd code. - if ((isa<AllocaInst>(Ptr) || isa<GlobalVariable>(Ptr)) && - (isa<AllocaInst>(SI->getOperand(1)) || - isa<GlobalVariable>(SI->getOperand(1)))) + + // If both StrippedPtr and StorePtr reach all the way to an alloca or + // global and they are different, ignore the store. This is a trivial form + // of alias analysis that is important for reg2mem'd code. + if ((isa<AllocaInst>(StrippedPtr) || isa<GlobalVariable>(StrippedPtr)) && + (isa<AllocaInst>(StorePtr) || isa<GlobalVariable>(StorePtr)) && + StrippedPtr != StorePtr) continue; - + // If we have alias analysis and it says the store won't modify the loaded // value, ignore the store. if (AA && - (AA->getModRefInfo(SI, Ptr, AccessSize) & AliasAnalysis::Mod) == 0) + (AA->getModRefInfo(SI, StrippedPtr, AccessSize) & + AliasAnalysis::Mod) == 0) continue; - + // Otherwise the store that may or may not alias the pointer, bail out. ++ScanFrom; return nullptr; } - + // If this is some other instruction that may clobber Ptr, bail out. if (Inst->mayWriteToMemory()) { // If alias analysis claims that it really won't modify the load, // ignore it. if (AA && - (AA->getModRefInfo(Inst, Ptr, AccessSize) & AliasAnalysis::Mod) == 0) + (AA->getModRefInfo(Inst, StrippedPtr, AccessSize) & + AliasAnalysis::Mod) == 0) continue; - + // May modify the pointer, bail out. ++ScanFrom; return nullptr; } } - + // Got to the start of the block, we didn't find it, but are done for this // block. return nullptr; diff --git a/lib/Analysis/LoopInfo.cpp b/lib/Analysis/LoopInfo.cpp index 46c0eaa..b1f62c4 100644 --- a/lib/Analysis/LoopInfo.cpp +++ b/lib/Analysis/LoopInfo.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -307,7 +308,8 @@ bool Loop::isAnnotatedParallel() const { // directly or indirectly through another list metadata (in case of // nested parallel loops). The loop identifier metadata refers to // itself so we can check both cases with the same routine. - MDNode *loopIdMD = II->getMetadata("llvm.mem.parallel_loop_access"); + MDNode *loopIdMD = + II->getMetadata(LLVMContext::MD_mem_parallel_loop_access); if (!loopIdMD) return false; diff --git a/lib/Analysis/LoopPass.cpp b/lib/Analysis/LoopPass.cpp index 7bd866e..190abc7 100644 --- a/lib/Analysis/LoopPass.cpp +++ b/lib/Analysis/LoopPass.cpp @@ -76,6 +76,9 @@ void LPPassManager::deleteLoopFromQueue(Loop *L) { LI->updateUnloop(L); + // Notify passes that the loop is being deleted. + deleteSimpleAnalysisLoop(L); + // If L is current loop then skip rest of the passes and let // runOnFunction remove L from LQ. Otherwise, remove L from LQ now // and continue applying other passes on CurrentLoop. @@ -164,6 +167,14 @@ void LPPassManager::deleteSimpleAnalysisValue(Value *V, Loop *L) { } } +/// Invoke deleteAnalysisLoop hook for all passes. +void LPPassManager::deleteSimpleAnalysisLoop(Loop *L) { + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + LoopPass *LP = getContainedPass(Index); + LP->deleteAnalysisLoop(L); + } +} + // Recurse through all subloops and all loops into LQ. static void addLoopIntoQueue(Loop *L, std::deque<Loop *> &LQ) { diff --git a/lib/Analysis/MemoryBuiltins.cpp b/lib/Analysis/MemoryBuiltins.cpp index 64d339f..08b41fe 100644 --- a/lib/Analysis/MemoryBuiltins.cpp +++ b/lib/Analysis/MemoryBuiltins.cpp @@ -332,7 +332,11 @@ const CallInst *llvm::isFreeCall(const Value *I, const TargetLibraryInfo *TLI) { TLIFn == LibFunc::ZdlPv || // operator delete(void*) TLIFn == LibFunc::ZdaPv) // operator delete[](void*) ExpectedNumParams = 1; - else if (TLIFn == LibFunc::ZdlPvRKSt9nothrow_t || // delete(void*, nothrow) + else if (TLIFn == LibFunc::ZdlPvj || // delete(void*, uint) + TLIFn == LibFunc::ZdlPvm || // delete(void*, ulong) + TLIFn == LibFunc::ZdlPvRKSt9nothrow_t || // delete(void*, nothrow) + TLIFn == LibFunc::ZdaPvj || // delete[](void*, uint) + TLIFn == LibFunc::ZdaPvm || // delete[](void*, ulong) TLIFn == LibFunc::ZdaPvRKSt9nothrow_t) // delete[](void*, nothrow) ExpectedNumParams = 2; else @@ -412,7 +416,7 @@ SizeOffsetType ObjectSizeOffsetVisitor::compute(Value *V) { if (Instruction *I = dyn_cast<Instruction>(V)) { // If we have already seen this instruction, bail out. Cycles can happen in // unreachable code after constant propagation. - if (!SeenInsts.insert(I)) + if (!SeenInsts.insert(I).second) return unknown(); if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) @@ -648,7 +652,7 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute_(Value *V) { // Record the pointers that were handled in this run, so that they can be // cleaned later if something fails. We also use this set to break cycles that // can occur in dead code. - if (!SeenVals.insert(V)) { + if (!SeenVals.insert(V).second) { Result = unknown(); } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { Result = visitGEPOperator(*GEP); diff --git a/lib/Analysis/MemoryDependenceAnalysis.cpp b/lib/Analysis/MemoryDependenceAnalysis.cpp index 9eaf109..187eada 100644 --- a/lib/Analysis/MemoryDependenceAnalysis.cpp +++ b/lib/Analysis/MemoryDependenceAnalysis.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/PHITransAddr.h" @@ -48,13 +49,17 @@ STATISTIC(NumCacheCompleteNonLocalPtr, "Number of block queries that were completely cached"); // Limit for the number of instructions to scan in a block. -static const int BlockScanLimit = 100; +static const unsigned int BlockScanLimit = 100; + +// Limit on the number of memdep results to process. +static const unsigned int NumResultsLimit = 100; char MemoryDependenceAnalysis::ID = 0; // Register this pass... INITIALIZE_PASS_BEGIN(MemoryDependenceAnalysis, "memdep", "Memory Dependence Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_AG_DEPENDENCY(AliasAnalysis) INITIALIZE_PASS_END(MemoryDependenceAnalysis, "memdep", "Memory Dependence Analysis", false, true) @@ -83,11 +88,13 @@ void MemoryDependenceAnalysis::releaseMemory() { /// void MemoryDependenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); + AU.addRequired<AssumptionTracker>(); AU.addRequiredTransitive<AliasAnalysis>(); } bool MemoryDependenceAnalysis::runOnFunction(Function &) { AA = &getAnalysis<AliasAnalysis>(); + AT = &getAnalysis<AssumptionTracker>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; DominatorTreeWrapperPass *DTWP = @@ -158,29 +165,32 @@ AliasAnalysis::ModRefResult GetLocation(const Instruction *Inst, return AliasAnalysis::Mod; } - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + AAMDNodes AAInfo; + switch (II->getIntrinsicID()) { case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: case Intrinsic::invariant_start: + II->getAAMetadata(AAInfo); Loc = AliasAnalysis::Location(II->getArgOperand(1), cast<ConstantInt>(II->getArgOperand(0)) - ->getZExtValue(), - II->getMetadata(LLVMContext::MD_tbaa)); + ->getZExtValue(), AAInfo); // These intrinsics don't really modify the memory, but returning Mod // will allow them to be handled conservatively. return AliasAnalysis::Mod; case Intrinsic::invariant_end: + II->getAAMetadata(AAInfo); Loc = AliasAnalysis::Location(II->getArgOperand(2), cast<ConstantInt>(II->getArgOperand(1)) - ->getZExtValue(), - II->getMetadata(LLVMContext::MD_tbaa)); + ->getZExtValue(), AAInfo); // These intrinsics don't really modify the memory, but returning Mod // will allow them to be handled conservatively. return AliasAnalysis::Mod; default: break; } + } // Otherwise, just do the coarse-grained thing that always works. if (Inst->mayWriteToMemory()) @@ -367,6 +377,36 @@ getPointerDependencyFrom(const AliasAnalysis::Location &MemLoc, bool isLoad, int64_t MemLocOffset = 0; unsigned Limit = BlockScanLimit; bool isInvariantLoad = false; + + // We must be careful with atomic accesses, as they may allow another thread + // to touch this location, cloberring it. We are conservative: if the + // QueryInst is not a simple (non-atomic) memory access, we automatically + // return getClobber. + // If it is simple, we know based on the results of + // "Compiler testing via a theory of sound optimisations in the C11/C++11 + // memory model" in PLDI 2013, that a non-atomic location can only be + // clobbered between a pair of a release and an acquire action, with no + // access to the location in between. + // Here is an example for giving the general intuition behind this rule. + // In the following code: + // store x 0; + // release action; [1] + // acquire action; [4] + // %val = load x; + // It is unsafe to replace %val by 0 because another thread may be running: + // acquire action; [2] + // store x 42; + // release action; [3] + // with synchronization from 1 to 2 and from 3 to 4, resulting in %val + // being 42. A key property of this program however is that if either + // 1 or 4 were missing, there would be a race between the store of 42 + // either the store of 0 or the load (making the whole progam racy). + // The paper mentionned above shows that the same property is respected + // by every program that can detect any optimisation of that kind: either + // it is racy (undefined) or there is a release followed by an acquire + // between the pair of accesses under consideration. + bool HasSeenAcquire = false; + if (isLoad && QueryInst) { LoadInst *LI = dyn_cast<LoadInst>(QueryInst); if (LI && LI->getMetadata(LLVMContext::MD_invariant_load) != nullptr) @@ -404,10 +444,37 @@ getPointerDependencyFrom(const AliasAnalysis::Location &MemLoc, bool isLoad, // Values depend on loads if the pointers are must aliased. This means that // a load depends on another must aliased load from the same value. + // One exception is atomic loads: a value can depend on an atomic load that it + // does not alias with when this atomic load indicates that another thread may + // be accessing the location. if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { // Atomic loads have complications involved. + // A Monotonic (or higher) load is OK if the query inst is itself not atomic. + // An Acquire (or higher) load sets the HasSeenAcquire flag, so that any + // release store will know to return getClobber. // FIXME: This is overly conservative. - if (!LI->isUnordered()) + if (!LI->isUnordered()) { + if (!QueryInst) + return MemDepResult::getClobber(LI); + if (auto *QueryLI = dyn_cast<LoadInst>(QueryInst)) { + if (!QueryLI->isSimple()) + return MemDepResult::getClobber(LI); + } else if (auto *QuerySI = dyn_cast<StoreInst>(QueryInst)) { + if (!QuerySI->isSimple()) + return MemDepResult::getClobber(LI); + } else if (QueryInst->mayReadOrWriteMemory()) { + return MemDepResult::getClobber(LI); + } + + if (isAtLeastAcquire(LI->getOrdering())) + HasSeenAcquire = true; + } + + // FIXME: this is overly conservative. + // While volatile access cannot be eliminated, they do not have to clobber + // non-aliasing locations, as normal accesses can for example be reordered + // with volatile accesses. + if (LI->isVolatile()) return MemDepResult::getClobber(LI); AliasAnalysis::Location LoadLoc = AA->getLocation(LI); @@ -466,8 +533,32 @@ getPointerDependencyFrom(const AliasAnalysis::Location &MemLoc, bool isLoad, if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { // Atomic stores have complications involved. + // A Monotonic store is OK if the query inst is itself not atomic. + // A Release (or higher) store further requires that no acquire load + // has been seen. // FIXME: This is overly conservative. - if (!SI->isUnordered()) + if (!SI->isUnordered()) { + if (!QueryInst) + return MemDepResult::getClobber(SI); + if (auto *QueryLI = dyn_cast<LoadInst>(QueryInst)) { + if (!QueryLI->isSimple()) + return MemDepResult::getClobber(SI); + } else if (auto *QuerySI = dyn_cast<StoreInst>(QueryInst)) { + if (!QuerySI->isSimple()) + return MemDepResult::getClobber(SI); + } else if (QueryInst->mayReadOrWriteMemory()) { + return MemDepResult::getClobber(SI); + } + + if (HasSeenAcquire && isAtLeastRelease(SI->getOrdering())) + return MemDepResult::getClobber(SI); + } + + // FIXME: this is overly conservative. + // While volatile access cannot be eliminated, they do not have to clobber + // non-aliasing locations, as normal accesses can for example be reordered + // with volatile accesses. + if (SI->isVolatile()) return MemDepResult::getClobber(SI); // If alias analysis can tell that this store is guaranteed to not modify @@ -685,7 +776,7 @@ MemoryDependenceAnalysis::getNonLocalCallDependency(CallSite QueryCS) { DirtyBlocks.pop_back(); // Already processed this block? - if (!Visited.insert(DirtyBB)) + if (!Visited.insert(DirtyBB).second) continue; // Do a binary search to see if we already have an entry for this block in @@ -775,7 +866,7 @@ getNonLocalPointerDependency(const AliasAnalysis::Location &Loc, bool isLoad, "Can't get pointer deps of a non-pointer!"); Result.clear(); - PHITransAddr Address(const_cast<Value *>(Loc.Ptr), DL); + PHITransAddr Address(const_cast<Value *>(Loc.Ptr), DL, AT); // This is the set of blocks we've inspected, and the pointer we consider in // each block. Because of critical edges, we currently bail out if querying @@ -861,7 +952,7 @@ GetNonLocalInfoForBlock(const AliasAnalysis::Location &Loc, return Dep; } -/// SortNonLocalDepInfoCache - Sort the a NonLocalDepInfo cache, given a certain +/// SortNonLocalDepInfoCache - Sort the NonLocalDepInfo cache, given a certain /// number of elements in the array that are already properly ordered. This is /// optimized for the case when only a few entries are added. static void @@ -922,10 +1013,10 @@ getNonLocalPointerDepFromBB(const PHITransAddr &Pointer, // Set up a temporary NLPI value. If the map doesn't yet have an entry for // CacheKey, this value will be inserted as the associated value. Otherwise, // it'll be ignored, and we'll have to check to see if the cached size and - // tbaa tag are consistent with the current query. + // aa tags are consistent with the current query. NonLocalPointerInfo InitialNLPI; InitialNLPI.Size = Loc.Size; - InitialNLPI.TBAATag = Loc.TBAATag; + InitialNLPI.AATags = Loc.AATags; // Get the NLPI for CacheKey, inserting one into the map if it doesn't // already have one. @@ -955,21 +1046,21 @@ getNonLocalPointerDepFromBB(const PHITransAddr &Pointer, SkipFirstBlock); } - // If the query's TBAATag is inconsistent with the cached one, + // If the query's AATags are inconsistent with the cached one, // conservatively throw out the cached data and restart the query with // no tag if needed. - if (CacheInfo->TBAATag != Loc.TBAATag) { - if (CacheInfo->TBAATag) { + if (CacheInfo->AATags != Loc.AATags) { + if (CacheInfo->AATags) { CacheInfo->Pair = BBSkipFirstBlockPair(); - CacheInfo->TBAATag = nullptr; + CacheInfo->AATags = AAMDNodes(); for (NonLocalDepInfo::iterator DI = CacheInfo->NonLocalDeps.begin(), DE = CacheInfo->NonLocalDeps.end(); DI != DE; ++DI) if (Instruction *Inst = DI->getResult().getInst()) RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey); CacheInfo->NonLocalDeps.clear(); } - if (Loc.TBAATag) - return getNonLocalPointerDepFromBB(Pointer, Loc.getWithoutTBAATag(), + if (Loc.AATags) + return getNonLocalPointerDepFromBB(Pointer, Loc.getWithoutAATags(), isLoad, StartBB, Result, Visited, SkipFirstBlock); } @@ -1045,6 +1136,25 @@ getNonLocalPointerDepFromBB(const PHITransAddr &Pointer, while (!Worklist.empty()) { BasicBlock *BB = Worklist.pop_back_val(); + // If we do process a large number of blocks it becomes very expensive and + // likely it isn't worth worrying about + if (Result.size() > NumResultsLimit) { + Worklist.clear(); + // Sort it now (if needed) so that recursive invocations of + // getNonLocalPointerDepFromBB and other routines that could reuse the + // cache value will only see properly sorted cache arrays. + if (Cache && NumSortedEntries != Cache->size()) { + SortNonLocalDepInfoCache(*Cache, NumSortedEntries); + NumSortedEntries = Cache->size(); + } + // Since we bail out, the "Cache" set won't contain all of the + // results for the query. This is ok (we can still use it to accelerate + // specific block queries) but we can't do the fastpath "return all + // results from the set". Clear out the indicator for this. + CacheInfo->Pair = BBSkipFirstBlockPair(); + return true; + } + // Skip the first block if we have it. if (!SkipFirstBlock) { // Analyze the dependency of *Pointer in FromBB. See if we already have @@ -1369,14 +1479,11 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { ReverseDepMapType::iterator ReverseDepIt = ReverseLocalDeps.find(RemInst); if (ReverseDepIt != ReverseLocalDeps.end()) { - SmallPtrSet<Instruction*, 4> &ReverseDeps = ReverseDepIt->second; // RemInst can't be the terminator if it has local stuff depending on it. - assert(!ReverseDeps.empty() && !isa<TerminatorInst>(RemInst) && + assert(!ReverseDepIt->second.empty() && !isa<TerminatorInst>(RemInst) && "Nothing can locally depend on a terminator"); - for (SmallPtrSet<Instruction*, 4>::iterator I = ReverseDeps.begin(), - E = ReverseDeps.end(); I != E; ++I) { - Instruction *InstDependingOnRemInst = *I; + for (Instruction *InstDependingOnRemInst : ReverseDepIt->second) { assert(InstDependingOnRemInst != RemInst && "Already removed our local dep info"); @@ -1402,12 +1509,10 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { ReverseDepIt = ReverseNonLocalDeps.find(RemInst); if (ReverseDepIt != ReverseNonLocalDeps.end()) { - SmallPtrSet<Instruction*, 4> &Set = ReverseDepIt->second; - for (SmallPtrSet<Instruction*, 4>::iterator I = Set.begin(), E = Set.end(); - I != E; ++I) { - assert(*I != RemInst && "Already removed NonLocalDep info for RemInst"); + for (Instruction *I : ReverseDepIt->second) { + assert(I != RemInst && "Already removed NonLocalDep info for RemInst"); - PerInstNLInfo &INLD = NonLocalDeps[*I]; + PerInstNLInfo &INLD = NonLocalDeps[I]; // The information is now dirty! INLD.second = true; @@ -1419,7 +1524,7 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { DI->setResult(NewDirtyVal); if (Instruction *NextI = NewDirtyVal.getInst()) - ReverseDepsToAdd.push_back(std::make_pair(NextI, *I)); + ReverseDepsToAdd.push_back(std::make_pair(NextI, I)); } } @@ -1438,12 +1543,9 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { ReverseNonLocalPtrDepTy::iterator ReversePtrDepIt = ReverseNonLocalPtrDeps.find(RemInst); if (ReversePtrDepIt != ReverseNonLocalPtrDeps.end()) { - SmallPtrSet<ValueIsLoadPair, 4> &Set = ReversePtrDepIt->second; SmallVector<std::pair<Instruction*, ValueIsLoadPair>,8> ReversePtrDepsToAdd; - for (SmallPtrSet<ValueIsLoadPair, 4>::iterator I = Set.begin(), - E = Set.end(); I != E; ++I) { - ValueIsLoadPair P = *I; + for (ValueIsLoadPair P : ReversePtrDepIt->second) { assert(P.getPointer() != RemInst && "Already removed NonLocalPointerDeps info for RemInst"); @@ -1484,8 +1586,10 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { DEBUG(verifyRemoved(RemInst)); } /// verifyRemoved - Verify that the specified instruction does not occur -/// in our internal data structures. +/// in our internal data structures. This function verifies by asserting in +/// debug builds. void MemoryDependenceAnalysis::verifyRemoved(Instruction *D) const { +#ifndef NDEBUG for (LocalDepMapType::const_iterator I = LocalDeps.begin(), E = LocalDeps.end(); I != E; ++I) { assert(I->first != D && "Inst occurs in data structures"); @@ -1514,18 +1618,16 @@ void MemoryDependenceAnalysis::verifyRemoved(Instruction *D) const { for (ReverseDepMapType::const_iterator I = ReverseLocalDeps.begin(), E = ReverseLocalDeps.end(); I != E; ++I) { assert(I->first != D && "Inst occurs in data structures"); - for (SmallPtrSet<Instruction*, 4>::const_iterator II = I->second.begin(), - EE = I->second.end(); II != EE; ++II) - assert(*II != D && "Inst occurs in data structures"); + for (Instruction *Inst : I->second) + assert(Inst != D && "Inst occurs in data structures"); } for (ReverseDepMapType::const_iterator I = ReverseNonLocalDeps.begin(), E = ReverseNonLocalDeps.end(); I != E; ++I) { assert(I->first != D && "Inst occurs in data structures"); - for (SmallPtrSet<Instruction*, 4>::const_iterator II = I->second.begin(), - EE = I->second.end(); II != EE; ++II) - assert(*II != D && "Inst occurs in data structures"); + for (Instruction *Inst : I->second) + assert(Inst != D && "Inst occurs in data structures"); } for (ReverseNonLocalPtrDepTy::const_iterator @@ -1533,11 +1635,10 @@ void MemoryDependenceAnalysis::verifyRemoved(Instruction *D) const { E = ReverseNonLocalPtrDeps.end(); I != E; ++I) { assert(I->first != D && "Inst occurs in rev NLPD map"); - for (SmallPtrSet<ValueIsLoadPair, 4>::const_iterator II = I->second.begin(), - E = I->second.end(); II != E; ++II) - assert(*II != ValueIsLoadPair(D, false) && - *II != ValueIsLoadPair(D, true) && + for (ValueIsLoadPair P : I->second) + assert(P != ValueIsLoadPair(D, false) && + P != ValueIsLoadPair(D, true) && "Inst occurs in ReverseNonLocalPtrDeps map"); } - +#endif } diff --git a/lib/Analysis/NoAliasAnalysis.cpp b/lib/Analysis/NoAliasAnalysis.cpp index 139fa38..c214d3c 100644 --- a/lib/Analysis/NoAliasAnalysis.cpp +++ b/lib/Analysis/NoAliasAnalysis.cpp @@ -57,8 +57,9 @@ namespace { Location getArgLocation(ImmutableCallSite CS, unsigned ArgIdx, ModRefResult &Mask) override { Mask = ModRef; - return Location(CS.getArgument(ArgIdx), UnknownSize, - CS.getInstruction()->getMetadata(LLVMContext::MD_tbaa)); + AAMDNodes AATags; + CS->getAAMetadata(AATags); + return Location(CS.getArgument(ArgIdx), UnknownSize, AATags); } ModRefResult getModRefInfo(ImmutableCallSite CS, diff --git a/lib/Analysis/PHITransAddr.cpp b/lib/Analysis/PHITransAddr.cpp index bfe8642..b3d060a 100644 --- a/lib/Analysis/PHITransAddr.cpp +++ b/lib/Analysis/PHITransAddr.cpp @@ -228,7 +228,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, return GEP; // Simplify the GEP to handle 'gep x, 0' -> x etc. - if (Value *V = SimplifyGEPInst(GEPOps, DL, TLI, DT)) { + if (Value *V = SimplifyGEPInst(GEPOps, DL, TLI, DT, AT)) { for (unsigned i = 0, e = GEPOps.size(); i != e; ++i) RemoveInstInputs(GEPOps[i], InstInputs); @@ -283,7 +283,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, } // See if the add simplifies away. - if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, DL, TLI, DT)) { + if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, DL, TLI, DT, AT)) { // If we simplified the operands, the LHS is no longer an input, but Res // is. RemoveInstInputs(LHS, InstInputs); @@ -369,7 +369,7 @@ InsertPHITranslatedSubExpr(Value *InVal, BasicBlock *CurBB, SmallVectorImpl<Instruction*> &NewInsts) { // See if we have a version of this value already available and dominating // PredBB. If so, there is no need to insert a new instance of it. - PHITransAddr Tmp(InVal, DL); + PHITransAddr Tmp(InVal, DL, AT); if (!Tmp.PHITranslateValue(CurBB, PredBB, &DT)) return Tmp.getAddr(); diff --git a/lib/Analysis/PtrUseVisitor.cpp b/lib/Analysis/PtrUseVisitor.cpp index 1b0f359..68c7535 100644 --- a/lib/Analysis/PtrUseVisitor.cpp +++ b/lib/Analysis/PtrUseVisitor.cpp @@ -17,7 +17,7 @@ using namespace llvm; void detail::PtrUseVisitorBase::enqueueUsers(Instruction &I) { for (Use &U : I.uses()) { - if (VisitedUses.insert(&U)) { + if (VisitedUses.insert(&U).second) { UseToVisit NewU = { UseToVisit::UseAndIsOffsetKnownPair(&U, IsOffsetKnown), Offset diff --git a/lib/Analysis/RegionInfo.cpp b/lib/Analysis/RegionInfo.cpp index 7f88ae1..08ebf0d 100644 --- a/lib/Analysis/RegionInfo.cpp +++ b/lib/Analysis/RegionInfo.cpp @@ -10,6 +10,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/RegionInfo.h" +#include "llvm/Analysis/RegionInfoImpl.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/LoopInfo.h" @@ -25,21 +26,26 @@ using namespace llvm; #define DEBUG_TYPE "region" +namespace llvm { +template class RegionBase<RegionTraits<Function>>; +template class RegionNodeBase<RegionTraits<Function>>; +template class RegionInfoBase<RegionTraits<Function>>; +} + +STATISTIC(numRegions, "The # of regions"); +STATISTIC(numSimpleRegions, "The # of simple regions"); + // Always verify if expensive checking is enabled. -#ifdef XDEBUG -static bool VerifyRegionInfo = true; -#else -static bool VerifyRegionInfo = false; -#endif static cl::opt<bool,true> -VerifyRegionInfoX("verify-region-info", cl::location(VerifyRegionInfo), - cl::desc("Verify region info (time consuming)")); +VerifyRegionInfoX( + "verify-region-info", + cl::location(RegionInfoBase<RegionTraits<Function>>::VerifyRegionInfo), + cl::desc("Verify region info (time consuming)")); -STATISTIC(numRegions, "The # of regions"); -STATISTIC(numSimpleRegions, "The # of simple regions"); -static cl::opt<enum Region::PrintStyle> printStyle("print-region-style", +static cl::opt<Region::PrintStyle, true> printStyleX("print-region-style", + cl::location(RegionInfo::printStyle), cl::Hidden, cl::desc("style of printing regions"), cl::values( @@ -49,812 +55,110 @@ static cl::opt<enum Region::PrintStyle> printStyle("print-region-style", clEnumValN(Region::PrintRN, "rn", "print regions in detail with element_iterator"), clEnumValEnd)); -//===----------------------------------------------------------------------===// -/// Region Implementation -Region::Region(BasicBlock *Entry, BasicBlock *Exit, RegionInfo* RInfo, - DominatorTree *dt, Region *Parent) - : RegionNode(Parent, Entry, 1), RI(RInfo), DT(dt), exit(Exit) {} - -Region::~Region() { - // Free the cached nodes. - for (BBNodeMapT::iterator it = BBNodeMap.begin(), - ie = BBNodeMap.end(); it != ie; ++it) - delete it->second; - - // Only clean the cache for this Region. Caches of child Regions will be - // cleaned when the child Regions are deleted. - BBNodeMap.clear(); -} - -void Region::replaceEntry(BasicBlock *BB) { - entry.setPointer(BB); -} - -void Region::replaceExit(BasicBlock *BB) { - assert(exit && "No exit to replace!"); - exit = BB; -} - -void Region::replaceEntryRecursive(BasicBlock *NewEntry) { - std::vector<Region *> RegionQueue; - BasicBlock *OldEntry = getEntry(); - - RegionQueue.push_back(this); - while (!RegionQueue.empty()) { - Region *R = RegionQueue.back(); - RegionQueue.pop_back(); - - R->replaceEntry(NewEntry); - for (Region::const_iterator RI = R->begin(), RE = R->end(); RI != RE; ++RI) - if ((*RI)->getEntry() == OldEntry) - RegionQueue.push_back(RI->get()); - } -} - -void Region::replaceExitRecursive(BasicBlock *NewExit) { - std::vector<Region *> RegionQueue; - BasicBlock *OldExit = getExit(); - - RegionQueue.push_back(this); - while (!RegionQueue.empty()) { - Region *R = RegionQueue.back(); - RegionQueue.pop_back(); - - R->replaceExit(NewExit); - for (Region::const_iterator RI = R->begin(), RE = R->end(); RI != RE; ++RI) - if ((*RI)->getExit() == OldExit) - RegionQueue.push_back(RI->get()); - } -} - -bool Region::contains(const BasicBlock *B) const { - BasicBlock *BB = const_cast<BasicBlock*>(B); - - if (!DT->getNode(BB)) - return false; - - BasicBlock *entry = getEntry(), *exit = getExit(); - - // Toplevel region. - if (!exit) - return true; - - return (DT->dominates(entry, BB) - && !(DT->dominates(exit, BB) && DT->dominates(entry, exit))); -} - -bool Region::contains(const Loop *L) const { - // BBs that are not part of any loop are element of the Loop - // described by the NULL pointer. This loop is not part of any region, - // except if the region describes the whole function. - if (!L) - return getExit() == nullptr; - - if (!contains(L->getHeader())) - return false; - - SmallVector<BasicBlock *, 8> ExitingBlocks; - L->getExitingBlocks(ExitingBlocks); - - for (SmallVectorImpl<BasicBlock*>::iterator BI = ExitingBlocks.begin(), - BE = ExitingBlocks.end(); BI != BE; ++BI) - if (!contains(*BI)) - return false; - - return true; -} - -Loop *Region::outermostLoopInRegion(Loop *L) const { - if (!contains(L)) - return nullptr; - - while (L && contains(L->getParentLoop())) { - L = L->getParentLoop(); - } - - return L; -} - -Loop *Region::outermostLoopInRegion(LoopInfo *LI, BasicBlock* BB) const { - assert(LI && BB && "LI and BB cannot be null!"); - Loop *L = LI->getLoopFor(BB); - return outermostLoopInRegion(L); -} - -BasicBlock *Region::getEnteringBlock() const { - BasicBlock *entry = getEntry(); - BasicBlock *Pred; - BasicBlock *enteringBlock = nullptr; - - for (pred_iterator PI = pred_begin(entry), PE = pred_end(entry); PI != PE; - ++PI) { - Pred = *PI; - if (DT->getNode(Pred) && !contains(Pred)) { - if (enteringBlock) - return nullptr; - - enteringBlock = Pred; - } - } - - return enteringBlock; -} - -BasicBlock *Region::getExitingBlock() const { - BasicBlock *exit = getExit(); - BasicBlock *Pred; - BasicBlock *exitingBlock = nullptr; - - if (!exit) - return nullptr; - - for (pred_iterator PI = pred_begin(exit), PE = pred_end(exit); PI != PE; - ++PI) { - Pred = *PI; - if (contains(Pred)) { - if (exitingBlock) - return nullptr; - - exitingBlock = Pred; - } - } - - return exitingBlock; -} - -bool Region::isSimple() const { - return !isTopLevelRegion() && getEnteringBlock() && getExitingBlock(); -} - -std::string Region::getNameStr() const { - std::string exitName; - std::string entryName; - - if (getEntry()->getName().empty()) { - raw_string_ostream OS(entryName); - - getEntry()->printAsOperand(OS, false); - } else - entryName = getEntry()->getName(); - - if (getExit()) { - if (getExit()->getName().empty()) { - raw_string_ostream OS(exitName); - - getExit()->printAsOperand(OS, false); - } else - exitName = getExit()->getName(); - } else - exitName = "<Function Return>"; - - return entryName + " => " + exitName; -} - -void Region::verifyBBInRegion(BasicBlock *BB) const { - if (!contains(BB)) - llvm_unreachable("Broken region found!"); - - BasicBlock *entry = getEntry(), *exit = getExit(); - - for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) - if (!contains(*SI) && exit != *SI) - llvm_unreachable("Broken region found!"); - - if (entry != BB) - for (pred_iterator SI = pred_begin(BB), SE = pred_end(BB); SI != SE; ++SI) - if (!contains(*SI)) - llvm_unreachable("Broken region found!"); -} - -void Region::verifyWalk(BasicBlock *BB, std::set<BasicBlock*> *visited) const { - BasicBlock *exit = getExit(); - - visited->insert(BB); - - verifyBBInRegion(BB); - - for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) - if (*SI != exit && visited->find(*SI) == visited->end()) - verifyWalk(*SI, visited); -} - -void Region::verifyRegion() const { - // Only do verification when user wants to, otherwise this expensive - // check will be invoked by PassManager. - if (!VerifyRegionInfo) return; - - std::set<BasicBlock*> visited; - verifyWalk(getEntry(), &visited); -} - -void Region::verifyRegionNest() const { - for (Region::const_iterator RI = begin(), RE = end(); RI != RE; ++RI) - (*RI)->verifyRegionNest(); - - verifyRegion(); -} - -Region::element_iterator Region::element_begin() { - return GraphTraits<Region*>::nodes_begin(this); -} - -Region::element_iterator Region::element_end() { - return GraphTraits<Region*>::nodes_end(this); -} - -Region::const_element_iterator Region::element_begin() const { - return GraphTraits<const Region*>::nodes_begin(this); -} - -Region::const_element_iterator Region::element_end() const { - return GraphTraits<const Region*>::nodes_end(this); -} - -Region* Region::getSubRegionNode(BasicBlock *BB) const { - Region *R = RI->getRegionFor(BB); - - if (!R || R == this) - return nullptr; - - // If we pass the BB out of this region, that means our code is broken. - assert(contains(R) && "BB not in current region!"); - - while (contains(R->getParent()) && R->getParent() != this) - R = R->getParent(); - - if (R->getEntry() != BB) - return nullptr; - - return R; -} - -RegionNode* Region::getBBNode(BasicBlock *BB) const { - assert(contains(BB) && "Can get BB node out of this region!"); - - BBNodeMapT::const_iterator at = BBNodeMap.find(BB); - - if (at != BBNodeMap.end()) - return at->second; - - RegionNode *NewNode = new RegionNode(const_cast<Region*>(this), BB); - BBNodeMap.insert(std::make_pair(BB, NewNode)); - return NewNode; -} - -RegionNode* Region::getNode(BasicBlock *BB) const { - assert(contains(BB) && "Can get BB node out of this region!"); - if (Region* Child = getSubRegionNode(BB)) - return Child->getNode(); - - return getBBNode(BB); -} - -void Region::transferChildrenTo(Region *To) { - for (iterator I = begin(), E = end(); I != E; ++I) { - (*I)->parent = To; - To->children.push_back(std::move(*I)); - } - children.clear(); -} - -void Region::addSubRegion(Region *SubRegion, bool moveChildren) { - assert(!SubRegion->parent && "SubRegion already has a parent!"); - assert(std::find_if(begin(), end(), [&](const std::unique_ptr<Region> &R) { - return R.get() == SubRegion; - }) == children.end() && - "Subregion already exists!"); - - SubRegion->parent = this; - children.push_back(std::unique_ptr<Region>(SubRegion)); - - if (!moveChildren) - return; - - assert(SubRegion->children.size() == 0 - && "SubRegions that contain children are not supported"); - - for (element_iterator I = element_begin(), E = element_end(); I != E; ++I) - if (!(*I)->isSubRegion()) { - BasicBlock *BB = (*I)->getNodeAs<BasicBlock>(); - - if (SubRegion->contains(BB)) - RI->setRegionFor(BB, SubRegion); - } - - std::vector<std::unique_ptr<Region>> Keep; - for (iterator I = begin(), E = end(); I != E; ++I) - if (SubRegion->contains(I->get()) && I->get() != SubRegion) { - (*I)->parent = SubRegion; - SubRegion->children.push_back(std::move(*I)); - } else - Keep.push_back(std::move(*I)); - - children.clear(); - children.insert(children.begin(), - std::move_iterator<RegionSet::iterator>(Keep.begin()), - std::move_iterator<RegionSet::iterator>(Keep.end())); -} - - -Region *Region::removeSubRegion(Region *Child) { - assert(Child->parent == this && "Child is not a child of this region!"); - Child->parent = nullptr; - RegionSet::iterator I = std::find_if( - children.begin(), children.end(), - [&](const std::unique_ptr<Region> &R) { return R.get() == Child; }); - assert(I != children.end() && "Region does not exit. Unable to remove."); - children.erase(children.begin()+(I-begin())); - return Child; -} - -unsigned Region::getDepth() const { - unsigned Depth = 0; - - for (Region *R = parent; R != nullptr; R = R->parent) - ++Depth; - - return Depth; -} -Region *Region::getExpandedRegion() const { - unsigned NumSuccessors = exit->getTerminator()->getNumSuccessors(); - if (NumSuccessors == 0) - return nullptr; - - for (pred_iterator PI = pred_begin(getExit()), PE = pred_end(getExit()); - PI != PE; ++PI) - if (!DT->dominates(getEntry(), *PI)) - return nullptr; - - Region *R = RI->getRegionFor(exit); - - if (R->getEntry() != exit) { - if (exit->getTerminator()->getNumSuccessors() == 1) - return new Region(getEntry(), *succ_begin(exit), RI, DT); - else - return nullptr; - } - - while (R->getParent() && R->getParent()->getEntry() == exit) - R = R->getParent(); - - if (!DT->dominates(getEntry(), R->getExit())) - for (pred_iterator PI = pred_begin(getExit()), PE = pred_end(getExit()); - PI != PE; ++PI) - if (!DT->dominates(R->getExit(), *PI)) - return nullptr; - - return new Region(getEntry(), R->getExit(), RI, DT); -} - -void Region::print(raw_ostream &OS, bool print_tree, unsigned level, - enum PrintStyle Style) const { - if (print_tree) - OS.indent(level*2) << "[" << level << "] " << getNameStr(); - else - OS.indent(level*2) << getNameStr(); - - OS << "\n"; - - - if (Style != PrintNone) { - OS.indent(level*2) << "{\n"; - OS.indent(level*2 + 2); - - if (Style == PrintBB) { - for (const auto &BB : blocks()) - OS << BB->getName() << ", "; // TODO: remove the last "," - } else if (Style == PrintRN) { - for (const_element_iterator I = element_begin(), E = element_end(); I!=E; ++I) - OS << **I << ", "; // TODO: remove the last ", - } - - OS << "\n"; - } +//===----------------------------------------------------------------------===// +// Region implementation +// - if (print_tree) - for (const_iterator RI = begin(), RE = end(); RI != RE; ++RI) - (*RI)->print(OS, print_tree, level+1, Style); +Region::Region(BasicBlock *Entry, BasicBlock *Exit, + RegionInfo* RI, + DominatorTree *DT, Region *Parent) : + RegionBase<RegionTraits<Function>>(Entry, Exit, RI, DT, Parent) { - if (Style != PrintNone) - OS.indent(level*2) << "} \n"; } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void Region::dump() const { - print(dbgs(), true, getDepth(), printStyle.getValue()); -} -#endif - -void Region::clearNodeCache() { - // Free the cached nodes. - for (BBNodeMapT::iterator I = BBNodeMap.begin(), - IE = BBNodeMap.end(); I != IE; ++I) - delete I->second; - - BBNodeMap.clear(); - for (Region::iterator RI = begin(), RE = end(); RI != RE; ++RI) - (*RI)->clearNodeCache(); -} +Region::~Region() { } //===----------------------------------------------------------------------===// // RegionInfo implementation // -bool RegionInfo::isCommonDomFrontier(BasicBlock *BB, BasicBlock *entry, - BasicBlock *exit) const { - for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); PI != PE; ++PI) { - BasicBlock *P = *PI; - if (DT->dominates(entry, P) && !DT->dominates(exit, P)) - return false; - } - return true; -} - -bool RegionInfo::isRegion(BasicBlock *entry, BasicBlock *exit) const { - assert(entry && exit && "entry and exit must not be null!"); - typedef DominanceFrontier::DomSetType DST; - - DST *entrySuccs = &DF->find(entry)->second; - - // Exit is the header of a loop that contains the entry. In this case, - // the dominance frontier must only contain the exit. - if (!DT->dominates(entry, exit)) { - for (DST::iterator SI = entrySuccs->begin(), SE = entrySuccs->end(); - SI != SE; ++SI) - if (*SI != exit && *SI != entry) - return false; - - return true; - } - - DST *exitSuccs = &DF->find(exit)->second; - - // Do not allow edges leaving the region. - for (DST::iterator SI = entrySuccs->begin(), SE = entrySuccs->end(); - SI != SE; ++SI) { - if (*SI == exit || *SI == entry) - continue; - if (exitSuccs->find(*SI) == exitSuccs->end()) - return false; - if (!isCommonDomFrontier(*SI, entry, exit)) - return false; - } - - // Do not allow edges pointing into the region. - for (DST::iterator SI = exitSuccs->begin(), SE = exitSuccs->end(); - SI != SE; ++SI) - if (DT->properlyDominates(entry, *SI) && *SI != exit) - return false; +RegionInfo::RegionInfo() : + RegionInfoBase<RegionTraits<Function>>() { - - return true; -} - -void RegionInfo::insertShortCut(BasicBlock *entry, BasicBlock *exit, - BBtoBBMap *ShortCut) const { - assert(entry && exit && "entry and exit must not be null!"); - - BBtoBBMap::iterator e = ShortCut->find(exit); - - if (e == ShortCut->end()) - // No further region at exit available. - (*ShortCut)[entry] = exit; - else { - // We found a region e that starts at exit. Therefore (entry, e->second) - // is also a region, that is larger than (entry, exit). Insert the - // larger one. - BasicBlock *BB = e->second; - (*ShortCut)[entry] = BB; - } } -DomTreeNode* RegionInfo::getNextPostDom(DomTreeNode* N, - BBtoBBMap *ShortCut) const { - BBtoBBMap::iterator e = ShortCut->find(N->getBlock()); - - if (e == ShortCut->end()) - return N->getIDom(); +RegionInfo::~RegionInfo() { - return PDT->getNode(e->second)->getIDom(); -} - -bool RegionInfo::isTrivialRegion(BasicBlock *entry, BasicBlock *exit) const { - assert(entry && exit && "entry and exit must not be null!"); - - unsigned num_successors = succ_end(entry) - succ_begin(entry); - - if (num_successors <= 1 && exit == *(succ_begin(entry))) - return true; - - return false; } void RegionInfo::updateStatistics(Region *R) { ++numRegions; // TODO: Slow. Should only be enabled if -stats is used. - if (R->isSimple()) ++numSimpleRegions; -} - -Region *RegionInfo::createRegion(BasicBlock *entry, BasicBlock *exit) { - assert(entry && exit && "entry and exit must not be null!"); - - if (isTrivialRegion(entry, exit)) - return nullptr; - - Region *region = new Region(entry, exit, this, DT); - BBtoRegion.insert(std::make_pair(entry, region)); - - #ifdef XDEBUG - region->verifyRegion(); - #else - DEBUG(region->verifyRegion()); - #endif - - updateStatistics(region); - return region; -} - -void RegionInfo::findRegionsWithEntry(BasicBlock *entry, BBtoBBMap *ShortCut) { - assert(entry); - - DomTreeNode *N = PDT->getNode(entry); - - if (!N) - return; - - Region *lastRegion= nullptr; - BasicBlock *lastExit = entry; - - // As only a BasicBlock that postdominates entry can finish a region, walk the - // post dominance tree upwards. - while ((N = getNextPostDom(N, ShortCut))) { - BasicBlock *exit = N->getBlock(); - - if (!exit) - break; - - if (isRegion(entry, exit)) { - Region *newRegion = createRegion(entry, exit); - - if (lastRegion) - newRegion->addSubRegion(lastRegion); - - lastRegion = newRegion; - lastExit = exit; - } - - // This can never be a region, so stop the search. - if (!DT->dominates(entry, exit)) - break; - } - - // Tried to create regions from entry to lastExit. Next time take a - // shortcut from entry to lastExit. - if (lastExit != entry) - insertShortCut(entry, lastExit, ShortCut); + if (R->isSimple()) + ++numSimpleRegions; } -void RegionInfo::scanForRegions(Function &F, BBtoBBMap *ShortCut) { - BasicBlock *entry = &(F.getEntryBlock()); - DomTreeNode *N = DT->getNode(entry); - - // Iterate over the dominance tree in post order to start with the small - // regions from the bottom of the dominance tree. If the small regions are - // detected first, detection of bigger regions is faster, as we can jump - // over the small regions. - for (po_iterator<DomTreeNode*> FI = po_begin(N), FE = po_end(N); FI != FE; - ++FI) { - findRegionsWithEntry(FI->getBlock(), ShortCut); - } -} +void RegionInfo::recalculate(Function &F, DominatorTree *DT_, + PostDominatorTree *PDT_, DominanceFrontier *DF_) { + DT = DT_; + PDT = PDT_; + DF = DF_; -Region *RegionInfo::getTopMostParent(Region *region) { - while (region->parent) - region = region->getParent(); - - return region; + TopLevelRegion = new Region(&F.getEntryBlock(), nullptr, + this, DT, nullptr); + updateStatistics(TopLevelRegion); + calculate(F); } -void RegionInfo::buildRegionsTree(DomTreeNode *N, Region *region) { - BasicBlock *BB = N->getBlock(); - - // Passed region exit - while (BB == region->getExit()) - region = region->getParent(); - - BBtoRegionMap::iterator it = BBtoRegion.find(BB); - - // This basic block is a start block of a region. It is already in the - // BBtoRegion relation. Only the child basic blocks have to be updated. - if (it != BBtoRegion.end()) { - Region *newRegion = it->second; - region->addSubRegion(getTopMostParent(newRegion)); - region = newRegion; - } else { - BBtoRegion[BB] = region; - } +//===----------------------------------------------------------------------===// +// RegionInfoPass implementation +// - for (DomTreeNode::iterator CI = N->begin(), CE = N->end(); CI != CE; ++CI) - buildRegionsTree(*CI, region); +RegionInfoPass::RegionInfoPass() : FunctionPass(ID) { + initializeRegionInfoPassPass(*PassRegistry::getPassRegistry()); } -void RegionInfo::releaseMemory() { - BBtoRegion.clear(); - if (TopLevelRegion) - delete TopLevelRegion; - TopLevelRegion = nullptr; -} +RegionInfoPass::~RegionInfoPass() { -RegionInfo::RegionInfo() : FunctionPass(ID) { - initializeRegionInfoPass(*PassRegistry::getPassRegistry()); - TopLevelRegion = nullptr; } -RegionInfo::~RegionInfo() { +bool RegionInfoPass::runOnFunction(Function &F) { releaseMemory(); -} -void RegionInfo::Calculate(Function &F) { - // ShortCut a function where for every BB the exit of the largest region - // starting with BB is stored. These regions can be threated as single BBS. - // This improves performance on linear CFGs. - BBtoBBMap ShortCut; + auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto PDT = &getAnalysis<PostDominatorTree>(); + auto DF = &getAnalysis<DominanceFrontier>(); - scanForRegions(F, &ShortCut); - BasicBlock *BB = &F.getEntryBlock(); - buildRegionsTree(DT->getNode(BB), TopLevelRegion); + RI.recalculate(F, DT, PDT, DF); + return false; } -bool RegionInfo::runOnFunction(Function &F) { - releaseMemory(); - - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - PDT = &getAnalysis<PostDominatorTree>(); - DF = &getAnalysis<DominanceFrontier>(); - - TopLevelRegion = new Region(&F.getEntryBlock(), nullptr, this, DT, nullptr); - updateStatistics(TopLevelRegion); - - Calculate(F); +void RegionInfoPass::releaseMemory() { + RI.releaseMemory(); +} - return false; +void RegionInfoPass::verifyAnalysis() const { + RI.verifyAnalysis(); } -void RegionInfo::getAnalysisUsage(AnalysisUsage &AU) const { +void RegionInfoPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequiredTransitive<DominatorTreeWrapperPass>(); AU.addRequired<PostDominatorTree>(); AU.addRequired<DominanceFrontier>(); } -void RegionInfo::print(raw_ostream &OS, const Module *) const { - OS << "Region tree:\n"; - TopLevelRegion->print(OS, true, 0, printStyle.getValue()); - OS << "End region tree\n"; -} - -void RegionInfo::verifyAnalysis() const { - // Only do verification when user wants to, otherwise this expensive check - // will be invoked by PMDataManager::verifyPreservedAnalysis when - // a regionpass (marked PreservedAll) finish. - if (!VerifyRegionInfo) return; - - TopLevelRegion->verifyRegionNest(); -} - -// Region pass manager support. -Region *RegionInfo::getRegionFor(BasicBlock *BB) const { - BBtoRegionMap::const_iterator I= - BBtoRegion.find(BB); - return I != BBtoRegion.end() ? I->second : nullptr; -} - -void RegionInfo::setRegionFor(BasicBlock *BB, Region *R) { - BBtoRegion[BB] = R; -} - -Region *RegionInfo::operator[](BasicBlock *BB) const { - return getRegionFor(BB); +void RegionInfoPass::print(raw_ostream &OS, const Module *) const { + RI.print(OS); } -BasicBlock *RegionInfo::getMaxRegionExit(BasicBlock *BB) const { - BasicBlock *Exit = nullptr; - - while (true) { - // Get largest region that starts at BB. - Region *R = getRegionFor(BB); - while (R && R->getParent() && R->getParent()->getEntry() == BB) - R = R->getParent(); - - // Get the single exit of BB. - if (R && R->getEntry() == BB) - Exit = R->getExit(); - else if (++succ_begin(BB) == succ_end(BB)) - Exit = *succ_begin(BB); - else // No single exit exists. - return Exit; - - // Get largest region that starts at Exit. - Region *ExitR = getRegionFor(Exit); - while (ExitR && ExitR->getParent() - && ExitR->getParent()->getEntry() == Exit) - ExitR = ExitR->getParent(); - - for (pred_iterator PI = pred_begin(Exit), PE = pred_end(Exit); PI != PE; - ++PI) - if (!R->contains(*PI) && !ExitR->contains(*PI)) - break; - - // This stops infinite cycles. - if (DT->dominates(Exit, BB)) - break; - - BB = Exit; - } - - return Exit; -} - -Region* -RegionInfo::getCommonRegion(Region *A, Region *B) const { - assert (A && B && "One of the Regions is NULL"); - - if (A->contains(B)) return A; - - while (!B->contains(A)) - B = B->getParent(); - - return B; -} - -Region* -RegionInfo::getCommonRegion(SmallVectorImpl<Region*> &Regions) const { - Region* ret = Regions.back(); - Regions.pop_back(); - - for (SmallVectorImpl<Region*>::const_iterator I = Regions.begin(), - E = Regions.end(); I != E; ++I) - ret = getCommonRegion(ret, *I); - - return ret; -} - -Region* -RegionInfo::getCommonRegion(SmallVectorImpl<BasicBlock*> &BBs) const { - Region* ret = getRegionFor(BBs.back()); - BBs.pop_back(); - - for (SmallVectorImpl<BasicBlock*>::const_iterator I = BBs.begin(), - E = BBs.end(); I != E; ++I) - ret = getCommonRegion(ret, getRegionFor(*I)); - - return ret; +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void RegionInfoPass::dump() const { + RI.dump(); } +#endif -void RegionInfo::splitBlock(BasicBlock* NewBB, BasicBlock *OldBB) -{ - Region *R = getRegionFor(OldBB); - - setRegionFor(NewBB, R); - - while (R->getEntry() == OldBB && !R->isTopLevelRegion()) { - R->replaceEntry(NewBB); - R = R->getParent(); - } - - setRegionFor(OldBB, R); -} +char RegionInfoPass::ID = 0; -char RegionInfo::ID = 0; -INITIALIZE_PASS_BEGIN(RegionInfo, "regions", +INITIALIZE_PASS_BEGIN(RegionInfoPass, "regions", "Detect single entry single exit regions", true, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(PostDominatorTree) INITIALIZE_PASS_DEPENDENCY(DominanceFrontier) -INITIALIZE_PASS_END(RegionInfo, "regions", +INITIALIZE_PASS_END(RegionInfoPass, "regions", "Detect single entry single exit regions", true, true) // Create methods available outside of this file, to use them @@ -863,7 +167,7 @@ INITIALIZE_PASS_END(RegionInfo, "regions", namespace llvm { FunctionPass *createRegionInfoPass() { - return new RegionInfo(); + return new RegionInfoPass(); } } diff --git a/lib/Analysis/RegionPass.cpp b/lib/Analysis/RegionPass.cpp index 71de144..de34b72 100644 --- a/lib/Analysis/RegionPass.cpp +++ b/lib/Analysis/RegionPass.cpp @@ -45,14 +45,14 @@ static void addRegionIntoQueue(Region &R, std::deque<Region *> &RQ) { /// Pass Manager itself does not invalidate any analysis info. void RGPassManager::getAnalysisUsage(AnalysisUsage &Info) const { - Info.addRequired<RegionInfo>(); + Info.addRequired<RegionInfoPass>(); Info.setPreservesAll(); } /// run - Execute all of the passes scheduled for execution. Keep track of /// whether any of the passes modifies the function, and if so, return true. bool RGPassManager::runOnFunction(Function &F) { - RI = &getAnalysis<RegionInfo>(); + RI = &getAnalysis<RegionInfoPass>().getRegionInfo(); bool Changed = false; // Collect inherited analysis from Module level pass manager. diff --git a/lib/Analysis/RegionPrinter.cpp b/lib/Analysis/RegionPrinter.cpp index 893210a..ad83113 100644 --- a/lib/Analysis/RegionPrinter.cpp +++ b/lib/Analysis/RegionPrinter.cpp @@ -56,23 +56,24 @@ struct DOTGraphTraits<RegionNode*> : public DefaultDOTGraphTraits { }; template<> -struct DOTGraphTraits<RegionInfo*> : public DOTGraphTraits<RegionNode*> { +struct DOTGraphTraits<RegionInfoPass*> : public DOTGraphTraits<RegionNode*> { - DOTGraphTraits (bool isSimple=false) + DOTGraphTraits (bool isSimple = false) : DOTGraphTraits<RegionNode*>(isSimple) {} - static std::string getGraphName(RegionInfo *DT) { + static std::string getGraphName(RegionInfoPass *DT) { return "Region Graph"; } - std::string getNodeLabel(RegionNode *Node, RegionInfo *G) { + std::string getNodeLabel(RegionNode *Node, RegionInfoPass *G) { + RegionInfo &RI = G->getRegionInfo(); return DOTGraphTraits<RegionNode*>::getNodeLabel(Node, - G->getTopLevelRegion()); + reinterpret_cast<RegionNode*>(RI.getTopLevelRegion())); } std::string getEdgeAttributes(RegionNode *srcNode, - GraphTraits<RegionInfo*>::ChildIteratorType CI, RegionInfo *RI) { - + GraphTraits<RegionInfo*>::ChildIteratorType CI, RegionInfoPass *G) { + RegionInfo &RI = G->getRegionInfo(); RegionNode *destNode = *CI; if (srcNode->isSubRegion() || destNode->isSubRegion()) @@ -82,7 +83,7 @@ struct DOTGraphTraits<RegionInfo*> : public DOTGraphTraits<RegionNode*> { BasicBlock *srcBB = srcNode->getNodeAs<BasicBlock>(); BasicBlock *destBB = destNode->getNodeAs<BasicBlock>(); - Region *R = RI->getRegionFor(destBB); + Region *R = RI.getRegionFor(destBB); while (R && R->getParent()) if (R->getParent()->getEntry() == destBB) @@ -98,7 +99,8 @@ struct DOTGraphTraits<RegionInfo*> : public DOTGraphTraits<RegionNode*> { // Print the cluster of the subregions. This groups the single basic blocks // and adds a different background color for each group. - static void printRegionCluster(const Region &R, GraphWriter<RegionInfo*> &GW, + static void printRegionCluster(const Region &R, + GraphWriter<RegionInfoPass*> &GW, unsigned depth = 0) { raw_ostream &O = GW.getOStream(); O.indent(2 * depth) << "subgraph cluster_" << static_cast<const void*>(&R) @@ -119,22 +121,23 @@ struct DOTGraphTraits<RegionInfo*> : public DOTGraphTraits<RegionNode*> { for (Region::const_iterator RI = R.begin(), RE = R.end(); RI != RE; ++RI) printRegionCluster(**RI, GW, depth + 1); - RegionInfo *RI = R.getRegionInfo(); + const RegionInfo &RI = *static_cast<const RegionInfo*>(R.getRegionInfo()); for (const auto &BB : R.blocks()) - if (RI->getRegionFor(BB) == &R) + if (RI.getRegionFor(BB) == &R) O.indent(2 * (depth + 1)) << "Node" - << static_cast<const void*>(RI->getTopLevelRegion()->getBBNode(BB)) + << static_cast<const void*>(RI.getTopLevelRegion()->getBBNode(BB)) << ";\n"; O.indent(2 * depth) << "}\n"; } - static void addCustomGraphFeatures(const RegionInfo* RI, - GraphWriter<RegionInfo*> &GW) { + static void addCustomGraphFeatures(const RegionInfoPass* RIP, + GraphWriter<RegionInfoPass*> &GW) { + const RegionInfo &RI = RIP->getRegionInfo(); raw_ostream &O = GW.getOStream(); O << "\tcolorscheme = \"paired12\"\n"; - printRegionCluster(*RI->getTopLevelRegion(), GW, 4); + printRegionCluster(*RI.getTopLevelRegion(), GW, 4); } }; } //end namespace llvm @@ -142,28 +145,28 @@ struct DOTGraphTraits<RegionInfo*> : public DOTGraphTraits<RegionNode*> { namespace { struct RegionViewer - : public DOTGraphTraitsViewer<RegionInfo, false> { + : public DOTGraphTraitsViewer<RegionInfoPass, false> { static char ID; - RegionViewer() : DOTGraphTraitsViewer<RegionInfo, false>("reg", ID){ + RegionViewer() : DOTGraphTraitsViewer<RegionInfoPass, false>("reg", ID){ initializeRegionViewerPass(*PassRegistry::getPassRegistry()); } }; char RegionViewer::ID = 0; struct RegionOnlyViewer - : public DOTGraphTraitsViewer<RegionInfo, true> { + : public DOTGraphTraitsViewer<RegionInfoPass, true> { static char ID; - RegionOnlyViewer() : DOTGraphTraitsViewer<RegionInfo, true>("regonly", ID) { + RegionOnlyViewer() : DOTGraphTraitsViewer<RegionInfoPass, true>("regonly", ID) { initializeRegionOnlyViewerPass(*PassRegistry::getPassRegistry()); } }; char RegionOnlyViewer::ID = 0; struct RegionPrinter - : public DOTGraphTraitsPrinter<RegionInfo, false> { + : public DOTGraphTraitsPrinter<RegionInfoPass, false> { static char ID; RegionPrinter() : - DOTGraphTraitsPrinter<RegionInfo, false>("reg", ID) { + DOTGraphTraitsPrinter<RegionInfoPass, false>("reg", ID) { initializeRegionPrinterPass(*PassRegistry::getPassRegistry()); } }; @@ -175,7 +178,7 @@ INITIALIZE_PASS(RegionPrinter, "dot-regions", INITIALIZE_PASS(RegionViewer, "view-regions", "View regions of function", true, true) - + INITIALIZE_PASS(RegionOnlyViewer, "view-regions-only", "View regions of function (with no function bodies)", true, true) @@ -183,10 +186,10 @@ INITIALIZE_PASS(RegionOnlyViewer, "view-regions-only", namespace { struct RegionOnlyPrinter - : public DOTGraphTraitsPrinter<RegionInfo, true> { + : public DOTGraphTraitsPrinter<RegionInfoPass, true> { static char ID; RegionOnlyPrinter() : - DOTGraphTraitsPrinter<RegionInfo, true>("reg", ID) { + DOTGraphTraitsPrinter<RegionInfoPass, true>("reg", ID) { initializeRegionOnlyPrinterPass(*PassRegistry::getPassRegistry()); } }; diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 06dbde5..68549ef 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -1,4 +1,4 @@ -//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===// +//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===// // // The LLVM Compiler Infrastructure // @@ -59,9 +59,11 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -78,6 +80,7 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -113,6 +116,7 @@ VerifySCEV("verify-scev", INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution", "Scalar Evolution Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(LoopInfo) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) @@ -671,7 +675,321 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, } } +static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) { + APInt A = C1->getValue()->getValue(); + APInt B = C2->getValue()->getValue(); + uint32_t ABW = A.getBitWidth(); + uint32_t BBW = B.getBitWidth(); + + if (ABW > BBW) + B = B.sext(ABW); + else if (ABW < BBW) + A = A.sext(BBW); + + return APIntOps::srem(A, B); +} + +static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) { + APInt A = C1->getValue()->getValue(); + APInt B = C2->getValue()->getValue(); + uint32_t ABW = A.getBitWidth(); + uint32_t BBW = B.getBitWidth(); + + if (ABW > BBW) + B = B.sext(ABW); + else if (ABW < BBW) + A = A.sext(BBW); + + return APIntOps::sdiv(A, B); +} + +static const APInt urem(const SCEVConstant *C1, const SCEVConstant *C2) { + APInt A = C1->getValue()->getValue(); + APInt B = C2->getValue()->getValue(); + uint32_t ABW = A.getBitWidth(); + uint32_t BBW = B.getBitWidth(); + + if (ABW > BBW) + B = B.zext(ABW); + else if (ABW < BBW) + A = A.zext(BBW); + + return APIntOps::urem(A, B); +} + +static const APInt udiv(const SCEVConstant *C1, const SCEVConstant *C2) { + APInt A = C1->getValue()->getValue(); + APInt B = C2->getValue()->getValue(); + uint32_t ABW = A.getBitWidth(); + uint32_t BBW = B.getBitWidth(); + + if (ABW > BBW) + B = B.zext(ABW); + else if (ABW < BBW) + A = A.zext(BBW); + + return APIntOps::udiv(A, B); +} + +namespace { +struct FindSCEVSize { + int Size; + FindSCEVSize() : Size(0) {} + + bool follow(const SCEV *S) { + ++Size; + // Keep looking at all operands of S. + return true; + } + bool isDone() const { + return false; + } +}; +} + +// Returns the size of the SCEV S. +static inline int sizeOfSCEV(const SCEV *S) { + FindSCEVSize F; + SCEVTraversal<FindSCEVSize> ST(F); + ST.visitAll(S); + return F.Size; +} +namespace { + +template <typename Derived> +struct SCEVDivision : public SCEVVisitor<Derived, void> { +public: + // Computes the Quotient and Remainder of the division of Numerator by + // Denominator. + static void divide(ScalarEvolution &SE, const SCEV *Numerator, + const SCEV *Denominator, const SCEV **Quotient, + const SCEV **Remainder) { + assert(Numerator && Denominator && "Uninitialized SCEV"); + + Derived D(SE, Numerator, Denominator); + + // Check for the trivial case here to avoid having to check for it in the + // rest of the code. + if (Numerator == Denominator) { + *Quotient = D.One; + *Remainder = D.Zero; + return; + } + + if (Numerator->isZero()) { + *Quotient = D.Zero; + *Remainder = D.Zero; + return; + } + + // Split the Denominator when it is a product. + if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) { + const SCEV *Q, *R; + *Quotient = Numerator; + for (const SCEV *Op : T->operands()) { + divide(SE, *Quotient, Op, &Q, &R); + *Quotient = Q; + + // Bail out when the Numerator is not divisible by one of the terms of + // the Denominator. + if (!R->isZero()) { + *Quotient = D.Zero; + *Remainder = Numerator; + return; + } + } + *Remainder = D.Zero; + return; + } + + D.visit(Numerator); + *Quotient = D.Quotient; + *Remainder = D.Remainder; + } + + // Except in the trivial case described above, we do not know how to divide + // Expr by Denominator for the following functions with empty implementation. + void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} + void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} + void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} + void visitUDivExpr(const SCEVUDivExpr *Numerator) {} + void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} + void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} + void visitUnknown(const SCEVUnknown *Numerator) {} + void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} + + void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { + const SCEV *StartQ, *StartR, *StepQ, *StepR; + assert(Numerator->isAffine() && "Numerator should be affine"); + divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); + divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); + Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), + Numerator->getNoWrapFlags()); + Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), + Numerator->getNoWrapFlags()); + } + + void visitAddExpr(const SCEVAddExpr *Numerator) { + SmallVector<const SCEV *, 2> Qs, Rs; + Type *Ty = Denominator->getType(); + + for (const SCEV *Op : Numerator->operands()) { + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + + // Bail out if types do not match. + if (Ty != Q->getType() || Ty != R->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + Qs.push_back(Q); + Rs.push_back(R); + } + + if (Qs.size() == 1) { + Quotient = Qs[0]; + Remainder = Rs[0]; + return; + } + + Quotient = SE.getAddExpr(Qs); + Remainder = SE.getAddExpr(Rs); + } + + void visitMulExpr(const SCEVMulExpr *Numerator) { + SmallVector<const SCEV *, 2> Qs; + Type *Ty = Denominator->getType(); + + bool FoundDenominatorTerm = false; + for (const SCEV *Op : Numerator->operands()) { + // Bail out if types do not match. + if (Ty != Op->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + if (FoundDenominatorTerm) { + Qs.push_back(Op); + continue; + } + + // Check whether Denominator divides one of the product operands. + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + if (!R->isZero()) { + Qs.push_back(Op); + continue; + } + + // Bail out if types do not match. + if (Ty != Q->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + FoundDenominatorTerm = true; + Qs.push_back(Q); + } + + if (FoundDenominatorTerm) { + Remainder = Zero; + if (Qs.size() == 1) + Quotient = Qs[0]; + else + Quotient = SE.getMulExpr(Qs); + return; + } + + if (!isa<SCEVUnknown>(Denominator)) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + // The Remainder is obtained by replacing Denominator by 0 in Numerator. + ValueToValueMap RewriteMap; + RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = + cast<SCEVConstant>(Zero)->getValue(); + Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + + if (Remainder->isZero()) { + // The Quotient is obtained by replacing Denominator by 1 in Numerator. + RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = + cast<SCEVConstant>(One)->getValue(); + Quotient = + SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + return; + } + + // Quotient is (Numerator - Remainder) divided by Denominator. + const SCEV *Q, *R; + const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); + if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) { + // This SCEV does not seem to simplify: fail the division here. + Quotient = Zero; + Remainder = Numerator; + return; + } + divide(SE, Diff, Denominator, &Q, &R); + assert(R == Zero && + "(Numerator - Remainder) should evenly divide Denominator"); + Quotient = Q; + } + +private: + SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator) + : SE(S), Denominator(Denominator) { + Zero = SE.getConstant(Denominator->getType(), 0); + One = SE.getConstant(Denominator->getType(), 1); + + // By default, we don't know how to divide Expr by Denominator. + // Providing the default here simplifies the rest of the code. + Quotient = Zero; + Remainder = Numerator; + } + + ScalarEvolution &SE; + const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; + + friend struct SCEVSDivision; + friend struct SCEVUDivision; +}; + +struct SCEVSDivision : public SCEVDivision<SCEVSDivision> { + SCEVSDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator) + : SCEVDivision(S, Numerator, Denominator) {} + + void visitConstant(const SCEVConstant *Numerator) { + if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { + Quotient = SE.getConstant(sdiv(Numerator, D)); + Remainder = SE.getConstant(srem(Numerator, D)); + return; + } + } +}; + +struct SCEVUDivision : public SCEVDivision<SCEVUDivision> { + SCEVUDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator) + : SCEVDivision(S, Numerator, Denominator) {} + + void visitConstant(const SCEVConstant *Numerator) { + if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { + Quotient = SE.getConstant(udiv(Numerator, D)); + Remainder = SE.getConstant(urem(Numerator, D)); + return; + } + } +}; + +} //===----------------------------------------------------------------------===// // Simple SCEV method implementations @@ -2061,71 +2379,66 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // Okay, if there weren't any loop invariants to be folded, check to see if // there are multiple AddRec's with the same loop induction variable being // multiplied together. If so, we can fold them. + + // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L> + // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ + // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z + // ]]],+,...up to x=2n}. + // Note that the arguments to choose() are always integers with values + // known at compile time, never SCEV objects. + // + // The implementation avoids pointless extra computations when the two + // addrec's are of different length (mathematically, it's equivalent to + // an infinite stream of zeros on the right). + bool OpsModified = false; for (unsigned OtherIdx = Idx+1; - OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); + OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); ++OtherIdx) { - if (AddRecLoop != cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) + const SCEVAddRecExpr *OtherAddRec = + dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]); + if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) continue; - // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L> - // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ - // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z - // ]]],+,...up to x=2n}. - // Note that the arguments to choose() are always integers with values - // known at compile time, never SCEV objects. - // - // The implementation avoids pointless extra computations when the two - // addrec's are of different length (mathematically, it's equivalent to - // an infinite stream of zeros on the right). - bool OpsModified = false; - for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); - ++OtherIdx) { - const SCEVAddRecExpr *OtherAddRec = - dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]); - if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) - continue; - - bool Overflow = false; - Type *Ty = AddRec->getType(); - bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; - SmallVector<const SCEV*, 7> AddRecOps; - for (int x = 0, xe = AddRec->getNumOperands() + - OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { - const SCEV *Term = getConstant(Ty, 0); - for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { - uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); - for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), - ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); - z < ze && !Overflow; ++z) { - uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); - uint64_t Coeff; - if (LargerThan64Bits) - Coeff = umul_ov(Coeff1, Coeff2, Overflow); - else - Coeff = Coeff1*Coeff2; - const SCEV *CoeffTerm = getConstant(Ty, Coeff); - const SCEV *Term1 = AddRec->getOperand(y-z); - const SCEV *Term2 = OtherAddRec->getOperand(z); - Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); - } + bool Overflow = false; + Type *Ty = AddRec->getType(); + bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; + SmallVector<const SCEV*, 7> AddRecOps; + for (int x = 0, xe = AddRec->getNumOperands() + + OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { + const SCEV *Term = getConstant(Ty, 0); + for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { + uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); + for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), + ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); + z < ze && !Overflow; ++z) { + uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); + uint64_t Coeff; + if (LargerThan64Bits) + Coeff = umul_ov(Coeff1, Coeff2, Overflow); + else + Coeff = Coeff1*Coeff2; + const SCEV *CoeffTerm = getConstant(Ty, Coeff); + const SCEV *Term1 = AddRec->getOperand(y-z); + const SCEV *Term2 = OtherAddRec->getOperand(z); + Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); } - AddRecOps.push_back(Term); - } - if (!Overflow) { - const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), - SCEV::FlagAnyWrap); - if (Ops.size() == 2) return NewAddRec; - Ops[Idx] = NewAddRec; - Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; - OpsModified = true; - AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec); - if (!AddRec) - break; } + AddRecOps.push_back(Term); + } + if (!Overflow) { + const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), + SCEV::FlagAnyWrap); + if (Ops.size() == 2) return NewAddRec; + Ops[Idx] = NewAddRec; + Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; + OpsModified = true; + AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec); + if (!AddRec) + break; } - if (OpsModified) - return getMulExpr(Ops); } + if (OpsModified) + return getMulExpr(Ops); // Otherwise couldn't fold anything into this recurrence. Move onto the // next one. @@ -3082,7 +3395,8 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { Visited.insert(PN); while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast<Value *>(I)); @@ -3263,7 +3577,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, DL, TLI, DT)) + if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AT)) if (LI->replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); @@ -3395,7 +3709,7 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. unsigned BitWidth = getTypeSizeInBits(U->getType()); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AT, nullptr, DT); return Zeros.countTrailingOnes(); } @@ -3403,6 +3717,31 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { return 0; } +/// GetRangeFromMetadata - Helper method to assign a range to V from +/// metadata present in the IR. +static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { + if (Instruction *I = dyn_cast<Instruction>(V)) { + if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) { + ConstantRange TotalRange( + cast<IntegerType>(I->getType())->getBitWidth(), false); + + unsigned NumRanges = MD->getNumOperands() / 2; + assert(NumRanges >= 1); + + for (unsigned i = 0; i < NumRanges; ++i) { + ConstantInt *Lower = cast<ConstantInt>(MD->getOperand(2*i + 0)); + ConstantInt *Upper = cast<ConstantInt>(MD->getOperand(2*i + 1)); + ConstantRange Range(Lower->getValue(), Upper->getValue()); + TotalRange = TotalRange.unionWith(Range); + } + + return TotalRange; + } + } + + return None; +} + /// getUnsignedRange - Determine the unsigned range for a particular SCEV. /// ConstantRange @@ -3532,9 +3871,14 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { } if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + // Check if the IR explicitly contains !range metadata. + Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); + if (MDRange.hasValue()) + ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); + // For a SCEVUnknown, ask ValueTracking. APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AT, nullptr, DT); if (Ones == ~Zeros + 1) return setUnsignedRange(U, ConservativeResult); return setUnsignedRange(U, @@ -3683,10 +4027,15 @@ ScalarEvolution::getSignedRange(const SCEV *S) { } if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + // Check if the IR explicitly contains !range metadata. + Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); + if (MDRange.hasValue()) + ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); + // For a SCEVUnknown, ask ValueTracking. if (!U->getValue()->getType()->isIntegerTy() && !DL) return setSignedRange(U, ConservativeResult); - unsigned NS = ComputeNumSignBits(U->getValue(), DL); + unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AT, nullptr, DT); if (NS <= 1) return setSignedRange(U, ConservativeResult); return setSignedRange(U, ConservativeResult.intersectWith( @@ -3793,7 +4142,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, + 0, AT, nullptr, DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); @@ -4070,6 +4420,14 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Iteration Count Computation Code // +unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { + if (BasicBlock *ExitingBB = L->getExitingBlock()) + return getSmallConstantTripCount(L, ExitingBB); + + // No trip count information for multiple exits. + return 0; +} + /// getSmallConstantTripCount - Returns the maximum trip count of this loop as a /// normal unsigned value. Returns 0 if the trip count is unknown or not /// constant. Will also return 0 if the maximum trip count is very large (>= @@ -4080,19 +4438,13 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { /// before taking the branch. For loops with multiple exits, it may not be the /// number times that the loop header executes because the loop may exit /// prematurely via another branch. -/// -/// FIXME: We conservatively call getBackedgeTakenCount(L) instead of -/// getExitCount(L, ExitingBlock) to compute a safe trip count considering all -/// loop exits. getExitCount() may return an exact count for this branch -/// assuming no-signed-wrap. The number of well-defined iterations may actually -/// be higher than this trip count if this exit test is skipped and the loop -/// exits via a different branch. Ideally, getExitCount() would know whether it -/// depends on a NSW assumption, and we would only fall back to a conservative -/// trip count in that case. -unsigned ScalarEvolution:: -getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { +unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, + BasicBlock *ExitingBlock) { + assert(ExitingBlock && "Must pass a non-null exiting block!"); + assert(L->isLoopExiting(ExitingBlock) && + "Exiting block must actually branch out of the loop!"); const SCEVConstant *ExitCount = - dyn_cast<SCEVConstant>(getBackedgeTakenCount(L)); + dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock)); if (!ExitCount) return 0; @@ -4106,6 +4458,14 @@ getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { return ((unsigned)ExitConst->getZExtValue()) + 1; } +unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { + if (BasicBlock *ExitingBB = L->getExitingBlock()) + return getSmallConstantTripMultiple(L, ExitingBB); + + // No trip multiple information for multiple exits. + return 0; +} + /// getSmallConstantTripMultiple - Returns the largest constant divisor of the /// trip count of this loop as a normal unsigned value, if possible. This /// means that the actual trip count is always a multiple of the returned @@ -4118,9 +4478,13 @@ getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { /// /// As explained in the comments for getSmallConstantTripCount, this assumes /// that control exits the loop via ExitingBlock. -unsigned ScalarEvolution:: -getSmallConstantTripMultiple(Loop *L, BasicBlock * /*ExitingBlock*/) { - const SCEV *ExitCount = getBackedgeTakenCount(L); +unsigned +ScalarEvolution::getSmallConstantTripMultiple(Loop *L, + BasicBlock *ExitingBlock) { + assert(ExitingBlock && "Must pass a non-null exiting block!"); + assert(L->isLoopExiting(ExitingBlock) && + "Exiting block must actually branch out of the loop!"); + const SCEV *ExitCount = getExitCount(L, ExitingBlock); if (ExitCount == getCouldNotCompute()) return 1; @@ -4230,7 +4594,8 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { SmallPtrSet<Instruction *, 8> Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast<Value *>(I)); @@ -4282,7 +4647,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) { SmallPtrSet<Instruction *, 8> Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast<Value *>(I)); @@ -4316,7 +4682,8 @@ void ScalarEvolution::forgetValue(Value *V) { SmallPtrSet<Instruction *, 8> Visited; while (!Worklist.empty()) { I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast<Value *>(I)); @@ -4467,20 +4834,12 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { // non-exiting iterations. Partition the loop exits into two kinds: // LoopMustExits and LoopMayExits. // - // A LoopMustExit meets two requirements: - // - // (a) Its ExitLimit.MustExit flag must be set which indicates that the exit - // test condition cannot be skipped (the tested variable has unit stride or - // the test is less-than or greater-than, rather than a strict inequality). - // - // (b) It must dominate the loop latch, hence must be tested on every loop - // iteration. - // - // If any computable LoopMustExit is found, then MaxBECount is the minimum - // EL.Max of computable LoopMustExits. Otherwise, MaxBECount is - // conservatively the maximum EL.Max, where CouldNotCompute is considered - // greater than any computable EL.Max. - if (EL.MustExit && EL.Max != getCouldNotCompute() && Latch && + // If the exit dominates the loop latch, it is a LoopMustExit otherwise it + // is a LoopMayExit. If any computable LoopMustExit is found, then + // MaxBECount is the minimum EL.Max of computable LoopMustExits. Otherwise, + // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is + // considered greater than any computable EL.Max. + if (EL.Max != getCouldNotCompute() && Latch && DT->dominates(ExitBB, Latch)) { if (!MustExitMaxBECount) MustExitMaxBECount = EL.Max; @@ -4567,18 +4926,19 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { return getCouldNotCompute(); } + bool IsOnlyExit = (L->getExitingBlock() != nullptr); TerminatorInst *Term = ExitingBlock->getTerminator(); if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); // Proceed to the next level to examine the exit condition expression. return ComputeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1), - /*IsSubExpr=*/false); + /*ControlsExit=*/IsOnlyExit); } if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) return ComputeExitLimitFromSingleExitSwitch(L, SI, Exit, - /*IsSubExpr=*/false); + /*ControlsExit=*/IsOnlyExit); return getCouldNotCompute(); } @@ -4587,28 +4947,27 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { /// backedge of the specified loop will execute if its exit condition /// were a conditional branch of ExitCond, TBB, and FBB. /// -/// @param IsSubExpr is true if ExitCond does not directly control the exit -/// branch. In this case, we cannot assume that the loop only exits when the -/// condition is true and cannot infer that failing to meet the condition prior -/// to integer wraparound results in undefined behavior. +/// @param ControlsExit is true if ExitCond directly controls the exit +/// branch. In this case, we can assume that the loop exits only if the +/// condition is true and can infer that failing to meet the condition prior to +/// integer wraparound results in undefined behavior. ScalarEvolution::ExitLimit ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool IsSubExpr) { + bool ControlsExit) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. bool EitherMayExit = L->contains(TBB); ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); - bool MustExit = false; if (EitherMayExit) { // Both conditions must be true for the loop to continue executing. // Choose the less conservative count. @@ -4623,7 +4982,6 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; else MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); - MustExit = EL0.MustExit || EL1.MustExit; } else { // Both conditions must be true at the same time for the loop to exit. // For now, be conservative. @@ -4632,21 +4990,19 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; if (EL0.Exact == EL1.Exact) BECount = EL0.Exact; - MustExit = EL0.MustExit && EL1.MustExit; } - return ExitLimit(BECount, MaxBECount, MustExit); + return ExitLimit(BECount, MaxBECount); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. bool EitherMayExit = L->contains(FBB); ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); - bool MustExit = false; if (EitherMayExit) { // Both conditions must be false for the loop to continue executing. // Choose the less conservative count. @@ -4661,7 +5017,6 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; else MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); - MustExit = EL0.MustExit || EL1.MustExit; } else { // Both conditions must be false at the same time for the loop to exit. // For now, be conservative. @@ -4670,17 +5025,16 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; if (EL0.Exact == EL1.Exact) BECount = EL0.Exact; - MustExit = EL0.MustExit && EL1.MustExit; } - return ExitLimit(BECount, MaxBECount, MustExit); + return ExitLimit(BECount, MaxBECount); } } // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) - return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, IsSubExpr); + return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); // Check for a constant condition. These are normally stripped out by // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to @@ -4707,7 +5061,7 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool IsSubExpr) { + bool ControlsExit) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; @@ -4759,7 +5113,7 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, IsSubExpr); + ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } @@ -4772,14 +5126,14 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = Cond == ICmpInst::ICMP_SLT; - ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, IsSubExpr); + ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = Cond == ICmpInst::ICMP_SGT; - ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, IsSubExpr); + ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } @@ -4801,7 +5155,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock, - bool IsSubExpr) { + bool ControlsExit) { assert(!L->contains(ExitingBlock) && "Not an exiting block!"); // Give up if the exit is the default dest of a switch. @@ -4814,7 +5168,7 @@ ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L, const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, IsSubExpr); + ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; @@ -5687,7 +6041,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { /// effectively V != 0. We know and take advantage of the fact that this /// expression only being used in a comparison by zero context. ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr) { +ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // If the value is a constant if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { // If the value is already zero, the branch will execute zero times. @@ -5781,37 +6135,30 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr) { else MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() : -CR.getUnsignedMin()); - return ExitLimit(Distance, MaxBECount, /*MustExit=*/true); + return ExitLimit(Distance, MaxBECount); } - // If the recurrence is known not to wraparound, unsigned divide computes the - // back edge count. (Ideally we would have an "isexact" bit for udiv). We know - // that the value will either become zero (and thus the loop terminates), that - // the loop will terminate through some other exit condition first, or that - // the loop has undefined behavior. This means we can't "miss" the exit - // value, even with nonunit stride, and exit later via the same branch. Note - // that we can skip this exit if loop later exits via a different - // branch. Hence MustExit=false. - // - // This is only valid for expressions that directly compute the loop exit. It - // is invalid for subexpressions in which the loop may exit through this - // branch even if this subexpression is false. In that case, the trip count - // computed by this udiv could be smaller than the number of well-defined - // iterations. - if (!IsSubExpr && AddRec->getNoWrapFlags(SCEV::FlagNW)) { + // If the step exactly divides the distance then unsigned divide computes the + // backedge count. + const SCEV *Q, *R; + ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); + SCEVUDivision::divide(SE, Distance, Step, &Q, &R); + if (R->isZero()) { const SCEV *Exact = - getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact, /*MustExit=*/false); + getUDivExactExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); + return ExitLimit(Exact, Exact); } - // If Step is a power of two that evenly divides Start we know that the loop - // will always terminate. Start may not be a constant so we just have the - // number of trailing zeros available. This is safe even in presence of - // overflow as the recurrence will overflow to exactly 0. - const APInt &StepV = StepC->getValue()->getValue(); - if (StepV.isPowerOf2() && - GetMinTrailingZeros(getNegativeSCEV(Start)) >= StepV.countTrailingZeros()) - return getUDivExactExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); + // If the condition controls loop exit (the loop exits only if the expression + // is true) and the addition is no-wrap we can use unsigned divide to + // compute the backedge count. In this case, the step may not divide the + // distance, but we don't care because if the condition is "missed" the loop + // will have undefined behavior due to wrapping. + if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) { + const SCEV *Exact = + getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); + return ExitLimit(Exact, Exact); + } // Then, try to solve the above equation provided that Start is constant. if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) @@ -6309,19 +6656,30 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return true; + if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + BasicBlock *Latch = L->getLoopLatch(); if (!Latch) return false; BranchInst *LoopContinuePredicate = dyn_cast<BranchInst>(Latch->getTerminator()); - if (!LoopContinuePredicate || - LoopContinuePredicate->isUnconditional()) - return false; + if (LoopContinuePredicate && LoopContinuePredicate->isConditional() && + isImpliedCond(Pred, LHS, RHS, + LoopContinuePredicate->getCondition(), + LoopContinuePredicate->getSuccessor(0) != L->getHeader())) + return true; + + // Check conditions due to any @llvm.assume intrinsics. + for (auto &CI : AT->assumptions(F)) { + if (!DT->dominates(CI, Latch->getTerminator())) + continue; + + if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) + return true; + } - return isImpliedCond(Pred, LHS, RHS, - LoopContinuePredicate->getCondition(), - LoopContinuePredicate->getSuccessor(0) != L->getHeader()); + return false; } /// isLoopEntryGuardedByCond - Test whether entry to the loop is protected @@ -6335,6 +6693,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return false; + if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors // leading to the original header. @@ -6355,6 +6715,15 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, return true; } + // Check conditions due to any @llvm.assume intrinsics. + for (auto &CI : AT->assumptions(F)) { + if (!DT->dominates(CI, L->getHeader())) + continue; + + if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) + return true; + } + return false; } @@ -6469,6 +6838,66 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, RHS, LHS, FoundLHS, FoundRHS); } + // Check if we can make progress by sharpening ranges. + if (FoundPred == ICmpInst::ICMP_NE && + (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) { + + const SCEVConstant *C = nullptr; + const SCEV *V = nullptr; + + if (isa<SCEVConstant>(FoundLHS)) { + C = cast<SCEVConstant>(FoundLHS); + V = FoundRHS; + } else { + C = cast<SCEVConstant>(FoundRHS); + V = FoundLHS; + } + + // The guarding predicate tells us that C != V. If the known range + // of V is [C, t), we can sharpen the range to [C + 1, t). The + // range we consider has to correspond to same signedness as the + // predicate we're interested in folding. + + APInt Min = ICmpInst::isSigned(Pred) ? + getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin(); + + if (Min == C->getValue()->getValue()) { + // Given (V >= Min && V != Min) we conclude V >= (Min + 1). + // This is true even if (Min + 1) wraps around -- in case of + // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). + + APInt SharperMin = Min + 1; + + switch (Pred) { + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGE: + // We know V `Pred` SharperMin. If this implies LHS `Pred` + // RHS, we're done. + if (isImpliedCondOperands(Pred, LHS, RHS, V, + getConstant(SharperMin))) + return true; + + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + // We know from the range information that (V `Pred` Min || + // V == Min). We know from the guarding condition that !(V + // == Min). This gives us + // + // V `Pred` Min || V == Min && !(V == Min) + // => V `Pred` Min + // + // If V `Pred` Min implies LHS `Pred` RHS, we're done. + + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) + return true; + + default: + // No change + break; + } + } + } + // Check whether the actual condition is beyond sufficient. if (FoundPred == ICmpInst::ICMP_EQ) if (ICmpInst::isTrueWhenEqual(Pred)) @@ -6614,13 +7043,13 @@ const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, /// specified less-than comparison will execute. If not computable, return /// CouldNotCompute. /// -/// @param IsSubExpr is true when the LHS < RHS condition does not directly -/// control the branch. In this case, we can only compute an iteration count for -/// a subexpression that cannot overflow before evaluating true. +/// @param ControlsExit is true when the LHS < RHS condition directly controls +/// the branch (loops exits only if condition is true). In this case, we can use +/// NoWrapFlags to skip overflow checks. ScalarEvolution::ExitLimit ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool IsSubExpr) { + bool ControlsExit) { // We handle only IV < Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -6631,7 +7060,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (!IV || IV->getLoop() != L || !IV->isAffine()) return getCouldNotCompute(); - bool NoWrap = !IsSubExpr && + bool NoWrap = ControlsExit && IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); const SCEV *Stride = IV->getStepRecurrence(*this); @@ -6651,9 +7080,19 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, : ICmpInst::ICMP_ULT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) - End = IsSigned ? getSMaxExpr(RHS, Start) - : getUMaxExpr(RHS, Start); + if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) { + const SCEV *Diff = getMinusSCEV(RHS, Start); + // If we have NoWrap set, then we can assume that the increment won't + // overflow, in which case if RHS - Start is a constant, we don't need to + // do a max operation since we can just figure it out statically + if (NoWrap && isa<SCEVConstant>(Diff)) { + APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue(); + if (D.isNegative()) + End = Start; + } else + End = IsSigned ? getSMaxExpr(RHS, Start) + : getUMaxExpr(RHS, Start); + } const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); @@ -6684,13 +7123,13 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, /*MustExit=*/true); + return ExitLimit(BECount, MaxBECount); } ScalarEvolution::ExitLimit ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool IsSubExpr) { + bool ControlsExit) { // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -6701,7 +7140,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (!IV || IV->getLoop() != L || !IV->isAffine()) return getCouldNotCompute(); - bool NoWrap = !IsSubExpr && + bool NoWrap = ControlsExit && IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); @@ -6722,9 +7161,19 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) - End = IsSigned ? getSMinExpr(RHS, Start) - : getUMinExpr(RHS, Start); + if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { + const SCEV *Diff = getMinusSCEV(RHS, Start); + // If we have NoWrap set, then we can assume that the increment won't + // overflow, in which case if RHS - Start is a constant, we don't need to + // do a max operation since we can just figure it out statically + if (NoWrap && isa<SCEVConstant>(Diff)) { + APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue(); + if (!D.isNegative()) + End = Start; + } else + End = IsSigned ? getSMinExpr(RHS, Start) + : getUMinExpr(RHS, Start); + } const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); @@ -6756,7 +7205,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, /*MustExit=*/true); + return ExitLimit(BECount, MaxBECount); } /// getNumIterationsInRange - Return the number of iterations of this loop that @@ -6984,268 +7433,6 @@ void SCEVAddRecExpr::collectParametricTerms( }); } -static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); - - return APIntOps::srem(A, B); -} - -static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); - - return APIntOps::sdiv(A, B); -} - -namespace { -struct FindSCEVSize { - int Size; - FindSCEVSize() : Size(0) {} - - bool follow(const SCEV *S) { - ++Size; - // Keep looking at all operands of S. - return true; - } - bool isDone() const { - return false; - } -}; -} - -// Returns the size of the SCEV S. -static inline int sizeOfSCEV(const SCEV *S) { - FindSCEVSize F; - SCEVTraversal<FindSCEVSize> ST(F); - ST.visitAll(S); - return F.Size; -} - -namespace { - -struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> { -public: - // Computes the Quotient and Remainder of the division of Numerator by - // Denominator. - static void divide(ScalarEvolution &SE, const SCEV *Numerator, - const SCEV *Denominator, const SCEV **Quotient, - const SCEV **Remainder) { - assert(Numerator && Denominator && "Uninitialized SCEV"); - - SCEVDivision D(SE, Numerator, Denominator); - - // Check for the trivial case here to avoid having to check for it in the - // rest of the code. - if (Numerator == Denominator) { - *Quotient = D.One; - *Remainder = D.Zero; - return; - } - - if (Numerator->isZero()) { - *Quotient = D.Zero; - *Remainder = D.Zero; - return; - } - - // Split the Denominator when it is a product. - if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) { - const SCEV *Q, *R; - *Quotient = Numerator; - for (const SCEV *Op : T->operands()) { - divide(SE, *Quotient, Op, &Q, &R); - *Quotient = Q; - - // Bail out when the Numerator is not divisible by one of the terms of - // the Denominator. - if (!R->isZero()) { - *Quotient = D.Zero; - *Remainder = Numerator; - return; - } - } - *Remainder = D.Zero; - return; - } - - D.visit(Numerator); - *Quotient = D.Quotient; - *Remainder = D.Remainder; - } - - SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, const SCEV *Denominator) - : SE(S), Denominator(Denominator) { - Zero = SE.getConstant(Denominator->getType(), 0); - One = SE.getConstant(Denominator->getType(), 1); - - // By default, we don't know how to divide Expr by Denominator. - // Providing the default here simplifies the rest of the code. - Quotient = Zero; - Remainder = Numerator; - } - - // Except in the trivial case described above, we do not know how to divide - // Expr by Denominator for the following functions with empty implementation. - void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} - void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} - void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} - void visitUDivExpr(const SCEVUDivExpr *Numerator) {} - void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} - void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} - void visitUnknown(const SCEVUnknown *Numerator) {} - void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} - - void visitConstant(const SCEVConstant *Numerator) { - if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { - Quotient = SE.getConstant(sdiv(Numerator, D)); - Remainder = SE.getConstant(srem(Numerator, D)); - return; - } - } - - void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { - const SCEV *StartQ, *StartR, *StepQ, *StepR; - assert(Numerator->isAffine() && "Numerator should be affine"); - divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); - divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); - Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), - Numerator->getNoWrapFlags()); - Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), - Numerator->getNoWrapFlags()); - } - - void visitAddExpr(const SCEVAddExpr *Numerator) { - SmallVector<const SCEV *, 2> Qs, Rs; - Type *Ty = Denominator->getType(); - - for (const SCEV *Op : Numerator->operands()) { - const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); - - // Bail out if types do not match. - if (Ty != Q->getType() || Ty != R->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } - - Qs.push_back(Q); - Rs.push_back(R); - } - - if (Qs.size() == 1) { - Quotient = Qs[0]; - Remainder = Rs[0]; - return; - } - - Quotient = SE.getAddExpr(Qs); - Remainder = SE.getAddExpr(Rs); - } - - void visitMulExpr(const SCEVMulExpr *Numerator) { - SmallVector<const SCEV *, 2> Qs; - Type *Ty = Denominator->getType(); - - bool FoundDenominatorTerm = false; - for (const SCEV *Op : Numerator->operands()) { - // Bail out if types do not match. - if (Ty != Op->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } - - if (FoundDenominatorTerm) { - Qs.push_back(Op); - continue; - } - - // Check whether Denominator divides one of the product operands. - const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); - if (!R->isZero()) { - Qs.push_back(Op); - continue; - } - - // Bail out if types do not match. - if (Ty != Q->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } - - FoundDenominatorTerm = true; - Qs.push_back(Q); - } - - if (FoundDenominatorTerm) { - Remainder = Zero; - if (Qs.size() == 1) - Quotient = Qs[0]; - else - Quotient = SE.getMulExpr(Qs); - return; - } - - if (!isa<SCEVUnknown>(Denominator)) { - Quotient = Zero; - Remainder = Numerator; - return; - } - - // The Remainder is obtained by replacing Denominator by 0 in Numerator. - ValueToValueMap RewriteMap; - RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = - cast<SCEVConstant>(Zero)->getValue(); - Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); - - if (Remainder->isZero()) { - // The Quotient is obtained by replacing Denominator by 1 in Numerator. - RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = - cast<SCEVConstant>(One)->getValue(); - Quotient = - SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); - return; - } - - // Quotient is (Numerator - Remainder) divided by Denominator. - const SCEV *Q, *R; - const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); - if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) { - // This SCEV does not seem to simplify: fail the division here. - Quotient = Zero; - Remainder = Numerator; - return; - } - divide(SE, Diff, Denominator, &Q, &R); - assert(R == Zero && - "(Numerator - Remainder) should evenly divide Denominator"); - Quotient = Q; - } - -private: - ScalarEvolution &SE; - const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; -}; -} - static bool findArrayDimensionsRec(ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &Terms, SmallVectorImpl<const SCEV *> &Sizes) { @@ -7270,7 +7457,7 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE, for (const SCEV *&Term : Terms) { // Normalize the terms before the next call to findArrayDimensionsRec. const SCEV *Q, *R; - SCEVDivision::divide(SE, Term, Step, &Q, &R); + SCEVSDivision::divide(SE, Term, Step, &Q, &R); // Bail out when GCD does not evenly divide one of the terms. if (!R->isZero()) @@ -7407,7 +7594,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, // Divide all terms by the element size. for (const SCEV *&Term : Terms) { const SCEV *Q, *R; - SCEVDivision::divide(SE, Term, ElementSize, &Q, &R); + SCEVSDivision::divide(SE, Term, ElementSize, &Q, &R); Term = Q; } @@ -7454,7 +7641,7 @@ void SCEVAddRecExpr::computeAccessFunctions( int Last = Sizes.size() - 1; for (int i = Last; i >= 0; i--) { const SCEV *Q, *R; - SCEVDivision::divide(SE, Res, Sizes[i], &Q, &R); + SCEVSDivision::divide(SE, Res, Sizes[i], &Q, &R); DEBUG({ dbgs() << "Res: " << *Res << "\n"; @@ -7609,7 +7796,7 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { // that until everything else is done. if (U == Old) continue; - if (!Visited.insert(U)) + if (!Visited.insert(U).second) continue; if (PHINode *PN = dyn_cast<PHINode>(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); @@ -7638,6 +7825,7 @@ ScalarEvolution::ScalarEvolution() bool ScalarEvolution::runOnFunction(Function &F) { this->F = &F; + AT = &getAnalysis<AssumptionTracker>(); LI = &getAnalysis<LoopInfo>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; @@ -7678,6 +7866,7 @@ void ScalarEvolution::releaseMemory() { void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); + AU.addRequired<AssumptionTracker>(); AU.addRequiredTransitive<LoopInfo>(); AU.addRequiredTransitive<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfo>(); diff --git a/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp b/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp index 6933f74..5c339ee 100644 --- a/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp +++ b/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp @@ -162,10 +162,10 @@ ScalarEvolutionAliasAnalysis::alias(const Location &LocA, if ((AO && AO != LocA.Ptr) || (BO && BO != LocB.Ptr)) if (alias(Location(AO ? AO : LocA.Ptr, AO ? +UnknownSize : LocA.Size, - AO ? nullptr : LocA.TBAATag), + AO ? AAMDNodes() : LocA.AATags), Location(BO ? BO : LocB.Ptr, BO ? +UnknownSize : LocB.Size, - BO ? nullptr : LocB.TBAATag)) == NoAlias) + BO ? AAMDNodes() : LocB.AATags)) == NoAlias) return NoAlias; // Forward the query to the next analysis. diff --git a/lib/Analysis/ScalarEvolutionExpander.cpp b/lib/Analysis/ScalarEvolutionExpander.cpp index 8c75b0d..bee3685 100644 --- a/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/lib/Analysis/ScalarEvolutionExpander.cpp @@ -1443,8 +1443,12 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { Constant *One = ConstantInt::get(Ty, 1); for (pred_iterator HPI = HPB; HPI != HPE; ++HPI) { BasicBlock *HP = *HPI; - if (!PredSeen.insert(HP)) + if (!PredSeen.insert(HP).second) { + // There must be an incoming value for each predecessor, even the + // duplicates! + CanonicalIV->addIncoming(CanonicalIV->getIncomingValueForBlock(HP), HP); continue; + } if (L->contains(HP)) { // Insert a unit add instruction right before the terminator @@ -1707,7 +1711,7 @@ unsigned SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, // Fold constant phis. They may be congruent to other constant phis and // would confuse the logic below that expects proper IVs. - if (Value *V = SimplifyInstruction(Phi, SE.DL, SE.TLI, SE.DT)) { + if (Value *V = SimplifyInstruction(Phi, SE.DL, SE.TLI, SE.DT, SE.AT)) { Phi->replaceAllUsesWith(V); DeadInsts.push_back(Phi); ++NumElim; diff --git a/lib/Analysis/ScalarEvolutionNormalization.cpp b/lib/Analysis/ScalarEvolutionNormalization.cpp index 3ccefb0..b238fe4 100644 --- a/lib/Analysis/ScalarEvolutionNormalization.cpp +++ b/lib/Analysis/ScalarEvolutionNormalization.cpp @@ -126,7 +126,7 @@ TransformImpl(const SCEV *S, Instruction *User, Value *OperandValToReplace) { // Normalized form: {-2,+,1,+,2} // Denormalized form: {1,+,3,+,2} // - // However, denormalization would use the a different step expression than + // However, denormalization would use a different step expression than // normalization (see getPostIncExpr), generating the wrong final // expression: {-2,+,1,+,2} + {1,+,2} => {-1,+,3,+,2} if (AR->isAffine() && diff --git a/lib/Analysis/ScopedNoAliasAA.cpp b/lib/Analysis/ScopedNoAliasAA.cpp new file mode 100644 index 0000000..f6c300a --- /dev/null +++ b/lib/Analysis/ScopedNoAliasAA.cpp @@ -0,0 +1,245 @@ +//===- ScopedNoAliasAA.cpp - Scoped No-Alias Alias Analysis ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the ScopedNoAlias alias-analysis pass, which implements +// metadata-based scoped no-alias support. +// +// Alias-analysis scopes are defined by an id (which can be a string or some +// other metadata node), a domain node, and an optional descriptive string. +// A domain is defined by an id (which can be a string or some other metadata +// node), and an optional descriptive string. +// +// !dom0 = metadata !{ metadata !"domain of foo()" } +// !scope1 = metadata !{ metadata !scope1, metadata !dom0, metadata !"scope 1" } +// !scope2 = metadata !{ metadata !scope2, metadata !dom0, metadata !"scope 2" } +// +// Loads and stores can be tagged with an alias-analysis scope, and also, with +// a noalias tag for a specific scope: +// +// ... = load %ptr1, !alias.scope !{ !scope1 } +// ... = load %ptr2, !alias.scope !{ !scope1, !scope2 }, !noalias !{ !scope1 } +// +// When evaluating an aliasing query, if one of the instructions is associated +// has a set of noalias scopes in some domain that is superset of the alias +// scopes in that domain of some other instruction, then the two memory +// accesses are assumed not to alias. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +using namespace llvm; + +// A handy option for disabling scoped no-alias functionality. The same effect +// can also be achieved by stripping the associated metadata tags from IR, but +// this option is sometimes more convenient. +static cl::opt<bool> +EnableScopedNoAlias("enable-scoped-noalias", cl::init(true)); + +namespace { +/// AliasScopeNode - This is a simple wrapper around an MDNode which provides +/// a higher-level interface by hiding the details of how alias analysis +/// information is encoded in its operands. +class AliasScopeNode { + const MDNode *Node; + +public: + AliasScopeNode() : Node(0) {} + explicit AliasScopeNode(const MDNode *N) : Node(N) {} + + /// getNode - Get the MDNode for this AliasScopeNode. + const MDNode *getNode() const { return Node; } + + /// getDomain - Get the MDNode for this AliasScopeNode's domain. + const MDNode *getDomain() const { + if (Node->getNumOperands() < 2) + return nullptr; + return dyn_cast_or_null<MDNode>(Node->getOperand(1)); + } +}; + +/// ScopedNoAliasAA - This is a simple alias analysis +/// implementation that uses scoped-noalias metadata to answer queries. +class ScopedNoAliasAA : public ImmutablePass, public AliasAnalysis { +public: + static char ID; // Class identification, replacement for typeinfo + ScopedNoAliasAA() : ImmutablePass(ID) { + initializeScopedNoAliasAAPass(*PassRegistry::getPassRegistry()); + } + + void initializePass() override { InitializeAliasAnalysis(this); } + + /// getAdjustedAnalysisPointer - This method is used when a pass implements + /// an analysis interface through multiple inheritance. If needed, it + /// should override this to adjust the this pointer as needed for the + /// specified pass info. + void *getAdjustedAnalysisPointer(const void *PI) override { + if (PI == &AliasAnalysis::ID) + return (AliasAnalysis*)this; + return this; + } + +protected: + bool mayAliasInScopes(const MDNode *Scopes, const MDNode *NoAlias) const; + void collectMDInDomain(const MDNode *List, const MDNode *Domain, + SmallPtrSetImpl<const MDNode *> &Nodes) const; + +private: + void getAnalysisUsage(AnalysisUsage &AU) const override; + AliasResult alias(const Location &LocA, const Location &LocB) override; + bool pointsToConstantMemory(const Location &Loc, bool OrLocal) override; + ModRefBehavior getModRefBehavior(ImmutableCallSite CS) override; + ModRefBehavior getModRefBehavior(const Function *F) override; + ModRefResult getModRefInfo(ImmutableCallSite CS, + const Location &Loc) override; + ModRefResult getModRefInfo(ImmutableCallSite CS1, + ImmutableCallSite CS2) override; +}; +} // End of anonymous namespace + +// Register this pass... +char ScopedNoAliasAA::ID = 0; +INITIALIZE_AG_PASS(ScopedNoAliasAA, AliasAnalysis, "scoped-noalias", + "Scoped NoAlias Alias Analysis", false, true, false) + +ImmutablePass *llvm::createScopedNoAliasAAPass() { + return new ScopedNoAliasAA(); +} + +void +ScopedNoAliasAA::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AliasAnalysis::getAnalysisUsage(AU); +} + +void +ScopedNoAliasAA::collectMDInDomain(const MDNode *List, const MDNode *Domain, + SmallPtrSetImpl<const MDNode *> &Nodes) const { + for (unsigned i = 0, ie = List->getNumOperands(); i != ie; ++i) + if (const MDNode *MD = dyn_cast<MDNode>(List->getOperand(i))) + if (AliasScopeNode(MD).getDomain() == Domain) + Nodes.insert(MD); +} + +bool +ScopedNoAliasAA::mayAliasInScopes(const MDNode *Scopes, + const MDNode *NoAlias) const { + if (!Scopes || !NoAlias) + return true; + + // Collect the set of scope domains relevant to the noalias scopes. + SmallPtrSet<const MDNode *, 16> Domains; + for (unsigned i = 0, ie = NoAlias->getNumOperands(); i != ie; ++i) + if (const MDNode *NAMD = dyn_cast<MDNode>(NoAlias->getOperand(i))) + if (const MDNode *Domain = AliasScopeNode(NAMD).getDomain()) + Domains.insert(Domain); + + // We alias unless, for some domain, the set of noalias scopes in that domain + // is a superset of the set of alias scopes in that domain. + for (const MDNode *Domain : Domains) { + SmallPtrSet<const MDNode *, 16> NANodes, ScopeNodes; + collectMDInDomain(NoAlias, Domain, NANodes); + collectMDInDomain(Scopes, Domain, ScopeNodes); + if (!ScopeNodes.size()) + continue; + + // To not alias, all of the nodes in ScopeNodes must be in NANodes. + bool FoundAll = true; + for (const MDNode *SMD : ScopeNodes) + if (!NANodes.count(SMD)) { + FoundAll = false; + break; + } + + if (FoundAll) + return false; + } + + return true; +} + +AliasAnalysis::AliasResult +ScopedNoAliasAA::alias(const Location &LocA, const Location &LocB) { + if (!EnableScopedNoAlias) + return AliasAnalysis::alias(LocA, LocB); + + // Get the attached MDNodes. + const MDNode *AScopes = LocA.AATags.Scope, + *BScopes = LocB.AATags.Scope; + + const MDNode *ANoAlias = LocA.AATags.NoAlias, + *BNoAlias = LocB.AATags.NoAlias; + + if (!mayAliasInScopes(AScopes, BNoAlias)) + return NoAlias; + + if (!mayAliasInScopes(BScopes, ANoAlias)) + return NoAlias; + + // If they may alias, chain to the next AliasAnalysis. + return AliasAnalysis::alias(LocA, LocB); +} + +bool ScopedNoAliasAA::pointsToConstantMemory(const Location &Loc, + bool OrLocal) { + return AliasAnalysis::pointsToConstantMemory(Loc, OrLocal); +} + +AliasAnalysis::ModRefBehavior +ScopedNoAliasAA::getModRefBehavior(ImmutableCallSite CS) { + return AliasAnalysis::getModRefBehavior(CS); +} + +AliasAnalysis::ModRefBehavior +ScopedNoAliasAA::getModRefBehavior(const Function *F) { + return AliasAnalysis::getModRefBehavior(F); +} + +AliasAnalysis::ModRefResult +ScopedNoAliasAA::getModRefInfo(ImmutableCallSite CS, const Location &Loc) { + if (!EnableScopedNoAlias) + return AliasAnalysis::getModRefInfo(CS, Loc); + + if (!mayAliasInScopes(Loc.AATags.Scope, CS.getInstruction()->getMetadata( + LLVMContext::MD_noalias))) + return NoModRef; + + if (!mayAliasInScopes( + CS.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), + Loc.AATags.NoAlias)) + return NoModRef; + + return AliasAnalysis::getModRefInfo(CS, Loc); +} + +AliasAnalysis::ModRefResult +ScopedNoAliasAA::getModRefInfo(ImmutableCallSite CS1, ImmutableCallSite CS2) { + if (!EnableScopedNoAlias) + return AliasAnalysis::getModRefInfo(CS1, CS2); + + if (!mayAliasInScopes( + CS1.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), + CS2.getInstruction()->getMetadata(LLVMContext::MD_noalias))) + return NoModRef; + + if (!mayAliasInScopes( + CS2.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), + CS1.getInstruction()->getMetadata(LLVMContext::MD_noalias))) + return NoModRef; + + return AliasAnalysis::getModRefInfo(CS1, CS2); +} + diff --git a/lib/Analysis/StratifiedSets.h b/lib/Analysis/StratifiedSets.h new file mode 100644 index 0000000..fd3fbc0 --- /dev/null +++ b/lib/Analysis/StratifiedSets.h @@ -0,0 +1,692 @@ +//===- StratifiedSets.h - Abstract stratified sets implementation. --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_STRATIFIEDSETS_H +#define LLVM_ADT_STRATIFIEDSETS_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Compiler.h" +#include <bitset> +#include <cassert> +#include <cmath> +#include <limits> +#include <type_traits> +#include <utility> +#include <vector> + +namespace llvm { +// \brief An index into Stratified Sets. +typedef unsigned StratifiedIndex; +// NOTE: ^ This can't be a short -- bootstrapping clang has a case where +// ~1M sets exist. + +// \brief Container of information related to a value in a StratifiedSet. +struct StratifiedInfo { + StratifiedIndex Index; + // For field sensitivity, etc. we can tack attributes on to this struct. +}; + +// The number of attributes that StratifiedAttrs should contain. Attributes are +// described below, and 32 was an arbitrary choice because it fits nicely in 32 +// bits (because we use a bitset for StratifiedAttrs). +static const unsigned NumStratifiedAttrs = 32; + +// These are attributes that the users of StratifiedSets/StratifiedSetBuilders +// may use for various purposes. These also have the special property of that +// they are merged down. So, if set A is above set B, and one decides to set an +// attribute in set A, then the attribute will automatically be set in set B. +typedef std::bitset<NumStratifiedAttrs> StratifiedAttrs; + +// \brief A "link" between two StratifiedSets. +struct StratifiedLink { + // \brief This is a value used to signify "does not exist" where + // the StratifiedIndex type is used. This is used instead of + // Optional<StratifiedIndex> because Optional<StratifiedIndex> would + // eat up a considerable amount of extra memory, after struct + // padding/alignment is taken into account. + static const StratifiedIndex SetSentinel; + + // \brief The index for the set "above" current + StratifiedIndex Above; + + // \brief The link for the set "below" current + StratifiedIndex Below; + + // \brief Attributes for these StratifiedSets. + StratifiedAttrs Attrs; + + StratifiedLink() : Above(SetSentinel), Below(SetSentinel) {} + + bool hasBelow() const { return Below != SetSentinel; } + bool hasAbove() const { return Above != SetSentinel; } + + void clearBelow() { Below = SetSentinel; } + void clearAbove() { Above = SetSentinel; } +}; + +// \brief These are stratified sets, as described in "Fast algorithms for +// Dyck-CFL-reachability with applications to Alias Analysis" by Zhang Q, Lyu M +// R, Yuan H, and Su Z. -- in short, this is meant to represent different sets +// of Value*s. If two Value*s are in the same set, or if both sets have +// overlapping attributes, then the Value*s are said to alias. +// +// Sets may be related by position, meaning that one set may be considered as +// above or below another. In CFL Alias Analysis, this gives us an indication +// of how two variables are related; if the set of variable A is below a set +// containing variable B, then at some point, a variable that has interacted +// with B (or B itself) was either used in order to extract the variable A, or +// was used as storage of variable A. +// +// Sets may also have attributes (as noted above). These attributes are +// generally used for noting whether a variable in the set has interacted with +// a variable whose origins we don't quite know (i.e. globals/arguments), or if +// the variable may have had operations performed on it (modified in a function +// call). All attributes that exist in a set A must exist in all sets marked as +// below set A. +template <typename T> class StratifiedSets { +public: + StratifiedSets() {} + + StratifiedSets(DenseMap<T, StratifiedInfo> Map, + std::vector<StratifiedLink> Links) + : Values(std::move(Map)), Links(std::move(Links)) {} + + StratifiedSets(StratifiedSets<T> &&Other) { *this = std::move(Other); } + + StratifiedSets &operator=(StratifiedSets<T> &&Other) { + Values = std::move(Other.Values); + Links = std::move(Other.Links); + return *this; + } + + Optional<StratifiedInfo> find(const T &Elem) const { + auto Iter = Values.find(Elem); + if (Iter == Values.end()) { + return NoneType(); + } + return Iter->second; + } + + const StratifiedLink &getLink(StratifiedIndex Index) const { + assert(inbounds(Index)); + return Links[Index]; + } + +private: + DenseMap<T, StratifiedInfo> Values; + std::vector<StratifiedLink> Links; + + bool inbounds(StratifiedIndex Idx) const { return Idx < Links.size(); } +}; + +// \brief Generic Builder class that produces StratifiedSets instances. +// +// The goal of this builder is to efficiently produce correct StratifiedSets +// instances. To this end, we use a few tricks: +// > Set chains (A method for linking sets together) +// > Set remaps (A method for marking a set as an alias [irony?] of another) +// +// ==== Set chains ==== +// This builder has a notion of some value A being above, below, or with some +// other value B: +// > The `A above B` relationship implies that there is a reference edge going +// from A to B. Namely, it notes that A can store anything in B's set. +// > The `A below B` relationship is the opposite of `A above B`. It implies +// that there's a dereference edge going from A to B. +// > The `A with B` relationship states that there's an assignment edge going +// from A to B, and that A and B should be treated as equals. +// +// As an example, take the following code snippet: +// +// %a = alloca i32, align 4 +// %ap = alloca i32*, align 8 +// %app = alloca i32**, align 8 +// store %a, %ap +// store %ap, %app +// %aw = getelementptr %ap, 0 +// +// Given this, the follow relations exist: +// - %a below %ap & %ap above %a +// - %ap below %app & %app above %ap +// - %aw with %ap & %ap with %aw +// +// These relations produce the following sets: +// [{%a}, {%ap, %aw}, {%app}] +// +// ...Which states that the only MayAlias relationship in the above program is +// between %ap and %aw. +// +// Life gets more complicated when we actually have logic in our programs. So, +// we either must remove this logic from our programs, or make consessions for +// it in our AA algorithms. In this case, we have decided to select the latter +// option. +// +// First complication: Conditionals +// Motivation: +// %ad = alloca int, align 4 +// %a = alloca int*, align 8 +// %b = alloca int*, align 8 +// %bp = alloca int**, align 8 +// %c = call i1 @SomeFunc() +// %k = select %c, %ad, %bp +// store %ad, %a +// store %b, %bp +// +// %k has 'with' edges to both %a and %b, which ordinarily would not be linked +// together. So, we merge the set that contains %a with the set that contains +// %b. We then recursively merge the set above %a with the set above %b, and +// the set below %a with the set below %b, etc. Ultimately, the sets for this +// program would end up like: {%ad}, {%a, %b, %k}, {%bp}, where {%ad} is below +// {%a, %b, %c} is below {%ad}. +// +// Second complication: Arbitrary casts +// Motivation: +// %ip = alloca int*, align 8 +// %ipp = alloca int**, align 8 +// %i = bitcast ipp to int +// store %ip, %ipp +// store %i, %ip +// +// This is impossible to construct with any of the rules above, because a set +// containing both {%i, %ipp} is supposed to exist, the set with %i is supposed +// to be below the set with %ip, and the set with %ip is supposed to be below +// the set with %ipp. Because we don't allow circular relationships like this, +// we merge all concerned sets into one. So, the above code would generate a +// single StratifiedSet: {%ip, %ipp, %i}. +// +// ==== Set remaps ==== +// More of an implementation detail than anything -- when merging sets, we need +// to update the numbers of all of the elements mapped to those sets. Rather +// than doing this at each merge, we note in the BuilderLink structure that a +// remap has occurred, and use this information so we can defer renumbering set +// elements until build time. +template <typename T> class StratifiedSetsBuilder { + // \brief Represents a Stratified Set, with information about the Stratified + // Set above it, the set below it, and whether the current set has been + // remapped to another. + struct BuilderLink { + const StratifiedIndex Number; + + BuilderLink(StratifiedIndex N) : Number(N) { + Remap = StratifiedLink::SetSentinel; + } + + bool hasAbove() const { + assert(!isRemapped()); + return Link.hasAbove(); + } + + bool hasBelow() const { + assert(!isRemapped()); + return Link.hasBelow(); + } + + void setBelow(StratifiedIndex I) { + assert(!isRemapped()); + Link.Below = I; + } + + void setAbove(StratifiedIndex I) { + assert(!isRemapped()); + Link.Above = I; + } + + void clearBelow() { + assert(!isRemapped()); + Link.clearBelow(); + } + + void clearAbove() { + assert(!isRemapped()); + Link.clearAbove(); + } + + StratifiedIndex getBelow() const { + assert(!isRemapped()); + assert(hasBelow()); + return Link.Below; + } + + StratifiedIndex getAbove() const { + assert(!isRemapped()); + assert(hasAbove()); + return Link.Above; + } + + StratifiedAttrs &getAttrs() { + assert(!isRemapped()); + return Link.Attrs; + } + + void setAttr(unsigned index) { + assert(!isRemapped()); + assert(index < NumStratifiedAttrs); + Link.Attrs.set(index); + } + + void setAttrs(const StratifiedAttrs &other) { + assert(!isRemapped()); + Link.Attrs |= other; + } + + bool isRemapped() const { return Remap != StratifiedLink::SetSentinel; } + + // \brief For initial remapping to another set + void remapTo(StratifiedIndex Other) { + assert(!isRemapped()); + Remap = Other; + } + + StratifiedIndex getRemapIndex() const { + assert(isRemapped()); + return Remap; + } + + // \brief Should only be called when we're already remapped. + void updateRemap(StratifiedIndex Other) { + assert(isRemapped()); + Remap = Other; + } + + // \brief Prefer the above functions to calling things directly on what's + // returned from this -- they guard against unexpected calls when the + // current BuilderLink is remapped. + const StratifiedLink &getLink() const { return Link; } + + private: + StratifiedLink Link; + StratifiedIndex Remap; + }; + + // \brief This function performs all of the set unioning/value renumbering + // that we've been putting off, and generates a vector<StratifiedLink> that + // may be placed in a StratifiedSets instance. + void finalizeSets(std::vector<StratifiedLink> &StratLinks) { + DenseMap<StratifiedIndex, StratifiedIndex> Remaps; + for (auto &Link : Links) { + if (Link.isRemapped()) { + continue; + } + + StratifiedIndex Number = StratLinks.size(); + Remaps.insert(std::make_pair(Link.Number, Number)); + StratLinks.push_back(Link.getLink()); + } + + for (auto &Link : StratLinks) { + if (Link.hasAbove()) { + auto &Above = linksAt(Link.Above); + auto Iter = Remaps.find(Above.Number); + assert(Iter != Remaps.end()); + Link.Above = Iter->second; + } + + if (Link.hasBelow()) { + auto &Below = linksAt(Link.Below); + auto Iter = Remaps.find(Below.Number); + assert(Iter != Remaps.end()); + Link.Below = Iter->second; + } + } + + for (auto &Pair : Values) { + auto &Info = Pair.second; + auto &Link = linksAt(Info.Index); + auto Iter = Remaps.find(Link.Number); + assert(Iter != Remaps.end()); + Info.Index = Iter->second; + } + } + + // \brief There's a guarantee in StratifiedLink where all bits set in a + // Link.externals will be set in all Link.externals "below" it. + static void propagateAttrs(std::vector<StratifiedLink> &Links) { + const auto getHighestParentAbove = [&Links](StratifiedIndex Idx) { + const auto *Link = &Links[Idx]; + while (Link->hasAbove()) { + Idx = Link->Above; + Link = &Links[Idx]; + } + return Idx; + }; + + SmallSet<StratifiedIndex, 16> Visited; + for (unsigned I = 0, E = Links.size(); I < E; ++I) { + auto CurrentIndex = getHighestParentAbove(I); + if (!Visited.insert(CurrentIndex).second) { + continue; + } + + while (Links[CurrentIndex].hasBelow()) { + auto &CurrentBits = Links[CurrentIndex].Attrs; + auto NextIndex = Links[CurrentIndex].Below; + auto &NextBits = Links[NextIndex].Attrs; + NextBits |= CurrentBits; + CurrentIndex = NextIndex; + } + } + } + +public: + // \brief Builds a StratifiedSet from the information we've been given since + // either construction or the prior build() call. + StratifiedSets<T> build() { + std::vector<StratifiedLink> StratLinks; + finalizeSets(StratLinks); + propagateAttrs(StratLinks); + Links.clear(); + return StratifiedSets<T>(std::move(Values), std::move(StratLinks)); + } + + std::size_t size() const { return Values.size(); } + std::size_t numSets() const { return Links.size(); } + + bool has(const T &Elem) const { return get(Elem).hasValue(); } + + bool add(const T &Main) { + if (get(Main).hasValue()) + return false; + + auto NewIndex = getNewUnlinkedIndex(); + return addAtMerging(Main, NewIndex); + } + + // \brief Restructures the stratified sets as necessary to make "ToAdd" in a + // set above "Main". There are some cases where this is not possible (see + // above), so we merge them such that ToAdd and Main are in the same set. + bool addAbove(const T &Main, const T &ToAdd) { + assert(has(Main)); + auto Index = *indexOf(Main); + if (!linksAt(Index).hasAbove()) + addLinkAbove(Index); + + auto Above = linksAt(Index).getAbove(); + return addAtMerging(ToAdd, Above); + } + + // \brief Restructures the stratified sets as necessary to make "ToAdd" in a + // set below "Main". There are some cases where this is not possible (see + // above), so we merge them such that ToAdd and Main are in the same set. + bool addBelow(const T &Main, const T &ToAdd) { + assert(has(Main)); + auto Index = *indexOf(Main); + if (!linksAt(Index).hasBelow()) + addLinkBelow(Index); + + auto Below = linksAt(Index).getBelow(); + return addAtMerging(ToAdd, Below); + } + + bool addWith(const T &Main, const T &ToAdd) { + assert(has(Main)); + auto MainIndex = *indexOf(Main); + return addAtMerging(ToAdd, MainIndex); + } + + void noteAttribute(const T &Main, unsigned AttrNum) { + assert(has(Main)); + assert(AttrNum < StratifiedLink::SetSentinel); + auto *Info = *get(Main); + auto &Link = linksAt(Info->Index); + Link.setAttr(AttrNum); + } + + void noteAttributes(const T &Main, const StratifiedAttrs &NewAttrs) { + assert(has(Main)); + auto *Info = *get(Main); + auto &Link = linksAt(Info->Index); + Link.setAttrs(NewAttrs); + } + + StratifiedAttrs getAttributes(const T &Main) { + assert(has(Main)); + auto *Info = *get(Main); + auto *Link = &linksAt(Info->Index); + auto Attrs = Link->getAttrs(); + while (Link->hasAbove()) { + Link = &linksAt(Link->getAbove()); + Attrs |= Link->getAttrs(); + } + + return Attrs; + } + + bool getAttribute(const T &Main, unsigned AttrNum) { + assert(AttrNum < StratifiedLink::SetSentinel); + auto Attrs = getAttributes(Main); + return Attrs[AttrNum]; + } + + // \brief Gets the attributes that have been applied to the set that Main + // belongs to. It ignores attributes in any sets above the one that Main + // resides in. + StratifiedAttrs getRawAttributes(const T &Main) { + assert(has(Main)); + auto *Info = *get(Main); + auto &Link = linksAt(Info->Index); + return Link.getAttrs(); + } + + // \brief Gets an attribute from the attributes that have been applied to the + // set that Main belongs to. It ignores attributes in any sets above the one + // that Main resides in. + bool getRawAttribute(const T &Main, unsigned AttrNum) { + assert(AttrNum < StratifiedLink::SetSentinel); + auto Attrs = getRawAttributes(Main); + return Attrs[AttrNum]; + } + +private: + DenseMap<T, StratifiedInfo> Values; + std::vector<BuilderLink> Links; + + // \brief Adds the given element at the given index, merging sets if + // necessary. + bool addAtMerging(const T &ToAdd, StratifiedIndex Index) { + StratifiedInfo Info = {Index}; + auto Pair = Values.insert(std::make_pair(ToAdd, Info)); + if (Pair.second) + return true; + + auto &Iter = Pair.first; + auto &IterSet = linksAt(Iter->second.Index); + auto &ReqSet = linksAt(Index); + + // Failed to add where we wanted to. Merge the sets. + if (&IterSet != &ReqSet) + merge(IterSet.Number, ReqSet.Number); + + return false; + } + + // \brief Gets the BuilderLink at the given index, taking set remapping into + // account. + BuilderLink &linksAt(StratifiedIndex Index) { + auto *Start = &Links[Index]; + if (!Start->isRemapped()) + return *Start; + + auto *Current = Start; + while (Current->isRemapped()) + Current = &Links[Current->getRemapIndex()]; + + auto NewRemap = Current->Number; + + // Run through everything that has yet to be updated, and update them to + // remap to NewRemap + Current = Start; + while (Current->isRemapped()) { + auto *Next = &Links[Current->getRemapIndex()]; + Current->updateRemap(NewRemap); + Current = Next; + } + + return *Current; + } + + // \brief Merges two sets into one another. Assumes that these sets are not + // already one in the same + void merge(StratifiedIndex Idx1, StratifiedIndex Idx2) { + assert(inbounds(Idx1) && inbounds(Idx2)); + assert(&linksAt(Idx1) != &linksAt(Idx2) && + "Merging a set into itself is not allowed"); + + // CASE 1: If the set at `Idx1` is above or below `Idx2`, we need to merge + // both the + // given sets, and all sets between them, into one. + if (tryMergeUpwards(Idx1, Idx2)) + return; + + if (tryMergeUpwards(Idx2, Idx1)) + return; + + // CASE 2: The set at `Idx1` is not in the same chain as the set at `Idx2`. + // We therefore need to merge the two chains together. + mergeDirect(Idx1, Idx2); + } + + // \brief Merges two sets assuming that the set at `Idx1` is unreachable from + // traversing above or below the set at `Idx2`. + void mergeDirect(StratifiedIndex Idx1, StratifiedIndex Idx2) { + assert(inbounds(Idx1) && inbounds(Idx2)); + + auto *LinksInto = &linksAt(Idx1); + auto *LinksFrom = &linksAt(Idx2); + // Merging everything above LinksInto then proceeding to merge everything + // below LinksInto becomes problematic, so we go as far "up" as possible! + while (LinksInto->hasAbove() && LinksFrom->hasAbove()) { + LinksInto = &linksAt(LinksInto->getAbove()); + LinksFrom = &linksAt(LinksFrom->getAbove()); + } + + if (LinksFrom->hasAbove()) { + LinksInto->setAbove(LinksFrom->getAbove()); + auto &NewAbove = linksAt(LinksInto->getAbove()); + NewAbove.setBelow(LinksInto->Number); + } + + // Merging strategy: + // > If neither has links below, stop. + // > If only `LinksInto` has links below, stop. + // > If only `LinksFrom` has links below, reset `LinksInto.Below` to + // match `LinksFrom.Below` + // > If both have links above, deal with those next. + while (LinksInto->hasBelow() && LinksFrom->hasBelow()) { + auto &FromAttrs = LinksFrom->getAttrs(); + LinksInto->setAttrs(FromAttrs); + + // Remap needs to happen after getBelow(), but before + // assignment of LinksFrom + auto *NewLinksFrom = &linksAt(LinksFrom->getBelow()); + LinksFrom->remapTo(LinksInto->Number); + LinksFrom = NewLinksFrom; + LinksInto = &linksAt(LinksInto->getBelow()); + } + + if (LinksFrom->hasBelow()) { + LinksInto->setBelow(LinksFrom->getBelow()); + auto &NewBelow = linksAt(LinksInto->getBelow()); + NewBelow.setAbove(LinksInto->Number); + } + + LinksFrom->remapTo(LinksInto->Number); + } + + // \brief Checks to see if lowerIndex is at a level lower than upperIndex. + // If so, it will merge lowerIndex with upperIndex (and all of the sets + // between) and return true. Otherwise, it will return false. + bool tryMergeUpwards(StratifiedIndex LowerIndex, StratifiedIndex UpperIndex) { + assert(inbounds(LowerIndex) && inbounds(UpperIndex)); + auto *Lower = &linksAt(LowerIndex); + auto *Upper = &linksAt(UpperIndex); + if (Lower == Upper) + return true; + + SmallVector<BuilderLink *, 8> Found; + auto *Current = Lower; + auto Attrs = Current->getAttrs(); + while (Current->hasAbove() && Current != Upper) { + Found.push_back(Current); + Attrs |= Current->getAttrs(); + Current = &linksAt(Current->getAbove()); + } + + if (Current != Upper) + return false; + + Upper->setAttrs(Attrs); + + if (Lower->hasBelow()) { + auto NewBelowIndex = Lower->getBelow(); + Upper->setBelow(NewBelowIndex); + auto &NewBelow = linksAt(NewBelowIndex); + NewBelow.setAbove(UpperIndex); + } else { + Upper->clearBelow(); + } + + for (const auto &Ptr : Found) + Ptr->remapTo(Upper->Number); + + return true; + } + + Optional<const StratifiedInfo *> get(const T &Val) const { + auto Result = Values.find(Val); + if (Result == Values.end()) + return NoneType(); + return &Result->second; + } + + Optional<StratifiedInfo *> get(const T &Val) { + auto Result = Values.find(Val); + if (Result == Values.end()) + return NoneType(); + return &Result->second; + } + + Optional<StratifiedIndex> indexOf(const T &Val) { + auto MaybeVal = get(Val); + if (!MaybeVal.hasValue()) + return NoneType(); + auto *Info = *MaybeVal; + auto &Link = linksAt(Info->Index); + return Link.Number; + } + + StratifiedIndex addLinkBelow(StratifiedIndex Set) { + auto At = addLinks(); + Links[Set].setBelow(At); + Links[At].setAbove(Set); + return At; + } + + StratifiedIndex addLinkAbove(StratifiedIndex Set) { + auto At = addLinks(); + Links[At].setBelow(Set); + Links[Set].setAbove(At); + return At; + } + + StratifiedIndex getNewUnlinkedIndex() { return addLinks(); } + + StratifiedIndex addLinks() { + auto Link = Links.size(); + Links.push_back(BuilderLink(Link)); + return Link; + } + + bool inbounds(StratifiedIndex N) const { return N < Links.size(); } +}; +} +#endif // LLVM_ADT_STRATIFIEDSETS_H diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp index cdb0b79..c1ffb9d 100644 --- a/lib/Analysis/TargetTransformInfo.cpp +++ b/lib/Analysis/TargetTransformInfo.cpp @@ -87,9 +87,10 @@ bool TargetTransformInfo::isLoweredToCall(const Function *F) const { return PrevTTI->isLoweredToCall(F); } -void TargetTransformInfo::getUnrollingPreferences(Loop *L, - UnrollingPreferences &UP) const { - PrevTTI->getUnrollingPreferences(L, UP); +void +TargetTransformInfo::getUnrollingPreferences(const Function *F, Loop *L, + UnrollingPreferences &UP) const { + PrevTTI->getUnrollingPreferences(F, L, UP); } bool TargetTransformInfo::isLegalAddImmediate(int64_t Imm) const { @@ -167,15 +168,16 @@ unsigned TargetTransformInfo::getRegisterBitWidth(bool Vector) const { return PrevTTI->getRegisterBitWidth(Vector); } -unsigned TargetTransformInfo::getMaximumUnrollFactor() const { - return PrevTTI->getMaximumUnrollFactor(); +unsigned TargetTransformInfo::getMaxInterleaveFactor() const { + return PrevTTI->getMaxInterleaveFactor(); } -unsigned TargetTransformInfo::getArithmeticInstrCost(unsigned Opcode, - Type *Ty, - OperandValueKind Op1Info, - OperandValueKind Op2Info) const { - return PrevTTI->getArithmeticInstrCost(Opcode, Ty, Op1Info, Op2Info); +unsigned TargetTransformInfo::getArithmeticInstrCost( + unsigned Opcode, Type *Ty, OperandValueKind Op1Info, + OperandValueKind Op2Info, OperandValueProperties Opd1PropInfo, + OperandValueProperties Opd2PropInfo) const { + return PrevTTI->getArithmeticInstrCost(Opcode, Ty, Op1Info, Op2Info, + Opd1PropInfo, Opd2PropInfo); } unsigned TargetTransformInfo::getShuffleCost(ShuffleKind Kind, Type *Tp, @@ -230,6 +232,11 @@ unsigned TargetTransformInfo::getReductionCost(unsigned Opcode, Type *Ty, return PrevTTI->getReductionCost(Opcode, Ty, IsPairwise); } +unsigned TargetTransformInfo::getCostOfKeepingLiveOverCall(ArrayRef<Type*> Tys) + const { + return PrevTTI->getCostOfKeepingLiveOverCall(Tys); +} + namespace { struct NoTTI final : ImmutablePass, TargetTransformInfo { @@ -239,7 +246,7 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { initializeNoTTIPass(*PassRegistry::getPassRegistry()); } - virtual void initializePass() override { + void initializePass() override { // Note that this subclass is special, and must *not* call initializeTTI as // it does not chain. TopTTI = this; @@ -248,7 +255,7 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { DL = DLP ? &DLP->getDataLayout() : nullptr; } - virtual void getAnalysisUsage(AnalysisUsage &AU) const override { + void getAnalysisUsage(AnalysisUsage &AU) const override { // Note that this subclass is special, and must *not* call // TTI::getAnalysisUsage as it breaks the recursion. } @@ -257,7 +264,7 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { static char ID; /// Provide necessary pointer adjustments for the two base classes. - virtual void *getAdjustedAnalysisPointer(const void *ID) override { + void *getAdjustedAnalysisPointer(const void *ID) override { if (ID == &TargetTransformInfo::ID) return (TargetTransformInfo*)this; return this; @@ -385,6 +392,8 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { // FIXME: This is wrong for libc intrinsics. return TCC_Basic; + case Intrinsic::annotation: + case Intrinsic::assume: case Intrinsic::dbg_declare: case Intrinsic::dbg_value: case Intrinsic::invariant_start: @@ -466,6 +475,8 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { // These will all likely lower to a single selection DAG node. if (Name == "copysign" || Name == "copysignf" || Name == "copysignl" || Name == "fabs" || Name == "fabsf" || Name == "fabsl" || Name == "sin" || + Name == "fmin" || Name == "fminf" || Name == "fminl" || + Name == "fmax" || Name == "fmaxf" || Name == "fmaxl" || Name == "sinf" || Name == "sinl" || Name == "cos" || Name == "cosf" || Name == "cosl" || Name == "sqrt" || Name == "sqrtf" || Name == "sqrtl") return false; @@ -480,8 +491,8 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { return true; } - void getUnrollingPreferences(Loop *, UnrollingPreferences &) const override { - } + void getUnrollingPreferences(const Function *, Loop *, + UnrollingPreferences &) const override {} bool isLegalAddImmediate(int64_t Imm) const override { return false; @@ -558,12 +569,13 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { return 32; } - unsigned getMaximumUnrollFactor() const override { + unsigned getMaxInterleaveFactor() const override { return 1; } unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind, - OperandValueKind) const override { + OperandValueKind, OperandValueProperties, + OperandValueProperties) const override { return 1; } @@ -612,6 +624,11 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { unsigned getReductionCost(unsigned, Type *, bool) const override { return 1; } + + unsigned getCostOfKeepingLiveOverCall(ArrayRef<Type*> Tys) const override { + return 0; + } + }; } // end anonymous namespace diff --git a/lib/Analysis/TypeBasedAliasAnalysis.cpp b/lib/Analysis/TypeBasedAliasAnalysis.cpp index f36f6f8..f347eb5 100644 --- a/lib/Analysis/TypeBasedAliasAnalysis.cpp +++ b/lib/Analysis/TypeBasedAliasAnalysis.cpp @@ -454,9 +454,9 @@ TypeBasedAliasAnalysis::alias(const Location &LocA, // Get the attached MDNodes. If either value lacks a tbaa MDNode, we must // be conservative. - const MDNode *AM = LocA.TBAATag; + const MDNode *AM = LocA.AATags.TBAA; if (!AM) return AliasAnalysis::alias(LocA, LocB); - const MDNode *BM = LocB.TBAATag; + const MDNode *BM = LocB.AATags.TBAA; if (!BM) return AliasAnalysis::alias(LocA, LocB); // If they may alias, chain to the next AliasAnalysis. @@ -472,7 +472,7 @@ bool TypeBasedAliasAnalysis::pointsToConstantMemory(const Location &Loc, if (!EnableTBAA) return AliasAnalysis::pointsToConstantMemory(Loc, OrLocal); - const MDNode *M = Loc.TBAATag; + const MDNode *M = Loc.AATags.TBAA; if (!M) return AliasAnalysis::pointsToConstantMemory(Loc, OrLocal); // If this is an "immutable" type, we can assume the pointer is pointing @@ -513,9 +513,9 @@ TypeBasedAliasAnalysis::getModRefInfo(ImmutableCallSite CS, if (!EnableTBAA) return AliasAnalysis::getModRefInfo(CS, Loc); - if (const MDNode *L = Loc.TBAATag) + if (const MDNode *L = Loc.AATags.TBAA) if (const MDNode *M = - CS.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) + CS.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) if (!Aliases(L, M)) return NoModRef; @@ -529,9 +529,9 @@ TypeBasedAliasAnalysis::getModRefInfo(ImmutableCallSite CS1, return AliasAnalysis::getModRefInfo(CS1, CS2); if (const MDNode *M1 = - CS1.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) + CS1.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) if (const MDNode *M2 = - CS2.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) + CS2.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) if (!Aliases(M1, M2)) return NoModRef; @@ -611,3 +611,24 @@ MDNode *MDNode::getMostGenericTBAA(MDNode *A, MDNode *B) { Value *Ops[3] = { Ret, Ret, ConstantInt::get(Int64, 0) }; return MDNode::get(A->getContext(), Ops); } + +void Instruction::getAAMetadata(AAMDNodes &N, bool Merge) const { + if (Merge) + N.TBAA = + MDNode::getMostGenericTBAA(N.TBAA, getMetadata(LLVMContext::MD_tbaa)); + else + N.TBAA = getMetadata(LLVMContext::MD_tbaa); + + if (Merge) + N.Scope = + MDNode::intersect(N.Scope, getMetadata(LLVMContext::MD_alias_scope)); + else + N.Scope = getMetadata(LLVMContext::MD_alias_scope); + + if (Merge) + N.NoAlias = + MDNode::intersect(N.NoAlias, getMetadata(LLVMContext::MD_noalias)); + else + N.NoAlias = getMetadata(LLVMContext::MD_noalias); +} + diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index 5264745..e9bbf83 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" @@ -20,6 +21,7 @@ #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalVariable.h" @@ -29,6 +31,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include <cstring> using namespace llvm; @@ -36,8 +39,8 @@ using namespace llvm::PatternMatch; const unsigned MaxDepth = 6; -/// getBitWidth - Returns the bitwidth of the given scalar or pointer type (if -/// unknown returns 0). For vector types, returns the element type's bitwidth. +/// Returns the bitwidth of the given scalar or pointer type (if unknown returns +/// 0). For vector types, returns the element type's bitwidth. static unsigned getBitWidth(Type *Ty, const DataLayout *TD) { if (unsigned BitWidth = Ty->getScalarSizeInBits()) return BitWidth; @@ -45,10 +48,125 @@ static unsigned getBitWidth(Type *Ty, const DataLayout *TD) { return TD ? TD->getPointerTypeSizeInBits(Ty) : 0; } +// Many of these functions have internal versions that take an assumption +// exclusion set. This is because of the potential for mutual recursion to +// cause computeKnownBits to repeatedly visit the same assume intrinsic. The +// classic case of this is assume(x = y), which will attempt to determine +// bits in x from bits in y, which will attempt to determine bits in y from +// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call +// isKnownNonZero, which calls computeKnownBits and ComputeSignBit and +// isKnownToBeAPowerOfTwo (all of which can call computeKnownBits), and so on. +typedef SmallPtrSet<const Value *, 8> ExclInvsSet; + +namespace { +// Simplifying using an assume can only be done in a particular control-flow +// context (the context instruction provides that context). If an assume and +// the context instruction are not in the same block then the DT helps in +// figuring out if we can use it. +struct Query { + ExclInvsSet ExclInvs; + AssumptionTracker *AT; + const Instruction *CxtI; + const DominatorTree *DT; + + Query(AssumptionTracker *AT = nullptr, const Instruction *CxtI = nullptr, + const DominatorTree *DT = nullptr) + : AT(AT), CxtI(CxtI), DT(DT) {} + + Query(const Query &Q, const Value *NewExcl) + : ExclInvs(Q.ExclInvs), AT(Q.AT), CxtI(Q.CxtI), DT(Q.DT) { + ExclInvs.insert(NewExcl); + } +}; +} // end anonymous namespace + +// Given the provided Value and, potentially, a context instruction, return +// the preferred context instruction (if any). +static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) { + // If we've been provided with a context instruction, then use that (provided + // it has been inserted). + if (CxtI && CxtI->getParent()) + return CxtI; + + // If the value is really an already-inserted instruction, then use that. + CxtI = dyn_cast<Instruction>(V); + if (CxtI && CxtI->getParent()) + return CxtI; + + return nullptr; +} + +static void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q); + +void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, + const DataLayout *TD, unsigned Depth, + AssumptionTracker *AT, const Instruction *CxtI, + const DominatorTree *DT) { + ::computeKnownBits(V, KnownZero, KnownOne, TD, Depth, + Query(AT, safeCxtI(V, CxtI), DT)); +} + +static void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q); + +void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + const DataLayout *TD, unsigned Depth, + AssumptionTracker *AT, const Instruction *CxtI, + const DominatorTree *DT) { + ::ComputeSignBit(V, KnownZero, KnownOne, TD, Depth, + Query(AT, safeCxtI(V, CxtI), DT)); +} + +static bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, + const Query &Q); + +bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, + AssumptionTracker *AT, + const Instruction *CxtI, + const DominatorTree *DT) { + return ::isKnownToBeAPowerOfTwo(V, OrZero, Depth, + Query(AT, safeCxtI(V, CxtI), DT)); +} + +static bool isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth, + const Query &Q); + +bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth, + AssumptionTracker *AT, const Instruction *CxtI, + const DominatorTree *DT) { + return ::isKnownNonZero(V, TD, Depth, Query(AT, safeCxtI(V, CxtI), DT)); +} + +static bool MaskedValueIsZero(Value *V, const APInt &Mask, + const DataLayout *TD, unsigned Depth, + const Query &Q); + +bool llvm::MaskedValueIsZero(Value *V, const APInt &Mask, + const DataLayout *TD, unsigned Depth, + AssumptionTracker *AT, const Instruction *CxtI, + const DominatorTree *DT) { + return ::MaskedValueIsZero(V, Mask, TD, Depth, + Query(AT, safeCxtI(V, CxtI), DT)); +} + +static unsigned ComputeNumSignBits(Value *V, const DataLayout *TD, + unsigned Depth, const Query &Q); + +unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, + unsigned Depth, AssumptionTracker *AT, + const Instruction *CxtI, + const DominatorTree *DT) { + return ::ComputeNumSignBits(V, TD, Depth, Query(AT, safeCxtI(V, CxtI), DT)); +} + static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2, - const DataLayout *TD, unsigned Depth) { + const DataLayout *TD, unsigned Depth, + const Query &Q) { if (!Add) { if (ConstantInt *CLHS = dyn_cast<ConstantInt>(Op0)) { // We know that the top bits of C-X are clear if X contains less bits @@ -59,7 +177,7 @@ static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, unsigned NLZ = (CLHS->getValue()+1).countLeadingZeros(); // NLZ can't be BitWidth with no sign bit APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ+1); - llvm::computeKnownBits(Op1, KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(Op1, KnownZero2, KnownOne2, TD, Depth+1, Q); // If all of the MaskV bits are known to be zero, then we know the // output top bits are zero, because we now know that the output is @@ -75,55 +193,51 @@ static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, unsigned BitWidth = KnownZero.getBitWidth(); - // If one of the operands has trailing zeros, then the bits that the - // other operand has in those bit positions will be preserved in the - // result. For an add, this works with either operand. For a subtract, - // this only works if the known zeros are in the right operand. + // If an initial sequence of bits in the result is not needed, the + // corresponding bits in the operands are not needed. APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - llvm::computeKnownBits(Op0, LHSKnownZero, LHSKnownOne, TD, Depth+1); - unsigned LHSKnownZeroOut = LHSKnownZero.countTrailingOnes(); - - llvm::computeKnownBits(Op1, KnownZero2, KnownOne2, TD, Depth+1); - unsigned RHSKnownZeroOut = KnownZero2.countTrailingOnes(); - - // Determine which operand has more trailing zeros, and use that - // many bits from the other operand. - if (LHSKnownZeroOut > RHSKnownZeroOut) { - if (Add) { - APInt Mask = APInt::getLowBitsSet(BitWidth, LHSKnownZeroOut); - KnownZero |= KnownZero2 & Mask; - KnownOne |= KnownOne2 & Mask; - } else { - // If the known zeros are in the left operand for a subtract, - // fall back to the minimum known zeros in both operands. - KnownZero |= APInt::getLowBitsSet(BitWidth, - std::min(LHSKnownZeroOut, - RHSKnownZeroOut)); - } - } else if (RHSKnownZeroOut >= LHSKnownZeroOut) { - APInt Mask = APInt::getLowBitsSet(BitWidth, RHSKnownZeroOut); - KnownZero |= LHSKnownZero & Mask; - KnownOne |= LHSKnownOne & Mask; + computeKnownBits(Op0, LHSKnownZero, LHSKnownOne, TD, Depth+1, Q); + computeKnownBits(Op1, KnownZero2, KnownOne2, TD, Depth+1, Q); + + // Carry in a 1 for a subtract, rather than a 0. + APInt CarryIn(BitWidth, 0); + if (!Add) { + // Sum = LHS + ~RHS + 1 + std::swap(KnownZero2, KnownOne2); + CarryIn.setBit(0); } + APInt PossibleSumZero = ~LHSKnownZero + ~KnownZero2 + CarryIn; + APInt PossibleSumOne = LHSKnownOne + KnownOne2 + CarryIn; + + // Compute known bits of the carry. + APInt CarryKnownZero = ~(PossibleSumZero ^ LHSKnownZero ^ KnownZero2); + APInt CarryKnownOne = PossibleSumOne ^ LHSKnownOne ^ KnownOne2; + + // Compute set of known bits (where all three relevant bits are known). + APInt LHSKnown = LHSKnownZero | LHSKnownOne; + APInt RHSKnown = KnownZero2 | KnownOne2; + APInt CarryKnown = CarryKnownZero | CarryKnownOne; + APInt Known = LHSKnown & RHSKnown & CarryKnown; + + assert((PossibleSumZero & Known) == (PossibleSumOne & Known) && + "known bits of sum differ"); + + // Compute known bits of the result. + KnownZero = ~PossibleSumOne & Known; + KnownOne = PossibleSumOne & Known; + // Are we still trying to solve for the sign bit? - if (!KnownZero.isNegative() && !KnownOne.isNegative()) { + if (!Known.isNegative()) { if (NSW) { - if (Add) { - // Adding two positive numbers can't wrap into negative - if (LHSKnownZero.isNegative() && KnownZero2.isNegative()) - KnownZero |= APInt::getSignBit(BitWidth); - // and adding two negative numbers can't wrap into positive. - else if (LHSKnownOne.isNegative() && KnownOne2.isNegative()) - KnownOne |= APInt::getSignBit(BitWidth); - } else { - // Subtracting a negative number from a positive one can't wrap - if (LHSKnownZero.isNegative() && KnownOne2.isNegative()) - KnownZero |= APInt::getSignBit(BitWidth); - // neither can subtracting a positive number from a negative one. - else if (LHSKnownOne.isNegative() && KnownZero2.isNegative()) - KnownOne |= APInt::getSignBit(BitWidth); - } + // Adding two non-negative numbers, or subtracting a negative number from + // a non-negative one, can't wrap into negative. + if (LHSKnownZero.isNegative() && KnownZero2.isNegative()) + KnownZero |= APInt::getSignBit(BitWidth); + // Adding two negative numbers, or subtracting a non-negative number from + // a negative one, can't wrap into non-negative. + else if (LHSKnownOne.isNegative() && KnownOne2.isNegative()) + KnownOne |= APInt::getSignBit(BitWidth); } } } @@ -131,10 +245,11 @@ static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, static void computeKnownBitsMul(Value *Op0, Value *Op1, bool NSW, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2, - const DataLayout *TD, unsigned Depth) { + const DataLayout *TD, unsigned Depth, + const Query &Q) { unsigned BitWidth = KnownZero.getBitWidth(); - computeKnownBits(Op1, KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(Op0, KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(Op1, KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(Op0, KnownZero2, KnownOne2, TD, Depth+1, Q); bool isKnownNegative = false; bool isKnownNonNegative = false; @@ -155,9 +270,9 @@ static void computeKnownBitsMul(Value *Op0, Value *Op1, bool NSW, // negative or zero. if (!isKnownNonNegative) isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 && - isKnownNonZero(Op0, TD, Depth)) || + isKnownNonZero(Op0, TD, Depth, Q)) || (isKnownNegativeOp0 && isKnownNonNegativeOp1 && - isKnownNonZero(Op1, TD, Depth)); + isKnownNonZero(Op1, TD, Depth, Q)); } } @@ -209,6 +324,410 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, KnownZero = APInt::getHighBitsSet(BitWidth, MinLeadingZeros); } +static bool isEphemeralValueOf(Instruction *I, const Value *E) { + SmallVector<const Value *, 16> WorkSet(1, I); + SmallPtrSet<const Value *, 32> Visited; + SmallPtrSet<const Value *, 16> EphValues; + + while (!WorkSet.empty()) { + const Value *V = WorkSet.pop_back_val(); + if (!Visited.insert(V).second) + continue; + + // If all uses of this value are ephemeral, then so is this value. + bool FoundNEUse = false; + for (const User *I : V->users()) + if (!EphValues.count(I)) { + FoundNEUse = true; + break; + } + + if (!FoundNEUse) { + if (V == E) + return true; + + EphValues.insert(V); + if (const User *U = dyn_cast<User>(V)) + for (User::const_op_iterator J = U->op_begin(), JE = U->op_end(); + J != JE; ++J) { + if (isSafeToSpeculativelyExecute(*J)) + WorkSet.push_back(*J); + } + } + } + + return false; +} + +// Is this an intrinsic that cannot be speculated but also cannot trap? +static bool isAssumeLikeIntrinsic(const Instruction *I) { + if (const CallInst *CI = dyn_cast<CallInst>(I)) + if (Function *F = CI->getCalledFunction()) + switch (F->getIntrinsicID()) { + default: break; + // FIXME: This list is repeated from NoTTI::getIntrinsicCost. + case Intrinsic::assume: + case Intrinsic::dbg_declare: + case Intrinsic::dbg_value: + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + case Intrinsic::objectsize: + case Intrinsic::ptr_annotation: + case Intrinsic::var_annotation: + return true; + } + + return false; +} + +static bool isValidAssumeForContext(Value *V, const Query &Q, + const DataLayout *DL) { + Instruction *Inv = cast<Instruction>(V); + + // There are two restrictions on the use of an assume: + // 1. The assume must dominate the context (or the control flow must + // reach the assume whenever it reaches the context). + // 2. The context must not be in the assume's set of ephemeral values + // (otherwise we will use the assume to prove that the condition + // feeding the assume is trivially true, thus causing the removal of + // the assume). + + if (Q.DT) { + if (Q.DT->dominates(Inv, Q.CxtI)) { + return true; + } else if (Inv->getParent() == Q.CxtI->getParent()) { + // The context comes first, but they're both in the same block. Make sure + // there is nothing in between that might interrupt the control flow. + for (BasicBlock::const_iterator I = + std::next(BasicBlock::const_iterator(Q.CxtI)), + IE(Inv); I != IE; ++I) + if (!isSafeToSpeculativelyExecute(I, DL) && + !isAssumeLikeIntrinsic(I)) + return false; + + return !isEphemeralValueOf(Inv, Q.CxtI); + } + + return false; + } + + // When we don't have a DT, we do a limited search... + if (Inv->getParent() == Q.CxtI->getParent()->getSinglePredecessor()) { + return true; + } else if (Inv->getParent() == Q.CxtI->getParent()) { + // Search forward from the assume until we reach the context (or the end + // of the block); the common case is that the assume will come first. + for (BasicBlock::iterator I = std::next(BasicBlock::iterator(Inv)), + IE = Inv->getParent()->end(); I != IE; ++I) + if (I == Q.CxtI) + return true; + + // The context must come first... + for (BasicBlock::const_iterator I = + std::next(BasicBlock::const_iterator(Q.CxtI)), + IE(Inv); I != IE; ++I) + if (!isSafeToSpeculativelyExecute(I, DL) && + !isAssumeLikeIntrinsic(I)) + return false; + + return !isEphemeralValueOf(Inv, Q.CxtI); + } + + return false; +} + +bool llvm::isValidAssumeForContext(const Instruction *I, + const Instruction *CxtI, + const DataLayout *DL, + const DominatorTree *DT) { + return ::isValidAssumeForContext(const_cast<Instruction*>(I), + Query(nullptr, CxtI, DT), DL); +} + +template<typename LHS, typename RHS> +inline match_combine_or<CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>, + CmpClass_match<RHS, LHS, ICmpInst, ICmpInst::Predicate>> +m_c_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) { + return m_CombineOr(m_ICmp(Pred, L, R), m_ICmp(Pred, R, L)); +} + +template<typename LHS, typename RHS> +inline match_combine_or<BinaryOp_match<LHS, RHS, Instruction::And>, + BinaryOp_match<RHS, LHS, Instruction::And>> +m_c_And(const LHS &L, const RHS &R) { + return m_CombineOr(m_And(L, R), m_And(R, L)); +} + +template<typename LHS, typename RHS> +inline match_combine_or<BinaryOp_match<LHS, RHS, Instruction::Or>, + BinaryOp_match<RHS, LHS, Instruction::Or>> +m_c_Or(const LHS &L, const RHS &R) { + return m_CombineOr(m_Or(L, R), m_Or(R, L)); +} + +template<typename LHS, typename RHS> +inline match_combine_or<BinaryOp_match<LHS, RHS, Instruction::Xor>, + BinaryOp_match<RHS, LHS, Instruction::Xor>> +m_c_Xor(const LHS &L, const RHS &R) { + return m_CombineOr(m_Xor(L, R), m_Xor(R, L)); +} + +static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, + APInt &KnownOne, + const DataLayout *DL, + unsigned Depth, const Query &Q) { + // Use of assumptions is context-sensitive. If we don't have a context, we + // cannot use them! + if (!Q.AT || !Q.CxtI) + return; + + unsigned BitWidth = KnownZero.getBitWidth(); + + Function *F = const_cast<Function*>(Q.CxtI->getParent()->getParent()); + for (auto &CI : Q.AT->assumptions(F)) { + CallInst *I = CI; + if (Q.ExclInvs.count(I)) + continue; + + if (match(I, m_Intrinsic<Intrinsic::assume>(m_Specific(V))) && + isValidAssumeForContext(I, Q, DL)) { + assert(BitWidth == 1 && "assume operand is not i1?"); + KnownZero.clearAllBits(); + KnownOne.setAllBits(); + return; + } + + Value *A, *B; + auto m_V = m_CombineOr(m_Specific(V), + m_CombineOr(m_PtrToInt(m_Specific(V)), + m_BitCast(m_Specific(V)))); + + CmpInst::Predicate Pred; + ConstantInt *C; + // assume(v = a) + if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_V, m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + KnownZero |= RHSKnownZero; + KnownOne |= RHSKnownOne; + // assume(v & b = a) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt MaskKnownZero(BitWidth, 0), MaskKnownOne(BitWidth, 0); + computeKnownBits(B, MaskKnownZero, MaskKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in the mask that are known to be one, we can propagate + // known bits from the RHS to V. + KnownZero |= RHSKnownZero & MaskKnownOne; + KnownOne |= RHSKnownOne & MaskKnownOne; + // assume(~(v & b) = a) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), + m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt MaskKnownZero(BitWidth, 0), MaskKnownOne(BitWidth, 0); + computeKnownBits(B, MaskKnownZero, MaskKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in the mask that are known to be one, we can propagate + // inverted known bits from the RHS to V. + KnownZero |= RHSKnownOne & MaskKnownOne; + KnownOne |= RHSKnownZero & MaskKnownOne; + // assume(v | b = a) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); + computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate known + // bits from the RHS to V. + KnownZero |= RHSKnownZero & BKnownZero; + KnownOne |= RHSKnownOne & BKnownZero; + // assume(~(v | b) = a) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), + m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); + computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate + // inverted known bits from the RHS to V. + KnownZero |= RHSKnownOne & BKnownZero; + KnownOne |= RHSKnownZero & BKnownZero; + // assume(v ^ b = a) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); + computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate known + // bits from the RHS to V. For those bits in B that are known to be one, + // we can propagate inverted known bits from the RHS to V. + KnownZero |= RHSKnownZero & BKnownZero; + KnownOne |= RHSKnownOne & BKnownZero; + KnownZero |= RHSKnownOne & BKnownOne; + KnownOne |= RHSKnownZero & BKnownOne; + // assume(~(v ^ b) = a) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), + m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); + computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate + // inverted known bits from the RHS to V. For those bits in B that are + // known to be one, we can propagate known bits from the RHS to V. + KnownZero |= RHSKnownOne & BKnownZero; + KnownOne |= RHSKnownZero & BKnownZero; + KnownZero |= RHSKnownZero & BKnownOne; + KnownOne |= RHSKnownOne & BKnownOne; + // assume(v << c = a) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), + m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them to known + // bits in V shifted to the right by C. + KnownZero |= RHSKnownZero.lshr(C->getZExtValue()); + KnownOne |= RHSKnownOne.lshr(C->getZExtValue()); + // assume(~(v << c) = a) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), + m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them inverted + // to known bits in V shifted to the right by C. + KnownZero |= RHSKnownOne.lshr(C->getZExtValue()); + KnownOne |= RHSKnownZero.lshr(C->getZExtValue()); + // assume(v >> c = a) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_CombineOr(m_LShr(m_V, m_ConstantInt(C)), + m_AShr(m_V, + m_ConstantInt(C))), + m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them to known + // bits in V shifted to the right by C. + KnownZero |= RHSKnownZero << C->getZExtValue(); + KnownOne |= RHSKnownOne << C->getZExtValue(); + // assume(~(v >> c) = a) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_Not(m_CombineOr( + m_LShr(m_V, m_ConstantInt(C)), + m_AShr(m_V, m_ConstantInt(C)))), + m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them inverted + // to known bits in V shifted to the right by C. + KnownZero |= RHSKnownOne << C->getZExtValue(); + KnownOne |= RHSKnownZero << C->getZExtValue(); + // assume(v >=_s c) where c is non-negative + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_ICmp(Pred, m_V, m_Value(A)))) && + Pred == ICmpInst::ICMP_SGE && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + if (RHSKnownZero.isNegative()) { + // We know that the sign bit is zero. + KnownZero |= APInt::getSignBit(BitWidth); + } + // assume(v >_s c) where c is at least -1. + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_ICmp(Pred, m_V, m_Value(A)))) && + Pred == ICmpInst::ICMP_SGT && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + if (RHSKnownOne.isAllOnesValue() || RHSKnownZero.isNegative()) { + // We know that the sign bit is zero. + KnownZero |= APInt::getSignBit(BitWidth); + } + // assume(v <=_s c) where c is negative + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_ICmp(Pred, m_V, m_Value(A)))) && + Pred == ICmpInst::ICMP_SLE && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + if (RHSKnownOne.isNegative()) { + // We know that the sign bit is one. + KnownOne |= APInt::getSignBit(BitWidth); + } + // assume(v <_s c) where c is non-positive + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_ICmp(Pred, m_V, m_Value(A)))) && + Pred == ICmpInst::ICMP_SLT && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + if (RHSKnownZero.isAllOnesValue() || RHSKnownOne.isNegative()) { + // We know that the sign bit is one. + KnownOne |= APInt::getSignBit(BitWidth); + } + // assume(v <=_u c) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_ICmp(Pred, m_V, m_Value(A)))) && + Pred == ICmpInst::ICMP_ULE && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + // Whatever high bits in c are zero are known to be zero. + KnownZero |= + APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()); + // assume(v <_u c) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_ICmp(Pred, m_V, m_Value(A)))) && + Pred == ICmpInst::ICMP_ULT && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + // Whatever high bits in c are zero are known to be zero (if c is a power + // of 2, then one more). + if (isKnownToBeAPowerOfTwo(A, false, Depth+1, Query(Q, I))) + KnownZero |= + APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()+1); + else + KnownZero |= + APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()); + } + } +} + /// Determine which bits of V are known to be either zero or one and return /// them in the KnownZero/KnownOne bit sets. /// @@ -224,8 +743,9 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, /// where V is a vector, known zero, and known one values are the /// same width as the vector element, and the bit is set only if it is true /// for all of the elements in the vector. -void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, - const DataLayout *TD, unsigned Depth) { +void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q) { assert(V && "No Value?"); assert(Depth <= MaxDepth && "Limit Search Depth"); unsigned BitWidth = KnownZero.getBitWidth(); @@ -270,6 +790,17 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, return; } + // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has + // the bits of its aliasee. + if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { + if (GA->mayBeOverridden()) { + KnownZero.clearAllBits(); KnownOne.clearAllBits(); + } else { + computeKnownBits(GA->getAliasee(), KnownZero, KnownOne, TD, Depth+1, Q); + } + return; + } + // The address of an aligned GlobalValue has trailing zeros. if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) { unsigned Align = GV->getAlignment(); @@ -295,25 +826,11 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, KnownOne.clearAllBits(); return; } - // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has - // the bits of its aliasee. - if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { - if (GA->mayBeOverridden()) { - KnownZero.clearAllBits(); KnownOne.clearAllBits(); - } else { - computeKnownBits(GA->getAliasee(), KnownZero, KnownOne, TD, Depth+1); - } - return; - } if (Argument *A = dyn_cast<Argument>(V)) { - unsigned Align = 0; + unsigned Align = A->getType()->isPointerTy() ? A->getParamAlignment() : 0; - if (A->hasByValOrInAllocaAttr()) { - // Get alignment information off byval/inalloca arguments if specified in - // the IR. - Align = A->getParamAlignment(); - } else if (TD && A->hasStructRetAttr()) { + if (!Align && TD && A->hasStructRetAttr()) { // An sret parameter has at least the ABI alignment of the return type. Type *EltTy = cast<PointerType>(A->getType())->getElementType(); if (EltTy->isSized()) @@ -322,6 +839,10 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (Align) KnownZero = APInt::getLowBitsSet(BitWidth, countTrailingZeros(Align)); + + // Don't give up yet... there might be an assumption that provides more + // information... + computeKnownBitsFromAssume(V, KnownZero, KnownOne, TD, Depth, Q); return; } @@ -331,6 +852,9 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (Depth == MaxDepth) return; // Limit search depth. + // Check whether a nearby assume intrinsic can determine some known bits. + computeKnownBitsFromAssume(V, KnownZero, KnownOne, TD, Depth, Q); + Operator *I = dyn_cast<Operator>(V); if (!I) return; @@ -343,8 +867,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; case Instruction::And: { // If either the LHS or the RHS are Zero, the result is zero. - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); // Output known-1 bits are only known if set in both the LHS & RHS. KnownOne &= KnownOne2; @@ -353,8 +877,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; } case Instruction::Or: { - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); // Output known-0 bits are only known if clear in both the LHS & RHS. KnownZero &= KnownZero2; @@ -363,8 +887,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; } case Instruction::Xor: { - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); // Output known-0 bits are known if clear or set in both the LHS & RHS. APInt KnownZeroOut = (KnownZero & KnownZero2) | (KnownOne & KnownOne2); @@ -376,19 +900,20 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, case Instruction::Mul: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, - KnownZero, KnownOne, KnownZero2, KnownOne2, TD, Depth); + KnownZero, KnownOne, KnownZero2, KnownOne2, TD, + Depth, Q); break; } case Instruction::UDiv: { // For the purposes of computing leading zeros we can conservatively // treat a udiv as a logical right shift by the power of 2 known to // be less than the denominator. - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); unsigned LeadZ = KnownZero2.countLeadingOnes(); KnownOne2.clearAllBits(); KnownZero2.clearAllBits(); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1, Q); unsigned RHSUnknownLeadingOnes = KnownOne2.countLeadingZeros(); if (RHSUnknownLeadingOnes != BitWidth) LeadZ = std::min(BitWidth, @@ -398,9 +923,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; } case Instruction::Select: - computeKnownBits(I->getOperand(2), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, - Depth+1); + computeKnownBits(I->getOperand(2), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1, Q); // Only known if known in both the LHS and RHS. KnownOne &= KnownOne2; @@ -415,6 +939,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; // Can't work with floating point. case Instruction::PtrToInt: case Instruction::IntToPtr: + case Instruction::AddrSpaceCast: // Pointers could be different sizes. // We can't handle these if we don't know the pointer size. if (!TD) break; // FALL THROUGH and handle them the same as zext/trunc. @@ -435,7 +960,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, assert(SrcBitWidth && "SrcBitWidth can't be zero"); KnownZero = KnownZero.zextOrTrunc(SrcBitWidth); KnownOne = KnownOne.zextOrTrunc(SrcBitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero = KnownZero.zextOrTrunc(BitWidth); KnownOne = KnownOne.zextOrTrunc(BitWidth); // Any top bits are known to be zero. @@ -449,7 +974,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // TODO: For now, not handling conversions like: // (bitcast i64 %x to <2 x i32>) !I->getType()->isVectorTy()) { - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); break; } break; @@ -460,7 +985,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, KnownZero = KnownZero.trunc(SrcBitWidth); KnownOne = KnownOne.trunc(SrcBitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero = KnownZero.zext(BitWidth); KnownOne = KnownOne.zext(BitWidth); @@ -476,11 +1001,10 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0 if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero <<= ShiftAmt; KnownOne <<= ShiftAmt; KnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt); // low bits known 0 - break; } break; case Instruction::LShr: @@ -490,12 +1014,11 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); // Unsigned shift right. - computeKnownBits(I->getOperand(0), KnownZero,KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); // high bits known zero. KnownZero |= APInt::getHighBitsSet(BitWidth, ShiftAmt); - break; } break; case Instruction::AShr: @@ -505,7 +1028,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); // Signed shift right. - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); @@ -514,21 +1037,20 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, KnownZero |= HighBits; else if (KnownOne[BitWidth-ShiftAmt-1]) // New bits are known one. KnownOne |= HighBits; - break; } break; case Instruction::Sub: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, KnownZero, KnownOne, KnownZero2, KnownOne2, TD, - Depth); + Depth, Q); break; } case Instruction::Add: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, KnownZero, KnownOne, KnownZero2, KnownOne2, TD, - Depth); + Depth, Q); break; } case Instruction::SRem: @@ -536,7 +1058,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, APInt RA = Rem->getValue().abs(); if (RA.isPowerOf2()) { APInt LowBits = RA - 1; - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, + Depth+1, Q); // The low bits of the first operand are unchanged by the srem. KnownZero = KnownZero2 & LowBits; @@ -561,7 +1084,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (KnownZero.isNonNegative()) { APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, TD, - Depth+1); + Depth+1, Q); // If it's known zero, our sign bit is also zero. if (LHSKnownZero.isNegative()) KnownZero.setBit(BitWidth - 1); @@ -574,7 +1097,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (RA.isPowerOf2()) { APInt LowBits = (RA - 1); computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, - Depth+1); + Depth+1, Q); KnownZero |= ~LowBits; KnownOne &= LowBits; break; @@ -583,8 +1106,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // Since the result is less than or equal to either operand, any leading // zero bits in either operand must also exist in the result. - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1, Q); unsigned Leaders = std::max(KnownZero.countLeadingOnes(), KnownZero2.countLeadingOnes()); @@ -608,7 +1131,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // to determine if we can prove known low zero bits. APInt LocalKnownZero(BitWidth, 0), LocalKnownOne(BitWidth, 0); computeKnownBits(I->getOperand(0), LocalKnownZero, LocalKnownOne, TD, - Depth+1); + Depth+1, Q); unsigned TrailZ = LocalKnownZero.countTrailingOnes(); gep_type_iterator GTI = gep_type_begin(I); @@ -644,7 +1167,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); uint64_t TypeSize = TD ? TD->getTypeAllocSize(IndexedTy) : 1; LocalKnownZero = LocalKnownOne = APInt(GEPOpiBits, 0); - computeKnownBits(Index, LocalKnownZero, LocalKnownOne, TD, Depth+1); + computeKnownBits(Index, LocalKnownZero, LocalKnownOne, TD, Depth+1, Q); TrailZ = std::min(TrailZ, unsigned(countTrailingZeros(TypeSize) + LocalKnownZero.countTrailingOnes())); @@ -686,11 +1209,11 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; // Ok, we have a PHI of the form L op= R. Check for low // zero bits. - computeKnownBits(R, KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(R, KnownZero2, KnownOne2, TD, Depth+1, Q); // We need to take the minimum number of known bits APInt KnownZero3(KnownZero), KnownOne3(KnownOne); - computeKnownBits(L, KnownZero3, KnownOne3, TD, Depth+1); + computeKnownBits(L, KnownZero3, KnownOne3, TD, Depth+1, Q); KnownZero = APInt::getLowBitsSet(BitWidth, std::min(KnownZero2.countTrailingOnes(), @@ -722,7 +1245,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // Recurse, but cap the recursion to one level, because we don't // want to waste time spinning around in loops. computeKnownBits(P->getIncomingValue(i), KnownZero2, KnownOne2, TD, - MaxDepth-1); + MaxDepth-1, Q); KnownZero &= KnownZero2; KnownOne &= KnownOne2; // If all bits have been ruled out, there's no need to check @@ -774,19 +1297,19 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, case Intrinsic::sadd_with_overflow: computeKnownBitsAddSub(true, II->getArgOperand(0), II->getArgOperand(1), false, KnownZero, - KnownOne, KnownZero2, KnownOne2, TD, Depth); + KnownOne, KnownZero2, KnownOne2, TD, Depth, Q); break; case Intrinsic::usub_with_overflow: case Intrinsic::ssub_with_overflow: computeKnownBitsAddSub(false, II->getArgOperand(0), II->getArgOperand(1), false, KnownZero, - KnownOne, KnownZero2, KnownOne2, TD, Depth); + KnownOne, KnownZero2, KnownOne2, TD, Depth, Q); break; case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false, KnownZero, KnownOne, - KnownZero2, KnownOne2, TD, Depth); + KnownZero2, KnownOne2, TD, Depth, Q); break; } } @@ -796,10 +1319,11 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); } -/// ComputeSignBit - Determine whether the sign bit is known to be zero or -/// one. Convenience wrapper around computeKnownBits. -void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, - const DataLayout *TD, unsigned Depth) { +/// Determine whether the sign bit is known to be zero or one. +/// Convenience wrapper around computeKnownBits. +void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q) { unsigned BitWidth = getBitWidth(V->getType(), TD); if (!BitWidth) { KnownZero = false; @@ -808,16 +1332,17 @@ void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, } APInt ZeroBits(BitWidth, 0); APInt OneBits(BitWidth, 0); - computeKnownBits(V, ZeroBits, OneBits, TD, Depth); + computeKnownBits(V, ZeroBits, OneBits, TD, Depth, Q); KnownOne = OneBits[BitWidth - 1]; KnownZero = ZeroBits[BitWidth - 1]; } -/// isKnownToBeAPowerOfTwo - Return true if the given value is known to have exactly one +/// Return true if the given value is known to have exactly one /// bit set when defined. For vectors return true if every element is known to -/// be a power of two when defined. Supports values with integer or pointer +/// be a power of two when defined. Supports values with integer or pointer /// types and vectors of integers. -bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { +bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, + const Query &Q) { if (Constant *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) return OrZero; @@ -844,19 +1369,20 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { // A shift of a power of two is a power of two or zero. if (OrZero && (match(V, m_Shl(m_Value(X), m_Value())) || match(V, m_Shr(m_Value(X), m_Value())))) - return isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth); + return isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth, Q); if (ZExtInst *ZI = dyn_cast<ZExtInst>(V)) - return isKnownToBeAPowerOfTwo(ZI->getOperand(0), OrZero, Depth); + return isKnownToBeAPowerOfTwo(ZI->getOperand(0), OrZero, Depth, Q); if (SelectInst *SI = dyn_cast<SelectInst>(V)) - return isKnownToBeAPowerOfTwo(SI->getTrueValue(), OrZero, Depth) && - isKnownToBeAPowerOfTwo(SI->getFalseValue(), OrZero, Depth); + return + isKnownToBeAPowerOfTwo(SI->getTrueValue(), OrZero, Depth, Q) && + isKnownToBeAPowerOfTwo(SI->getFalseValue(), OrZero, Depth, Q); if (OrZero && match(V, m_And(m_Value(X), m_Value(Y)))) { // A power of two and'd with anything is a power of two or zero. - if (isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth) || - isKnownToBeAPowerOfTwo(Y, /*OrZero*/true, Depth)) + if (isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth, Q) || + isKnownToBeAPowerOfTwo(Y, /*OrZero*/true, Depth, Q)) return true; // X & (-X) is always a power of two or zero. if (match(X, m_Neg(m_Specific(Y))) || match(Y, m_Neg(m_Specific(X)))) @@ -871,19 +1397,19 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { if (OrZero || VOBO->hasNoUnsignedWrap() || VOBO->hasNoSignedWrap()) { if (match(X, m_And(m_Specific(Y), m_Value())) || match(X, m_And(m_Value(), m_Specific(Y)))) - if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth)) + if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth, Q)) return true; if (match(Y, m_And(m_Specific(X), m_Value())) || match(Y, m_And(m_Value(), m_Specific(X)))) - if (isKnownToBeAPowerOfTwo(X, OrZero, Depth)) + if (isKnownToBeAPowerOfTwo(X, OrZero, Depth, Q)) return true; unsigned BitWidth = V->getType()->getScalarSizeInBits(); APInt LHSZeroBits(BitWidth, 0), LHSOneBits(BitWidth, 0); - computeKnownBits(X, LHSZeroBits, LHSOneBits, nullptr, Depth); + computeKnownBits(X, LHSZeroBits, LHSOneBits, nullptr, Depth, Q); APInt RHSZeroBits(BitWidth, 0), RHSOneBits(BitWidth, 0); - computeKnownBits(Y, RHSZeroBits, RHSOneBits, nullptr, Depth); + computeKnownBits(Y, RHSZeroBits, RHSOneBits, nullptr, Depth, Q); // If i8 V is a power of two or zero: // ZeroBits: 1 1 1 0 1 1 1 1 // ~ZeroBits: 0 0 0 1 0 0 0 0 @@ -900,7 +1426,8 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { // copying a sign bit (sdiv int_min, 2). if (match(V, m_Exact(m_LShr(m_Value(), m_Value()))) || match(V, m_Exact(m_UDiv(m_Value(), m_Value())))) { - return isKnownToBeAPowerOfTwo(cast<Operator>(V)->getOperand(0), OrZero, Depth); + return isKnownToBeAPowerOfTwo(cast<Operator>(V)->getOperand(0), OrZero, + Depth, Q); } return false; @@ -913,7 +1440,7 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { /// /// Currently this routine does not support vector GEPs. static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout *DL, - unsigned Depth) { + unsigned Depth, const Query &Q) { if (!GEP->isInBounds() || GEP->getPointerAddressSpace() != 0) return false; @@ -922,7 +1449,7 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout *DL, // If the base pointer is non-null, we cannot walk to a null address with an // inbounds GEP in address space zero. - if (isKnownNonZero(GEP->getPointerOperand(), DL, Depth)) + if (isKnownNonZero(GEP->getPointerOperand(), DL, Depth, Q)) return true; // Past this, if we don't have DataLayout, we can't do much. @@ -965,18 +1492,36 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout *DL, if (Depth++ >= MaxDepth) continue; - if (isKnownNonZero(GTI.getOperand(), DL, Depth)) + if (isKnownNonZero(GTI.getOperand(), DL, Depth, Q)) return true; } return false; } -/// isKnownNonZero - Return true if the given value is known to be non-zero -/// when defined. For vectors return true if every element is known to be -/// non-zero when defined. Supports values with integer or pointer type and -/// vectors of integers. -bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { +/// Does the 'Range' metadata (which must be a valid MD_range operand list) +/// ensure that the value it's attached to is never Value? 'RangeType' is +/// is the type of the value described by the range. +static bool rangeMetadataExcludesValue(MDNode* Ranges, + const APInt& Value) { + const unsigned NumRanges = Ranges->getNumOperands() / 2; + assert(NumRanges >= 1); + for (unsigned i = 0; i < NumRanges; ++i) { + ConstantInt *Lower = cast<ConstantInt>(Ranges->getOperand(2*i + 0)); + ConstantInt *Upper = cast<ConstantInt>(Ranges->getOperand(2*i + 1)); + ConstantRange Range(Lower->getValue(), Upper->getValue()); + if (Range.contains(Value)) + return false; + } + return true; +} + +/// Return true if the given value is known to be non-zero when defined. +/// For vectors return true if every element is known to be non-zero when +/// defined. Supports values with integer or pointer type and vectors of +/// integers. +bool isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth, + const Query &Q) { if (Constant *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) return false; @@ -987,6 +1532,18 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { return false; } + if (Instruction* I = dyn_cast<Instruction>(V)) { + if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) { + // If the possible ranges don't contain zero, then the value is + // definitely non-zero. + if (IntegerType* Ty = dyn_cast<IntegerType>(V->getType())) { + const APInt ZeroValue(Ty->getBitWidth(), 0); + if (rangeMetadataExcludesValue(Ranges, ZeroValue)) + return true; + } + } + } + // The remaining tests are all recursive, so bail out if we hit the limit. if (Depth++ >= MaxDepth) return false; @@ -996,7 +1553,7 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { if (isKnownNonNull(V)) return true; if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) - if (isGEPKnownNonNull(GEP, TD, Depth)) + if (isGEPKnownNonNull(GEP, TD, Depth, Q)) return true; } @@ -1005,11 +1562,12 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // X | Y != 0 if X != 0 or Y != 0. Value *X = nullptr, *Y = nullptr; if (match(V, m_Or(m_Value(X), m_Value(Y)))) - return isKnownNonZero(X, TD, Depth) || isKnownNonZero(Y, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q) || + isKnownNonZero(Y, TD, Depth, Q); // ext X != 0 if X != 0. if (isa<SExtInst>(V) || isa<ZExtInst>(V)) - return isKnownNonZero(cast<Instruction>(V)->getOperand(0), TD, Depth); + return isKnownNonZero(cast<Instruction>(V)->getOperand(0), TD, Depth, Q); // shl X, Y != 0 if X is odd. Note that the value of the shift is undefined // if the lowest bit is shifted off the end. @@ -1017,11 +1575,11 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // shl nuw can't remove any non-zero bits. OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); if (BO->hasNoUnsignedWrap()) - return isKnownNonZero(X, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(X, KnownZero, KnownOne, TD, Depth); + computeKnownBits(X, KnownZero, KnownOne, TD, Depth, Q); if (KnownOne[0]) return true; } @@ -1031,28 +1589,29 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // shr exact can only shift out zero bits. PossiblyExactOperator *BO = cast<PossiblyExactOperator>(V); if (BO->isExact()) - return isKnownNonZero(X, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q); bool XKnownNonNegative, XKnownNegative; - ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth); + ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth, Q); if (XKnownNegative) return true; } // div exact can only produce a zero if the dividend is zero. else if (match(V, m_Exact(m_IDiv(m_Value(X), m_Value())))) { - return isKnownNonZero(X, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q); } // X + Y. else if (match(V, m_Add(m_Value(X), m_Value(Y)))) { bool XKnownNonNegative, XKnownNegative; bool YKnownNonNegative, YKnownNegative; - ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth); - ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, TD, Depth); + ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth, Q); + ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, TD, Depth, Q); // If X and Y are both non-negative (as signed values) then their sum is not // zero unless both X and Y are zero. if (XKnownNonNegative && YKnownNonNegative) - if (isKnownNonZero(X, TD, Depth) || isKnownNonZero(Y, TD, Depth)) + if (isKnownNonZero(X, TD, Depth, Q) || + isKnownNonZero(Y, TD, Depth, Q)) return true; // If X and Y are both negative (as signed values) then their sum is not @@ -1063,20 +1622,22 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { APInt Mask = APInt::getSignedMaxValue(BitWidth); // The sign bit of X is set. If some other bit is set then X is not equal // to INT_MIN. - computeKnownBits(X, KnownZero, KnownOne, TD, Depth); + computeKnownBits(X, KnownZero, KnownOne, TD, Depth, Q); if ((KnownOne & Mask) != 0) return true; // The sign bit of Y is set. If some other bit is set then Y is not equal // to INT_MIN. - computeKnownBits(Y, KnownZero, KnownOne, TD, Depth); + computeKnownBits(Y, KnownZero, KnownOne, TD, Depth, Q); if ((KnownOne & Mask) != 0) return true; } // The sum of a non-negative number and a power of two is not zero. - if (XKnownNonNegative && isKnownToBeAPowerOfTwo(Y, /*OrZero*/false, Depth)) + if (XKnownNonNegative && + isKnownToBeAPowerOfTwo(Y, /*OrZero*/false, Depth, Q)) return true; - if (YKnownNonNegative && isKnownToBeAPowerOfTwo(X, /*OrZero*/false, Depth)) + if (YKnownNonNegative && + isKnownToBeAPowerOfTwo(X, /*OrZero*/false, Depth, Q)) return true; } // X * Y. @@ -1085,51 +1646,53 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // If X and Y are non-zero then so is X * Y as long as the multiplication // does not overflow. if ((BO->hasNoSignedWrap() || BO->hasNoUnsignedWrap()) && - isKnownNonZero(X, TD, Depth) && isKnownNonZero(Y, TD, Depth)) + isKnownNonZero(X, TD, Depth, Q) && + isKnownNonZero(Y, TD, Depth, Q)) return true; } // (C ? X : Y) != 0 if X != 0 and Y != 0. else if (SelectInst *SI = dyn_cast<SelectInst>(V)) { - if (isKnownNonZero(SI->getTrueValue(), TD, Depth) && - isKnownNonZero(SI->getFalseValue(), TD, Depth)) + if (isKnownNonZero(SI->getTrueValue(), TD, Depth, Q) && + isKnownNonZero(SI->getFalseValue(), TD, Depth, Q)) return true; } if (!BitWidth) return false; APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, TD, Depth); + computeKnownBits(V, KnownZero, KnownOne, TD, Depth, Q); return KnownOne != 0; } -/// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero. We use -/// this predicate to simplify operations downstream. Mask is known to be zero -/// for bits that V cannot have. +/// Return true if 'V & Mask' is known to be zero. We use this predicate to +/// simplify operations downstream. Mask is known to be zero for bits that V +/// cannot have. /// /// This function is defined on values with integer type, values with pointer /// type (but only if TD is non-null), and vectors of integers. In the case /// where V is a vector, the mask, known zero, and known one values are the /// same width as the vector element, and the bit is set only if it is true /// for all of the elements in the vector. -bool llvm::MaskedValueIsZero(Value *V, const APInt &Mask, - const DataLayout *TD, unsigned Depth) { +bool MaskedValueIsZero(Value *V, const APInt &Mask, + const DataLayout *TD, unsigned Depth, + const Query &Q) { APInt KnownZero(Mask.getBitWidth(), 0), KnownOne(Mask.getBitWidth(), 0); - computeKnownBits(V, KnownZero, KnownOne, TD, Depth); + computeKnownBits(V, KnownZero, KnownOne, TD, Depth, Q); return (KnownZero & Mask) == Mask; } -/// ComputeNumSignBits - Return the number of times the sign bit of the -/// register is replicated into the other bits. We know that at least 1 bit -/// is always equal to the sign bit (itself), but other cases can give us -/// information. For example, immediately after an "ashr X, 2", we know that -/// the top 3 bits are all equal to each other, so we return 3. +/// Return the number of times the sign bit of the register is replicated into +/// the other bits. We know that at least 1 bit is always equal to the sign bit +/// (itself), but other cases can give us information. For example, immediately +/// after an "ashr X, 2", we know that the top 3 bits are all equal to each +/// other, so we return 3. /// /// 'Op' must have a scalar integer type. /// -unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, - unsigned Depth) { +unsigned ComputeNumSignBits(Value *V, const DataLayout *TD, + unsigned Depth, const Query &Q) { assert((TD || V->getType()->isIntOrIntVectorTy()) && "ComputeNumSignBits requires a DataLayout object to operate " "on non-integer values!"); @@ -1150,10 +1713,10 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, default: break; case Instruction::SExt: Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits(); - return ComputeNumSignBits(U->getOperand(0), TD, Depth+1) + Tmp; + return ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q) + Tmp; case Instruction::AShr: { - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); // ashr X, C -> adds C sign bits. Vectors too. const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { @@ -1166,7 +1729,7 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { // shl destroys sign bits. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); Tmp2 = ShAmt->getZExtValue(); if (Tmp2 >= TyBits || // Bad shift. Tmp2 >= Tmp) break; // Shifted all sign bits out. @@ -1178,9 +1741,9 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, case Instruction::Or: case Instruction::Xor: // NOT is handled here. // Logical binary ops preserve the number of sign bits at the worst. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); if (Tmp != 1) { - Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); FirstAnswer = std::min(Tmp, Tmp2); // We computed what we know about the sign bits as our first // answer. Now proceed to the generic code that uses @@ -1189,22 +1752,22 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, break; case Instruction::Select: - Tmp = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); if (Tmp == 1) return 1; // Early out. - Tmp2 = ComputeNumSignBits(U->getOperand(2), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(2), TD, Depth+1, Q); return std::min(Tmp, Tmp2); case Instruction::Add: // Add can have at most one carry bit. Thus we know that the output // is, at worst, one more bit than the inputs. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); if (Tmp == 1) return 1; // Early out. // Special case decrementing a value (ADD X, -1): if (ConstantInt *CRHS = dyn_cast<ConstantInt>(U->getOperand(1))) if (CRHS->isAllOnesValue()) { APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is all // sign bits set. @@ -1217,19 +1780,19 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, return Tmp; } - Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); if (Tmp2 == 1) return 1; return std::min(Tmp, Tmp2)-1; case Instruction::Sub: - Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); if (Tmp2 == 1) return 1; // Handle NEG. if (ConstantInt *CLHS = dyn_cast<ConstantInt>(U->getOperand(0))) if (CLHS->isNullValue()) { APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - computeKnownBits(U->getOperand(1), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(U->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is all // sign bits set. if ((KnownZero | APInt(TyBits, 1)).isAllOnesValue()) @@ -1245,7 +1808,7 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, // Sub can have at most one carry bit. Thus we know that the output // is, at worst, one more bit than the inputs. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); if (Tmp == 1) return 1; // Early out. return std::min(Tmp, Tmp2)-1; @@ -1256,11 +1819,12 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, // Take the minimum of all incoming values. This can't infinitely loop // because of our depth threshold. - Tmp = ComputeNumSignBits(PN->getIncomingValue(0), TD, Depth+1); + Tmp = ComputeNumSignBits(PN->getIncomingValue(0), TD, Depth+1, Q); for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) { if (Tmp == 1) return Tmp; Tmp = std::min(Tmp, - ComputeNumSignBits(PN->getIncomingValue(i), TD, Depth+1)); + ComputeNumSignBits(PN->getIncomingValue(i), TD, + Depth+1, Q)); } return Tmp; } @@ -1275,7 +1839,7 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, // use this information. APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); APInt Mask; - computeKnownBits(V, KnownZero, KnownOne, TD, Depth); + computeKnownBits(V, KnownZero, KnownOne, TD, Depth, Q); if (KnownZero.isNegative()) { // sign bit is 0 Mask = KnownZero; @@ -1295,9 +1859,9 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, return std::max(FirstAnswer, std::min(TyBits, Mask.countLeadingZeros())); } -/// ComputeMultiple - This function computes the integer multiple of Base that -/// equals V. If successful, it returns true and returns the multiple in -/// Multiple. If unsuccessful, it returns false. It looks +/// This function computes the integer multiple of Base that equals V. +/// If successful, it returns true and returns the multiple in +/// Multiple. If unsuccessful, it returns false. It looks /// through SExt instructions only if LookThroughSExt is true. bool llvm::ComputeMultiple(Value *V, unsigned Base, Value *&Multiple, bool LookThroughSExt, unsigned Depth) { @@ -1415,8 +1979,8 @@ bool llvm::ComputeMultiple(Value *V, unsigned Base, Value *&Multiple, return false; } -/// CannotBeNegativeZero - Return true if we can prove that the specified FP -/// value is never equal to -0.0. +/// Return true if we can prove that the specified FP value is never equal to +/// -0.0. /// /// NOTE: this function will need to be revisited when we support non-default /// rounding modes! @@ -1469,8 +2033,8 @@ bool llvm::CannotBeNegativeZero(const Value *V, unsigned Depth) { return false; } -/// isBytewiseValue - If the specified value can be set by repeating the same -/// byte in memory, return the i8 value that it is represented with. This is +/// If the specified value can be set by repeating the same byte in memory, +/// return the i8 value that it is represented with. This is /// true for all i8 values obviously, but is also true for i32 0, i32 -1, /// i16 0xF0F0, double 0.0 etc. If the value can't be handled with a repeated /// byte store (e.g. i16 0x1234), return null. @@ -1618,7 +2182,7 @@ static Value *BuildSubAggregate(Value *From, ArrayRef<unsigned> idx_range, return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore); } -/// FindInsertedValue - Given an aggregrate and an sequence of indices, see if +/// Given an aggregrate and an sequence of indices, see if /// the scalar value indexed is already around as a register, for example if it /// were inserted directly into the aggregrate. /// @@ -1708,9 +2272,8 @@ Value *llvm::FindInsertedValue(Value *V, ArrayRef<unsigned> idx_range, return nullptr; } -/// GetPointerBaseWithConstantOffset - Analyze the specified pointer to see if -/// it can be expressed as a base pointer plus a constant offset. Return the -/// base and offset to the caller. +/// Analyze the specified pointer to see if it can be expressed as a base +/// pointer plus a constant offset. Return the base and offset to the caller. Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout *DL) { // Without DataLayout, conservatively assume 64-bit offsets, which is @@ -1731,7 +2294,8 @@ Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, } Ptr = GEP->getPointerOperand(); - } else if (Operator::getOpcode(Ptr) == Instruction::BitCast) { + } else if (Operator::getOpcode(Ptr) == Instruction::BitCast || + Operator::getOpcode(Ptr) == Instruction::AddrSpaceCast) { Ptr = cast<Operator>(Ptr)->getOperand(0); } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(Ptr)) { if (GA->mayBeOverridden()) @@ -1746,9 +2310,9 @@ Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, } -/// getConstantStringInfo - This function computes the length of a -/// null-terminated C string pointed to by V. If successful, it returns true -/// and returns the string in Str. If unsuccessful, it returns false. +/// This function computes the length of a null-terminated C string pointed to +/// by V. If successful, it returns true and returns the string in Str. +/// If unsuccessful, it returns false. bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, uint64_t Offset, bool TrimAtNul) { assert(V); @@ -1832,16 +2396,16 @@ bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, // nodes. // TODO: See if we can integrate these two together. -/// GetStringLengthH - If we can compute the length of the string pointed to by +/// If we can compute the length of the string pointed to by /// the specified pointer, return 'len+1'. If we can't, return 0. -static uint64_t GetStringLengthH(Value *V, SmallPtrSet<PHINode*, 32> &PHIs) { +static uint64_t GetStringLengthH(Value *V, SmallPtrSetImpl<PHINode*> &PHIs) { // Look through noop bitcast instructions. V = V->stripPointerCasts(); // If this is a PHI node, there are two cases: either we have already seen it // or we haven't. if (PHINode *PN = dyn_cast<PHINode>(V)) { - if (!PHIs.insert(PN)) + if (!PHIs.insert(PN).second) return ~0ULL; // already in the set. // If it was new, see if all the input strings are the same length. @@ -1881,7 +2445,7 @@ static uint64_t GetStringLengthH(Value *V, SmallPtrSet<PHINode*, 32> &PHIs) { return StrData.size()+1; } -/// GetStringLength - If we can compute the length of the string pointed to by +/// If we can compute the length of the string pointed to by /// the specified pointer, return 'len+1'. If we can't, return 0. uint64_t llvm::GetStringLength(Value *V) { if (!V->getType()->isPointerTy()) return 0; @@ -1900,7 +2464,8 @@ llvm::GetUnderlyingObject(Value *V, const DataLayout *TD, unsigned MaxLookup) { for (unsigned Count = 0; MaxLookup == 0 || Count < MaxLookup; ++Count) { if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { V = GEP->getPointerOperand(); - } else if (Operator::getOpcode(V) == Instruction::BitCast) { + } else if (Operator::getOpcode(V) == Instruction::BitCast || + Operator::getOpcode(V) == Instruction::AddrSpaceCast) { V = cast<Operator>(V)->getOperand(0); } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { if (GA->mayBeOverridden()) @@ -1909,7 +2474,7 @@ llvm::GetUnderlyingObject(Value *V, const DataLayout *TD, unsigned MaxLookup) { } else { // See if InstructionSimplify knows any relevant tricks. if (Instruction *I = dyn_cast<Instruction>(V)) - // TODO: Acquire a DominatorTree and use it. + // TODO: Acquire a DominatorTree and AssumptionTracker and use them. if (Value *Simplified = SimplifyInstruction(I, TD, nullptr)) { V = Simplified; continue; @@ -1934,7 +2499,7 @@ llvm::GetUnderlyingObjects(Value *V, Value *P = Worklist.pop_back_val(); P = GetUnderlyingObject(P, TD, MaxLookup); - if (!Visited.insert(P)) + if (!Visited.insert(P).second) continue; if (SelectInst *SI = dyn_cast<SelectInst>(P)) { @@ -1953,9 +2518,7 @@ llvm::GetUnderlyingObjects(Value *V, } while (!Worklist.empty()); } -/// onlyUsedByLifetimeMarkers - Return true if the only users of this pointer -/// are lifetime markers. -/// +/// Return true if the only users of this pointer are lifetime markers. bool llvm::onlyUsedByLifetimeMarkers(const Value *V) { for (const User *U : V->users()) { const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); @@ -1983,23 +2546,31 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, default: return true; case Instruction::UDiv: - case Instruction::URem: - // x / y is undefined if y == 0, but calculations like x / 3 are safe. - return isKnownNonZero(Inst->getOperand(1), TD); + case Instruction::URem: { + // x / y is undefined if y == 0. + const APInt *V; + if (match(Inst->getOperand(1), m_APInt(V))) + return *V != 0; + return false; + } case Instruction::SDiv: case Instruction::SRem: { - Value *Op = Inst->getOperand(1); - // x / y is undefined if y == 0 - if (!isKnownNonZero(Op, TD)) - return false; - // x / y might be undefined if y == -1 - unsigned BitWidth = getBitWidth(Op->getType(), TD); - if (BitWidth == 0) - return false; - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(Op, KnownZero, KnownOne, TD); - return !!KnownZero; + // x / y is undefined if y == 0 or x == INT_MIN and y == -1 + const APInt *X, *Y; + if (match(Inst->getOperand(1), m_APInt(Y))) { + if (*Y != 0) { + if (*Y == -1) { + // The numerator can't be MinSignedValue if the denominator is -1. + if (match(Inst->getOperand(0), m_APInt(X))) + return !Y->isMinSignedValue(); + // The numerator *might* be MinSignedValue. + return false; + } + // The denominator is not 0 or -1, it's safe to proceed. + return true; + } + } + return false; } case Instruction::Load: { const LoadInst *LI = cast<LoadInst>(Inst); @@ -2010,41 +2581,44 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, return LI->getPointerOperand()->isDereferenceablePointer(TD); } case Instruction::Call: { - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { - switch (II->getIntrinsicID()) { - // These synthetic intrinsics have no side-effects and just mark - // information about their operands. - // FIXME: There are other no-op synthetic instructions that potentially - // should be considered at least *safe* to speculate... - case Intrinsic::dbg_declare: - case Intrinsic::dbg_value: - return true; - - case Intrinsic::bswap: - case Intrinsic::ctlz: - case Intrinsic::ctpop: - case Intrinsic::cttz: - case Intrinsic::objectsize: - case Intrinsic::sadd_with_overflow: - case Intrinsic::smul_with_overflow: - case Intrinsic::ssub_with_overflow: - case Intrinsic::uadd_with_overflow: - case Intrinsic::umul_with_overflow: - case Intrinsic::usub_with_overflow: - return true; - // Sqrt should be OK, since the llvm sqrt intrinsic isn't defined to set - // errno like libm sqrt would. - case Intrinsic::sqrt: - case Intrinsic::fma: - case Intrinsic::fmuladd: - return true; - // TODO: some fp intrinsics are marked as having the same error handling - // as libm. They're safe to speculate when they won't error. - // TODO: are convert_{from,to}_fp16 safe? - // TODO: can we list target-specific intrinsics here? - default: break; - } - } + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + switch (II->getIntrinsicID()) { + // These synthetic intrinsics have no side-effects and just mark + // information about their operands. + // FIXME: There are other no-op synthetic instructions that potentially + // should be considered at least *safe* to speculate... + case Intrinsic::dbg_declare: + case Intrinsic::dbg_value: + return true; + + case Intrinsic::bswap: + case Intrinsic::ctlz: + case Intrinsic::ctpop: + case Intrinsic::cttz: + case Intrinsic::objectsize: + case Intrinsic::sadd_with_overflow: + case Intrinsic::smul_with_overflow: + case Intrinsic::ssub_with_overflow: + case Intrinsic::uadd_with_overflow: + case Intrinsic::umul_with_overflow: + case Intrinsic::usub_with_overflow: + return true; + // Sqrt should be OK, since the llvm sqrt intrinsic isn't defined to set + // errno like libm sqrt would. + case Intrinsic::sqrt: + case Intrinsic::fma: + case Intrinsic::fmuladd: + case Intrinsic::fabs: + case Intrinsic::minnum: + case Intrinsic::maxnum: + return true; + // TODO: some fp intrinsics are marked as having the same error handling + // as libm. They're safe to speculate when they won't error. + // TODO: are convert_{from,to}_fp16 safe? + // TODO: can we list target-specific intrinsics here? + default: break; + } + } return false; // The called function could have undefined behavior or // side-effects, even if marked readnone nounwind. } @@ -2067,8 +2641,7 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, } } -/// isKnownNonNull - Return true if we know that the specified value is never -/// null. +/// Return true if we know that the specified value is never null. bool llvm::isKnownNonNull(const Value *V, const TargetLibraryInfo *TLI) { // Alloca never returns null, malloc might. if (isa<AllocaInst>(V)) return true; @@ -2081,8 +2654,12 @@ bool llvm::isKnownNonNull(const Value *V, const TargetLibraryInfo *TLI) { if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) return !GV->hasExternalWeakLinkage(); + // A Load tagged w/nonnull metadata is never null. + if (const LoadInst *LI = dyn_cast<LoadInst>(V)) + return LI->getMetadata(LLVMContext::MD_nonnull); + if (ImmutableCallSite CS = V) - if (CS.paramHasAttr(0, Attribute::NonNull)) + if (CS.isReturnNonNull()) return true; // operator new never returns null. |