aboutsummaryrefslogtreecommitdiffstats
path: root/lib/Target/NVPTX/NVPTXLowerStructArgs.cpp
blob: 68dfbb716139010c910a0727ffc71577985109a2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
//===-- NVPTXLowerStructArgs.cpp - Copy struct args to local memory =====--===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// Copy struct args to local memory. This is needed for kernel functions only.
// This is a preparation for handling cases like
//
// kernel void foo(struct A arg, ...)
// {
//     struct A *p = &arg;
//     ...
//     ... = p->filed1 ...  (this is no generic address for .param)
//     p->filed2 = ...      (this is no write access to .param)
// }
//
//===----------------------------------------------------------------------===//

#include "NVPTX.h"
#include "NVPTXUtilities.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Pass.h"

using namespace llvm;

namespace llvm {
void initializeNVPTXLowerStructArgsPass(PassRegistry &);
}

namespace {
class NVPTXLowerStructArgs : public FunctionPass {
  bool runOnFunction(Function &F) override;

  void handleStructPtrArgs(Function &);
  void handleParam(Argument *);

public:
  static char ID; // Pass identification, replacement for typeid
  NVPTXLowerStructArgs() : FunctionPass(ID) {}
  const char *getPassName() const override {
    return "Copy structure (byval *) arguments to stack";
  }
};
} // namespace

char NVPTXLowerStructArgs::ID = 1;

INITIALIZE_PASS(NVPTXLowerStructArgs, "nvptx-lower-struct-args",
                "Lower structure arguments (NVPTX)", false, false)

void NVPTXLowerStructArgs::handleParam(Argument *Arg) {
  Function *Func = Arg->getParent();
  Instruction *FirstInst = &(Func->getEntryBlock().front());
  PointerType *PType = dyn_cast<PointerType>(Arg->getType());

  assert(PType && "Expecting pointer type in handleParam");

  Type *StructType = PType->getElementType();
  AllocaInst *AllocA = new AllocaInst(StructType, Arg->getName(), FirstInst);

  /* Set the alignment to alignment of the byval parameter. This is because,
   * later load/stores assume that alignment, and we are going to replace
   * the use of the byval parameter with this alloca instruction.
   */
  AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1));

  Arg->replaceAllUsesWith(AllocA);

  // Get the cvt.gen.to.param intrinsic
  Type *CvtTypes[] = {
      Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_PARAM),
      Type::getInt8PtrTy(Func->getParent()->getContext(),
                         ADDRESS_SPACE_GENERIC)};
  Function *CvtFunc = Intrinsic::getDeclaration(
      Func->getParent(), Intrinsic::nvvm_ptr_gen_to_param, CvtTypes);

  Value *BitcastArgs[] = {
      new BitCastInst(Arg, Type::getInt8PtrTy(Func->getParent()->getContext(),
                                              ADDRESS_SPACE_GENERIC),
                      Arg->getName(), FirstInst)};
  CallInst *CallCVT =
      CallInst::Create(CvtFunc, BitcastArgs, "cvt_to_param", FirstInst);

  BitCastInst *BitCast = new BitCastInst(
      CallCVT, PointerType::get(StructType, ADDRESS_SPACE_PARAM),
      Arg->getName(), FirstInst);
  LoadInst *LI = new LoadInst(BitCast, Arg->getName(), FirstInst);
  new StoreInst(LI, AllocA, FirstInst);
}

// =============================================================================
// If the function had a struct ptr arg, say foo(%struct.x *byval %d), then
// add the following instructions to the first basic block :
//
// %temp = alloca %struct.x, align 8
// %tt1 = bitcast %struct.x * %d to i8 *
// %tt2 = llvm.nvvm.cvt.gen.to.param %tt2
// %tempd = bitcast i8 addrspace(101) * to %struct.x addrspace(101) *
// %tv = load %struct.x addrspace(101) * %tempd
// store %struct.x %tv, %struct.x * %temp, align 8
//
// The above code allocates some space in the stack and copies the incoming
// struct from param space to local space.
// Then replace all occurences of %d by %temp.
// =============================================================================
void NVPTXLowerStructArgs::handleStructPtrArgs(Function &F) {
  for (Argument &Arg : F.args()) {
    if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
      handleParam(&Arg);
    }
  }
}

// =============================================================================
// Main function for this pass.
// =============================================================================
bool NVPTXLowerStructArgs::runOnFunction(Function &F) {
  // Skip non-kernels. See the comments at the top of this file.
  if (!isKernelFunction(F))
    return false;

  handleStructPtrArgs(F);
  return true;
}

FunctionPass *llvm::createNVPTXLowerStructArgsPass() {
  return new NVPTXLowerStructArgs();
}