summaryrefslogtreecommitdiffstats
path: root/rs
diff options
context:
space:
mode:
Diffstat (limited to 'rs')
-rw-r--r--rs/java/android/renderscript/ScriptIntrinsicBLAS.java419
1 files changed, 245 insertions, 174 deletions
diff --git a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java
index 6cfdfee..a387aab 100644
--- a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java
+++ b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java
@@ -241,7 +241,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
static void validateUplo(@Uplo int Uplo) {
- if (Uplo != LEFT && Uplo != RIGHT) {
+ if (Uplo != UPPER && Uplo != LOWER) {
throw new RSRuntimeException("Invalid uplo passed to BLAS");
}
}
@@ -276,36 +276,36 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
expectedYDim = 1 + (N - 1) * incY;
}
if (X.getType().getX() != expectedXDim ||
- Y.getType().getY() != expectedXDim) {
+ Y.getType().getX() != expectedYDim) {
throw new RSRuntimeException("Incorrect vector dimensions for GEMV");
}
}
- void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
+ public void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
+ public void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
+ public void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
+ public void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
+ public void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
// GBMV has the same validation requirements as GEMV + KL and KU >= 0
validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
if (KL < 0 || KU < 0) {
@@ -315,7 +315,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
}
- void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
+ public void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
// GBMV has the same validation requirements as GEMV + KL and KU >= 0
validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
if (KL < 0 || KU < 0) {
@@ -325,7 +325,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
}
- void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
+ public void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
// GBMV has the same validation requirements as GEMV + KL and KU >= 0
validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
if (KL < 0 || KU < 0) {
@@ -335,7 +335,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
}
- void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
+ public void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
// GBMV has the same validation requirements as GEMV + KL and KU >= 0
validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
if (KL < 0 || KU < 0) {
@@ -346,8 +346,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
}
- static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) {
+ static void validateTRMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
validateTranspose(TransA);
+ validateUplo(Uplo);
+ validateDiag(Diag);
int N = A.getType().getY();
if (A.getType().getX() != N) {
throw new RSRuntimeException("A must be a square matrix for TRMV");
@@ -386,158 +388,174 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
+ //is it really doing anything?
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
throw new RSRuntimeException("Invalid dimension for Ap");
}
-
+ if (incX <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
- throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
+ throw new RSRuntimeException("Incorrect vector dimensions for TPMV");
}
return N;
}
- void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
- validateTRMV(Element.F32(mRS), TransA, A, X, incX);
+ public void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
- validateTRMV(Element.F64(mRS), TransA, A, X, incX);
+ public void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
- validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
+ public void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
- validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
+ public void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBMV has the same requirements as TRMV
- validateTRMV(Element.F32(mRS), TransA, A, X, incX);
+
+ public void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBMV has the same requirements as TRMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
+ validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBMV has the same requirements as TRMV
- validateTRMV(Element.F64(mRS), TransA, A, X, incX);
+ public void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBMV has the same requirements as TRMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
+ validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBMV has the same requirements as TRMV
- validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
+ public void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBMV has the same requirements as TRMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
+ validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBMV has the same requirements as TRMV
- validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
+ public void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBMV has the same requirements as TRMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
+ validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ public void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
// TRSV is the same as TRMV
- validateTRMV(Element.F32(mRS), TransA, A, X, incX);
+ validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ public void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
// TRSV is the same as TRMV
- validateTRMV(Element.F64(mRS), TransA, A, X, incX);
+ validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ public void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
// TRSV is the same as TRMV
- validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
+ validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
+ public void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
// TRSV is the same as TRMV
- validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
+ validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBSV is the same as TRMV
- validateTRMV(Element.F32(mRS), TransA, A, X, incX);
+ public void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBSV is the same as TRMV + K >= 0
+ validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
if (K < 0) {
throw new RSRuntimeException("Number of diagonals must be positive");
}
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBSV is the same as TRMV
- validateTRMV(Element.F64(mRS), TransA, A, X, incX);
+ public void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBSV is the same as TRMV + K >= 0
+ validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
if (K < 0) {
throw new RSRuntimeException("Number of diagonals must be positive");
}
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBSV is the same as TRMV
- validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
+ public void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBSV is the same as TRMV + K >= 0
+ validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
if (K < 0) {
throw new RSRuntimeException("Number of diagonals must be positive");
}
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
- // TBSV is the same as TRMV
- validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
+ public void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
+ // TBSV is the same as TRMV + K >= 0
+ validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
int N = A.getType().getY();
if (K < 0) {
throw new RSRuntimeException("Number of diagonals must be positive");
}
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
// TPSV is same as TPMV
int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
// TPSV is same as TPMV
int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
}
- void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
// TPSV is same as TPMV
int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
}
- void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
+ public void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
// TPSV is same as TPMV
int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
@@ -593,7 +611,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
throw new RSRuntimeException("Invalid dimension for Ap");
}
-
+ if (incX <= 0 || incY <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
@@ -622,8 +642,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (N < 1 || M < 1) {
throw new RSRuntimeException("M and N must be 1 or greater for GER");
}
-
- int expectedXDim = 1 + (N - 1) * incX;
+ if (incX <= 0 || incY <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
+ int expectedXDim = 1 + (M - 1) * incX;
if (X.getType().getX() != expectedXDim) {
throw new RSRuntimeException("Incorrect vector dimensions for GER");
}
@@ -649,7 +671,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (N != A.getType().getY()) {
throw new RSRuntimeException("A must be a symmetric matrix");
}
-
+ if (incX <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
throw new RSRuntimeException("Incorrect vector dimensions for SYR");
@@ -674,10 +698,12 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
throw new RSRuntimeException("Invalid dimension for Ap");
}
-
+ if (incX <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
- throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
+ throw new RSRuntimeException("Incorrect vector dimensions for SPR");
}
return N;
@@ -700,7 +726,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (N != A.getType().getY()) {
throw new RSRuntimeException("A must be a symmetric matrix");
}
-
+ if (incX <= 0 || incY <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
int expectedYDim = 1 + (N - 1) * incY;
if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
@@ -728,81 +756,91 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
throw new RSRuntimeException("Invalid dimension for Ap");
}
-
+ if (incX <= 0 || incY <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
int expectedXDim = 1 + (N - 1) * incX;
int expectedYDim = 1 + (N - 1) * incY;
if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
- throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
+ throw new RSRuntimeException("Incorrect vector dimensions for SPR2");
}
return N;
}
- void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
+ public void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
- // SBMV is the same as SYMV
+ public void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
+ // SBMV is the same as SYMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) {
+ public void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) {
int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
int M = A.getType().getY();
int N = A.getType().getX();
+ validateGER(Element.F32(mRS), X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
}
- void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
+ public void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
}
- void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
+ public void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
}
- void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
}
- void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
+ public void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
}
- void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
+ public void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
- // SBMV is the same as SYMV
+ public void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
+ // SBMV is the same as SYMV + K >= 0
+ if (K < 0) {
+ throw new RSRuntimeException("K must be greater than or equal to 0");
+ }
int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) {
+ public void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) {
int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
}
- void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
int M = A.getType().getY();
int N = A.getType().getX();
+ validateGER(Element.F64(mRS), X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
}
- void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
+ public void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
}
- void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
+ public void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
}
- void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
}
- void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
+ public void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
}
@@ -824,8 +862,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
int M = A.getType().getY();
int N = A.getType().getX();
-
- int expectedXDim = 1 + (N - 1) * incX;
+ if (incX <= 0 || incY <= 0) {
+ throw new RSRuntimeException("Vector increments must be greater than 0");
+ }
+ int expectedXDim = 1 + (M - 1) * incX;
if (X.getType().getX() != expectedXDim) {
throw new RSRuntimeException("Incorrect vector dimensions for GERU");
}
@@ -836,12 +876,12 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
- void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
+ public void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
// HEMV is the same as SYR2 validation-wise
int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
+ public void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
// HBMV is the same as SYR2 validation-wise
int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
if (K < 0) {
@@ -849,50 +889,50 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
+ public void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
// HPMV is the same as SPR2
int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
// same as GERU
validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
+ public void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
// same as SYR
- int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
+ int N = validateSYR(Element.F32_2(mRS), Uplo, X, incX, A);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
}
- void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
+ public void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
// equivalent to SPR for validation
int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
}
- void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
// same as SYR2
int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
+ public void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
// same as SPR2
int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
}
- void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
+ public void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
// HEMV is the same as SYR2 validation-wise
int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
+ public void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
// HBMV is the same as SYR2 validation-wise
int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
if (K < 0) {
@@ -900,40 +940,40 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
+ public void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
// HPMV is the same as SPR2
int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
}
- void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
// same as GERU
validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
int M = A.getType().getY();
int N = A.getType().getX();
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
+ public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
// same as SYR
- int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
+ int N = validateSYR(Element.F64_2(mRS), Uplo, X, incX, A);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
}
- void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
+ public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
// equivalent to SPR for validation
int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
}
- void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
+ public void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
// same as SYR2
int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
}
- void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
+ public void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
// same as SPR2
int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
@@ -945,56 +985,74 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
*/
static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
- int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1;
+ int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
if ((A != null && !A.getType().getElement().isCompatible(e)) ||
(B != null && !B.getType().getElement().isCompatible(e)) ||
(C != null && !C.getType().getElement().isCompatible(e))) {
throw new RSRuntimeException("Called BLAS with wrong Element type");
}
- if (C != null) {
- cX = C.getType().getY();
- cY = C.getType().getX();
+ if (C == null) {
+ //since matrix C is used to store the result, it cannot be null.
+ throw new RSRuntimeException("Allocation C cannot be null");
}
+ cM = C.getType().getY();
+ cN = C.getType().getX();
+
if (Side == RIGHT) {
+ if ((A == null && B != null) || (A != null && B == null)) {
+ throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa");
+ }
if (B != null) {
- bX = A.getType().getY();
- bY = A.getType().getX();
+ bM = A.getType().getY();
+ bN = A.getType().getX();
}
if (A != null) {
- aX = B.getType().getY();
- aY = B.getType().getX();
+ aM = B.getType().getY();
+ aN = B.getType().getX();
}
} else {
if (A != null) {
- if (TransA == TRANSPOSE) {
- aY = A.getType().getY();
- aX = A.getType().getX();
+ if (TransA != NO_TRANSPOSE) {
+ aN = A.getType().getY();
+ aM = A.getType().getX();
} else {
- aX = A.getType().getY();
- aY = A.getType().getX();
+ aM = A.getType().getY();
+ aN = A.getType().getX();
}
}
if (B != null) {
- if (TransB == TRANSPOSE) {
- bY = B.getType().getY();
- bX = B.getType().getX();
+ if (TransB != NO_TRANSPOSE) {
+ bN = B.getType().getY();
+ bM = B.getType().getX();
} else {
- bX = B.getType().getY();
- bY = B.getType().getX();
+ bM = B.getType().getY();
+ bN = B.getType().getX();
}
}
}
if (A != null && B != null && C != null) {
- if (aY != bX || aX != cX || bY != cY) {
+ if (aN != bM || aM != cM || bN != cN) {
throw new RSRuntimeException("Called BLAS with invalid dimensions");
}
} else if (A != null && C != null) {
- // A and C only
- if (aX != cY || aY != cX) {
- throw new RSRuntimeException("Called BLAS with invalid dimensions");
+ // A and C only, for SYRK
+ if (cM != cN) {
+ throw new RSRuntimeException("Matrix C is not symmetric");
+ }
+ if (TransA != NO_TRANSPOSE) {
+ if (aN != cM) {
+ throw new RSRuntimeException("Called BLAS with invalid dimensions");
+ }
+ } else {
+ if (aM != cM) {
+ throw new RSRuntimeException("Called BLAS with invalid dimensions");
+ }
}
} else if (A != null && B != null) {
// A and B only
+ if (aN != bM) {
+ throw new RSRuntimeException("Called BLAS with invalid dimensions");
+ }
}
}
@@ -1006,14 +1064,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1027,14 +1085,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateTranspose(TransB);
validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1048,14 +1106,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateTranspose(TransB);
validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1070,14 +1128,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateTranspose(TransB);
validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1090,6 +1148,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, float beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ //For SYMM, Matrix A should be symmetric
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F32(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
beta, C.getID(mRS), 0, 0, 0, 0);
@@ -1098,6 +1160,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, double beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F64(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
beta, C.getID(mRS), 0, 0, 0, 0);
@@ -1106,6 +1171,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, Float2 beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
@@ -1114,6 +1182,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, Double2 beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
@@ -1124,7 +1195,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
@@ -1138,37 +1209,37 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
}
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
}
- public void CSYRK(@Uplo int Uplo, @Transpose int Trans, float alphaX, float alphaY, Allocation A, float betaX, float betaY, Allocation C) {
+ public void CSYRK(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Float2 beta, Allocation C) {
validateTranspose(Trans);
validateUplo(Uplo);
validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
}
- mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY,
+ mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
C.getID(mRS), 0, 0, 0, 0);
}
- public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, double alphaX, double alphaY, Allocation A, double betaX, double betaY, Allocation C) {
+ public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Double2 beta, Allocation C) {
validateTranspose(Trans);
validateUplo(Uplo);
validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
}
- mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY,
+ mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
C.getID(mRS), 0, 0, 0, 0);
}
@@ -1189,7 +1260,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
// check rows versus C
Cdim = A.getType().getY();
}
- if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) {
+ if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) {
throw new RSRuntimeException("Invalid symmetric matrix in SYR2K");
}
// A dims == B dims
@@ -1245,26 +1316,26 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
validateSide(Side);
validateTranspose(TransA);
- int aX = -1, aY = -1, bX = -1, bY = -1;
+ int aM = -1, aN = -1, bM = -1, bN = -1;
if (!A.getType().getElement().isCompatible(e) ||
!B.getType().getElement().isCompatible(e)) {
throw new RSRuntimeException("Called BLAS with wrong Element type");
}
- if (TransA == TRANSPOSE) {
- aY = A.getType().getY();
- aX = A.getType().getX();
- } else {
- aY = A.getType().getX();
- aX = A.getType().getY();
+
+ aM = A.getType().getY();
+ aN = A.getType().getX();
+ if (aM != aN) {
+ throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A");
}
- bX = B.getType().getY();
- bY = B.getType().getX();
+
+ bM = B.getType().getY();
+ bN = B.getType().getX();
if (Side == LEFT) {
- if (aX == 0 || aY != bX) {
+ if (aN != bM) {
throw new RSRuntimeException("Called TRMM with invalid matrices");
}
} else {
- if (bY != aX || aY == 0) {
+ if (bN != aM) {
throw new RSRuntimeException("Called TRMM with invalid matrices");
}
}
@@ -1299,7 +1370,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
- int adim = -1, bX = -1, bY = -1;
+ int adim = -1, bM = -1, bN = -1;
validateSide(Side);
validateTranspose(TransA);
if (!A.getType().getElement().isCompatible(e) ||
@@ -1313,16 +1384,16 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
// for now we assume adapters are sufficient, will reevaluate in the future
throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A");
}
- bX = B.getType().getY();
- bY = B.getType().getX();
+ bM = B.getType().getY();
+ bN = B.getType().getX();
if (Side == LEFT) {
// A is M*M
- if (adim != bY) {
+ if (adim != bM) {
throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
}
} else {
// A is N*N
- if (adim != bX) {
+ if (adim != bN) {
throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
}
}
@@ -1379,17 +1450,17 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
throw new RSRuntimeException("Called HEMM with mismatched B and C");
}
}
- public void CHEMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, Allocation B, float beta, Allocation C) {
+ public void CHEMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
validateUplo(Uplo);
validateHEMM(Element.F32_2(mRS), Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
- alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
+ alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
}
- public void ZHEMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, Allocation B, double beta, Allocation C) {
+ public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
validateUplo(Uplo);
- validateHEMM(Element.F32_2(mRS), Side, A, B, C);
+ validateHEMM(Element.F64_2(mRS), Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
- alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
+ alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
}
static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) {
@@ -1403,11 +1474,11 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
throw new RSRuntimeException("Called HERK with non-square C");
}
if (Trans == NO_TRANSPOSE) {
- if (cdim != A.getType().getX()) {
+ if (cdim != A.getType().getY()) {
throw new RSRuntimeException("Called HERK with invalid A");
}
} else {
- if (cdim != A.getType().getY()) {
+ if (cdim != A.getType().getX()) {
throw new RSRuntimeException("Called HERK with invalid A");
}
}
@@ -1416,7 +1487,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateHERK(Element.F32_2(mRS), Trans, A, C);
int k = 0;
- if (Trans == TRANSPOSE) {
+ if (Trans == CONJ_TRANSPOSE) {
k = A.getType().getY();
} else {
k = A.getType().getX();
@@ -1428,7 +1499,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateHERK(Element.F64_2(mRS), Trans, A, C);
int k = 0;
- if (Trans == TRANSPOSE) {
+ if (Trans == CONJ_TRANSPOSE) {
k = A.getType().getY();
} else {
k = A.getType().getX();