diff options
author | Tim Murray <timmurray@google.com> | 2015-01-12 16:47:56 -0800 |
---|---|---|
committer | Tim Murray <timmurray@google.com> | 2015-02-17 13:10:25 -0800 |
commit | 25207df658d6a8a3e885c7017fcc25702363583c (patch) | |
tree | 1f6c3911bdac4830524e1c19ff31111a6955bab3 /rs/java/android/renderscript/ScriptIntrinsicBLAS.java | |
parent | 627396ac90dd4f8510e931bffb6705f41878c87a (diff) | |
download | frameworks_base-25207df658d6a8a3e885c7017fcc25702363583c.zip frameworks_base-25207df658d6a8a3e885c7017fcc25702363583c.tar.gz frameworks_base-25207df658d6a8a3e885c7017fcc25702363583c.tar.bz2 |
Add BLAS intrinsic.
Change-Id: I95ddc46cb3f6217d2ead1091fd47450389544324
Diffstat (limited to 'rs/java/android/renderscript/ScriptIntrinsicBLAS.java')
-rw-r--r-- | rs/java/android/renderscript/ScriptIntrinsicBLAS.java | 1489 |
1 files changed, 1489 insertions, 0 deletions
diff --git a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java new file mode 100644 index 0000000..90d2300 --- /dev/null +++ b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java @@ -0,0 +1,1489 @@ +/* + * Copyright (C) 2015 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.renderscript; + +import android.annotation.IntDef; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +/** + * + * BLAS + * + * @hide + **/ +public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { + private Allocation mLUT; + + private ScriptIntrinsicBLAS(long id, RenderScript rs) { + super(id, rs); + } + + private static final int RsBlas_sdsdot = 1; + private static final int RsBlas_dsdot = 2; + private static final int RsBlas_sdot = 3; + private static final int RsBlas_ddot = 4; + private static final int RsBlas_cdotu_sub = 5; + private static final int RsBlas_cdotc_sub = 6; + private static final int RsBlas_zdotu_sub = 7; + private static final int RsBlas_zdotc_sub = 8; + private static final int RsBlas_snrm2 = 9; + private static final int RsBlas_sasum = 10; + private static final int RsBlas_dnrm2 = 11; + private static final int RsBlas_dasum = 12; + private static final int RsBlas_scnrm2 = 13; + private static final int RsBlas_scasum = 14; + private static final int RsBlas_dznrm2 = 15; + private static final int RsBlas_dzasum = 16; + private static final int RsBlas_isamax = 17; + private static final int RsBlas_idamax = 18; + private static final int RsBlas_icamax = 19; + private static final int RsBlas_izamax = 20; + private static final int RsBlas_sswap = 21; + private static final int RsBlas_scopy = 22; + private static final int RsBlas_saxpy = 23; + private static final int RsBlas_dswap = 24; + private static final int RsBlas_dcopy = 25; + private static final int RsBlas_daxpy = 26; + private static final int RsBlas_cswap = 27; + private static final int RsBlas_ccopy = 28; + private static final int RsBlas_caxpy = 29; + private static final int RsBlas_zswap = 30; + private static final int RsBlas_zcopy = 31; + private static final int RsBlas_zaxpy = 32; + private static final int RsBlas_srotg = 33; + private static final int RsBlas_srotmg = 34; + private static final int RsBlas_srot = 35; + private static final int RsBlas_srotm = 36; + private static final int RsBlas_drotg = 37; + private static final int RsBlas_drotmg = 38; + private static final int RsBlas_drot = 39; + private static final int RsBlas_drotm = 40; + private static final int RsBlas_sscal = 41; + private static final int RsBlas_dscal = 42; + private static final int RsBlas_cscal = 43; + private static final int RsBlas_zscal = 44; + private static final int RsBlas_csscal = 45; + private static final int RsBlas_zdscal = 46; + private static final int RsBlas_sgemv = 47; + private static final int RsBlas_sgbmv = 48; + private static final int RsBlas_strmv = 49; + private static final int RsBlas_stbmv = 50; + private static final int RsBlas_stpmv = 51; + private static final int RsBlas_strsv = 52; + private static final int RsBlas_stbsv = 53; + private static final int RsBlas_stpsv = 54; + private static final int RsBlas_dgemv = 55; + private static final int RsBlas_dgbmv = 56; + private static final int RsBlas_dtrmv = 57; + private static final int RsBlas_dtbmv = 58; + private static final int RsBlas_dtpmv = 59; + private static final int RsBlas_dtrsv = 60; + private static final int RsBlas_dtbsv = 61; + private static final int RsBlas_dtpsv = 62; + private static final int RsBlas_cgemv = 63; + private static final int RsBlas_cgbmv = 64; + private static final int RsBlas_ctrmv = 65; + private static final int RsBlas_ctbmv = 66; + private static final int RsBlas_ctpmv = 67; + private static final int RsBlas_ctrsv = 68; + private static final int RsBlas_ctbsv = 69; + private static final int RsBlas_ctpsv = 70; + private static final int RsBlas_zgemv = 71; + private static final int RsBlas_zgbmv = 72; + private static final int RsBlas_ztrmv = 73; + private static final int RsBlas_ztbmv = 74; + private static final int RsBlas_ztpmv = 75; + private static final int RsBlas_ztrsv = 76; + private static final int RsBlas_ztbsv = 77; + private static final int RsBlas_ztpsv = 78; + private static final int RsBlas_ssymv = 79; + private static final int RsBlas_ssbmv = 80; + private static final int RsBlas_sspmv = 81; + private static final int RsBlas_sger = 82; + private static final int RsBlas_ssyr = 83; + private static final int RsBlas_sspr = 84; + private static final int RsBlas_ssyr2 = 85; + private static final int RsBlas_sspr2 = 86; + private static final int RsBlas_dsymv = 87; + private static final int RsBlas_dsbmv = 88; + private static final int RsBlas_dspmv = 89; + private static final int RsBlas_dger = 90; + private static final int RsBlas_dsyr = 91; + private static final int RsBlas_dspr = 92; + private static final int RsBlas_dsyr2 = 93; + private static final int RsBlas_dspr2 = 94; + private static final int RsBlas_chemv = 95; + private static final int RsBlas_chbmv = 96; + private static final int RsBlas_chpmv = 97; + private static final int RsBlas_cgeru = 98; + private static final int RsBlas_cgerc = 99; + private static final int RsBlas_cher = 100; + private static final int RsBlas_chpr = 101; + private static final int RsBlas_cher2 = 102; + private static final int RsBlas_chpr2 = 103; + private static final int RsBlas_zhemv = 104; + private static final int RsBlas_zhbmv = 105; + private static final int RsBlas_zhpmv = 106; + private static final int RsBlas_zgeru = 107; + private static final int RsBlas_zgerc = 108; + private static final int RsBlas_zher = 109; + private static final int RsBlas_zhpr = 110; + private static final int RsBlas_zher2 = 111; + private static final int RsBlas_zhpr2 = 112; + private static final int RsBlas_sgemm = 113; + private static final int RsBlas_ssymm = 114; + private static final int RsBlas_ssyrk = 115; + private static final int RsBlas_ssyr2k = 116; + private static final int RsBlas_strmm = 117; + private static final int RsBlas_strsm = 118; + private static final int RsBlas_dgemm = 119; + private static final int RsBlas_dsymm = 120; + private static final int RsBlas_dsyrk = 121; + private static final int RsBlas_dsyr2k = 122; + private static final int RsBlas_dtrmm = 123; + private static final int RsBlas_dtrsm = 124; + private static final int RsBlas_cgemm = 125; + private static final int RsBlas_csymm = 126; + private static final int RsBlas_csyrk = 127; + private static final int RsBlas_csyr2k = 128; + private static final int RsBlas_ctrmm = 129; + private static final int RsBlas_ctrsm = 130; + private static final int RsBlas_zgemm = 131; + private static final int RsBlas_zsymm = 132; + private static final int RsBlas_zsyrk = 133; + private static final int RsBlas_zsyr2k = 134; + private static final int RsBlas_ztrmm = 135; + private static final int RsBlas_ztrsm = 136; + private static final int RsBlas_chemm = 137; + private static final int RsBlas_cherk = 138; + private static final int RsBlas_cher2k = 139; + private static final int RsBlas_zhemm = 140; + private static final int RsBlas_zherk = 141; + private static final int RsBlas_zher2k = 142; + + /** + */ + public static ScriptIntrinsicBLAS create(RenderScript rs) { + long id = rs.nScriptIntrinsicCreate(13, Element.U32(rs).getID(rs)); + return new ScriptIntrinsicBLAS(id, rs); + } + + @IntDef({NO_TRANSPOSE, TRANSPOSE, CONJ_TRANSPOSE}) + @Retention(RetentionPolicy.SOURCE) + public @interface Transpose {} + + @IntDef({UPPER, LOWER}) + @Retention(RetentionPolicy.SOURCE) + public @interface Uplo {} + + @IntDef({NON_UNIT, UNIT}) + @Retention(RetentionPolicy.SOURCE) + public @interface Diag {} + + @IntDef({LEFT, RIGHT}) + @Retention(RetentionPolicy.SOURCE) + public @interface Side {} + + public static final int NO_TRANSPOSE = 111; + public static final int TRANSPOSE = 112; + public static final int CONJ_TRANSPOSE = 113; + + public static final int UPPER = 121; + public static final int LOWER = 122; + + public static final int NON_UNIT = 131; + public static final int UNIT = 132; + + public static final int LEFT = 141; + public static final int RIGHT = 142; + + static void validateSide(@Side int Side) { + if (Side != LEFT && Side != RIGHT) { + throw new RSRuntimeException("Invalid side passed to BLAS"); + } + } + + static void validateTranspose(@Transpose int Trans) { + if (Trans != NO_TRANSPOSE && Trans != TRANSPOSE && + Trans != CONJ_TRANSPOSE) { + throw new RSRuntimeException("Invalid transpose passed to BLAS"); + } + } + + static void validateConjTranspose(@Transpose int Trans) { + if (Trans != NO_TRANSPOSE && + Trans != CONJ_TRANSPOSE) { + throw new RSRuntimeException("Invalid transpose passed to BLAS"); + } + } + + static void validateDiag(@Diag int Diag) { + if (Diag != NON_UNIT && Diag != UNIT) { + throw new RSRuntimeException("Invalid diag passed to BLAS"); + } + } + + static void validateUplo(@Uplo int Uplo) { + if (Uplo != LEFT && Uplo != RIGHT) { + throw new RSRuntimeException("Invalid uplo passed to BLAS"); + } + } + + + /** + * Level 2 BLAS + */ + + static void validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) { + validateTranspose(TransA); + int M = A.getType().getY(); + int N = A.getType().getX(); + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } + int expectedXDim = -1, expectedYDim = -1; + if (TransA == NO_TRANSPOSE) { + expectedXDim = 1 + (N - 1) * incX; + expectedYDim = 1 + (M - 1) * incY; + } else { + expectedXDim = 1 + (M - 1) * incX; + expectedYDim = 1 + (N - 1) * incY; + } + if (X.getType().getX() != expectedXDim || + Y.getType().getY() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for GEMV"); + } + } + void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + + void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + // GBMV has the same validation requirements as GEMV + KL and KU >= 0 + validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); + if (KL < 0 || KU < 0) { + throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); + } + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); + } + void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + // GBMV has the same validation requirements as GEMV + KL and KU >= 0 + validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); + if (KL < 0 || KU < 0) { + throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); + } + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); + } + void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + // GBMV has the same validation requirements as GEMV + KL and KU >= 0 + validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); + if (KL < 0 || KU < 0) { + throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); + } + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); + } + void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + // GBMV has the same validation requirements as GEMV + KL and KU >= 0 + validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); + if (KL < 0 || KU < 0) { + throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); + } + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); + } + + static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) { + validateTranspose(TransA); + int N = A.getType().getY(); + if (A.getType().getX() != N) { + throw new RSRuntimeException("A must be a square matrix for TRMV"); + } + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (incX <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for TRMV"); + } + } + + static int validateTPMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + validateTranspose(TransA); + validateUplo(Uplo); + validateDiag(Diag); + if (!Ap.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (Ap.getType().getY() > 1) { + throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); + } + + int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); + if (Ap.getType().getX() != ((N * (N+1)) / 2)) { + throw new RSRuntimeException("Invalid dimension for Ap"); + } + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); + } + + return N; + } + + void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F32(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F64(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + validateTRMV(Element.F32(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + validateTRMV(Element.F64(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + // TRSV is the same as TRMV + validateTRMV(Element.F32(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + + } + void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + // TRSV is the same as TRMV + validateTRMV(Element.F64(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + + } + void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + // TRSV is the same as TRMV + validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + + } + void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + // TRSV is the same as TRMV + validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + + } + void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + validateTRMV(Element.F32(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + if (K < 0) { + throw new RSRuntimeException("Number of diagonals must be positive"); + } + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + validateTRMV(Element.F64(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + if (K < 0) { + throw new RSRuntimeException("Number of diagonals must be positive"); + } + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + if (K < 0) { + throw new RSRuntimeException("Number of diagonals must be positive"); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + if (K < 0) { + throw new RSRuntimeException("Number of diagonals must be positive"); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + // TPSV is same as TPMV + int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + // TPSV is same as TPMV + int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + // TPSV is same as TPMV + int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + // TPSV is same as TPMV + int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + + /** + * Level 2, S and D only + */ + static int validateSYMV(Element e, @Uplo int Uplo, Allocation A, Allocation X, Allocation Y, int incX, int incY) { + validateUplo(Uplo); + int N = A.getType().getY(); + if (A.getType().getX() != N) { + throw new RSRuntimeException("A must be a square matrix for SYMV"); + } + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e) ) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); + } + int expectedYDim = 1 + (N - 1) * incY; + if (Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); + } + return N; + } + static int validateSPMV(Element e, @Uplo int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) { + validateUplo(Uplo); + if (!Ap.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (Ap.getType().getY() > 1) { + throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); + } + + int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); + if (Ap.getType().getX() != ((N * (N+1)) / 2)) { + throw new RSRuntimeException("Invalid dimension for Ap"); + } + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + } + int expectedYDim = 1 + (N - 1) * incY; + if (Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + } + + return N; + } + static void validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e) ) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + int M = A.getType().getY(); + int N = A.getType().getX(); + + if (N < 1 || M < 1) { + throw new RSRuntimeException("M and N must be 1 or greater for GER"); + } + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for GER"); + } + int expectedYDim = 1 + (N - 1) * incY; + if (Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for GER"); + } + + + } + static int validateSYR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation A) { + validateUplo(Uplo); + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + + int N = A.getType().getX(); + + if (X.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + if (N != A.getType().getY()) { + throw new RSRuntimeException("A must be a symmetric matrix"); + } + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SYR"); + } + return N; + } + static int validateSPR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Ap) { + validateUplo(Uplo); + if (!Ap.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (Ap.getType().getY() > 1) { + throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); + } + + int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); + if (Ap.getType().getX() != ((N * (N+1)) / 2)) { + throw new RSRuntimeException("Invalid dimension for Ap"); + } + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + } + + return N; + } + + static int validateSYR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + validateUplo(Uplo); + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + int N = A.getType().getX(); + + if (N != A.getType().getY()) { + throw new RSRuntimeException("A must be a symmetric matrix"); + } + + int expectedXDim = 1 + (N - 1) * incX; + int expectedYDim = 1 + (N - 1) * incY; + if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SYR"); + } + return N; + + } + static int validateSPR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + validateUplo(Uplo); + if (!Ap.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (Ap.getType().getY() > 1) { + throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); + } + + int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); + if (Ap.getType().getX() != ((N * (N+1)) / 2)) { + throw new RSRuntimeException("Invalid dimension for Ap"); + } + + int expectedXDim = 1 + (N - 1) * incX; + int expectedYDim = 1 + (N - 1) * incY; + if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + } + + return N; + } + + void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + // SBMV is the same as SYMV + int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) { + int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); + } + void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { + int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); + } + void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { + int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); + } + void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); + } + void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); + } + void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + // SBMV is the same as SYMV + int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) { + int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); + } + void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { + int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); + } + void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { + int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); + } + void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); + } + void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); + } + + + /** + * Level 2, C and Z only + */ + + static void validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + int M = A.getType().getY(); + int N = A.getType().getX(); + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for GERU"); + } + int expectedYDim = 1 + (N - 1) * incY; + if (Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for GERU"); + } + + } + + void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + // HEMV is the same as SYR2 validation-wise + int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + // HBMV is the same as SYR2 validation-wise + int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); + if (K < 0) { + throw new RSRuntimeException("K must be 0 or greater for HBMV"); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + // HPMV is the same as SPR2 + int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + // same as GERU + validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { + // same as SYR + int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); + } + void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { + // equivalent to SPR for validation + int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); + } + void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + // same as SYR2 + int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + // same as SPR2 + int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); + } + void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + // HEMV is the same as SYR2 validation-wise + int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + // HBMV is the same as SYR2 validation-wise + int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); + if (K < 0) { + throw new RSRuntimeException("K must be 0 or greater for HBMV"); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + // HPMV is the same as SPR2 + int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + // same as GERU + validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { + // same as SYR + int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); + } + void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { + // equivalent to SPR for validation + int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); + } + void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + // same as SYR2 + int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + // same as SPR2 + int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); + } + + + /** + * Level 3 BLAS + */ + + static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { + int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1; + if ((A != null && !A.getType().getElement().isCompatible(e)) || + (B != null && !B.getType().getElement().isCompatible(e)) || + (C != null && !C.getType().getElement().isCompatible(e))) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (C != null) { + cX = C.getType().getY(); + cY = C.getType().getX(); + } + if (Side == RIGHT) { + if (B != null) { + bX = A.getType().getY(); + bY = A.getType().getX(); + } + if (A != null) { + aX = B.getType().getY(); + aY = B.getType().getX(); + } + } else { + if (A != null) { + if (TransA == TRANSPOSE) { + aY = A.getType().getY(); + aX = A.getType().getX(); + } else { + aX = A.getType().getY(); + aY = A.getType().getX(); + } + } + if (B != null) { + if (TransB == TRANSPOSE) { + bY = B.getType().getY(); + bX = B.getType().getX(); + } else { + bX = B.getType().getY(); + bY = B.getType().getX(); + } + } + } + if (A != null && B != null && C != null) { + if (aY != bX || aX != cX || bY != cY) { + throw new RSRuntimeException("Called BLAS with invalid dimensions"); + } + } else if (A != null && C != null) { + // A and C only + if (aX != cY || aY != cX) { + throw new RSRuntimeException("Called BLAS with invalid dimensions"); + } + } else if (A != null && B != null) { + // A and B only + } + + } + + public void SGEMM(@Transpose int TransA, @Transpose int TransB, float alpha, Allocation A, + Allocation B, float beta, Allocation C) { + validateTranspose(TransA); + validateTranspose(TransB); + validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); + + int M = -1, N = -1, K = -1; + if (TransA == TRANSPOSE) { + M = A.getType().getX(); + K = A.getType().getY(); + } else { + M = A.getType().getY(); + K = A.getType().getX(); + } + if (TransB == TRANSPOSE) { + N = B.getType().getY(); + } else { + N = B.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS), + beta, C.getID(mRS), 0, 0, 0, 0); + } + public void DGEMM(@Transpose int TransA, @Transpose int TransB, double alpha, Allocation A, + Allocation B, double beta, Allocation C) { + validateTranspose(TransA); + validateTranspose(TransB); + validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); + int M = -1, N = -1, K = -1; + if (TransA == TRANSPOSE) { + M = A.getType().getX(); + K = A.getType().getY(); + } else { + M = A.getType().getY(); + K = A.getType().getX(); + } + if (TransB == TRANSPOSE) { + N = B.getType().getY(); + } else { + N = B.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS), + beta, C.getID(mRS), 0, 0, 0, 0); + } + public void CGEMM(@Transpose int TransA, @Transpose int TransB, Float2 alpha, Allocation A, + Allocation B, Float2 beta, Allocation C) { + validateTranspose(TransA); + validateTranspose(TransB); + validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); + int M = -1, N = -1, K = -1; + if (TransA == TRANSPOSE) { + M = A.getType().getX(); + K = A.getType().getY(); + } else { + M = A.getType().getY(); + K = A.getType().getX(); + } + if (TransB == TRANSPOSE) { + N = B.getType().getY(); + } else { + N = B.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), + beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + + public void ZGEMM(@Transpose int TransA, @Transpose int TransB, Double2 alpha, Allocation A, + Allocation B, Double2 beta, Allocation C) { + validateTranspose(TransA); + validateTranspose(TransB); + validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); + int M = -1, N = -1, K = -1; + if (TransA == TRANSPOSE) { + M = A.getType().getX(); + K = A.getType().getY(); + } else { + M = A.getType().getY(); + K = A.getType().getX(); + } + if (TransB == TRANSPOSE) { + N = B.getType().getY(); + } else { + N = B.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), + beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + + public void SSYMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, + Allocation B, float beta, Allocation C) { + validateSide(Side); + validateUplo(Uplo); + validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), + beta, C.getID(mRS), 0, 0, 0, 0); + } + public void DSYMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, + Allocation B, double beta, Allocation C) { + validateSide(Side); + validateUplo(Uplo); + validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), + beta, C.getID(mRS), 0, 0, 0, 0); + } + public void CSYMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, + Allocation B, Float2 beta, Allocation C) { + validateSide(Side); + validateUplo(Uplo); + validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), + beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + public void ZSYMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, + Allocation B, Double2 beta, Allocation C) { + validateSide(Side); + validateUplo(Uplo); + validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), + beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + + public void SSYRK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) { + validateTranspose(Trans); + validateUplo(Uplo); + validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0); + } + + public void DSYRK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) { + validateTranspose(Trans); + validateUplo(Uplo); + validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0); + } + public void CSYRK(@Uplo int Uplo, @Transpose int Trans, float alphaX, float alphaY, Allocation A, float betaX, float betaY, Allocation C) { + validateTranspose(Trans); + validateUplo(Uplo); + validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY, + C.getID(mRS), 0, 0, 0, 0); + } + public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, double alphaX, double alphaY, Allocation A, double betaX, double betaY, Allocation C) { + validateTranspose(Trans); + validateUplo(Uplo); + validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY, + C.getID(mRS), 0, 0, 0, 0); + } + + static void validateSYR2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) { + validateTranspose(Trans); + if (!A.getType().getElement().isCompatible(e) || + !B.getType().getElement().isCompatible(e) || + !C.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + int Cdim = -1; + // A is n x k if no transpose, k x n if transpose + // C is n x n + if (Trans == TRANSPOSE) { + // check columns versus C + Cdim = A.getType().getX(); + } else { + // check rows versus C + Cdim = A.getType().getY(); + } + if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) { + throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); + } + // A dims == B dims + if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) { + throw new RSRuntimeException("Invalid A and B in SYR2K"); + } + } + public void SSYR2K(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, Allocation B, float beta, Allocation C) { + validateUplo(Uplo); + validateSYR2K(Element.F32(mRS), Trans, A, B, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); + } + public void DSYR2K(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, Allocation B, double beta, Allocation C) { + validateUplo(Uplo); + validateSYR2K(Element.F64(mRS), Trans, A, B, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); + } + public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) { + validateUplo(Uplo); + validateSYR2K(Element.F32_2(mRS), Trans, A, B, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { + validateUplo(Uplo); + validateSYR2K(Element.F64_2(mRS), Trans, A, B, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + + static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { + validateSide(Side); + validateTranspose(TransA); + int aX = -1, aY = -1, bX = -1, bY = -1; + if (!A.getType().getElement().isCompatible(e) || + !B.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (TransA == TRANSPOSE) { + aY = A.getType().getY(); + aX = A.getType().getX(); + } else { + aY = A.getType().getX(); + aX = A.getType().getY(); + } + bX = B.getType().getY(); + bY = B.getType().getX(); + if (Side == LEFT) { + if (aX == 0 || aY != bX) { + throw new RSRuntimeException("Called TRMM with invalid matrices"); + } + } else { + if (bY != aX || aY == 0) { + throw new RSRuntimeException("Called TRMM with invalid matrices"); + } + } + } + public void STRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRMM(Element.F32(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0); + } + public void DTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRMM(Element.F64(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0); + } + public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRMM(Element.F32_2(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); + } + public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRMM(Element.F64_2(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); + } + + static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { + int adim = -1, bX = -1, bY = -1; + validateSide(Side); + validateTranspose(TransA); + if (!A.getType().getElement().isCompatible(e) || + !B.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + adim = A.getType().getX(); + if (adim != A.getType().getY()) { + // this may be unnecessary, the restriction could potentially be relaxed + // A needs to contain at least that symmetric matrix but could theoretically be larger + // for now we assume adapters are sufficient, will reevaluate in the future + throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); + } + bX = B.getType().getY(); + bY = B.getType().getX(); + if (Side == LEFT) { + // A is M*M + if (adim != bY) { + throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); + } + } else { + // A is N*N + if (adim != bX) { + throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); + } + } + } + public void STRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRSM(Element.F32(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); + } + public void DTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRSM(Element.F64(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); + } + public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRSM(Element.F32_2(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); + } + public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRSM(Element.F64_2(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); + } + + static void validateHEMM(Element e, @Side int Side, Allocation A, Allocation B, Allocation C) { + validateSide(Side); + + if (!A.getType().getElement().isCompatible(e) || + !B.getType().getElement().isCompatible(e) || + !C.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + + // A must be square; can potentially be relaxed similar to TRSM + int adim = A.getType().getX(); + if (adim != A.getType().getY()) { + throw new RSRuntimeException("Called HEMM with non-square A"); + } + if ((Side == LEFT && adim != B.getType().getY()) || + (Side == RIGHT && adim != B.getType().getX())) { + throw new RSRuntimeException("Called HEMM with invalid B"); + } + if (B.getType().getX() != C.getType().getX() || + B.getType().getY() != C.getType().getY()) { + throw new RSRuntimeException("Called HEMM with mismatched B and C"); + } + } + public void CHEMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, Allocation B, float beta, Allocation C) { + validateUplo(Uplo); + validateHEMM(Element.F32_2(mRS), Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, + alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + public void ZHEMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, Allocation B, double beta, Allocation C) { + validateUplo(Uplo); + validateHEMM(Element.F32_2(mRS), Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, + alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + + static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) { + if (!A.getType().getElement().isCompatible(e) || + !C.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + validateConjTranspose(Trans); + int cdim = C.getType().getX(); + if (cdim != C.getType().getY()) { + throw new RSRuntimeException("Called HERK with non-square C"); + } + if (Trans == NO_TRANSPOSE) { + if (cdim != A.getType().getX()) { + throw new RSRuntimeException("Called HERK with invalid A"); + } + } else { + if (cdim != A.getType().getY()) { + throw new RSRuntimeException("Called HERK with invalid A"); + } + } + } + public void CHERK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) { + validateUplo(Uplo); + validateHERK(Element.F32_2(mRS), Trans, A, C); + int k = 0; + if (Trans == TRANSPOSE) { + k = A.getType().getY(); + } else { + k = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, + alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + public void ZHERK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) { + validateUplo(Uplo); + validateHERK(Element.F64_2(mRS), Trans, A, C); + int k = 0; + if (Trans == TRANSPOSE) { + k = A.getType().getY(); + } else { + k = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, + alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + + static void validateHER2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) { + if (!A.getType().getElement().isCompatible(e) || + !B.getType().getElement().isCompatible(e) || + !C.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + validateConjTranspose(Trans); + int cdim = C.getType().getX(); + if (cdim != C.getType().getY()) { + throw new RSRuntimeException("Called HER2K with non-square C"); + } + if (Trans == NO_TRANSPOSE) { + if (A.getType().getY() != cdim) { + throw new RSRuntimeException("Called HER2K with invalid matrices"); + } + } else { + if (A.getType().getX() != cdim) { + throw new RSRuntimeException("Called HER2K with invalid matrices"); + } + } + if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) { + throw new RSRuntimeException("Called HER2K with invalid A and B matrices"); + } + } + public void CHER2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, float beta, Allocation C) { + validateUplo(Uplo); + validateHER2K(Element.F32_2(mRS), Trans, A, B, C); + int k = 0; + if (Trans == NO_TRANSPOSE) { + k = A.getType().getX(); + } else { + k = A.getType().getY(); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y, + A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + public void ZHER2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, double beta, Allocation C) { + validateUplo(Uplo); + validateHER2K(Element.F64_2(mRS), Trans, A, B, C); + int k = 0; + if (Trans == NO_TRANSPOSE) { + k = A.getType().getX(); + } else { + k = A.getType().getY(); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y, + A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + + + +} |