diff options
-rw-r--r-- | rs/java/android/renderscript/ScriptIntrinsicBLAS.java | 145 |
1 files changed, 88 insertions, 57 deletions
diff --git a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java index 94bde10..d503699 100644 --- a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java +++ b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java @@ -241,7 +241,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } static void validateUplo(@Uplo int Uplo) { - if (Uplo != LEFT && Uplo != RIGHT) { + if (Uplo != UPPER && Uplo != LOWER) { throw new RSRuntimeException("Invalid uplo passed to BLAS"); } } @@ -985,56 +985,74 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { */ static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { - int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1; + int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; if ((A != null && !A.getType().getElement().isCompatible(e)) || (B != null && !B.getType().getElement().isCompatible(e)) || (C != null && !C.getType().getElement().isCompatible(e))) { throw new RSRuntimeException("Called BLAS with wrong Element type"); } - if (C != null) { - cX = C.getType().getY(); - cY = C.getType().getX(); + if (C == null) { + //since matrix C is used to store the result, it cannot be null. + throw new RSRuntimeException("Allocation C cannot be null"); } + cM = C.getType().getY(); + cN = C.getType().getX(); + if (Side == RIGHT) { + if ((A == null && B != null) || (A != null && B == null)) { + throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa"); + } if (B != null) { - bX = A.getType().getY(); - bY = A.getType().getX(); + bM = A.getType().getY(); + bN = A.getType().getX(); } if (A != null) { - aX = B.getType().getY(); - aY = B.getType().getX(); + aM = B.getType().getY(); + aN = B.getType().getX(); } } else { if (A != null) { - if (TransA == TRANSPOSE) { - aY = A.getType().getY(); - aX = A.getType().getX(); + if (TransA != NO_TRANSPOSE) { + aN = A.getType().getY(); + aM = A.getType().getX(); } else { - aX = A.getType().getY(); - aY = A.getType().getX(); + aM = A.getType().getY(); + aN = A.getType().getX(); } } if (B != null) { - if (TransB == TRANSPOSE) { - bY = B.getType().getY(); - bX = B.getType().getX(); + if (TransB != NO_TRANSPOSE) { + bN = B.getType().getY(); + bM = B.getType().getX(); } else { - bX = B.getType().getY(); - bY = B.getType().getX(); + bM = B.getType().getY(); + bN = B.getType().getX(); } } } if (A != null && B != null && C != null) { - if (aY != bX || aX != cX || bY != cY) { + if (aN != bM || aM != cM || bN != cN) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } else if (A != null && C != null) { - // A and C only - if (aX != cY || aY != cX) { - throw new RSRuntimeException("Called BLAS with invalid dimensions"); + // A and C only, for SYRK + if (cM != cN) { + throw new RSRuntimeException("Matrix C is not symmetric"); + } + if (TransA != NO_TRANSPOSE) { + if (aN != cM) { + throw new RSRuntimeException("Called BLAS with invalid dimensions"); + } + } else { + if (aM != cM) { + throw new RSRuntimeException("Called BLAS with invalid dimensions"); + } } } else if (A != null && B != null) { // A and B only + if (aN != bM) { + throw new RSRuntimeException("Called BLAS with invalid dimensions"); + } } } @@ -1046,14 +1064,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { + if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } - if (TransB == TRANSPOSE) { + if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); @@ -1067,14 +1085,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { + if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } - if (TransB == TRANSPOSE) { + if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); @@ -1088,14 +1106,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { + if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } - if (TransB == TRANSPOSE) { + if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); @@ -1110,14 +1128,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { + if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } - if (TransB == TRANSPOSE) { + if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); @@ -1130,6 +1148,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, float beta, Allocation C) { validateSide(Side); validateUplo(Uplo); + //For SYMM, Matrix A should be symmetric + if (A.getType().getX() != A.getType().getY()) { + throw new RSRuntimeException("Matrix A is not symmetric"); + } validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); @@ -1138,6 +1160,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, double beta, Allocation C) { validateSide(Side); validateUplo(Uplo); + if (A.getType().getX() != A.getType().getY()) { + throw new RSRuntimeException("Matrix A is not symmetric"); + } validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); @@ -1146,6 +1171,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Float2 beta, Allocation C) { validateSide(Side); validateUplo(Uplo); + if (A.getType().getX() != A.getType().getY()) { + throw new RSRuntimeException("Matrix A is not symmetric"); + } validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); @@ -1154,6 +1182,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Double2 beta, Allocation C) { validateSide(Side); validateUplo(Uplo); + if (A.getType().getX() != A.getType().getY()) { + throw new RSRuntimeException("Matrix A is not symmetric"); + } validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); @@ -1164,7 +1195,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); @@ -1178,7 +1209,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); @@ -1190,7 +1221,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); @@ -1203,7 +1234,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); @@ -1229,7 +1260,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // check rows versus C Cdim = A.getType().getY(); } - if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) { + if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) { throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); } // A dims == B dims @@ -1285,26 +1316,26 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { validateSide(Side); validateTranspose(TransA); - int aX = -1, aY = -1, bX = -1, bY = -1; + int aM = -1, aN = -1, bM = -1, bN = -1; if (!A.getType().getElement().isCompatible(e) || !B.getType().getElement().isCompatible(e)) { throw new RSRuntimeException("Called BLAS with wrong Element type"); } - if (TransA == TRANSPOSE) { - aY = A.getType().getY(); - aX = A.getType().getX(); - } else { - aY = A.getType().getX(); - aX = A.getType().getY(); + + aM = A.getType().getY(); + aN = A.getType().getX(); + if (aM != aN) { + throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A"); } - bX = B.getType().getY(); - bY = B.getType().getX(); + + bM = B.getType().getY(); + bN = B.getType().getX(); if (Side == LEFT) { - if (aX == 0 || aY != bX) { + if (aN != bM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); } } else { - if (bY != aX || aY == 0) { + if (bN != aM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); } } @@ -1339,7 +1370,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { - int adim = -1, bX = -1, bY = -1; + int adim = -1, bM = -1, bN = -1; validateSide(Side); validateTranspose(TransA); if (!A.getType().getElement().isCompatible(e) || @@ -1353,16 +1384,16 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // for now we assume adapters are sufficient, will reevaluate in the future throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); } - bX = B.getType().getY(); - bY = B.getType().getX(); + bM = B.getType().getY(); + bN = B.getType().getX(); if (Side == LEFT) { // A is M*M - if (adim != bY) { + if (adim != bM) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } else { // A is N*N - if (adim != bX) { + if (adim != bN) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } @@ -1427,7 +1458,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { validateUplo(Uplo); - validateHEMM(Element.F32_2(mRS), Side, A, B, C); + validateHEMM(Element.F64_2(mRS), Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); } @@ -1443,11 +1474,11 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { throw new RSRuntimeException("Called HERK with non-square C"); } if (Trans == NO_TRANSPOSE) { - if (cdim != A.getType().getX()) { + if (cdim != A.getType().getY()) { throw new RSRuntimeException("Called HERK with invalid A"); } } else { - if (cdim != A.getType().getY()) { + if (cdim != A.getType().getX()) { throw new RSRuntimeException("Called HERK with invalid A"); } } @@ -1456,7 +1487,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateHERK(Element.F32_2(mRS), Trans, A, C); int k = 0; - if (Trans == TRANSPOSE) { + if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); } else { k = A.getType().getX(); @@ -1468,7 +1499,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateHERK(Element.F64_2(mRS), Trans, A, C); int k = 0; - if (Trans == TRANSPOSE) { + if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); } else { k = A.getType().getX(); |