diff options
Diffstat (limited to 'lib/Target/X86/X86TargetTransformInfo.cpp')
-rw-r--r-- | lib/Target/X86/X86TargetTransformInfo.cpp | 149 |
1 files changed, 114 insertions, 35 deletions
diff --git a/lib/Target/X86/X86TargetTransformInfo.cpp b/lib/Target/X86/X86TargetTransformInfo.cpp index 3bbddad..f88a666 100644 --- a/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/lib/Target/X86/X86TargetTransformInfo.cpp @@ -101,6 +101,9 @@ public: unsigned AddressSpace) const; virtual unsigned getAddressComputationCost(Type *PtrTy, bool IsComplex) const; + + virtual unsigned getReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm) const; /// @} }; @@ -127,8 +130,8 @@ X86TTI::PopcntSupportKind X86TTI::getPopcntSupport(unsigned TyWidth) const { assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2"); // TODO: Currently the __builtin_popcount() implementation using SSE3 // instructions is inefficient. Once the problem is fixed, we should - // call ST->hasSSE3() instead of ST->hasSSE4(). - return ST->hasSSE41() ? PSK_FastHardware : PSK_Software; + // call ST->hasSSE3() instead of ST->hasPOPCNT(). + return ST->hasPOPCNT() ? PSK_FastHardware : PSK_Software; } unsigned X86TTI::getNumberOfRegisters(bool Vector) const { @@ -174,7 +177,7 @@ unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty, int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); - static const CostTblEntry<MVT> AVX2CostTable[] = { + static const CostTblEntry<MVT::SimpleValueType> AVX2CostTable[] = { // Shifts on v4i64/v8i32 on AVX2 is legal even though we declare to // customize them to detect the cases where shift amount is a scalar one. { ISD::SHL, MVT::v4i32, 1 }, @@ -211,13 +214,13 @@ unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty, // Look for AVX2 lowering tricks. if (ST->hasAVX2()) { - int Idx = CostTableLookup<MVT>(AVX2CostTable, array_lengthof(AVX2CostTable), - ISD, LT.second); + int Idx = CostTableLookup(AVX2CostTable, ISD, LT.second); if (Idx != -1) return LT.first * AVX2CostTable[Idx].Cost; } - static const CostTblEntry<MVT> SSE2UniformConstCostTable[] = { + static const CostTblEntry<MVT::SimpleValueType> + SSE2UniformConstCostTable[] = { // We don't correctly identify costs of casts because they are marked as // custom. // Constant splats are cheaper for the following instructions. @@ -238,15 +241,13 @@ unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty, if (Op2Info == TargetTransformInfo::OK_UniformConstantValue && ST->hasSSE2()) { - int Idx = CostTableLookup<MVT>(SSE2UniformConstCostTable, - array_lengthof(SSE2UniformConstCostTable), - ISD, LT.second); + int Idx = CostTableLookup(SSE2UniformConstCostTable, ISD, LT.second); if (Idx != -1) return LT.first * SSE2UniformConstCostTable[Idx].Cost; } - static const CostTblEntry<MVT> SSE2CostTable[] = { + static const CostTblEntry<MVT::SimpleValueType> SSE2CostTable[] = { // We don't correctly identify costs of casts because they are marked as // custom. // For some cases, where the shift amount is a scalar we would be able @@ -287,13 +288,12 @@ unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty, }; if (ST->hasSSE2()) { - int Idx = CostTableLookup<MVT>(SSE2CostTable, array_lengthof(SSE2CostTable), - ISD, LT.second); + int Idx = CostTableLookup(SSE2CostTable, ISD, LT.second); if (Idx != -1) return LT.first * SSE2CostTable[Idx].Cost; } - static const CostTblEntry<MVT> AVX1CostTable[] = { + static const CostTblEntry<MVT::SimpleValueType> AVX1CostTable[] = { // We don't have to scalarize unsupported ops. We can issue two half-sized // operations and we only need to extract the upper YMM half. // Two ops + 1 extract + 1 insert = 4. @@ -312,21 +312,19 @@ unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty, // Look for AVX1 lowering tricks. if (ST->hasAVX() && !ST->hasAVX2()) { - int Idx = CostTableLookup<MVT>(AVX1CostTable, array_lengthof(AVX1CostTable), - ISD, LT.second); + int Idx = CostTableLookup(AVX1CostTable, ISD, LT.second); if (Idx != -1) return LT.first * AVX1CostTable[Idx].Cost; } // Custom lowering of vectors. - static const CostTblEntry<MVT> CustomLowered[] = { + static const CostTblEntry<MVT::SimpleValueType> CustomLowered[] = { // A v2i64/v4i64 and multiply is custom lowered as a series of long // multiplies(3), shifts(4) and adds(2). { ISD::MUL, MVT::v2i64, 9 }, { ISD::MUL, MVT::v4i64, 9 }, }; - int Idx = CostTableLookup<MVT>(CustomLowered, array_lengthof(CustomLowered), - ISD, LT.second); + int Idx = CostTableLookup(CustomLowered, ISD, LT.second); if (Idx != -1) return LT.first * CustomLowered[Idx].Cost; @@ -363,7 +361,8 @@ unsigned X86TTI::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const { std::pair<unsigned, MVT> LTSrc = TLI->getTypeLegalizationCost(Src); std::pair<unsigned, MVT> LTDest = TLI->getTypeLegalizationCost(Dst); - static const TypeConversionCostTblEntry<MVT> SSE2ConvTbl[] = { + static const TypeConversionCostTblEntry<MVT::SimpleValueType> + SSE2ConvTbl[] = { // These are somewhat magic numbers justified by looking at the output of // Intel's IACA, running some kernels and making sure when we take // legalization into account the throughput will be overestimated. @@ -387,9 +386,8 @@ unsigned X86TTI::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const { }; if (ST->hasSSE2() && !ST->hasAVX()) { - int Idx = ConvertCostTableLookup<MVT>(SSE2ConvTbl, - array_lengthof(SSE2ConvTbl), - ISD, LTDest.second, LTSrc.second); + int Idx = + ConvertCostTableLookup(SSE2ConvTbl, ISD, LTDest.second, LTSrc.second); if (Idx != -1) return LTSrc.first * SSE2ConvTbl[Idx].Cost; } @@ -401,13 +399,17 @@ unsigned X86TTI::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const { if (!SrcTy.isSimple() || !DstTy.isSimple()) return TargetTransformInfo::getCastInstrCost(Opcode, Dst, Src); - static const TypeConversionCostTblEntry<MVT> AVXConversionTbl[] = { + static const TypeConversionCostTblEntry<MVT::SimpleValueType> + AVXConversionTbl[] = { + { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 1 }, + { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 1 }, { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i16, 1 }, { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i16, 1 }, { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i32, 1 }, { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i32, 1 }, { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64, 1 }, { ISD::TRUNCATE, MVT::v8i16, MVT::v8i32, 1 }, + { ISD::TRUNCATE, MVT::v16i8, MVT::v16i16, 2 }, { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i1, 8 }, { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8, 8 }, @@ -446,9 +448,8 @@ unsigned X86TTI::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const { }; if (ST->hasAVX()) { - int Idx = ConvertCostTableLookup<MVT>(AVXConversionTbl, - array_lengthof(AVXConversionTbl), - ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()); + int Idx = ConvertCostTableLookup(AVXConversionTbl, ISD, DstTy.getSimpleVT(), + SrcTy.getSimpleVT()); if (Idx != -1) return AVXConversionTbl[Idx].Cost; } @@ -466,7 +467,7 @@ unsigned X86TTI::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); - static const CostTblEntry<MVT> SSE42CostTbl[] = { + static const CostTblEntry<MVT::SimpleValueType> SSE42CostTbl[] = { { ISD::SETCC, MVT::v2f64, 1 }, { ISD::SETCC, MVT::v4f32, 1 }, { ISD::SETCC, MVT::v2i64, 1 }, @@ -475,7 +476,7 @@ unsigned X86TTI::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, { ISD::SETCC, MVT::v16i8, 1 }, }; - static const CostTblEntry<MVT> AVX1CostTbl[] = { + static const CostTblEntry<MVT::SimpleValueType> AVX1CostTbl[] = { { ISD::SETCC, MVT::v4f64, 1 }, { ISD::SETCC, MVT::v8f32, 1 }, // AVX1 does not support 8-wide integer compare. @@ -485,7 +486,7 @@ unsigned X86TTI::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, { ISD::SETCC, MVT::v32i8, 4 }, }; - static const CostTblEntry<MVT> AVX2CostTbl[] = { + static const CostTblEntry<MVT::SimpleValueType> AVX2CostTbl[] = { { ISD::SETCC, MVT::v4i64, 1 }, { ISD::SETCC, MVT::v8i32, 1 }, { ISD::SETCC, MVT::v16i16, 1 }, @@ -493,22 +494,19 @@ unsigned X86TTI::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, }; if (ST->hasAVX2()) { - int Idx = CostTableLookup<MVT>(AVX2CostTbl, array_lengthof(AVX2CostTbl), - ISD, MTy); + int Idx = CostTableLookup(AVX2CostTbl, ISD, MTy); if (Idx != -1) return LT.first * AVX2CostTbl[Idx].Cost; } if (ST->hasAVX()) { - int Idx = CostTableLookup<MVT>(AVX1CostTbl, array_lengthof(AVX1CostTbl), - ISD, MTy); + int Idx = CostTableLookup(AVX1CostTbl, ISD, MTy); if (Idx != -1) return LT.first * AVX1CostTbl[Idx].Cost; } if (ST->hasSSE42()) { - int Idx = CostTableLookup<MVT>(SSE42CostTbl, array_lengthof(SSE42CostTbl), - ISD, MTy); + int Idx = CostTableLookup(SSE42CostTbl, ISD, MTy); if (Idx != -1) return LT.first * SSE42CostTbl[Idx].Cost; } @@ -613,3 +611,84 @@ unsigned X86TTI::getAddressComputationCost(Type *Ty, bool IsComplex) const { return TargetTransformInfo::getAddressComputationCost(Ty, IsComplex); } + +unsigned X86TTI::getReductionCost(unsigned Opcode, Type *ValTy, + bool IsPairwise) const { + + std::pair<unsigned, MVT> LT = TLI->getTypeLegalizationCost(ValTy); + + MVT MTy = LT.second; + + int ISD = TLI->InstructionOpcodeToISD(Opcode); + assert(ISD && "Invalid opcode"); + + // We use the Intel Architecture Code Analyzer(IACA) to measure the throughput + // and make it as the cost. + + static const CostTblEntry<MVT::SimpleValueType> SSE42CostTblPairWise[] = { + { ISD::FADD, MVT::v2f64, 2 }, + { ISD::FADD, MVT::v4f32, 4 }, + { ISD::ADD, MVT::v2i64, 2 }, // The data reported by the IACA tool is "1.6". + { ISD::ADD, MVT::v4i32, 3 }, // The data reported by the IACA tool is "3.5". + { ISD::ADD, MVT::v8i16, 5 }, + }; + + static const CostTblEntry<MVT::SimpleValueType> AVX1CostTblPairWise[] = { + { ISD::FADD, MVT::v4f32, 4 }, + { ISD::FADD, MVT::v4f64, 5 }, + { ISD::FADD, MVT::v8f32, 7 }, + { ISD::ADD, MVT::v2i64, 1 }, // The data reported by the IACA tool is "1.5". + { ISD::ADD, MVT::v4i32, 3 }, // The data reported by the IACA tool is "3.5". + { ISD::ADD, MVT::v4i64, 5 }, // The data reported by the IACA tool is "4.8". + { ISD::ADD, MVT::v8i16, 5 }, + { ISD::ADD, MVT::v8i32, 5 }, + }; + + static const CostTblEntry<MVT::SimpleValueType> SSE42CostTblNoPairWise[] = { + { ISD::FADD, MVT::v2f64, 2 }, + { ISD::FADD, MVT::v4f32, 4 }, + { ISD::ADD, MVT::v2i64, 2 }, // The data reported by the IACA tool is "1.6". + { ISD::ADD, MVT::v4i32, 3 }, // The data reported by the IACA tool is "3.3". + { ISD::ADD, MVT::v8i16, 4 }, // The data reported by the IACA tool is "4.3". + }; + + static const CostTblEntry<MVT::SimpleValueType> AVX1CostTblNoPairWise[] = { + { ISD::FADD, MVT::v4f32, 3 }, + { ISD::FADD, MVT::v4f64, 3 }, + { ISD::FADD, MVT::v8f32, 4 }, + { ISD::ADD, MVT::v2i64, 1 }, // The data reported by the IACA tool is "1.5". + { ISD::ADD, MVT::v4i32, 3 }, // The data reported by the IACA tool is "2.8". + { ISD::ADD, MVT::v4i64, 3 }, + { ISD::ADD, MVT::v8i16, 4 }, + { ISD::ADD, MVT::v8i32, 5 }, + }; + + if (IsPairwise) { + if (ST->hasAVX()) { + int Idx = CostTableLookup(AVX1CostTblPairWise, ISD, MTy); + if (Idx != -1) + return LT.first * AVX1CostTblPairWise[Idx].Cost; + } + + if (ST->hasSSE42()) { + int Idx = CostTableLookup(SSE42CostTblPairWise, ISD, MTy); + if (Idx != -1) + return LT.first * SSE42CostTblPairWise[Idx].Cost; + } + } else { + if (ST->hasAVX()) { + int Idx = CostTableLookup(AVX1CostTblNoPairWise, ISD, MTy); + if (Idx != -1) + return LT.first * AVX1CostTblNoPairWise[Idx].Cost; + } + + if (ST->hasSSE42()) { + int Idx = CostTableLookup(SSE42CostTblNoPairWise, ISD, MTy); + if (Idx != -1) + return LT.first * SSE42CostTblNoPairWise[Idx].Cost; + } + } + + return TargetTransformInfo::getReductionCost(Opcode, ValTy, IsPairwise); +} + |