@@ -368,10 +368,10 @@ for (fname, elty) in ((:rocblas_stbmv,:Float32),
368368 x
369369 end
370370 function tbmv (
371- uplo:: Char , trans:: Char , diag:: Char ,
371+ uplo:: Char , trans:: Char , diag:: Char , k :: Integer ,
372372 A:: ROCMatrix{$elty} , x:: ROCVector{$elty} ,
373373 )
374- tbmv! (uplo, trans, diag, A, copy (x))
374+ tbmv! (uplo, trans, diag, k, A, copy (x))
375375 end
376376 end
377377end
@@ -496,10 +496,10 @@ for (fname, elty) in ((:rocblas_dsyr,:Float64),
496496end
497497
498498# ## her
499- for (fname, elty) in ((:rocblas_zher ,:ComplexF64 ),
500- (:rocblas_cher ,:ComplexF32 ))
499+ for (fname, elty, relty ) in ((:rocblas_zher ,:ComplexF64 , :Float64 ),
500+ (:rocblas_cher ,:ComplexF32 , :Float32 ))
501501 @eval begin
502- function her! (uplo:: Char , alpha:: $elty , x:: ROCVector{$elty} , A:: ROCMatrix{$elty} )
502+ function her! (uplo:: Char , alpha:: $relty , x:: ROCVector{$elty} , A:: ROCMatrix{$elty} )
503503 m, n = size (A)
504504 m == n || throw (DimensionMismatch (" Matrix A is $m by $n but must be square" ))
505505 length (x) == n || throw (DimensionMismatch (" Length of vector must be the same as the matrix dimensions" ))
639639 " - length(B)=$(length (B)) \n " *
640640 " - length(C)=$(length (C)) \n " ))
641641 end
642+ m, k, n = (- 1 , - 1 , - 1 )
642643 for (i, (As, Bs, Cs)) in enumerate (zip (A, B, C))
643644 m, k = size (As, transA == ' N' ? 1 : 2 ), size (As, transA == ' N' ? 2 : 1 )
644645 n, g = size (Bs, transB == ' N' ? 2 : 1 ), size (Bs, transB == ' N' ? 1 : 2 )
653654 lda = max (1 , stride (A[1 ], 2 ))
654655 ldb = max (1 , stride (B[1 ], 2 ))
655656 ldc = max (1 , stride (C[1 ], 2 ))
656- m, k, n, lda, ldb, ldc
657+ return m, k, n, lda, ldb, ldc
657658end
658659
659660# # (GE) general matrix-matrix multiplication batched
@@ -666,15 +667,17 @@ for (fname, elty) in
666667 @eval begin
667668 function gemm_batched! (
668669 transA:: Char , transB:: Char ,
669- alpha:: ($elty) , A:: ROCArray{$elty, 3} ,
670- B:: ROCArray{$elty, 3} , beta:: ($elty) , C:: ROCArray{$elty, 3} ,
671- )
670+ alpha:: ($elty) , A:: TA ,
671+ B:: TB , beta:: ($elty) , C:: TC ,
672+ ) where {TA<: Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}} ,
673+ TB<: Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}} ,
674+ TC<: Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}} }
672675 m, k, n, lda, ldb, ldc = _check_gemm_batched_dims (
673676 transA, transB, A, B, C)
674677
675- batch_count = size (C, 3 )
676- a_broadcast = (size (A, 3 ) == 1 ) && (batch_count > 1 )
677- b_broadcast = (size (B, 3 ) == 1 ) && (batch_count > 1 )
678+ batch_count = C isa ROCArray ? size (C, 3 ) : length (C )
679+ a_broadcast = A isa ROCArray && (size (A, 3 ) == 1 ) && (batch_count > 1 )
680+ b_broadcast = B isa ROCArray && (size (B, 3 ) == 1 ) && (batch_count > 1 )
678681 Ab = a_broadcast ? device_batch (A, batch_count) : device_batch (A)
679682 Bb = b_broadcast ? device_batch (B, batch_count) : device_batch (B)
680683 Cb = device_batch (C)
@@ -684,18 +687,18 @@ for (fname, elty) in
684687 handle, transA, transB,
685688 m, n, k, Ref (alpha), Ab, lda, Bb, ldb, Ref (beta),
686689 Cb, ldc, batch_count)
687- C
690+ return C
688691 end
689692 function gemm_batched (
690693 transA:: Char , transB:: Char , alpha:: ($elty) , A:: T , B:: K ,
691694 ) where {
692- T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}} ,
693- K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}} ,
695+ T <: Union{AnyROCArray{$elty, 3}, Vector{<: ROCMatrix{$elty}}} ,
696+ K <: Union{AnyROCArray{$elty, 3}, Vector{<: ROCMatrix{$elty}}} ,
694697 }
695698 is_ab_vec = Int (T <: Vector ) + Int (K <: Vector )
696699 (is_ab_vec != 0 ) && (is_ab_vec != 2 ) && throw (ArgumentError (
697700 " If `A` is a `Vector{ROCMatrix}`, then `B` must be too." ))
698- if T isa Vector
701+ if T <: Vector
699702 C = ROCMatrix{$ elty}[similar (B[i], $ elty, (
700703 size (A[i], transA == ' N' ? 1 : 2 ),
701704 size (B[i], transB == ' N' ? 2 : 1 ))) for i in 1 : length (A)]
@@ -704,13 +707,15 @@ for (fname, elty) in
704707 k = size (B, transB == ' N' ? 2 : 1 )
705708 C = similar (A, $ elty, (m, k, max (size (A, 3 ), size (B, 3 ))))
706709 end
707- gemm_batched! (transA, transB, alpha, A, B, zero ($ elty), C)
710+ m, k, n, lda, ldb, ldc = _check_gemm_batched_dims (
711+ transA, transB, A, B, C)
712+ return gemm_batched! (transA, transB, alpha, A, B, zero ($ elty), C)
708713 end
709714 function gemm_batched (transA:: Char , transB:: Char , A:: T , B:: K ) where {
710- T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}} ,
711- K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}} ,
715+ T <: Union{AnyROCArray{$elty, 3}, Vector{<: ROCMatrix{$elty}}} ,
716+ K <: Union{AnyROCArray{$elty, 3}, Vector{<: ROCMatrix{$elty}}} ,
712717 }
713- gemm_batched (transA, transB, one ($ elty), A, B)
718+ return gemm_batched (transA, transB, one ($ elty), A, B)
714719 end
715720 end
716721end
@@ -863,12 +868,12 @@ for (fname, elty) in ((:rocblas_zhemm,:ComplexF64),
863868end
864869
865870# # herk
866- for (fname, elty) in ((:rocblas_zherk ,:ComplexF64 ),
867- (:rocblas_cherk ,:ComplexF32 ))
871+ for (fname, elty, relty ) in ((:rocblas_zherk ,:ComplexF64 , :Float64 ),
872+ (:rocblas_cherk ,:ComplexF32 , :Float32 ))
868873 @eval begin
869874 function herk! (
870- uplo:: Char , trans:: Char , alpha:: ($elty ) , A:: ROCVecOrMat{$elty} ,
871- beta:: ($elty ) , C:: ROCMatrix{$elty} ,
875+ uplo:: Char , trans:: Char , alpha:: ($relty ) , A:: ROCVecOrMat{$elty} ,
876+ beta:: ($relty ) , C:: ROCMatrix{$elty} ,
872877 )
873878 mC, n = size (C)
874879 if mC != n throw (DimensionMismatch (" C must be square" )) end
@@ -881,12 +886,12 @@ for (fname, elty) in ((:rocblas_zherk,:ComplexF64),
881886 $ (fname)(handle, uplo, trans, n, k, Ref (alpha), A, lda, Ref (beta), C, ldc)
882887 C
883888 end
884- function herk (uplo:: Char , trans:: Char , alpha:: ($elty ) , A:: ROCVecOrMat{$elty} )
889+ function herk (uplo:: Char , trans:: Char , alpha:: ($relty ) , A:: ROCVecOrMat{$elty} )
885890 n = size (A, trans == ' N' ? 1 : 2 )
886- herk! (uplo, trans, alpha, A, zero ($ elty ), similar (A, $ elty, (n,n)))
891+ herk! (uplo, trans, alpha, A, zero ($ relty ), similar (A, $ elty, (n,n)))
887892 end
888893 herk (uplo:: Char , trans:: Char , A:: ROCVecOrMat{$elty} ) =
889- herk (uplo, trans, one ($ elty ), A)
894+ herk (uplo, trans, one ($ relty ), A)
890895 end
891896end
892897
@@ -1029,7 +1034,7 @@ for (fname, elty) in
10291034 @eval begin
10301035 function trsm_batched! (
10311036 side:: Char , uplo:: Char , transa:: Char , diag:: Char , alpha:: ($elty) ,
1032- A:: Array{ROCMatrix{$elty},1} , B:: Array{ROCMatrix{$elty},1} ,
1037+ A:: Array{<: ROCMatrix{$elty},1} , B:: Array{<: ROCMatrix{$elty},1} ,
10331038 )
10341039 if ( length (A) != length (B) )
10351040 throw (DimensionMismatch (" " ))
@@ -1051,7 +1056,7 @@ for (fname, elty) in
10511056 end
10521057 function trsm_batched (
10531058 side:: Char , uplo:: Char , transa:: Char , diag:: Char , alpha:: ($elty) ,
1054- A:: Array{ROCMatrix{$elty},1} , B:: Array{ROCMatrix{$elty},1} ,
1059+ A:: Array{<: ROCMatrix{$elty},1} , B:: Array{<: ROCMatrix{$elty},1} ,
10551060 )
10561061 trsm_batched! (side, uplo, transa, diag, alpha, A, copy (B) )
10571062 end
@@ -1092,13 +1097,13 @@ for (fname, elty) in ((:rocblas_dgeam,:Float64),
10921097 )
10931098 m,n = size (B)
10941099 if ((transb == ' T' || transb == ' C' ))
1095- geam! ( transa, transb, alpha, A, beta, B, similar (B, $ elty, (n,m) ) )
1100+ return geam! ( transa, transb, alpha, A, beta, B, similar (B, $ elty, (n,m) ) )
10961101 end
10971102 if (transb == ' N' )
1098- geam! ( transa, transb, alpha, A, beta, B, similar (B, $ elty, (m,n) ) )
1103+ return geam! ( transa, transb, alpha, A, beta, B, similar (B, $ elty, (m,n) ) )
10991104 end
11001105 end
1101- geam ( uplo :: Char , trans :: Char , A:: ROCMatrix{$elty} , B:: ROCMatrix{$elty} ) = geam ( uplo, trans , one ($ elty), A, one ($ elty), B)
1106+ geam ( transa :: Char , transb :: Char , A:: ROCMatrix{$elty} , B:: ROCMatrix{$elty} ) = geam ( transa, transb , one ($ elty), A, one ($ elty), B)
11021107 end
11031108end
11041109
0 commit comments