From 68ca43ebe6e162ff13fc4f96d5aacd494980f6b6 Mon Sep 17 00:00:00 2001 From: Miao Wang Date: Thu, 23 Apr 2015 15:06:09 -0700 Subject: [RenderScript] improve & minor fixes of L2 BLAS validation. Change-Id: If8dd0f9d4c7db03df22763a80fa6d600539be7c1 --- .../android/renderscript/ScriptIntrinsicBLAS.java | 124 ++++++++++++++------- 1 file changed, 82 insertions(+), 42 deletions(-) (limited to 'rs') diff --git a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java index 51096a0..02554be 100644 --- a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java +++ b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java @@ -276,7 +276,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { expectedYDim = 1 + (N - 1) * incY; } if (X.getType().getX() != expectedXDim || - Y.getType().getY() != expectedXDim) { + Y.getType().getX() != expectedYDim) { throw new RSRuntimeException("Incorrect vector dimensions for GEMV"); } } @@ -346,8 +346,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); } - static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) { + static void validateTRMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { validateTranspose(TransA); + validateUplo(Uplo); + validateDiag(Diag); int N = A.getType().getY(); if (A.getType().getX() != N) { throw new RSRuntimeException("A must be a square matrix for TRMV"); @@ -386,59 +388,75 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); + //is it really doing anything? if (Ap.getType().getX() != ((N * (N+1)) / 2)) { throw new RSRuntimeException("Invalid dimension for Ap"); } - + if (incX <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); + throw new RSRuntimeException("Incorrect vector dimensions for TPMV"); } return N; } void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F32(mRS), TransA, A, X, incX); + validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F64(mRS), TransA, A, X, incX); + validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } + void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F32(mRS), TransA, A, X, incX); + // TBMV has the same requirements as TRMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } + validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F64(mRS), TransA, A, X, incX); + // TBMV has the same requirements as TRMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } + validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + // TBMV has the same requirements as TRMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } + validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + // TBMV has the same requirements as TRMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } + validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } @@ -460,35 +478,35 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { // TRSV is the same as TRMV - validateTRMV(Element.F32(mRS), TransA, A, X, incX); + validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { // TRSV is the same as TRMV - validateTRMV(Element.F64(mRS), TransA, A, X, incX); + validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { // TRSV is the same as TRMV - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { // TRSV is the same as TRMV - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F32(mRS), TransA, A, X, incX); + // TBSV is the same as TRMV + K >= 0 + validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); if (K < 0) { throw new RSRuntimeException("Number of diagonals must be positive"); @@ -496,8 +514,8 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F64(mRS), TransA, A, X, incX); + // TBSV is the same as TRMV + K >= 0 + validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); if (K < 0) { throw new RSRuntimeException("Number of diagonals must be positive"); @@ -505,8 +523,8 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + // TBSV is the same as TRMV + K >= 0 + validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); if (K < 0) { throw new RSRuntimeException("Number of diagonals must be positive"); @@ -514,8 +532,8 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + // TBSV is the same as TRMV + K >= 0 + validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); if (K < 0) { throw new RSRuntimeException("Number of diagonals must be positive"); @@ -593,7 +611,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (Ap.getType().getX() != ((N * (N+1)) / 2)) { throw new RSRuntimeException("Invalid dimension for Ap"); } - + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; if (X.getType().getX() != expectedXDim) { throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); @@ -622,8 +642,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (N < 1 || M < 1) { throw new RSRuntimeException("M and N must be 1 or greater for GER"); } - - int expectedXDim = 1 + (N - 1) * incX; + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } + int expectedXDim = 1 + (M - 1) * incX; if (X.getType().getX() != expectedXDim) { throw new RSRuntimeException("Incorrect vector dimensions for GER"); } @@ -649,7 +671,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (N != A.getType().getY()) { throw new RSRuntimeException("A must be a symmetric matrix"); } - + if (incX <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; if (X.getType().getX() != expectedXDim) { throw new RSRuntimeException("Incorrect vector dimensions for SYR"); @@ -674,10 +698,12 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (Ap.getType().getX() != ((N * (N+1)) / 2)) { throw new RSRuntimeException("Invalid dimension for Ap"); } - + if (incX <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + throw new RSRuntimeException("Incorrect vector dimensions for SPR"); } return N; @@ -700,7 +726,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (N != A.getType().getY()) { throw new RSRuntimeException("A must be a symmetric matrix"); } - + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; int expectedYDim = 1 + (N - 1) * incY; if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { @@ -728,11 +756,13 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (Ap.getType().getX() != ((N * (N+1)) / 2)) { throw new RSRuntimeException("Invalid dimension for Ap"); } - + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; int expectedYDim = 1 + (N - 1) * incY; if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + throw new RSRuntimeException("Incorrect vector dimensions for SPR2"); } return N; @@ -743,7 +773,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { - // SBMV is the same as SYMV + // SBMV is the same as SYMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } @@ -754,6 +787,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { int M = A.getType().getY(); int N = A.getType().getX(); + validateGER(Element.F32(mRS), X, incX, Y, incY, A); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); } void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { @@ -777,7 +811,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { - // SBMV is the same as SYMV + // SBMV is the same as SYMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } @@ -788,6 +825,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { int M = A.getType().getY(); int N = A.getType().getX(); + validateGER(Element.F64(mRS), X, incX, Y, incY, A); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); } void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { @@ -824,8 +862,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { int M = A.getType().getY(); int N = A.getType().getX(); - - int expectedXDim = 1 + (N - 1) * incX; + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } + int expectedXDim = 1 + (M - 1) * incX; if (X.getType().getX() != expectedXDim) { throw new RSRuntimeException("Incorrect vector dimensions for GERU"); } @@ -869,7 +909,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { // same as SYR - int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); + int N = validateSYR(Element.F32_2(mRS), Uplo, X, incX, A); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); } void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { -- cgit v1.1