summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--rs/java/android/renderscript/ScriptIntrinsicBLAS.java145
1 files changed, 88 insertions, 57 deletions
diff --git a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java
index 94bde10..d503699 100644
--- a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java
+++ b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java
@@ -241,7 +241,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
static void validateUplo(@Uplo int Uplo) {
- if (Uplo != LEFT && Uplo != RIGHT) {
+ if (Uplo != UPPER && Uplo != LOWER) {
throw new RSRuntimeException("Invalid uplo passed to BLAS");
}
}
@@ -985,56 +985,74 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
*/
static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
- int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1;
+ int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
if ((A != null && !A.getType().getElement().isCompatible(e)) ||
(B != null && !B.getType().getElement().isCompatible(e)) ||
(C != null && !C.getType().getElement().isCompatible(e))) {
throw new RSRuntimeException("Called BLAS with wrong Element type");
}
- if (C != null) {
- cX = C.getType().getY();
- cY = C.getType().getX();
+ if (C == null) {
+ //since matrix C is used to store the result, it cannot be null.
+ throw new RSRuntimeException("Allocation C cannot be null");
}
+ cM = C.getType().getY();
+ cN = C.getType().getX();
+
if (Side == RIGHT) {
+ if ((A == null && B != null) || (A != null && B == null)) {
+ throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa");
+ }
if (B != null) {
- bX = A.getType().getY();
- bY = A.getType().getX();
+ bM = A.getType().getY();
+ bN = A.getType().getX();
}
if (A != null) {
- aX = B.getType().getY();
- aY = B.getType().getX();
+ aM = B.getType().getY();
+ aN = B.getType().getX();
}
} else {
if (A != null) {
- if (TransA == TRANSPOSE) {
- aY = A.getType().getY();
- aX = A.getType().getX();
+ if (TransA != NO_TRANSPOSE) {
+ aN = A.getType().getY();
+ aM = A.getType().getX();
} else {
- aX = A.getType().getY();
- aY = A.getType().getX();
+ aM = A.getType().getY();
+ aN = A.getType().getX();
}
}
if (B != null) {
- if (TransB == TRANSPOSE) {
- bY = B.getType().getY();
- bX = B.getType().getX();
+ if (TransB != NO_TRANSPOSE) {
+ bN = B.getType().getY();
+ bM = B.getType().getX();
} else {
- bX = B.getType().getY();
- bY = B.getType().getX();
+ bM = B.getType().getY();
+ bN = B.getType().getX();
}
}
}
if (A != null && B != null && C != null) {
- if (aY != bX || aX != cX || bY != cY) {
+ if (aN != bM || aM != cM || bN != cN) {
throw new RSRuntimeException("Called BLAS with invalid dimensions");
}
} else if (A != null && C != null) {
- // A and C only
- if (aX != cY || aY != cX) {
- throw new RSRuntimeException("Called BLAS with invalid dimensions");
+ // A and C only, for SYRK
+ if (cM != cN) {
+ throw new RSRuntimeException("Matrix C is not symmetric");
+ }
+ if (TransA != NO_TRANSPOSE) {
+ if (aN != cM) {
+ throw new RSRuntimeException("Called BLAS with invalid dimensions");
+ }
+ } else {
+ if (aM != cM) {
+ throw new RSRuntimeException("Called BLAS with invalid dimensions");
+ }
}
} else if (A != null && B != null) {
// A and B only
+ if (aN != bM) {
+ throw new RSRuntimeException("Called BLAS with invalid dimensions");
+ }
}
}
@@ -1046,14 +1064,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1067,14 +1085,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateTranspose(TransB);
validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1088,14 +1106,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateTranspose(TransB);
validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1110,14 +1128,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateTranspose(TransB);
validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C);
int M = -1, N = -1, K = -1;
- if (TransA == TRANSPOSE) {
+ if (TransA != NO_TRANSPOSE) {
M = A.getType().getX();
K = A.getType().getY();
} else {
M = A.getType().getY();
K = A.getType().getX();
}
- if (TransB == TRANSPOSE) {
+ if (TransB != NO_TRANSPOSE) {
N = B.getType().getY();
} else {
N = B.getType().getX();
@@ -1130,6 +1148,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, float beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ //For SYMM, Matrix A should be symmetric
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F32(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
beta, C.getID(mRS), 0, 0, 0, 0);
@@ -1138,6 +1160,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, double beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F64(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
beta, C.getID(mRS), 0, 0, 0, 0);
@@ -1146,6 +1171,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, Float2 beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
@@ -1154,6 +1182,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
Allocation B, Double2 beta, Allocation C) {
validateSide(Side);
validateUplo(Uplo);
+ if (A.getType().getX() != A.getType().getY()) {
+ throw new RSRuntimeException("Matrix A is not symmetric");
+ }
validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
@@ -1164,7 +1195,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
@@ -1178,7 +1209,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
@@ -1190,7 +1221,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
@@ -1203,7 +1234,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C);
int K = -1;
- if (Trans == TRANSPOSE) {
+ if (Trans != NO_TRANSPOSE) {
K = A.getType().getY();
} else {
K = A.getType().getX();
@@ -1229,7 +1260,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
// check rows versus C
Cdim = A.getType().getY();
}
- if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) {
+ if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) {
throw new RSRuntimeException("Invalid symmetric matrix in SYR2K");
}
// A dims == B dims
@@ -1285,26 +1316,26 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
validateSide(Side);
validateTranspose(TransA);
- int aX = -1, aY = -1, bX = -1, bY = -1;
+ int aM = -1, aN = -1, bM = -1, bN = -1;
if (!A.getType().getElement().isCompatible(e) ||
!B.getType().getElement().isCompatible(e)) {
throw new RSRuntimeException("Called BLAS with wrong Element type");
}
- if (TransA == TRANSPOSE) {
- aY = A.getType().getY();
- aX = A.getType().getX();
- } else {
- aY = A.getType().getX();
- aX = A.getType().getY();
+
+ aM = A.getType().getY();
+ aN = A.getType().getX();
+ if (aM != aN) {
+ throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A");
}
- bX = B.getType().getY();
- bY = B.getType().getX();
+
+ bM = B.getType().getY();
+ bN = B.getType().getX();
if (Side == LEFT) {
- if (aX == 0 || aY != bX) {
+ if (aN != bM) {
throw new RSRuntimeException("Called TRMM with invalid matrices");
}
} else {
- if (bY != aX || aY == 0) {
+ if (bN != aM) {
throw new RSRuntimeException("Called TRMM with invalid matrices");
}
}
@@ -1339,7 +1370,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
- int adim = -1, bX = -1, bY = -1;
+ int adim = -1, bM = -1, bN = -1;
validateSide(Side);
validateTranspose(TransA);
if (!A.getType().getElement().isCompatible(e) ||
@@ -1353,16 +1384,16 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
// for now we assume adapters are sufficient, will reevaluate in the future
throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A");
}
- bX = B.getType().getY();
- bY = B.getType().getX();
+ bM = B.getType().getY();
+ bN = B.getType().getX();
if (Side == LEFT) {
// A is M*M
- if (adim != bY) {
+ if (adim != bM) {
throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
}
} else {
// A is N*N
- if (adim != bX) {
+ if (adim != bN) {
throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
}
}
@@ -1427,7 +1458,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
}
public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
validateUplo(Uplo);
- validateHEMM(Element.F32_2(mRS), Side, A, B, C);
+ validateHEMM(Element.F64_2(mRS), Side, A, B, C);
mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
}
@@ -1443,11 +1474,11 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
throw new RSRuntimeException("Called HERK with non-square C");
}
if (Trans == NO_TRANSPOSE) {
- if (cdim != A.getType().getX()) {
+ if (cdim != A.getType().getY()) {
throw new RSRuntimeException("Called HERK with invalid A");
}
} else {
- if (cdim != A.getType().getY()) {
+ if (cdim != A.getType().getX()) {
throw new RSRuntimeException("Called HERK with invalid A");
}
}
@@ -1456,7 +1487,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateHERK(Element.F32_2(mRS), Trans, A, C);
int k = 0;
- if (Trans == TRANSPOSE) {
+ if (Trans == CONJ_TRANSPOSE) {
k = A.getType().getY();
} else {
k = A.getType().getX();
@@ -1468,7 +1499,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
validateUplo(Uplo);
validateHERK(Element.F64_2(mRS), Trans, A, C);
int k = 0;
- if (Trans == TRANSPOSE) {
+ if (Trans == CONJ_TRANSPOSE) {
k = A.getType().getY();
} else {
k = A.getType().getX();