aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lib/Target/ARM/ARMISelLowering.cpp96
-rw-r--r--test/CodeGen/ARM/vmul.ll29
2 files changed, 110 insertions, 15 deletions
diff --git a/lib/Target/ARM/ARMISelLowering.cpp b/lib/Target/ARM/ARMISelLowering.cpp
index 623393d..3267ea8 100644
--- a/lib/Target/ARM/ARMISelLowering.cpp
+++ b/lib/Target/ARM/ARMISelLowering.cpp
@@ -4370,6 +4370,28 @@ static SDValue SkipExtension(SDNode *N, SelectionDAG &DAG) {
MVT::getVectorVT(TruncVT, NumElts), Ops.data(), NumElts);
}
+static bool isAddSubSExt(SDNode *N, SelectionDAG &DAG) {
+ unsigned Opcode = N->getOpcode();
+ if (Opcode == ISD::ADD || Opcode == ISD::SUB) {
+ SDNode *N0 = N->getOperand(0).getNode();
+ SDNode *N1 = N->getOperand(1).getNode();
+ return N0->hasOneUse() && N1->hasOneUse() &&
+ isSignExtended(N0, DAG) && isSignExtended(N1, DAG);
+ }
+ return false;
+}
+
+static bool isAddSubZExt(SDNode *N, SelectionDAG &DAG) {
+ unsigned Opcode = N->getOpcode();
+ if (Opcode == ISD::ADD || Opcode == ISD::SUB) {
+ SDNode *N0 = N->getOperand(0).getNode();
+ SDNode *N1 = N->getOperand(1).getNode();
+ return N0->hasOneUse() && N1->hasOneUse() &&
+ isZeroExtended(N0, DAG) && isZeroExtended(N1, DAG);
+ }
+ return false;
+}
+
static SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) {
// Multiplications are only custom-lowered for 128-bit vectors so that
// VMULL can be detected. Otherwise v2i64 multiplications are not legal.
@@ -4378,26 +4400,70 @@ static SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) {
SDNode *N0 = Op.getOperand(0).getNode();
SDNode *N1 = Op.getOperand(1).getNode();
unsigned NewOpc = 0;
- if (isSignExtended(N0, DAG) && isSignExtended(N1, DAG))
+ bool isMLA = false;
+ bool isN0SExt = isSignExtended(N0, DAG);
+ bool isN1SExt = isSignExtended(N1, DAG);
+ if (isN0SExt && isN1SExt)
NewOpc = ARMISD::VMULLs;
- else if (isZeroExtended(N0, DAG) && isZeroExtended(N1, DAG))
- NewOpc = ARMISD::VMULLu;
- else if (VT == MVT::v2i64)
- // Fall through to expand this. It is not legal.
- return SDValue();
- else
- // Other vector multiplications are legal.
- return Op;
+ else {
+ bool isN0ZExt = isZeroExtended(N0, DAG);
+ bool isN1ZExt = isZeroExtended(N1, DAG);
+ if (isN0ZExt && isN1ZExt)
+ NewOpc = ARMISD::VMULLu;
+ else if (isN1SExt || isN1ZExt) {
+ // Look for (s/zext A + s/zext B) * (s/zext C). We want to turn these
+ // into (s/zext A * s/zext C) + (s/zext B * s/zext C)
+ if (isN1SExt && isAddSubSExt(N0, DAG)) {
+ NewOpc = ARMISD::VMULLs;
+ isMLA = true;
+ } else if (isN1ZExt && isAddSubZExt(N0, DAG)) {
+ NewOpc = ARMISD::VMULLu;
+ isMLA = true;
+ } else if (isN0ZExt && isAddSubZExt(N1, DAG)) {
+ std::swap(N0, N1);
+ NewOpc = ARMISD::VMULLu;
+ isMLA = true;
+ }
+ }
+
+ if (!NewOpc) {
+ if (VT == MVT::v2i64)
+ // Fall through to expand this. It is not legal.
+ return SDValue();
+ else
+ // Other vector multiplications are legal.
+ return Op;
+ }
+ }
// Legalize to a VMULL instruction.
DebugLoc DL = Op.getDebugLoc();
- SDValue Op0 = SkipExtension(N0, DAG);
+ SDValue Op0;
SDValue Op1 = SkipExtension(N1, DAG);
-
- assert(Op0.getValueType().is64BitVector() &&
- Op1.getValueType().is64BitVector() &&
- "unexpected types for extended operands to VMULL");
- return DAG.getNode(NewOpc, DL, VT, Op0, Op1);
+ if (!isMLA) {
+ Op0 = SkipExtension(N0, DAG);
+ assert(Op0.getValueType().is64BitVector() &&
+ Op1.getValueType().is64BitVector() &&
+ "unexpected types for extended operands to VMULL");
+ return DAG.getNode(NewOpc, DL, VT, Op0, Op1);
+ }
+
+ // Optimizing (zext A + zext B) * C, to (VMULL A, C) + (VMULL B, C) during
+ // isel lowering to take advantage of no-stall back to back vmul + vmla.
+ // vmull q0, d4, d6
+ // vmlal q0, d5, d6
+ // is faster than
+ // vaddl q0, d4, d5
+ // vmovl q1, d6
+ // vmul q0, q0, q1
+ SDValue N00 = SkipExtension(N0->getOperand(0).getNode(), DAG);
+ SDValue N01 = SkipExtension(N0->getOperand(1).getNode(), DAG);
+ EVT Op1VT = Op1.getValueType();
+ return DAG.getNode(N0->getOpcode(), DL, VT,
+ DAG.getNode(NewOpc, DL, VT,
+ DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1),
+ DAG.getNode(NewOpc, DL, VT,
+ DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1));
}
static SDValue
diff --git a/test/CodeGen/ARM/vmul.ll b/test/CodeGen/ARM/vmul.ll
index ee033ca..585394e 100644
--- a/test/CodeGen/ARM/vmul.ll
+++ b/test/CodeGen/ARM/vmul.ll
@@ -339,3 +339,32 @@ define <2 x i64> @vmull_extvec_u32(<2 x i32> %arg) nounwind {
%tmp4 = mul <2 x i64> %tmp3, <i64 1234, i64 1234>
ret <2 x i64> %tmp4
}
+
+; rdar://9197392
+define void @distribue(i16* %dst, i8* %src, i32 %mul) nounwind {
+entry:
+; CHECK: distribue:
+; CHECK: vmull.u8 [[REG1:(q[0-9]+)]], d{{.*}}, [[REG2:(d[0-9]+)]]
+; CHECK: vmlal.u8 [[REG1]], d{{.*}}, [[REG2]]
+ %0 = trunc i32 %mul to i8
+ %1 = insertelement <8 x i8> undef, i8 %0, i32 0
+ %2 = shufflevector <8 x i8> %1, <8 x i8> undef, <8 x i32> zeroinitializer
+ %3 = tail call <16 x i8> @llvm.arm.neon.vld1.v16i8(i8* %src, i32 1)
+ %4 = bitcast <16 x i8> %3 to <2 x double>
+ %5 = extractelement <2 x double> %4, i32 1
+ %6 = bitcast double %5 to <8 x i8>
+ %7 = zext <8 x i8> %6 to <8 x i16>
+ %8 = zext <8 x i8> %2 to <8 x i16>
+ %9 = extractelement <2 x double> %4, i32 0
+ %10 = bitcast double %9 to <8 x i8>
+ %11 = zext <8 x i8> %10 to <8 x i16>
+ %12 = add <8 x i16> %7, %11
+ %13 = mul <8 x i16> %12, %8
+ %14 = bitcast i16* %dst to i8*
+ tail call void @llvm.arm.neon.vst1.v8i16(i8* %14, <8 x i16> %13, i32 2)
+ ret void
+}
+
+declare <16 x i8> @llvm.arm.neon.vld1.v16i8(i8*, i32) nounwind readonly
+
+declare void @llvm.arm.neon.vst1.v8i16(i8*, <8 x i16>, i32) nounwind