aboutsummaryrefslogtreecommitdiffstats
path: root/lib/CodeGen/CalcSpillWeights.cpp
blob: dcffb8a247a9f1873522c0621e19b439bdbe378a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
//===------------------------ CalcSpillWeights.cpp ------------------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#define DEBUG_TYPE "calcspillweights"

#include "llvm/Function.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/CodeGen/CalcSpillWeights.h"
#include "llvm/CodeGen/LiveIntervalAnalysis.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineLoopInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/SlotIndexes.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetInstrInfo.h"
#include "llvm/Target/TargetRegisterInfo.h"

using namespace llvm;

char CalculateSpillWeights::ID = 0;
static RegisterPass<CalculateSpillWeights> X("calcspillweights",
                                             "Calculate spill weights");

void CalculateSpillWeights::getAnalysisUsage(AnalysisUsage &au) const {
  au.addRequired<LiveIntervals>();
  au.addRequired<MachineLoopInfo>();
  au.setPreservesAll();
  MachineFunctionPass::getAnalysisUsage(au);
}

bool CalculateSpillWeights::runOnMachineFunction(MachineFunction &fn) {

  DEBUG(errs() << "********** Compute Spill Weights **********\n"
               << "********** Function: "
               << fn.getFunction()->getName() << '\n');

  LiveIntervals *lis = &getAnalysis<LiveIntervals>();
  MachineLoopInfo *loopInfo = &getAnalysis<MachineLoopInfo>();
  const TargetInstrInfo *tii = fn.getTarget().getInstrInfo();
  MachineRegisterInfo *mri = &fn.getRegInfo();

  SmallSet<unsigned, 4> processed;
  for (MachineFunction::iterator mbbi = fn.begin(), mbbe = fn.end();
       mbbi != mbbe; ++mbbi) {
    MachineBasicBlock* mbb = mbbi;
    SlotIndex mbbEnd = lis->getMBBEndIdx(mbb);
    MachineLoop* loop = loopInfo->getLoopFor(mbb);
    unsigned loopDepth = loop ? loop->getLoopDepth() : 0;
    bool isExiting = loop ? loop->isLoopExiting(mbb) : false;

    for (MachineBasicBlock::const_iterator mii = mbb->begin(), mie = mbb->end();
         mii != mie; ++mii) {
      const MachineInstr *mi = mii;
      if (tii->isIdentityCopy(*mi))
        continue;

      if (mi->getOpcode() == TargetInstrInfo::IMPLICIT_DEF)
        continue;

      for (unsigned i = 0, e = mi->getNumOperands(); i != e; ++i) {
        const MachineOperand &mopi = mi->getOperand(i);
        if (!mopi.isReg() || mopi.getReg() == 0)
          continue;
        unsigned reg = mopi.getReg();
        if (!TargetRegisterInfo::isVirtualRegister(mopi.getReg()))
          continue;
        // Multiple uses of reg by the same instruction. It should not
        // contribute to spill weight again.
        if (!processed.insert(reg))
          continue;

        bool hasDef = mopi.isDef();
        bool hasUse = !hasDef;
        for (unsigned j = i+1; j != e; ++j) {
          const MachineOperand &mopj = mi->getOperand(j);
          if (!mopj.isReg() || mopj.getReg() != reg)
            continue;
          hasDef |= mopj.isDef();
          hasUse |= mopj.isUse();
          if (hasDef && hasUse)
            break;
        }

        LiveInterval &regInt = lis->getInterval(reg);
        float weight = lis->getSpillWeight(hasDef, hasUse, loopDepth);
        if (hasDef && isExiting) {
          // Looks like this is a loop count variable update.
          SlotIndex defIdx = lis->getInstructionIndex(mi).getDefIndex();
          const LiveRange *dlr =
            lis->getInterval(reg).getLiveRangeContaining(defIdx);
          if (dlr->end > mbbEnd)
            weight *= 3.0F;
        }
        regInt.weight += weight;
      }
      processed.clear();
    }
  }

  for (LiveIntervals::iterator I = lis->begin(), E = lis->end(); I != E; ++I) {
    LiveInterval &li = *I->second;
    if (TargetRegisterInfo::isVirtualRegister(li.reg)) {
      // If the live interval length is essentially zero, i.e. in every live
      // range the use follows def immediately, it doesn't make sense to spill
      // it and hope it will be easier to allocate for this li.
      if (isZeroLengthInterval(&li)) {
        li.weight = HUGE_VALF;
        continue;
      }

      bool isLoad = false;
      SmallVector<LiveInterval*, 4> spillIs;
      if (lis->isReMaterializable(li, spillIs, isLoad)) {
        // If all of the definitions of the interval are re-materializable,
        // it is a preferred candidate for spilling. If non of the defs are
        // loads, then it's potentially very cheap to re-materialize.
        // FIXME: this gets much more complicated once we support non-trivial
        // re-materialization.
        if (isLoad)
          li.weight *= 0.9F;
        else
          li.weight *= 0.5F;
      }

      // Slightly prefer live interval that has been assigned a preferred reg.
      std::pair<unsigned, unsigned> Hint = mri->getRegAllocationHint(li.reg);
      if (Hint.first || Hint.second)
        li.weight *= 1.01F;

      // Divide the weight of the interval by its size.  This encourages
      // spilling of intervals that are large and have few uses, and
      // discourages spilling of small intervals with many uses.
      li.weight /= lis->getApproximateInstructionCount(li) * SlotIndex::NUM;
    }
  }
  
  return false;
}

/// Returns true if the given live interval is zero length.
bool CalculateSpillWeights::isZeroLengthInterval(LiveInterval *li) const {
  for (LiveInterval::Ranges::const_iterator
       i = li->ranges.begin(), e = li->ranges.end(); i != e; ++i)
    if (i->end.getPrevIndex() > i->start)
      return false;
  return true;
}