Skip to content

Commit b799f4a

Browse files
author
Katharine Hyatt
committed
Even more BLAS tests and fixes
1 parent eb928bc commit b799f4a

File tree

2 files changed

+80
-17
lines changed

2 files changed

+80
-17
lines changed

src/blas/wrappers.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ end
639639
"- length(B)=$(length(B))\n" *
640640
"- length(C)=$(length(C))\n"))
641641
end
642+
m, k, n = (-1, -1, -1)
642643
for (i, (As, Bs, Cs)) in enumerate(zip(A, B, C))
643644
m, k = size(As, transA == 'N' ? 1 : 2), size(As, transA == 'N' ? 2 : 1)
644645
n, g = size(Bs, transB == 'N' ? 2 : 1), size(Bs, transB == 'N' ? 1 : 2)
@@ -653,7 +654,7 @@ end
653654
lda = max(1, stride(A[1], 2))
654655
ldb = max(1, stride(B[1], 2))
655656
ldc = max(1, stride(C[1], 2))
656-
m, k, n, lda, ldb, ldc
657+
return m, k, n, lda, ldb, ldc
657658
end
658659

659660
## (GE) general matrix-matrix multiplication batched
@@ -666,15 +667,17 @@ for (fname, elty) in
666667
@eval begin
667668
function gemm_batched!(
668669
transA::Char, transB::Char,
669-
alpha::($elty), A::ROCArray{$elty, 3},
670-
B::ROCArray{$elty, 3}, beta::($elty), C::ROCArray{$elty, 3},
671-
)
670+
alpha::($elty), A::TA,
671+
B::TB, beta::($elty), C::TC,
672+
) where {TA<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
673+
TB<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
674+
TC<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}}}
672675
m, k, n, lda, ldb, ldc = _check_gemm_batched_dims(
673676
transA, transB, A, B, C)
674677

675-
batch_count = size(C, 3)
676-
a_broadcast = (size(A, 3) == 1) && (batch_count > 1)
677-
b_broadcast = (size(B, 3) == 1) && (batch_count > 1)
678+
batch_count = C isa ROCArray ? size(C, 3) : length(C)
679+
a_broadcast = A isa ROCArray && (size(A, 3) == 1) && (batch_count > 1)
680+
b_broadcast = B isa ROCArray && (size(B, 3) == 1) && (batch_count > 1)
678681
Ab = a_broadcast ? device_batch(A, batch_count) : device_batch(A)
679682
Bb = b_broadcast ? device_batch(B, batch_count) : device_batch(B)
680683
Cb = device_batch(C)
@@ -684,18 +687,18 @@ for (fname, elty) in
684687
handle, transA, transB,
685688
m, n, k, Ref(alpha), Ab, lda, Bb, ldb, Ref(beta),
686689
Cb, ldc, batch_count)
687-
C
690+
return C
688691
end
689692
function gemm_batched(
690693
transA::Char, transB::Char, alpha::($elty), A::T, B::K,
691694
) where {
692-
T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
693-
K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
695+
T <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
696+
K <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
694697
}
695698
is_ab_vec = Int(T <: Vector) + Int(K <: Vector)
696699
(is_ab_vec != 0) && (is_ab_vec != 2) && throw(ArgumentError(
697700
"If `A` is a `Vector{ROCMatrix}`, then `B` must be too."))
698-
if T isa Vector
701+
if T <: Vector
699702
C = ROCMatrix{$elty}[similar(B[i], $elty, (
700703
size(A[i], transA == 'N' ? 1 : 2),
701704
size(B[i], transB == 'N' ? 2 : 1))) for i in 1:length(A)]
@@ -704,13 +707,15 @@ for (fname, elty) in
704707
k = size(B, transB == 'N' ? 2 : 1)
705708
C = similar(A, $elty, (m, k, max(size(A, 3), size(B, 3))))
706709
end
707-
gemm_batched!(transA, transB, alpha, A, B, zero($elty), C)
710+
m, k, n, lda, ldb, ldc = _check_gemm_batched_dims(
711+
transA, transB, A, B, C)
712+
return gemm_batched!(transA, transB, alpha, A, B, zero($elty), C)
708713
end
709714
function gemm_batched(transA::Char, transB::Char, A::T, B::K) where {
710-
T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
711-
K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
715+
T <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
716+
K <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
712717
}
713-
gemm_batched(transA, transB, one($elty), A, B)
718+
return gemm_batched(transA, transB, one($elty), A, B)
714719
end
715720
end
716721
end
@@ -1029,7 +1034,7 @@ for (fname, elty) in
10291034
@eval begin
10301035
function trsm_batched!(
10311036
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
1032-
A::Array{ROCMatrix{$elty},1}, B::Array{ROCMatrix{$elty},1},
1037+
A::Array{<:ROCMatrix{$elty},1}, B::Array{<:ROCMatrix{$elty},1},
10331038
)
10341039
if( length(A) != length(B) )
10351040
throw(DimensionMismatch(""))
@@ -1051,7 +1056,7 @@ for (fname, elty) in
10511056
end
10521057
function trsm_batched(
10531058
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
1054-
A::Array{ROCMatrix{$elty},1}, B::Array{ROCMatrix{$elty},1},
1059+
A::Array{<:ROCMatrix{$elty},1}, B::Array{<:ROCMatrix{$elty},1},
10551060
)
10561061
trsm_batched!(side, uplo, transa, diag, alpha, A, copy(B) )
10571062
end

test/rocarray/blas.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ end
5151
@test testf(norm, rand(T, m))
5252
@test testf(BLAS.asum, rand(T, m))
5353
@test testf(axpy!, Ref(rand()), rand(T, m), rand(T, m))
54+
@test testf(axpy!, Ref(rand()), rand(T, m), 1:m-1, rand(T, m), 1:m-1)
5455
@test testf(axpby!, Ref(rand()), rand(T, m), Ref(rand()), rand(T, m))
5556

5657
@test testf(rotate!, rand(T, m), rand(T, m), rand(real(T)), rand(real(T)))
@@ -147,6 +148,14 @@ end
147148
@test testf((a, b) -> f(TR(a)) * b, A, x)
148149
@test testf((a, b) -> lmul!(f(TR(a)), b), A, copy(x))
149150
end
151+
@testset "trmv" begin
152+
A = rand(T, m, m)
153+
dA = ROCArray(A)
154+
x = rand(T, m)
155+
dx = ROCArray(x)
156+
dy = rocBLAS.trmv('U', 'N', 'N', dA, dx)
157+
@test collect(dy) triu(A) * x
158+
end
150159

151160
A, x = rand(T, m, m), rand(T, m)
152161
@testset "Triangular ldiv" for TR in (
@@ -157,6 +166,14 @@ end
157166
@test testf((a, b) -> f(TR(a)) \ b, A, x)
158167
@test testf((a, b) -> ldiv!(f(TR(a)), b), A, copy(x))
159168
end
169+
@testset "trsv" begin
170+
A = rand(T, m, m)
171+
dA = ROCArray(A)
172+
x = rand(T, m)
173+
dx = ROCArray(x)
174+
dy = rocBLAS.trsv('U', 'N', 'N', dA, dx)
175+
@test collect(dy) triu(A) \ x
176+
end
160177

161178
x = rand(T, m, m)
162179
@testset "inv($TR)" for TR in (
@@ -372,6 +389,25 @@ end
372389
(a, b) -> b / adjtype(uplotype(a)),
373390
triu(rand(T, m, m)), rand(T, n, m))
374391
end
392+
@testset "trsm" begin
393+
A = rand(T, m, m)
394+
dA = ROCArray(A)
395+
B = rand(T, m, m)
396+
dB = ROCArray(B)
397+
dC = rocBLAS.trsm('L', 'U', 'N', 'N', one(T), dA, dB)
398+
@test collect(dC) triu(A) \ B
399+
end
400+
@testset "trsm_batched" begin
401+
batch_count = 3
402+
A = [rand(T, m, m) for ix in 1:batch_count]
403+
dA = [ROCArray(A_) for A_ in A]
404+
B = [rand(T, m, m) for ix in 1:batch_count]
405+
dB = [ROCArray(B_) for B_ in B]
406+
dC = rocBLAS.trsm_batched('L', 'U', 'N', 'N', one(T), dA, dB)
407+
for ix in 1:batch_count
408+
@test collect(dC[ix]) triu(A[ix]) \ B[ix]
409+
end
410+
end
375411

376412
@testset "triangular-dense mul ($T, $adjtype, $uplotype)" for adjtype in (
377413
identity, adjoint, transpose,
@@ -389,6 +425,14 @@ end
389425
(c, a, b) -> mul!(c, b, adjtype(uplotype(a))),
390426
zeros(T, n, m), A, rand(T, n, m))
391427
end
428+
@testset "trmm" begin
429+
A = rand(T, m, m)
430+
dA = ROCArray(A)
431+
B = rand(T, m, m)
432+
dB = ROCArray(B)
433+
dC = rocBLAS.trmm('L', 'U', 'N', 'N', one(T), dA, dB)
434+
@test collect(dC) triu(A) * B
435+
end
392436

393437
@testset "triangular-triangular mul" for (TRa, ta, TRb, tb) in (
394438
(UpperTriangular, identity, LowerTriangular, identity),
@@ -452,6 +496,20 @@ end
452496
(bt == 'T' ? transpose(B[:, :, i]) : B[:, :, i])
453497
@test C[:, :, i] c
454498
end
499+
A = [rand(T, 4, 4) for ix in 1:batch_count]
500+
B = [rand(T, 4, 4) for ix in 1:batch_count]
501+
RA = [ROCArray(A_) for A_ in A]
502+
RB = [ROCArray(B_) for B_ in B]
503+
504+
RC = rocBLAS.gemm_batched(at, bt, RA, RB)
505+
@test length(RC) == batch_count
506+
C = [Array(RC_) for RC_ in RC]
507+
for i in 1:batch_count
508+
c =
509+
(at == 'T' ? transpose(A[i]) : A[i]) *
510+
(bt == 'T' ? transpose(B[i]) : B[i])
511+
@test C[i] c
512+
end
455513
end
456514
end
457515
end

0 commit comments

Comments
 (0)