Skip to content

Commit 504686e

Browse files
committed
Run tests on GPU
1 parent 0a110b3 commit 504686e

File tree

8 files changed

+76
-61
lines changed

8 files changed

+76
-61
lines changed

src/implementations/svd.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ end
152152
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
153153
check_input(svd_compact!, A, USVᴴ, alg)
154154
U, S, Vᴴ = USVᴴ
155+
if length(A) == 0
156+
one!(U)
157+
zero!(S)
158+
one!(Vᴴ)
159+
return USVᴴ
160+
end
155161

156162
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
157163
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
@@ -382,6 +388,12 @@ end
382388
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
383389
check_input(svd_compact!, A, USVᴴ, alg)
384390
U, S, Vᴴ = USVᴴ
391+
if length(A) == 0
392+
one!(U)
393+
zero!(S)
394+
one!(Vᴴ)
395+
return USVᴴ
396+
end
385397

386398
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
387399
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
@@ -406,6 +418,10 @@ _largest(x, y) = abs(x) < abs(y) ? y : x
406418

407419
function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
408420
check_input(svd_vals!, A, S, alg)
421+
if length(A) == 0
422+
zero!(S)
423+
return S
424+
end
409425
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
410426

411427
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})

test/polar.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ for T in BLASFloats, n in (37, m, 63)
1818
if is_buildkite
1919
if CUDA.functional()
2020
TestSuite.test_polar(CuMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
21-
TestSuite.test_polar(Diagonal{T, CuVector{T}}, m; test_pivoted = false, test_blocksize = false)
21+
# not supported
22+
#TestSuite.test_polar(Diagonal{T, CuVector{T}}, m; test_pivoted = false, test_blocksize = false)
2223
end
2324
if AMDGPU.functional()
2425
TestSuite.test_polar(ROCMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
25-
TestSuite.test_polar(Diagonal{T, ROCVector{T}}, m; test_pivoted = false, test_blocksize = false)
26+
# not supported
27+
#TestSuite.test_polar(Diagonal{T, ROCVector{T}}, m; test_pivoted = false, test_blocksize = false)
2628
end
2729
else
2830
TestSuite.test_polar(T, (m, n))

test/svd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using StableRNGs
5-
using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef, norm
6-
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, isisometric
5+
using LinearAlgebra: Diagonal
6+
using CUDA, AMDGPU
77

88
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
99
GenericFloats = (Float16, BigFloat, Complex{BigFloat})

test/testsuite/eig.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ function test_eig(T::Type, sz; kwargs...)
44
summary_str = testargs_summary(T, sz)
55
return @testset "eig $summary_str" begin
66
test_eig_full(T, sz; kwargs...)
7-
test_eig_trunc(T, sz; kwargs...)
7+
if T <: Number || T <: Diagonal{<:Number, <:Vector}
8+
test_eig_trunc(T, sz; kwargs...)
9+
end
810
end
911
end
1012

test/testsuite/eigh.jl

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function test_eigh(T::Type, sz; kwargs...)
44
summary_str = testargs_summary(T, sz)
55
return @testset "eigh $summary_str" begin
66
test_eigh_full(T, sz; kwargs...)
7-
if eltype(T) <: Union{Float16, ComplexF16, Float32, Float64, ComplexF32, ComplexF64} && !(T <: Diagonal)
7+
if T <: Number && eltype(T) <: Union{Float16, ComplexF16, Float32, Float64, ComplexF32, ComplexF64} && !(T <: Diagonal)
88
test_eigh_trunc(T, sz; kwargs...)
99
end
1010
end
@@ -54,31 +54,29 @@ function test_eigh_trunc(
5454
r = m - 2
5555
s = 1 + sqrt(eps(real(eltype(T))))
5656
atol = sqrt(eps(real(eltype(T))))
57-
local V1, V2, V3
58-
@testset "truncrank" begin
59-
D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r))
60-
@test length(diagview(D1)) == r
61-
@test isisometric(V1)
62-
@test A * V1 V1 * D1
63-
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') D₀[r + 1]
64-
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
65-
end
66-
@testset "trunctol" begin
67-
trunc = trunctol(; atol = s * D₀[r + 1])
68-
D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc)
69-
@test length(diagview(D2)) == r
70-
@test isisometric(V2)
71-
@test A * V2 V2 * D2
72-
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
73-
end
74-
@testset "truncerror" begin
75-
s = 1 - sqrt(eps(real(eltype(T))))
76-
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
77-
D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc)
78-
@test length(diagview(D3)) == r
79-
@test A * V3 V3 * D3
80-
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
81-
end
57+
# truncrank
58+
D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r))
59+
@test length(diagview(D1)) == r
60+
@test isisometric(V1)
61+
@test A * V1 V1 * D1
62+
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') D₀[r + 1]
63+
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
64+
65+
# trunctol
66+
trunc = trunctol(; atol = s * D₀[r + 1])
67+
D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc)
68+
@test length(diagview(D2)) == r
69+
@test isisometric(V2)
70+
@test A * V2 V2 * D2
71+
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
72+
73+
#truncerror
74+
s = 1 - sqrt(eps(real(eltype(T))))
75+
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
76+
D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc)
77+
@test length(diagview(D3)) == r
78+
@test A * V3 V3 * D3
79+
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
8280

8381
# test for same subspace
8482
@test V1 * (V1' * V2) V2

test/testsuite/polar.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ function test_left_polar(
1919
A = instantiate_matrix(T, sz)
2020
algs = if T <: Diagonal
2121
(PolarNewton(),)
22-
else
22+
elseif T <: Number
2323
(PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)), PolarNewton())
24+
else
25+
(PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)),)
2426
end
2527
@testset "algorithm $alg" for alg in algs
2628
A = instantiate_matrix(T, sz)
@@ -61,8 +63,10 @@ function test_right_polar(
6163
A = instantiate_matrix(T, sz)
6264
algs = if T <: Diagonal
6365
(PolarNewton(),)
64-
else
66+
elseif T <: Number
6567
(PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)), PolarNewton())
68+
else
69+
(PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)),)
6670
end
6771
@testset "algorithm $alg" for alg in algs
6872
A = instantiate_matrix(T, sz)

test/testsuite/projections.jl

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -114,32 +114,25 @@ function test_project_isometric(
114114
summary_str = testargs_summary(T, sz)
115115
return @testset "project_isometric! $summary_str" begin
116116
A = instantiate_matrix(T, sz)
117-
algs = if T <: Diagonal
118-
(PolarNewton(),)
119-
else
120-
(PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)), PolarNewton())
121-
end
122-
@testset "algorithm $alg" for alg in algs
123-
A = instantiate_matrix(T, sz)
124-
Ac = deepcopy(A)
125-
k = min(size(A)...)
126-
W = project_isometric(A, alg)
127-
@test isisometric(W)
128-
W2 = project_isometric(W, alg)
129-
@test W2 W # stability of the projection
130-
@test W * (W' * A) A
131-
132-
W2 = @testinferred project_isometric!(Ac, W, alg)
133-
@test W2 === W
134-
@test isisometric(W)
135-
136-
# test that W is closer to A then any other isometry
137-
for k in 1:10
138-
δA = randn(rng, eltype(T), size(A)...)
139-
W = project_isometric(A, alg)
140-
W2 = project_isometric(A + δA / 100, alg)
141-
@test norm(A - W2) > norm(A - W)
142-
end
117+
Ac = deepcopy(A)
118+
k = min(size(A)...)
119+
W = project_isometric(A)
120+
@test isisometric(W)
121+
W2 = project_isometric(W)
122+
@test W2 W # stability of the projection
123+
@test W * (W' * A) A
124+
125+
W2 = @testinferred project_isometric!(Ac, W)
126+
@test W2 === W
127+
@test isisometric(W)
128+
129+
# test that W is closer to A then any other isometry
130+
for k in 1:10
131+
δA = instantiate_matrix(T, sz)
132+
W = project_isometric(A)
133+
W2 = project_isometric(A + δA / 100)
134+
# must be ≥ for real Diagonal case
135+
@test norm(A - W2) norm(A - W)
143136
end
144137
end
145138
end

test/testsuite/svd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function test_svd(T::Type, sz; kwargs...)
55
return @testset "svd $summary_str" begin
66
test_svd_compact(T, sz; kwargs...)
77
test_svd_full(T, sz; kwargs...)
8-
if min(sz...) > 0
8+
if min(sz...) > 0 && (T <: Number || T <: Diagonal{<:Number, <:Vector})
99
test_svd_trunc(T, sz; kwargs...)
1010
end
1111
end
@@ -86,7 +86,7 @@ function test_svd_full(
8686
Sc = similar(A, real(eltype(T)), min(m, n))
8787
Sc2 = svd_vals!(copy!(Ac, A), Sc)
8888
@test Sc === Sc2
89-
@test diagview(S) Sc
89+
@test collect(diagview(S)) collect(Sc)
9090
end
9191
end
9292

0 commit comments

Comments
 (0)