diff options
Diffstat (limited to 'lib/Target/X86/X86TargetTransformInfo.cpp')
-rw-r--r-- | lib/Target/X86/X86TargetTransformInfo.cpp | 112 |
1 files changed, 103 insertions, 9 deletions
diff --git a/lib/Target/X86/X86TargetTransformInfo.cpp b/lib/Target/X86/X86TargetTransformInfo.cpp index c961e2f..2b70fd0 100644 --- a/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/lib/Target/X86/X86TargetTransformInfo.cpp @@ -48,8 +48,8 @@ public: } X86TTI(const X86TargetMachine *TM) - : ImmutablePass(ID), ST(TM->getSubtargetImpl()), - TLI(TM->getTargetLowering()) { + : ImmutablePass(ID), ST(TM->getSubtargetImpl()), + TLI(TM->getSubtargetImpl()->getTargetLowering()) { initializeX86TTIPass(*PassRegistry::getPassRegistry()); } @@ -82,9 +82,10 @@ public: unsigned getNumberOfRegisters(bool Vector) const override; unsigned getRegisterBitWidth(bool Vector) const override; - unsigned getMaximumUnrollFactor() const override; + unsigned getMaxInterleaveFactor() const override; unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind, - OperandValueKind) const override; + OperandValueKind, OperandValueProperties, + OperandValueProperties) const override; unsigned getShuffleCost(ShuffleKind Kind, Type *Tp, int Index, Type *SubTp) const override; unsigned getCastInstrCost(unsigned Opcode, Type *Dst, @@ -166,7 +167,7 @@ unsigned X86TTI::getRegisterBitWidth(bool Vector) const { } -unsigned X86TTI::getMaximumUnrollFactor() const { +unsigned X86TTI::getMaxInterleaveFactor() const { if (ST->isAtom()) return 1; @@ -178,15 +179,37 @@ unsigned X86TTI::getMaximumUnrollFactor() const { return 2; } -unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty, - OperandValueKind Op1Info, - OperandValueKind Op2Info) const { +unsigned X86TTI::getArithmeticInstrCost( + unsigned Opcode, Type *Ty, OperandValueKind Op1Info, + OperandValueKind Op2Info, OperandValueProperties Opd1PropInfo, + OperandValueProperties Opd2PropInfo) const { // Legalize the type. std::pair<unsigned, MVT> LT = TLI->getTypeLegalizationCost(Ty); int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); + if (ISD == ISD::SDIV && + Op2Info == TargetTransformInfo::OK_UniformConstantValue && + Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) { + // On X86, vector signed division by constants power-of-two are + // normally expanded to the sequence SRA + SRL + ADD + SRA. + // The OperandValue properties many not be same as that of previous + // operation;conservatively assume OP_None. + unsigned Cost = + 2 * getArithmeticInstrCost(Instruction::AShr, Ty, Op1Info, Op2Info, + TargetTransformInfo::OP_None, + TargetTransformInfo::OP_None); + Cost += getArithmeticInstrCost(Instruction::LShr, Ty, Op1Info, Op2Info, + TargetTransformInfo::OP_None, + TargetTransformInfo::OP_None); + Cost += getArithmeticInstrCost(Instruction::Add, Ty, Op1Info, Op2Info, + TargetTransformInfo::OP_None, + TargetTransformInfo::OP_None); + + return Cost; + } + static const CostTblEntry<MVT::SimpleValueType> AVX2UniformConstCostTable[] = { { ISD::SDIV, MVT::v16i16, 6 }, // vpmulhw sequence @@ -202,6 +225,15 @@ unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty, return LT.first * AVX2UniformConstCostTable[Idx].Cost; } + static const CostTblEntry<MVT::SimpleValueType> AVX512CostTable[] = { + { ISD::SHL, MVT::v16i32, 1 }, + { ISD::SRL, MVT::v16i32, 1 }, + { ISD::SRA, MVT::v16i32, 1 }, + { ISD::SHL, MVT::v8i64, 1 }, + { ISD::SRL, MVT::v8i64, 1 }, + { ISD::SRA, MVT::v8i64, 1 }, + }; + static const CostTblEntry<MVT::SimpleValueType> AVX2CostTable[] = { // Shifts on v4i64/v8i32 on AVX2 is legal even though we declare to // customize them to detect the cases where shift amount is a scalar one. @@ -237,6 +269,11 @@ unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty, { ISD::UDIV, MVT::v4i64, 4*20 }, }; + if (ST->hasAVX512()) { + int Idx = CostTableLookup(AVX512CostTable, ISD, LT.second); + if (Idx != -1) + return LT.first * AVX512CostTable[Idx].Cost; + } // Look for AVX2 lowering tricks. if (ST->hasAVX2()) { if (ISD == ISD::SHL && LT.second == MVT::v16i16 && @@ -541,7 +578,7 @@ unsigned X86TTI::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const { { ISD::SINT_TO_FP, MVT::v2f64, MVT::v16i8, 16*10 }, // There are faster sequences for float conversions. { ISD::UINT_TO_FP, MVT::v4f32, MVT::v2i64, 15 }, - { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 15 }, + { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 8 }, { ISD::UINT_TO_FP, MVT::v4f32, MVT::v8i16, 15 }, { ISD::UINT_TO_FP, MVT::v4f32, MVT::v16i8, 8 }, { ISD::SINT_TO_FP, MVT::v4f32, MVT::v2i64, 15 }, @@ -557,6 +594,45 @@ unsigned X86TTI::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const { return LTSrc.first * SSE2ConvTbl[Idx].Cost; } + static const TypeConversionCostTblEntry<MVT::SimpleValueType> + AVX512ConversionTbl[] = { + { ISD::FP_EXTEND, MVT::v8f64, MVT::v8f32, 1 }, + { ISD::FP_EXTEND, MVT::v8f64, MVT::v16f32, 3 }, + { ISD::FP_ROUND, MVT::v8f32, MVT::v8f64, 1 }, + { ISD::FP_ROUND, MVT::v16f32, MVT::v8f64, 3 }, + + { ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 1 }, + { ISD::TRUNCATE, MVT::v16i16, MVT::v16i32, 1 }, + { ISD::TRUNCATE, MVT::v8i16, MVT::v8i64, 1 }, + { ISD::TRUNCATE, MVT::v8i32, MVT::v8i64, 1 }, + { ISD::TRUNCATE, MVT::v16i32, MVT::v8i64, 4 }, + + // v16i1 -> v16i32 - load + broadcast + { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i1, 2 }, + { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i1, 2 }, + + { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 1 }, + { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 1 }, + { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i16, 1 }, + { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i16, 1 }, + { ISD::SIGN_EXTEND, MVT::v8i64, MVT::v16i32, 3 }, + { ISD::ZERO_EXTEND, MVT::v8i64, MVT::v16i32, 3 }, + + { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i1, 3 }, + { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 2 }, + { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i16, 2 }, + { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i32, 1 }, + { ISD::SINT_TO_FP, MVT::v8f64, MVT::v8i1, 4 }, + { ISD::SINT_TO_FP, MVT::v8f64, MVT::v8i16, 2 }, + { ISD::SINT_TO_FP, MVT::v8f64, MVT::v8i32, 1 }, + }; + + if (ST->hasAVX512()) { + int Idx = ConvertCostTableLookup(AVX512ConversionTbl, ISD, LTDest.second, + LTSrc.second); + if (Idx != -1) + return AVX512ConversionTbl[Idx].Cost; + } EVT SrcTy = TLI->getValueType(Src); EVT DstTy = TLI->getValueType(Dst); @@ -589,6 +665,11 @@ unsigned X86TTI::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const { { ISD::TRUNCATE, MVT::v8i8, MVT::v8i32, 2 }, { ISD::TRUNCATE, MVT::v8i16, MVT::v8i32, 2 }, { ISD::TRUNCATE, MVT::v8i32, MVT::v8i64, 4 }, + + { ISD::FP_EXTEND, MVT::v8f64, MVT::v8f32, 3 }, + { ISD::FP_ROUND, MVT::v8f32, MVT::v8f64, 3 }, + + { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i32, 8 }, }; static const TypeConversionCostTblEntry<MVT::SimpleValueType> @@ -715,6 +796,19 @@ unsigned X86TTI::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, { ISD::SETCC, MVT::v32i8, 1 }, }; + static const CostTblEntry<MVT::SimpleValueType> AVX512CostTbl[] = { + { ISD::SETCC, MVT::v8i64, 1 }, + { ISD::SETCC, MVT::v16i32, 1 }, + { ISD::SETCC, MVT::v8f64, 1 }, + { ISD::SETCC, MVT::v16f32, 1 }, + }; + + if (ST->hasAVX512()) { + int Idx = CostTableLookup(AVX512CostTbl, ISD, MTy); + if (Idx != -1) + return LT.first * AVX512CostTbl[Idx].Cost; + } + if (ST->hasAVX2()) { int Idx = CostTableLookup(AVX2CostTbl, ISD, MTy); if (Idx != -1) |