diff options
Diffstat (limited to 'lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp')
-rw-r--r-- | lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 532 |
1 files changed, 521 insertions, 11 deletions
diff --git a/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index cd30880..0dfbf10 100644 --- a/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -24,11 +24,14 @@ using namespace llvm; #define DEBUG_TYPE "nvptx-isel" -static cl::opt<int> -FMAContractLevel("nvptx-fma-level", cl::ZeroOrMore, cl::Hidden, - cl::desc("NVPTX Specific: FMA contraction (0: don't do it" - " 1: do it 2: do it aggressively"), - cl::init(2)); +unsigned FMAContractLevel = 0; + +static cl::opt<unsigned, true> +FMAContractLevelOpt("nvptx-fma-level", cl::ZeroOrMore, cl::Hidden, + cl::desc("NVPTX Specific: FMA contraction (0: don't do it" + " 1: do it 2: do it aggressively"), + cl::location(FMAContractLevel), + cl::init(2)); static cl::opt<int> UsePrecDivF32( "nvptx-prec-divf32", cl::ZeroOrMore, cl::Hidden, @@ -138,7 +141,7 @@ SDNode *NVPTXDAGToDAGISel::Select(SDNode *N) { case NVPTXISD::LDGV4: case NVPTXISD::LDUV2: case NVPTXISD::LDUV4: - ResNode = SelectLDGLDUVector(N); + ResNode = SelectLDGLDU(N); break; case NVPTXISD::StoreV2: case NVPTXISD::StoreV4: @@ -164,6 +167,9 @@ SDNode *NVPTXDAGToDAGISel::Select(SDNode *N) { case ISD::INTRINSIC_WO_CHAIN: ResNode = SelectIntrinsicNoChain(N); break; + case ISD::INTRINSIC_W_CHAIN: + ResNode = SelectIntrinsicChain(N); + break; case NVPTXISD::Tex1DFloatI32: case NVPTXISD::Tex1DFloatFloat: case NVPTXISD::Tex1DFloatFloatLevel: @@ -253,6 +259,12 @@ SDNode *NVPTXDAGToDAGISel::Select(SDNode *N) { case NVPTXISD::Suld3DV4I32Trap: ResNode = SelectSurfaceIntrinsic(N); break; + case ISD::AND: + case ISD::SRA: + case ISD::SRL: + // Try to select BFE + ResNode = SelectBFE(N); + break; case ISD::ADDRSPACECAST: ResNode = SelectAddrSpaceCast(N); break; @@ -264,6 +276,21 @@ SDNode *NVPTXDAGToDAGISel::Select(SDNode *N) { return SelectCode(N); } +SDNode *NVPTXDAGToDAGISel::SelectIntrinsicChain(SDNode *N) { + unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); + switch (IID) { + default: + return NULL; + case Intrinsic::nvvm_ldg_global_f: + case Intrinsic::nvvm_ldg_global_i: + case Intrinsic::nvvm_ldg_global_p: + case Intrinsic::nvvm_ldu_global_f: + case Intrinsic::nvvm_ldu_global_i: + case Intrinsic::nvvm_ldu_global_p: + return SelectLDGLDU(N); + } +} + static unsigned int getCodeAddrSpace(MemSDNode *N, const NVPTXSubtarget &Subtarget) { const Value *Src = N->getMemOperand()->getValue(); @@ -981,22 +1008,101 @@ SDNode *NVPTXDAGToDAGISel::SelectLoadVector(SDNode *N) { return LD; } -SDNode *NVPTXDAGToDAGISel::SelectLDGLDUVector(SDNode *N) { +SDNode *NVPTXDAGToDAGISel::SelectLDGLDU(SDNode *N) { SDValue Chain = N->getOperand(0); - SDValue Op1 = N->getOperand(1); + SDValue Op1; + MemSDNode *Mem; + bool IsLDG = true; + + // If this is an LDG intrinsic, the address is the third operand. Its its an + // LDG/LDU SD node (from custom vector handling), then its the second operand + if (N->getOpcode() == ISD::INTRINSIC_W_CHAIN) { + Op1 = N->getOperand(2); + Mem = cast<MemIntrinsicSDNode>(N); + unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); + switch (IID) { + default: + return NULL; + case Intrinsic::nvvm_ldg_global_f: + case Intrinsic::nvvm_ldg_global_i: + case Intrinsic::nvvm_ldg_global_p: + IsLDG = true; + break; + case Intrinsic::nvvm_ldu_global_f: + case Intrinsic::nvvm_ldu_global_i: + case Intrinsic::nvvm_ldu_global_p: + IsLDG = false; + break; + } + } else { + Op1 = N->getOperand(1); + Mem = cast<MemSDNode>(N); + } + unsigned Opcode; SDLoc DL(N); SDNode *LD; - MemSDNode *Mem = cast<MemSDNode>(N); SDValue Base, Offset, Addr; - EVT EltVT = Mem->getMemoryVT().getVectorElementType(); + EVT EltVT = Mem->getMemoryVT(); + if (EltVT.isVector()) { + EltVT = EltVT.getVectorElementType(); + } if (SelectDirectAddr(Op1, Addr)) { switch (N->getOpcode()) { default: return nullptr; + case ISD::INTRINSIC_W_CHAIN: + if (IsLDG) { + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return nullptr; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i8avar; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i16avar; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i32avar; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i64avar; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_f32avar; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_f64avar; + break; + } + } else { + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return nullptr; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i8avar; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i16avar; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i32avar; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i64avar; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_f32avar; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_f64avar; + break; + } + } + break; case NVPTXISD::LDGV2: switch (EltVT.getSimpleVT().SimpleTy) { default: @@ -1092,6 +1198,55 @@ SDNode *NVPTXDAGToDAGISel::SelectLDGLDUVector(SDNode *N) { switch (N->getOpcode()) { default: return nullptr; + case ISD::INTRINSIC_W_CHAIN: + if (IsLDG) { + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return nullptr; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i8ari64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i16ari64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i32ari64; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i64ari64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_f32ari64; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_f64ari64; + break; + } + } else { + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return nullptr; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i8ari64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i16ari64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i32ari64; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i64ari64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_f32ari64; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_f64ari64; + break; + } + } + break; case NVPTXISD::LDGV2: switch (EltVT.getSimpleVT().SimpleTy) { default: @@ -1181,6 +1336,55 @@ SDNode *NVPTXDAGToDAGISel::SelectLDGLDUVector(SDNode *N) { switch (N->getOpcode()) { default: return nullptr; + case ISD::INTRINSIC_W_CHAIN: + if (IsLDG) { + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return nullptr; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i8ari; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i16ari; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i32ari; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i64ari; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_f32ari; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_f64ari; + break; + } + } else { + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return nullptr; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i8ari; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i16ari; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i32ari; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i64ari; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_f32ari; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_f64ari; + break; + } + } + break; case NVPTXISD::LDGV2: switch (EltVT.getSimpleVT().SimpleTy) { default: @@ -1276,6 +1480,55 @@ SDNode *NVPTXDAGToDAGISel::SelectLDGLDUVector(SDNode *N) { switch (N->getOpcode()) { default: return nullptr; + case ISD::INTRINSIC_W_CHAIN: + if (IsLDG) { + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return nullptr; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i8areg64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i16areg64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i32areg64; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i64areg64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_f32areg64; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_f64areg64; + break; + } + } else { + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return nullptr; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i8areg64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i16areg64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i32areg64; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i64areg64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_f32areg64; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_f64areg64; + break; + } + } + break; case NVPTXISD::LDGV2: switch (EltVT.getSimpleVT().SimpleTy) { default: @@ -1365,6 +1618,55 @@ SDNode *NVPTXDAGToDAGISel::SelectLDGLDUVector(SDNode *N) { switch (N->getOpcode()) { default: return nullptr; + case ISD::INTRINSIC_W_CHAIN: + if (IsLDG) { + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return nullptr; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i8areg; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i16areg; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i32areg; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_i64areg; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_f32areg; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDG_GLOBAL_f64areg; + break; + } + } else { + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return nullptr; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i8areg; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i16areg; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i32areg; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_i64areg; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_f32areg; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDU_GLOBAL_f64areg; + break; + } + } + break; case NVPTXISD::LDGV2: switch (EltVT.getSimpleVT().SimpleTy) { default: @@ -1457,7 +1759,7 @@ SDNode *NVPTXDAGToDAGISel::SelectLDGLDUVector(SDNode *N) { } MachineSDNode::mmo_iterator MemRefs0 = MF->allocateMemRefsArray(1); - MemRefs0[0] = cast<MemSDNode>(N)->getMemOperand(); + MemRefs0[0] = Mem->getMemOperand(); cast<MachineSDNode>(LD)->setMemRefs(MemRefs0, MemRefs0 + 1); return LD; @@ -2959,6 +3261,214 @@ SDNode *NVPTXDAGToDAGISel::SelectSurfaceIntrinsic(SDNode *N) { return Ret; } +/// SelectBFE - Look for instruction sequences that can be made more efficient +/// by using the 'bfe' (bit-field extract) PTX instruction +SDNode *NVPTXDAGToDAGISel::SelectBFE(SDNode *N) { + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + SDValue Len; + SDValue Start; + SDValue Val; + bool IsSigned = false; + + if (N->getOpcode() == ISD::AND) { + // Canonicalize the operands + // We want 'and %val, %mask' + if (isa<ConstantSDNode>(LHS) && !isa<ConstantSDNode>(RHS)) { + std::swap(LHS, RHS); + } + + ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(RHS); + if (!Mask) { + // We need a constant mask on the RHS of the AND + return NULL; + } + + // Extract the mask bits + uint64_t MaskVal = Mask->getZExtValue(); + if (!isMask_64(MaskVal)) { + // We *could* handle shifted masks here, but doing so would require an + // 'and' operation to fix up the low-order bits so we would trade + // shr+and for bfe+and, which has the same throughput + return NULL; + } + + // How many bits are in our mask? + uint64_t NumBits = CountTrailingOnes_64(MaskVal); + Len = CurDAG->getTargetConstant(NumBits, MVT::i32); + + if (LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SRA) { + // We have a 'srl/and' pair, extract the effective start bit and length + Val = LHS.getNode()->getOperand(0); + Start = LHS.getNode()->getOperand(1); + ConstantSDNode *StartConst = dyn_cast<ConstantSDNode>(Start); + if (StartConst) { + uint64_t StartVal = StartConst->getZExtValue(); + // How many "good" bits do we have left? "good" is defined here as bits + // that exist in the original value, not shifted in. + uint64_t GoodBits = Start.getValueType().getSizeInBits() - StartVal; + if (NumBits > GoodBits) { + // Do not handle the case where bits have been shifted in. In theory + // we could handle this, but the cost is likely higher than just + // emitting the srl/and pair. + return NULL; + } + Start = CurDAG->getTargetConstant(StartVal, MVT::i32); + } else { + // Do not handle the case where the shift amount (can be zero if no srl + // was found) is not constant. We could handle this case, but it would + // require run-time logic that would be more expensive than just + // emitting the srl/and pair. + return NULL; + } + } else { + // Do not handle the case where the LHS of the and is not a shift. While + // it would be trivial to handle this case, it would just transform + // 'and' -> 'bfe', but 'and' has higher-throughput. + return NULL; + } + } else if (N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) { + if (LHS->getOpcode() == ISD::AND) { + ConstantSDNode *ShiftCnst = dyn_cast<ConstantSDNode>(RHS); + if (!ShiftCnst) { + // Shift amount must be constant + return NULL; + } + + uint64_t ShiftAmt = ShiftCnst->getZExtValue(); + + SDValue AndLHS = LHS->getOperand(0); + SDValue AndRHS = LHS->getOperand(1); + + // Canonicalize the AND to have the mask on the RHS + if (isa<ConstantSDNode>(AndLHS)) { + std::swap(AndLHS, AndRHS); + } + + ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(AndRHS); + if (!MaskCnst) { + // Mask must be constant + return NULL; + } + + uint64_t MaskVal = MaskCnst->getZExtValue(); + uint64_t NumZeros; + uint64_t NumBits; + if (isMask_64(MaskVal)) { + NumZeros = 0; + // The number of bits in the result bitfield will be the number of + // trailing ones (the AND) minus the number of bits we shift off + NumBits = CountTrailingOnes_64(MaskVal) - ShiftAmt; + } else if (isShiftedMask_64(MaskVal)) { + NumZeros = countTrailingZeros(MaskVal); + unsigned NumOnes = CountTrailingOnes_64(MaskVal >> NumZeros); + // The number of bits in the result bitfield will be the number of + // trailing zeros plus the number of set bits in the mask minus the + // number of bits we shift off + NumBits = NumZeros + NumOnes - ShiftAmt; + } else { + // This is not a mask we can handle + return NULL; + } + + if (ShiftAmt < NumZeros) { + // Handling this case would require extra logic that would make this + // transformation non-profitable + return NULL; + } + + Val = AndLHS; + Start = CurDAG->getTargetConstant(ShiftAmt, MVT::i32); + Len = CurDAG->getTargetConstant(NumBits, MVT::i32); + } else if (LHS->getOpcode() == ISD::SHL) { + // Here, we have a pattern like: + // + // (sra (shl val, NN), MM) + // or + // (srl (shl val, NN), MM) + // + // If MM >= NN, we can efficiently optimize this with bfe + Val = LHS->getOperand(0); + + SDValue ShlRHS = LHS->getOperand(1); + ConstantSDNode *ShlCnst = dyn_cast<ConstantSDNode>(ShlRHS); + if (!ShlCnst) { + // Shift amount must be constant + return NULL; + } + uint64_t InnerShiftAmt = ShlCnst->getZExtValue(); + + SDValue ShrRHS = RHS; + ConstantSDNode *ShrCnst = dyn_cast<ConstantSDNode>(ShrRHS); + if (!ShrCnst) { + // Shift amount must be constant + return NULL; + } + uint64_t OuterShiftAmt = ShrCnst->getZExtValue(); + + // To avoid extra codegen and be profitable, we need Outer >= Inner + if (OuterShiftAmt < InnerShiftAmt) { + return NULL; + } + + // If the outer shift is more than the type size, we have no bitfield to + // extract (since we also check that the inner shift is <= the outer shift + // then this also implies that the inner shift is < the type size) + if (OuterShiftAmt >= Val.getValueType().getSizeInBits()) { + return NULL; + } + + Start = + CurDAG->getTargetConstant(OuterShiftAmt - InnerShiftAmt, MVT::i32); + Len = + CurDAG->getTargetConstant(Val.getValueType().getSizeInBits() - + OuterShiftAmt, MVT::i32); + + if (N->getOpcode() == ISD::SRA) { + // If we have a arithmetic right shift, we need to use the signed bfe + // variant + IsSigned = true; + } + } else { + // No can do... + return NULL; + } + } else { + // No can do... + return NULL; + } + + + unsigned Opc; + // For the BFE operations we form here from "and" and "srl", always use the + // unsigned variants. + if (Val.getValueType() == MVT::i32) { + if (IsSigned) { + Opc = NVPTX::BFE_S32rii; + } else { + Opc = NVPTX::BFE_U32rii; + } + } else if (Val.getValueType() == MVT::i64) { + if (IsSigned) { + Opc = NVPTX::BFE_S64rii; + } else { + Opc = NVPTX::BFE_U64rii; + } + } else { + // We cannot handle this type + return NULL; + } + + SDValue Ops[] = { + Val, Start, Len + }; + + SDNode *Ret = + CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops); + + return Ret; +} + // SelectDirectAddr - Match a direct address for DAG. // A direct address could be a globaladdress or externalsymbol. bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) { |