Skip to content

Commit b1f1c64

Browse files
committed
Implement DiagonalAlgorithm for GPU
1 parent ba2c9ef commit b1f1c64

File tree

12 files changed

+560
-195
lines changed

12 files changed

+560
-195
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,18 @@ end
2727
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
2828
return ROCSOLVER_DivideAndConquer(; kwargs...)
2929
end
30+
for f in (
31+
:(MatrixAlgebraKit.default_lq_algorithm),
32+
:(MatrixAlgebraKit.default_qr_algorithm),
33+
:(MatrixAlgebraKit.default_eig_algorithm),
34+
:(MatrixAlgebraKit.default_eigh_algorithm),
35+
:(MatrixAlgebraKit.default_svd_algorithm),
36+
)
37+
38+
@eval function $f(::Type{T}; kwargs...) where {S, T <: Diagonal{S, <:StridedROCVector}}
39+
return DiagonalAlgorithm(; kwargs...)
40+
end
41+
end
3042

3143
_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A)
3244
_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ)

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,19 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT
3232
return CUSOLVER_DivideAndConquer(; kwargs...)
3333
end
3434

35+
for f in (
36+
:(MatrixAlgebraKit.default_lq_algorithm),
37+
:(MatrixAlgebraKit.default_qr_algorithm),
38+
:(MatrixAlgebraKit.default_eig_algorithm),
39+
:(MatrixAlgebraKit.default_eigh_algorithm),
40+
:(MatrixAlgebraKit.default_svd_algorithm),
41+
)
42+
43+
@eval function $f(::Type{T}; kwargs...) where {S, T <: Diagonal{S, <:StridedCuVector}}
44+
return DiagonalAlgorithm(; kwargs...)
45+
end
46+
end
47+
3548
# include for block sector support
3649
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
3750
return CUSOLVER_HouseholderQR(; kwargs...)

src/implementations/svd.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,12 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
218218
check_input(svd_full!, A, USVᴴ, alg)
219219
Ad = diagview(A)
220220
U, S, Vᴴ = USVᴴ
221-
p = sortperm(Ad; by = abs, rev = true)
221+
p = if isempty(Ad)
222+
Int[]
223+
else
224+
sortperm(Ad; by = abs, rev = true)
225+
end
226+
222227
zero!(U)
223228
zero!(Vᴴ)
224229
n = size(A, 1)

test/amd/eigh.jl

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ using LinearAlgebra: LinearAlgebra, Diagonal, I
66
using MatrixAlgebraKit: TruncatedAlgorithm, diagview
77
using AMDGPU
88

9-
@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
9+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
10+
11+
@testset "eigh_full! for T = $T" for T in BLASFloats
1012
rng = StableRNG(123)
1113
m = 54
1214
for alg in (
@@ -32,11 +34,11 @@ using AMDGPU
3234
end
3335
end
3436

35-
#=@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
37+
#=@testset "eigh_trunc! for T = $T" for T in BLASFloats
3638
rng = StableRNG(123)
3739
m = 54
38-
for alg in (CUSOLVER_QRIteration(),
39-
CUSOLVER_DivideAndConquer(),
40+
for alg in (ROCSOLVER_QRIteration(),
41+
ROCSOLVER_DivideAndConquer(),
4042
)
4143
A = ROCArray(randn(rng, T, m, m))
4244
A = A * A'
@@ -64,18 +66,40 @@ end
6466
end
6567
end
6668
67-
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in
68-
(Float32, Float64,
69-
ComplexF32,
70-
ComplexF64)
69+
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats
7170
rng = StableRNG(123)
7271
m = 4
7372
V = qr_compact(ROCArray(randn(rng, T, m, m)))[1]
7473
D = Diagonal([0.9, 0.3, 0.1, 0.01])
7574
A = V * D * V'
7675
A = (A + A') / 2
77-
alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2))
76+
alg = TruncatedAlgorithm(ROCSOLVER_QRIteration(), truncrank(2))
7877
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
7978
@test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
8079
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
8180
end=#
81+
82+
@testset "eigh for Diagonal{$T}" for T in BLASFloats
83+
rng = StableRNG(123)
84+
m = 54
85+
Ad = randn(rng, T, m)
86+
Ad .+= conj.(Ad)
87+
A = Diagonal(ROCArray(Ad))
88+
atol = sqrt(eps(real(T)))
89+
90+
D, V = @constinferred eigh_full(A)
91+
@test D isa Diagonal{real(T)} && size(D) == size(A)
92+
@test V isa Diagonal{T} && size(V) == size(A)
93+
@test A * V V * D
94+
95+
D2 = @constinferred eigh_vals(A)
96+
@test D2 isa AbstractVector{real(T)} && length(D2) == m
97+
@test diagview(D) D2
98+
99+
# TODO partialsortperm
100+
#=A2 = Diagonal(ROCArray(T[0.9, 0.3, 0.1, 0.01]))
101+
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
102+
D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg)
103+
@test diagview(D2) ≈ diagview(A2)[1:2]
104+
@test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol=#
105+
end

test/amd/lq.jl

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ using Test
44
using TestExtras
55
using StableRNGs
66
using AMDGPU
7+
using LinearAlgebra
78

89
include(joinpath("..", "utilities.jl"))
910

10-
@testset "lq_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
11+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
12+
13+
@testset "lq_compact! for T = $T" for T in BLASFloats
1114
rng = StableRNG(123)
1215
m = 54
1316
for n in (37, m, 63)
@@ -65,7 +68,7 @@ include(joinpath("..", "utilities.jl"))
6568
end
6669
end
6770

68-
@testset "lq_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
71+
@testset "lq_full! for T = $T" for T in BLASFloats
6972
rng = StableRNG(123)
7073
m = 54
7174
for n in (37, m, 63)
@@ -115,3 +118,48 @@ end
115118
@test_throws ArgumentError lq_full!(copy!(Ac, A), (L, Q); blocksize = 8)
116119
end
117120
end
121+
122+
@testset "lq_compact, lq_full and lq_null for Diagonal{$T}" for T in BLASFloats
123+
rng = StableRNG(123)
124+
atol = eps(real(T))^(3 / 4)
125+
for m in (54, 0)
126+
Ad = ROCArray(randn(rng, T, m))
127+
A = Diagonal(Ad)
128+
129+
# compact
130+
L, Q = @constinferred lq_compact(A)
131+
@test Q isa Diagonal{T} && size(Q) == (m, m)
132+
@test L isa Diagonal{T} && size(L) == (m, m)
133+
@test L * Q A
134+
@test isunitary(Q)
135+
136+
# compact and positive
137+
Lp, Qp = @constinferred lq_compact(A; positive = true)
138+
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
139+
@test Lp isa Diagonal{T} && size(Lp) == (m, m)
140+
@test Lp * Qp A
141+
@test isunitary(Qp)
142+
@test all((zero(real(T))), real(diag(Lp))) &&
143+
all((zero(real(T)); atol), imag(diag(Lp)))
144+
145+
# full
146+
L, Q = @constinferred lq_full(A)
147+
@test Q isa Diagonal{T} && size(Q) == (m, m)
148+
@test L isa Diagonal{T} && size(L) == (m, m)
149+
@test L * Q A
150+
@test isunitary(Q)
151+
152+
# full and positive
153+
Lp, Qp = @constinferred lq_full(A; positive = true)
154+
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
155+
@test Lp isa Diagonal{T} && size(Lp) == (m, m)
156+
@test Lp * Qp A
157+
@test isunitary(Qp)
158+
@test all((zero(real(T))), real(diag(Lp))) &&
159+
all((zero(real(T)); atol), imag(diag(Lp)))
160+
161+
# null
162+
N = @constinferred lq_null(A)
163+
@test N isa AbstractMatrix{T} && size(N) == (0, m)
164+
end
165+
end

test/amd/qr.jl

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ using Test
44
using TestExtras
55
using StableRNGs
66
using AMDGPU
7+
using LinearAlgebra
78

89
include(joinpath("..", "utilities.jl"))
910

10-
eltypes = (Float32, Float64, ComplexF32, ComplexF64)
11+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
1112

12-
@testset "qr_compact! and qr_null! for T = $T" for T in eltypes
13+
@testset "qr_compact! and qr_null! for T = $T" for T in BLASFloats
1314
rng = StableRNG(123)
1415
m = 54
1516
for n in (37, m, 63)
@@ -68,7 +69,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64)
6869
end
6970
end
7071

71-
@testset "qr_full! for T = $T" for T in eltypes
72+
@testset "qr_full! for T = $T" for T in BLASFloats
7273
rng = StableRNG(123)
7374
m = 63
7475
for n in (37, m, 63)
@@ -121,3 +122,48 @@ end
121122
@test_throws ArgumentError qr_full!(copy!(Ac, A), (Q, R); blocksize = 8)
122123
end
123124
end
125+
126+
@testset "qr_compact, qr_full and qr_null for Diagonal{$T}" for T in BLASFloats
127+
rng = StableRNG(123)
128+
atol = eps(real(T))^(3 / 4)
129+
for m in (54, 0)
130+
Ad = ROCArray(randn(rng, T, m))
131+
A = Diagonal(Ad)
132+
133+
# compact
134+
Q, R = @constinferred qr_compact(A)
135+
@test Q isa Diagonal{T} && size(Q) == (m, m)
136+
@test R isa Diagonal{T} && size(R) == (m, m)
137+
@test Q * R A
138+
@test isunitary(Q)
139+
140+
# compact and positive
141+
Qp, Rp = @constinferred qr_compact(A; positive = true)
142+
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
143+
@test Rp isa Diagonal{T} && size(Rp) == (m, m)
144+
@test Qp * Rp A
145+
@test isunitary(Qp)
146+
@test all((zero(real(T))), real(diag(Rp))) &&
147+
all((zero(real(T)); atol), imag(diag(Rp)))
148+
149+
# full
150+
Q, R = @constinferred qr_full(A)
151+
@test Q isa Diagonal{T} && size(Q) == (m, m)
152+
@test R isa Diagonal{T} && size(R) == (m, m)
153+
@test Q * R A
154+
@test isunitary(Q)
155+
156+
# full and positive
157+
Qp, Rp = @constinferred qr_full(A; positive = true)
158+
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
159+
@test Rp isa Diagonal{T} && size(Rp) == (m, m)
160+
@test Qp * Rp A
161+
@test isunitary(Qp)
162+
@test all((zero(real(T))), real(diag(Rp))) &&
163+
all((zero(real(T)); atol), imag(diag(Rp)))
164+
165+
# null
166+
N = @constinferred qr_null(A)
167+
@test N isa AbstractMatrix{T} && size(N) == (m, 0)
168+
end
169+
end

0 commit comments

Comments
 (0)