aboutsummaryrefslogtreecommitdiffstats
path: root/lib/CodeGen/ForwardControlFlowIntegrity.cpp
blob: 63c3699e13b61f80da05a8cc2f030dfb107fba75 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
//===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===//
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
///
/// \file
/// \brief A pass that instruments code with fast checks for indirect calls and
/// hooks for a function to check violations.
///
//===----------------------------------------------------------------------===//

#define DEBUG_TYPE "cfi"

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/JumpInstrTableInfo.h"
#include "llvm/CodeGen/ForwardControlFlowIntegrity.h"
#include "llvm/CodeGen/JumpInstrTables.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

using namespace llvm;

STATISTIC(NumCFIIndirectCalls,
          "Number of indirect call sites rewritten by the CFI pass");

char ForwardControlFlowIntegrity::ID = 0;
INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi",
                      "Control-Flow Integrity", true, true)
INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo);
INITIALIZE_PASS_DEPENDENCY(JumpInstrTables);
INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi",
                    "Control-Flow Integrity", true, true)

ModulePass *llvm::createForwardControlFlowIntegrityPass() {
  return new ForwardControlFlowIntegrity();
}

ModulePass *llvm::createForwardControlFlowIntegrityPass(
    JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
    StringRef CFIFuncName) {
  return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing,
                                         CFIFuncName);
}

// Checks to see if a given CallSite is making an indirect call, including
// cases where the indirect call is made through a bitcast.
static bool isIndirectCall(CallSite &CS) {
  if (CS.getCalledFunction())
    return false;

  // Check the value to see if it is merely a bitcast of a function. In
  // this case, it will translate to a direct function call in the resulting
  // assembly, so we won't treat it as an indirect call here.
  const Value *V = CS.getCalledValue();
  if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
    return !(CE->isCast() && isa<Function>(CE->getOperand(0)));
  }

  // Otherwise, since we know it's a call, it must be an indirect call
  return true;
}

static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning";

ForwardControlFlowIntegrity::ForwardControlFlowIntegrity()
    : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single),
      CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") {
  initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
}

ForwardControlFlowIntegrity::ForwardControlFlowIntegrity(
    JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
    std::string CFIFuncName)
    : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType),
      CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) {
  initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
}

ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {}

void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const {
  AU.addRequired<JumpInstrTableInfo>();
  AU.addRequired<JumpInstrTables>();
}

void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) {
  // To get the indirect calls, we iterate over all functions and iterate over
  // the list of basic blocks in each. We extract a total list of indirect calls
  // before modifying any of them, since our modifications will modify the list
  // of basic blocks.
  for (Function &F : M) {
    for (BasicBlock &BB : F) {
      for (Instruction &I : BB) {
        CallSite CS(&I);
        if (!(CS && isIndirectCall(CS)))
          continue;

        Value *CalledValue = CS.getCalledValue();

        // Don't rewrite this instruction if the indirect call is actually just
        // inline assembly, since our transformation will generate an invalid
        // module in that case.
        if (isa<InlineAsm>(CalledValue))
          continue;

        IndirectCalls.push_back(&I);
      }
    }
  }
}

void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M,
                                                      CFITables &CFIT) {
  Type *Int64Ty = Type::getInt64Ty(M.getContext());
  for (Instruction *I : IndirectCalls) {
    CallSite CS(I);
    Value *CalledValue = CS.getCalledValue();

    // Get the function type for this call and look it up in the tables.
    Type *VTy = CalledValue->getType();
    PointerType *PTy = dyn_cast<PointerType>(VTy);
    Type *EltTy = PTy->getElementType();
    FunctionType *FunTy = dyn_cast<FunctionType>(EltTy);
    FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy);
    ++NumCFIIndirectCalls;
    Constant *JumpTableStart = nullptr;
    Constant *JumpTableMask = nullptr;
    Constant *JumpTableSize = nullptr;

    // Some call sites have function types that don't correspond to any
    // address-taken function in the module. This happens when function pointers
    // are passed in from external code.
    auto it = CFIT.find(TransformedTy);
    if (it == CFIT.end()) {
      // In this case, make sure that the function pointer will change by
      // setting the mask and the start to be 0 so that the transformed
      // function is 0.
      JumpTableStart = ConstantInt::get(Int64Ty, 0);
      JumpTableMask = ConstantInt::get(Int64Ty, 0);
      JumpTableSize = ConstantInt::get(Int64Ty, 0);
    } else {
      JumpTableStart = it->second.StartValue;
      JumpTableMask = it->second.MaskValue;
      JumpTableSize = it->second.Size;
    }

    rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask,
                           JumpTableSize);
  }

  return;
}

bool ForwardControlFlowIntegrity::runOnModule(Module &M) {
  JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>();
  Type *Int64Ty = Type::getInt64Ty(M.getContext());
  Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());

  // JumpInstrTableInfo stores information about the alignment of each entry.
  // The alignment returned by JumpInstrTableInfo is alignment in bytes, not
  // in the exponent.
  ByteAlignment = JITI->entryByteAlignment();
  LogByteAlignment = llvm::Log2_64(ByteAlignment);

  // Set up tables for control-flow integrity based on information about the
  // jump-instruction tables.
  CFITables CFIT;
  for (const auto &KV : JITI->getTables()) {
    uint64_t Size = static_cast<uint64_t>(KV.second.size());
    uint64_t TableSize = NextPowerOf2(Size);

    int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment;
    Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue);
    Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size);

    // The base of the table is defined to be the first jumptable function in
    // the table.
    Function *First = KV.second.begin()->second;
    Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy);
    CFIT[KV.first].StartValue = JumpTableStartValue;
    CFIT[KV.first].MaskValue = JumpTableMaskValue;
    CFIT[KV.first].Size = JumpTableSize;
  }

  if (CFIT.empty())
    return false;

  getIndirectCalls(M);

  if (!CFIEnforcing) {
    addWarningFunction(M);
  }

  // Update the instructions with the check and the indirect jump through our
  // table.
  updateIndirectCalls(M, CFIT);

  return true;
}

void ForwardControlFlowIntegrity::addWarningFunction(Module &M) {
  PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext());

  // Get the type of the Warning Function: void (i8*, i8*),
  // where the first argument is the name of the function in which the violation
  // occurs, and the second is the function pointer that violates CFI.
  SmallVector<Type *, 2> WarningFunArgs;
  WarningFunArgs.push_back(CharPtrTy);
  WarningFunArgs.push_back(CharPtrTy);
  FunctionType *WarningFunTy =
      FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false);

  if (!CFIFuncName.empty()) {
    Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy);
    if (!FailureFun)
      report_fatal_error("Could not get or insert the function specified by"
                         " -cfi-func-name");
  } else {
    // The default warning function swallows the warning and lets the call
    // continue, since there's no generic way for it to print out this
    // information.
    Function *WarningFun = M.getFunction(cfi_failure_func_name);
    if (!WarningFun) {
      WarningFun =
          Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage,
                           cfi_failure_func_name, &M);
    }

    BasicBlock *Entry =
        BasicBlock::Create(M.getContext(), "entry", WarningFun, 0);
    ReturnInst::Create(M.getContext(), Entry);
  }
}

void ForwardControlFlowIntegrity::rewriteFunctionPointer(
    Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart,
    Constant *JumpTableMask, Constant *JumpTableSize) {
  IRBuilder<> TempBuilder(I);

  Type *OrigFunType = FunPtr->getType();

  BasicBlock *CurBB = cast<BasicBlock>(I->getParent());
  Function *CurF = cast<Function>(CurBB->getParent());
  Type *Int64Ty = Type::getInt64Ty(M.getContext());

  Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty);
  Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty);

  Value *NewFunPtr = nullptr;
  Value *Check = nullptr;
  switch (CFIType) {
  case CFIntegrity::Sub: {
    // This is the subtract, mask, and add version.
    // Subtract from the base.
    Value *Sub = TempBuilder.CreateSub(TI, TStartInt);

    // Mask the difference to force this to be a table offset.
    Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask);

    // Add it back to the base.
    Value *Result = TempBuilder.CreateAdd(And, TStartInt);

    // Convert it back into a function pointer that we can call.
    NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
    break;
  }
  case CFIntegrity::Ror: {
    // This is the subtract and rotate version.
    // Rotate right by the alignment value. The optimizer should recognize
    // this sequence as a rotation.

    // This cast is safe, since unsigned is always a subset of uint64_t.
    uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment);
    Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64);
    Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64);

    // Subtract from the base.
    Value *Sub = TempBuilder.CreateSub(TI, TStartInt);

    // Create the equivalent of a rotate-right instruction.
    Value *Shr = TempBuilder.CreateLShr(Sub, RightShift);
    Value *Shl = TempBuilder.CreateShl(Sub, LeftShift);
    Value *Or = TempBuilder.CreateOr(Shr, Shl);

    // Perform unsigned comparison to check for inclusion in the table.
    Check = TempBuilder.CreateICmpULT(Or, JumpTableSize);
    NewFunPtr = FunPtr;
    break;
  }
  case CFIntegrity::Add: {
    // This is the mask and add version.
    // Mask the function pointer to turn it into an offset into the table.
    Value *And = TempBuilder.CreateAnd(TI, JumpTableMask);

    // Then or this offset to the base and get the pointer value.
    Value *Result = TempBuilder.CreateAdd(And, TStartInt);

    // Convert it back into a function pointer that we can call.
    NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
    break;
  }
  }

  if (!CFIEnforcing) {
    // If a check hasn't been added (in the rotation version), then check to see
    // if it's the same as the original function. This check determines whether
    // or not we call the CFI failure function.
    if (!Check)
      Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr);
    BasicBlock *InvalidPtrBlock =
        BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0);
    BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I);

    // Remove the unconditional branch that connects the two blocks.
    TerminatorInst *TermInst = CurBB->getTerminator();
    TermInst->eraseFromParent();

    // Add a conditional branch that depends on the Check above.
    BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB);

    // Call the warning function for this pointer, then continue.
    Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock);
    insertWarning(M, InvalidPtrBlock, BI, FunPtr);
  } else {
    // Modify the instruction to call this value.
    CallSite CS(I);
    CS.setCalledFunction(NewFunPtr);
  }
}

void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block,
                                                Instruction *I, Value *FunPtr) {
  Function *ParentFun = cast<Function>(Block->getParent());

  // Get the function to call right before the instruction.
  Function *WarningFun = nullptr;
  if (CFIFuncName.empty()) {
    WarningFun = M.getFunction(cfi_failure_func_name);
  } else {
    WarningFun = M.getFunction(CFIFuncName);
  }

  assert(WarningFun && "Could not find the CFI failure function");

  Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());

  IRBuilder<> WarningInserter(I);
  // Create a mergeable GlobalVariable containing the name of the function.
  Value *ParentNameGV =
      WarningInserter.CreateGlobalString(ParentFun->getName());
  Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy);
  Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy);
  WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr);
}