aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lib/Target/NVPTX/NVPTXAsmPrinter.cpp18
-rw-r--r--lib/Target/NVPTX/NVPTXISelLowering.cpp45
-rw-r--r--test/CodeGen/NVPTX/vector-args.ll27
3 files changed, 83 insertions, 7 deletions
diff --git a/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 0115e1f..c0e8670 100644
--- a/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1481,7 +1481,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F,
O << "(\n";
for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
- const Type *Ty = I->getType();
+ Type *Ty = I->getType();
if (!first)
O << ",\n";
@@ -1504,6 +1504,22 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F,
}
if (PAL.hasAttribute(paramIndex+1, Attribute::ByVal) == false) {
+ if (Ty->isVectorTy()) {
+ // Just print .param .b8 .align <a> .param[size];
+ // <a> = PAL.getparamalignment
+ // size = typeallocsize of element type
+ unsigned align = PAL.getParamAlignment(paramIndex+1);
+ if (align == 0)
+ align = TD->getABITypeAlignment(Ty);
+
+ unsigned sz = TD->getTypeAllocSize(Ty);
+ O << "\t.param .align " << align
+ << " .b8 ";
+ printParamName(I, paramIndex, O);
+ O << "[" << sz << "]";
+
+ continue;
+ }
// Just a scalar
const PointerType *PTy = dyn_cast<PointerType>(Ty);
if (isKernelFunc) {
diff --git a/lib/Target/NVPTX/NVPTXISelLowering.cpp b/lib/Target/NVPTX/NVPTXISelLowering.cpp
index e9a9fbf..987d34b 100644
--- a/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1058,15 +1058,15 @@ NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
theArgs.push_back(I);
argTypes.push_back(I->getType());
}
- assert(argTypes.size() == Ins.size() &&
- "Ins types and function types did not match");
+ //assert(argTypes.size() == Ins.size() &&
+ // "Ins types and function types did not match");
int idx = 0;
- for (unsigned i=0, e=Ins.size(); i!=e; ++i, ++idx) {
+ for (unsigned i=0, e=argTypes.size(); i!=e; ++i, ++idx) {
Type *Ty = argTypes[i];
EVT ObjectVT = getValueType(Ty);
- assert(ObjectVT == Ins[i].VT &&
- "Ins type did not match function type");
+ //assert(ObjectVT == Ins[i].VT &&
+ // "Ins type did not match function type");
// If the kernel argument is image*_t or sampler_t, convert it to
// a i32 constant holding the parameter position. This can later
@@ -1081,7 +1081,15 @@ NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
if (theArgs[i]->use_empty()) {
// argument is dead
- InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
+ if (ObjectVT.isVector()) {
+ EVT EltVT = ObjectVT.getVectorElementType();
+ unsigned NumElts = ObjectVT.getVectorNumElements();
+ for (unsigned vi = 0; vi < NumElts; ++vi) {
+ InVals.push_back(DAG.getNode(ISD::UNDEF, dl, EltVT));
+ }
+ } else {
+ InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
+ }
continue;
}
@@ -1090,6 +1098,31 @@ NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
// appear in the same order as their order of appearance
// in the original function. "idx+1" holds that order.
if (PAL.hasAttribute(i+1, Attribute::ByVal) == false) {
+ if (ObjectVT.isVector()) {
+ unsigned NumElts = ObjectVT.getVectorNumElements();
+ EVT EltVT = ObjectVT.getVectorElementType();
+ unsigned Offset = 0;
+ for (unsigned vi = 0; vi < NumElts; ++vi) {
+ SDValue A = getParamSymbol(DAG, idx, getPointerTy());
+ SDValue B = DAG.getIntPtrConstant(Offset);
+ SDValue Addr = DAG.getNode(ISD::ADD, dl, getPointerTy(),
+ //getParamSymbol(DAG, idx, EltVT),
+ //DAG.getConstant(Offset, getPointerTy()));
+ A, B);
+ Value *SrcValue = Constant::getNullValue(PointerType::get(
+ EltVT.getTypeForEVT(F->getContext()),
+ llvm::ADDRESS_SPACE_PARAM));
+ SDValue Ld = DAG.getLoad(EltVT, dl, Root, Addr,
+ MachinePointerInfo(SrcValue),
+ false, false, false,
+ TD->getABITypeAlignment(EltVT.getTypeForEVT(
+ F->getContext())));
+ Offset += EltVT.getStoreSizeInBits()/8;
+ InVals.push_back(Ld);
+ }
+ continue;
+ }
+
// A plain scalar.
if (isABI || isKernel) {
// If ABI, load from the param symbol
diff --git a/test/CodeGen/NVPTX/vector-args.ll b/test/CodeGen/NVPTX/vector-args.ll
new file mode 100644
index 0000000..80deae4
--- /dev/null
+++ b/test/CodeGen/NVPTX/vector-args.ll
@@ -0,0 +1,27 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
+
+
+define float @foo(<2 x float> %a) {
+; CHECK: .func (.param .b32 func_retval0) foo
+; CHECK: .param .align 8 .b8 foo_param_0[8]
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+ %t1 = fmul <2 x float> %a, %a
+ %t2 = extractelement <2 x float> %t1, i32 0
+ %t3 = extractelement <2 x float> %t1, i32 1
+ %t4 = fadd float %t2, %t3
+ ret float %t4
+}
+
+
+define float @bar(<4 x float> %a) {
+; CHECK: .func (.param .b32 func_retval0) bar
+; CHECK: .param .align 16 .b8 bar_param_0[16]
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+ %t1 = fmul <4 x float> %a, %a
+ %t2 = extractelement <4 x float> %t1, i32 0
+ %t3 = extractelement <4 x float> %t1, i32 1
+ %t4 = fadd float %t2, %t3
+ ret float %t4
+}