Skip to content

Commit 1306be5

Browse files
Merge branch 'JuliaGPU:master' into syrk_strided
2 parents 07ba9be + 15824e6 commit 1306be5

File tree

11 files changed

+512
-45
lines changed

11 files changed

+512
-45
lines changed

.buildkite/pipeline.yml

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ steps:
7575
- label: "Julia 1.12"
7676
plugins:
7777
- JuliaCI/julia#v1:
78-
version: "1.12-nightly"
78+
version: "1.12"
7979
- JuliaCI/julia-test#v1:
8080
- JuliaCI/julia-coverage#v1:
8181
codecov: true
@@ -111,7 +111,25 @@ steps:
111111
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
112112
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
113113
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
114-
soft_fail: true
114+
115+
- label: "Julia 1.11 Enzyme"
116+
plugins:
117+
- JuliaCI/julia#v1:
118+
version: "1.11"
119+
- JuliaCI/julia-test#v1:
120+
test_args: "enzyme"
121+
agents:
122+
queue: "juliagpu"
123+
rocm: "*"
124+
rocmgpu: "*"
125+
if: build.message !~ /\[skip tests\]/
126+
command: "julia --project -e 'using Pkg; Pkg.update()'"
127+
timeout_in_minutes: 30
128+
env:
129+
JULIA_NUM_THREADS: 4
130+
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
131+
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
132+
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
115133

116134
- label: "GPU-less environment"
117135
plugins:

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AMDGPU"
22
uuid = "21141c5a-9bdb-4563-92ae-f87d6854732e"
33
authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>", "Valentin Churavy <v.churavy@gmail.com>", "Anton Smirnov <tonysmn97@gmail.com>"]
4-
version = "2.1.0"
4+
version = "2.1.2"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -52,13 +52,13 @@ EnzymeCore = "0.8"
5252
ExprTools = "0.1"
5353
GPUArrays = "11.2.1"
5454
GPUCompiler = "1"
55-
GPUToolbox = "0.1.0, 0.2, 0.3"
55+
GPUToolbox = "0.1.0, 0.2, 0.3, 1"
5656
KernelAbstractions = "0.9.2"
5757
LLD_jll = "15, 16, 17, 18, 19"
5858
LLVM = "9"
5959
LLVM_jll = "15, 16, 17, 18, 19"
6060
Preferences = "1"
61-
PrettyTables = "2"
61+
PrettyTables = "3"
6262
ROCmDeviceLibs_jll = "=5.6.1, =6.2.1"
6363
Random123 = "1.6"
6464
RandomNumbers = "1.5"

ext/AMDGPUEnzymeCoreExt/AMDGPUEnzymeCoreExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ function EnzymeCore.compiler_job_from_backend(
1414
return GPUCompiler.CompilerJob(mi, AMDGPU.compiler_config(AMDGPU.device()))
1515
end
1616

17+
function EnzymeRules.forward(
18+
config, fn::Const{typeof(AMDGPU.hipfunction)}, ::Type{<: Const},
19+
f::Const{F}, tt::Const{TT}; kwargs...,
20+
) where {F, TT}
21+
res = fn.val(f.val, tt.val; kwargs...)
22+
return res
23+
end
24+
1725
function EnzymeRules.forward(
1826
config, fn::Const{typeof(AMDGPU.hipfunction)}, ::Type{<: Duplicated},
1927
f::Const{F}, tt::Const{TT}; kwargs...,

src/blas/highlevel.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ function LinearAlgebra.:(*)(
4545
dotu(n, x, stride(x, 1), y, stride(y, 1))
4646
end
4747

48+
LinearAlgebra.norm(x::Diagonal{T, <:StridedROCVector{T}}, p::Real=2) where {T<:Union{Float16, ComplexF16, ROCBLASFloat}} = norm(x.diag, p)
4849
LinearAlgebra.norm(x::ROCArray{T}) where T <: ROCBLASFloat = nrm2(length(x), x, stride(x, 1))
4950
LinearAlgebra.BLAS.asum(x::ROCArray{T}) where T <: ROCBLASFloat = asum(length(x), x, stride(x, 1))
5051

@@ -163,7 +164,7 @@ if VERSION >= v"1.12-"
163164
# https://github.com/JuliaLang/LinearAlgebra.jl/blob/4e7c3f40316a956119ac419a97c4b8aad7a17e6c/src/matmul.jl#L490
164165
for blas_flag in (LinearAlgebra.BlasFlag.SyrkHerkGemm, LinearAlgebra.BlasFlag.SymmHemmGeneric)
165166
@eval LinearAlgebra.generic_matmatmul_wrapper!(
166-
C::StridedROCVecOrMat, tA, tB, A::StridedROCVecOrMat, B::StridedROCVecOrMat,
167+
C::StridedROCMatrix, tA::AbstractChar, tB::AbstractChar, A::StridedROCVecOrMat, B::StridedROCVecOrMat,
167168
alpha::Number, beta::Number, ::$blas_flag) =
168169
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta)
169170
end
@@ -327,7 +328,7 @@ end
327328
if VERSION v"1.12-"
328329
# Otherwise, dispatches to:
329330
# https://github.com/JuliaLang/LinearAlgebra.jl/blob/4e7c3f40316a956119ac419a97c4b8aad7a17e6c/src/generic.jl#L2092
330-
LinearAlgebra.copytrito!(B::Matrix{T}, A::ROCMatrix{T}, uplo::AbstractChar) where {T <: ROCBLASFloat} =
331+
LinearAlgebra.copytrito!(B::Matrix{T}, A::ROCMatrix{T}, uplo::AbstractChar) where {T <: ROCBLASFloat} =
331332
invoke(LinearAlgebra.copytrito!, Tuple{AbstractMatrix, AbstractMatrix, AbstractChar}, B, A, uplo)
332333
end
333334

src/blas/util.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,35 @@
1+
# convert matrix to band storage
2+
function band(A::AbstractMatrix,kl,ku)
3+
m, n = size(A)
4+
AB = zeros(eltype(A),kl+ku+1,n)
5+
for j = 1:n
6+
for i = max(1,j-ku):min(m,j+kl)
7+
AB[ku+1-j+i,j] = A[i,j]
8+
end
9+
end
10+
return AB
11+
end
12+
13+
# convert band storage to general matrix
14+
function unband(AB::AbstractMatrix,m,kl,ku)
15+
bm, n = size(AB)
16+
A = zeros(eltype(AB),m,n)
17+
for j = 1:n
18+
for i = max(1,j-ku):min(m,j+kl)
19+
A[i,j] = AB[ku+1-j+i,j]
20+
end
21+
end
22+
return A
23+
end
24+
25+
# zero out elements not on matrix bands
26+
function bandex(A::AbstractMatrix,kl,ku)
27+
m, n = size(A)
28+
AB = band(A,kl,ku)
29+
B = unband(AB,m,kl,ku)
30+
return B
31+
end
32+
133
const ROCBLASReal = Union{Float32, Float64}
234
const ROCBLASComplex = Union{ComplexF32, ComplexF64}
335
const ROCBLASFloat = Union{ROCBLASReal, ROCBLASComplex}

src/blas/wrappers.jl

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,10 @@ for (fname, elty) in ((:rocblas_stbmv,:Float32),
368368
x
369369
end
370370
function tbmv(
371-
uplo::Char, trans::Char, diag::Char,
371+
uplo::Char, trans::Char, diag::Char, k::Integer,
372372
A::ROCMatrix{$elty}, x::ROCVector{$elty},
373373
)
374-
tbmv!(uplo, trans, diag, A, copy(x))
374+
tbmv!(uplo, trans, diag, k, A, copy(x))
375375
end
376376
end
377377
end
@@ -496,10 +496,10 @@ for (fname, elty) in ((:rocblas_dsyr,:Float64),
496496
end
497497

498498
### her
499-
for (fname, elty) in ((:rocblas_zher,:ComplexF64),
500-
(:rocblas_cher,:ComplexF32))
499+
for (fname, elty, relty) in ((:rocblas_zher,:ComplexF64,:Float64),
500+
(:rocblas_cher,:ComplexF32,:Float32))
501501
@eval begin
502-
function her!(uplo::Char, alpha::$elty, x::ROCVector{$elty}, A::ROCMatrix{$elty})
502+
function her!(uplo::Char, alpha::$relty, x::ROCVector{$elty}, A::ROCMatrix{$elty})
503503
m, n = size(A)
504504
m == n || throw(DimensionMismatch("Matrix A is $m by $n but must be square"))
505505
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
@@ -639,6 +639,7 @@ end
639639
"- length(B)=$(length(B))\n" *
640640
"- length(C)=$(length(C))\n"))
641641
end
642+
m, k, n = (-1, -1, -1)
642643
for (i, (As, Bs, Cs)) in enumerate(zip(A, B, C))
643644
m, k = size(As, transA == 'N' ? 1 : 2), size(As, transA == 'N' ? 2 : 1)
644645
n, g = size(Bs, transB == 'N' ? 2 : 1), size(Bs, transB == 'N' ? 1 : 2)
@@ -653,7 +654,7 @@ end
653654
lda = max(1, stride(A[1], 2))
654655
ldb = max(1, stride(B[1], 2))
655656
ldc = max(1, stride(C[1], 2))
656-
m, k, n, lda, ldb, ldc
657+
return m, k, n, lda, ldb, ldc
657658
end
658659

659660
## (GE) general matrix-matrix multiplication batched
@@ -666,15 +667,17 @@ for (fname, elty) in
666667
@eval begin
667668
function gemm_batched!(
668669
transA::Char, transB::Char,
669-
alpha::($elty), A::ROCArray{$elty, 3},
670-
B::ROCArray{$elty, 3}, beta::($elty), C::ROCArray{$elty, 3},
671-
)
670+
alpha::($elty), A::TA,
671+
B::TB, beta::($elty), C::TC,
672+
) where {TA<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
673+
TB<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
674+
TC<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}}}
672675
m, k, n, lda, ldb, ldc = _check_gemm_batched_dims(
673676
transA, transB, A, B, C)
674677

675-
batch_count = size(C, 3)
676-
a_broadcast = (size(A, 3) == 1) && (batch_count > 1)
677-
b_broadcast = (size(B, 3) == 1) && (batch_count > 1)
678+
batch_count = C isa ROCArray ? size(C, 3) : length(C)
679+
a_broadcast = A isa ROCArray && (size(A, 3) == 1) && (batch_count > 1)
680+
b_broadcast = B isa ROCArray && (size(B, 3) == 1) && (batch_count > 1)
678681
Ab = a_broadcast ? device_batch(A, batch_count) : device_batch(A)
679682
Bb = b_broadcast ? device_batch(B, batch_count) : device_batch(B)
680683
Cb = device_batch(C)
@@ -684,18 +687,18 @@ for (fname, elty) in
684687
handle, transA, transB,
685688
m, n, k, Ref(alpha), Ab, lda, Bb, ldb, Ref(beta),
686689
Cb, ldc, batch_count)
687-
C
690+
return C
688691
end
689692
function gemm_batched(
690693
transA::Char, transB::Char, alpha::($elty), A::T, B::K,
691694
) where {
692-
T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
693-
K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
695+
T <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
696+
K <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
694697
}
695698
is_ab_vec = Int(T <: Vector) + Int(K <: Vector)
696699
(is_ab_vec != 0) && (is_ab_vec != 2) && throw(ArgumentError(
697700
"If `A` is a `Vector{ROCMatrix}`, then `B` must be too."))
698-
if T isa Vector
701+
if T <: Vector
699702
C = ROCMatrix{$elty}[similar(B[i], $elty, (
700703
size(A[i], transA == 'N' ? 1 : 2),
701704
size(B[i], transB == 'N' ? 2 : 1))) for i in 1:length(A)]
@@ -704,13 +707,15 @@ for (fname, elty) in
704707
k = size(B, transB == 'N' ? 2 : 1)
705708
C = similar(A, $elty, (m, k, max(size(A, 3), size(B, 3))))
706709
end
707-
gemm_batched!(transA, transB, alpha, A, B, zero($elty), C)
710+
m, k, n, lda, ldb, ldc = _check_gemm_batched_dims(
711+
transA, transB, A, B, C)
712+
return gemm_batched!(transA, transB, alpha, A, B, zero($elty), C)
708713
end
709714
function gemm_batched(transA::Char, transB::Char, A::T, B::K) where {
710-
T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
711-
K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
715+
T <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
716+
K <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
712717
}
713-
gemm_batched(transA, transB, one($elty), A, B)
718+
return gemm_batched(transA, transB, one($elty), A, B)
714719
end
715720
end
716721
end
@@ -863,12 +868,12 @@ for (fname, elty) in ((:rocblas_zhemm,:ComplexF64),
863868
end
864869

865870
## herk
866-
for (fname, elty) in ((:rocblas_zherk,:ComplexF64),
867-
(:rocblas_cherk,:ComplexF32))
871+
for (fname, elty, relty) in ((:rocblas_zherk,:ComplexF64,:Float64),
872+
(:rocblas_cherk,:ComplexF32,:Float32))
868873
@eval begin
869874
function herk!(
870-
uplo::Char, trans::Char, alpha::($elty), A::ROCVecOrMat{$elty},
871-
beta::($elty), C::ROCMatrix{$elty},
875+
uplo::Char, trans::Char, alpha::($relty), A::ROCVecOrMat{$elty},
876+
beta::($relty), C::ROCMatrix{$elty},
872877
)
873878
mC, n = size(C)
874879
if mC != n throw(DimensionMismatch("C must be square")) end
@@ -881,12 +886,12 @@ for (fname, elty) in ((:rocblas_zherk,:ComplexF64),
881886
$(fname)(handle, uplo, trans, n, k, Ref(alpha), A, lda, Ref(beta), C, ldc)
882887
C
883888
end
884-
function herk(uplo::Char, trans::Char, alpha::($elty), A::ROCVecOrMat{$elty})
889+
function herk(uplo::Char, trans::Char, alpha::($relty), A::ROCVecOrMat{$elty})
885890
n = size(A, trans == 'N' ? 1 : 2)
886-
herk!(uplo, trans, alpha, A, zero($elty), similar(A, $elty, (n,n)))
891+
herk!(uplo, trans, alpha, A, zero($relty), similar(A, $elty, (n,n)))
887892
end
888893
herk(uplo::Char, trans::Char, A::ROCVecOrMat{$elty}) =
889-
herk(uplo, trans, one($elty), A)
894+
herk(uplo, trans, one($relty), A)
890895
end
891896
end
892897

@@ -1029,7 +1034,7 @@ for (fname, elty) in
10291034
@eval begin
10301035
function trsm_batched!(
10311036
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
1032-
A::Array{ROCMatrix{$elty},1}, B::Array{ROCMatrix{$elty},1},
1037+
A::Array{<:ROCMatrix{$elty},1}, B::Array{<:ROCMatrix{$elty},1},
10331038
)
10341039
if( length(A) != length(B) )
10351040
throw(DimensionMismatch(""))
@@ -1051,7 +1056,7 @@ for (fname, elty) in
10511056
end
10521057
function trsm_batched(
10531058
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
1054-
A::Array{ROCMatrix{$elty},1}, B::Array{ROCMatrix{$elty},1},
1059+
A::Array{<:ROCMatrix{$elty},1}, B::Array{<:ROCMatrix{$elty},1},
10551060
)
10561061
trsm_batched!(side, uplo, transa, diag, alpha, A, copy(B) )
10571062
end
@@ -1092,13 +1097,13 @@ for (fname, elty) in ((:rocblas_dgeam,:Float64),
10921097
)
10931098
m,n = size(B)
10941099
if ((transb == 'T' || transb == 'C'))
1095-
geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (n,m) ) )
1100+
return geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (n,m) ) )
10961101
end
10971102
if (transb == 'N')
1098-
geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (m,n) ) )
1103+
return geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (m,n) ) )
10991104
end
11001105
end
1101-
geam( uplo::Char, trans::Char, A::ROCMatrix{$elty}, B::ROCMatrix{$elty}) = geam( uplo, trans, one($elty), A, one($elty), B)
1106+
geam( transa::Char, transb::Char, A::ROCMatrix{$elty}, B::ROCMatrix{$elty}) = geam( transa, transb, one($elty), A, one($elty), B)
11021107
end
11031108
end
11041109

src/discovery/discovery.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function _hip_runtime_version()
4444
VersionNumber(major, minor, patch)
4545
end
4646

47-
global rel_libdir::String = ""
47+
global rel_libdir::String = Sys.islinux() ? "" : "bin"
4848
global libhsaruntime::String = ""
4949
global lld_path::String = ""
5050
global lld_artifact::Bool = false

src/hip/device.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ function __pretty_data(dev::HIPDevice)
168168
end
169169

170170
function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, dev::HIPDevice)
171-
PrettyTables.pretty_table(io, __pretty_data(dev); header=[
171+
PrettyTables.pretty_table(io, __pretty_data(dev); column_labels=[
172172
"Id", "Name", "GCN arch", "Wavefront", "Memory", "Shared Memory"])
173173
end
174174

175175
function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, devs::Vector{HIPDevice})
176-
PrettyTables.pretty_table(io, vcat(__pretty_data.(devs)...); header=[
176+
PrettyTables.pretty_table(io, vcat(__pretty_data.(devs)...); column_labels=[
177177
"Id", "Name", "GCN arch", "Wavefront", "Memory", "Shared Memory"])
178178
end

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function versioninfo()
1515
_status(functional(:MIOpen)) "MIOpen" _ver(:MIOpen, MIOpen.version) _libpath(libMIOpen_path);
1616
]
1717

18-
PrettyTables.pretty_table(data; header=[
18+
PrettyTables.pretty_table(data; column_labels=[
1919
"Available", "Name", "Version", "Path"],
2020
alignment=[:c, :l, :l, :l])
2121

0 commit comments

Comments
 (0)