639639 " - length(B)=$(length (B)) \n " *
640640 " - length(C)=$(length (C)) \n " ))
641641 end
642+ m, k, n = (- 1 , - 1 , - 1 )
642643 for (i, (As, Bs, Cs)) in enumerate (zip (A, B, C))
643644 m, k = size (As, transA == ' N' ? 1 : 2 ), size (As, transA == ' N' ? 2 : 1 )
644645 n, g = size (Bs, transB == ' N' ? 2 : 1 ), size (Bs, transB == ' N' ? 1 : 2 )
653654 lda = max (1 , stride (A[1 ], 2 ))
654655 ldb = max (1 , stride (B[1 ], 2 ))
655656 ldc = max (1 , stride (C[1 ], 2 ))
656- m, k, n, lda, ldb, ldc
657+ return m, k, n, lda, ldb, ldc
657658end
658659
659660# # (GE) general matrix-matrix multiplication batched
@@ -666,15 +667,17 @@ for (fname, elty) in
666667 @eval begin
667668 function gemm_batched! (
668669 transA:: Char , transB:: Char ,
669- alpha:: ($elty) , A:: ROCArray{$elty, 3} ,
670- B:: ROCArray{$elty, 3} , beta:: ($elty) , C:: ROCArray{$elty, 3} ,
671- )
670+ alpha:: ($elty) , A:: TA ,
671+ B:: TB , beta:: ($elty) , C:: TC ,
672+ ) where {TA<: Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}} ,
673+ TB<: Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}} ,
674+ TC<: Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}} }
672675 m, k, n, lda, ldb, ldc = _check_gemm_batched_dims (
673676 transA, transB, A, B, C)
674677
675- batch_count = size (C, 3 )
676- a_broadcast = (size (A, 3 ) == 1 ) && (batch_count > 1 )
677- b_broadcast = (size (B, 3 ) == 1 ) && (batch_count > 1 )
678+ batch_count = C isa ROCArray ? size (C, 3 ) : length (C )
679+ a_broadcast = A isa ROCArray && (size (A, 3 ) == 1 ) && (batch_count > 1 )
680+ b_broadcast = B isa ROCArray && (size (B, 3 ) == 1 ) && (batch_count > 1 )
678681 Ab = a_broadcast ? device_batch (A, batch_count) : device_batch (A)
679682 Bb = b_broadcast ? device_batch (B, batch_count) : device_batch (B)
680683 Cb = device_batch (C)
@@ -684,18 +687,18 @@ for (fname, elty) in
684687 handle, transA, transB,
685688 m, n, k, Ref (alpha), Ab, lda, Bb, ldb, Ref (beta),
686689 Cb, ldc, batch_count)
687- C
690+ return C
688691 end
689692 function gemm_batched (
690693 transA:: Char , transB:: Char , alpha:: ($elty) , A:: T , B:: K ,
691694 ) where {
692- T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}} ,
693- K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}} ,
695+ T <: Union{AnyROCArray{$elty, 3}, Vector{<: ROCMatrix{$elty}}} ,
696+ K <: Union{AnyROCArray{$elty, 3}, Vector{<: ROCMatrix{$elty}}} ,
694697 }
695698 is_ab_vec = Int (T <: Vector ) + Int (K <: Vector )
696699 (is_ab_vec != 0 ) && (is_ab_vec != 2 ) && throw (ArgumentError (
697700 " If `A` is a `Vector{ROCMatrix}`, then `B` must be too." ))
698- if T isa Vector
701+ if T <: Vector
699702 C = ROCMatrix{$ elty}[similar (B[i], $ elty, (
700703 size (A[i], transA == ' N' ? 1 : 2 ),
701704 size (B[i], transB == ' N' ? 2 : 1 ))) for i in 1 : length (A)]
@@ -704,13 +707,15 @@ for (fname, elty) in
704707 k = size (B, transB == ' N' ? 2 : 1 )
705708 C = similar (A, $ elty, (m, k, max (size (A, 3 ), size (B, 3 ))))
706709 end
707- gemm_batched! (transA, transB, alpha, A, B, zero ($ elty), C)
710+ m, k, n, lda, ldb, ldc = _check_gemm_batched_dims (
711+ transA, transB, A, B, C)
712+ return gemm_batched! (transA, transB, alpha, A, B, zero ($ elty), C)
708713 end
709714 function gemm_batched (transA:: Char , transB:: Char , A:: T , B:: K ) where {
710- T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}} ,
711- K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}} ,
715+ T <: Union{AnyROCArray{$elty, 3}, Vector{<: ROCMatrix{$elty}}} ,
716+ K <: Union{AnyROCArray{$elty, 3}, Vector{<: ROCMatrix{$elty}}} ,
712717 }
713- gemm_batched (transA, transB, one ($ elty), A, B)
718+ return gemm_batched (transA, transB, one ($ elty), A, B)
714719 end
715720 end
716721end
@@ -1029,7 +1034,7 @@ for (fname, elty) in
10291034 @eval begin
10301035 function trsm_batched! (
10311036 side:: Char , uplo:: Char , transa:: Char , diag:: Char , alpha:: ($elty) ,
1032- A:: Array{ROCMatrix{$elty},1} , B:: Array{ROCMatrix{$elty},1} ,
1037+ A:: Array{<: ROCMatrix{$elty},1} , B:: Array{<: ROCMatrix{$elty},1} ,
10331038 )
10341039 if ( length (A) != length (B) )
10351040 throw (DimensionMismatch (" " ))
@@ -1051,7 +1056,7 @@ for (fname, elty) in
10511056 end
10521057 function trsm_batched (
10531058 side:: Char , uplo:: Char , transa:: Char , diag:: Char , alpha:: ($elty) ,
1054- A:: Array{ROCMatrix{$elty},1} , B:: Array{ROCMatrix{$elty},1} ,
1059+ A:: Array{<: ROCMatrix{$elty},1} , B:: Array{<: ROCMatrix{$elty},1} ,
10551060 )
10561061 trsm_batched! (side, uplo, transa, diag, alpha, A, copy (B) )
10571062 end
0 commit comments