aboutsummaryrefslogtreecommitdiffstats
path: root/lib/CodeGen/BranchFolding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/CodeGen/BranchFolding.cpp')
-rw-r--r--lib/CodeGen/BranchFolding.cpp92
1 files changed, 78 insertions, 14 deletions
diff --git a/lib/CodeGen/BranchFolding.cpp b/lib/CodeGen/BranchFolding.cpp
index 7503e57..2128da1 100644
--- a/lib/CodeGen/BranchFolding.cpp
+++ b/lib/CodeGen/BranchFolding.cpp
@@ -20,6 +20,8 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
+#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
@@ -32,8 +34,8 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetInstrInfo.h"
-#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetRegisterInfo.h"
+#include "llvm/Target/TargetSubtargetInfo.h"
#include <algorithm>
using namespace llvm;
@@ -70,6 +72,8 @@ namespace {
bool runOnMachineFunction(MachineFunction &MF) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<MachineBlockFrequencyInfo>();
+ AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<TargetPassConfig>();
MachineFunctionPass::getAnalysisUsage(AU);
}
@@ -91,22 +95,24 @@ bool BranchFolderPass::runOnMachineFunction(MachineFunction &MF) {
// HW that requires structurized CFG.
bool EnableTailMerge = !MF.getTarget().requiresStructuredCFG() &&
PassConfig->getEnableTailMerge();
- BranchFolder Folder(EnableTailMerge, /*CommonHoist=*/true);
- return Folder.OptimizeFunction(MF,
- MF.getTarget().getInstrInfo(),
- MF.getTarget().getRegisterInfo(),
+ BranchFolder Folder(EnableTailMerge, /*CommonHoist=*/true,
+ getAnalysis<MachineBlockFrequencyInfo>(),
+ getAnalysis<MachineBranchProbabilityInfo>());
+ return Folder.OptimizeFunction(MF, MF.getSubtarget().getInstrInfo(),
+ MF.getSubtarget().getRegisterInfo(),
getAnalysisIfAvailable<MachineModuleInfo>());
}
-
-BranchFolder::BranchFolder(bool defaultEnableTailMerge, bool CommonHoist) {
+BranchFolder::BranchFolder(bool defaultEnableTailMerge, bool CommonHoist,
+ const MachineBlockFrequencyInfo &FreqInfo,
+ const MachineBranchProbabilityInfo &ProbInfo)
+ : EnableHoistCommonCode(CommonHoist), MBBFreqInfo(FreqInfo),
+ MBPI(ProbInfo) {
switch (FlagEnableTailMerge) {
case cl::BOU_UNSET: EnableTailMerge = defaultEnableTailMerge; break;
case cl::BOU_TRUE: EnableTailMerge = true; break;
case cl::BOU_FALSE: EnableTailMerge = false; break;
}
-
- EnableHoistCommonCode = CommonHoist;
}
/// RemoveDeadBlock - Remove the specified dead machine basic block from the
@@ -388,10 +394,8 @@ void BranchFolder::MaintainLiveIns(MachineBasicBlock *CurMBB,
RS->enterBasicBlock(CurMBB);
if (!CurMBB->empty())
RS->forward(std::prev(CurMBB->end()));
- BitVector RegsLiveAtExit(TRI->getNumRegs());
- RS->getRegsUsed(RegsLiveAtExit, false);
- for (unsigned int i = 0, e = TRI->getNumRegs(); i != e; i++)
- if (RegsLiveAtExit[i])
+ for (unsigned int i = 1, e = TRI->getNumRegs(); i != e; i++)
+ if (RS->isRegUsed(i, false))
NewMBB->addLiveIn(i);
}
}
@@ -435,6 +439,9 @@ MachineBasicBlock *BranchFolder::SplitMBBAt(MachineBasicBlock &CurMBB,
// Splice the code over.
NewMBB->splice(NewMBB->end(), &CurMBB, BBI1, CurMBB.end());
+ // NewMBB inherits CurMBB's block frequency.
+ MBBFreqInfo.setBlockFreq(NewMBB, MBBFreqInfo.getBlockFreq(&CurMBB));
+
// For targets that use the register scavenger, we must maintain LiveIns.
MaintainLiveIns(&CurMBB, NewMBB);
@@ -504,6 +511,21 @@ BranchFolder::MergePotentialsElt::operator<(const MergePotentialsElt &o) const {
#endif
}
+BlockFrequency
+BranchFolder::MBFIWrapper::getBlockFreq(const MachineBasicBlock *MBB) const {
+ auto I = MergedBBFreq.find(MBB);
+
+ if (I != MergedBBFreq.end())
+ return I->second;
+
+ return MBFI.getBlockFreq(MBB);
+}
+
+void BranchFolder::MBFIWrapper::setBlockFreq(const MachineBasicBlock *MBB,
+ BlockFrequency F) {
+ MergedBBFreq[MBB] = F;
+}
+
/// CountTerminators - Count the number of terminators in the given
/// block and set I to the position of the first non-terminator, if there
/// is one, or MBB->end() otherwise.
@@ -806,6 +828,10 @@ bool BranchFolder::TryTailMergeBlocks(MachineBasicBlock *SuccBB,
}
MachineBasicBlock *MBB = SameTails[commonTailIndex].getBlock();
+
+ // Recompute commont tail MBB's edge weights and block frequency.
+ setCommonTailEdgeWeights(*MBB);
+
// MBB is common tail. Adjust all other BB's to jump to this one.
// Traversal must be forwards so erases work.
DEBUG(dbgs() << "\nUsing common tail in BB#" << MBB->getNumber()
@@ -890,7 +916,7 @@ bool BranchFolder::TailMergeBlocks(MachineFunction &MF) {
continue;
// Visit each predecessor only once.
- if (!UniquePreds.insert(PBB))
+ if (!UniquePreds.insert(PBB).second)
continue;
// Skip blocks which may jump to a landing pad. Can't tail merge these.
@@ -968,6 +994,44 @@ bool BranchFolder::TailMergeBlocks(MachineFunction &MF) {
return MadeChange;
}
+void BranchFolder::setCommonTailEdgeWeights(MachineBasicBlock &TailMBB) {
+ SmallVector<BlockFrequency, 2> EdgeFreqLs(TailMBB.succ_size());
+ BlockFrequency AccumulatedMBBFreq;
+
+ // Aggregate edge frequency of successor edge j:
+ // edgeFreq(j) = sum (freq(bb) * edgeProb(bb, j)),
+ // where bb is a basic block that is in SameTails.
+ for (const auto &Src : SameTails) {
+ const MachineBasicBlock *SrcMBB = Src.getBlock();
+ BlockFrequency BlockFreq = MBBFreqInfo.getBlockFreq(SrcMBB);
+ AccumulatedMBBFreq += BlockFreq;
+
+ // It is not necessary to recompute edge weights if TailBB has less than two
+ // successors.
+ if (TailMBB.succ_size() <= 1)
+ continue;
+
+ auto EdgeFreq = EdgeFreqLs.begin();
+
+ for (auto SuccI = TailMBB.succ_begin(), SuccE = TailMBB.succ_end();
+ SuccI != SuccE; ++SuccI, ++EdgeFreq)
+ *EdgeFreq += BlockFreq * MBPI.getEdgeProbability(SrcMBB, *SuccI);
+ }
+
+ MBBFreqInfo.setBlockFreq(&TailMBB, AccumulatedMBBFreq);
+
+ if (TailMBB.succ_size() <= 1)
+ return;
+
+ auto MaxEdgeFreq = *std::max_element(EdgeFreqLs.begin(), EdgeFreqLs.end());
+ uint64_t Scale = MaxEdgeFreq.getFrequency() / UINT32_MAX + 1;
+ auto EdgeFreq = EdgeFreqLs.begin();
+
+ for (auto SuccI = TailMBB.succ_begin(), SuccE = TailMBB.succ_end();
+ SuccI != SuccE; ++SuccI, ++EdgeFreq)
+ TailMBB.setSuccWeight(SuccI, EdgeFreq->getFrequency() / Scale);
+}
+
//===----------------------------------------------------------------------===//
// Branch Optimization
//===----------------------------------------------------------------------===//