diff options
Diffstat (limited to 'rs')
-rw-r--r-- | rs/java/android/renderscript/Allocation.java | 2 | ||||
-rw-r--r-- | rs/java/android/renderscript/Element.java | 5 | ||||
-rw-r--r-- | rs/java/android/renderscript/RenderScript.java | 24 | ||||
-rw-r--r-- | rs/java/android/renderscript/Script.java | 4 | ||||
-rw-r--r-- | rs/java/android/renderscript/ScriptGroup.java | 6 | ||||
-rw-r--r-- | rs/java/android/renderscript/ScriptIntrinsicBLAS.java | 439 | ||||
-rw-r--r-- | rs/jni/android_renderscript_RenderScript.cpp | 238 |
7 files changed, 457 insertions, 261 deletions
diff --git a/rs/java/android/renderscript/Allocation.java b/rs/java/android/renderscript/Allocation.java index 3b61f9d..70a5821 100644 --- a/rs/java/android/renderscript/Allocation.java +++ b/rs/java/android/renderscript/Allocation.java @@ -1422,6 +1422,8 @@ public class Allocation extends BaseObj { } /** + * @hide + * * This is only intended to be used by auto-generated code reflected from * the RenderScript script files and should not be used by developers. * diff --git a/rs/java/android/renderscript/Element.java b/rs/java/android/renderscript/Element.java index 4b3e30f..6efb6d6 100644 --- a/rs/java/android/renderscript/Element.java +++ b/rs/java/android/renderscript/Element.java @@ -536,8 +536,8 @@ public class Element extends BaseObj { } public static Element F16_3(RenderScript rs) { - if(rs.mElement_FLOAT_3 == null) { - rs.mElement_FLOAT_3 = createVector(rs, DataType.FLOAT_16, 3); + if(rs.mElement_HALF_3 == null) { + rs.mElement_HALF_3 = createVector(rs, DataType.FLOAT_16, 3); } return rs.mElement_HALF_3; } @@ -911,6 +911,7 @@ public class Element extends BaseObj { switch (dt) { // Support only primitive integer/float/boolean types as vectors. + case FLOAT_16: case FLOAT_32: case FLOAT_64: case SIGNED_8: diff --git a/rs/java/android/renderscript/RenderScript.java b/rs/java/android/renderscript/RenderScript.java index e7f210b..27f2cc8 100644 --- a/rs/java/android/renderscript/RenderScript.java +++ b/rs/java/android/renderscript/RenderScript.java @@ -131,7 +131,7 @@ public class RenderScript { // this should be a monotonically increasing ID // used in conjunction with the API version of a device - static final long sMinorID = 1; + static final long sMinorVersion = 1; /** * Returns an identifier that can be used to identify a particular @@ -140,8 +140,8 @@ public class RenderScript { * @return The minor RenderScript version number * */ - public static long getMinorID() { - return sMinorID; + public static long getMinorVersion() { + return sMinorVersion; } /** @@ -302,8 +302,12 @@ public class RenderScript { long[] fieldIDs, long[] values, int[] sizes, long[] depClosures, long[] depFieldIDs) { validate(); - return rsnClosureCreate(mContext, kernelID, returnValue, fieldIDs, values, + long c = rsnClosureCreate(mContext, kernelID, returnValue, fieldIDs, values, sizes, depClosures, depFieldIDs); + if (c == 0) { + throw new RSRuntimeException("Failed creating closure."); + } + return c; } native long rsnInvokeClosureCreate(long con, long invokeID, byte[] params, @@ -311,8 +315,12 @@ public class RenderScript { synchronized long nInvokeClosureCreate(long invokeID, byte[] params, long[] fieldIDs, long[] values, int[] sizes) { validate(); - return rsnInvokeClosureCreate(mContext, invokeID, params, fieldIDs, + long c = rsnInvokeClosureCreate(mContext, invokeID, params, fieldIDs, values, sizes); + if (c == 0) { + throw new RSRuntimeException("Failed creating closure."); + } + return c; } native void rsnClosureSetArg(long con, long closureID, int index, @@ -337,7 +345,11 @@ public class RenderScript { synchronized long nScriptGroup2Create(String name, String cachePath, long[] closures) { validate(); - return rsnScriptGroup2Create(mContext, name, cachePath, closures); + long g = rsnScriptGroup2Create(mContext, name, cachePath, closures); + if (g == 0) { + throw new RSRuntimeException("Failed creating script group."); + } + return g; } native void rsnScriptGroup2Execute(long con, long groupID); diff --git a/rs/java/android/renderscript/Script.java b/rs/java/android/renderscript/Script.java index 6a1efee..7cd6d09 100644 --- a/rs/java/android/renderscript/Script.java +++ b/rs/java/android/renderscript/Script.java @@ -182,9 +182,9 @@ public class Script extends BaseObj { mRS.validateObject(ain); mRS.validateObject(aout); - if (ain == null && aout == null) { + if (ain == null && aout == null && sc == null) { throw new RSIllegalArgumentException( - "At least one of ain or aout is required to be non-null."); + "At least one of input allocation, output allocation, or LaunchOptions is required to be non-null."); } long[] in_ids = null; diff --git a/rs/java/android/renderscript/ScriptGroup.java b/rs/java/android/renderscript/ScriptGroup.java index be8b0fd..d1a12f9 100644 --- a/rs/java/android/renderscript/ScriptGroup.java +++ b/rs/java/android/renderscript/ScriptGroup.java @@ -400,8 +400,10 @@ public final class ScriptGroup extends BaseObj { /** * Executes a script group * - * @param inputs inputs to the script group - * @return outputs of the script group as an array of objects + * @param inputs Values for inputs to the script group, in the order as the + * inputs are added via {@link Builder2#addInput}. + * @return Outputs of the script group as an array of objects, in the order + * as futures are passed to {@link Builder2#create}. */ public Object[] execute(Object... inputs) { diff --git a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java index 6cfdfee..f7e81b0 100644 --- a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java +++ b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java @@ -241,7 +241,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } static void validateUplo(@Uplo int Uplo) { - if (Uplo != LEFT && Uplo != RIGHT) { + if (Uplo != UPPER && Uplo != LOWER) { throw new RSRuntimeException("Invalid uplo passed to BLAS"); } } @@ -276,36 +276,36 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { expectedYDim = 1 + (N - 1) * incY; } if (X.getType().getX() != expectedXDim || - Y.getType().getY() != expectedXDim) { + Y.getType().getX() != expectedYDim) { throw new RSRuntimeException("Incorrect vector dimensions for GEMV"); } } - void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + public void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); int M = A.getType().getY(); int N = A.getType().getX(); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } - void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + public void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); int M = A.getType().getY(); int N = A.getType().getX(); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } - void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + public void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); int M = A.getType().getY(); int N = A.getType().getX(); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); } - void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + public void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); int M = A.getType().getY(); int N = A.getType().getX(); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); } - void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + public void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { @@ -315,7 +315,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { int N = A.getType().getX(); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); } - void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + public void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { @@ -325,7 +325,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { int N = A.getType().getX(); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); } - void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + public void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { @@ -335,7 +335,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { int N = A.getType().getX(); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); } - void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + public void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { @@ -346,8 +346,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); } - static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) { + static void validateTRMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { validateTranspose(TransA); + validateUplo(Uplo); + validateDiag(Diag); int N = A.getType().getY(); if (A.getType().getX() != N) { throw new RSRuntimeException("A must be a square matrix for TRMV"); @@ -386,158 +388,174 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); + //is it really doing anything? if (Ap.getType().getX() != ((N * (N+1)) / 2)) { throw new RSRuntimeException("Invalid dimension for Ap"); } - + if (incX <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); + throw new RSRuntimeException("Incorrect vector dimensions for TPMV"); } return N; } - void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F32(mRS), TransA, A, X, incX); + public void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F64(mRS), TransA, A, X, incX); + public void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + public void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } - void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + public void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } - void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F32(mRS), TransA, A, X, incX); + + public void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } + validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F64(mRS), TransA, A, X, incX); + public void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } + validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + public void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } + validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } - void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + public void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } + validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } - void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + public void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + public void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + public void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } - void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + public void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } - void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + public void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { // TRSV is the same as TRMV - validateTRMV(Element.F32(mRS), TransA, A, X, incX); + validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + public void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { // TRSV is the same as TRMV - validateTRMV(Element.F64(mRS), TransA, A, X, incX); + validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + public void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { // TRSV is the same as TRMV - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } - void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + public void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { // TRSV is the same as TRMV - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } - void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F32(mRS), TransA, A, X, incX); + public void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + K >= 0 + validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); if (K < 0) { throw new RSRuntimeException("Number of diagonals must be positive"); } mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F64(mRS), TransA, A, X, incX); + public void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + K >= 0 + validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); if (K < 0) { throw new RSRuntimeException("Number of diagonals must be positive"); } mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + public void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + K >= 0 + validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); if (K < 0) { throw new RSRuntimeException("Number of diagonals must be positive"); } mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } - void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + public void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + K >= 0 + validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A.getType().getY(); if (K < 0) { throw new RSRuntimeException("Number of diagonals must be positive"); } mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } - void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + public void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { // TPSV is same as TPMV int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + public void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { // TPSV is same as TPMV int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); } - void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + public void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { // TPSV is same as TPMV int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); } - void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + public void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { // TPSV is same as TPMV int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); @@ -593,7 +611,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (Ap.getType().getX() != ((N * (N+1)) / 2)) { throw new RSRuntimeException("Invalid dimension for Ap"); } - + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; if (X.getType().getX() != expectedXDim) { throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); @@ -622,8 +642,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (N < 1 || M < 1) { throw new RSRuntimeException("M and N must be 1 or greater for GER"); } - - int expectedXDim = 1 + (N - 1) * incX; + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } + int expectedXDim = 1 + (M - 1) * incX; if (X.getType().getX() != expectedXDim) { throw new RSRuntimeException("Incorrect vector dimensions for GER"); } @@ -649,7 +671,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (N != A.getType().getY()) { throw new RSRuntimeException("A must be a symmetric matrix"); } - + if (incX <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; if (X.getType().getX() != expectedXDim) { throw new RSRuntimeException("Incorrect vector dimensions for SYR"); @@ -674,10 +698,12 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (Ap.getType().getX() != ((N * (N+1)) / 2)) { throw new RSRuntimeException("Invalid dimension for Ap"); } - + if (incX <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + throw new RSRuntimeException("Incorrect vector dimensions for SPR"); } return N; @@ -700,7 +726,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (N != A.getType().getY()) { throw new RSRuntimeException("A must be a symmetric matrix"); } - + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; int expectedYDim = 1 + (N - 1) * incY; if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { @@ -728,81 +756,91 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { if (Ap.getType().getX() != ((N * (N+1)) / 2)) { throw new RSRuntimeException("Invalid dimension for Ap"); } - + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } int expectedXDim = 1 + (N - 1) * incX; int expectedYDim = 1 + (N - 1) * incY; if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + throw new RSRuntimeException("Incorrect vector dimensions for SPR2"); } return N; } - void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + public void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } - void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { - // SBMV is the same as SYMV + public void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + // SBMV is the same as SYMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } - void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) { + public void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) { int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } - void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + public void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { int M = A.getType().getY(); int N = A.getType().getX(); + validateGER(Element.F32(mRS), X, incX, Y, incY, A); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); } - void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { + public void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); } - void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { + public void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); } - void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + public void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); } - void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + public void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); } - void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + public void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } - void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { - // SBMV is the same as SYMV + public void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + // SBMV is the same as SYMV + K >= 0 + if (K < 0) { + throw new RSRuntimeException("K must be greater than or equal to 0"); + } int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } - void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) { + public void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) { int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); } - void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + public void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { int M = A.getType().getY(); int N = A.getType().getX(); + validateGER(Element.F64(mRS), X, incX, Y, incY, A); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); } - void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { + public void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); } - void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { + public void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); } - void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + public void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); } - void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + public void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); } @@ -824,8 +862,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { int M = A.getType().getY(); int N = A.getType().getX(); - - int expectedXDim = 1 + (N - 1) * incX; + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } + int expectedXDim = 1 + (M - 1) * incX; if (X.getType().getX() != expectedXDim) { throw new RSRuntimeException("Incorrect vector dimensions for GERU"); } @@ -836,12 +876,12 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } - void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + public void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { // HEMV is the same as SYR2 validation-wise int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); } - void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + public void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { // HBMV is the same as SYR2 validation-wise int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); if (K < 0) { @@ -849,50 +889,50 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); } - void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + public void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { // HPMV is the same as SPR2 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); } - void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + public void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); int M = A.getType().getY(); int N = A.getType().getX(); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); } - void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + public void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { // same as GERU validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); int M = A.getType().getY(); int N = A.getType().getX(); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); } - void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { + public void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { // same as SYR - int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); + int N = validateSYR(Element.F32_2(mRS), Uplo, X, incX, A); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); } - void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { + public void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { // equivalent to SPR for validation int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); } - void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + public void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { // same as SYR2 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); } - void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + public void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { // same as SPR2 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); } - void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + public void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { // HEMV is the same as SYR2 validation-wise int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); } - void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + public void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { // HBMV is the same as SYR2 validation-wise int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); if (K < 0) { @@ -900,40 +940,40 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); } - void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + public void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { // HPMV is the same as SPR2 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); } - void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + public void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); int M = A.getType().getY(); int N = A.getType().getX(); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); } - void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + public void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { // same as GERU validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); int M = A.getType().getY(); int N = A.getType().getX(); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); } - void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { + public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { // same as SYR - int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); + int N = validateSYR(Element.F64_2(mRS), Uplo, X, incX, A); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); } - void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { + public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { // equivalent to SPR for validation int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); } - void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + public void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { // same as SYR2 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); } - void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + public void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { // same as SPR2 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); @@ -945,56 +985,68 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { */ static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { - int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1; + int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; if ((A != null && !A.getType().getElement().isCompatible(e)) || (B != null && !B.getType().getElement().isCompatible(e)) || (C != null && !C.getType().getElement().isCompatible(e))) { throw new RSRuntimeException("Called BLAS with wrong Element type"); } - if (C != null) { - cX = C.getType().getY(); - cY = C.getType().getX(); + if (C == null) { + //since matrix C is used to store the result, it cannot be null. + throw new RSRuntimeException("Allocation C cannot be null"); } + cM = C.getType().getY(); + cN = C.getType().getX(); + if (Side == RIGHT) { + if ((A == null && B != null) || (A != null && B == null)) { + throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa"); + } if (B != null) { - bX = A.getType().getY(); - bY = A.getType().getX(); + bM = A.getType().getY(); + bN = A.getType().getX(); } if (A != null) { - aX = B.getType().getY(); - aY = B.getType().getX(); + aM = B.getType().getY(); + aN = B.getType().getX(); } } else { if (A != null) { - if (TransA == TRANSPOSE) { - aY = A.getType().getY(); - aX = A.getType().getX(); + if (TransA == TRANSPOSE || TransA == CONJ_TRANSPOSE) { + aN = A.getType().getY(); + aM = A.getType().getX(); } else { - aX = A.getType().getY(); - aY = A.getType().getX(); + aM = A.getType().getY(); + aN = A.getType().getX(); } } if (B != null) { - if (TransB == TRANSPOSE) { - bY = B.getType().getY(); - bX = B.getType().getX(); + if (TransB == TRANSPOSE || TransB == CONJ_TRANSPOSE) { + bN = B.getType().getY(); + bM = B.getType().getX(); } else { - bX = B.getType().getY(); - bY = B.getType().getX(); + bM = B.getType().getY(); + bN = B.getType().getX(); } } } if (A != null && B != null && C != null) { - if (aY != bX || aX != cX || bY != cY) { + if (aN != bM || aM != cM || bN != cN) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } else if (A != null && C != null) { - // A and C only - if (aX != cY || aY != cX) { + // A and C only, for SYRK + if (cM != cN) { + throw new RSRuntimeException("Matrix C is not symmetric"); + } + if (aM != cM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } else if (A != null && B != null) { // A and B only + if (aN != bM) { + throw new RSRuntimeException("Called BLAS with invalid dimensions"); + } } } @@ -1006,14 +1058,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { + if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } - if (TransB == TRANSPOSE) { + if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); @@ -1027,14 +1079,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { + if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } - if (TransB == TRANSPOSE) { + if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); @@ -1048,14 +1100,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { + if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } - if (TransB == TRANSPOSE) { + if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); @@ -1070,14 +1122,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { + if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } - if (TransB == TRANSPOSE) { + if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); @@ -1090,6 +1142,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, float beta, Allocation C) { validateSide(Side); validateUplo(Uplo); + //For SYMM, Matrix A should be symmetric + if (A.getType().getX() != A.getType().getY()) { + throw new RSRuntimeException("Matrix A is not symmetric"); + } validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); @@ -1098,6 +1154,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, double beta, Allocation C) { validateSide(Side); validateUplo(Uplo); + if (A.getType().getX() != A.getType().getY()) { + throw new RSRuntimeException("Matrix A is not symmetric"); + } validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); @@ -1106,6 +1165,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Float2 beta, Allocation C) { validateSide(Side); validateUplo(Uplo); + if (A.getType().getX() != A.getType().getY()) { + throw new RSRuntimeException("Matrix A is not symmetric"); + } validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); @@ -1114,6 +1176,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Double2 beta, Allocation C) { validateSide(Side); validateUplo(Uplo); + if (A.getType().getX() != A.getType().getY()) { + throw new RSRuntimeException("Matrix A is not symmetric"); + } validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); @@ -1124,7 +1189,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); @@ -1138,37 +1203,37 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); } mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0); } - public void CSYRK(@Uplo int Uplo, @Transpose int Trans, float alphaX, float alphaY, Allocation A, float betaX, float betaY, Allocation C) { + public void CSYRK(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Float2 beta, Allocation C) { validateTranspose(Trans); validateUplo(Uplo); validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); } - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY, + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); } - public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, double alphaX, double alphaY, Allocation A, double betaX, double betaY, Allocation C) { + public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Double2 beta, Allocation C) { validateTranspose(Trans); validateUplo(Uplo); validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); } - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY, + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); } @@ -1189,7 +1254,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // check rows versus C Cdim = A.getType().getY(); } - if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) { + if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) { throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); } // A dims == B dims @@ -1201,7 +1266,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateSYR2K(Element.F32(mRS), Trans, A, B, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); @@ -1212,59 +1277,59 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateSYR2K(Element.F64(mRS), Trans, A, B, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); } - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); } public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) { validateUplo(Uplo); validateSYR2K(Element.F32_2(mRS), Trans, A, B, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); } - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); } public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { validateUplo(Uplo); validateSYR2K(Element.F64_2(mRS), Trans, A, B, C); int K = -1; - if (Trans == TRANSPOSE) { + if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); } - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); } static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { validateSide(Side); validateTranspose(TransA); - int aX = -1, aY = -1, bX = -1, bY = -1; + int aM = -1, aN = -1, bM = -1, bN = -1; if (!A.getType().getElement().isCompatible(e) || !B.getType().getElement().isCompatible(e)) { throw new RSRuntimeException("Called BLAS with wrong Element type"); } - if (TransA == TRANSPOSE) { - aY = A.getType().getY(); - aX = A.getType().getX(); - } else { - aY = A.getType().getX(); - aX = A.getType().getY(); + + aM = A.getType().getY(); + aN = A.getType().getX(); + if (aM != aN) { + throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A"); } - bX = B.getType().getY(); - bY = B.getType().getX(); + + bM = B.getType().getY(); + bN = B.getType().getX(); if (Side == LEFT) { - if (aX == 0 || aY != bX) { + if (aN != bM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); } } else { - if (bY != aX || aY == 0) { + if (bN != aM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); } } @@ -1280,26 +1345,26 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateDiag(Diag); validateTRMM(Element.F64(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, - alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); } public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { validateUplo(Uplo); validateDiag(Diag); validateTRMM(Element.F32_2(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); } public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { validateUplo(Uplo); validateDiag(Diag); validateTRMM(Element.F64_2(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); } static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { - int adim = -1, bX = -1, bY = -1; + int adim = -1, bM = -1, bN = -1; validateSide(Side); validateTranspose(TransA); if (!A.getType().getElement().isCompatible(e) || @@ -1313,16 +1378,16 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // for now we assume adapters are sufficient, will reevaluate in the future throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); } - bX = B.getType().getY(); - bY = B.getType().getX(); + bM = B.getType().getY(); + bN = B.getType().getX(); if (Side == LEFT) { // A is M*M - if (adim != bY) { + if (adim != bM) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } else { // A is N*N - if (adim != bX) { + if (adim != bN) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } @@ -1338,21 +1403,21 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateDiag(Diag); validateTRSM(Element.F64(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); } public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { validateUplo(Uplo); validateDiag(Diag); validateTRSM(Element.F32_2(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); } public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { validateUplo(Uplo); validateDiag(Diag); validateTRSM(Element.F64_2(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); } @@ -1379,17 +1444,17 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { throw new RSRuntimeException("Called HEMM with mismatched B and C"); } } - public void CHEMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, Allocation B, float beta, Allocation C) { + public void CHEMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) { validateUplo(Uplo); validateHEMM(Element.F32_2(mRS), Side, A, B, C); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, - alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); + alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); } - public void ZHEMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, Allocation B, double beta, Allocation C) { + public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { validateUplo(Uplo); - validateHEMM(Element.F32_2(mRS), Side, A, B, C); + validateHEMM(Element.F64_2(mRS), Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, - alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); + alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); } static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) { @@ -1403,11 +1468,11 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { throw new RSRuntimeException("Called HERK with non-square C"); } if (Trans == NO_TRANSPOSE) { - if (cdim != A.getType().getX()) { + if (cdim != A.getType().getY()) { throw new RSRuntimeException("Called HERK with invalid A"); } } else { - if (cdim != A.getType().getY()) { + if (cdim != A.getType().getX()) { throw new RSRuntimeException("Called HERK with invalid A"); } } @@ -1416,7 +1481,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateHERK(Element.F32_2(mRS), Trans, A, C); int k = 0; - if (Trans == TRANSPOSE) { + if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); } else { k = A.getType().getX(); @@ -1428,7 +1493,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateHERK(Element.F64_2(mRS), Trans, A, C); int k = 0; - if (Trans == TRANSPOSE) { + if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); } else { k = A.getType().getX(); diff --git a/rs/jni/android_renderscript_RenderScript.cpp b/rs/jni/android_renderscript_RenderScript.cpp index cbe87fc..1833a1c 100644 --- a/rs/jni/android_renderscript_RenderScript.cpp +++ b/rs/jni/android_renderscript_RenderScript.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#define LOG_TAG "libRS_jni" +#define LOG_TAG "RenderScript_jni" #include <stdlib.h> #include <stdio.h> @@ -328,79 +328,167 @@ nClosureCreate(JNIEnv *_env, jobject _this, jlong con, jlong kernelID, jlong returnValue, jlongArray fieldIDArray, jlongArray valueArray, jintArray sizeArray, jlongArray depClosureArray, jlongArray depFieldIDArray) { + jlong ret = 0; + jlong* jFieldIDs = _env->GetLongArrayElements(fieldIDArray, nullptr); jsize fieldIDs_length = _env->GetArrayLength(fieldIDArray); - RsScriptFieldID* fieldIDs = - (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * fieldIDs_length); - for (int i = 0; i< fieldIDs_length; i++) { + jlong* jValues = _env->GetLongArrayElements(valueArray, nullptr); + jsize values_length = _env->GetArrayLength(valueArray); + jint* jSizes = _env->GetIntArrayElements(sizeArray, nullptr); + jsize sizes_length = _env->GetArrayLength(sizeArray); + jlong* jDepClosures = + _env->GetLongArrayElements(depClosureArray, nullptr); + jsize depClosures_length = _env->GetArrayLength(depClosureArray); + jlong* jDepFieldIDs = + _env->GetLongArrayElements(depFieldIDArray, nullptr); + jsize depFieldIDs_length = _env->GetArrayLength(depFieldIDArray); + + size_t numValues, numDependencies; + RsScriptFieldID* fieldIDs; + uintptr_t* values; + RsClosure* depClosures; + RsScriptFieldID* depFieldIDs; + + if (fieldIDs_length != values_length || values_length != sizes_length) { + ALOGE("Unmatched field IDs, values, and sizes in closure creation."); + goto exit; + } + + numValues = (size_t)fieldIDs_length; + + if (depClosures_length != depFieldIDs_length) { + ALOGE("Unmatched closures and field IDs for dependencies in closure creation."); + goto exit; + } + + numDependencies = (size_t)depClosures_length; + + if (numDependencies > numValues) { + ALOGE("Unexpected number of dependencies in closure creation"); + goto exit; + } + + if (numValues > RS_CLOSURE_MAX_NUMBER_ARGS_AND_BINDINGS) { + ALOGE("Too many arguments or globals in closure creation"); + goto exit; + } + + fieldIDs = (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * numValues); + if (fieldIDs == nullptr) { + goto exit; + } + + for (size_t i = 0; i < numValues; i++) { fieldIDs[i] = (RsScriptFieldID)jFieldIDs[i]; } - jlong* jValues = _env->GetLongArrayElements(valueArray, nullptr); - jsize values_length = _env->GetArrayLength(valueArray); - uintptr_t* values = (uintptr_t*)alloca(sizeof(uintptr_t) * values_length); - for (int i = 0; i < values_length; i++) { + values = (uintptr_t*)alloca(sizeof(uintptr_t) * numValues); + if (values == nullptr) { + goto exit; + } + + for (size_t i = 0; i < numValues; i++) { values[i] = (uintptr_t)jValues[i]; } - jint* sizes = _env->GetIntArrayElements(sizeArray, nullptr); - jsize sizes_length = _env->GetArrayLength(sizeArray); + depClosures = (RsClosure*)alloca(sizeof(RsClosure) * numDependencies); + if (depClosures == nullptr) { + goto exit; + } - jlong* jDepClosures = - _env->GetLongArrayElements(depClosureArray, nullptr); - jsize depClosures_length = _env->GetArrayLength(depClosureArray); - RsClosure* depClosures = - (RsClosure*)alloca(sizeof(RsClosure) * depClosures_length); - for (int i = 0; i < depClosures_length; i++) { + for (size_t i = 0; i < numDependencies; i++) { depClosures[i] = (RsClosure)jDepClosures[i]; } - jlong* jDepFieldIDs = - _env->GetLongArrayElements(depFieldIDArray, nullptr); - jsize depFieldIDs_length = _env->GetArrayLength(depFieldIDArray); - RsScriptFieldID* depFieldIDs = - (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * depFieldIDs_length); - for (int i = 0; i < depClosures_length; i++) { + depFieldIDs = (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * numDependencies); + if (depFieldIDs == nullptr) { + goto exit; + } + + for (size_t i = 0; i < numDependencies; i++) { depFieldIDs[i] = (RsClosure)jDepFieldIDs[i]; } - return (jlong)(uintptr_t)rsClosureCreate( + ret = (jlong)(uintptr_t)rsClosureCreate( (RsContext)con, (RsScriptKernelID)kernelID, (RsAllocation)returnValue, - fieldIDs, (size_t)fieldIDs_length, values, (size_t)values_length, - (int*)sizes, (size_t)sizes_length, - depClosures, (size_t)depClosures_length, - depFieldIDs, (size_t)depFieldIDs_length); + fieldIDs, numValues, values, numValues, + (int*)jSizes, numValues, + depClosures, numDependencies, + depFieldIDs, numDependencies); + +exit: + + _env->ReleaseLongArrayElements(depFieldIDArray, jDepFieldIDs, JNI_ABORT); + _env->ReleaseLongArrayElements(depClosureArray, jDepClosures, JNI_ABORT); + _env->ReleaseIntArrayElements (sizeArray, jSizes, JNI_ABORT); + _env->ReleaseLongArrayElements(valueArray, jValues, JNI_ABORT); + _env->ReleaseLongArrayElements(fieldIDArray, jFieldIDs, JNI_ABORT); + + return ret; } static jlong nInvokeClosureCreate(JNIEnv *_env, jobject _this, jlong con, jlong invokeID, jbyteArray paramArray, jlongArray fieldIDArray, jlongArray valueArray, jintArray sizeArray) { + jlong ret = 0; + jbyte* jParams = _env->GetByteArrayElements(paramArray, nullptr); jsize jParamLength = _env->GetArrayLength(paramArray); - jlong* jFieldIDs = _env->GetLongArrayElements(fieldIDArray, nullptr); jsize fieldIDs_length = _env->GetArrayLength(fieldIDArray); - RsScriptFieldID* fieldIDs = - (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * fieldIDs_length); - for (int i = 0; i< fieldIDs_length; i++) { + jlong* jValues = _env->GetLongArrayElements(valueArray, nullptr); + jsize values_length = _env->GetArrayLength(valueArray); + jint* jSizes = _env->GetIntArrayElements(sizeArray, nullptr); + jsize sizes_length = _env->GetArrayLength(sizeArray); + + size_t numValues; + RsScriptFieldID* fieldIDs; + uintptr_t* values; + + if (fieldIDs_length != values_length || values_length != sizes_length) { + ALOGE("Unmatched field IDs, values, and sizes in closure creation."); + goto exit; + } + + numValues = (size_t) fieldIDs_length; + + if (numValues > RS_CLOSURE_MAX_NUMBER_ARGS_AND_BINDINGS) { + ALOGE("Too many arguments or globals in closure creation"); + goto exit; + } + + fieldIDs = (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * numValues); + if (fieldIDs == nullptr) { + goto exit; + } + + for (size_t i = 0; i< numValues; i++) { fieldIDs[i] = (RsScriptFieldID)jFieldIDs[i]; } - jlong* jValues = _env->GetLongArrayElements(valueArray, nullptr); - jsize values_length = _env->GetArrayLength(valueArray); - uintptr_t* values = (uintptr_t*)alloca(sizeof(uintptr_t) * values_length); - for (int i = 0; i < values_length; i++) { - values[i] = (uintptr_t)jValues[i]; + values = (uintptr_t*)alloca(sizeof(uintptr_t) * numValues); + if (values == nullptr) { + goto exit; } - jint* sizes = _env->GetIntArrayElements(sizeArray, nullptr); - jsize sizes_length = _env->GetArrayLength(sizeArray); + for (size_t i = 0; i < numValues; i++) { + values[i] = (uintptr_t)jValues[i]; + } - return (jlong)(uintptr_t)rsInvokeClosureCreate( + ret = (jlong)(uintptr_t)rsInvokeClosureCreate( (RsContext)con, (RsScriptInvokeID)invokeID, jParams, jParamLength, - fieldIDs, (size_t)fieldIDs_length, values, (size_t)values_length, - (int*)sizes, (size_t)sizes_length); + fieldIDs, numValues, values, numValues, + (int*)jSizes, numValues); + +exit: + + _env->ReleaseIntArrayElements (sizeArray, jSizes, JNI_ABORT); + _env->ReleaseLongArrayElements(valueArray, jValues, JNI_ABORT); + _env->ReleaseLongArrayElements(fieldIDArray, jFieldIDs, JNI_ABORT); + _env->ReleaseByteArrayElements(paramArray, jParams, JNI_ABORT); + + return ret; } static void @@ -420,20 +508,40 @@ nClosureSetGlobal(JNIEnv *_env, jobject _this, jlong con, jlong closureID, static long nScriptGroup2Create(JNIEnv *_env, jobject _this, jlong con, jstring name, jstring cacheDir, jlongArray closureArray) { + jlong ret = 0; + AutoJavaStringToUTF8 nameUTF(_env, name); AutoJavaStringToUTF8 cacheDirUTF(_env, cacheDir); jlong* jClosures = _env->GetLongArrayElements(closureArray, nullptr); jsize numClosures = _env->GetArrayLength(closureArray); - RsClosure* closures = (RsClosure*)alloca(sizeof(RsClosure) * numClosures); + + RsClosure* closures; + + if (numClosures > (jsize) RS_SCRIPT_GROUP_MAX_NUMBER_CLOSURES) { + ALOGE("Too many closures in script group"); + goto exit; + } + + closures = (RsClosure*)alloca(sizeof(RsClosure) * numClosures); + if (closures == nullptr) { + goto exit; + } + for (int i = 0; i < numClosures; i++) { closures[i] = (RsClosure)jClosures[i]; } - return (jlong)(uintptr_t)rsScriptGroup2Create( + ret = (jlong)(uintptr_t)rsScriptGroup2Create( (RsContext)con, nameUTF.c_str(), nameUTF.length(), cacheDirUTF.c_str(), cacheDirUTF.length(), closures, numClosures); + +exit: + + _env->ReleaseLongArrayElements(closureArray, jClosures, JNI_ABORT); + + return ret; } static void @@ -526,7 +634,7 @@ nScriptIntrinsicBLAS_Complex(JNIEnv *_env, jobject _this, jlong con, jlong id, j call.alpha.c.r = alphaX; call.alpha.c.i = alphaY; call.beta.c.r = betaX; - call.beta.c.r = betaY; + call.beta.c.i = betaY; call.incX = incX; call.incY = incY; call.KL = KL; @@ -561,7 +669,7 @@ nScriptIntrinsicBLAS_Z(JNIEnv *_env, jobject _this, jlong con, jlong id, jint fu call.alpha.z.r = alphaX; call.alpha.z.i = alphaY; call.beta.z.r = betaX; - call.beta.z.r = betaY; + call.beta.z.i = betaY; call.incX = incX; call.incY = incY; call.KL = KL; @@ -1102,9 +1210,8 @@ static jlong nAllocationCreateFromBitmap(JNIEnv *_env, jobject _this, jlong con, jlong type, jint mip, jobject jbitmap, jint usage) { - SkBitmap const * nativeBitmap = - GraphicsJNI::getSkBitmap(_env, jbitmap); - const SkBitmap& bitmap(*nativeBitmap); + SkBitmap bitmap; + GraphicsJNI::getSkBitmap(_env, jbitmap, &bitmap); bitmap.lockPixels(); const void* ptr = bitmap.getPixels(); @@ -1119,9 +1226,8 @@ static jlong nAllocationCreateBitmapBackedAllocation(JNIEnv *_env, jobject _this, jlong con, jlong type, jint mip, jobject jbitmap, jint usage) { - SkBitmap const * nativeBitmap = - GraphicsJNI::getSkBitmap(_env, jbitmap); - const SkBitmap& bitmap(*nativeBitmap); + SkBitmap bitmap; + GraphicsJNI::getSkBitmap(_env, jbitmap, &bitmap); bitmap.lockPixels(); const void* ptr = bitmap.getPixels(); @@ -1136,9 +1242,8 @@ static jlong nAllocationCubeCreateFromBitmap(JNIEnv *_env, jobject _this, jlong con, jlong type, jint mip, jobject jbitmap, jint usage) { - SkBitmap const * nativeBitmap = - GraphicsJNI::getSkBitmap(_env, jbitmap); - const SkBitmap& bitmap(*nativeBitmap); + SkBitmap bitmap; + GraphicsJNI::getSkBitmap(_env, jbitmap, &bitmap); bitmap.lockPixels(); const void* ptr = bitmap.getPixels(); @@ -1152,9 +1257,8 @@ nAllocationCubeCreateFromBitmap(JNIEnv *_env, jobject _this, jlong con, jlong ty static void nAllocationCopyFromBitmap(JNIEnv *_env, jobject _this, jlong con, jlong alloc, jobject jbitmap) { - SkBitmap const * nativeBitmap = - GraphicsJNI::getSkBitmap(_env, jbitmap); - const SkBitmap& bitmap(*nativeBitmap); + SkBitmap bitmap; + GraphicsJNI::getSkBitmap(_env, jbitmap, &bitmap); int w = bitmap.width(); int h = bitmap.height(); @@ -1169,9 +1273,8 @@ nAllocationCopyFromBitmap(JNIEnv *_env, jobject _this, jlong con, jlong alloc, j static void nAllocationCopyToBitmap(JNIEnv *_env, jobject _this, jlong con, jlong alloc, jobject jbitmap) { - SkBitmap const * nativeBitmap = - GraphicsJNI::getSkBitmap(_env, jbitmap); - const SkBitmap& bitmap(*nativeBitmap); + SkBitmap bitmap; + GraphicsJNI::getSkBitmap(_env, jbitmap, &bitmap); bitmap.lockPixels(); void* ptr = bitmap.getPixels(); @@ -1751,7 +1854,7 @@ nScriptForEach(JNIEnv *_env, jobject _this, jlong con, jlong script, jint slot, jintArray limits) { if (kLogApi) { - ALOGD("nScriptForEach, con(%p), s(%p), slot(%i)", (RsContext)con, (void *)script, slot); + ALOGD("nScriptForEach, con(%p), s(%p), slot(%i) ains(%p) aout(%" PRId64 ")", (RsContext)con, (void *)script, slot, ains, aout); } jint in_len = 0; @@ -1761,8 +1864,14 @@ nScriptForEach(JNIEnv *_env, jobject _this, jlong con, jlong script, jint slot, if (ains != nullptr) { in_len = _env->GetArrayLength(ains); - in_ptr = _env->GetLongArrayElements(ains, nullptr); + if (in_len > (jint)RS_KERNEL_MAX_ARGUMENTS) { + ALOGE("Too many arguments in kernel launch."); + // TODO (b/20758983): Report back to Java and throw an exception + return; + } + // TODO (b/20760800): Check in_ptr is not null + in_ptr = _env->GetLongArrayElements(ains, nullptr); if (sizeof(RsAllocation) == sizeof(jlong)) { in_allocs = (RsAllocation*)in_ptr; @@ -1770,6 +1879,11 @@ nScriptForEach(JNIEnv *_env, jobject _this, jlong con, jlong script, jint slot, // Convert from 64-bit jlong types to the native pointer type. in_allocs = (RsAllocation*)alloca(in_len * sizeof(RsAllocation)); + if (in_allocs == nullptr) { + ALOGE("Failed launching kernel for lack of memory."); + _env->ReleaseLongArrayElements(ains, in_ptr, JNI_ABORT); + return; + } for (int index = in_len; --index >= 0;) { in_allocs[index] = (RsAllocation)in_ptr[index]; |