summaryrefslogtreecommitdiffstats
path: root/rs
diff options
context:
space:
mode:
Diffstat (limited to 'rs')
-rw-r--r--rs/java/android/renderscript/Allocation.java2
-rw-r--r--rs/java/android/renderscript/Element.java5
-rw-r--r--rs/java/android/renderscript/RenderScript.java24
-rw-r--r--rs/java/android/renderscript/Script.java4
-rw-r--r--rs/java/android/renderscript/ScriptGroup.java6
-rw-r--r--rs/java/android/renderscript/ScriptIntrinsicBLAS.java439
-rw-r--r--rs/jni/android_renderscript_RenderScript.cpp238
7 files changed, 457 insertions, 261 deletions
diff --git a/rs/java/android/renderscript/Allocation.java b/rs/java/android/renderscript/Allocation.java
index 3b61f9d..70a5821 100644
--- a/rs/java/android/renderscript/Allocation.java
+++ b/rs/java/android/renderscript/Allocation.java
@@ -1422,6 +1422,8 @@ public class Allocation extends BaseObj {
}
/**
+ * @hide
+ *
* This is only intended to be used by auto-generated code reflected from
* the RenderScript script files and should not be used by developers.
*
diff --git a/rs/java/android/renderscript/Element.java b/rs/java/android/renderscript/Element.java
index 4b3e30f..6efb6d6 100644
--- a/rs/java/android/renderscript/Element.java
+++ b/rs/java/android/renderscript/Element.java
@@ -536,8 +536,8 @@ public class Element extends BaseObj {
}
public static Element F16_3(RenderScript rs) {
- if(rs.mElement_FLOAT_3 == null) {
- rs.mElement_FLOAT_3 = createVector(rs, DataType.FLOAT_16, 3);
+ if(rs.mElement_HALF_3 == null) {
+ rs.mElement_HALF_3 = createVector(rs, DataType.FLOAT_16, 3);
}
return rs.mElement_HALF_3;
}
@@ -911,6 +911,7 @@ public class Element extends BaseObj {
switch (dt) {
// Support only primitive integer/float/boolean types as vectors.
+ case FLOAT_16:
case FLOAT_32:
case FLOAT_64:
case SIGNED_8:
diff --git a/rs/java/android/renderscript/RenderScript.java b/rs/java/android/renderscript/RenderScript.java
index e7f210b..27f2cc8 100644
--- a/rs/java/android/renderscript/RenderScript.java
+++ b/rs/java/android/renderscript/RenderScript.java
@@ -131,7 +131,7 @@ public class RenderScript {
// this should be a monotonically increasing ID
// used in conjunction with the API version of a device
- static final long sMinorID = 1;
+ static final long sMinorVersion = 1;
/**
* Returns an identifier that can be used to identify a particular
@@ -140,8 +140,8 @@ public class RenderScript {
* @return The minor RenderScript version number
*
*/
- public static long getMinorID() {
- return sMinorID;
+ public static long getMinorVersion() {
+ return sMinorVersion;
}
/**
@@ -302,8 +302,12 @@ public class RenderScript {
long[] fieldIDs, long[] values, int[] sizes, long[] depClosures,
long[] depFieldIDs) {
validate();
- return rsnClosureCreate(mContext, kernelID, returnValue, fieldIDs, values,
+ long c = rsnClosureCreate(mContext, kernelID, returnValue, fieldIDs, values,
sizes, depClosures, depFieldIDs);
+ if (c == 0) {
+ throw new RSRuntimeException("Failed creating closure.");
+ }
+ return c;
}
native long rsnInvokeClosureCreate(long con, long invokeID, byte[] params,
@@ -311,8 +315,12 @@ public class RenderScript {
synchronized long nInvokeClosureCreate(long invokeID, byte[] params,
long[] fieldIDs, long[] values, int[] sizes) {
validate();
- return rsnInvokeClosureCreate(mContext, invokeID, params, fieldIDs,
+ long c = rsnInvokeClosureCreate(mContext, invokeID, params, fieldIDs,
values, sizes);
+ if (c == 0) {
+ throw new RSRuntimeException("Failed creating closure.");
+ }
+ return c;
}
native void rsnClosureSetArg(long con, long closureID, int index,
@@ -337,7 +345,11 @@ public class RenderScript {
synchronized long nScriptGroup2Create(String name, String cachePath,
long[] closures) {
validate();
- return rsnScriptGroup2Create(mContext, name, cachePath, closures);
+ long g = rsnScriptGroup2Create(mContext, name, cachePath, closures);
+ if (g == 0) {
+ throw new RSRuntimeException("Failed creating script group.");
+ }
+ return g;
}
native void rsnScriptGroup2Execute(long con, long groupID);
diff --git a/rs/java/android/renderscript/Script.java b/rs/java/android/renderscript/Script.java
index 6a1efee..7cd6d09 100644
--- a/rs/java/android/renderscript/Script.java
+++ b/rs/java/android/renderscript/Script.java
@@ -182,9 +182,9 @@ public class Script extends BaseObj {
mRS.validateObject(ain);
mRS.validateObject(aout);
- if (ain == null && aout == null) {
+ if (ain == null && aout == null && sc == null) {
throw new RSIllegalArgumentException(
- "At least one of ain or aout is required to be non-null.");
+ "At least one of input allocation, output allocation, or LaunchOptions is required to be non-null.");
}
long[] in_ids = null;
diff --git a/rs/java/android/renderscript/ScriptGroup.java b/rs/java/android/renderscript/ScriptGroup.java
index be8b0fd..d1a12f9 100644
--- a/rs/java/android/renderscript/ScriptGroup.java
+++ b/rs/java/android/renderscript/ScriptGroup.java
@@ -400,8 +400,10 @@ public final class ScriptGroup extends BaseObj {
/**
* Executes a script group
*
- * @param inputs inputs to the script group
- * @return outputs of the script group as an array of objects
+ * @param inputs Values for inputs to the script group, in the order as the
+ * inputs are added via {@link Builder2#addInput}.
+ * @return Outputs of the script group as an array of objects, in the order
+ * as futures are passed to {@link Builder2#create}.
*/
public Object[] execute(Object... inputs) {
diff --git a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java
index 6cfdfee..f7e81b0 100644
--- a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java
+++ b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java
@@ -241,7 +241,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
static void validateUplo(@Uplo int Uplo) {
- if (Uplo != LEFT && Uplo != RIGHT) {
+ if (Uplo != UPPER && Uplo != LOWER) {
throw new RSRuntimeException("Invalid uplo passed to BLAS");
}
}
@@ -276,36 +276,36 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
expectedYDim = 1 + (N - 1) * incY;
}
if (X.getType().getX() != expectedXDim ||
- Y.getType().getY() != expectedXDim) {
+ Y.getType().getX() != expectedYDim) {
throw new RSRuntimeException("Incorrect vector dimensions for GEMV");
}
}
- void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
+ public void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
+ public void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
+ public void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
+ public void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
+ public void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
// GBMV has the same validation requirements as GEMV + KL and KU >= 0
validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
if (KL < 0 || KU < 0) {
@@ -315,7 +315,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
}
- void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
+ public void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
// GBMV has the same validation requirements as GEMV + KL and KU >= 0
validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
if (KL < 0 || KU < 0) {
@@ -325,7 +325,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
}
- void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
+ public void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
// GBMV has the same validation requirements as GEMV + KL and KU >= 0
validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
if (KL < 0 || KU < 0) {
@@ -335,7 +335,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
}
- void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
+ public void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
// GBMV has the same validation requirements as GEMV + KL and KU >= 0
validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
if (KL < 0 || KU < 0) {
@@ -346,8 +346,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
}
- static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) {
+ static void validateTRMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
validateTranspose(TransA);
+ validateUplo(Uplo);
+ validateDiag(Diag);
int N = A.getType().getY();
if (A.getType().getX() != N) {
throw new RSRuntimeException("A must be a square matrix for TRMV");
@@ -386,158 +388,174 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
+ //is it really doing anything?
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
throw new RSRuntimeException("Invalid dimension for Ap");
}
-
+ if (incX <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
- throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
+ throw new RSRuntimeException("Incorrect vector dimensions for TPMV");
}
return N;
}
- void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
- validateTRMV(Element.F32(mRS), TransA, A, X, incX);
+ public void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
- validateTRMV(Element.F64(mRS), TransA, A, X, incX);
+ public void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
- validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
+ public void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
- validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
+ public void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBMV has the same requirements as TRMV
- validateTRMV(Element.F32(mRS), TransA, A, X, incX);
+
+ public void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBMV has the same requirements as TRMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
+ validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBMV has the same requirements as TRMV
- validateTRMV(Element.F64(mRS), TransA, A, X, incX);
+ public void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBMV has the same requirements as TRMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
+ validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBMV has the same requirements as TRMV
- validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
+ public void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBMV has the same requirements as TRMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
+ validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBMV has the same requirements as TRMV
- validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
+ public void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBMV has the same requirements as TRMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
+ validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ public void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
// TRSV is the same as TRMV
- validateTRMV(Element.F32(mRS), TransA, A, X, incX);
+ validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ public void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
// TRSV is the same as TRMV
- validateTRMV(Element.F64(mRS), TransA, A, X, incX);
+ validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ public void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
// TRSV is the same as TRMV
- validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
+ validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ public void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
// TRSV is the same as TRMV
- validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
+ validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBSV is the same as TRMV
- validateTRMV(Element.F32(mRS), TransA, A, X, incX);
+ public void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBSV is the same as TRMV + K >= 0
+ validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
if (K < 0) {
throw new RSRuntimeException("Number of diagonals must be positive");
}
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBSV is the same as TRMV
- validateTRMV(Element.F64(mRS), TransA, A, X, incX);
+ public void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBSV is the same as TRMV + K >= 0
+ validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
if (K < 0) {
throw new RSRuntimeException("Number of diagonals must be positive");
}
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBSV is the same as TRMV
- validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
+ public void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBSV is the same as TRMV + K >= 0
+ validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
if (K < 0) {
throw new RSRuntimeException("Number of diagonals must be positive");
}
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBSV is the same as TRMV
- validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
+ public void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBSV is the same as TRMV + K >= 0
+ validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
if (K < 0) {
throw new RSRuntimeException("Number of diagonals must be positive");
}
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
// TPSV is same as TPMV
int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
// TPSV is same as TPMV
int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
// TPSV is same as TPMV
int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
// TPSV is same as TPMV
int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
@@ -593,7 +611,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
throw new RSRuntimeException("Invalid dimension for Ap");
}
-
+ if (incX <= 0 || incY <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
@@ -622,8 +642,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (N < 1 || M < 1) {
throw new RSRuntimeException("M and N must be 1 or greater for GER");
}
-
- int expectedXDim = 1 + (N - 1) * incX;
+ if (incX <= 0 || incY <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
+ int expectedXDim = 1 + (M - 1) * incX;
if (X.getType().getX() != expectedXDim) {
throw new RSRuntimeException("Incorrect vector dimensions for GER");
}
@@ -649,7 +671,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (N != A.getType().getY()) {
throw new RSRuntimeException("A must be a symmetric matrix");
}
-
+ if (incX <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
throw new RSRuntimeException("Incorrect vector dimensions for SYR");
@@ -674,10 +698,12 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
throw new RSRuntimeException("Invalid dimension for Ap");
}
-
+ if (incX <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
- throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
+ throw new RSRuntimeException("Incorrect vector dimensions for SPR");
}
return N;
@@ -700,7 +726,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (N != A.getType().getY()) {
throw new RSRuntimeException("A must be a symmetric matrix");
}
-
+ if (incX <= 0 || incY <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
int expectedYDim = 1 + (N - 1) * incY;
if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
@@ -728,81 +756,91 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
throw new RSRuntimeException("Invalid dimension for Ap");
}
-
+ if (incX <= 0 || incY <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
int expectedYDim = 1 + (N - 1) * incY;
if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
- throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
+ throw new RSRuntimeException("Incorrect vector dimensions for SPR2");
}
return N;
}
- void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
+ public void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
- // SBMV is the same as SYMV
+ public void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
+ // SBMV is the same as SYMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) {
+ public void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) {
int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
int M = A.getType().getY();
int N = A.getType().getX();
+ validateGER(Element.F32(mRS), X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
}
- void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
+ public void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
}
- void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
+ public void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
}
- void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
}
- void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
+ public void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
}
- void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
+ public void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
- // SBMV is the same as SYMV
+ public void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
+ // SBMV is the same as SYMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) {
+ public void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) {
int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
int M = A.getType().getY();
int N = A.getType().getX();
+ validateGER(Element.F64(mRS), X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
}
- void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
+ public void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
}
- void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
+ public void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
}
- void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
}
- void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
+ public void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
}
@@ -824,8 +862,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
int M = A.getType().getY();
int N = A.getType().getX();
-
- int expectedXDim = 1 + (N - 1) * incX;
+ if (incX <= 0 || incY <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
+ int expectedXDim = 1 + (M - 1) * incX;
if (X.getType().getX() != expectedXDim) {
throw new RSRuntimeException("Incorrect vector dimensions for GERU");
}
@@ -836,12 +876,12 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
- void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
+ public void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
// HEMV is the same as SYR2 validation-wise
int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
+ public void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
// HBMV is the same as SYR2 validation-wise
int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
if (K < 0) {
@@ -849,50 +889,50 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
+ public void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
// HPMV is the same as SPR2
int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
// same as GERU
validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
+ public void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
// same as SYR
- int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
+ int N = validateSYR(Element.F32_2(mRS), Uplo, X, incX, A);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
}
- void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
+ public void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
// equivalent to SPR for validation
int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
}
- void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
// same as SYR2
int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
+ public void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
// same as SPR2
int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
}
- void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
+ public void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
// HEMV is the same as SYR2 validation-wise
int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
+ public void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
// HBMV is the same as SYR2 validation-wise
int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
if (K < 0) {
@@ -900,40 +940,40 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
+ public void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
// HPMV is the same as SPR2
int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
// same as GERU
validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
+ public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
// same as SYR
- int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
+ int N = validateSYR(Element.F64_2(mRS), Uplo, X, incX, A);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
}
- void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
+ public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
// equivalent to SPR for validation
int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
}
- void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
// same as SYR2
int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
+ public void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
// same as SPR2
int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
@@ -945,56 +985,68 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
*/
static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
- int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1;
+ int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
if ((A != null && !A.getType().getElement().isCompatible(e)) ||
(B != null && !B.getType().getElement().isCompatible(e)) ||
(C != null && !C.getType().getElement().isCompatible(e))) {
throw new RSRuntimeException("Called BLAS with wrong Element type");
}
- if (C != null) {
- cX = C.getType().getY();
- cY = C.getType().getX();
+ if (C == null) {
+ //since matrix C is used to store the result, it cannot be null.
+ throw new RSRuntimeException("Allocation C cannot be null");
}
+ cM = C.getType().getY();
+ cN = C.getType().getX();
+
if (Side == RIGHT) {
+ if ((A == null && B != null) || (A != null && B == null)) {
+ throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa");
+ }
if (B != null) {
- bX = A.getType().getY();
- bY = A.getType().getX();
+ bM = A.getType().getY();
+ bN = A.getType().getX();
}
if (A != null) {
- aX = B.getType().getY();
- aY = B.getType().getX();
+ aM = B.getType().getY();
+ aN = B.getType().getX();
}
} else {
if (A != null) {
- if (TransA == TRANSPOSE) {
- aY = A.getType().getY();
- aX = A.getType().getX();
+ if (TransA == TRANSPOSE || TransA == CONJ_TRANSPOSE) {
+ aN = A.getType().getY();
+ aM = A.getType().getX();
} else {
- aX = A.getType().getY();
- aY = A.getType().getX();
+ aM = A.getType().getY();
+ aN = A.getType().getX();
}
}
if (B != null) {
- if (TransB == TRANSPOSE) {
- bY = B.getType().getY();
- bX = B.getType().getX();
+ if (TransB == TRANSPOSE || TransB == CONJ_TRANSPOSE) {
+ bN = B.getType().getY();
+ bM = B.getType().getX();
} else {
- bX = B.getType().getY();
- bY = B.getType().getX();
+ bM = B.getType().getY();
+ bN = B.getType().getX();
}
}
}
if (A != null && B != null && C != null) {
- if (aY != bX || aX != cX || bY != cY) {
+ if (aN != bM || aM != cM || bN != cN) {
throw new RSRuntimeException("Called BLAS with invalid dimensions");
}
} else if (A != null && C != null) {
- // A and C only
- if (aX != cY || aY != cX) {
+ // A and C only, for SYRK
+ if (cM != cN) {
+ throw new RSRuntimeException("Matrix C is not symmetric");
+ }
+ if (aM != cM) {
throw new RSRuntimeException("Called BLAS with invalid dimensions");
}
} else if (A != null && B != null) {
// A and B only
+ if (aN != bM) {
+ throw new RSRuntimeException("Called BLAS with invalid dimensions");
+ }
}
}
@@ -1006,14 +1058,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1027,14 +1079,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateTranspose(TransB);
validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1048,14 +1100,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateTranspose(TransB);
validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1070,14 +1122,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateTranspose(TransB);
validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1090,6 +1142,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, float beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ //For SYMM, Matrix A should be symmetric
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F32(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
beta, C.getID(mRS), 0, 0, 0, 0);
@@ -1098,6 +1154,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, double beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F64(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
beta, C.getID(mRS), 0, 0, 0, 0);
@@ -1106,6 +1165,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, Float2 beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
@@ -1114,6 +1176,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, Double2 beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
@@ -1124,7 +1189,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
@@ -1138,37 +1203,37 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
}
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
}
- public void CSYRK(@Uplo int Uplo, @Transpose int Trans, float alphaX, float alphaY, Allocation A, float betaX, float betaY, Allocation C) {
+ public void CSYRK(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Float2 beta, Allocation C) {
validateTranspose(Trans);
validateUplo(Uplo);
validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
}
- mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY,
+ mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
C.getID(mRS), 0, 0, 0, 0);
}
- public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, double alphaX, double alphaY, Allocation A, double betaX, double betaY, Allocation C) {
+ public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Double2 beta, Allocation C) {
validateTranspose(Trans);
validateUplo(Uplo);
validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
}
- mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY,
+ mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
C.getID(mRS), 0, 0, 0, 0);
}
@@ -1189,7 +1254,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
// check rows versus C
Cdim = A.getType().getY();
}
- if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) {
+ if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) {
throw new RSRuntimeException("Invalid symmetric matrix in SYR2K");
}
// A dims == B dims
@@ -1201,7 +1266,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateSYR2K(Element.F32(mRS), Trans, A, B, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
@@ -1212,59 +1277,59 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateSYR2K(Element.F64(mRS), Trans, A, B, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
}
- mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
+ mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
}
public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
validateUplo(Uplo);
validateSYR2K(Element.F32_2(mRS), Trans, A, B, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
}
- mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
+ mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
}
public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
validateUplo(Uplo);
validateSYR2K(Element.F64_2(mRS), Trans, A, B, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
}
- mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
+ mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
}
static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
validateSide(Side);
validateTranspose(TransA);
- int aX = -1, aY = -1, bX = -1, bY = -1;
+ int aM = -1, aN = -1, bM = -1, bN = -1;
if (!A.getType().getElement().isCompatible(e) ||
!B.getType().getElement().isCompatible(e)) {
throw new RSRuntimeException("Called BLAS with wrong Element type");
}
- if (TransA == TRANSPOSE) {
- aY = A.getType().getY();
- aX = A.getType().getX();
- } else {
- aY = A.getType().getX();
- aX = A.getType().getY();
+
+ aM = A.getType().getY();
+ aN = A.getType().getX();
+ if (aM != aN) {
+ throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A");
}
- bX = B.getType().getY();
- bY = B.getType().getX();
+
+ bM = B.getType().getY();
+ bN = B.getType().getX();
if (Side == LEFT) {
- if (aX == 0 || aY != bX) {
+ if (aN != bM) {
throw new RSRuntimeException("Called TRMM with invalid matrices");
}
} else {
- if (bY != aX || aY == 0) {
+ if (bN != aM) {
throw new RSRuntimeException("Called TRMM with invalid matrices");
}
}
@@ -1280,26 +1345,26 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateDiag(Diag);
validateTRMM(Element.F64(mRS), Side, TransA, A, B);
- mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
- alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
+ mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
+ alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
}
public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
validateUplo(Uplo);
validateDiag(Diag);
validateTRMM(Element.F32_2(mRS), Side, TransA, A, B);
- mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
+ mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
}
public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
validateUplo(Uplo);
validateDiag(Diag);
validateTRMM(Element.F64_2(mRS), Side, TransA, A, B);
- mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
+ mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
}
static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
- int adim = -1, bX = -1, bY = -1;
+ int adim = -1, bM = -1, bN = -1;
validateSide(Side);
validateTranspose(TransA);
if (!A.getType().getElement().isCompatible(e) ||
@@ -1313,16 +1378,16 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
// for now we assume adapters are sufficient, will reevaluate in the future
throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A");
}
- bX = B.getType().getY();
- bY = B.getType().getX();
+ bM = B.getType().getY();
+ bN = B.getType().getX();
if (Side == LEFT) {
// A is M*M
- if (adim != bY) {
+ if (adim != bM) {
throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
}
} else {
// A is N*N
- if (adim != bX) {
+ if (adim != bN) {
throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
}
}
@@ -1338,21 +1403,21 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateDiag(Diag);
validateTRSM(Element.F64(mRS), Side, TransA, A, B);
- mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
+ mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
}
public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
validateUplo(Uplo);
validateDiag(Diag);
validateTRSM(Element.F32_2(mRS), Side, TransA, A, B);
- mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
+ mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
}
public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
validateUplo(Uplo);
validateDiag(Diag);
validateTRSM(Element.F64_2(mRS), Side, TransA, A, B);
- mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
+ mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
}
@@ -1379,17 +1444,17 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
throw new RSRuntimeException("Called HEMM with mismatched B and C");
}
}
- public void CHEMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, Allocation B, float beta, Allocation C) {
+ public void CHEMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
validateUplo(Uplo);
validateHEMM(Element.F32_2(mRS), Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
- alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
+ alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
}
- public void ZHEMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, Allocation B, double beta, Allocation C) {
+ public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
validateUplo(Uplo);
- validateHEMM(Element.F32_2(mRS), Side, A, B, C);
+ validateHEMM(Element.F64_2(mRS), Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
- alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
+ alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
}
static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) {
@@ -1403,11 +1468,11 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
throw new RSRuntimeException("Called HERK with non-square C");
}
if (Trans == NO_TRANSPOSE) {
- if (cdim != A.getType().getX()) {
+ if (cdim != A.getType().getY()) {
throw new RSRuntimeException("Called HERK with invalid A");
}
} else {
- if (cdim != A.getType().getY()) {
+ if (cdim != A.getType().getX()) {
throw new RSRuntimeException("Called HERK with invalid A");
}
}
@@ -1416,7 +1481,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateHERK(Element.F32_2(mRS), Trans, A, C);
int k = 0;
- if (Trans == TRANSPOSE) {
+ if (Trans == CONJ_TRANSPOSE) {
k = A.getType().getY();
} else {
k = A.getType().getX();
@@ -1428,7 +1493,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateHERK(Element.F64_2(mRS), Trans, A, C);
int k = 0;
- if (Trans == TRANSPOSE) {
+ if (Trans == CONJ_TRANSPOSE) {
k = A.getType().getY();
} else {
k = A.getType().getX();
diff --git a/rs/jni/android_renderscript_RenderScript.cpp b/rs/jni/android_renderscript_RenderScript.cpp
index cbe87fc..1833a1c 100644
--- a/rs/jni/android_renderscript_RenderScript.cpp
+++ b/rs/jni/android_renderscript_RenderScript.cpp
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#define LOG_TAG "libRS_jni"
+#define LOG_TAG "RenderScript_jni"
#include <stdlib.h>
#include <stdio.h>
@@ -328,79 +328,167 @@ nClosureCreate(JNIEnv *_env, jobject _this, jlong con, jlong kernelID,
jlong returnValue, jlongArray fieldIDArray,
jlongArray valueArray, jintArray sizeArray,
jlongArray depClosureArray, jlongArray depFieldIDArray) {
+ jlong ret = 0;
+
jlong* jFieldIDs = _env->GetLongArrayElements(fieldIDArray, nullptr);
jsize fieldIDs_length = _env->GetArrayLength(fieldIDArray);
- RsScriptFieldID* fieldIDs =
- (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * fieldIDs_length);
- for (int i = 0; i< fieldIDs_length; i++) {
+ jlong* jValues = _env->GetLongArrayElements(valueArray, nullptr);
+ jsize values_length = _env->GetArrayLength(valueArray);
+ jint* jSizes = _env->GetIntArrayElements(sizeArray, nullptr);
+ jsize sizes_length = _env->GetArrayLength(sizeArray);
+ jlong* jDepClosures =
+ _env->GetLongArrayElements(depClosureArray, nullptr);
+ jsize depClosures_length = _env->GetArrayLength(depClosureArray);
+ jlong* jDepFieldIDs =
+ _env->GetLongArrayElements(depFieldIDArray, nullptr);
+ jsize depFieldIDs_length = _env->GetArrayLength(depFieldIDArray);
+
+ size_t numValues, numDependencies;
+ RsScriptFieldID* fieldIDs;
+ uintptr_t* values;
+ RsClosure* depClosures;
+ RsScriptFieldID* depFieldIDs;
+
+ if (fieldIDs_length != values_length || values_length != sizes_length) {
+ ALOGE("Unmatched field IDs, values, and sizes in closure creation.");
+ goto exit;
+ }
+
+ numValues = (size_t)fieldIDs_length;
+
+ if (depClosures_length != depFieldIDs_length) {
+ ALOGE("Unmatched closures and field IDs for dependencies in closure creation.");
+ goto exit;
+ }
+
+ numDependencies = (size_t)depClosures_length;
+
+ if (numDependencies > numValues) {
+ ALOGE("Unexpected number of dependencies in closure creation");
+ goto exit;
+ }
+
+ if (numValues > RS_CLOSURE_MAX_NUMBER_ARGS_AND_BINDINGS) {
+ ALOGE("Too many arguments or globals in closure creation");
+ goto exit;
+ }
+
+ fieldIDs = (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * numValues);
+ if (fieldIDs == nullptr) {
+ goto exit;
+ }
+
+ for (size_t i = 0; i < numValues; i++) {
fieldIDs[i] = (RsScriptFieldID)jFieldIDs[i];
}
- jlong* jValues = _env->GetLongArrayElements(valueArray, nullptr);
- jsize values_length = _env->GetArrayLength(valueArray);
- uintptr_t* values = (uintptr_t*)alloca(sizeof(uintptr_t) * values_length);
- for (int i = 0; i < values_length; i++) {
+ values = (uintptr_t*)alloca(sizeof(uintptr_t) * numValues);
+ if (values == nullptr) {
+ goto exit;
+ }
+
+ for (size_t i = 0; i < numValues; i++) {
values[i] = (uintptr_t)jValues[i];
}
- jint* sizes = _env->GetIntArrayElements(sizeArray, nullptr);
- jsize sizes_length = _env->GetArrayLength(sizeArray);
+ depClosures = (RsClosure*)alloca(sizeof(RsClosure) * numDependencies);
+ if (depClosures == nullptr) {
+ goto exit;
+ }
- jlong* jDepClosures =
- _env->GetLongArrayElements(depClosureArray, nullptr);
- jsize depClosures_length = _env->GetArrayLength(depClosureArray);
- RsClosure* depClosures =
- (RsClosure*)alloca(sizeof(RsClosure) * depClosures_length);
- for (int i = 0; i < depClosures_length; i++) {
+ for (size_t i = 0; i < numDependencies; i++) {
depClosures[i] = (RsClosure)jDepClosures[i];
}
- jlong* jDepFieldIDs =
- _env->GetLongArrayElements(depFieldIDArray, nullptr);
- jsize depFieldIDs_length = _env->GetArrayLength(depFieldIDArray);
- RsScriptFieldID* depFieldIDs =
- (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * depFieldIDs_length);
- for (int i = 0; i < depClosures_length; i++) {
+ depFieldIDs = (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * numDependencies);
+ if (depFieldIDs == nullptr) {
+ goto exit;
+ }
+
+ for (size_t i = 0; i < numDependencies; i++) {
depFieldIDs[i] = (RsClosure)jDepFieldIDs[i];
}
- return (jlong)(uintptr_t)rsClosureCreate(
+ ret = (jlong)(uintptr_t)rsClosureCreate(
(RsContext)con, (RsScriptKernelID)kernelID, (RsAllocation)returnValue,
- fieldIDs, (size_t)fieldIDs_length, values, (size_t)values_length,
- (int*)sizes, (size_t)sizes_length,
- depClosures, (size_t)depClosures_length,
- depFieldIDs, (size_t)depFieldIDs_length);
+ fieldIDs, numValues, values, numValues,
+ (int*)jSizes, numValues,
+ depClosures, numDependencies,
+ depFieldIDs, numDependencies);
+
+exit:
+
+ _env->ReleaseLongArrayElements(depFieldIDArray, jDepFieldIDs, JNI_ABORT);
+ _env->ReleaseLongArrayElements(depClosureArray, jDepClosures, JNI_ABORT);
+ _env->ReleaseIntArrayElements (sizeArray, jSizes, JNI_ABORT);
+ _env->ReleaseLongArrayElements(valueArray, jValues, JNI_ABORT);
+ _env->ReleaseLongArrayElements(fieldIDArray, jFieldIDs, JNI_ABORT);
+
+ return ret;
}
static jlong
nInvokeClosureCreate(JNIEnv *_env, jobject _this, jlong con, jlong invokeID,
jbyteArray paramArray, jlongArray fieldIDArray, jlongArray valueArray,
jintArray sizeArray) {
+ jlong ret = 0;
+
jbyte* jParams = _env->GetByteArrayElements(paramArray, nullptr);
jsize jParamLength = _env->GetArrayLength(paramArray);
-
jlong* jFieldIDs = _env->GetLongArrayElements(fieldIDArray, nullptr);
jsize fieldIDs_length = _env->GetArrayLength(fieldIDArray);
- RsScriptFieldID* fieldIDs =
- (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * fieldIDs_length);
- for (int i = 0; i< fieldIDs_length; i++) {
+ jlong* jValues = _env->GetLongArrayElements(valueArray, nullptr);
+ jsize values_length = _env->GetArrayLength(valueArray);
+ jint* jSizes = _env->GetIntArrayElements(sizeArray, nullptr);
+ jsize sizes_length = _env->GetArrayLength(sizeArray);
+
+ size_t numValues;
+ RsScriptFieldID* fieldIDs;
+ uintptr_t* values;
+
+ if (fieldIDs_length != values_length || values_length != sizes_length) {
+ ALOGE("Unmatched field IDs, values, and sizes in closure creation.");
+ goto exit;
+ }
+
+ numValues = (size_t) fieldIDs_length;
+
+ if (numValues > RS_CLOSURE_MAX_NUMBER_ARGS_AND_BINDINGS) {
+ ALOGE("Too many arguments or globals in closure creation");
+ goto exit;
+ }
+
+ fieldIDs = (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * numValues);
+ if (fieldIDs == nullptr) {
+ goto exit;
+ }
+
+ for (size_t i = 0; i< numValues; i++) {
fieldIDs[i] = (RsScriptFieldID)jFieldIDs[i];
}
- jlong* jValues = _env->GetLongArrayElements(valueArray, nullptr);
- jsize values_length = _env->GetArrayLength(valueArray);
- uintptr_t* values = (uintptr_t*)alloca(sizeof(uintptr_t) * values_length);
- for (int i = 0; i < values_length; i++) {
- values[i] = (uintptr_t)jValues[i];
+ values = (uintptr_t*)alloca(sizeof(uintptr_t) * numValues);
+ if (values == nullptr) {
+ goto exit;
}
- jint* sizes = _env->GetIntArrayElements(sizeArray, nullptr);
- jsize sizes_length = _env->GetArrayLength(sizeArray);
+ for (size_t i = 0; i < numValues; i++) {
+ values[i] = (uintptr_t)jValues[i];
+ }
- return (jlong)(uintptr_t)rsInvokeClosureCreate(
+ ret = (jlong)(uintptr_t)rsInvokeClosureCreate(
(RsContext)con, (RsScriptInvokeID)invokeID, jParams, jParamLength,
- fieldIDs, (size_t)fieldIDs_length, values, (size_t)values_length,
- (int*)sizes, (size_t)sizes_length);
+ fieldIDs, numValues, values, numValues,
+ (int*)jSizes, numValues);
+
+exit:
+
+ _env->ReleaseIntArrayElements (sizeArray, jSizes, JNI_ABORT);
+ _env->ReleaseLongArrayElements(valueArray, jValues, JNI_ABORT);
+ _env->ReleaseLongArrayElements(fieldIDArray, jFieldIDs, JNI_ABORT);
+ _env->ReleaseByteArrayElements(paramArray, jParams, JNI_ABORT);
+
+ return ret;
}
static void
@@ -420,20 +508,40 @@ nClosureSetGlobal(JNIEnv *_env, jobject _this, jlong con, jlong closureID,
static long
nScriptGroup2Create(JNIEnv *_env, jobject _this, jlong con, jstring name,
jstring cacheDir, jlongArray closureArray) {
+ jlong ret = 0;
+
AutoJavaStringToUTF8 nameUTF(_env, name);
AutoJavaStringToUTF8 cacheDirUTF(_env, cacheDir);
jlong* jClosures = _env->GetLongArrayElements(closureArray, nullptr);
jsize numClosures = _env->GetArrayLength(closureArray);
- RsClosure* closures = (RsClosure*)alloca(sizeof(RsClosure) * numClosures);
+
+ RsClosure* closures;
+
+ if (numClosures > (jsize) RS_SCRIPT_GROUP_MAX_NUMBER_CLOSURES) {
+ ALOGE("Too many closures in script group");
+ goto exit;
+ }
+
+ closures = (RsClosure*)alloca(sizeof(RsClosure) * numClosures);
+ if (closures == nullptr) {
+ goto exit;
+ }
+
for (int i = 0; i < numClosures; i++) {
closures[i] = (RsClosure)jClosures[i];
}
- return (jlong)(uintptr_t)rsScriptGroup2Create(
+ ret = (jlong)(uintptr_t)rsScriptGroup2Create(
(RsContext)con, nameUTF.c_str(), nameUTF.length(),
cacheDirUTF.c_str(), cacheDirUTF.length(),
closures, numClosures);
+
+exit:
+
+ _env->ReleaseLongArrayElements(closureArray, jClosures, JNI_ABORT);
+
+ return ret;
}
static void
@@ -526,7 +634,7 @@ nScriptIntrinsicBLAS_Complex(JNIEnv *_env, jobject _this, jlong con, jlong id, j
call.alpha.c.r = alphaX;
call.alpha.c.i = alphaY;
call.beta.c.r = betaX;
- call.beta.c.r = betaY;
+ call.beta.c.i = betaY;
call.incX = incX;
call.incY = incY;
call.KL = KL;
@@ -561,7 +669,7 @@ nScriptIntrinsicBLAS_Z(JNIEnv *_env, jobject _this, jlong con, jlong id, jint fu
call.alpha.z.r = alphaX;
call.alpha.z.i = alphaY;
call.beta.z.r = betaX;
- call.beta.z.r = betaY;
+ call.beta.z.i = betaY;
call.incX = incX;
call.incY = incY;
call.KL = KL;
@@ -1102,9 +1210,8 @@ static jlong
nAllocationCreateFromBitmap(JNIEnv *_env, jobject _this, jlong con, jlong type, jint mip,
jobject jbitmap, jint usage)
{
- SkBitmap const * nativeBitmap =
- GraphicsJNI::getSkBitmap(_env, jbitmap);
- const SkBitmap& bitmap(*nativeBitmap);
+ SkBitmap bitmap;
+ GraphicsJNI::getSkBitmap(_env, jbitmap, &bitmap);
bitmap.lockPixels();
const void* ptr = bitmap.getPixels();
@@ -1119,9 +1226,8 @@ static jlong
nAllocationCreateBitmapBackedAllocation(JNIEnv *_env, jobject _this, jlong con, jlong type,
jint mip, jobject jbitmap, jint usage)
{
- SkBitmap const * nativeBitmap =
- GraphicsJNI::getSkBitmap(_env, jbitmap);
- const SkBitmap& bitmap(*nativeBitmap);
+ SkBitmap bitmap;
+ GraphicsJNI::getSkBitmap(_env, jbitmap, &bitmap);
bitmap.lockPixels();
const void* ptr = bitmap.getPixels();
@@ -1136,9 +1242,8 @@ static jlong
nAllocationCubeCreateFromBitmap(JNIEnv *_env, jobject _this, jlong con, jlong type, jint mip,
jobject jbitmap, jint usage)
{
- SkBitmap const * nativeBitmap =
- GraphicsJNI::getSkBitmap(_env, jbitmap);
- const SkBitmap& bitmap(*nativeBitmap);
+ SkBitmap bitmap;
+ GraphicsJNI::getSkBitmap(_env, jbitmap, &bitmap);
bitmap.lockPixels();
const void* ptr = bitmap.getPixels();
@@ -1152,9 +1257,8 @@ nAllocationCubeCreateFromBitmap(JNIEnv *_env, jobject _this, jlong con, jlong ty
static void
nAllocationCopyFromBitmap(JNIEnv *_env, jobject _this, jlong con, jlong alloc, jobject jbitmap)
{
- SkBitmap const * nativeBitmap =
- GraphicsJNI::getSkBitmap(_env, jbitmap);
- const SkBitmap& bitmap(*nativeBitmap);
+ SkBitmap bitmap;
+ GraphicsJNI::getSkBitmap(_env, jbitmap, &bitmap);
int w = bitmap.width();
int h = bitmap.height();
@@ -1169,9 +1273,8 @@ nAllocationCopyFromBitmap(JNIEnv *_env, jobject _this, jlong con, jlong alloc, j
static void
nAllocationCopyToBitmap(JNIEnv *_env, jobject _this, jlong con, jlong alloc, jobject jbitmap)
{
- SkBitmap const * nativeBitmap =
- GraphicsJNI::getSkBitmap(_env, jbitmap);
- const SkBitmap& bitmap(*nativeBitmap);
+ SkBitmap bitmap;
+ GraphicsJNI::getSkBitmap(_env, jbitmap, &bitmap);
bitmap.lockPixels();
void* ptr = bitmap.getPixels();
@@ -1751,7 +1854,7 @@ nScriptForEach(JNIEnv *_env, jobject _this, jlong con, jlong script, jint slot,
jintArray limits)
{
if (kLogApi) {
- ALOGD("nScriptForEach, con(%p), s(%p), slot(%i)", (RsContext)con, (void *)script, slot);
+ ALOGD("nScriptForEach, con(%p), s(%p), slot(%i) ains(%p) aout(%" PRId64 ")", (RsContext)con, (void *)script, slot, ains, aout);
}
jint in_len = 0;
@@ -1761,8 +1864,14 @@ nScriptForEach(JNIEnv *_env, jobject _this, jlong con, jlong script, jint slot,
if (ains != nullptr) {
in_len = _env->GetArrayLength(ains);
- in_ptr = _env->GetLongArrayElements(ains, nullptr);
+ if (in_len > (jint)RS_KERNEL_MAX_ARGUMENTS) {
+ ALOGE("Too many arguments in kernel launch.");
+ // TODO (b/20758983): Report back to Java and throw an exception
+ return;
+ }
+ // TODO (b/20760800): Check in_ptr is not null
+ in_ptr = _env->GetLongArrayElements(ains, nullptr);
if (sizeof(RsAllocation) == sizeof(jlong)) {
in_allocs = (RsAllocation*)in_ptr;
@@ -1770,6 +1879,11 @@ nScriptForEach(JNIEnv *_env, jobject _this, jlong con, jlong script, jint slot,
// Convert from 64-bit jlong types to the native pointer type.
in_allocs = (RsAllocation*)alloca(in_len * sizeof(RsAllocation));
+ if (in_allocs == nullptr) {
+ ALOGE("Failed launching kernel for lack of memory.");
+ _env->ReleaseLongArrayElements(ains, in_ptr, JNI_ABORT);
+ return;
+ }
for (int index = in_len; --index >= 0;) {
in_allocs[index] = (RsAllocation)in_ptr[index];