aboutsummaryrefslogtreecommitdiffstats
path: root/lib/Analysis/DataStructure/DSCallSiteIterator.h
blob: bc51fcf3caaf8c58cece9c180a7ed58743e4fa13 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
//===- DSCallSiteIterator.h - Iterator for DSGraph call sites ---*- C++ -*-===//
// 
//                     The LLVM Compiler Infrastructure
//
// This file was developed by the LLVM research group and is distributed under
// the University of Illinois Open Source License. See LICENSE.TXT for details.
// 
//===----------------------------------------------------------------------===//
//
// This file implements an iterator for complete call sites in DSGraphs.  This
// code can either iterator over the normal call list or the aux calls list, and
// is used by the TD and BU passes.
//
//===----------------------------------------------------------------------===//

#ifndef DSCALLSITEITERATOR_H
#define DSCALLSITEITERATOR_H

#include "llvm/Analysis/DataStructure/DSGraph.h"
#include "llvm/Function.h"

namespace llvm {

struct DSCallSiteIterator {
  // FCs are the edges out of the current node are the call site targets...
  std::list<DSCallSite> *FCs;
  std::list<DSCallSite>::iterator CallSite;
  unsigned CallSiteEntry;

  DSCallSiteIterator(std::list<DSCallSite> &CS) : FCs(&CS) {
    CallSite = CS.begin(); CallSiteEntry = 0;
    advanceToValidCallee();
  }

  // End iterator ctor.
  DSCallSiteIterator(std::list<DSCallSite> &CS, bool) : FCs(&CS) {
    CallSite = CS.end(); CallSiteEntry = 0;
  }

  static bool isVAHackFn(const Function *F) {
    return F->getName() == "printf"  || F->getName() == "sscanf" ||
      F->getName() == "fprintf" || F->getName() == "open" ||
      F->getName() == "sprintf" || F->getName() == "fputs" ||
      F->getName() == "fscanf";
  }

  // isUnresolvableFunction - Return true if this is an unresolvable
  // external function.  A direct or indirect call to this cannot be resolved.
  // 
  static bool isUnresolvableFunc(const Function* callee) {
    return callee->isExternal() && !isVAHackFn(callee);
  } 

  void advanceToValidCallee() {
    while (CallSite != FCs->end()) {
      if (CallSite->isDirectCall()) {
        if (CallSiteEntry == 0 &&        // direct call only has one target...
            ! isUnresolvableFunc(CallSite->getCalleeFunc()))
          return;                       // and not an unresolvable external func
      } else {
        DSNode *CalleeNode = CallSite->getCalleeNode();
        if (CallSiteEntry || isCompleteNode(CalleeNode)) {
          const std::vector<GlobalValue*> &Callees = CalleeNode->getGlobals();
          while (CallSiteEntry < Callees.size()) {
            if (isa<Function>(Callees[CallSiteEntry]))
              return;
            ++CallSiteEntry;
          }
        }
      }
      CallSiteEntry = 0;
      ++CallSite;
    }
  }
  
  // isCompleteNode - Return true if we know all of the targets of this node,
  // and if the call sites are not external.
  //
  static inline bool isCompleteNode(DSNode *N) {
    if (N->isIncomplete()) return false;
    const std::vector<GlobalValue*> &Callees = N->getGlobals();
    for (unsigned i = 0, e = Callees.size(); i != e; ++i)
      if (isUnresolvableFunc(cast<Function>(Callees[i])))
        return false;               // Unresolvable external function found...
    return true;  // otherwise ok
  }

public:
  static DSCallSiteIterator begin_aux(DSGraph &G) {
    return G.getAuxFunctionCalls();
  }
  static DSCallSiteIterator end_aux(DSGraph &G) {
    return DSCallSiteIterator(G.getAuxFunctionCalls(), true);
  }
  static DSCallSiteIterator begin_std(DSGraph &G) {
    return G.getFunctionCalls();
  }
  static DSCallSiteIterator end_std(DSGraph &G) {
    return DSCallSiteIterator(G.getFunctionCalls(), true);
  }
  static DSCallSiteIterator begin(std::list<DSCallSite> &CSs) { return CSs; }
  static DSCallSiteIterator end(std::list<DSCallSite> &CSs) {
    return DSCallSiteIterator(CSs, true);
  }
  bool operator==(const DSCallSiteIterator &CSI) const {
    return CallSite == CSI.CallSite && CallSiteEntry == CSI.CallSiteEntry;
  }
  bool operator!=(const DSCallSiteIterator &CSI) const {
    return !operator==(CSI);
  }

  std::list<DSCallSite>::iterator getCallSiteIdx() const { return CallSite; }
  const DSCallSite &getCallSite() const { return *CallSite; }

  Function *operator*() const {
    if (CallSite->isDirectCall()) {
      return CallSite->getCalleeFunc();
    } else {
      DSNode *Node = CallSite->getCalleeNode();
      return cast<Function>(Node->getGlobals()[CallSiteEntry]);
    }
  }

  DSCallSiteIterator& operator++() {                // Preincrement
    ++CallSiteEntry;
    advanceToValidCallee();
    return *this;
  }
  DSCallSiteIterator operator++(int) { // Postincrement
    DSCallSiteIterator tmp = *this; ++*this; return tmp; 
  }
};

} // End llvm namespace

#endif