Skip to content

Commit ba2c9ef

Browse files
kshyattJutholkdvos
authored
A few more updates for GPU compatibility for TensorKit (#100)
* A few more updates for GPU compatibility for TensorKit * Update src/implementations/projections.jl Co-authored-by: Jutho <[email protected]> * Update antihermiticity check * Respond to comments, update tests * Add tests for SVD * remove `@invoke` calls * fix kwargs * further cleanup of `ishermitian` and friends * include diagonal tests in projections on CPU * make JET happy * Revert "Add tests for SVD" This reverts commit 1b4399f. * revert Diagonal to `diagm` for svd with gpuarrays * dont call project_isometric with invalid alg-input types * Add alg for amd test --------- Co-authored-by: Jutho <[email protected]> Co-authored-by: Lukas Devos <[email protected]>
1 parent 13a1771 commit ba2c9ef

File tree

8 files changed

+191
-120
lines changed

8 files changed

+191
-120
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,14 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::Strid
127127
return nothing
128128
end
129129

130-
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
131-
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
132-
all(A.diag .== adjoint(A.diag))
133-
MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; kwargs...) =
134-
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)
135-
136-
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) =
137-
all(A .== -adjoint(A))
138-
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
139-
all(A.diag .== -adjoint(A.diag))
140-
MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; kwargs...) =
141-
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
130+
# avoids calling the `StridedMatrix` specialization to avoid scalar indexing,
131+
# use (allocating) fallback instead until we write a dedicated kernel
132+
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = A == A'
133+
MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) =
134+
norm(project_antihermitian(A; kwargs...)) max(atol, rtol * norm(A))
135+
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = A == -A'
136+
MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) =
137+
norm(project_hermitian(A; kwargs...)) max(atol, rtol * norm(A))
142138

143139
function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
144140
axes(A) == axes(B) || throw(DimensionMismatch())

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,19 +151,14 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::Stride
151151
return nothing
152152
end
153153

154-
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) =
155-
all(A .== adjoint(A))
156-
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
157-
all(A.diag .== adjoint(A.diag))
158-
MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; kwargs...) =
159-
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)
160-
161-
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) =
162-
all(A .== -adjoint(A))
163-
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
164-
all(A.diag .== -adjoint(A.diag))
165-
MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; kwargs...) =
166-
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
154+
# avoids calling the `StridedMatrix` specialization to avoid scalar indexing,
155+
# use (allocating) fallback instead until we write a dedicated kernel
156+
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = A == A'
157+
MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) =
158+
norm(project_antihermitian(A; kwargs...)) max(atol, rtol * norm(A))
159+
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = A == -A'
160+
MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) =
161+
norm(project_hermitian(A; kwargs...)) max(atol, rtol * norm(A))
167162

168163
function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
169164
axes(A) == axes(B) || throw(DimensionMismatch())

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MatrixAlgebraKit
33
using LinearAlgebra: LinearAlgebra
44
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
55
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
6-
using LinearAlgebra: sylvester, lu!
6+
using LinearAlgebra: sylvester, lu!, diagm
77
using LinearAlgebra: isposdef, issymmetric
88
using LinearAlgebra: Diagonal, diag, diagind, isdiag
99
using LinearAlgebra: UpperTriangular, LowerTriangular

src/common/matrixproperties.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,13 @@ end
7979

8080
ishermitian_exact(A) = A == A'
8181
ishermitian_exact(A::StridedMatrix; kwargs...) = strided_ishermitian_exact(A, Val(false); kwargs...)
82+
ishermitian_exact(A::Diagonal) = diagonal_ishermitian_exact(A, Val(false))
83+
8284
function ishermitian_approx(A; atol, rtol, kwargs...)
8385
return norm(project_antihermitian(A; kwargs...)) max(atol, rtol * norm(A))
8486
end
8587
ishermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(false); kwargs...)
88+
ishermitian_approx(A::Diagonal; kwargs...) = diagonal_ishermitian_approx(A, Val(false); kwargs...)
8689

8790
"""
8891
isantihermitian(A; isapprox_kwargs...)
@@ -97,16 +100,15 @@ function isantihermitian(A; atol::Real = 0, rtol::Real = 0, kwargs...)
97100
return isantihermitian_approx(A; atol, rtol, kwargs...)
98101
end
99102
end
100-
function isantihermitian_exact(A)
101-
return A == -A'
102-
end
103-
function isantihermitian_exact(A::StridedMatrix; kwargs...)
104-
return strided_ishermitian_exact(A, Val(true); kwargs...)
105-
end
103+
isantihermitian_exact(A) = A == -A'
104+
isantihermitian_exact(A::StridedMatrix; kwargs...) = strided_ishermitian_exact(A, Val(true); kwargs...)
105+
isantihermitian_exact(A::Diagonal) = diagonal_ishermitian_exact(A, Val(true))
106+
106107
function isantihermitian_approx(A; atol, rtol, kwargs...)
107108
return norm(project_hermitian(A; kwargs...)) max(atol, rtol * norm(A))
108109
end
109110
isantihermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(true); kwargs...)
111+
isantihermitian_approx(A::Diagonal; kwargs...) = diagonal_ishermitian_approx(A, Val(true); kwargs...)
110112

111113
# blocked implementation of exact checks for strided matrices
112114
# -----------------------------------------------------------
@@ -145,7 +147,6 @@ function _ishermitian_exact_offdiag(Al, Au, ::Val{anti}) where {anti}
145147
return true
146148
end
147149

148-
149150
function strided_ishermitian_approx(
150151
A::AbstractMatrix, anti::Val;
151152
blocksize = 32, atol::Real = default_hermitian_tol(A), rtol::Real = 0
@@ -192,3 +193,16 @@ function _ishermitian_approx_offdiag(Al, Au, ::Val{anti}) where {anti}
192193
end
193194
return ϵ²
194195
end
196+
197+
diagonal_ishermitian_exact(A, ::Val{anti}) where {anti} = all(iszero (anti ? real : imag), diagview(A))
198+
199+
function diagonal_ishermitian_approx(
200+
A, ::Val{anti}; atol::Real = default_hermitian_tol(A), rtol::Real = 0
201+
) where {anti}
202+
m, n = size(A)
203+
m == n || throw(DimensionMismatch())
204+
init = abs2(zero(eltype(A)))
205+
ϵ² = sum(abs2 (anti ? real : imag), diagview(A); init)
206+
ϵ²max = oftype(ϵ², rtol > 0 ? max(atol, rtol * norm(A)) : atol)^2
207+
return ϵ² ϵ²max
208+
end

src/implementations/projections.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ copy_input(::typeof(project_isometric), A) = copy_input(left_polar, A)
99

1010
function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
1111
LinearAlgebra.checksquare(A)
12-
n = Base.require_one_based_indexing(A)
12+
Base.require_one_based_indexing(A)
13+
n = size(A, 1)
1314
B === A || @check_size(B, (n, n))
1415
return nothing
1516
end
1617
function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
1718
LinearAlgebra.checksquare(A)
18-
n = Base.require_one_based_indexing(A)
19+
Base.require_one_based_indexing(A)
20+
n = size(A, 1)
1921
B === A || @check_size(B, (n, n))
2022
return nothing
2123
end
@@ -61,6 +63,15 @@ function project_isometric!(A::AbstractMatrix, W, alg::AbstractAlgorithm)
6163
return W
6264
end
6365

66+
function project_hermitian_native!(A::Diagonal, B::Diagonal, ::Val{anti}; kwargs...) where {anti}
67+
if anti
68+
diagview(A) .= _imimag.(diagview(B))
69+
else
70+
diagview(A) .= real.(diagview(B))
71+
end
72+
return A
73+
end
74+
6475
function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32)
6576
n = size(A, 1)
6677
for j in 1:blocksize:n

test/amd/projections.jl

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,40 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
1212
m = 54
1313
noisefactor = eps(real(T))^(3 / 4)
1414
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
15-
A = ROCArray(randn(rng, T, m, m))
16-
Ah = (A + A') / 2
17-
Aa = (A - A') / 2
18-
Ac = copy(A)
15+
for A in (ROCArray(randn(rng, T, m, m)), Diagonal(ROCArray(randn(rng, T, m))))
16+
Ah = (A + A') / 2
17+
Aa = (A - A') / 2
18+
Ac = copy(A)
1919

20-
Bh = project_hermitian(A, alg)
21-
@test ishermitian(Bh)
22-
@test Bh Ah
23-
@test A == Ac
24-
Bh_approx = Bh + noisefactor * Aa
25-
@test !ishermitian(Bh_approx)
26-
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
20+
Bh = project_hermitian(A, alg)
21+
@test ishermitian(Bh)
22+
@test Bh Ah
23+
@test A == Ac
24+
Bh_approx = Bh + noisefactor * Aa
25+
# this is still hermitian for real Diagonal: |A - A'| == 0
26+
@test !ishermitian(Bh_approx) || norm(Aa) == 0
27+
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
2728

28-
Ba = project_antihermitian(A, alg)
29-
@test isantihermitian(Ba)
30-
@test Ba Aa
31-
@test A == Ac
32-
Ba_approx = Ba + noisefactor * Ah
33-
@test !isantihermitian(Ba_approx)
34-
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
29+
Ba = project_antihermitian(A, alg)
30+
@test isantihermitian(Ba)
31+
@test Ba Aa
32+
@test A == Ac
33+
Ba_approx = Ba + noisefactor * Ah
34+
@test !isantihermitian(Ba_approx)
35+
# this is never anti-hermitian for real Diagonal: |A - A'| == 0
36+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0
3537

36-
Bh = project_hermitian!(Ac, alg)
37-
@test Bh === Ac
38-
@test ishermitian(Bh)
39-
@test Bh Ah
38+
Bh = project_hermitian!(Ac, alg)
39+
@test Bh === Ac
40+
@test ishermitian(Bh)
41+
@test Bh Ah
4042

41-
copy!(Ac, A)
42-
Ba = project_antihermitian!(Ac, alg)
43-
@test Ba === Ac
44-
@test isantihermitian(Ba)
45-
@test Ba Aa
43+
copy!(Ac, A)
44+
Ba = project_antihermitian!(Ac, alg)
45+
@test Ba === Ac
46+
@test isantihermitian(Ba)
47+
@test Ba Aa
48+
end
4649
end
4750
end
4851

@@ -68,10 +71,33 @@ end
6871

6972
# test that W is closer to A then any other isometry
7073
for k in 1:10
71-
δA = ROCArray(randn(rng, T, m, n))
74+
δA = ROCArray(randn(rng, T, size(A)...))
75+
W = project_isometric(A, alg)
76+
W2 = project_isometric(A + δA / 100, alg)
77+
@test norm(A - W2) >= norm(A - W)
78+
end
79+
end
80+
81+
m == n && @testset "DiagonalAlgorithm" begin
82+
A = Diagonal(ROCArray(randn(rng, T, m)))
83+
alg = PolarViaSVD(DiagonalAlgorithm())
84+
W = project_isometric(A, alg)
85+
@test isisometric(W)
86+
W2 = project_isometric(W, alg)
87+
@test W2 W # stability of the projection
88+
@test W * (W' * A) A
89+
90+
Ac = similar(A)
91+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
92+
@test W2 === W
93+
@test isisometric(W)
94+
95+
# test that W is closer to A then any other isometry
96+
for k in 1:10
97+
δA = Diagonal(ROCArray(randn(rng, T, m)))
7298
W = project_isometric(A, alg)
7399
W2 = project_isometric(A + δA / 100, alg)
74-
@test norm(A - W2) > norm(A - W)
100+
@test norm(A - W2) >= norm(A - W)
75101
end
76102
end
77103
end

test/cuda/projections.jl

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,40 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
1212
m = 54
1313
noisefactor = eps(real(T))^(3 / 4)
1414
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
15-
A = CuArray(randn(rng, T, m, m))
16-
Ah = (A + A') / 2
17-
Aa = (A - A') / 2
18-
Ac = copy(A)
15+
for A in (CuArray(randn(rng, T, m, m)), Diagonal(CuArray(randn(rng, T, m))))
16+
Ah = (A + A') / 2
17+
Aa = (A - A') / 2
18+
Ac = copy(A)
1919

20-
Bh = project_hermitian(A, alg)
21-
@test ishermitian(Bh)
22-
@test Bh Ah
23-
@test A == Ac
24-
Bh_approx = Bh + noisefactor * Aa
25-
@test !ishermitian(Bh_approx)
26-
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
20+
Bh = project_hermitian(A, alg)
21+
@test ishermitian(Bh)
22+
@test Bh Ah
23+
@test A == Ac
24+
Bh_approx = Bh + noisefactor * Aa
25+
# this is still hermitian for real Diagonal: |A - A'| == 0
26+
@test !ishermitian(Bh_approx) || norm(Aa) == 0
27+
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
2728

28-
Ba = project_antihermitian(A, alg)
29-
@test isantihermitian(Ba)
30-
@test Ba Aa
31-
@test A == Ac
32-
Ba_approx = Ba + noisefactor * Ah
33-
@test !isantihermitian(Ba_approx)
34-
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
29+
Ba = project_antihermitian(A, alg)
30+
@test isantihermitian(Ba)
31+
@test Ba Aa
32+
@test A == Ac
33+
Ba_approx = Ba + noisefactor * Ah
34+
@test !isantihermitian(Ba_approx)
35+
# this is never anti-hermitian for real Diagonal: |A - A'| == 0
36+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0
3537

36-
Bh = project_hermitian!(Ac, alg)
37-
@test Bh === Ac
38-
@test ishermitian(Bh)
39-
@test Bh Ah
38+
Bh = project_hermitian!(Ac, alg)
39+
@test Bh === Ac
40+
@test ishermitian(Bh)
41+
@test Bh Ah
4042

41-
copy!(Ac, A)
42-
Ba = project_antihermitian!(Ac, alg)
43-
@test Ba === Ac
44-
@test isantihermitian(Ba)
45-
@test Ba Aa
43+
copy!(Ac, A)
44+
Ba = project_antihermitian!(Ac, alg)
45+
@test Ba === Ac
46+
@test isantihermitian(Ba)
47+
@test Ba Aa
48+
end
4649
end
4750
end
4851

@@ -68,10 +71,33 @@ end
6871

6972
# test that W is closer to A then any other isometry
7073
for k in 1:10
71-
δA = CuArray(randn(rng, T, m, n))
74+
δA = CuArray(randn(rng, T, size(A)...))
75+
W = project_isometric(A, alg)
76+
W2 = project_isometric(A + δA / 100, alg)
77+
@test norm(A - W2) >= norm(A - W)
78+
end
79+
end
80+
81+
m == n && @testset "DiagonalAlgorithm" begin
82+
A = Diagonal(CuArray(randn(rng, T, m)))
83+
alg = PolarViaSVD(DiagonalAlgorithm())
84+
W = project_isometric(A, alg)
85+
@test isisometric(W)
86+
W2 = project_isometric(W, alg)
87+
@test W2 W # stability of the projection
88+
@test W * (W' * A) A
89+
90+
Ac = similar(A)
91+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
92+
@test W2 === W
93+
@test isisometric(W)
94+
95+
# test that W is closer to A then any other isometry
96+
for k in 1:10
97+
δA = Diagonal(CuArray(randn(rng, T, m)))
7298
W = project_isometric(A, alg)
7399
W2 = project_isometric(A + δA / 100, alg)
74-
@test norm(A - W2) > norm(A - W)
100+
@test norm(A - W2) >= norm(A - W)
75101
end
76102
end
77103
end

0 commit comments

Comments
 (0)