From 235c8a5716f2708876f3371186f09b10360cb313 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 6 Nov 2025 09:28:36 -0500 Subject: [PATCH 1/6] add testsuite --- Project.toml | 6 +- .../MatrixAlgebraKitAMDGPUExt.jl | 39 --- src/implementations/qr.jl | 8 +- src/implementations/schur.jl | 51 ++++ src/implementations/svd.jl | 16 + test/amd/eigh.jl | 105 ------- test/amd/lq.jl | 163 ---------- test/amd/orthnull.jl | 285 ------------------ test/amd/polar.jl | 83 ----- test/amd/projections.jl | 104 ------- test/amd/qr.jl | 167 ---------- test/amd/svd.jl | 157 ---------- test/cuda/eig.jl | 108 ------- test/cuda/eigh.jl | 118 -------- test/cuda/lq.jl | 163 ---------- test/cuda/orthnull.jl | 264 ---------------- test/cuda/polar.jl | 83 ----- test/cuda/projections.jl | 104 ------- test/cuda/qr.jl | 168 ----------- test/cuda/svd.jl | 161 ---------- test/eig.jl | 131 ++------ test/eigh.jl | 137 ++------- test/genericlinearalgebra/eigh.jl | 93 ------ test/genericlinearalgebra/lq.jl | 124 -------- test/genericlinearalgebra/qr.jl | 109 ------- test/genericlinearalgebra/svd.jl | 171 ----------- test/genericschur/eig.jl | 116 ------- test/linearmap.jl | 8 +- test/lq.jl | 258 ++-------------- test/orthnull.jl | 235 ++------------- test/polar.jl | 97 ++---- test/projections.jl | 107 ++----- test/qr.jl | 227 ++------------ test/runtests.jl | 161 +++------- test/schur.jl | 47 +-- test/svd.jl | 243 ++------------- test/testsuite/TestSuite.jl | 76 +++++ test/testsuite/eig.jl | 100 ++++++ test/testsuite/eigh.jl | 107 +++++++ test/testsuite/lq.jl | 162 ++++++++++ test/testsuite/orthnull.jl | 268 ++++++++++++++++ test/testsuite/polar.jl | 102 +++++++ test/testsuite/projections.jl | 138 +++++++++ test/testsuite/qr.jl | 162 ++++++++++ test/testsuite/schur.jl | 38 +++ test/testsuite/svd.jl | 167 ++++++++++ 46 files changed, 1658 insertions(+), 4279 deletions(-) delete mode 100644 test/amd/eigh.jl delete mode 100644 test/amd/lq.jl delete mode 100644 test/amd/orthnull.jl delete mode 100644 test/amd/polar.jl delete mode 100644 test/amd/projections.jl delete mode 100644 test/amd/qr.jl delete mode 100644 test/amd/svd.jl delete mode 100644 test/cuda/eig.jl delete mode 100644 test/cuda/eigh.jl delete mode 100644 test/cuda/lq.jl delete mode 100644 test/cuda/orthnull.jl delete mode 100644 test/cuda/polar.jl delete mode 100644 test/cuda/projections.jl delete mode 100644 test/cuda/qr.jl delete mode 100644 test/cuda/svd.jl delete mode 100644 test/genericlinearalgebra/eigh.jl delete mode 100644 test/genericlinearalgebra/lq.jl delete mode 100644 test/genericlinearalgebra/qr.jl delete mode 100644 test/genericlinearalgebra/svd.jl delete mode 100644 test/genericschur/eig.jl create mode 100644 test/testsuite/TestSuite.jl create mode 100644 test/testsuite/eig.jl create mode 100644 test/testsuite/eigh.jl create mode 100644 test/testsuite/lq.jl create mode 100644 test/testsuite/orthnull.jl create mode 100644 test/testsuite/polar.jl create mode 100644 test/testsuite/projections.jl create mode 100644 test/testsuite/qr.jl create mode 100644 test/testsuite/schur.jl create mode 100644 test/testsuite/svd.jl diff --git a/Project.toml b/Project.toml index f09ad557..d4d5a143 100644 --- a/Project.toml +++ b/Project.toml @@ -33,10 +33,11 @@ GenericSchur = "0.5.6" JET = "0.9, 0.10" LinearAlgebra = "1" Mooncake = "0.4.183" +Random = "1" SafeTestsets = "0.1" StableRNGs = "1" Test = "1" -TestExtras = "0.2,0.3" +TestExtras = "0.3.2" Zygote = "0.7" julia = "1.10" @@ -47,6 +48,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -54,4 +56,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Random", "Mooncake"] diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index ff150f24..4e2255a6 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -159,43 +159,4 @@ function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix) return A, B end -function MatrixAlgebraKit.truncate( - ::typeof(left_null!), US::Tuple{TU, TS}, strategy::TruncationStrategy - ) where {TU <: ROCMatrix, TS} - # TODO: avoid allocation? - U, S = US - extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2)))) - ind = MatrixAlgebraKit.findtruncated(extended_S, strategy) - trunc_cols = collect(1:size(U, 2))[ind] - Utrunc = U[:, trunc_cols] - return Utrunc, ind -end -function MatrixAlgebraKit.truncate( - ::typeof(right_null!), SVᴴ::Tuple{TS, TVᴴ}, strategy::TruncationStrategy - ) where {TS, TVᴴ <: ROCMatrix} - # TODO: avoid allocation? - S, Vᴴ = SVᴴ - extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1)))) - ind = MatrixAlgebraKit.findtruncated(extended_S, strategy) - trunc_rows = collect(1:size(Vᴴ, 1))[ind] - Vᴴtrunc = Vᴴ[trunc_rows, :] - return Vᴴtrunc, ind -end - -# disambiguate: -function MatrixAlgebraKit.truncate( - ::typeof(left_null!), (U, S)::Tuple{TU, TS}, ::NoTruncation - ) where {TU <: ROCMatrix, TS} - m, n = size(S) - ind = (n + 1):m - return U[:, ind], ind -end -function MatrixAlgebraKit.truncate( - ::typeof(right_null!), (S, Vᴴ)::Tuple{TS, TVᴴ}, ::NoTruncation - ) where {TS, TVᴴ <: ROCMatrix} - m, n = size(S) - ind = (m + 1):n - return Vᴴ[ind, :], ind -end - end diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index a3c9d1f5..c3937b27 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -270,10 +270,12 @@ function _gpu_unmqr!( end function _gpu_qr!( - A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive = false, blocksize = 1 + A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; pivoted = false, positive = false, blocksize = 1 ) blocksize > 1 && throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition")) + pivoted && + throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a pivoted implementation for a QR decomposition")) m, n = size(A) minmn = min(m, n) computeR = length(R) > 0 @@ -309,10 +311,12 @@ function _gpu_qr!( end function _gpu_qr_null!( - A::AbstractMatrix, N::AbstractMatrix; positive = false, blocksize = 1 + A::AbstractMatrix, N::AbstractMatrix; positive = false, blocksize = 1, pivoted = false ) blocksize > 1 && throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition")) + pivoted && + throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a pivoted implementation for a QR decomposition")) m, n = size(A) minmn = min(m, n) fill!(N, zero(eltype(N))) diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index 3cad9e0f..605f8176 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -26,6 +26,29 @@ function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals, ::AbstractA return nothing end +function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + T, Z, vals = TZv + @assert vals isa AbstractVector && Z isa Diagonal + @check_scalar(T, A) + @check_size(Z, (m, m)) + @check_scalar(Z, A) + @check_size(vals, (n,)) + # Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable + @check_scalar(vals, A) + return nothing +end +function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert vals isa AbstractVector + @check_size(vals, (n,)) + # Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable + @check_scalar(vals, A) + return nothing +end + # Outputs # ------- function initialize_output(::typeof(schur_full!), A::AbstractMatrix, ::AbstractAlgorithm) @@ -39,6 +62,17 @@ function initialize_output(::typeof(schur_vals!), A::AbstractMatrix, ::AbstractA vals = similar(A, complex(eltype(A)), n) return vals end +function initialize_output(::typeof(schur_full!), A::Diagonal, ::DiagonalAlgorithm) + n = size(A, 1) + Z = similar(A) + vals = similar(A, eltype(A), n) + return (A, Z, vals) +end +function initialize_output(::typeof(schur_vals!), A::Diagonal, ::DiagonalAlgorithm) + n = size(A, 1) + vals = similar(A, eltype(A), n) + return vals +end # Implementation # -------------- @@ -72,3 +106,20 @@ function schur_vals!(A::AbstractMatrix, vals, alg::LAPACK_EigAlgorithm) end return vals end + +# Diagonal logic +# -------------- +function schur_full!(A::Diagonal, (T, Z, vals)::Tuple{Diagonal, Diagonal, <:AbstractVector}, alg::DiagonalAlgorithm) + check_input(schur_full!, A, (T, Z, vals), alg) + copy!(vals, diagview(A)) + one!(Z) + T === A || copy!(T, A) + return T, Z, vals +end + +function schur_vals!(A::Diagonal, vals::AbstractVector, alg::DiagonalAlgorithm) + check_input(schur_vals!, A, vals, alg) + Ad = diagview(A) + vals === Ad || copy!(vals, Ad) + return vals +end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 126e6a04..007b88af 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -152,6 +152,12 @@ end function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ + if length(A) == 0 + one!(U) + zero!(S) + one!(Vᴴ) + return USVᴴ + end do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) @@ -398,6 +404,12 @@ end function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ + if length(A) == 0 + one!(U) + zero!(S) + one!(Vᴴ) + return USVᴴ + end do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) @@ -422,6 +434,10 @@ _largest(x, y) = abs(x) < abs(y) ? y : x function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) check_input(svd_vals!, A, S, alg) + if length(A) == 0 + zero!(S) + return S + end U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) diff --git a/test/amd/eigh.jl b/test/amd/eigh.jl deleted file mode 100644 index 4d23f128..00000000 --- a/test/amd/eigh.jl +++ /dev/null @@ -1,105 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview -using AMDGPU - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "eigh_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in ( - ROCSOLVER_DivideAndConquer(), - ROCSOLVER_Jacobi(), - ROCSOLVER_Bisection(), - ROCSOLVER_QRIteration(), - ) - A = ROCArray(randn(rng, T, m, m)) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 === D - @test V2 === V - - D3 = @constinferred eigh_vals(A, alg) - @test parent(D) ≈ D3 - end -end - -#=@testset "eigh_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (ROCSOLVER_QRIteration(), - ROCSOLVER_DivideAndConquer(), - ) - A = ROCArray(randn(rng, T, m, m)) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - r = m - 2 - s = 1 + sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - - trunc = trunctol(; atol=s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - end -end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 4 - V = qr_compact(ROCArray(randn(rng, T, m, m)))[1] - D = Diagonal([0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(ROCSOLVER_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) - @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) -end=# - -@testset "eigh for Diagonal{$T}" for T in BLASFloats - rng = StableRNG(123) - m = 54 - Ad = randn(rng, T, m) - Ad .+= conj.(Ad) - A = Diagonal(ROCArray(Ad)) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eigh_full(A) - @test D isa Diagonal{real(T)} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eigh_vals(A) - @test D2 isa AbstractVector{real(T)} && length(D2) == m - @test diagview(D) ≈ D2 - - # TODO partialsortperm - #=A2 = Diagonal(ROCArray(T[0.9, 0.3, 0.1, 0.01])) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol=# -end diff --git a/test/amd/lq.jl b/test/amd/lq.jl deleted file mode 100644 index ef46bbe9..00000000 --- a/test/amd/lq.jl +++ /dev/null @@ -1,163 +0,0 @@ -using MatrixAlgebraKit -using MatrixAlgebraKit: diagview -using Test -using TestExtras -using StableRNGs -using AMDGPU -using LinearAlgebra - -include(joinpath("..", "utilities.jl")) - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "lq_compact! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = ROCArray(randn(rng, T, m, n)) - L, Q = @constinferred lq_compact(A) - @test L isa ROCMatrix{T} && size(L) == (m, minmn) - @test Q isa ROCMatrix{T} && size(Q) == (minmn, n) - @test L * Q ≈ A - @test isapproxone(Q * Q') - Nᴴ = @constinferred lq_null(A) - @test Nᴴ isa ROCMatrix{T} && size(Nᴴ) == (n - minmn, n) - @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) - @test isapproxone(Nᴴ * Nᴴ') - - Ac = similar(A) - L2, Q2 = @constinferred lq_compact!(copy!(Ac, A), (L, Q)) - @test L2 === L - @test Q2 === Q - Nᴴ2 = @constinferred lq_null!(copy!(Ac, A), Nᴴ) - @test Nᴴ2 === Nᴴ - - # noL - noL = similar(A, 0, minmn) - Q2 = similar(Q) - lq_compact!(copy!(Ac, A), (noL, Q2)) - @test Q == Q2 - - # positive - lq_compact!(copy!(Ac, A), (L, Q); positive = true) - @test L * Q ≈ A - @test isapproxone(Q * Q') - @test all(>=(zero(real(T))), real(diagview(L))) - lq_compact!(copy!(Ac, A), (noL, Q2); positive = true) - @test Q == Q2 - - # explicit blocksize - lq_compact!(copy!(Ac, A), (L, Q); blocksize = 1) - @test L * Q ≈ A - @test isapproxone(Q * Q') - lq_compact!(copy!(Ac, A), (noL, Q2); blocksize = 1) - @test Q == Q2 - lq_null!(copy!(Ac, A), Nᴴ; blocksize = 1) - @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) - @test isapproxone(Nᴴ * Nᴴ') - if m <= n - lq_compact!(copy!(Q2, A), (noL, Q2); blocksize = 1) # in-place Q - @test Q ≈ Q2 - # these do not work because of the in-place Q - @test_throws ArgumentError lq_compact!(copy!(Q2, A), (L, Q2); blocksize = 1) - @test_throws ArgumentError lq_compact!(copy!(Q2, A), (noL, Q2); positive = true) - end - # no blocked CUDA - @test_throws ArgumentError lq_compact!(copy!(Q2, A), (L, Q2); blocksize = 8) - end -end - -@testset "lq_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = ROCArray(randn(rng, T, m, n)) - L, Q = lq_full(A) - @test L isa ROCMatrix{T} && size(L) == (m, n) - @test Q isa ROCMatrix{T} && size(Q) == (n, n) - @test L * Q ≈ A - @test isapproxone(Q * Q') - - Ac = similar(A) - L2, Q2 = @constinferred lq_full!(copy!(Ac, A), (L, Q)) - @test L2 === L - @test Q2 === Q - @test L * Q ≈ A - @test isapproxone(Q * Q') - - # noL - noL = similar(A, 0, n) - Q2 = similar(Q) - lq_full!(copy!(Ac, A), (noL, Q2)) - @test Q == Q2 - - # positive - lq_full!(copy!(Ac, A), (L, Q); positive = true) - @test L * Q ≈ A - @test isapproxone(Q * Q') - @test all(>=(zero(real(T))), real(diagview(L))) - lq_full!(copy!(Ac, A), (noL, Q2); positive = true) - @test Q == Q2 - - # explicit blocksize - lq_full!(copy!(Ac, A), (L, Q); blocksize = 1) - @test L * Q ≈ A - @test isapproxone(Q * Q') - lq_full!(copy!(Ac, A), (noL, Q2); blocksize = 1) - @test Q == Q2 - if n == m - lq_full!(copy!(Q2, A), (noL, Q2); blocksize = 1) # in-place Q - @test Q ≈ Q2 - # these do not work because of the in-place Q - @test_throws ArgumentError lq_full!(copy!(Q2, A), (L, Q2); blocksize = 1) - @test_throws ArgumentError lq_full!(copy!(Q2, A), (noL, Q2); positive = true) - end - # no blocked CUDA - @test_throws ArgumentError lq_full!(copy!(Ac, A), (L, Q); blocksize = 8) - end -end - -@testset "lq_compact, lq_full and lq_null for Diagonal{$T}" for T in BLASFloats - rng = StableRNG(123) - atol = eps(real(T))^(3 / 4) - for m in (54, 0) - Ad = ROCArray(randn(rng, T, m)) - A = Diagonal(Ad) - - # compact - L, Q = @constinferred lq_compact(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test L isa Diagonal{T} && size(L) == (m, m) - @test L * Q ≈ A - @test isunitary(Q) - - # compact and positive - Lp, Qp = @constinferred lq_compact(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Lp isa Diagonal{T} && size(Lp) == (m, m) - @test Lp * Qp ≈ A - @test isunitary(Qp) - @test all(isposdef.(diagview(Lp))) - - # full - L, Q = @constinferred lq_full(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test L isa Diagonal{T} && size(L) == (m, m) - @test L * Q ≈ A - @test isunitary(Q) - - # full and positive - Lp, Qp = @constinferred lq_full(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Lp isa Diagonal{T} && size(Lp) == (m, m) - @test Lp * Qp ≈ A - @test isunitary(Qp) - @test all(isposdef.(diagview(Lp))) - - # null - N = @constinferred lq_null(A) - @test N isa AbstractMatrix{T} && size(N) == (0, m) - end -end diff --git a/test/amd/orthnull.jl b/test/amd/orthnull.jl deleted file mode 100644 index 3223979c..00000000 --- a/test/amd/orthnull.jl +++ /dev/null @@ -1,285 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, I, mul!, diagm, norm -using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, - initialize_output, AbstractAlgorithm -using AMDGPU - -# testing non-AbstractArray codepaths: -include(joinpath("..", "linearmap.jl")) - -eltypes = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "left_orth and left_null for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - @testset for n in (37, m, 63) - minmn = min(m, n) - A = ROCArray(randn(rng, T, m, n)) - V, C = @constinferred left_orth(A) - N = @constinferred left_null(A) - @test V isa ROCMatrix{T} && size(V) == (m, minmn) - @test C isa ROCMatrix{T} && size(C) == (minmn, n) - @test N isa ROCMatrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - hV = collect(V) - hN = collect(N) - @test hV * hV' + hN * hN' ≈ I - - M = LinearMap(A) - VM, CM = @constinferred left_orth(M; alg = :svd) - @test parent(VM) * parent(CM) ≈ A - - if m > n - nullity = 5 - V, C = @constinferred left_orth(A) - AMDGPU.@allowscalar begin - N = @constinferred left_null(A; trunc = (; maxnullity = nullity)) - end - @test V isa ROCMatrix{T} && size(V) == (m, minmn) - @test C isa ROCMatrix{T} && size(C) == (minmn, n) - @test N isa ROCMatrix{T} && size(N) == (m, nullity) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - end - - # passing a kind and some kwargs - V, C = @constinferred left_orth(A; alg = :qr, positive = true) - N = @constinferred left_null(A; alg = :qr, positive = true) - @test V isa ROCMatrix{T} && size(V) == (m, minmn) - @test C isa ROCMatrix{T} && size(C) == (minmn, n) - @test N isa ROCMatrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - hV = collect(V) - hN = collect(N) - @test hV * hV' + hN * hN' ≈ I - - # passing an algorithm - V, C = @constinferred left_orth(A; alg = CUSOLVER_HouseholderQR()) - N = @constinferred left_null(A; alg = :qr, positive = true) - @test V isa ROCMatrix{T} && size(V) == (m, minmn) - @test C isa ROCMatrix{T} && size(C) == (minmn, n) - @test N isa ROCMatrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - hV = collect(V) - hN = collect(N) - @test hV * hV' + hN * hN' ≈ I - - Ac = similar(A) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C)) - N2 = @constinferred left_null!(copy!(Ac, A), N) - @test V2 === V - @test C2 === C - @test N2 === N - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - - atol = eps(real(T)) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = (; atol = atol)) - AMDGPU.@allowscalar begin - N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol)) - end - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - - rtol = eps(real(T)) - for (trunc_orth, trunc_null) in ( - ((; rtol = rtol), (; rtol = rtol)), - (trunctol(; rtol), trunctol(; rtol, keep_below = true)), - ) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = trunc_orth) - AMDGPU.@allowscalar begin - N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null) - end - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - end - - @testset for alg in (:qr, :polar, :svd) # explicit alg kwarg - m < n && alg == :polar && continue - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg))) - @test V2 * C2 ≈ A - @test isisometric(V2) - if alg != :polar - N2 = @constinferred left_null!(copy!(Ac, A), N; alg) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - end - - # with alg and tol kwargs - if alg == :svd - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; atol)) - AMDGPU.@allowscalar begin - N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) - end - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; rtol)) - AMDGPU.@allowscalar begin - N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) - end - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - else - @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) - @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) - alg == :polar && continue - @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) - @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) - end - end - end -end - -@testset "right_orth and right_null for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - @testset for n in (37, m, 63) - minmn = min(m, n) - A = ROCArray(randn(rng, T, m, n)) - C, Vᴴ = @constinferred right_orth(A) - Nᴴ = @constinferred right_null(A) - @test C isa ROCMatrix{T} && size(C) == (m, minmn) - @test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (minmn, n) - @test Nᴴ isa ROCMatrix{T} && size(Nᴴ) == (n - minmn, n) - @test C * Vᴴ ≈ A - @test isisometric(Vᴴ; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ; side = :right) - hVᴴ = collect(Vᴴ) - hNᴴ = collect(Nᴴ) - @test hVᴴ' * hVᴴ + hNᴴ' * hNᴴ ≈ I - - M = LinearMap(A) - CM, VMᴴ = @constinferred right_orth(M; alg = :svd) - @test parent(CM) * parent(VMᴴ) ≈ A - - Ac = similar(A) - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - - atol = eps(real(T)) - rtol = eps(real(T)) - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol = atol)) - AMDGPU.@allowscalar begin - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol = atol)) - end - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol = rtol)) - AMDGPU.@allowscalar begin - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol = rtol)) - end - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - - - @testset "alg = $alg" for alg in (:lq, :polar, :svd) - n < m && alg == :polar && continue - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg))) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - if alg != :polar - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg))) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - end - - if alg == :svd - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; atol)) - AMDGPU.@allowscalar begin - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; atol)) - end - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; rtol)) - AMDGPU.@allowscalar begin - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; rtol)) - end - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - else - @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) - @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) - alg == :polar && continue - @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) - @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) - end - end - - end -end diff --git a/test/amd/polar.jl b/test/amd/polar.jl deleted file mode 100644 index 4040b674..00000000 --- a/test/amd/polar.jl +++ /dev/null @@ -1,83 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, I, isposdef, Hermitian -using MatrixAlgebraKit: PolarViaSVD -using AMDGPU - -@testset "left_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m) - k = min(m, n) - svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) - @testset "algorithm $svd_alg" for svd_alg in svd_algs - A = ROCArray(randn(rng, T, m, n)) - alg = PolarViaSVD(svd_alg) - W, P = left_polar(A; alg) - @test W isa ROCMatrix{T} && size(W) == (m, n) - @test P isa ROCMatrix{T} && size(P) == (n, n) - @test W * P ≈ A - @test isisometric(W) - # work around extremely strict Julia criteria for Hermiticity - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P)) - - Ac = similar(A) - W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, P), alg) - @test W2 === W - @test P2 === P - @test W * P ≈ A - @test isisometric(W) - # work around extremely strict Julia criteria for Hermiticity - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P)) - - noP = similar(P, (0, 0)) - W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, noP), alg) - @test P2 === noP - @test W2 === W - @test isisometric(W) - P = W' * A # compute P explicitly to verify W correctness - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(Hermitian(project_hermitian!(P))) - end - end -end - -@testset "right_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - rng = StableRNG(123) - n = 54 - @testset "size ($m, $n)" for m in (37, n) - k = min(m, n) - svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) - @testset "algorithm $svd_alg" for svd_alg in svd_algs - A = ROCArray(randn(rng, T, m, n)) - alg = PolarViaSVD(svd_alg) - P, Wᴴ = right_polar(A; alg) - @test Wᴴ isa ROCMatrix{T} && size(Wᴴ) == (m, n) - @test P isa ROCMatrix{T} && size(P) == (m, m) - @test P * Wᴴ ≈ A - @test isisometric(Wᴴ; side = :right) - # work around extremely strict Julia criteria for Hermiticity - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P)) - - Ac = similar(A) - P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (P, Wᴴ), alg) - @test P2 === P - @test Wᴴ2 === Wᴴ - @test P * Wᴴ ≈ A - @test isisometric(Wᴴ; side = :right) - # work around extremely strict Julia criteria for Hermiticity - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P)) - - noP = similar(P, (0, 0)) - P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg) - @test P2 === noP - @test Wᴴ2 === Wᴴ - @test isisometric(Wᴴ; side = :right) - P = A * Wᴴ' # compute P explicitly to verify W correctness - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(Hermitian(project_hermitian!(P))) - end - end -end diff --git a/test/amd/projections.jl b/test/amd/projections.jl deleted file mode 100644 index a06b152c..00000000 --- a/test/amd/projections.jl +++ /dev/null @@ -1,104 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, norm -using AMDGPU - -const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - noisefactor = eps(real(T))^(3 / 4) - for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64)) - for A in (ROCArray(randn(rng, T, m, m)), Diagonal(ROCArray(randn(rng, T, m)))) - Ah = (A + A') / 2 - Aa = (A - A') / 2 - Ac = copy(A) - - Bh = project_hermitian(A, alg) - @test ishermitian(Bh) - @test Bh ≈ Ah - @test A == Ac - Bh_approx = Bh + noisefactor * Aa - # this is still hermitian for real Diagonal: |A - A'| == 0 - @test !ishermitian(Bh_approx) || norm(Aa) == 0 - @test ishermitian(Bh_approx; rtol = 10 * noisefactor) - - Ba = project_antihermitian(A, alg) - @test isantihermitian(Ba) - @test Ba ≈ Aa - @test A == Ac - Ba_approx = Ba + noisefactor * Ah - @test !isantihermitian(Ba_approx) - # this is never anti-hermitian for real Diagonal: |A - A'| == 0 - @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0 - - Bh = project_hermitian!(Ac, alg) - @test Bh === Ac - @test ishermitian(Bh) - @test Bh ≈ Ah - - copy!(Ac, A) - Ba = project_antihermitian!(Ac, alg) - @test Ba === Ac - @test isantihermitian(Ba) - @test Ba ≈ Aa - end - end -end - -@testset "project_isometric! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m) - k = min(m, n) - svdalgs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) - algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO - @testset "algorithm $alg" for alg in algs - A = ROCArray(randn(rng, T, m, n)) - W = project_isometric(A, alg) - @test isisometric(W) - W2 = project_isometric(W, alg) - @test W2 ≈ W # stability of the projection - @test W * (W' * A) ≈ A - - Ac = similar(A) - W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) - @test W2 === W - @test isisometric(W) - - # test that W is closer to A then any other isometry - for k in 1:10 - δA = ROCArray(randn(rng, T, size(A)...)) - W = project_isometric(A, alg) - W2 = project_isometric(A + δA / 100, alg) - @test norm(A - W2) >= norm(A - W) - end - end - - m == n && @testset "DiagonalAlgorithm" begin - A = Diagonal(ROCArray(randn(rng, T, m))) - alg = PolarViaSVD(DiagonalAlgorithm()) - W = project_isometric(A, alg) - @test isisometric(W) - W2 = project_isometric(W, alg) - @test W2 ≈ W # stability of the projection - @test W * (W' * A) ≈ A - - Ac = similar(A) - W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) - @test W2 === W - @test isisometric(W) - - # test that W is closer to A then any other isometry - for k in 1:10 - δA = Diagonal(ROCArray(randn(rng, T, m))) - W = project_isometric(A, alg) - W2 = project_isometric(A + δA / 100, alg) - @test norm(A - W2) >= norm(A - W) - end - end - end -end diff --git a/test/amd/qr.jl b/test/amd/qr.jl deleted file mode 100644 index b0708ae1..00000000 --- a/test/amd/qr.jl +++ /dev/null @@ -1,167 +0,0 @@ -using MatrixAlgebraKit -using MatrixAlgebraKit: diagview -using Test -using TestExtras -using StableRNGs -using AMDGPU -using LinearAlgebra - -include(joinpath("..", "utilities.jl")) - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "qr_compact! and qr_null! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = ROCArray(randn(rng, T, m, n)) - Q, R = @constinferred qr_compact(A) - @test Q isa ROCMatrix{T} && size(Q) == (m, minmn) - @test R isa ROCMatrix{T} && size(R) == (minmn, n) - @test Q * R ≈ A - N = @constinferred qr_null(A) - @test N isa ROCMatrix{T} && size(N) == (m, m - minmn) - @test isapproxone(Q' * Q) - @test maximum(abs, A' * N) < eps(real(T))^(2 / 3) - @test isapproxone(N' * N) - - Ac = similar(A) - Q2, R2 = @constinferred qr_compact!(copy!(Ac, A), (Q, R)) - @test Q2 === Q - @test R2 === R - N2 = @constinferred qr_null!(copy!(Ac, A), N) - @test N2 === N - - # noR - Q2 = similar(Q) - noR = similar(A, minmn, 0) - qr_compact!(copy!(Ac, A), (Q2, noR)) - @test Q == Q2 - - # positive - qr_compact!(copy!(Ac, A), (Q, R); positive = true) - @test Q * R ≈ A - @test isapproxone(Q' * Q) - @test all(>=(zero(real(T))), real(diagview(R))) - qr_compact!(copy!(Ac, A), (Q2, noR); positive = true) - @test Q == Q2 - - # explicit blocksize - qr_compact!(copy!(Ac, A), (Q, R); blocksize = 1) - @test Q * R ≈ A - @test isapproxone(Q' * Q) - qr_compact!(copy!(Ac, A), (Q2, noR); blocksize = 1) - @test Q == Q2 - qr_compact!(copy!(Ac, A), (Q2, noR); blocksize = 1) - qr_null!(copy!(Ac, A), N; blocksize = 1) - @test maximum(abs, A' * N) < eps(real(T))^(2 / 3) - @test isapproxone(N' * N) - if n <= m - qr_compact!(copy!(Q2, A), (Q2, noR); blocksize = 1) # in-place Q - @test Q ≈ Q2 - # these do not work because of the in-place Q - @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, R2)) - @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, noR); positive = true) - end - # no blocked CUDA - @test_throws ArgumentError qr_compact!(copy!(Ac, A), (Q2, R); blocksize = 8) - end -end - -@testset "qr_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 63 - for n in (37, m, 63) - minmn = min(m, n) - A = ROCArray(randn(rng, T, m, n)) - Q, R = qr_full(A) - @test Q isa ROCMatrix{T} && size(Q) == (m, m) - @test R isa ROCMatrix{T} && size(R) == (m, n) - @test Q * R ≈ A - @test isapproxone(Q' * Q) - - Ac = similar(A) - Q2 = similar(Q) - noR = similar(A, m, 0) - Q2, R2 = @constinferred qr_full!(copy!(Ac, A), (Q, R)) - @test Q2 === Q - @test R2 === R - @test Q * R ≈ A - @test isapproxone(Q' * Q) - qr_full!(copy!(Ac, A), (Q2, noR)) - @test Q == Q2 - - # noR - noR = similar(A, m, 0) - Q2 = similar(Q) - qr_full!(copy!(Ac, A), (Q2, noR)) - @test Q == Q2 - - # positive - qr_full!(copy!(Ac, A), (Q, R); positive = true) - @test Q * R ≈ A - @test isapproxone(Q' * Q) - @test all(>=(zero(real(T))), real(diagview(R))) - qr_full!(copy!(Ac, A), (Q2, noR); positive = true) - @test Q == Q2 - - # explicit blocksize - qr_full!(copy!(Ac, A), (Q, R); blocksize = 1) - @test Q * R ≈ A - @test isapproxone(Q' * Q) - qr_full!(copy!(Ac, A), (Q2, noR); blocksize = 1) - @test Q == Q2 - if n == m - qr_full!(copy!(Q2, A), (Q2, noR); blocksize = 1) # in-place Q - @test Q ≈ Q2 - @test_throws ArgumentError qr_full!(copy!(Q2, A), (Q2, R2)) - @test_throws ArgumentError qr_full!(copy!(Q2, A), (Q2, noR); positive = true) - end - # no blocked CUDA - @test_throws ArgumentError qr_full!(copy!(Ac, A), (Q, R); blocksize = 8) - end -end - -@testset "qr_compact, qr_full and qr_null for Diagonal{$T}" for T in BLASFloats - rng = StableRNG(123) - atol = eps(real(T))^(3 / 4) - for m in (54, 0) - Ad = ROCArray(randn(rng, T, m)) - A = Diagonal(Ad) - - # compact - Q, R = @constinferred qr_compact(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test R isa Diagonal{T} && size(R) == (m, m) - @test Q * R ≈ A - @test isunitary(Q) - - # compact and positive - Qp, Rp = @constinferred qr_compact(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Rp isa Diagonal{T} && size(Rp) == (m, m) - @test Qp * Rp ≈ A - @test isunitary(Qp) - @test all(isposdef.(diagview(Rp))) - - # full - Q, R = @constinferred qr_full(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test R isa Diagonal{T} && size(R) == (m, m) - @test Q * R ≈ A - @test isunitary(Q) - - # full and positive - Qp, Rp = @constinferred qr_full(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Rp isa Diagonal{T} && size(Rp) == (m, m) - @test Qp * Rp ≈ A - @test isunitary(Qp) - @test all(isposdef.(diagview(Rp))) - - # null - N = @constinferred qr_null(A) - @test N isa AbstractMatrix{T} && size(N) == (m, 0) - end -end diff --git a/test/amd/svd.jl b/test/amd/svd.jl deleted file mode 100644 index fcd5b490..00000000 --- a/test/amd/svd.jl +++ /dev/null @@ -1,157 +0,0 @@ -using MatrixAlgebraKit -using MatrixAlgebraKit: diagview -using LinearAlgebra: Diagonal, isposdef -using Test -using TestExtras -using StableRNGs -using AMDGPU - -include(joinpath("..", "utilities.jl")) - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "svd_compact! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63) - k = min(m, n) - algs(::ROCArray) = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) - algs(::Diagonal) = (DiagonalAlgorithm(),) - As = m == n ? (ROCArray(randn(rng, T, m, n)), Diagonal(ROCArray(randn(rng, T, m)))) : (ROCArray(randn(rng, T, m, n)),) - for A in As - @testset "algorithm $alg" for alg in algs(A) - minmn = min(m, n) - - U, S, Vᴴ = svd_compact(A; alg) - @test U isa ROCMatrix{T} && size(U) == (m, minmn) - @test S isa Diagonal{real(T), <:ROCVector} && size(S) == (minmn, minmn) - @test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (minmn, n) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(Vᴴ * Vᴴ') - @test isposdef(S) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(Vᴴ * Vᴴ') - @test isposdef(S) - - Sd = svd_vals(A, alg) - @test ROCArray(diagview(S)) ≈ Sd - # ROCArray is necessary because norm of ROCArray view with non-unit step is broken - if alg isa ROCSOLVER_QRIteration - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad")) - end - end - end - end -end - -@testset "svd_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - algs(::ROCArray) = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) - algs(::Diagonal) = (DiagonalAlgorithm(),) - @testset "size ($m, $n)" for n in (37, m, 63) - As = m == n ? (ROCArray(randn(rng, T, m, n)), Diagonal(ROCArray(randn(rng, T, m)))) : (ROCArray(randn(rng, T, m, n)),) - for A in As - @testset "algorithm $alg" for alg in algs(A) - U, S, Vᴴ = svd_full(A; alg) - @test U isa ROCMatrix{T} && size(U) == (m, m) - if A isa Diagonal - @test S isa Diagonal{real(T), <:ROCVector{real(T)}} && size(S) == (m, n) - else - @test S isa ROCMatrix{real(T)} && size(S) == (m, n) - end - @test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (n, n) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - - Sc = similar(A, real(T), min(m, n)) - Sc2 = svd_vals!(copy!(Ac, A), Sc, alg) - @test Sc === Sc2 - @test ROCArray(diagview(S)) ≈ Sc - # ROCArray is necessary because norm of ROCArray view with non-unit step is broken - if alg isa ROCSOLVER_QRIteration - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_full!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad")) - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_vals!(copy!(Ac, A), Sc, ROCSOLVER_QRIteration(; bad = "bad")) - end - end - end - end - @testset "size (0, 0)" begin - for A in (ROCArray(randn(rng, T, 0, 0)), Diagonal(ROCArray(randn(rng, T, 0)))) - @testset "algorithm $alg" for alg in algs(A) - U, S, Vᴴ = svd_full(A; alg) - @test U isa ROCMatrix{T} && size(U) == (0, 0) - if isa(A, Diagonal) - @test S isa Diagonal{real(T), <:ROCVector{real(T)}} - else - @test S isa ROCMatrix{real(T)} - end - @test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (0, 0) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - end - end - end -end - -# @testset "svd_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) -# rng = StableRNG(123) -# m = 54 -# if LinearAlgebra.LAPACK.version() < v"3.12.0" -# algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) -# else -# algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), -# LAPACK_Jacobi()) -# end -# -# @testset "size ($m, $n)" for n in (37, m, 63) -# @testset "algorithm $alg" for alg in algs -# n > m && alg isa LAPACK_Jacobi && continue # not supported -# A = randn(rng, T, m, n) -# S₀ = svd_vals(A) -# minmn = min(m, n) -# r = minmn - 2 -# -# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) -# @test length(S1.diag) == r -# @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] -# -# s = 1 + sqrt(eps(real(T))) -# trunc2 = trunctol(; atol=s * S₀[r + 1]) -# -# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) -# @test length(S2.diag) == r -# @test U1 ≈ U2 -# @test S1 ≈ S2 -# @test V1ᴴ ≈ V2ᴴ -# end -# end -# end diff --git a/test/cuda/eig.jl b/test/cuda/eig.jl deleted file mode 100644 index e40ee9b8..00000000 --- a/test/cuda/eig.jl +++ /dev/null @@ -1,108 +0,0 @@ -using MatrixAlgebraKit -using LinearAlgebra: Diagonal -using Test -using TestExtras -using StableRNGs -using CUDA -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm - -include(joinpath("..", "utilities.jl")) - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "eig_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (CUSOLVER_Simple(), :CUSOLVER_Simple, CUSOLVER_Simple) - A = CuArray(randn(rng, T, m, m)) - Tc = complex(T) - - D, V = @constinferred eig_full(A; alg = ($alg)) - @test eltype(D) == eltype(V) == Tc - @test A * V ≈ V * D - - alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg) - - Ac = similar(A) - D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′) - @test D2 === D - @test V2 === V - @test A * V ≈ V * D - - Dc = @constinferred eig_vals(A, alg′) - @test eltype(Dc) == Tc - @test parent(D) ≈ Dc - end -end - -#= -@testset "eig_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - rng = StableRNG(123) - m = 54 - for alg in (CUSOLVER_Simple(),) - A = CuArray(randn(rng, T, m, m)) - A *= A' # TODO: deal with eigenvalue ordering etc - # eigenvalues are sorted by ascending real component... - D₀ = sort!(eig_vals(A); by=abs, rev=true) - rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) - r = length(D₀) - rmin - - D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc=truncrank(r)) - @test length(D1.diag) == r - @test A * V1 ≈ V1 * D1 - - s = 1 + sqrt(eps(real(T))) - trunc = trunctol(; atol=s * abs(D₀[r + 1])) - D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test A * V2 ≈ V2 * D2 - - # trunctol keeps order, truncrank might not - # test for same subspace - @test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2 - @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1 - end -end - -@testset "eig_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 4 - atol = sqrt(eps(real(T))) - V = randn(rng, T, m, m) - D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - A = V * D * inv(V) - alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) - - alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1)) - D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol -end -=# - -@testset "eig for Diagonal{$T}" for T in BLASFloats - rng = StableRNG(123) - m = 54 - Ad = CuArray(randn(rng, T, m)) - A = Diagonal(Ad) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eig_full(A) - @test D isa Diagonal{T} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eig_vals(A) - @test D2 isa AbstractVector{T} && length(D2) == m - @test diagview(D) ≈ D2 - - #=A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol=# -end diff --git a/test/cuda/eigh.jl b/test/cuda/eigh.jl deleted file mode 100644 index a8171615..00000000 --- a/test/cuda/eigh.jl +++ /dev/null @@ -1,118 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview -using CUDA - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "eigh_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (CUSOLVER_DivideAndConquer(), CUSOLVER_Jacobi()) - A = CuArray(randn(rng, T, m, m)) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 === D - @test V2 === V - - D3 = @constinferred eigh_vals(A, alg) - @test parent(D) ≈ D3 - end -end -#= #TODO mul! -@testset "eigh_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (CUSOLVER_QRIteration(), - CUSOLVER_DivideAndConquer(), - ) - A = CuArray(randn(rng, T, m, m)) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - r = m - 2 - s = 1 + sqrt(eps(real(T))) - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - @test V1 * (V1' * V3) ≈ V3 - @test V3 * (V3' * V1) ≈ V1 - end -end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 4 - V = qr_compact(CuArray(randn(rng, T, m, m)))[1] - D = Diagonal([0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) - @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - - alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncerror(; atol = 0.2)) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol -end -=# -@testset "eigh for Diagonal{$T}" for T in BLASFloats - rng = StableRNG(123) - m = 54 - Ad = randn(rng, T, m) - Ad .+= conj.(Ad) - A = Diagonal(CuArray(Ad)) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eigh_full(A) - @test D isa Diagonal{real(T)} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eigh_vals(A) - @test D2 isa AbstractVector{real(T)} && length(D2) == m - @test diagview(D) ≈ D2 - - # TODO partialsortperm - #=A2 = Diagonal(CuArray(T[0.9, 0.3, 0.1, 0.01])) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol=# -end diff --git a/test/cuda/lq.jl b/test/cuda/lq.jl deleted file mode 100644 index de904c88..00000000 --- a/test/cuda/lq.jl +++ /dev/null @@ -1,163 +0,0 @@ -using MatrixAlgebraKit -using MatrixAlgebraKit: diagview -using Test -using TestExtras -using StableRNGs -using CUDA -using LinearAlgebra - -include(joinpath("..", "utilities.jl")) - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "lq_compact! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = CuArray(randn(rng, T, m, n)) - L, Q = @constinferred lq_compact(A) - @test L isa CuMatrix{T} && size(L) == (m, minmn) - @test Q isa CuMatrix{T} && size(Q) == (minmn, n) - @test L * Q ≈ A - @test isapproxone(Q * Q') - Nᴴ = @constinferred lq_null(A) - @test Nᴴ isa CuMatrix{T} && size(Nᴴ) == (n - minmn, n) - @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) - @test isapproxone(Nᴴ * Nᴴ') - - Ac = similar(A) - L2, Q2 = @constinferred lq_compact!(copy!(Ac, A), (L, Q)) - @test L2 === L - @test Q2 === Q - Nᴴ2 = @constinferred lq_null!(copy!(Ac, A), Nᴴ) - @test Nᴴ2 === Nᴴ - - # noL - noL = similar(A, 0, minmn) - Q2 = similar(Q) - lq_compact!(copy!(Ac, A), (noL, Q2)) - @test Q == Q2 - - # positive - lq_compact!(copy!(Ac, A), (L, Q); positive = true) - @test L * Q ≈ A - @test isapproxone(Q * Q') - @test all(>=(zero(real(T))), real(diagview(L))) - lq_compact!(copy!(Ac, A), (noL, Q2); positive = true) - @test Q == Q2 - - # explicit blocksize - lq_compact!(copy!(Ac, A), (L, Q); blocksize = 1) - @test L * Q ≈ A - @test isapproxone(Q * Q') - lq_compact!(copy!(Ac, A), (noL, Q2); blocksize = 1) - @test Q == Q2 - lq_null!(copy!(Ac, A), Nᴴ; blocksize = 1) - @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) - @test isapproxone(Nᴴ * Nᴴ') - if m <= n - lq_compact!(copy!(Q2, A), (noL, Q2); blocksize = 1) # in-place Q - @test Q ≈ Q2 - # these do not work because of the in-place Q - @test_throws ArgumentError lq_compact!(copy!(Q2, A), (L, Q2); blocksize = 1) - @test_throws ArgumentError lq_compact!(copy!(Q2, A), (noL, Q2); positive = true) - end - # no blocked CUDA - @test_throws ArgumentError lq_compact!(copy!(Q2, A), (L, Q2); blocksize = 8) - end -end - -@testset "lq_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = CuArray(randn(rng, T, m, n)) - L, Q = lq_full(A) - @test L isa CuMatrix{T} && size(L) == (m, n) - @test Q isa CuMatrix{T} && size(Q) == (n, n) - @test L * Q ≈ A - @test isapproxone(Q * Q') - - Ac = similar(A) - L2, Q2 = @constinferred lq_full!(copy!(Ac, A), (L, Q)) - @test L2 === L - @test Q2 === Q - @test L * Q ≈ A - @test isapproxone(Q * Q') - - # noL - noL = similar(A, 0, n) - Q2 = similar(Q) - lq_full!(copy!(Ac, A), (noL, Q2)) - @test Q == Q2 - - # positive - lq_full!(copy!(Ac, A), (L, Q); positive = true) - @test L * Q ≈ A - @test isapproxone(Q * Q') - @test all(>=(zero(real(T))), real(diagview(L))) - lq_full!(copy!(Ac, A), (noL, Q2); positive = true) - @test Q == Q2 - - # explicit blocksize - lq_full!(copy!(Ac, A), (L, Q); blocksize = 1) - @test L * Q ≈ A - @test isapproxone(Q * Q') - lq_full!(copy!(Ac, A), (noL, Q2); blocksize = 1) - @test Q == Q2 - if n == m - lq_full!(copy!(Q2, A), (noL, Q2); blocksize = 1) # in-place Q - @test Q ≈ Q2 - # these do not work because of the in-place Q - @test_throws ArgumentError lq_full!(copy!(Q2, A), (L, Q2); blocksize = 1) - @test_throws ArgumentError lq_full!(copy!(Q2, A), (noL, Q2); positive = true) - end - # no blocked CUDA - @test_throws ArgumentError lq_full!(copy!(Ac, A), (L, Q); blocksize = 8) - end -end - -@testset "lq_compact, lq_full and lq_null for Diagonal{$T}" for T in BLASFloats - rng = StableRNG(123) - atol = eps(real(T))^(3 / 4) - for m in (54, 0) - Ad = CuArray(randn(rng, T, m)) - A = Diagonal(Ad) - - # compact - L, Q = @constinferred lq_compact(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test L isa Diagonal{T} && size(L) == (m, m) - @test L * Q ≈ A - @test isunitary(Q) - - # compact and positive - Lp, Qp = @constinferred lq_compact(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Lp isa Diagonal{T} && size(Lp) == (m, m) - @test Lp * Qp ≈ A - @test isunitary(Qp) - @test all(isposdef.(diagview(Lp))) - - # full - L, Q = @constinferred lq_full(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test L isa Diagonal{T} && size(L) == (m, m) - @test L * Q ≈ A - @test isunitary(Q) - - # full and positive - Lp, Qp = @constinferred lq_full(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Lp isa Diagonal{T} && size(Lp) == (m, m) - @test Lp * Qp ≈ A - @test isunitary(Qp) - @test all(isposdef.(diagview(Lp))) - - # null - N = @constinferred lq_null(A) - @test N isa AbstractMatrix{T} && size(N) == (0, m) - end -end diff --git a/test/cuda/orthnull.jl b/test/cuda/orthnull.jl deleted file mode 100644 index 2a2a26f6..00000000 --- a/test/cuda/orthnull.jl +++ /dev/null @@ -1,264 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, I, mul!, diagm, norm -using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, - initialize_output, AbstractAlgorithm -using CUDA - -# testing non-AbstractArray codepaths: -include(joinpath("..", "linearmap.jl")) - -eltypes = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "left_orth and left_null for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - @testset for n in (37, m, 63) - minmn = min(m, n) - A = CuArray(randn(rng, T, m, n)) - V, C = @constinferred left_orth(A) - N = @constinferred left_null(A) - @test V isa CuMatrix{T} && size(V) == (m, minmn) - @test C isa CuMatrix{T} && size(C) == (minmn, n) - @test N isa CuMatrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - hV = collect(V) - hN = collect(N) - @test hV * hV' + hN * hN' ≈ I - - M = LinearMap(A) - VM, CM = @constinferred left_orth(M; alg = :svd) - @test parent(VM) * parent(CM) ≈ A - - if m > n - nullity = 5 - V, C = @constinferred left_orth(A) - CUDA.@allowscalar begin - N = @constinferred left_null(A; trunc = (; maxnullity = nullity)) - end - @test V isa CuMatrix{T} && size(V) == (m, minmn) - @test C isa CuMatrix{T} && size(C) == (minmn, n) - @test N isa CuMatrix{T} && size(N) == (m, nullity) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - end - - # passing a kind and some kwargs - V, C = @constinferred left_orth(A; alg = :qr, positive = true) - N = @constinferred left_null(A; alg = :qr, positive = true) - @test V isa CuMatrix{T} && size(V) == (m, minmn) - @test C isa CuMatrix{T} && size(C) == (minmn, n) - @test N isa CuMatrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - hV = collect(V) - hN = collect(N) - @test hV * hV' + hN * hN' ≈ I - - # passing an algorithm - V, C = @constinferred left_orth(A; alg = CUSOLVER_HouseholderQR()) - N = @constinferred left_null(A; alg = :qr, positive = true) - @test V isa CuMatrix{T} && size(V) == (m, minmn) - @test C isa CuMatrix{T} && size(C) == (minmn, n) - @test N isa CuMatrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - hV = collect(V) - hN = collect(N) - @test hV * hV' + hN * hN' ≈ I - - Ac = similar(A) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C)) - N2 = @constinferred left_null!(copy!(Ac, A), N) - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - - atol = eps(real(T)) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = (; atol = atol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol)) - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - - rtol = eps(real(T)) - for (trunc_orth, trunc_null) in ( - ((; rtol = rtol), (; rtol = rtol)), - (trunctol(; rtol), trunctol(; rtol, keep_below = true)), - ) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = trunc_orth) - N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null) - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - end - - @testset for alg in (:qr, :polar, :svd) # explicit alg kwarg - m < n && alg == :polar && continue - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg))) - @test V2 * C2 ≈ A - @test isisometric(V2) - if alg != :polar - N2 = @constinferred left_null!(copy!(Ac, A), N; alg) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - end - - # with alg and tol kwargs - if alg == :svd - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; atol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; rtol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - hV2 = collect(V2) - hN2 = collect(N2) - @test hV2 * hV2' + hN2 * hN2' ≈ I - else - @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) - @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) - alg == :polar && continue - @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) - @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) - end - end - end -end - -@testset "right_orth and right_null for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - @testset for n in (37, m, 63) - minmn = min(m, n) - A = CuArray(randn(rng, T, m, n)) - C, Vᴴ = @constinferred right_orth(A) - Nᴴ = @constinferred right_null(A) - @test C isa CuMatrix{T} && size(C) == (m, minmn) - @test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (minmn, n) - @test Nᴴ isa CuMatrix{T} && size(Nᴴ) == (n - minmn, n) - @test C * Vᴴ ≈ A - @test isisometric(Vᴴ; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ; side = :right) - hVᴴ = collect(Vᴴ) - hNᴴ = collect(Nᴴ) - @test hVᴴ' * hVᴴ + hNᴴ' * hNᴴ ≈ I - - M = LinearMap(A) - CM, VMᴴ = @constinferred right_orth(M; alg = :svd) - @test parent(CM) * parent(VMᴴ) ≈ A - - Ac = similar(A) - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - - atol = eps(real(T)) - rtol = eps(real(T)) - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol = atol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol = atol)) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol = rtol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol = rtol)) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - - @testset "alg = $alg" for alg in (:lq, :polar, :svd) - n < m && alg == :polar && continue - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg))) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - if alg != :polar - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg))) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - end - - if alg == :svd - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; atol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; atol)) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; rtol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; rtol)) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - hVᴴ2 = collect(Vᴴ2) - hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - else - @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) - @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) - alg == :polar && continue - @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) - @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) - end - end - end -end diff --git a/test/cuda/polar.jl b/test/cuda/polar.jl deleted file mode 100644 index 1f512367..00000000 --- a/test/cuda/polar.jl +++ /dev/null @@ -1,83 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, I, isposdef, Hermitian -using MatrixAlgebraKit: PolarViaSVD -using CUDA - -@testset "left_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m) - k = min(m, n) - svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi()) - @testset "algorithm $svd_alg" for svd_alg in svd_algs - A = CuArray(randn(rng, T, m, n)) - alg = PolarViaSVD(svd_alg) - W, P = left_polar(A; alg) - @test W isa CuMatrix{T} && size(W) == (m, n) - @test P isa CuMatrix{T} && size(P) == (n, n) - @test W * P ≈ A - @test isisometric(W) - # work around extremely strict Julia criteria for Hermiticity - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P)) - - Ac = similar(A) - W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, P), alg) - @test W2 === W - @test P2 === P - @test W * P ≈ A - @test isisometric(W) - # work around extremely strict Julia criteria for Hermiticity - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P)) - - noP = similar(P, (0, 0)) - W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, noP), alg) - @test P2 === noP - @test W2 === W - @test isisometric(W) - P = W' * A # compute P explicitly to verify W correctness - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(Hermitian(project_hermitian!(P))) - end - end -end - -@testset "right_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - rng = StableRNG(123) - n = 54 - @testset "size ($m, $n)" for m in (37, n) - k = min(m, n) - svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi()) - @testset "algorithm $svd_alg" for svd_alg in svd_algs - A = CuArray(randn(rng, T, m, n)) - alg = PolarViaSVD(svd_alg) - P, Wᴴ = right_polar(A; alg) - @test Wᴴ isa CuMatrix{T} && size(Wᴴ) == (m, n) - @test P isa CuMatrix{T} && size(P) == (m, m) - @test P * Wᴴ ≈ A - @test isisometric(Wᴴ; side = :right) - # work around extremely strict Julia criteria for Hermiticity - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P)) - - Ac = similar(A) - P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (P, Wᴴ), alg) - @test P2 === P - @test Wᴴ2 === Wᴴ - @test P * Wᴴ ≈ A - @test isisometric(Wᴴ; side = :right) - # work around extremely strict Julia criteria for Hermiticity - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P)) - - noP = similar(P, (0, 0)) - P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg) - @test P2 === noP - @test Wᴴ2 === Wᴴ - @test isisometric(Wᴴ; side = :right) - P = A * Wᴴ' # compute P explicitly to verify W correctness - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(Hermitian(project_hermitian!(P))) - end - end -end diff --git a/test/cuda/projections.jl b/test/cuda/projections.jl deleted file mode 100644 index 677ef520..00000000 --- a/test/cuda/projections.jl +++ /dev/null @@ -1,104 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, norm -using CUDA - -const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - noisefactor = eps(real(T))^(3 / 4) - for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64)) - for A in (CuArray(randn(rng, T, m, m)), Diagonal(CuArray(randn(rng, T, m)))) - Ah = (A + A') / 2 - Aa = (A - A') / 2 - Ac = copy(A) - - Bh = project_hermitian(A, alg) - @test ishermitian(Bh) - @test Bh ≈ Ah - @test A == Ac - Bh_approx = Bh + noisefactor * Aa - # this is still hermitian for real Diagonal: |A - A'| == 0 - @test !ishermitian(Bh_approx) || norm(Aa) == 0 - @test ishermitian(Bh_approx; rtol = 10 * noisefactor) - - Ba = project_antihermitian(A, alg) - @test isantihermitian(Ba) - @test Ba ≈ Aa - @test A == Ac - Ba_approx = Ba + noisefactor * Ah - @test !isantihermitian(Ba_approx) - # this is never anti-hermitian for real Diagonal: |A - A'| == 0 - @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0 - - Bh = project_hermitian!(Ac, alg) - @test Bh === Ac - @test ishermitian(Bh) - @test Bh ≈ Ah - - copy!(Ac, A) - Ba = project_antihermitian!(Ac, alg) - @test Ba === Ac - @test isantihermitian(Ba) - @test Ba ≈ Aa - end - end -end - -@testset "project_isometric! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m) - k = min(m, n) - svdalgs = (CUSOLVER_SVDPolar(), CUSOLVER_QRIteration(), CUSOLVER_Jacobi()) - algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO - @testset "algorithm $alg" for alg in algs - A = CuArray(randn(rng, T, m, n)) - W = project_isometric(A, alg) - @test isisometric(W) - W2 = project_isometric(W, alg) - @test W2 ≈ W # stability of the projection - @test W * (W' * A) ≈ A - - Ac = similar(A) - W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) - @test W2 === W - @test isisometric(W) - - # test that W is closer to A then any other isometry - for k in 1:10 - δA = CuArray(randn(rng, T, size(A)...)) - W = project_isometric(A, alg) - W2 = project_isometric(A + δA / 100, alg) - @test norm(A - W2) >= norm(A - W) - end - end - - m == n && @testset "DiagonalAlgorithm" begin - A = Diagonal(CuArray(randn(rng, T, m))) - alg = PolarViaSVD(DiagonalAlgorithm()) - W = project_isometric(A, alg) - @test isisometric(W) - W2 = project_isometric(W, alg) - @test W2 ≈ W # stability of the projection - @test W * (W' * A) ≈ A - - Ac = similar(A) - W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) - @test W2 === W - @test isisometric(W) - - # test that W is closer to A then any other isometry - for k in 1:10 - δA = Diagonal(CuArray(randn(rng, T, m))) - W = project_isometric(A, alg) - W2 = project_isometric(A + δA / 100, alg) - @test norm(A - W2) >= norm(A - W) - end - end - end -end diff --git a/test/cuda/qr.jl b/test/cuda/qr.jl deleted file mode 100644 index 73fa3985..00000000 --- a/test/cuda/qr.jl +++ /dev/null @@ -1,168 +0,0 @@ -using MatrixAlgebraKit -using MatrixAlgebraKit: diagview -using Test -using TestExtras -using StableRNGs -using CUDA -using LinearAlgebra -using LinearAlgebra: isposdef - -include(joinpath("..", "utilities.jl")) - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "qr_compact! and qr_null! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = CuArray(randn(rng, T, m, n)) - Q, R = @constinferred qr_compact(A) - @test Q isa CuMatrix{T} && size(Q) == (m, minmn) - @test R isa CuMatrix{T} && size(R) == (minmn, n) - @test Q * R ≈ A - N = @constinferred qr_null(A) - @test N isa CuMatrix{T} && size(N) == (m, m - minmn) - @test isapproxone(Q' * Q) - @test maximum(abs, A' * N) < eps(real(T))^(2 / 3) - @test isapproxone(N' * N) - - Ac = similar(A) - Q2, R2 = @constinferred qr_compact!(copy!(Ac, A), (Q, R)) - @test Q2 === Q - @test R2 === R - N2 = @constinferred qr_null!(copy!(Ac, A), N) - @test N2 === N - - # noR - Q2 = similar(Q) - noR = similar(A, minmn, 0) - qr_compact!(copy!(Ac, A), (Q2, noR)) - @test Q == Q2 - - # positive - qr_compact!(copy!(Ac, A), (Q, R); positive = true) - @test Q * R ≈ A - @test isapproxone(Q' * Q) - @test all(>=(zero(real(T))), real(diagview(R))) - qr_compact!(copy!(Ac, A), (Q2, noR); positive = true) - @test Q == Q2 - - # explicit blocksize - qr_compact!(copy!(Ac, A), (Q, R); blocksize = 1) - @test Q * R ≈ A - @test isapproxone(Q' * Q) - qr_compact!(copy!(Ac, A), (Q2, noR); blocksize = 1) - @test Q == Q2 - qr_compact!(copy!(Ac, A), (Q2, noR); blocksize = 1) - qr_null!(copy!(Ac, A), N; blocksize = 1) - @test maximum(abs, A' * N) < eps(real(T))^(2 / 3) - @test isapproxone(N' * N) - if n <= m - qr_compact!(copy!(Q2, A), (Q2, noR); blocksize = 1) # in-place Q - @test Q ≈ Q2 - # these do not work because of the in-place Q - @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, R2)) - @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, noR); positive = true) - end - # no blocked CUDA - @test_throws ArgumentError qr_compact!(copy!(Ac, A), (Q2, R); blocksize = 8) - end -end - -@testset "qr_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 63 - for n in (37, m, 63) - minmn = min(m, n) - A = CuArray(randn(rng, T, m, n)) - Q, R = qr_full(A) - @test Q isa CuMatrix{T} && size(Q) == (m, m) - @test R isa CuMatrix{T} && size(R) == (m, n) - @test Q * R ≈ A - @test isapproxone(Q' * Q) - - Ac = similar(A) - Q2 = similar(Q) - noR = similar(A, m, 0) - Q2, R2 = @constinferred qr_full!(copy!(Ac, A), (Q, R)) - @test Q2 === Q - @test R2 === R - @test Q * R ≈ A - @test isapproxone(Q' * Q) - qr_full!(copy!(Ac, A), (Q2, noR)) - @test Q == Q2 - - # noR - noR = similar(A, m, 0) - Q2 = similar(Q) - qr_full!(copy!(Ac, A), (Q2, noR)) - @test Q == Q2 - - # positive - qr_full!(copy!(Ac, A), (Q, R); positive = true) - @test Q * R ≈ A - @test isapproxone(Q' * Q) - @test all(>=(zero(real(T))), real(diagview(R))) - qr_full!(copy!(Ac, A), (Q2, noR); positive = true) - @test Q == Q2 - - # explicit blocksize - qr_full!(copy!(Ac, A), (Q, R); blocksize = 1) - @test Q * R ≈ A - @test isapproxone(Q' * Q) - qr_full!(copy!(Ac, A), (Q2, noR); blocksize = 1) - @test Q == Q2 - if n == m - qr_full!(copy!(Q2, A), (Q2, noR); blocksize = 1) # in-place Q - @test Q ≈ Q2 - @test_throws ArgumentError qr_full!(copy!(Q2, A), (Q2, R2)) - @test_throws ArgumentError qr_full!(copy!(Q2, A), (Q2, noR); positive = true) - end - # no blocked CUDA - @test_throws ArgumentError qr_full!(copy!(Ac, A), (Q, R); blocksize = 8) - end -end - -@testset "qr_compact, qr_full and qr_null for Diagonal{$T}" for T in BLASFloats - rng = StableRNG(123) - atol = eps(real(T))^(3 / 4) - for m in (54, 0) - Ad = CuArray(randn(rng, T, m)) - A = Diagonal(Ad) - - # compact - Q, R = @constinferred qr_compact(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test R isa Diagonal{T} && size(R) == (m, m) - @test Q * R ≈ A - @test isunitary(Q) - - # compact and positive - Qp, Rp = @constinferred qr_compact(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Rp isa Diagonal{T} && size(Rp) == (m, m) - @test Qp * Rp ≈ A - @test isunitary(Qp) - @test all(isposdef.(diagview(Rp))) - - # full - Q, R = @constinferred qr_full(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test R isa Diagonal{T} && size(R) == (m, m) - @test Q * R ≈ A - @test isunitary(Q) - - # full and positive - Qp, Rp = @constinferred qr_full(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Rp isa Diagonal{T} && size(Rp) == (m, m) - @test Qp * Rp ≈ A - @test isunitary(Qp) - @test all(isposdef.(diagview(Rp))) - - # null - N = @constinferred qr_null(A) - @test N isa AbstractMatrix{T} && size(N) == (m, 0) - end -end diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl deleted file mode 100644 index fc564fec..00000000 --- a/test/cuda/svd.jl +++ /dev/null @@ -1,161 +0,0 @@ -using MatrixAlgebraKit -using MatrixAlgebraKit: diagview -using LinearAlgebra: Diagonal, isposdef, norm, opnorm -using Test -using TestExtras -using StableRNGs -using CUDA - -include(joinpath("..", "utilities.jl")) - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "svd_compact! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63) - k = min(m, n) - algs(::CuArray) = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()) - algs(::Diagonal) = (DiagonalAlgorithm(),) - As = m == n ? (CuArray(randn(rng, T, m, n)), Diagonal(CuArray(randn(rng, T, m)))) : (CuArray(randn(rng, T, m, n)),) - for A in As - @testset "algorithm $alg" for alg in algs(A) - minmn = min(m, n) - U, S, Vᴴ = svd_compact(A; alg) - @test U isa CuMatrix{T} && size(U) == (m, minmn) - @test S isa Diagonal{real(T), <:CuVector} && size(S) == (minmn, minmn) - @test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (minmn, n) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(Vᴴ * Vᴴ') - @test isposdef(S) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(Vᴴ * Vᴴ') - @test isposdef(S) - - Sd = svd_vals(A, alg) - @test CuArray(diagview(S)) ≈ Sd - # CuArray is necessary because norm of CuArray view with non-unit step is broken - if alg isa CUSOLVER_QRIteration - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad")) - end - end - end - end -end - -@testset "svd_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - algs(::CuArray) = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()) - algs(::Diagonal) = (DiagonalAlgorithm(),) - @testset "size ($m, $n)" for n in (37, m, 63) - As = m == n ? (CuArray(randn(rng, T, m, n)), Diagonal(CuArray(randn(rng, T, m)))) : (CuArray(randn(rng, T, m, n)),) - for A in As - @testset "algorithm $alg" for alg in algs(A) - minmn = min(m, n) - U, S, Vᴴ = svd_full(A; alg) - @test U isa CuMatrix{T} && size(U) == (m, m) - if A isa Diagonal - @test S isa Diagonal{real(T), <:CuVector{real(T)}} && size(S) == (m, n) - else - @test S isa CuMatrix{real(T)} && size(S) == (m, n) - end - @test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (n, n) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - - minmn = min(m, n) - Sc = similar(A, real(T), minmn) - Sc2 = svd_vals!(copy!(Ac, A), Sc, alg) - @test Sc === Sc2 - @test CuArray(diagview(S)) ≈ Sc - # CuArray is necessary because norm of CuArray view with non-unit step is broken - if alg isa CUSOLVER_QRIteration - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_full!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad")) - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_vals!(copy!(Ac, A), Sc, CUSOLVER_QRIteration(; bad = "bad")) - end - end - end - end - @testset "size (0, 0)" begin - for A in (CuArray(randn(rng, T, 0, 0)), Diagonal(CuArray(randn(rng, T, 0)))) - @testset "algorithm $alg" for alg in algs(A) - U, S, Vᴴ = svd_full(A; alg) - @test U isa CuMatrix{T} && size(U) == (0, 0) - @test size(S) == (0, 0) - if isa(A, Diagonal) - @test S isa Diagonal{real(T), <:CuVector{real(T)}} - else - @test S isa CuMatrix{real(T)} - end - @test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (0, 0) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - end - end - end -end - -@testset "svd_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63) - k = min(m, n) - 20 - p = min(m, n) - k - 1 - algs(::CuArray) = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi(), CUSOLVER_Randomized(; k = k, p = p, niters = 100)) - algs(::Diagonal) = (DiagonalAlgorithm(),) - hAs = m == n ? (randn(rng, T, m, n), Diagonal(randn(rng, T, m))) : (randn(rng, T, m, n),) - minmn = min(m, n) - for hA in hAs - A = CuArray(hA) - @testset "algorithm $alg" for alg in algs(A) - S₀ = svd_vals(hA) - r = k - - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) - @test length(S1.diag) == r - @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] - @test norm(A - U1 * S1 * V1ᴴ) ≈ ϵ1 - - if !(alg isa CUSOLVER_Randomized) - s = 1 + sqrt(eps(real(T))) - trunc2 = trunctol(; atol = s * S₀[r + 1]) - - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) - @test length(S2.diag) == r - @test U1 ≈ U2 - @test parent(S1) ≈ parent(S2) - @test V1ᴴ ≈ V2ᴴ - end - end - end - end -end diff --git a/test/eig.jl b/test/eig.jl index 6da6d72c..5558be02 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -4,113 +4,36 @@ using TestExtras using StableRNGs using LinearAlgebra: Diagonal using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm +using CUDA, AMDGPU BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) -GenericFloats = (Float16, BigFloat, Complex{BigFloat}) - -@testset "eig_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (LAPACK_Simple(), LAPACK_Expert(), :LAPACK_Simple, LAPACK_Simple) - A = randn(rng, T, m, m) - Tc = complex(T) - - D, V = @constinferred eig_full(A; alg = ($alg)) - @test eltype(D) == eltype(V) == Tc - @test A * V ≈ V * D - - alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg) - - Ac = similar(A) - D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′) - @test D2 === D - @test V2 === V - @test A * V ≈ V * D - - Dc = @constinferred eig_vals(A, alg′) - @test eltype(Dc) == Tc - @test D ≈ Diagonal(Dc) +GenericFloats = (Float16,) #BigFloat, Complex{BigFloat}) + +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 54 +for T in BLASFloats + TestSuite.seed_rng!(123) + if is_buildkite + if CUDA.functional() + TestSuite.test_eig(CuMatrix{T}, (m, m); test_blocksize = false) + TestSuite.test_eig(Diagonal{T, CuVector{T}}, m; test_blocksize = false) + end + #= not yet supported + if AMDGPU.functional() + TestSuite.test_eig(ROCMatrix{T}, (m, m); test_blocksize = false) + TestSuite.test_eig(Diagonal{T, ROCVector{T}}, m; test_blocksize = false) + end=# + else + TestSuite.test_eig(T, (m, m)) end end - -@testset "eig_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (LAPACK_Simple(), LAPACK_Expert()) - A = randn(rng, T, m, m) - A *= A' # TODO: deal with eigenvalue ordering etc - # eigenvalues are sorted by ascending real component... - D₀ = sort!(eig_vals(A); by = abs, rev = true) - rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) - r = length(D₀) - rmin - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) - @test length(diagview(D1)) == r - @test A * V1 ≈ V1 * D1 - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 + sqrt(eps(real(T))) - trunc = trunctol(; atol = s * abs(D₀[r + 1])) - D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # trunctol keeps order, truncrank might not - # test for same subspace - @test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2 - @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1 - @test V1 * ((V1' * V1) \ (V1' * V3)) ≈ V3 - @test V3 * ((V3' * V3) \ (V3' * V1)) ≈ V1 +if !is_buildkite + for T in (BLASFloats..., GenericFloats...) + AT = Diagonal{T, Vector{T}} + TestSuite.test_eig(AT, m; test_blocksize = false) end end - -@testset "eig_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 4 - atol = sqrt(eps(real(T))) - V = randn(rng, T, m, m) - D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - A = V * D * inv(V) - alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) - - alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1)) - D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol -end - -@testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) - rng = StableRNG(123) - m = 54 - Ad = randn(rng, T, m) - A = Diagonal(Ad) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eig_full(A) - @test D isa Diagonal{T} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eig_vals(A) - @test D2 isa AbstractVector{T} && length(D2) == m - @test diagview(D) ≈ D2 - - A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol -end diff --git a/test/eigh.jl b/test/eigh.jl index 92b0f3a0..830269da 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -3,124 +3,35 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm +using CUDA, AMDGPU BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) GenericFloats = (Float16, BigFloat, Complex{BigFloat}) -@testset "eigh_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in ( - LAPACK_MultipleRelativelyRobustRepresentations(), - LAPACK_DivideAndConquer(), - LAPACK_QRIteration(), - LAPACK_Bisection(), - ) - A = randn(rng, T, m, m) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 === D - @test V2 === V - - D3 = @constinferred eigh_vals(A, alg) - @test D ≈ Diagonal(D3) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 54 +for T in BLASFloats + TestSuite.seed_rng!(123) + if is_buildkite + if CUDA.functional() + TestSuite.test_eigh(CuMatrix{T}, (m, m); test_blocksize = false) + TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m; test_blocksize = false) + end + if AMDGPU.functional() + TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_blocksize = false) + TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m; test_blocksize = false) + end + else + TestSuite.test_eigh(T, (m, m)) end end - -@testset "eigh_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in ( - LAPACK_QRIteration(), - LAPACK_Bisection(), - LAPACK_DivideAndConquer(), - LAPACK_MultipleRelativelyRobustRepresentations(), - ) - A = randn(rng, T, m, m) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - r = m - 2 - s = 1 + sqrt(eps(real(T))) - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - @test V1 * (V1' * V3) ≈ V3 - @test V3 * (V3' * V1) ≈ V1 +if !is_buildkite + for T in (BLASFloats..., GenericFloats...) + AT = Diagonal{T, Vector{T}} + TestSuite.test_eigh(AT, m; test_blocksize = false) end end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 4 - atol = sqrt(eps(real(T))) - V = qr_compact(randn(rng, T, m, m))[1] - D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] - @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - - alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol -end - -@testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) - rng = StableRNG(123) - m = 54 - Ad = randn(rng, T, m) - Ad .+= conj.(Ad) - A = Diagonal(Ad) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eigh_full(A) - @test D isa Diagonal{real(T)} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eigh_vals(A) - @test D2 isa AbstractVector{real(T)} && length(D2) == m - @test diagview(D) ≈ D2 - - A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol -end diff --git a/test/genericlinearalgebra/eigh.jl b/test/genericlinearalgebra/eigh.jl deleted file mode 100644 index 7e602026..00000000 --- a/test/genericlinearalgebra/eigh.jl +++ /dev/null @@ -1,93 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm -using GenericLinearAlgebra - -const eltypes = (BigFloat, Complex{BigFloat}) - -@testset "eigh_full! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - alg = GLA_QRIteration() - - A = randn(rng, T, m, m) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 ≈ D - @test V2 ≈ V - - D3 = @constinferred eigh_vals(A, alg) - @test D ≈ Diagonal(D3) -end - -@testset "eigh_trunc! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - alg = GLA_QRIteration() - A = randn(rng, T, m, m) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - - r = m - 2 - s = 1 + sqrt(eps(real(T))) - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) - Dfull, Vfull = eigh_full(A; alg) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - @test V1 * (V1' * V3) ≈ V3 - @test V3 * (V3' * V1) ≈ V1 -end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in eltypes - rng = StableRNG(123) - m = 4 - atol = sqrt(eps(real(T))) - V = qr_compact(randn(rng, T, m, m))[1] - D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(GLA_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] - @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - - alg = TruncatedAlgorithm(GLA_QRIteration(), truncerror(; atol = 0.2)) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol -end diff --git a/test/genericlinearalgebra/lq.jl b/test/genericlinearalgebra/lq.jl deleted file mode 100644 index dc186dcb..00000000 --- a/test/genericlinearalgebra/lq.jl +++ /dev/null @@ -1,124 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: diag, I, Diagonal -using GenericLinearAlgebra - -eltypes = (BigFloat, Complex{BigFloat}) - -@testset "qr_compact! for T = $T" for T in eltypes - - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - m = 54 - A = randn(rng, T, m, n) - L, Q = @constinferred lq_compact(A) - @test L isa Matrix{T} && size(L) == (m, minmn) - @test Q isa Matrix{T} && size(Q) == (minmn, n) - @test L * Q ≈ A - @test isisometric(Q; side = :right) - - Ac = similar(A) - L2, Q2 = @constinferred lq_compact!(copy!(Ac, A), (L, Q)) - @test L2 === L - @test Q2 === Q - - noL = similar(A, 0, minmn) - Q2 = similar(Q) - lq_compact!(copy!(Ac, A), (noL, Q2)) - @test Q == Q2 - - # Transposed QR algorithm - qr_alg = GLA_HouseholderQR() - lq_alg = LQViaTransposedQR(qr_alg) - L2, Q2 = @constinferred lq_compact!(copy!(Ac, A), (L, Q), lq_alg) - @test L2 === L - @test Q2 === Q - noL = similar(A, 0, minmn) - Q2 = similar(Q) - lq_compact!(copy!(Ac, A), (noL, Q2), lq_alg) - @test Q == Q2 - - @test_throws ArgumentError lq_compact(A; blocksize = 2) - @test_throws ArgumentError lq_compact(A; pivoted = true) - - # positive - lq_compact!(copy!(Ac, A), (L, Q); positive = true) - @test L * Q ≈ A - @test isisometric(Q; side = :right) - @test all(>=(zero(real(T))), real(diag(L))) - lq_compact!(copy!(Ac, A), (noL, Q2); positive = true) - @test Q == Q2 - end -end - -@testset "lq_full! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = randn(rng, T, m, n) - L, Q = lq_full(A) - @test L isa Matrix{T} && size(L) == (m, n) - @test Q isa Matrix{T} && size(Q) == (n, n) - @test L * Q ≈ A - @test isunitary(Q) - - Ac = similar(A) - L2, Q2 = @constinferred lq_full!(copy!(Ac, A), (L, Q)) - @test L2 === L - @test Q2 === Q - @test L * Q ≈ A - @test isunitary(Q) - - noL = similar(A, 0, n) - Q2 = similar(Q) - lq_full!(copy!(Ac, A), (noL, Q2)) - @test Q[1:minmn, n] ≈ Q2[1:minmn, n] - - # Transposed QR algorithm - qr_alg = GLA_HouseholderQR() - lq_alg = LQViaTransposedQR(qr_alg) - L2, Q2 = @constinferred lq_full!(copy!(Ac, A), (L, Q), lq_alg) - @test L2 === L - @test Q2 === Q - @test L * Q ≈ A - @test Q * Q' ≈ I - noL = similar(A, 0, n) - Q2 = similar(Q) - lq_full!(copy!(Ac, A), (noL, Q2), lq_alg) - @test Q[1:minmn, n] ≈ Q2[1:minmn, n] - - # Argument errors for unsupported options - @test_throws ArgumentError lq_full(A; blocksize = 2) - @test_throws ArgumentError lq_full(A; pivoted = true) - - # positive - lq_full!(copy!(Ac, A), (L, Q); positive = true) - @test L * Q ≈ A - @test isunitary(Q) - @test all(>=(zero(real(T))), real(diag(L))) - lq_full!(copy!(Ac, A), (noL, Q2); positive = true) - @test Q[1:minmn, n] ≈ Q2[1:minmn, n] - - qr_alg = GLA_HouseholderQR(; positive = true) - lq_alg = LQViaTransposedQR(qr_alg) - lq_full!(copy!(Ac, A), (L, Q), lq_alg) - @test L * Q ≈ A - @test Q * Q' ≈ I - @test all(>=(zero(real(T))), real(diag(L))) - lq_full!(copy!(Ac, A), (noL, Q2), lq_alg) - @test Q[1:minmn, n] ≈ Q2[1:minmn, n] - - # positive and blocksize 1 - lq_full!(copy!(Ac, A), (L, Q); positive = true, blocksize = 1) - @test L * Q ≈ A - @test isunitary(Q) - @test all(>=(zero(real(T))), real(diag(L))) - lq_full!(copy!(Ac, A), (noL, Q2); positive = true, blocksize = 1) - @test Q[1:minmn, n] ≈ Q2[1:minmn, n] - end -end diff --git a/test/genericlinearalgebra/qr.jl b/test/genericlinearalgebra/qr.jl deleted file mode 100644 index 3ce530bb..00000000 --- a/test/genericlinearalgebra/qr.jl +++ /dev/null @@ -1,109 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: diag, I, Diagonal -using GenericLinearAlgebra - -eltypes = (BigFloat, Complex{BigFloat}) - -@testset "qr_compact! for T = $T" for T in eltypes - - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - m = 54 - A = randn(rng, T, m, n) - Q, R = @constinferred qr_compact(A) - @test Q isa Matrix{T} && size(Q) == (m, minmn) - @test R isa Matrix{T} && size(R) == (minmn, n) - @test Q * R ≈ A - - Ac = similar(A) - Q2, R2 = @constinferred qr_compact!(copy!(Ac, A), (Q, R)) - @test Q2 === Q - @test R2 === R - - Q2 = similar(Q) - noR = similar(A, minmn, 0) - qr_compact!(copy!(Ac, A), (Q2, noR)) - @test Q == Q2 - - @test_throws ArgumentError qr_compact(A; blocksize = 2) - @test_throws ArgumentError qr_compact(A; pivoted = true) - - # positive - qr_compact!(copy!(Ac, A), (Q, R); positive = true) - @test Q * R ≈ A - @test isisometric(Q) - @test all(>=(zero(real(T))), real(diag(R))) - qr_compact!(copy!(Ac, A), (Q2, noR); positive = true) - @test Q == Q2 - end -end - -@testset "qr_full! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = randn(rng, T, m, n) - Q, R = qr_full(A) - @test Q isa Matrix{T} && size(Q) == (m, m) - @test R isa Matrix{T} && size(R) == (m, n) - Qc, Rc = qr_compact(A) - @test Q * R ≈ A - @test isunitary(Q) - - Ac = similar(A) - Q2 = similar(Q) - noR = similar(A, m, 0) - Q2, R2 = @constinferred qr_full!(copy!(Ac, A), (Q, R)) - @test Q2 === Q - @test R2 === R - @test Q * R ≈ A - @test isunitary(Q) - qr_full!(copy!(Ac, A), (Q2, noR)) - @test Q == Q2 - - # unblocked algorithm - qr_full!(copy!(Ac, A), (Q, R); blocksize = 1) - @test Q * R ≈ A - @test isunitary(Q) - qr_full!(copy!(Ac, A), (Q2, noR); blocksize = 1) - @test Q == Q2 - if n == m - qr_full!(copy!(Q2, A), (Q2, noR); blocksize = 1) # in-place Q - @test Q ≈ Q2 - end - - # Argument errors for unsupported options - @test_throws ArgumentError qr_full(A; blocksize = 2) - @test_throws ArgumentError qr_compact(A; pivoted = true) - - # positive - qr_full!(copy!(Ac, A), (Q, R); positive = true) - @test Q * R ≈ A - @test isunitary(Q) - @test all(>=(zero(real(T))), real(diag(R))) - qr_full!(copy!(Ac, A), (Q2, noR); positive = true) - @test Q == Q2 - # positive and blocksize 1 - qr_full!(copy!(Ac, A), (Q, R); positive = true, blocksize = 1) - @test Q * R ≈ A - @test isunitary(Q) - @test all(>=(zero(real(T))), real(diag(R))) - qr_full!(copy!(Ac, A), (Q2, noR); positive = true, blocksize = 1) - @test Q == Q2 - if n <= m - # the following test tries to find the diagonal element (in order to test positivity) - # before the column permutation. This only works if all columns have a diagonal - # element - for j in 1:n - i = findlast(!iszero, view(R, :, j)) - @test real(R[i, j]) >= zero(real(T)) - end - end - end -end diff --git a/test/genericlinearalgebra/svd.jl b/test/genericlinearalgebra/svd.jl deleted file mode 100644 index f7177e79..00000000 --- a/test/genericlinearalgebra/svd.jl +++ /dev/null @@ -1,171 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef, norm -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, isisometric -using GenericLinearAlgebra - -eltypes = (BigFloat, Complex{BigFloat}) - -@testset "svd_compact! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63, 0) - k = min(m, n) - alg = GLA_QRIteration() - minmn = min(m, n) - A = randn(rng, T, m, n) - - if VERSION < v"1.11" - # This is type unstable on older versions of Julia. - U, S, Vᴴ = svd_compact(A; alg) - else - U, S, Vᴴ = @constinferred svd_compact(A; alg = ($alg)) - end - @test U isa Matrix{T} && size(U) == (m, minmn) - @test S isa Diagonal{real(T)} && size(S) == (minmn, minmn) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (minmn, n) - @test U * S * Vᴴ ≈ A - @test isisometric(U) - @test isisometric(Vᴴ; side = :right) - @test isposdef(S) - - Ac = similar(A) - Sc = similar(A, real(T), min(m, n)) - alg′ = @constinferred MatrixAlgebraKit.select_algorithm(svd_compact!, A, $alg) - U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg′) - @test U2 ≈ U - @test S2 ≈ S - @test V2ᴴ ≈ Vᴴ - @test U * S * Vᴴ ≈ A - @test isisometric(U) - @test isisometric(Vᴴ; side = :right) - @test isposdef(S) - - Sd = @constinferred svd_vals(A, alg′) - @test S ≈ Diagonal(Sd) - end -end - -@testset "svd_full! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63, 0) - alg = GLA_QRIteration() - A = randn(rng, T, m, n) - U, S, Vᴴ = svd_full(A; alg) - @test U isa Matrix{T} && size(U) == (m, m) - @test S isa Matrix{real(T)} && size(S) == (m, n) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (n, n) - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 ≈ U - @test S2 ≈ S - @test V2ᴴ ≈ Vᴴ - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - - Sc = svd_vals!(copy!(Ac, A), alg) - @test diagview(S) ≈ Sc - end - @testset "size (0, 0)" begin - @testset "algorithm $alg" for alg in - (LAPACK_DivideAndConquer(), LAPACK_QRIteration()) - A = randn(rng, T, 0, 0) - U, S, Vᴴ = svd_full(A; alg) - @test U isa Matrix{T} && size(U) == (0, 0) - @test S isa Matrix{real(T)} && size(S) == (0, 0) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (0, 0) - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - end - end -end - -@testset "svd_trunc! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - atol = sqrt(eps(real(T))) - alg = GLA_QRIteration() - - @testset "size ($m, $n)" for n in (37, m, 63) - n > m && alg isa LAPACK_Jacobi && continue # not supported - A = randn(rng, T, m, n) - S₀ = svd_vals(A) - minmn = min(m, n) - r = minmn - 2 - - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) - @test length(diagview(S1)) == r - @test diagview(S1) ≈ S₀[1:r] - @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] - # Test truncation error - @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol - - s = 1 + sqrt(eps(real(T))) - trunc = trunctol(; atol = s * S₀[r + 1]) - - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) - @test length(diagview(S2)) == r - @test U1 ≈ U2 - @test S1 ≈ S2 - @test V1ᴴ ≈ V2ᴴ - @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol - - trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) - U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) - @test length(diagview(S3)) == r - @test U1 ≈ U3 - @test S1 ≈ S3 - @test V1ᴴ ≈ V3ᴴ - @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol - end -end - -@testset "svd_trunc! mix maxrank and tol for T = $T" for T in eltypes - rng = StableRNG(123) - alg = GLA_QRIteration() - m = 4 - U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - Vᴴ = qr_compact(randn(rng, T, m, m))[1] - A = U * S * Vᴴ - - for trunc_fun in ( - (rtol, maxrank) -> (; rtol, maxrank), - (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), - ) - U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) - @test length(diagview(S1)) == 1 - @test diagview(S1) ≈ diagview(S)[1:1] - - U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) - @test length(diagview(S2)) == 2 - @test diagview(S2) ≈ diagview(S)[1:2] - end -end - -@testset "svd_trunc! specify truncation algorithm T = $T" for T in eltypes - rng = StableRNG(123) - atol = sqrt(eps(real(T))) - m = 4 - U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - Vᴴ = qr_compact(randn(rng, T, m, m))[1] - A = U * S * Vᴴ - alg = TruncatedAlgorithm(GLA_QRIteration(), trunctol(; atol = 0.2)) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) - @test diagview(S2) ≈ diagview(S)[1:2] - @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol - @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) -end diff --git a/test/genericschur/eig.jl b/test/genericschur/eig.jl deleted file mode 100644 index ce1e8f1b..00000000 --- a/test/genericschur/eig.jl +++ /dev/null @@ -1,116 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: Diagonal -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm -using GenericSchur - -const eltypes = (BigFloat, Complex{BigFloat}) - -@testset "eig_full! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 24 - alg = GS_QRIteration() - A = randn(rng, T, m, m) - Tc = complex(T) - - D, V = @constinferred eig_full(A; alg = ($alg)) - @test eltype(D) == eltype(V) == Tc - @test A * V ≈ V * D - - alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg) - - Ac = similar(A) - D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′) - @test D2 ≈ D - @test V2 ≈ V - @test A * V ≈ V * D - - Dc = @constinferred eig_vals(A, alg′) - @test eltype(Dc) == Tc - @test D ≈ Diagonal(Dc) -end - -@testset "eig_trunc! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 6 - alg = GS_QRIteration() - A = randn(rng, T, m, m) - A *= A' # TODO: deal with eigenvalue ordering etc - # eigenvalues are sorted by ascending real component... - D₀ = sort!(eig_vals(A); by = abs, rev = true) - rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) - r = length(D₀) - rmin - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) - D1base, V1base = @constinferred eig_full(A; alg) - - @test length(diagview(D1)) == r - @test A * V1 ≈ V1 * D1 - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 + sqrt(eps(real(T))) - trunc = trunctol(; atol = s * abs(D₀[r + 1])) - D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # trunctol keeps order, truncrank might not - # test for same subspace - @test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2 - @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1 - @test V1 * ((V1' * V1) \ (V1' * V3)) ≈ V3 - @test V3 * ((V3' * V3) \ (V3' * V1)) ≈ V1 -end - -@testset "eig_trunc! specify truncation algorithm T = $T" for T in eltypes - rng = StableRNG(123) - m = 4 - atol = sqrt(eps(real(T))) - V = randn(rng, T, m, m) - D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - A = V * D * inv(V) - alg = TruncatedAlgorithm(GS_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) - - alg = TruncatedAlgorithm(GS_QRIteration(), truncerror(; atol = 0.2, p = 1)) - D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol -end - -@testset "eig for Diagonal{$T}" for T in eltypes - rng = StableRNG(123) - m = 24 - Ad = randn(rng, T, m) - A = Diagonal(Ad) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eig_full(A) - @test D isa Diagonal{T} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eig_vals(A) - @test D2 isa AbstractVector{T} && length(D2) == m - @test diagview(D) ≈ D2 - - A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol -end diff --git a/test/linearmap.jl b/test/linearmap.jl index bbf84e5e..ad6102d7 100644 --- a/test/linearmap.jl +++ b/test/linearmap.jl @@ -3,7 +3,7 @@ module LinearMaps export LinearMap using MatrixAlgebraKit - using MatrixAlgebraKit: AbstractAlgorithm + using MatrixAlgebraKit: AbstractAlgorithm, DiagonalAlgorithm import MatrixAlgebraKit as MAK using LinearAlgebra: LinearAlgebra, lmul!, rmul! @@ -32,6 +32,12 @@ module LinearMaps LinearMap.(MAK.initialize_output($f!, parent(A), alg)) @eval MAK.$f!(A::LinearMap, F, alg::AbstractAlgorithm) = LinearMap.(MAK.$f!(parent(A), parent.(F), alg)) + @eval MAK.check_input(::typeof($f!), A::LinearMap, F, alg::DiagonalAlgorithm) = + MAK.check_input($f!, parent(A), parent.(F), alg) + @eval MAK.initialize_output(::typeof($f!), A::LinearMap, alg::DiagonalAlgorithm) = + LinearMap.(MAK.initialize_output($f!, parent(A), alg)) + @eval MAK.$f!(A::LinearMap, F, alg::DiagonalAlgorithm) = + LinearMap.(MAK.$f!(parent(A), parent.(F), alg)) end for f in (:qr, :lq, :svd) diff --git a/test/lq.jl b/test/lq.jl index 8de5a582..4377c2c5 100644 --- a/test/lq.jl +++ b/test/lq.jl @@ -1,255 +1,37 @@ using MatrixAlgebraKit using Test -using TestExtras using StableRNGs using LinearAlgebra: diag, I, Diagonal using MatrixAlgebraKit: LQViaTransposedQR, LAPACK_HouseholderQR +using CUDA, AMDGPU BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) GenericFloats = (Float16, BigFloat, Complex{BigFloat}) -@testset "lq_compact! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = randn(rng, T, m, n) - L, Q = @constinferred lq_compact(A) - @test L isa Matrix{T} && size(L) == (m, minmn) - @test Q isa Matrix{T} && size(Q) == (minmn, n) - @test L * Q ≈ A - @test isisometric(Q; side = :right) - Nᴴ = @constinferred lq_null(A) - @test Nᴴ isa Matrix{T} && size(Nᴴ) == (n - minmn, n) - @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) - @test isisometric(Nᴴ; side = :right) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - Ac = similar(A) - L2, Q2 = @constinferred lq_compact!(copy!(Ac, A), (L, Q)) - @test L2 === L - @test Q2 === Q - Nᴴ2 = @constinferred lq_null!(copy!(Ac, A), Nᴴ) - @test Nᴴ2 === Nᴴ +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - noL = similar(A, 0, minmn) - Q2 = similar(Q) - lq_compact!(copy!(Ac, A), (noL, Q2)) - @test Q == Q2 - - # Transposed QR algorithm - qr_alg = LAPACK_HouseholderQR() - lq_alg = LQViaTransposedQR(qr_alg) - L2, Q2 = @constinferred lq_compact!(copy!(Ac, A), (L, Q), lq_alg) - @test L2 === L - @test Q2 === Q - Nᴴ2 = @constinferred lq_null!(copy!(Ac, A), Nᴴ, lq_alg) - @test Nᴴ2 === Nᴴ - noL = similar(A, 0, minmn) - Q2 = similar(Q) - lq_compact!(copy!(Ac, A), (noL, Q2), lq_alg) - @test Q == Q2 - - # unblocked algorithm - lq_compact!(copy!(Ac, A), (L, Q); blocksize = 1) - @test L * Q ≈ A - @test isisometric(Q; side = :right) - lq_compact!(copy!(Ac, A), (noL, Q2); blocksize = 1) - @test Q == Q2 - lq_null!(copy!(Ac, A), Nᴴ; blocksize = 1) - @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) - @test isisometric(Nᴴ; side = :right) - if m <= n - lq_compact!(copy!(Q2, A), (noL, Q2); blocksize = 1) # in-place Q - @test Q ≈ Q2 - @test_throws ArgumentError lq_compact!(copy!(Q2, A), (L, Q2); blocksize = 1) - @test_throws ArgumentError lq_compact!(copy!(Q2, A), (noL, Q2); positive = true) - @test_throws ArgumentError lq_compact!(copy!(Q2, A), (noL, Q2); blocksize = 8) +m = 54 +for T in BLASFloats, n in (37, m, 63) + TestSuite.seed_rng!(123) + if is_buildkite + if CUDA.functional() + TestSuite.test_lq(CuMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false) + n == m && TestSuite.test_lq(Diagonal{T, CuVector{T}}, m; test_pivoted = false, test_blocksize = false) end - lq_compact!(copy!(Ac, A), (L, Q); blocksize = 8) - @test L * Q ≈ A - @test isisometric(Q; side = :right) - lq_compact!(copy!(Ac, A), (noL, Q2); blocksize = 8) - @test Q == Q2 - lq_null!(copy!(Ac, A), Nᴴ; blocksize = 8) - @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) - @test isisometric(Nᴴ; side = :right) - @test Nᴴ * Nᴴ' ≈ I - - qr_alg = LAPACK_HouseholderQR(; blocksize = 1) - lq_alg = LQViaTransposedQR(qr_alg) - lq_compact!(copy!(Ac, A), (L, Q), lq_alg) - @test L * Q ≈ A - @test Q * Q' ≈ I - lq_compact!(copy!(Ac, A), (noL, Q2), lq_alg) - @test Q == Q2 - lq_null!(copy!(Ac, A), Nᴴ, lq_alg) - @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) - @test Nᴴ * Nᴴ' ≈ I - - # pivoted - @test_throws ArgumentError lq_compact!(copy!(Ac, A), (L, Q); pivoted = true) - - # positive - lq_compact!(copy!(Ac, A), (L, Q); positive = true) - @test L * Q ≈ A - @test isisometric(Q; side = :right) - @test all(>=(zero(real(T))), real(diag(L))) - lq_compact!(copy!(Ac, A), (noL, Q2); positive = true) - @test Q == Q2 - - # positive and blocksize 1 - lq_compact!(copy!(Ac, A), (L, Q); positive = true, blocksize = 1) - @test L * Q ≈ A - @test isisometric(Q; side = :right) - @test all(>=(zero(real(T))), real(diag(L))) - lq_compact!(copy!(Ac, A), (noL, Q2); positive = true, blocksize = 1) - @test Q == Q2 - qr_alg = LAPACK_HouseholderQR(; positive = true, blocksize = 1) - lq_alg = LQViaTransposedQR(qr_alg) - lq_compact!(copy!(Ac, A), (L, Q), lq_alg) - @test L * Q ≈ A - @test Q * Q' ≈ I - @test all(>=(zero(real(T))), real(diag(L))) - lq_compact!(copy!(Ac, A), (noL, Q2), lq_alg) - @test Q == Q2 - end -end - -@testset "lq_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = randn(rng, T, m, n) - L, Q = lq_full(A) - @test L isa Matrix{T} && size(L) == (m, n) - @test Q isa Matrix{T} && size(Q) == (n, n) - @test L * Q ≈ A - @test isunitary(Q) - - Ac = similar(A) - L2, Q2 = @constinferred lq_full!(copy!(Ac, A), (L, Q)) - @test L2 === L - @test Q2 === Q - @test L * Q ≈ A - @test isunitary(Q) - - noL = similar(A, 0, n) - Q2 = similar(Q) - lq_full!(copy!(Ac, A), (noL, Q2)) - @test Q == Q2 - - # Transposed QR algorithm - qr_alg = LAPACK_HouseholderQR() - lq_alg = LQViaTransposedQR(qr_alg) - L2, Q2 = @constinferred lq_full!(copy!(Ac, A), (L, Q), lq_alg) - @test L2 === L - @test Q2 === Q - @test L * Q ≈ A - @test Q * Q' ≈ I - noL = similar(A, 0, n) - Q2 = similar(Q) - lq_full!(copy!(Ac, A), (noL, Q2), lq_alg) - @test Q == Q2 - - # unblocked algorithm - lq_full!(copy!(Ac, A), (L, Q); blocksize = 1) - @test L * Q ≈ A - @test isunitary(Q) - lq_full!(copy!(Ac, A), (noL, Q2); blocksize = 1) - @test Q == Q2 - if n == m - lq_full!(copy!(Q2, A), (noL, Q2); blocksize = 1) # in-place Q - @test Q ≈ Q2 - end - qr_alg = LAPACK_HouseholderQR(; blocksize = 1) - lq_alg = LQViaTransposedQR(qr_alg) - lq_full!(copy!(Ac, A), (L, Q), lq_alg) - @test L * Q ≈ A - @test Q * Q' ≈ I - lq_full!(copy!(Ac, A), (noL, Q2), lq_alg) - @test Q == Q2 - if n == m - lq_full!(copy!(Q2, A), (noL, Q2), lq_alg) # in-place Q - @test Q ≈ Q2 + if AMDGPU.functional() + TestSuite.test_lq(ROCMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false) + n == m && TestSuite.test_lq(Diagonal{T, ROCVector{T}}, m; test_pivoted = false, test_blocksize = false) end - - # other blocking - lq_full!(copy!(Ac, A), (L, Q); blocksize = 18) - @test L * Q ≈ A - @test isunitary(Q) - lq_full!(copy!(Ac, A), (noL, Q2); blocksize = 18) - @test Q == Q2 - # pivoted - @test_throws ArgumentError lq_full!(copy!(Ac, A), (L, Q); pivoted = true) - # positive - lq_full!(copy!(Ac, A), (L, Q); positive = true) - @test L * Q ≈ A - @test isunitary(Q) - @test all(>=(zero(real(T))), real(diag(L))) - lq_full!(copy!(Ac, A), (noL, Q2); positive = true) - @test Q == Q2 - - qr_alg = LAPACK_HouseholderQR(; positive = true) - lq_alg = LQViaTransposedQR(qr_alg) - lq_full!(copy!(Ac, A), (L, Q), lq_alg) - @test L * Q ≈ A - @test Q * Q' ≈ I - @test all(>=(zero(real(T))), real(diag(L))) - lq_full!(copy!(Ac, A), (noL, Q2), lq_alg) - @test Q == Q2 - - # positive and blocksize 1 - lq_full!(copy!(Ac, A), (L, Q); positive = true, blocksize = 1) - @test L * Q ≈ A - @test isunitary(Q) - @test all(>=(zero(real(T))), real(diag(L))) - lq_full!(copy!(Ac, A), (noL, Q2); positive = true, blocksize = 1) - @test Q == Q2 + else + TestSuite.test_lq(T, (m, n); test_pivoted = false) end end - -@testset "lq_compact, lq_full and lq_null for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) - rng = StableRNG(123) - atol = eps(real(T))^(3 / 4) - for m in (54, 0) - Ad = randn(rng, T, m) - A = Diagonal(Ad) - - # compact - L, Q = @constinferred lq_compact(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test L isa Diagonal{T} && size(L) == (m, m) - @test L * Q ≈ A - @test isunitary(Q) - - # compact and positive - Lp, Qp = @constinferred lq_compact(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Lp isa Diagonal{T} && size(Lp) == (m, m) - @test Lp * Qp ≈ A - @test isunitary(Qp) - @test all(≥(zero(real(T))), real(diag(Lp))) && - all(≈(zero(real(T)); atol), imag(diag(Lp))) - - # full - L, Q = @constinferred lq_full(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test L isa Diagonal{T} && size(L) == (m, m) - @test L * Q ≈ A - @test isunitary(Q) - - # full and positive - Lp, Qp = @constinferred lq_full(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Lp isa Diagonal{T} && size(Lp) == (m, m) - @test Lp * Qp ≈ A - @test isunitary(Qp) - @test all(≥(zero(real(T))), real(diag(Lp))) && - all(≈(zero(real(T)); atol), imag(diag(Lp))) - - # null - N = @constinferred lq_null(A) - @test N isa AbstractMatrix{T} && size(N) == (0, m) +if !is_buildkite + for T in (BLASFloats..., GenericFloats...) + AT = Diagonal{T, Vector{T}} + TestSuite.test_lq(AT, m; test_pivoted = false, test_blocksize = false) end end diff --git a/test/orthnull.jl b/test/orthnull.jl index ce742e8f..e8339adb 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -2,225 +2,36 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using LinearAlgebra: LinearAlgebra, I +using LinearAlgebra: LinearAlgebra, I, Diagonal +using CUDA, AMDGPU -# testing non-AbstractArray codepaths: -include("linearmap.jl") +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, BigFloat, Complex{BigFloat}) -eltypes = (Float32, Float64, ComplexF32, ComplexF64) -@testset "left_orth and left_null for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = randn(rng, T, m, n) - V, C = @constinferred left_orth(A) - N = @constinferred left_null(A) - @test V isa Matrix{T} && size(V) == (m, minmn) - @test C isa Matrix{T} && size(C) == (minmn, n) - @test N isa Matrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - @test V * V' + N * N' ≈ I +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - M = LinearMap(A) - VM, CM = @constinferred left_orth(M; alg = :svd) - @test parent(VM) * parent(CM) ≈ A +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - if m > n - nullity = 5 - V, C = @constinferred left_orth(A) - N = @constinferred left_null(A; trunc = (; maxnullity = nullity)) - @test V isa Matrix{T} && size(V) == (m, minmn) - @test C isa Matrix{T} && size(C) == (minmn, n) - @test N isa Matrix{T} && size(N) == (m, nullity) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) +m = 54 +for T in BLASFloats, n in (37, m, 63) + TestSuite.seed_rng!(123) + if is_buildkite + if CUDA.functional() + TestSuite.test_orthnull(CuMatrix{T}, (m, n); test_blocksize = false) + n == m && TestSuite.test_orthnull(Diagonal{T, CuVector{T}}, m; test_blocksize = false) end - - # passing a kind and some kwargs - V, C = @constinferred left_orth(A; alg = :qr, positive = true) - N = @constinferred left_null(A; alg = :qr, positive = true) - @test V isa Matrix{T} && size(V) == (m, minmn) - @test C isa Matrix{T} && size(C) == (minmn, n) - @test N isa Matrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - @test V * V' + N * N' ≈ I - - # passing an algorithm - V, C = @constinferred left_orth(A; alg = LAPACK_HouseholderQR()) - N = @constinferred left_null(A; alg = :qr, positive = true) - @test V isa Matrix{T} && size(V) == (m, minmn) - @test C isa Matrix{T} && size(C) == (minmn, n) - @test N isa Matrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - @test V * V' + N * N' ≈ I - - Ac = similar(A) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C)) - N2 = @constinferred left_null!(copy!(Ac, A), N) - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - @test V2 * V2' + N2 * N2' ≈ I - - atol = eps(real(T)) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = (; atol = atol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol)) - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - @test V2 * V2' + N2 * N2' ≈ I - - rtol = eps(real(T)) - for (trunc_orth, trunc_null) in ( - ((; rtol = rtol), (; rtol = rtol)), - (trunctol(; rtol), trunctol(; rtol, keep_below = true)), - ) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = trunc_orth) - N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null) - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - @test V2 * V2' + N2 * N2' ≈ I - end - - for alg in (:qr, :polar, :svd) # explicit kind kwarg - m < n && alg === :polar && continue - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg))) - @test V2 * C2 ≈ A - @test isisometric(V2) - if alg != :polar - N2 = @constinferred left_null!(copy!(Ac, A), N; alg) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - @test V2 * V2' + N2 * N2' ≈ I - end - - # with kind and tol kwargs - if alg == :svd - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; atol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - @test V2 * V2' + N2 * N2' ≈ I - - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; rtol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) - @test V2 * C2 ≈ A - @test isisometric(V2) - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N2) - @test V2 * V2' + N2 * N2' ≈ I - else - @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) - @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) - alg == :polar && continue - @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) - @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) - end + if AMDGPU.functional() + TestSuite.test_orthnull(ROCMatrix{T}, (m, n); test_blocksize = false) + n == m && TestSuite.test_orthnull(Diagonal{T, ROCVector{T}}, m; test_blocksize = false) end + else + TestSuite.test_orthnull(T, (m, n)) end end - -@testset "right_orth and right_null for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = randn(rng, T, m, n) - C, Vᴴ = @constinferred right_orth(A) - Nᴴ = @constinferred right_null(A) - @test C isa Matrix{T} && size(C) == (m, minmn) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (minmn, n) - @test Nᴴ isa Matrix{T} && size(Nᴴ) == (n - minmn, n) - @test C * Vᴴ ≈ A - @test isisometric(Vᴴ; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ; side = :right) - @test Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I - - M = LinearMap(A) - CM, VMᴴ = @constinferred right_orth(M; alg = :svd) - @test parent(CM) * parent(VMᴴ) ≈ A - - Ac = similar(A) - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ; side = :right) - @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - - atol = eps(real(T)) - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol)) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ; side = :right) - @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - - rtol = eps(real(T)) - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol)) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - - for alg in (:lq, :polar, :svd) - n < m && alg == :polar && continue - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg))) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - if alg != :polar - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg))) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - end - - if alg == :svd - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; atol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; atol)) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; rtol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; rtol)) - @test C2 * Vᴴ2 ≈ A - @test isisometric(Vᴴ2; side = :right) - @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(Nᴴ2; side = :right) - @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - else - @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) - @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) - alg == :polar && continue - @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) - @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) - end - end +if !is_buildkite + for T in (BLASFloats..., GenericFloats...) + AT = Diagonal{T, Vector{T}} + TestSuite.test_orthnull(AT, m; test_blocksize = false) end end diff --git a/test/polar.jl b/test/polar.jl index 8087149e..4aa5e961 100644 --- a/test/polar.jl +++ b/test/polar.jl @@ -1,83 +1,38 @@ using MatrixAlgebraKit using Test -using TestExtras using StableRNGs -using LinearAlgebra: LinearAlgebra, I, isposdef +using LinearAlgebra: LinearAlgebra, I, isposdef, Diagonal +using CUDA, AMDGPU -@testset "left_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m) - k = min(m, n) - if LinearAlgebra.LAPACK.version() < v"3.12.0" - svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) - else - svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_Jacobi()) - end - algs = (PolarViaSVD.(svdalgs)..., PolarNewton()) - @testset "algorithm $alg" for alg in algs - A = randn(rng, T, m, n) +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, BigFloat, Complex{BigFloat}) - W, P = left_polar(A; alg) - @test W isa Matrix{T} && size(W) == (m, n) - @test P isa Matrix{T} && size(P) == (n, n) - @test W * P ≈ A - @test isisometric(W) - @test isposdef(P) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - Ac = similar(A) - W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, P), alg) - @test W2 === W - @test P2 === P - @test W * P ≈ A - @test isisometric(W) - @test isposdef(P) +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - noP = similar(P, (0, 0)) - W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, noP), alg) - @test P2 === noP - @test W2 === W - @test isisometric(W) - P = W' * A # compute P explicitly to verify W correctness - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) +m = 54 +for T in BLASFloats, n in (37, m, 63) + TestSuite.seed_rng!(123) + if is_buildkite + if CUDA.functional() + TestSuite.test_polar(CuMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false) + # not supported + #n == m && TestSuite.test_polar(Diagonal{T, CuVector{T}}, m; test_pivoted = false, test_blocksize = false) + end + if AMDGPU.functional() + TestSuite.test_polar(ROCMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false) + # not supported + #n == m && TestSuite.test_polar(Diagonal{T, ROCVector{T}}, m; test_pivoted = false, test_blocksize = false) end + else + TestSuite.test_polar(T, (m, n)) end end - -@testset "right_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - rng = StableRNG(123) - n = 54 - @testset "size ($m, $n)" for m in (37, n) - k = min(m, n) - svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) - algs = (PolarViaSVD.(svdalgs)..., PolarNewton()) - @testset "algorithm $alg" for alg in algs - A = randn(rng, T, m, n) - - P, Wᴴ = right_polar(A; alg) - @test Wᴴ isa Matrix{T} && size(Wᴴ) == (m, n) - @test P isa Matrix{T} && size(P) == (m, m) - @test P * Wᴴ ≈ A - @test isisometric(Wᴴ; side = :right) - @test isposdef(P) - - Ac = similar(A) - P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (P, Wᴴ), alg) - @test P2 === P - @test Wᴴ2 === Wᴴ - @test P * Wᴴ ≈ A - @test isisometric(Wᴴ; side = :right) - @test isposdef(P) - - noP = similar(P, (0, 0)) - P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg) - @test P2 === noP - @test Wᴴ2 === Wᴴ - @test isisometric(Wᴴ; side = :right) - P = A * Wᴴ' # compute P explicitly to verify W correctness - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) - end +if !is_buildkite + for T in (BLASFloats..., GenericFloats...) + AT = Diagonal{T, Vector{T}} + TestSuite.test_polar(AT, m; test_pivoted = false, test_blocksize = false) end end diff --git a/test/projections.jl b/test/projections.jl index 3923528e..89ead93a 100644 --- a/test/projections.jl +++ b/test/projections.jl @@ -3,97 +3,36 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, norm, normalize! +using CUDA, AMDGPU -const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, BigFloat, Complex{BigFloat}) -@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - noisefactor = eps(real(T))^(3 / 4) - for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64)) - for A in (randn(rng, T, m, m), Diagonal(randn(rng, T, m))) - Ah = (A + A') / 2 - Aa = (A - A') / 2 - Ac = copy(A) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - Bh = project_hermitian(A, alg) - @test ishermitian(Bh) - @test Bh ≈ Ah - @test A == Ac - Bh_approx = Bh + noisefactor * Aa - # this is still hermitian for real Diagonal: |A - A'| == 0 - @test !ishermitian(Bh_approx) || norm(Aa) == 0 - @test ishermitian(Bh_approx; rtol = 10 * noisefactor) +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - Ba = project_antihermitian(A, alg) - @test isantihermitian(Ba) - @test Ba ≈ Aa - @test A == Ac - Ba_approx = Ba + noisefactor * Ah - @test !isantihermitian(Ba_approx) - # this is never anti-hermitian for real Diagonal: |A - A'| == 0 - @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0 - Bh = project_hermitian!(Ac, alg) - @test Bh === Ac - @test ishermitian(Bh) - @test Bh ≈ Ah - - copy!(Ac, A) - Ba = project_antihermitian!(Ac, alg) - @test Ba === Ac - @test isantihermitian(Ba) - @test Ba ≈ Aa +m = 54 +for T in BLASFloats + TestSuite.seed_rng!(123) + if is_buildkite + if CUDA.functional() + TestSuite.test_projections(CuMatrix{T}, (m, m); test_pivoted = false, test_blocksize = false) + TestSuite.test_projections(Diagonal{T, CuVector{T}}, m; test_pivoted = false, test_blocksize = false) + end + if AMDGPU.functional() + TestSuite.test_projections(ROCMatrix{T}, (m, m); test_pivoted = false, test_blocksize = false) + TestSuite.test_projections(Diagonal{T, ROCVector{T}}, m; test_pivoted = false, test_blocksize = false) end + else + TestSuite.test_projections(T, (m, m)) end - - # test approximate error calculation - A = normalize!(randn(rng, T, m, m)) - Ah = project_hermitian(A) - Aa = project_antihermitian(A) - - Ah_approx = Ah + noisefactor * Aa - ϵ = norm(project_antihermitian(Ah_approx)) - @test !ishermitian(Ah_approx; atol = (999 // 1000) * ϵ) - @test ishermitian(Ah_approx; atol = (1001 // 1000) * ϵ) - - Aa_approx = Aa + noisefactor * Ah - ϵ = norm(project_hermitian(Aa_approx)) - @test !isantihermitian(Aa_approx; atol = (999 // 1000) * ϵ) - @test isantihermitian(Aa_approx; atol = (1001 // 1000) * ϵ) end - -@testset "project_isometric! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m) - k = min(m, n) - if LinearAlgebra.LAPACK.version() < v"3.12.0" - svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) - else - svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_Jacobi()) - end - algs = (PolarViaSVD.(svdalgs)..., PolarNewton()) - @testset "algorithm $alg" for alg in algs - A = randn(rng, T, m, n) - W = project_isometric(A, alg) - @test isisometric(W) - W2 = project_isometric(W, alg) - @test W2 ≈ W # stability of the projection - @test W * (W' * A) ≈ A - - Ac = similar(A) - W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) - @test W2 === W - @test isisometric(W) - - # test that W is closer to A then any other isometry - for k in 1:10 - δA = randn(rng, T, m, n) - W = project_isometric(A, alg) - W2 = project_isometric(A + δA / 100, alg) - @test norm(A - W2) > norm(A - W) - end - end +if !is_buildkite + for T in (BLASFloats..., GenericFloats...) + AT = Diagonal{T, Vector{T}} + TestSuite.test_projections(AT, m; test_pivoted = false, test_blocksize = false) end end diff --git a/test/qr.jl b/test/qr.jl index c4f0c9d6..71201ae2 100644 --- a/test/qr.jl +++ b/test/qr.jl @@ -1,223 +1,36 @@ using MatrixAlgebraKit using Test -using TestExtras using StableRNGs using LinearAlgebra: diag, I, Diagonal +using CUDA, AMDGPU BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) GenericFloats = (Float16, BigFloat, Complex{BigFloat}) -@testset "qr_compact! and qr_null! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = randn(rng, T, m, n) - Q, R = @constinferred qr_compact(A) - @test Q isa Matrix{T} && size(Q) == (m, minmn) - @test R isa Matrix{T} && size(R) == (minmn, n) - @test Q * R ≈ A - N = @constinferred qr_null(A) - @test N isa Matrix{T} && size(N) == (m, m - minmn) - @test isisometric(Q) - @test maximum(abs, A' * N) < eps(real(T))^(2 / 3) - @test isisometric(N) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - Ac = similar(A) - Q2, R2 = @constinferred qr_compact!(copy!(Ac, A), (Q, R)) - @test Q2 === Q - @test R2 === R - N2 = @constinferred qr_null!(copy!(Ac, A), N) - @test N2 === N +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - Q2 = similar(Q) - noR = similar(A, minmn, 0) - qr_compact!(copy!(Ac, A), (Q2, noR)) - @test Q == Q2 - - # unblocked algorithm - qr_compact!(copy!(Ac, A), (Q, R); blocksize = 1) - @test Q * R ≈ A - @test isisometric(Q) - qr_compact!(copy!(Ac, A), (Q2, noR); blocksize = 1) - @test Q == Q2 - qr_compact!(copy!(Ac, A), (Q2, noR); blocksize = 1) - qr_null!(copy!(Ac, A), N; blocksize = 1) - @test maximum(abs, A' * N) < eps(real(T))^(2 / 3) - @test isisometric(N) - if n <= m - qr_compact!(copy!(Q2, A), (Q2, noR); blocksize = 1) # in-place Q - @test Q ≈ Q2 - @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, R); blocksize = 1) - @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, noR); positive = true) - @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, noR); blocksize = 8) +m = 54 +for T in BLASFloats, n in (37, m, 63) + TestSuite.seed_rng!(123) + if is_buildkite + if CUDA.functional() + TestSuite.test_qr(CuMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false) + n == m && TestSuite.test_qr(Diagonal{T, CuVector{T}}, m; test_pivoted = false, test_blocksize = false) end - # other blocking - qr_compact!(copy!(Ac, A), (Q, R); blocksize = 8) - @test Q * R ≈ A - @test isisometric(Q) - qr_compact!(copy!(Ac, A), (Q2, noR); blocksize = 8) - @test Q == Q2 - qr_null!(copy!(Ac, A), N; blocksize = 8) - @test maximum(abs, A' * N) < eps(real(T))^(2 / 3) - @test isisometric(N) - - # pivoted - qr_compact!(copy!(Ac, A), (Q, R); pivoted = true) - @test Q * R ≈ A - @test Q' * Q ≈ I - qr_compact!(copy!(Ac, A), (Q2, noR); pivoted = true) - @test Q == Q2 - # positive - qr_compact!(copy!(Ac, A), (Q, R); positive = true) - @test Q * R ≈ A - @test isisometric(Q) - @test all(>=(zero(real(T))), real(diag(R))) - qr_compact!(copy!(Ac, A), (Q2, noR); positive = true) - @test Q == Q2 - # positive and blocksize 1 - qr_compact!(copy!(Ac, A), (Q, R); positive = true, blocksize = 1) - @test Q * R ≈ A - @test isisometric(Q) - @test all(>=(zero(real(T))), real(diag(R))) - qr_compact!(copy!(Ac, A), (Q2, noR); positive = true, blocksize = 1) - @test Q == Q2 - # positive and pivoted - qr_compact!(copy!(Ac, A), (Q, R); positive = true, pivoted = true) - @test Q * R ≈ A - @test isisometric(Q) - if n <= m - # the following test tries to find the diagonal element (in order to test positivity) - # before the column permutation. This only works if all columns have a diagonal - # element - for j in 1:n - i = findlast(!iszero, view(R, :, j)) - @test real(R[i, j]) >= zero(real(T)) - end + if AMDGPU.functional() + TestSuite.test_qr(ROCMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false) + n == m && TestSuite.test_qr(Diagonal{T, ROCVector{T}}, m; test_pivoted = false, test_blocksize = false) end - qr_compact!(copy!(Ac, A), (Q2, noR); positive = true, pivoted = true) - @test Q == Q2 + else + TestSuite.test_qr(T, (m, n)) end end - -@testset "qr_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for n in (37, m, 63) - minmn = min(m, n) - A = randn(rng, T, m, n) - Q, R = qr_full(A) - @test Q isa Matrix{T} && size(Q) == (m, m) - @test R isa Matrix{T} && size(R) == (m, n) - @test Q * R ≈ A - @test isunitary(Q) - - Ac = similar(A) - Q2 = similar(Q) - noR = similar(A, m, 0) - Q2, R2 = @constinferred qr_full!(copy!(Ac, A), (Q, R)) - @test Q2 === Q - @test R2 === R - @test Q * R ≈ A - @test isunitary(Q) - qr_full!(copy!(Ac, A), (Q2, noR)) - @test Q == Q2 - - # unblocked algorithm - qr_full!(copy!(Ac, A), (Q, R); blocksize = 1) - @test Q * R ≈ A - @test isunitary(Q) - qr_full!(copy!(Ac, A), (Q2, noR); blocksize = 1) - @test Q == Q2 - if n == m - qr_full!(copy!(Q2, A), (Q2, noR); blocksize = 1) # in-place Q - @test Q ≈ Q2 - end - # other blocking - qr_full!(copy!(Ac, A), (Q, R); blocksize = 8) - @test Q * R ≈ A - @test isunitary(Q) - qr_full!(copy!(Ac, A), (Q2, noR); blocksize = 8) - @test Q == Q2 - # pivoted - qr_full!(copy!(Ac, A), (Q, R); pivoted = true) - @test Q * R ≈ A - @test isunitary(Q) - qr_full!(copy!(Ac, A), (Q2, noR); pivoted = true) - @test Q == Q2 - # positive - qr_full!(copy!(Ac, A), (Q, R); positive = true) - @test Q * R ≈ A - @test isunitary(Q) - @test all(>=(zero(real(T))), real(diag(R))) - qr_full!(copy!(Ac, A), (Q2, noR); positive = true) - @test Q == Q2 - # positive and blocksize 1 - qr_full!(copy!(Ac, A), (Q, R); positive = true, blocksize = 1) - @test Q * R ≈ A - @test isunitary(Q) - @test all(>=(zero(real(T))), real(diag(R))) - qr_full!(copy!(Ac, A), (Q2, noR); positive = true, blocksize = 1) - @test Q == Q2 - # positive and pivoted - qr_full!(copy!(Ac, A), (Q, R); positive = true, pivoted = true) - @test Q * R ≈ A - @test isunitary(Q) - if n <= m - # the following test tries to find the diagonal element (in order to test positivity) - # before the column permutation. This only works if all columns have a diagonal - # element - for j in 1:n - i = findlast(!iszero, view(R, :, j)) - @test real(R[i, j]) >= zero(real(T)) - end - end - qr_full!(copy!(Ac, A), (Q2, noR); positive = true, pivoted = true) - @test Q == Q2 - end -end - -@testset "qr_compact, qr_full and qr_null for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) - rng = StableRNG(123) - atol = eps(real(T))^(3 / 4) - for m in (54, 0) - Ad = randn(rng, T, m) - A = Diagonal(Ad) - - # compact - Q, R = @constinferred qr_compact(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test R isa Diagonal{T} && size(R) == (m, m) - @test Q * R ≈ A - @test isunitary(Q) - - # compact and positive - Qp, Rp = @constinferred qr_compact(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Rp isa Diagonal{T} && size(Rp) == (m, m) - @test Qp * Rp ≈ A - @test isunitary(Qp) - @test all(≥(zero(real(T))), real(diag(Rp))) && - all(≈(zero(real(T)); atol), imag(diag(Rp))) - - # full - Q, R = @constinferred qr_full(A) - @test Q isa Diagonal{T} && size(Q) == (m, m) - @test R isa Diagonal{T} && size(R) == (m, m) - @test Q * R ≈ A - @test isunitary(Q) - - # full and positive - Qp, Rp = @constinferred qr_full(A; positive = true) - @test Qp isa Diagonal{T} && size(Qp) == (m, m) - @test Rp isa Diagonal{T} && size(Rp) == (m, m) - @test Qp * Rp ≈ A - @test isunitary(Qp) - @test all(≥(zero(real(T))), real(diag(Rp))) && - all(≈(zero(real(T)); atol), imag(diag(Rp))) - - # null - N = @constinferred qr_null(A) - @test N isa AbstractMatrix{T} && size(N) == (m, 0) +if !is_buildkite + for T in (BLASFloats..., GenericFloats...) + AT = Diagonal{T, Vector{T}} + TestSuite.test_qr(AT, m; test_pivoted = false, test_blocksize = false) end end diff --git a/test/runtests.jl b/test/runtests.jl index 1ed1f456..edc8385e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,129 +1,58 @@ using SafeTestsets -# don't run all tests on GPU, only the GPU -# specific ones is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +@safetestset "Algorithms" begin + include("algorithms.jl") +end +@safetestset "Projections" begin + include("projections.jl") +end +@safetestset "Truncate" begin + include("truncate.jl") +end +@safetestset "QR / LQ Decomposition" begin + include("qr.jl") + include("lq.jl") +end +@safetestset "Singular Value Decomposition" begin + include("svd.jl") +end +@safetestset "Hermitian Eigenvalue Decomposition" begin + include("eigh.jl") +end +@safetestset "General Eigenvalue Decomposition" begin + include("eig.jl") +end +@safetestset "Generalized Eigenvalue Decomposition" begin + include("gen_eig.jl") +end +@safetestset "Schur Decomposition" begin + include("schur.jl") +end +@safetestset "Polar Decomposition" begin + include("polar.jl") +end +@safetestset "Image and Null Space" begin + include("orthnull.jl") +end if !is_buildkite - @safetestset "Algorithms" begin - include("algorithms.jl") - end - @safetestset "Projections" begin - include("projections.jl") - end - @safetestset "Truncate" begin - include("truncate.jl") - end - @safetestset "QR / LQ Decomposition" begin - include("qr.jl") - include("lq.jl") - end - @safetestset "Singular Value Decomposition" begin - include("svd.jl") - end - @safetestset "Hermitian Eigenvalue Decomposition" begin - include("eigh.jl") - end - @safetestset "General Eigenvalue Decomposition" begin - include("eig.jl") - end - @safetestset "Generalized Eigenvalue Decomposition" begin - include("gen_eig.jl") - end - @safetestset "Schur Decomposition" begin - include("schur.jl") - end - @safetestset "Polar Decomposition" begin - include("polar.jl") - end - @safetestset "Image and Null Space" begin - include("orthnull.jl") - end @safetestset "Mooncake" begin include("mooncake.jl") end @safetestset "ChainRules" begin include("chainrules.jl") end - @safetestset "MatrixAlgebraKit.jl" begin - @safetestset "Code quality (Aqua.jl)" begin - using MatrixAlgebraKit - using Aqua - Aqua.test_all(MatrixAlgebraKit) - end - @safetestset "Code linting (JET.jl)" begin - using MatrixAlgebraKit - using JET - JET.test_package(MatrixAlgebraKit; target_defined_modules = true) - end - end - - using GenericLinearAlgebra - @safetestset "QR / LQ Decomposition" begin - include("genericlinearalgebra/qr.jl") - include("genericlinearalgebra/lq.jl") - end - @safetestset "Singular Value Decomposition" begin - include("genericlinearalgebra/svd.jl") - end - @safetestset "Hermitian Eigenvalue Decomposition" begin - include("genericlinearalgebra/eigh.jl") - end - - using GenericSchur - @safetestset "General Eigenvalue Decomposition" begin - include("genericschur/eig.jl") - end end - -using CUDA -if CUDA.functional() - @safetestset "CUDA QR" begin - include("cuda/qr.jl") - end - @safetestset "CUDA LQ" begin - include("cuda/lq.jl") - end - @safetestset "CUDA Projections" begin - include("cuda/projections.jl") - end - @safetestset "CUDA SVD" begin - include("cuda/svd.jl") - end - @safetestset "CUDA General Eigenvalue Decomposition" begin - include("cuda/eig.jl") - end - @safetestset "CUDA Hermitian Eigenvalue Decomposition" begin - include("cuda/eigh.jl") - end - @safetestset "CUDA Polar Decomposition" begin - include("cuda/polar.jl") - end - @safetestset "CUDA Image and Null Space" begin - include("cuda/orthnull.jl") - end -end - -using AMDGPU -if AMDGPU.functional() - @safetestset "AMDGPU QR" begin - include("amd/qr.jl") - end - @safetestset "AMDGPU LQ" begin - include("amd/lq.jl") - end - @safetestset "AMDGPU Projections" begin - include("amd/projections.jl") - end - @safetestset "AMDGPU SVD" begin - include("amd/svd.jl") - end - @safetestset "AMDGPU Hermitian Eigenvalue Decomposition" begin - include("amd/eigh.jl") - end - @safetestset "AMDGPU Polar Decomposition" begin - include("amd/polar.jl") - end - @safetestset "AMDGPU Image and Null Space" begin - include("amd/orthnull.jl") +@safetestset "MatrixAlgebraKit.jl" begin + @safetestset "Code quality (Aqua.jl)" begin + using MatrixAlgebraKit + using Aqua + Aqua.test_all(MatrixAlgebraKit) + end + @safetestset "Code linting (JET.jl)" begin + using MatrixAlgebraKit + using JET + JET.test_package(MatrixAlgebraKit; target_defined_modules = true) end end diff --git a/test/schur.jl b/test/schur.jl index e24de579..08c998ef 100644 --- a/test/schur.jl +++ b/test/schur.jl @@ -2,30 +2,35 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using LinearAlgebra: I +using LinearAlgebra: I, Diagonal -@testset "schur_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - rng = StableRNG(123) - m = 54 - for alg in (LAPACK_Simple(), LAPACK_Expert()) - A = randn(rng, T, m, m) - Tc = complex(T) +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, BigFloat, Complex{BigFloat}) - TA, Z, vals = @constinferred schur_full(A; alg) - @test eltype(TA) == eltype(Z) == T - @test eltype(vals) == Tc - @test isisometric(Z) - @test A * Z ≈ Z * TA +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - Ac = similar(A) - TA2, Z2, vals2 = @constinferred schur_full!(copy!(Ac, A), (TA, Z, vals), alg) - @test TA2 === TA - @test Z2 === Z - @test vals2 === vals - @test A * Z ≈ Z * TA +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - valsc = @constinferred schur_vals(A, alg) - @test eltype(valsc) == Tc - @test valsc ≈ eig_vals(A, alg) +m = 54 +for T in BLASFloats + TestSuite.seed_rng!(123) + if is_buildkite + #=if CUDA.functional() + TestSuite.test_schur(CuMatrix{T}, (m, m); test_blocksize = false) + TestSuite.test_schur(Diagonal{T, CuVector{T}}, m; test_blocksize = false) + end + if AMDGPU.functional() + TestSuite.test_schur(ROCMatrix{T}, (m, m); test_blocksize = false) + TestSuite.test_schur(Diagonal{T, ROCVector{T}}, m; test_blocksize = false) + end=# # not yet supported + else + TestSuite.test_schur(T, (m, m)) + end +end +if !is_buildkite + for T in (BLASFloats..., GenericFloats...) + AT = Diagonal{T, Vector{T}} + TestSuite.test_schur(AT, m; test_blocksize = false) end end diff --git a/test/svd.jl b/test/svd.jl index d055f866..a2d09e8d 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -2,239 +2,36 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef, norm -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, isisometric +using LinearAlgebra: Diagonal +using CUDA, AMDGPU BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) GenericFloats = (Float16, BigFloat, Complex{BigFloat}) -@testset "svd_compact! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63, 0) - k = min(m, n) - if LinearAlgebra.LAPACK.version() < v"3.12.0" - algs = ( - LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), - LAPACK_DivideAndConquer, :LAPACK_DivideAndConquer, - ) - else - algs = ( - LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), - LAPACK_Jacobi(), LAPACK_DivideAndConquer, :LAPACK_DivideAndConquer, - ) - end - @testset "algorithm $alg" for alg in algs - n > m && alg isa LAPACK_Jacobi && continue # not supported - minmn = min(m, n) - A = randn(rng, T, m, n) - - if VERSION < v"1.11" - # This is type unstable on older versions of Julia. - U, S, Vᴴ = svd_compact(A; alg) - else - U, S, Vᴴ = @constinferred svd_compact(A; alg = ($alg)) - end - @test U isa Matrix{T} && size(U) == (m, minmn) - @test S isa Diagonal{real(T)} && size(S) == (minmn, minmn) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (minmn, n) - @test U * S * Vᴴ ≈ A - @test isisometric(U) - @test isisometric(Vᴴ; side = :right) - @test isposdef(S) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - Ac = similar(A) - Sc = similar(A, real(T), min(m, n)) - alg′ = @constinferred MatrixAlgebraKit.select_algorithm(svd_compact!, A, $alg) - U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg′) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isisometric(U) - @test isisometric(Vᴴ; side = :right) - @test isposdef(S) +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - Sd = @constinferred svd_vals(A, alg′) - @test S ≈ Diagonal(Sd) +m = 54 +for T in BLASFloats, n in (37, m, 63, 0) + TestSuite.seed_rng!(123) + if is_buildkite + if CUDA.functional() + TestSuite.test_svd(CuMatrix{T}, (m, n); test_blocksize = false) + n == m && TestSuite.test_svd(Diagonal{T, CuVector{T}}, m; test_blocksize = false) end - end -end - -@testset "svd_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63, 0) - @testset "algorithm $alg" for alg in - (LAPACK_DivideAndConquer(), LAPACK_QRIteration()) - A = randn(rng, T, m, n) - U, S, Vᴴ = svd_full(A; alg) - @test U isa Matrix{T} && size(U) == (m, m) - @test S isa Matrix{real(T)} && size(S) == (m, n) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (n, n) - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - - Sc = similar(A, real(T), min(m, n)) - Sc2 = svd_vals!(copy!(Ac, A), Sc, alg) - @test Sc === Sc2 - @test diagview(S) ≈ Sc - end - end - @testset "size (0, 0)" begin - @testset "algorithm $alg" for alg in - (LAPACK_DivideAndConquer(), LAPACK_QRIteration()) - A = randn(rng, T, 0, 0) - U, S, Vᴴ = svd_full(A; alg) - @test U isa Matrix{T} && size(U) == (0, 0) - @test S isa Matrix{real(T)} && size(S) == (0, 0) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (0, 0) - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - end - end -end - -@testset "svd_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - atol = sqrt(eps(real(T))) - if LinearAlgebra.LAPACK.version() < v"3.12.0" - algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) - else - algs = ( - LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_Jacobi(), - ) - end - - @testset "size ($m, $n)" for n in (37, m, 63) - @testset "algorithm $alg" for alg in algs - n > m && alg isa LAPACK_Jacobi && continue # not supported - A = randn(rng, T, m, n) - S₀ = svd_vals(A) - minmn = min(m, n) - r = minmn - 2 - - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) - @test length(diagview(S1)) == r - @test diagview(S1) ≈ S₀[1:r] - @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] - # Test truncation error - @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol - - s = 1 + sqrt(eps(real(T))) - trunc = trunctol(; atol = s * S₀[r + 1]) - - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) - @test length(diagview(S2)) == r - @test U1 ≈ U2 - @test S1 ≈ S2 - @test V1ᴴ ≈ V2ᴴ - @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol - - trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) - U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) - @test length(diagview(S3)) == r - @test U1 ≈ U3 - @test S1 ≈ S3 - @test V1ᴴ ≈ V3ᴴ - @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + if AMDGPU.functional() + TestSuite.test_svd(ROCMatrix{T}, (m, n); test_blocksize = false) + n == m && TestSuite.test_svd(Diagonal{T, ROCVector{T}}, m; test_blocksize = false) end - end -end - -@testset "svd_trunc! mix maxrank and tol for T = $T" for T in BLASFloats - rng = StableRNG(123) - if LinearAlgebra.LAPACK.version() < v"3.12.0" - algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) else - algs = ( - LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_Jacobi(), - ) + TestSuite.test_svd(T, (m, n)) end - m = 4 - @testset "algorithm $alg" for alg in algs - U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - Vᴴ = qr_compact(randn(rng, T, m, m))[1] - A = U * S * Vᴴ - - for trunc_fun in ( - (rtol, maxrank) -> (; rtol, maxrank), - (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), - ) - U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) - @test length(diagview(S1)) == 1 - @test diagview(S1) ≈ diagview(S)[1:1] - - U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) - @test length(diagview(S2)) == 2 - @test diagview(S2) ≈ diagview(S)[1:2] - end - end -end - -@testset "svd_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - atol = sqrt(eps(real(T))) - m = 4 - U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - Vᴴ = qr_compact(randn(rng, T, m, m))[1] - A = U * S * Vᴴ - alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), trunctol(; atol = 0.2)) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) - @test diagview(S2) ≈ diagview(S)[1:2] - @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol - @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) end - -@testset "svd for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) - rng = StableRNG(123) - atol = sqrt(eps(real(T))) - for m in (54, 0) - Ad = randn(T, m) - A = Diagonal(Ad) - - U, S, Vᴴ = @constinferred svd_compact(A) - @test U isa AbstractMatrix{T} && size(U) == size(A) - @test Vᴴ isa AbstractMatrix{T} && size(Vᴴ) == size(A) - @test S isa Diagonal{real(T)} && size(S) == size(A) - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(≥(0), diagview(S)) - @test A ≈ U * S * Vᴴ - - U, S, Vᴴ = @constinferred svd_full(A) - @test U isa AbstractMatrix{T} && size(U) == size(A) - @test Vᴴ isa AbstractMatrix{T} && size(Vᴴ) == size(A) - @test S isa Diagonal{real(T)} && size(S) == size(A) - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(≥(0), diagview(S)) - @test A ≈ U * S * Vᴴ - - S2 = @constinferred svd_vals(A) - @test S2 isa AbstractVector{real(T)} && length(S2) == m - @test S2 ≈ diagview(S) - - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(A; alg) - @test diagview(S3) ≈ S2[1:min(m, 2)] - @test ϵ3 ≈ norm(S2[(min(m, 2) + 1):m]) atol = atol +if !is_buildkite + for T in (BLASFloats...,) # GenericFloats...) # not yet supported + AT = Diagonal{T, Vector{T}} + TestSuite.test_svd(AT, m; test_blocksize = false) end end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl new file mode 100644 index 00000000..b7c54259 --- /dev/null +++ b/test/testsuite/TestSuite.jl @@ -0,0 +1,76 @@ +# Based on the design of GPUArrays.jl + +""" + TestSuite + +Suite of tests that may be used for all packages inheriting from MatrixAlgebraKit. + +""" +module TestSuite + +using Test +using MatrixAlgebraKit +using MatrixAlgebraKit: diagview +using LinearAlgebra: Diagonal, norm, istriu, istril +using Random, StableRNGs +using AMDGPU, CUDA + +const tests = Dict() + +macro testsuite(name, ex) + safe_name = lowercase(replace(replace(name, " " => "_"), "/" => "_")) + fn = Symbol("test_", safe_name) + return quote + $(esc(fn))(AT; eltypes = supported_eltypes(AT, $(esc(fn)))) = $(esc(ex))(AT, eltypes) + @assert !haskey(tests, $name) "testsuite already exists" + tests[$name] = $fn + end +end + +testargs_summary(args...) = string(args) + +const rng = StableRNG(123) +seed_rng!(seed) = Random.seed!(rng, seed) + +instantiate_matrix(::Type{T}, size) where {T <: Number} = randn(rng, T, size) +instantiate_matrix(::Type{AT}, size) where {AT <: Array} = randn(rng, eltype(AT), size) +instantiate_matrix(::Type{AT}, size) where {AT <: CuArray} = CuArray(randn(rng, eltype(AT), size)) +instantiate_matrix(::Type{AT}, size) where {AT <: ROCArray} = ROCArray(randn(rng, eltype(AT), size)) +instantiate_matrix(::Type{AT}, size) where {AT <: Diagonal} = Diagonal(randn(rng, eltype(AT), size)) +instantiate_matrix(::Type{AT}, size) where {T, AT <: Diagonal{T, <:CuVector}} = Diagonal(CuArray(randn(rng, eltype(AT), size))) +instantiate_matrix(::Type{AT}, size) where {T, AT <: Diagonal{T, <:ROCVector}} = Diagonal(ROCArray(randn(rng, eltype(AT), size))) + +precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T))) +precision(::Type{T}) where {T} = precision(eltype(T)) + +function has_positive_diagonal(A) + T = eltype(A) + return if T <: Real + all(≥(zero(T)), diagview(A)) + else + all(≥(zero(real(T))), real(diagview(A))) && + all(≈(zero(real(T))), imag(diagview(A))) + end +end +isleftnull(N, A; atol::Real = 0, rtol::Real = precision(eltype(A))) = + isapprox(norm(A' * N), 0; atol = max(atol, norm(A) * rtol)) + +isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) = + isapprox(norm(A * Nᴴ'), 0; atol = max(atol, norm(A) * rtol)) + +# TODO: actually make this a test +macro testinferred(ex) + return esc(:(@inferred $ex)) +end + +include("qr.jl") +include("lq.jl") +include("polar.jl") +include("orthnull.jl") +include("projections.jl") +include("eigh.jl") +include("eig.jl") +include("schur.jl") +include("svd.jl") + +end diff --git a/test/testsuite/eig.jl b/test/testsuite/eig.jl new file mode 100644 index 00000000..4c7daaf8 --- /dev/null +++ b/test/testsuite/eig.jl @@ -0,0 +1,100 @@ +using TestExtras + +function test_eig(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "eig $summary_str" begin + test_eig_full(T, sz; kwargs...) + if T <: Number || T <: Diagonal{<:Number, <:Vector} + test_eig_trunc(T, sz; kwargs...) + end + end +end + +function test_eig_full( + T::Type, sz; + test_blocksize = true, + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eig_full! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T)) + D, V = @testinferred eig_full(A) + @test eltype(D) == eltype(V) == Tc + @test A * V ≈ V * D + + D2, V2 = @testinferred eig_full!(Ac, (D, V)) + @test D2 === D + @test V2 === V + @test A * V ≈ V * D + + Dc = @testinferred eig_vals(A) + @test eltype(Dc) == Tc + @test D ≈ Diagonal(Dc) + end +end + +function test_eig_trunc( + T::Type, sz; + test_blocksize = true, + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eig_trunc! $summary_str" begin + A = instantiate_matrix(T, sz) + A *= A' # TODO: deal with eigenvalue ordering etc + Ac = deepcopy(A) + Tc = complex(eltype(T)) + # eigenvalues are sorted by ascending real component... + D₀ = sort!(eig_vals(A); by = abs, rev = true) + m = size(A, 1) + rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) + r = length(D₀) - rmin + atol = sqrt(eps(real(eltype(T)))) + + D1, V1, ϵ1 = @testinferred eig_trunc(A; trunc = truncrank(r)) + @test length(diagview(D1)) == r + @test A * V1 ≈ V1 * D1 + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + s = 1 + sqrt(eps(real(eltype(T)))) + trunc = trunctol(; atol = s * abs(D₀[r + 1])) + D2, V2, ϵ2 = @testinferred eig_trunc(A; trunc) + @test length(diagview(D2)) == r + @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + s = 1 - sqrt(eps(real(eltype(T)))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D3, V3, ϵ3 = @testinferred eig_trunc(A; trunc) + @test length(diagview(D3)) == r + @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + # trunctol keeps order, truncrank might not + # test for same subspace + @test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2 + @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1 + @test V1 * ((V1' * V1) \ (V1' * V3)) ≈ V3 + @test V3 * ((V3' * V3) \ (V3' * V1)) ≈ V1 + + # TODO + #=atol = sqrt(eps(real(eltype(T)))) + V = randn(rng, T, m, m) + D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) + A = V * D * inv(V) + alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) + D2, V2, ϵ2 = @testinferred eig_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) + + alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1)) + D3, V3, ϵ3 = @testinferred eig_trunc(A; alg) + @test diagview(D3) ≈ diagview(D)[1:2] + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol=# + end +end diff --git a/test/testsuite/eigh.jl b/test/testsuite/eigh.jl new file mode 100644 index 00000000..b6aaa98c --- /dev/null +++ b/test/testsuite/eigh.jl @@ -0,0 +1,107 @@ +using TestExtras + +function test_eigh(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "eigh $summary_str" begin + test_eigh_full(T, sz; kwargs...) + if T <: Number && eltype(T) <: Union{Float16, ComplexF16, Float32, Float64, ComplexF32, ComplexF64} && !(T <: Diagonal) + test_eigh_trunc(T, sz; kwargs...) + end + end +end + +function test_eigh_full( + T::Type, sz; + test_blocksize = true, + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eigh_full! $summary_str" begin + A = instantiate_matrix(T, sz) + A = (A + A') / 2 + Ac = deepcopy(A) + + D, V = @testinferred eigh_full(A) + @test A * V ≈ V * D + @test isunitary(V) + @test all(isreal, D) + + D2, V2 = eigh_full!(copy(A), (D, V)) + @test D2 === D + @test V2 === V + + D3 = @testinferred eigh_vals(A) + @test D ≈ Diagonal(D3) + end +end + +function test_eigh_trunc( + T::Type, sz; + test_blocksize = true, + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eigh_trunc! $summary_str" begin + A = instantiate_matrix(T, sz) + A = A * A' + A = (A + A') / 2 + Ac = deepcopy(A) + + m = size(A, 1) + D₀ = reverse(eigh_vals(A)) + r = m - 2 + s = 1 + sqrt(eps(real(eltype(T)))) + atol = sqrt(eps(real(eltype(T)))) + # truncrank + D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r)) + @test length(diagview(D1)) == r + @test isisometric(V1) + @test A * V1 ≈ V1 * D1 + @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + # trunctol + trunc = trunctol(; atol = s * D₀[r + 1]) + D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc) + @test length(diagview(D2)) == r + @test isisometric(V2) + @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + #truncerror + s = 1 - sqrt(eps(real(eltype(T)))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc) + @test length(diagview(D3)) == r + @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + # test for same subspace + @test V1 * (V1' * V2) ≈ V2 + @test V2 * (V2' * V1) ≈ V1 + @test V1 * (V1' * V3) ≈ V3 + @test V3 * (V3' * V1) ≈ V1 + + # TODO + #= + @testset "specify truncation algorithm" begin + atol = sqrt(eps(real(eltype(T)))) + V = qr_compact(instantiate_matrix(T, sz))[1] + D = Diagonal(real(eltype(T))[0.9, 0.3, 0.1, 0.01]) + A = V * D * V' + A = (A + A') / 2 + alg = TruncatedAlgorithm(MatrixAlgebraKit.default_qr_algorithm(A), truncrank(2)) + D2, V2, ϵ2 = @testinferred eigh_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] + @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + + alg = TruncatedAlgorithm(MatrixAlgebraKit.default_qr_algorithm(A), truncerror(; atol = 0.2)) + D3, V3, ϵ3 = @testinferred eigh_trunc(A; alg) + @test diagview(D3) ≈ diagview(D)[1:2] + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol + end=# + end +end diff --git a/test/testsuite/lq.jl b/test/testsuite/lq.jl new file mode 100644 index 00000000..575d299f --- /dev/null +++ b/test/testsuite/lq.jl @@ -0,0 +1,162 @@ +using TestExtras + +function test_lq(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "lq $summary_str" begin + test_lq_compact(T, sz; kwargs...) + test_lq_full(T, sz; kwargs...) + test_lq_null(T, sz; kwargs...) + end +end + +function test_lq_compact( + T::Type, sz; + test_positive = true, test_pivoted = true, test_blocksize = true, + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "lq_compact! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + + # does the elementary functionality work + L, Q = @testinferred lq_compact(A) + @test L * Q ≈ A + @test isisometric(Q; side = :right, atol, rtol) + @test istril(L) + @test A == Ac + + # can I pass in outputs? + L2, Q2 = @testinferred lq_compact!(deepcopy(A), (L, Q)) + @test L2 * Q2 ≈ A + @test isisometric(Q2; side = :right, atol, rtol) + @test istril(L2) + + # do we support `positive = true`? + if test_positive + Lpos, Qpos = @testinferred lq_compact(A; positive = true) + @test Lpos * Qpos ≈ A + @test isisometric(Qpos; side = :right, atol, rtol) + @test istril(Lpos) + @test has_positive_diagonal(Lpos) + else + @test_throws Exception lq_compact(A; positive = true) + end + + # do we support `pivoted = true`? + if test_pivoted + Lpiv, Qpiv = @testinferred lq_compact(A; pivoted = true) + @test Lpiv * Qpiv ≈ A + @test isisometric(Qpos; side = :right, atol, rtol) + else + @test_throws Exception lq_compact(A; pivoted = true) + end + + # do we support `blocksize = Int`? + if test_blocksize + Lblocked, Qblocked = @testinferred lq_compact(A; blocksize = 2) + @test Lblocked * Qblocked ≈ A + @test isisometric(Qblocked; side = :right, atol, rtol) + else + @test_throws Exception lq_compact(A; blocksize = 2) + end + end +end + +function test_lq_full( + T::Type, sz; + test_positive = true, test_pivoted = true, test_blocksize = true, + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "lq_full! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + + # does the elementary functionality work + L, Q = @testinferred lq_full(A) + @test L * Q ≈ A + @test isunitary(Q; atol, rtol) + @test istril(L) + @test A == Ac + + # can I pass in outputs? + L2, Q2 = @testinferred lq_full!(deepcopy(A), (L, Q)) + @test L2 * Q2 ≈ A + @test isunitary(Q2; atol, rtol) + @test istril(L2) + + # do we support `positive = true`? + if test_positive + Lpos, Qpos = @testinferred lq_full(A; positive = true) + @test Lpos * Qpos ≈ A + @test isunitary(Qpos; atol, rtol) + @test istril(Lpos) + @test has_positive_diagonal(Lpos) + else + @test_throws Exception lq_full(A; positive = true) + end + + # do we support `pivoted = true`? + if test_pivoted + Lpiv, Qpiv = @testinferred lq_full(A; pivoted = true) + @test Lpiv * Qpiv ≈ A + @test isunitary(Qpos; atol, rtol) + else + @test_throws Exception lq_full(A; pivoted = true) + end + + # do we support `blocksize = Int`? + if test_blocksize + Lblocked, Qblocked = @testinferred lq_full(A; blocksize = 2) + @test Lblocked * Qblocked ≈ A + @test isunitary(Qblocked; atol, rtol) + else + @test_throws Exception lq_full(A; blocksize = 2) + end + end +end + +function test_lq_null( + T::Type, sz; + test_pivoted = true, test_blocksize = true, + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "lq_null! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + + # does the elementary functionality work + Nᴴ = @testinferred lq_null(A) + @test isrightnull(Nᴴ, A; atol, rtol) + @test isisometric(Nᴴ; side = :right, atol, rtol) + @test A == Ac + + # can I pass in outputs? + Nᴴ2 = @testinferred lq_null!(deepcopy(A), Nᴴ) + @test isrightnull(Nᴴ2, A; atol, rtol) + @test isisometric(Nᴴ2; side = :right, atol, rtol) + + # do we support `pivoted = true`? + if test_pivoted + Nᴴpiv = @testinferred lq_null(A; pivoted = true) + @test isrightnull(Nᴴpiv, A; atol, rtol) + @test isisometric(Nᴴpiv; side = :right, atol, rtol) + #else # DISABLE for now as lq_null does support pivoting... + # @test_throws Exception lq_null(A; pivoted = true) + end + + # do we support `blocksize = Int`? + if test_blocksize + Nᴴblocked = @testinferred lq_null(A; blocksize = 2) + @test isrightnull(Nᴴblocked, A; atol, rtol) + @test isisometric(Nᴴblocked; side = :right, atol, rtol) + else + @test_throws Exception lq_null(A; blocksize = 2) + end + end +end diff --git a/test/testsuite/orthnull.jl b/test/testsuite/orthnull.jl new file mode 100644 index 00000000..fa09214e --- /dev/null +++ b/test/testsuite/orthnull.jl @@ -0,0 +1,268 @@ +using TestExtras +using LinearAlgebra + +include("../linearmap.jl") + +function test_orthnull(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "orthnull $summary_str" begin + test_left_orthnull(T, sz; kwargs...) + test_right_orthnull(T, sz; kwargs...) + end +end + +function test_left_orthnull( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "left_orth! and left_null! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + V, C = @testinferred left_orth(A) + N = @testinferred left_null(A) + m, n = size(A) + minmn = min(m, n) + @test V isa typeof(A) && size(V) == (m, minmn) + @test C isa typeof(A) && size(C) == (minmn, n) + @test eltype(N) == eltype(A) && size(N) == (m, m - minmn) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + @test collect(V) * collect(V)' + collect(N) * collect(N)' ≈ I + + M = LinearMap(A) + # broken + #VM, CM = @testinferred left_orth(M; alg = :svd) + VM, CM = left_orth(M; alg = :svd) + @test parent(VM) * parent(CM) ≈ A + + if m > n && (T <: Number || T <: Diagonal{<:Number, <:Vector}) + nullity = 5 + V, C = @testinferred left_orth(A) + N = @testinferred left_null(A; trunc = (; maxnullity = nullity)) + @test V isa typeof(A) && size(V) == (m, minmn) + @test C isa typeof(A) && size(C) == (minmn, n) + @test eltype(N) == eltype(A) && size(N) == (m, nullity) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + end + + # passing a kind and some kwargs + # broken + # V, C = @testinferred left_orth(A; alg = :qr, positive = true) + V, C = left_orth(A; alg = :qr, positive = true) + N = @testinferred left_null(A; alg = :qr, positive = true) + @test V isa typeof(A) && size(V) == (m, minmn) + @test C isa typeof(A) && size(C) == (minmn, n) + @test eltype(N) == eltype(A) && size(N) == (m, m - minmn) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + @test collect(V) * collect(V)' + collect(N) * collect(N)' ≈ I + + # passing an algorithm + if !isa(A, Diagonal) + V, C = @testinferred left_orth(A; alg = MatrixAlgebraKit.default_qr_algorithm(A)) + N = @testinferred left_null(A; alg = :qr, positive = true) + @test V isa typeof(A) && size(V) == (m, minmn) + @test C isa typeof(A) && size(C) == (minmn, n) + @test eltype(N) == eltype(A) && size(N) == (m, m - minmn) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + @test collect(V) * collect(V)' + collect(N) * collect(N)' ≈ I + end + + Ac = similar(A) + V2, C2 = @testinferred left_orth!(copy!(Ac, A), (V, C)) + N2 = @testinferred left_null!(copy!(Ac, A), N) + @test V2 * C2 ≈ A + @test isisometric(V2) + @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N2) + @test collect(V2) * collect(V2)' + collect(N2) * collect(N2)' ≈ I + + # doesn't work on AMD... + atol = eps(real(eltype(T))) + V2, C2 = @testinferred left_orth!(copy!(Ac, A), (V, C); trunc = (; atol = atol)) + N2 = @testinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol)) + @test V2 * C2 ≈ A + @test isisometric(V2) + @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N2) + @test collect(V2) * collect(V2)' + collect(N2) * collect(N2)' ≈ I + + if (T <: Number || T <: Diagonal{<:Number, <:Vector}) + rtol = eps(real(eltype(T))) + for (trunc_orth, trunc_null) in ( + ((; rtol = rtol), (; rtol = rtol)), + (trunctol(; rtol), trunctol(; rtol, keep_below = true)), + ) + V2, C2 = @testinferred left_orth!(copy!(Ac, A), (V, C); trunc = trunc_orth) + N2 = @testinferred left_null!(copy!(Ac, A), N; trunc = trunc_null) + @test V2 * C2 ≈ A + @test isisometric(V2) + @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N2) + @test collect(V2) * collect(V2)' + collect(N2) * collect(N2)' ≈ I + end + end + + for alg in (:qr, :polar, :svd) # explicit kind kwarg + m < n && alg === :polar && continue + # broken + # V2, C2 = @testinferred left_orth!(copy!(Ac, A), (V, C); alg = alg) + V2, C2 = left_orth!(copy!(Ac, A), (V, C); alg = alg) + @test V2 * C2 ≈ A + @test isisometric(V2) + if alg != :polar + N2 = @testinferred left_null!(copy!(Ac, A), N; alg) + @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N2) + @test collect(V2) * collect(V2)' + collect(N2) * collect(N2)' ≈ I + end + + # with kind and tol kwargs + if alg == :svd + if (T <: Number || T <: Diagonal{<:Number, <:Vector}) + # broken + # V2, C2 = @testinferred left_orth!(copy!(Ac, A), (V, C); alg = alg, trunc = (; atol)) + V2, C2 = left_orth!(copy!(Ac, A), (V, C); alg = alg, trunc = (; atol)) + N2 = @testinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) + @test V2 * C2 ≈ A + @test isisometric(V2) + @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N2) + @test collect(V2) * collect(V2)' + collect(N2) * collect(N2)' ≈ I + + # broken + # V2, C2 = @testinferred left_orth!(copy!(Ac, A), (V, C); alg = alg, trunc = (; rtol)) + V2, C2 = left_orth!(copy!(Ac, A), (V, C); alg = alg, trunc = (; rtol)) + N2 = @testinferred left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) + @test V2 * C2 ≈ A + @test isisometric(V2) + @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N2) + @test collect(V2) * collect(V2)' + collect(N2) * collect(N2)' ≈ I + end + else + @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) + @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) + alg == :polar && continue + @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) + @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) + end + end + end +end + +function test_right_orthnull( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "right_orth! and right_null! $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + minmn = min(m, n) + Ac = deepcopy(A) + C, Vᴴ = @testinferred right_orth(A) + Nᴴ = @testinferred right_null(A) + @test C isa typeof(A) && size(C) == (m, minmn) + @test Vᴴ isa typeof(A) && size(Vᴴ) == (minmn, n) + @test eltype(Nᴴ) == eltype(A) && size(Nᴴ) == (n - minmn, n) + @test C * Vᴴ ≈ A + @test isisometric(Vᴴ; side = :right) + @test LinearAlgebra.norm(A * adjoint(Nᴴ)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(Nᴴ; side = :right) + @test collect(Vᴴ)' * collect(Vᴴ) + collect(Nᴴ)' * collect(Nᴴ) ≈ I + + M = LinearMap(A) + # broken + #CM, VMᴴ = @testinferred right_orth(M; alg = :svd) + CM, VMᴴ = right_orth(M; alg = :svd) + @test parent(CM) * parent(VMᴴ) ≈ A + + Ac = similar(A) + C2, Vᴴ2 = @testinferred right_orth!(copy!(Ac, A), (C, Vᴴ)) + Nᴴ2 = @testinferred right_null!(copy!(Ac, A), Nᴴ) + @test C2 * Vᴴ2 ≈ A + @test isisometric(Vᴴ2; side = :right) + @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(Nᴴ; side = :right) + @test collect(Vᴴ2)' * collect(Vᴴ2) + collect(Nᴴ2)' * collect(Nᴴ2) ≈ I + + if (T <: Number || T <: Diagonal{<:Number, <:Vector}) + atol = eps(real(eltype(T))) + C2, Vᴴ2 = @testinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol)) + Nᴴ2 = @testinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol)) + @test C2 * Vᴴ2 ≈ A + @test isisometric(Vᴴ2; side = :right) + @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(Nᴴ; side = :right) + @test collect(Vᴴ2)' * collect(Vᴴ2) + collect(Nᴴ2)' * collect(Nᴴ2) ≈ I + + rtol = eps(real(eltype(T))) + C2, Vᴴ2 = @testinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol)) + Nᴴ2 = @testinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol)) + @test C2 * Vᴴ2 ≈ A + @test isisometric(Vᴴ2; side = :right) + @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(Nᴴ2; side = :right) + @test collect(Vᴴ2)' * collect(Vᴴ2) + collect(Nᴴ2)' * collect(Nᴴ2) ≈ I + end + + for alg in (:lq, :polar, :svd) + n < m && alg == :polar && continue + # broken + #C2, Vᴴ2 = @testinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = alg) + C2, Vᴴ2 = right_orth!(copy!(Ac, A), (C, Vᴴ); alg = alg) + @test C2 * Vᴴ2 ≈ A + @test isisometric(Vᴴ2; side = :right) + if alg != :polar + Nᴴ2 = @testinferred right_null!(copy!(Ac, A), Nᴴ; alg = alg) + @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(Nᴴ2; side = :right) + @test collect(Vᴴ2)' * collect(Vᴴ2) + collect(Nᴴ2)' * collect(Nᴴ2) ≈ I + end + + if alg == :svd + if (T <: Number || T <: Diagonal{<:Number, <:Vector}) + # broken + #C2, Vᴴ2 = @testinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = alg, trunc = (; atol)) + C2, Vᴴ2 = right_orth!(copy!(Ac, A), (C, Vᴴ); alg = alg, trunc = (; atol)) + Nᴴ2 = @testinferred right_null!(copy!(Ac, A), Nᴴ; alg = alg, trunc = (; atol)) + @test C2 * Vᴴ2 ≈ A + @test isisometric(Vᴴ2; side = :right) + @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(Nᴴ2; side = :right) + @test collect(Vᴴ2)' * collect(Vᴴ2) + collect(Nᴴ2)' * collect(Nᴴ2) ≈ I + + # broken + #C2, Vᴴ2 = @testinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = alg, trunc = (; rtol)) + C2, Vᴴ2 = right_orth!(copy!(Ac, A), (C, Vᴴ); alg = alg, trunc = (; rtol)) + Nᴴ2 = @testinferred right_null!(copy!(Ac, A), Nᴴ; alg = alg, trunc = (; rtol)) + @test C2 * Vᴴ2 ≈ A + @test isisometric(Vᴴ2; side = :right) + @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(Nᴴ2; side = :right) + @test collect(Vᴴ2)' * collect(Vᴴ2) + collect(Nᴴ2)' * collect(Nᴴ2) ≈ I + end + else + @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) + @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) + alg == :polar && continue + @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) + @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) + end + end + end +end diff --git a/test/testsuite/polar.jl b/test/testsuite/polar.jl new file mode 100644 index 00000000..57c06e14 --- /dev/null +++ b/test/testsuite/polar.jl @@ -0,0 +1,102 @@ +using TestExtras +using LinearAlgebra: isposdef + +function test_polar(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "polar $summary_str" begin + (length(sz) == 1 || sz[1] ≥ sz[2]) && test_left_polar(T, sz; kwargs...) + (length(sz) == 1 || sz[2] ≥ sz[1]) && test_right_polar(T, sz; kwargs...) + end +end + +function test_left_polar( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "left_polar! $summary_str" begin + A = instantiate_matrix(T, sz) + algs = if T <: Diagonal + (PolarNewton(),) + elseif T <: Number + (PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)), PolarNewton()) + else + (PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)),) + end + @testset "algorithm $alg" for alg in algs + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + W, P = left_polar(A; alg) + @test eltype(W) == eltype(A) && size(W) == (size(A, 1), size(A, 2)) + @test eltype(P) == eltype(A) && size(P) == (size(A, 2), size(A, 2)) + @test W * P ≈ A + @test isisometric(W) + @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) + @test isposdef(project_hermitian!(P)) + + W2, P2 = @testinferred left_polar!(Ac, (W, P), alg) + @test W2 === W + @test P2 === P + @test W * P ≈ A + @test isisometric(W) + @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) + @test isposdef(project_hermitian!(P)) + + noP = similar(P, (0, 0)) + W2, P2 = @testinferred left_polar!(copy!(Ac, A), (W, noP), alg) + @test P2 === noP + @test W2 === W + @test isisometric(W) + P = W' * A # compute P explicitly to verify W correctness + @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) + @test isposdef(project_hermitian!(P)) + end + end +end + +function test_right_polar( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "right_polar! $summary_str" begin + A = instantiate_matrix(T, sz) + algs = if T <: Diagonal + (PolarNewton(),) + elseif T <: Number + (PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)), PolarNewton()) + else + (PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)),) + end + @testset "algorithm $alg" for alg in algs + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + P, Wᴴ = right_polar(A; alg) + @test eltype(Wᴴ) == eltype(A) && size(Wᴴ) == (size(A, 1), size(A, 2)) + @test eltype(P) == eltype(A) && size(P) == (size(A, 1), size(A, 1)) + @test P * Wᴴ ≈ A + @test isisometric(Wᴴ; side = :right) + @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) + @test isposdef(project_hermitian!(P)) + + P2, Wᴴ2 = @testinferred right_polar!(Ac, (P, Wᴴ), alg) + @test P2 === P + @test Wᴴ2 === Wᴴ + @test P * Wᴴ ≈ A + @test isisometric(Wᴴ; side = :right) + @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) + @test isposdef(project_hermitian!(P)) + + noP = similar(P, (0, 0)) + P2, Wᴴ2 = @testinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg) + @test P2 === noP + @test Wᴴ2 === Wᴴ + @test isisometric(Wᴴ; side = :right) + P = A * Wᴴ' # compute P explicitly to verify W correctness + @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) + @test isposdef(project_hermitian!(P)) + end + end +end diff --git a/test/testsuite/projections.jl b/test/testsuite/projections.jl new file mode 100644 index 00000000..e7b3d5e1 --- /dev/null +++ b/test/testsuite/projections.jl @@ -0,0 +1,138 @@ +using TestExtras +using MatrixAlgebraKit: ishermitian +using LinearAlgebra: Diagonal, normalize! + +function test_projections(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "projections $summary_str" begin + test_project_antihermitian(T, sz; kwargs...) + test_project_hermitian(T, sz; kwargs...) + test_project_isometric(T, sz; kwargs...) + end +end + +function test_project_antihermitian( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "project_antihermitian! $summary_str" begin + noisefactor = eps(real(eltype(T)))^(3 / 4) + algs = (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64)) + @testset "algorithm $alg" for alg in algs + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + Ah = (A + A') / 2 + Aa = (A - A') / 2 + + Ba = project_antihermitian(A, alg) + @test isantihermitian(Ba) + @test Ba ≈ Aa + @test A == Ac + Ba_approx = Ba + noisefactor * Ah + @test !isantihermitian(Ba_approx) + # this is never anti-hermitian for real Diagonal: |A - A'| == 0 + @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0 + + copy!(Ac, A) + Ba = project_antihermitian!(Ac, alg) + @test Ba === Ac + @test isantihermitian(Ba) + @test Ba ≈ Aa + end + + # test approximate error calculation + A = normalize!(randn(rng, eltype(T), size(A)...)) + Ah = project_hermitian(A) + Aa = project_antihermitian(A) + + Ah_approx = Ah + noisefactor * Aa + ϵ = norm(project_antihermitian(Ah_approx)) + @test !ishermitian(Ah_approx; atol = (999 // 1000) * ϵ) + @test ishermitian(Ah_approx; atol = (1001 // 1000) * ϵ) + + Aa_approx = Aa + noisefactor * Ah + ϵ = norm(project_hermitian(Aa_approx)) + @test !isantihermitian(Aa_approx; atol = (999 // 1000) * ϵ) + @test isantihermitian(Aa_approx; atol = (1001 // 1000) * ϵ) + end +end + +function test_project_hermitian( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "project_hermitian! $summary_str" begin + noisefactor = eps(real(eltype(T)))^(3 / 4) + algs = (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64)) + @testset "algorithm $alg" for alg in algs + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + Ah = (A + A') / 2 + Aa = (A - A') / 2 + + Bh = project_hermitian(A, alg) + @test ishermitian(Bh) + @test Bh ≈ Ah + @test A == Ac + Bh_approx = Bh + noisefactor * Aa + # this is still hermitian for real Diagonal: |A - A'| == 0 + @test !ishermitian(Bh_approx) || norm(Aa) == 0 + @test ishermitian(Bh_approx; rtol = 10 * noisefactor) + + Bh = project_hermitian!(Ac, alg) + @test Bh === Ac + @test ishermitian(Bh) + @test Bh ≈ Ah + end + + # test approximate error calculation + A = normalize!(randn(rng, eltype(T), size(A)...)) + Ah = project_hermitian(A) + Aa = project_antihermitian(A) + + Ah_approx = Ah + noisefactor * Aa + ϵ = norm(project_antihermitian(Ah_approx)) + @test !ishermitian(Ah_approx; atol = (999 // 1000) * ϵ) + @test ishermitian(Ah_approx; atol = (1001 // 1000) * ϵ) + + Aa_approx = Aa + noisefactor * Ah + ϵ = norm(project_hermitian(Aa_approx)) + @test !isantihermitian(Aa_approx; atol = (999 // 1000) * ϵ) + @test isantihermitian(Aa_approx; atol = (1001 // 1000) * ϵ) + end +end + +function test_project_isometric( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "project_isometric! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + k = min(size(A)...) + W = project_isometric(A) + @test isisometric(W) + W2 = project_isometric(W) + @test W2 ≈ W # stability of the projection + @test W * (W' * A) ≈ A + + W2 = @testinferred project_isometric!(Ac, W) + @test W2 === W + @test isisometric(W) + + # test that W is closer to A then any other isometry + for k in 1:10 + δA = instantiate_matrix(T, sz) + W = project_isometric(A) + W2 = project_isometric(A + δA / 100) + # must be ≥ for real Diagonal case + @test norm(A - W2) ≥ norm(A - W) + end + end +end diff --git a/test/testsuite/qr.jl b/test/testsuite/qr.jl new file mode 100644 index 00000000..09e2a637 --- /dev/null +++ b/test/testsuite/qr.jl @@ -0,0 +1,162 @@ +using TestExtras + +function test_qr(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "qr $summary_str" begin + test_qr_compact(T, sz; kwargs...) + test_qr_full(T, sz; kwargs...) + test_qr_null(T, sz; kwargs...) + end +end + +function test_qr_compact( + T::Type, sz; + test_positive = true, test_pivoted = true, test_blocksize = true, + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "qr_compact! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + + # does the elementary functionality work + Q, R = @testinferred qr_compact(A) + @test Q * R ≈ A + @test isisometric(Q; atol, rtol) + @test istriu(R) + @test A == Ac + + # can I pass in outputs? + Q2, R2 = @testinferred qr_compact!(deepcopy(A), (Q, R)) + @test Q2 * R2 ≈ A + @test isisometric(Q2; atol, rtol) + @test istriu(R2) + + # do we support `positive = true`? + if test_positive + Qpos, Rpos = @testinferred qr_compact(A; positive = true) + @test Qpos * Rpos ≈ A + @test isisometric(Qpos; atol, rtol) + @test istriu(Rpos) + @test has_positive_diagonal(Rpos) + else + @test_throws Exception qr_compact(A; positive = true) + end + + # do we support `pivoted = true`? + if test_pivoted + Qpiv, Rpiv = @testinferred qr_compact(A; pivoted = true) + @test Qpiv * Rpiv ≈ A + @test isisometric(Qpos; atol, rtol) + else + @test_throws Exception qr_compact(A; pivoted = true) + end + + # do we support `blocksize = Int`? + if test_blocksize + Qblocked, Rblocked = @testinferred qr_compact(A; blocksize = 2) + @test Qblocked * Rblocked ≈ A + @test isisometric(Qblocked; atol, rtol) + else + @test_throws Exception qr_compact(A; blocksize = 2) + end + end +end + +function test_qr_full( + T::Type, sz; + test_positive = true, test_pivoted = true, test_blocksize = true, + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "qr_full! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + + # does the elementary functionality work + Q, R = @testinferred qr_full(A) + @test Q * R ≈ A + @test isunitary(Q; atol, rtol) + @test istriu(R) + @test A == Ac + + # can I pass in outputs? + Q2, R2 = @testinferred qr_full!(deepcopy(A), (Q, R)) + @test Q2 * R2 ≈ A + @test isunitary(Q2; atol, rtol) + @test istriu(R2) + + # do we support `positive = true`? + if test_positive + Qpos, Rpos = @testinferred qr_full(A; positive = true) + @test Qpos * Rpos ≈ A + @test isunitary(Qpos; atol, rtol) + @test istriu(Rpos) + @test has_positive_diagonal(Rpos) + else + @test_throws Exception qr_full(A; positive = true) + end + + # do we support `pivoted = true`? + if test_pivoted + Qpiv, Rpiv = @testinferred qr_full(A; pivoted = true) + @test Qpiv * Rpiv ≈ A + @test isunitary(Qpos; atol, rtol) + else + @test_throws Exception qr_full(A; pivoted = true) + end + + # do we support `blocksize = Int`? + if test_blocksize + Qblocked, Rblocked = @testinferred qr_full(A; blocksize = 2) + @test Qblocked * Rblocked ≈ A + @test isunitary(Qblocked; atol, rtol) + else + @test_throws Exception qr_full(A; blocksize = 2) + end + end +end + +function test_qr_null( + T::Type, sz; + test_pivoted = true, test_blocksize = true, + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "qr_null! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + + # does the elementary functionality work + N = @testinferred qr_null(A) + @test isleftnull(N, A; atol, rtol) + @test isisometric(N; atol, rtol) + @test A == Ac + + # can I pass in outputs? + N2 = @testinferred qr_null!(deepcopy(A), N) + @test isleftnull(N2, A; atol, rtol) + @test isisometric(N2; atol, rtol) + + # do we support `pivoted = true`? + if test_pivoted + Npiv = @testinferred qr_null(A; pivoted = true) + @test isleftnull(Npiv, A; atol, rtol) + @test isisometric(Npiv; atol, rtol) + else + @test_throws Exception qr_null(A; pivoted = true) + end + + # do we support `blocksize = Int`? + if test_blocksize + Nblocked = @testinferred qr_null(A; blocksize = 2) + @test isleftnull(Nblocked, A; atol, rtol) + @test isisometric(Nblocked; atol, rtol) + else + @test_throws Exception qr_null(A; blocksize = 2) + end + end +end diff --git a/test/testsuite/schur.jl b/test/testsuite/schur.jl new file mode 100644 index 00000000..4d3f306f --- /dev/null +++ b/test/testsuite/schur.jl @@ -0,0 +1,38 @@ +using TestExtras + +function test_schur(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "schur $summary_str" begin + test_schur_full(T, sz; kwargs...) + end +end + +function test_schur_full( + T::Type, sz; + test_blocksize = true, + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eig_full! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T)) + + TA, Z, vals = @testinferred schur_full(A) + @test eltype(TA) == eltype(Z) == eltype(T) + @test eltype(vals) == Tc + @test isisometric(Z) + @test A * Z ≈ Z * TA + + TA2, Z2, vals2 = @testinferred schur_full!(Ac, (TA, Z, vals)) + @test TA2 === TA + @test Z2 === Z + @test vals2 === vals + @test A * Z ≈ Z * TA + + valsc = @testinferred schur_vals(A) + @test eltype(valsc) == Tc + @test valsc ≈ eig_vals(A) + end +end diff --git a/test/testsuite/svd.jl b/test/testsuite/svd.jl new file mode 100644 index 00000000..b2a370c4 --- /dev/null +++ b/test/testsuite/svd.jl @@ -0,0 +1,167 @@ +using TestExtras + +function test_svd(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "svd $summary_str" begin + test_svd_compact(T, sz; kwargs...) + test_svd_full(T, sz; kwargs...) + if min(sz...) > 0 && (T <: Number || T <: Diagonal{<:Number, <:Vector}) + test_svd_trunc(T, sz; kwargs...) + end + end +end + +function test_svd_compact( + T::Type, sz; + test_blocksize = true, + atol::Real = 0, rtol::Real = precision(eltype(T)), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "svd_compact! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + m, n = size(A) + minmn = min(m, n) + if VERSION < v"1.11" + # This is type unstable on older versions of Julia. + U, S, Vᴴ = svd_compact(A) + else + U, S, Vᴴ = @testinferred svd_compact(A) + end + @test size(U) == (m, minmn) + @test S isa Diagonal{real(eltype(T))} && size(S) == (minmn, minmn) + @test size(Vᴴ) == (minmn, n) + @test U * S * Vᴴ ≈ A + @test isisometric(U) + @test isisometric(Vᴴ; side = :right) + @test isposdef(S) + + Sc = similar(A, real(eltype(T)), min(m, n)) + U2, S2, V2ᴴ = @testinferred svd_compact!(Ac, (U, S, Vᴴ)) + @test U2 === U + @test S2 === S + @test V2ᴴ === Vᴴ + @test U * S * Vᴴ ≈ A + @test isisometric(U) + @test isisometric(Vᴴ; side = :right) + @test isposdef(S) + + Sd = @testinferred svd_vals(A) + @test S ≈ Diagonal(Sd) + end +end + +function test_svd_full( + T::Type, sz; + test_blocksize = true, + atol::Real = 0, rtol::Real = precision(eltype(T)), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "svd_full! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + m, n = size(A) + minmn = min(m, n) + + U, S, Vᴴ = svd_full(A) + @test size(U) == (m, m) + @test eltype(S) == real(eltype(T)) && size(S) == (m, n) + @test size(Vᴴ) == (n, n) + @test U * S * Vᴴ ≈ A + @test isunitary(U) + @test isunitary(Vᴴ) + @test all(isposdef, diagview(S)) + + U2, S2, V2ᴴ = @testinferred svd_full!(Ac, (U, S, Vᴴ)) + @test U2 === U + @test S2 === S + @test V2ᴴ === Vᴴ + @test U * S * Vᴴ ≈ A + @test isunitary(U) + @test isunitary(Vᴴ) + @test all(isposdef, diagview(S)) + + Sc = similar(A, real(eltype(T)), min(m, n)) + Sc2 = svd_vals!(copy!(Ac, A), Sc) + @test Sc === Sc2 + @test collect(diagview(S)) ≈ collect(Sc) + end +end + +function test_svd_trunc( + T::Type, sz; + test_blocksize = true, + atol::Real = 0, rtol::Real = precision(eltype(T)), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "svd_trunc! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + m, n = size(A) + minmn = min(m, n) + S₀ = svd_vals(A) + r = minmn - 2 + + U1, S1, V1ᴴ, ϵ1 = @testinferred svd_trunc(A; trunc = truncrank(r)) + @test length(diagview(S1)) == r + @test diagview(S1) ≈ S₀[1:r] + @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + # Test truncation error + @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + + s = 1 + sqrt(eps(real(eltype(T)))) + trunc = trunctol(; atol = s * S₀[r + 1]) + + U2, S2, V2ᴴ, ϵ2 = @testinferred svd_trunc(A; trunc) + @test length(diagview(S2)) == r + @test U1 ≈ U2 + @test S1 ≈ S2 + @test V1ᴴ ≈ V2ᴴ + @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + + trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) + U3, S3, V3ᴴ, ϵ3 = @testinferred svd_trunc(A; trunc) + @test length(diagview(S3)) == r + @test U1 ≈ U3 + @test S1 ≈ S3 + @test V1ᴴ ≈ V3ᴴ + @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + + # TODO + #=@testset "mix maxrank and tol" begin + U = qr_compact(randn(rng, T, m, m))[1] + S = Diagonal(T[0.9, 0.3, 0.1, 0.01]) + Vᴴ = qr_compact(randn(rng, T, m, m))[1] + A = U * S * Vᴴ + + for trunc_fun in ( + (rtol, maxrank) -> (; rtol, maxrank), + (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), + ) + U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; trunc = trunc_fun(0.2, 1)) + @test length(diagview(S1)) == 1 + @test diagview(S1) ≈ diagview(S)[1:1] + + U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; trunc = trunc_fun(0.2, 3)) + @test length(diagview(S2)) == 2 + @test diagview(S2) ≈ diagview(S)[1:2] + end + end + @testset "specify truncation algorithm" begin + atol = sqrt(eps(real(eltype(T)))) + m = 4 + U = qr_compact(randn(rng, T, m, m))[1] + S = Diagonal(real(eltype(T))[0.9, 0.3, 0.1, 0.01]) + Vᴴ = qr_compact(randn(rng, T, m, m))[1] + A = U * S * Vᴴ + alg = TruncatedAlgorithm(trunctol(; atol = 0.2)) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A) + @test diagview(S2) ≈ diagview(S)[1:2] + @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol + @test_throws ArgumentError svd_trunc(A; trunc = (; maxrank = 2)) + end=# + end +end From 186ea8f8f1aa3a29db6c4f7b8bff1f75744a5b62 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 09:10:58 -0500 Subject: [PATCH 2/6] AD to Testsuite --- test/ad_utils.jl | 31 -- test/chainrules.jl | 570 +--------------------------------- test/mooncake.jl | 577 ++-------------------------------- test/runtests.jl | 34 +- test/testsuite/TestSuite.jl | 4 + test/testsuite/ad_utils.jl | 352 +++++++++++++++++++++ test/testsuite/chainrules.jl | 587 +++++++++++++++++++++++++++++++++++ test/testsuite/mooncake.jl | 412 ++++++++++++++++++++++++ 8 files changed, 1401 insertions(+), 1166 deletions(-) delete mode 100644 test/ad_utils.jl create mode 100644 test/testsuite/ad_utils.jl create mode 100644 test/testsuite/chainrules.jl create mode 100644 test/testsuite/mooncake.jl diff --git a/test/ad_utils.jl b/test/ad_utils.jl deleted file mode 100644 index 4c03e50c..00000000 --- a/test/ad_utils.jl +++ /dev/null @@ -1,31 +0,0 @@ -function remove_svdgauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true) - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 - mul!(ΔU, U, gaugepart, -1, 1) - return ΔU, ΔVᴴ -end -function remove_eiggauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = V' * ΔV - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V / (V' * V), gaugepart, -1, 1) - return ΔV -end -function remove_eighgauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = V' * ΔV - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V, gaugepart, -1, 1) - return ΔV -end - -precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T))) diff --git a/test/chainrules.jl b/test/chainrules.jl index 5258b839..8ed6fdb8 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,569 +1,21 @@ using MatrixAlgebraKit using Test -using TestExtras using StableRNGs -using ChainRulesCore, ChainRulesTestUtils, Zygote -using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD -using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! -include("ad_utils.jl") +#BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI -for f in - ( - :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, - :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, - :svd_compact, :svd_trunc, :svd_vals, - :left_polar, :right_polar, - ) - copy_f = Symbol(:copy_, f) - f! = Symbol(f, '!') - _hermitian = startswith(string(f), "eigh") - @eval begin - function $copy_f(input, alg) - if $_hermitian - input = (input + input') / 2 - end - return $f(input, alg) - end - function ChainRulesCore.rrule(::typeof($copy_f), input, alg) - output = MatrixAlgebraKit.initialize_output($f!, input, alg) - if $_hermitian - input = (input + input') / 2 - else - input = copy(input) - end - output, pb = ChainRulesCore.rrule($f!, input, output, alg) - return output, x -> (NoTangent(), pb(x)[2], NoTangent()) - end - end -end - -@timedtestset "QR AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - # qr_compact - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - alg = LAPACK_HouseholderQR(; positive = true) - Q, R = copy_qr_compact(A, alg) - ΔQ = randn(rng, T, m, minmn) - ΔR = randn(rng, T, minmn, n) - ΔR2 = UpperTriangular(randn(rng, T, minmn, minmn)) - ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) - test_rrule( - copy_qr_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - test_rrule( - copy_qr_null, A, alg ⊢ NoTangent(); - output_tangent = ΔN, atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ qr_compact, A; - fkwargs = (; positive = true), output_tangent = ΔQ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ qr_compact, A; - fkwargs = (; positive = true), output_tangent = ΔR, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, qr_null, A; - fkwargs = (; positive = true), output_tangent = ΔN, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - # qr_full - Q, R = copy_qr_full(A, alg) - Q1 = view(Q, 1:m, 1:minmn) - ΔQ = randn(rng, T, m, m) - ΔQ2 = view(ΔQ, :, (minmn + 1):m) - mul!(ΔQ2, Q1, Q1' * ΔQ2) - ΔR = randn(rng, T, m, n) - test_rrule( - copy_qr_full, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_full, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - if m > n - _, null_pb = Zygote.pullback(qr_null, A, alg) - @test_logs (:warn,) null_pb(randn(rng, T, m, max(0, m - minmn))) - _, full_pb = Zygote.pullback(qr_full, A, alg) - @test_logs (:warn,) full_pb((randn(rng, T, m, m), randn(rng, T, m, n))) - end - # rank-deficient A - r = minmn - 5 - A = randn(rng, T, m, r) * randn(rng, T, r, n) - Q, R = qr_compact(A, alg) - ΔQ = randn(rng, T, m, minmn) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - ΔQ2 .= 0 - ΔR = randn(rng, T, minmn, n) - view(ΔR, (r + 1):minmn, :) .= 0 - test_rrule( - copy_qr_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "LQ AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - # lq_compact - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - alg = LAPACK_HouseholderLQ(; positive = true) - L, Q = copy_lq_compact(A, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - test_rrule( - copy_lq_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - test_rrule( - copy_lq_null, A, alg ⊢ NoTangent(); - output_tangent = ΔNᴴ, atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ lq_compact, A; - fkwargs = (; positive = true), output_tangent = ΔL, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ lq_compact, A; - fkwargs = (; positive = true), output_tangent = ΔQ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, lq_null, A; - fkwargs = (; positive = true), output_tangent = ΔNᴴ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - # lq_full - L, Q = copy_lq_full(A, alg) - Q1 = view(Q, 1:minmn, 1:n) - ΔQ = randn(rng, T, n, n) - ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) - mul!(ΔQ2, ΔQ2 * Q1', Q1) - ΔL = randn(rng, T, m, n) - test_rrule( - copy_lq_full, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_full, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - if m < n - Nᴴ, null_pb = Zygote.pullback(lq_null, A, alg) - @test_logs (:warn,) null_pb(randn(rng, T, max(0, n - minmn), n)) - _, full_pb = Zygote.pullback(lq_full, A, alg) - @test_logs (:warn,) full_pb((randn(rng, T, m, n), randn(rng, T, n, n))) - end - # rank-deficient A - r = minmn - 5 - A = randn(rng, T, m, r) * randn(rng, T, r, n) - L, Q = lq_compact(A, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - test_rrule( - copy_lq_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "EIG AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) - D, V = eig_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, complex(T), m, m) - ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, complex(T), m, m) - ΔD2 = Diagonal(randn(rng, complex(T), m)) - for alg in (LAPACK_Simple(), LAPACK_Expert()) - test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol - ) - test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol - ) - test_rrule( - copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol - ) - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(Ddiag[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, eig_full, A; - output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eig_full, A; - output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ eig_full, A; - output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ eig_full, A; - output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eig_vals, A; - output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) -end - -@timedtestset "EIGH AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) - A = A + A' - D, V = eigh_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, T, m, m) - ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, real(T), m, m) - ΔD2 = Diagonal(randn(rng, real(T), m)) - for alg in ( - LAPACK_QRIteration(), LAPACK_DivideAndConquer(), LAPACK_Bisection(), - LAPACK_MultipleRelativelyRobustRepresentations(), - ) - # copy_eigh_full includes a projector onto the Hermitian part of the matrix - test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol - ) - test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol - ) - test_rrule( - copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol - ) - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - # eigh_full does not include a projector onto the Hermitian part of the matrix - test_rrule( - config, eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eigh_vals ∘ Matrix ∘ Hermitian, A; - output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) - for r in 1:4:m - trunc = truncrank(r; by = real) - ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) - test_rrule( - config, eigh_trunc2, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end - trunc = trunctol(; rtol = 1 / 2) - ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) - test_rrule( - config, eigh_trunc2, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) -end - -@timedtestset "SVD AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - for alg in (LAPACK_QRIteration(), LAPACK_DivideAndConquer()) - test_rrule( - copy_svd_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔU, ΔS, ΔVᴴ), atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_vals, A, alg ⊢ NoTangent(); - output_tangent = diagview(ΔS), atol, rtol - ) - for r in 1:4:minmn - truncalg = TruncatedAlgorithm(alg, truncrank(r)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) - dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) - dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, svd_compact, A; - output_tangent = (ΔU, ΔS, ΔVᴴ), atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_compact, A; - output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_vals, A; - output_tangent = diagview(ΔS), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - for r in 1:4:minmn - trunc = truncrank(r) - ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) - test_rrule( - config, svd_trunc, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end - trunc = trunctol(; atol = S[1, 1] / 2) - ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) - test_rrule( - config, svd_trunc, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "Polar AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) - m >= n && - test_rrule(copy_left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) - m <= n && - test_rrule(copy_right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - m >= n && test_rrule( - config, left_polar, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m <= n && test_rrule( - config, right_polar, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "Orth and null with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, left_orth, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, left_orth, A; - fkwargs = (; alg = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m >= n && - test_rrule( - config, left_orth, A; - fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - test_rrule( - config, left_null, A; - fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - test_rrule( - config, right_orth, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, right_orth, A; fkwargs = (; alg = :lq), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m <= n && - test_rrule( - config, right_orth, A; fkwargs = (; alg = :polar), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) +rng = StableRNG(12345) - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - test_rrule( - config, right_null, A; - fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) +m = 19 +for T in BLASFloats, n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite # doesn't work on GPU + TestSuite.test_chainrules(T, (m, n), rng; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/mooncake.jl b/test/mooncake.jl index 3e19e44d..aa449c44 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -1,569 +1,28 @@ using MatrixAlgebraKit using Test -using TestExtras using StableRNGs -using Mooncake, Mooncake.TestUtils -using Mooncake: rrule!! -using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc -using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! +using CUDA, AMDGPU -include("ad_utils.jl") +#BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI -make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) -make_mooncake_tangent(ΔA::Matrix{<:Real}) = ΔA -make_mooncake_tangent(ΔA::Vector{<:Real}) = ΔA -make_mooncake_tangent(ΔA::Matrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) -make_mooncake_tangent(ΔA::Vector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) -make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) -make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD))) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite -make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), T...) +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" -make_mooncake_fdata(x) = make_mooncake_tangent(x) -make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) - -ETs = (Float32, ComplexF64) - -# no `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) - dA_copy = make_mooncake_tangent(copy(ΔA)) - A_copy = copy(A) - dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy)) - copy_pb!!(rdata) - return dA_copy -end - -# `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_copy = make_mooncake_tangent(copy(ΔA)) - A_copy = copy(A) - dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) - copy_pb!!(rdata) - return dA_copy -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) - dA_inplace = make_mooncake_tangent(copy(ΔA)) - A_inplace = copy(A) - dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) - end - inplace_pb!!(rdata) - return dA_inplace -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) - dA_inplace = make_mooncake_tangent(copy(ΔA)) - A_inplace = copy(A) - dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) - end - inplace_pb!!(rdata) - return dA_inplace -end - -""" - test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - -Compare the result of running the *in-place, mutating* function `f!`'s reverse rule -with the result of running its *non-mutating* partner function `f`'s reverse rule. -We must compare directly because many of the mutating functions modify `A` as a -scratch workspace, making testing `f!` against finite differences infeasible. - -The arguments to this function are: - - `f!` the mutating, in-place version of the function (accepts `args` for the function result) - - `f` the non-mutating version of the function (does not accept `args` for the function result) - - `A` the input matrix to factorize - - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) - - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input - - `alg` optional algorithm keyword argument - - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) -""" -function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) - sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - rrule = Mooncake.build_rrule(rvs_interp, sig) - ΔA = randn(rng, eltype(A), size(A)) - - dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) - - dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] - dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] - @test dA_inplace_ ≈ dA_copy_ - return -end - -@timedtestset "QR AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_HouseholderQR(), - LAPACK_HouseholderQR(; positive = true), - ) - @testset "qr_compact" begin - QR = qr_compact(A, alg) - Q = randn(rng, T, m, minmn) - R = randn(rng, T, minmn, n) - Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_compact!, qr_compact, A, (Q, R), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) - end - @testset "qr_null" begin - Q, R = qr_compact(A, alg) - ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) - N = qr_null(A, alg) - dN = make_mooncake_tangent(copy(ΔN)) - Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dN, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) - end - @testset "qr_full" begin - Q, R = qr_full(A, alg) - Q1 = view(Q, 1:m, 1:minmn) - ΔQ = randn(rng, T, m, m) - ΔQ2 = view(ΔQ, :, (minmn + 1):m) - mul!(ΔQ2, Q1, Q1' * ΔQ2) - ΔR = randn(rng, T, m, n) - dQ = make_mooncake_tangent(copy(ΔQ)) - dR = make_mooncake_tangent(copy(ΔR)) - dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) - end - @testset "qr_compact - rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) - Q, R = qr_compact(Ard, alg) - QR = (Q, R) - ΔQ = randn(rng, T, m, minmn) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - ΔQ2 .= 0 - ΔR = randn(rng, T, minmn, n) - view(ΔR, (r + 1):minmn, :) .= 0 - dQ = make_mooncake_tangent(copy(ΔQ)) - dR = make_mooncake_tangent(copy(ΔR)) - dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) - end - end - end -end - -@timedtestset "LQ AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_HouseholderLQ(), - LAPACK_HouseholderLQ(; positive = true), - ) - @testset "lq_compact" begin - L, Q = lq_compact(A, alg) - Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) - end - @testset "lq_null" begin - L, Q = lq_compact(A, alg) - ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - Nᴴ = randn(rng, T, max(0, n - minmn), n) - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dNᴴ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) - end - @testset "lq_full" begin - L, Q = lq_full(A, alg) - Q1 = view(Q, 1:minmn, 1:n) - ΔQ = randn(rng, T, n, n) - ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) - mul!(ΔQ2, ΔQ2 * Q1', Q1) - ΔL = randn(rng, T, m, n) - dL = make_mooncake_tangent(ΔL) - dQ = make_mooncake_tangent(ΔQ) - dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) - end - @testset "lq_compact - rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) - L, Q = lq_compact(Ard, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - dL = make_mooncake_tangent(ΔL) - dQ = make_mooncake_tangent(ΔQ) - dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) - end - end - end -end - -@timedtestset "EIG AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) - DV = eig_full(A) - D, V = DV - Ddiag = diagview(D) - ΔV = randn(rng, complex(T), m, m) - ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, complex(T), m, m) - ΔD2 = Diagonal(randn(rng, complex(T), m)) - - dD = make_mooncake_tangent(ΔD2) - dV = make_mooncake_tangent(ΔV) - dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) - # compute the dA corresponding to the above dD, dV - @testset for alg in ( - LAPACK_Simple(), - #LAPACK_Expert(), # expensive on CI - ) - @testset "eig_full" begin - Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) - end - @testset "eig_vals" begin - Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) +rng = StableRNG(12345) +m = 19 +for T in BLASFloats, n in (17, m, 23) + TestSuite.seed_rng!(123) + if is_buildkite + if CUDA.functional() + TestSuite.test_mooncake(CuMatrix{T}, (m, n), rng; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end - @testset "eig_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - end - truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - end - end -end - -function copy_eigh_full(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_full(A, alg; kwargs...) -end - -function copy_eigh_full!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_full!(A, DV, alg; kwargs...) -end - -function copy_eigh_vals(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals(A, alg; kwargs...) -end - -function copy_eigh_vals!(A, D, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals!(A, D, alg; kwargs...) -end - -function copy_eigh_trunc(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc(A, alg; kwargs...) -end - -function copy_eigh_trunc!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc!(A, DV, alg; kwargs...) -end - -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) - -@timedtestset "EIGH AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) - A = A + A' - D, V = eigh_full(A) - ΔV = randn(rng, T, m, m) - ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, real(T), m, m) - ΔD2 = Diagonal(randn(rng, real(T), m)) - dD = make_mooncake_tangent(ΔD2) - dV = make_mooncake_tangent(ΔV) - dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) - Ddiag = diagview(D) - @testset for alg in ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), - #LAPACK_Bisection(), - #LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI - ) - @testset "eigh_full" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) - end - @testset "eigh_vals" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) - end - @testset "eigh_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - end - truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - end - end -end - -@timedtestset "SVD AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), # expensive on CI - ) - @testset "svd_compact" begin - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - dS = make_mooncake_tangent(ΔS2) - dU = make_mooncake_tangent(ΔU) - dVᴴ = make_mooncake_tangent(ΔVᴴ) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), alg) - end - @testset "svd_full" begin - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - ΔUfull = zeros(T, m, m) - ΔSfull = zeros(real(T), m, n) - ΔVᴴfull = zeros(T, n, n) - U, S, Vᴴ = svd_full(A) - view(ΔUfull, :, 1:minmn) .= ΔU - view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ - diagview(ΔSfull)[1:minmn] .= diagview(ΔS2) - dS = make_mooncake_tangent(ΔSfull) - dU = make_mooncake_tangent(ΔUfull) - dVᴴ = make_mooncake_tangent(ΔVᴴfull) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔUfull, ΔSfull, ΔVᴴfull)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_full!, svd_full, A, (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull), alg) - end - @testset "svd_vals" begin - Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) - S = svd_vals(A, alg) - test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, randn(rng, real(T), minmn), alg) - end - @testset "svd_trunc" begin - @testset for r in 1:4:minmn - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - truncalg = TruncatedAlgorithm(alg, truncrank(r)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - dStrunc = make_mooncake_tangent(ΔStrunc) - dUtrunc = make_mooncake_tangent(ΔUtrunc) - dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) - dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - end - @testset "trunctol" begin - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - dStrunc = make_mooncake_tangent(ΔStrunc) - dUtrunc = make_mooncake_tangent(ΔUtrunc) - dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) - dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - end - end - end - end -end - -@timedtestset "Polar AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - @testset for alg in PolarViaSVD.( - ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), # expensive on CI - ) - ) - if m >= n - WP = left_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, left_polar!, left_polar, A, WP, (randn(rng, T, m, n), randn(rng, T, n, n)), alg) - elseif m <= n - PWᴴ = right_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, (randn(rng, T, m, m), randn(rng, T, m, n)), alg) - end - end - end -end - -left_orth_qr(X) = left_orth(X; alg = :qr) -left_orth_polar(X) = left_orth(X; alg = :polar) -left_null_qr(X) = left_null(X; alg = :qr) -right_orth_lq(X) = right_orth(X; alg = :lq) -right_orth_polar(X) = right_orth(X; alg = :polar) -right_null_lq(X) = right_null(X; alg = :lq) - -MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) - -@timedtestset "Orth and null with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - VC = left_orth(A) - CVᴴ = right_orth(A) - Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, left_orth!, left_orth, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, right_orth!, right_orth, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - - Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - if m >= n - Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - end - - N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) - test_pullbacks_match(rng, ((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) - - Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - - if m <= n - Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - end - - Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) - test_pullbacks_match(rng, ((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) + #=if AMDGPU.functional() + TestSuite.test_mooncake(ROCMatrix{T}, (m, n), rng; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end=# # not yet supported + else + TestSuite.test_mooncake(T, (m, n), rng; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/runtests.jl b/test/runtests.jl index edc8385e..ab8432b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,23 +36,23 @@ end @safetestset "Image and Null Space" begin include("orthnull.jl") end -if !is_buildkite - @safetestset "Mooncake" begin - include("mooncake.jl") - end - @safetestset "ChainRules" begin - include("chainrules.jl") - end +@safetestset "Mooncake" begin + include("mooncake.jl") end -@safetestset "MatrixAlgebraKit.jl" begin - @safetestset "Code quality (Aqua.jl)" begin - using MatrixAlgebraKit - using Aqua - Aqua.test_all(MatrixAlgebraKit) - end - @safetestset "Code linting (JET.jl)" begin - using MatrixAlgebraKit - using JET - JET.test_package(MatrixAlgebraKit; target_defined_modules = true) +@safetestset "ChainRules" begin + include("chainrules.jl") +end +if !is_buildkite + @safetestset "MatrixAlgebraKit.jl" begin + @safetestset "Code quality (Aqua.jl)" begin + using MatrixAlgebraKit + using Aqua + Aqua.test_all(MatrixAlgebraKit) + end + @safetestset "Code linting (JET.jl)" begin + using MatrixAlgebraKit + using JET + JET.test_package(MatrixAlgebraKit; target_defined_modules = true) + end end end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index b7c54259..9de86172 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -63,6 +63,8 @@ macro testinferred(ex) return esc(:(@inferred $ex)) end +include("ad_utils.jl") + include("qr.jl") include("lq.jl") include("polar.jl") @@ -72,5 +74,7 @@ include("eigh.jl") include("eig.jl") include("schur.jl") include("svd.jl") +include("mooncake.jl") +include("chainrules.jl") end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl new file mode 100644 index 00000000..d93cd339 --- /dev/null +++ b/test/testsuite/ad_utils.jl @@ -0,0 +1,352 @@ +function remove_svdgauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) + ) + gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true) + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end +function remove_eiggauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + gaugepart = V' * ΔV + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end +function remove_eighgauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + gaugepart = V' * ΔV + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V, gaugepart, -1, 1) + return ΔV +end + +function ad_qr_compact_setup(rng, A) + m, n = size(A) + minmn = min(m, n) + QR = qr_compact(A) + T = eltype(A) + ΔQ = randn(rng, T, m, minmn) + ΔR = randn(rng, T, minmn, n) + return QR, (ΔQ, ΔR) +end + +function ad_qr_null_setup(rng, A) + m, n = size(A) + minmn = min(m, n) + Q, R = qr_compact(A) + T = eltype(A) + ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) + N = qr_null(A) + return N, ΔN +end + +function ad_qr_full_setup(rng, A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + Q, R = qr_full(A) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn(rng, T, m, m) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn(rng, T, m, n) + return (Q, R), (ΔQ, ΔR) +end + +function ad_qr_rd_compact_setup(rng, A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Q, R = qr_compact(Ard) + QR = (Q, R) + ΔQ = randn(rng, T, m, minmn) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn(rng, T, minmn, n) + view(ΔR, (r + 1):minmn, :) .= 0 + return (Q, R), (ΔQ, ΔR) +end + +function ad_lq_compact_setup(rng, A) + m, n = size(A) + minmn = min(m, n) + LQ = lq_compact(A) + T = eltype(A) + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + return LQ, (ΔL, ΔQ) +end + +function ad_lq_null_setup(rng, A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + L, Q = lq_compact(A) + ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q + Nᴴ = randn(rng, T, max(0, n - minmn), n) + return Nᴴ, ΔNᴴ +end + +function ad_lq_full_setup(rng, A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + L, Q = lq_full(A) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn(rng, T, n, n) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔL = randn(rng, T, m, n) + return (L, Q), (ΔL, ΔQ) +end + +function ad_lq_rd_compact_setup(rng, A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + L, Q = lq_compact(Ard) + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + return (L, Q), (ΔL, ΔQ) +end + +function ad_eig_full_setup(rng, A) + m, n = size(A) + T = eltype(A) + DV = eig_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn(rng, complex(T), m, m) + ΔV = remove_eiggauge_dependence!(ΔV, D, V) + ΔD = randn(rng, complex(T), m, m) + ΔD2 = Diagonal(randn(rng, complex(T), m)) + return DV, (ΔD, ΔV), (ΔD2, ΔV) +end + +function ad_eig_vals_setup(rng, A) + m, n = size(A) + T = eltype(A) + D = eig_vals(A) + ΔD = randn(rng, complex(T), m) + return D, ΔD +end + +function ad_eig_trunc_setup(rng, A, truncalg) + m, n = size(A) + T = eltype(A) + DV = eig_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn(rng, complex(T), m, m) + ΔV = remove_eiggauge_dependence!(ΔV, D, V) + ΔD = randn(rng, complex(T), m, m) + ΔD2 = Diagonal(randn(rng, complex(T), m)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + return DV, (ΔD2, ΔV), (ΔDtrunc, ΔVtrunc) +end + +function copy_eigh_full(A; kwargs...) + A = (A + A') / 2 + return eigh_full(A; kwargs...) +end + +function copy_eigh_full!(A, DV; kwargs...) + A = (A + A') / 2 + return eigh_full!(A, DV; kwargs...) +end + +function copy_eigh_vals(A; kwargs...) + A = (A + A') / 2 + return eigh_vals(A; kwargs...) +end + +function copy_eigh_vals!(A, D; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D; kwargs...) +end + +function copy_eigh_trunc(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc(A, alg; kwargs...) +end + +function copy_eigh_trunc!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc!(A, DV, alg; kwargs...) +end + +MatrixAlgebraKit.copy_input(::typeof(copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) +MatrixAlgebraKit.copy_input(::typeof(copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) +MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) + +function ad_eigh_full_setup(rng, A) + m, n = size(A) + T = eltype(A) + DV = eigh_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn(rng, T, m, m) + ΔV = remove_eighgauge_dependence!(ΔV, D, V) + ΔD = randn(rng, real(T), m, m) + ΔD2 = Diagonal(randn(rng, real(T), m)) + return DV, (ΔD, ΔV), (ΔD2, ΔV) +end + +function ad_eigh_vals_setup(rng, A) + m, n = size(A) + T = eltype(A) + D = eigh_vals(A) + ΔD = randn(rng, real(T), m) + return D, ΔD +end + +function ad_eigh_trunc_setup(rng, A, truncalg) + m, n = size(A) + T = eltype(A) + DV = eigh_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn(rng, T, m, m) + ΔV = remove_eighgauge_dependence!(ΔV, D, V) + ΔD = randn(rng, real(T), m, m) + ΔD2 = Diagonal(randn(rng, real(T), m)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + return DV, (ΔD2, ΔV), (ΔDtrunc, ΔVtrunc) +end + +function ad_svd_compact_setup(rng, A) + m, n = size(A) + T = eltype(A) + minmn = min(m, n) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), (ΔU, ΔS2, ΔVᴴ) +end + +function ad_svd_full_setup(rng, A) + m, n = size(A) + T = eltype(A) + minmn = min(m, n) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + ΔUfull = zeros(T, m, m) + ΔSfull = zeros(real(T), m, n) + ΔVᴴfull = zeros(T, n, n) + U, S, Vᴴ = svd_full(A) + view(ΔUfull, :, 1:minmn) .= ΔU + view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ + diagview(ΔSfull)[1:minmn] .= diagview(ΔS2) + return (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull) +end + +function ad_svd_vals_setup(rng, A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + S = svd_vals(A) + ΔS = randn(rng, real(T), minmn) + return S, ΔS +end + +function ad_svd_trunc_setup(rng, A, truncalg) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + return (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc) +end + +function ad_left_polar_setup(rng, A) + m, n = size(A) + T = eltype(A) + WP = left_polar(A) + ΔWP = (randn(rng, T, m, n), randn(rng, T, n, n)) + return WP, ΔWP +end + +function ad_right_polar_setup(rng, A) + m, n = size(A) + T = eltype(A) + PWᴴ = right_polar(A) + ΔPWᴴ = (randn(rng, T, m, m), randn(rng, T, m, n)) + return PWᴴ, ΔPWᴴ +end + +function ad_left_orth_setup(rng, A) + m, n = size(A) + T = eltype(A) + VC = left_orth(A) + ΔVC = (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...)) + return VC, ΔVC +end + +function ad_left_null_setup(rng, A) + m, n = size(A) + T = eltype(A) + N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + return N, ΔN +end + +function ad_right_orth_setup(rng, A) + m, n = size(A) + T = eltype(A) + CVᴴ = right_orth(A) + ΔCVᴴ = (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...)) + return CVᴴ, ΔCVᴴ +end + +function ad_right_null_setup(rng, A) + m, n = size(A) + T = eltype(A) + Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] + return Nᴴ, ΔNᴴ +end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl new file mode 100644 index 00000000..5edd8319 --- /dev/null +++ b/test/testsuite/chainrules.jl @@ -0,0 +1,587 @@ +using MatrixAlgebraKit +using ChainRulesCore, ChainRulesTestUtils, Zygote +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +include("ad_utils.jl") + +for f in + ( + :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, + :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, + :svd_compact, :svd_trunc, :svd_vals, + :left_polar, :right_polar, + ) + copy_f = Symbol(:copy_, f) + f! = Symbol(f, '!') + _hermitian = startswith(string(f), "eigh") + @eval begin + function $copy_f(input, alg) + if $_hermitian + input = (input + input') / 2 + end + return $f(input, alg) + end + function ChainRulesCore.rrule(::typeof($copy_f), input, alg) + output = MatrixAlgebraKit.initialize_output($f!, input, alg) + if $_hermitian + input = (input + input') / 2 + else + input = copy(input) + end + output, pb = ChainRulesCore.rrule($f!, input, output, alg) + return output, x -> (NoTangent(), pb(x)[2], NoTangent()) + end + end +end + +function test_chainrules(T::Type, sz, rng; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Chainrules AD $summary_str" begin + test_chainrules_qr(T, sz, rng; kwargs...) + test_chainrules_lq(T, sz, rng; kwargs...) + if length(sz) == 1 || sz[1] == sz[2] + test_chainrules_eig(T, sz, rng; kwargs...) + test_chainrules_eigh(T, sz, rng; kwargs...) + end + test_chainrules_svd(T, sz, rng; kwargs...) + test_chainrules_polar(T, sz, rng; kwargs...) + test_chainrules_orthnull(T, sz, rng; kwargs...) + end +end + +function test_chainrules_qr( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "QR ChainRules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_qr_algorithm(A) + @testset "qr_compact" begin + QR, ΔQR = ad_qr_compact_setup(rng, A) + ΔQ, ΔR = ΔQR + test_rrule( + copy_qr_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "qr_null" begin + N, ΔN = ad_qr_null_setup(rng, A) + test_rrule( + copy_qr_null, A, alg ⊢ NoTangent(); + output_tangent = ΔN, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_null, A; + fkwargs = (; positive = true), output_tangent = ΔN, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m, n = size(A) + end + @testset "qr_full" begin + QR, ΔQR = ad_qr_full_setup(rng, A) + test_rrule( + copy_qr_full, A, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_full, A; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m, n = size(A) + end + @testset "qr_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + QR, ΔQR = ad_qr_rd_compact_setup(rng, Ard) + ΔQ, ΔR = ΔQR + test_rrule( + copy_qr_compact, Ard, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_compact, Ard; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_lq( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "LQ Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_lq_algorithm(A) + @testset "lq_compact" begin + LQ, ΔLQ = ad_lq_compact_setup(rng, A) + ΔL, ΔQ = ΔLQ + test_rrule( + copy_lq_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔL, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_null" begin + Nᴴ, ΔNᴴ = ad_lq_null_setup(rng, A) + test_rrule( + copy_lq_null, A, alg ⊢ NoTangent(); + output_tangent = ΔNᴴ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_null, A; + fkwargs = (; positive = true), output_tangent = ΔNᴴ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_full" begin + LQ, ΔLQ = ad_lq_full_setup(rng, A) + test_rrule( + copy_lq_full, A, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_full, A; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + LQ, ΔLQ = ad_lq_rd_compact_setup(rng, Ard) + test_rrule( + copy_lq_compact, Ard, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_compact, Ard; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_eig( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIG Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m = size(A, 1) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_eig_algorithm(A) + @testset "eig_full" begin + DV, ΔDV, ΔD2V = ad_eig_full_setup(rng, A) + ΔD, ΔV = ΔDV + test_rrule( + copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol + ) + test_rrule( + copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol + ) + test_rrule( + config, eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔDV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔD2V, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eig_vals" begin + D, ΔD = ad_eig_vals_setup(rng, A) + test_rrule( + copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol + ) + test_rrule( + config, eig_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔD, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + DV, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(rng, A, truncalg) + test_rrule( + copy_eig_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + D, V = DV + Ddiag = diagview(D) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(Ddiag[ind]) + Vtrunc = V[:, ind] + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) + DV, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(rng, A, truncalg) + test_rrule( + copy_eig_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + D, V = DV + Ddiag = diagview(D) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(Ddiag[ind]) + Vtrunc = V[:, ind] + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end +end + +function test_chainrules_eigh( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIGH ChainRules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + A = A + A' + m = size(A, 1) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_eigh_algorithm(A) + # copy_eigh_xxxx includes a projector onto the Hermitian part of the matrix + @testset "eigh_full" begin + DV, ΔDV, ΔD2V = ad_eigh_full_setup(rng, A) + ΔD, ΔV = ΔDV + test_rrule( + copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol + ) + test_rrule( + copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol + ) + # eigh_full does not include a projector onto the Hermitian part of the matrix + test_rrule( + config, eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔDV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD2V, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eigh_vals" begin + D, ΔD = ad_eigh_vals_setup(rng, A) + test_rrule( + copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol + ) + test_rrule( + config, eigh_vals ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eigh_trunc" begin + eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) + test_rrule( + copy_eigh_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + D, V = DV + Ddiag = diagview(D) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(Ddiag[ind]) + Vtrunc = V[:, ind] + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = truncrank(r; by = real) + ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) + truncalg = TruncatedAlgorithm(alg, trunc) + DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) + test_rrule( + config, eigh_trunc2, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + D, ΔD = ad_eigh_vals_setup(rng, A) + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) + DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + ΔDtrunc = Diagonal(diagview(ΔDV[1])[ind]) + ΔVtrunc = ΔDV[2][:, ind] + test_rrule( + copy_eigh_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), + atol = atol, rtol = rtol + ) + D, V = DV + Ddiag = diagview(D) + Dtrunc = Diagonal(Ddiag[ind]) + Vtrunc = V[:, ind] + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = trunctol(; rtol = 1 / 2) + truncalg = TruncatedAlgorithm(alg, trunc) + DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) + D, V = DV + Ddiag = diagview(D) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(Ddiag[ind]) + Vtrunc = V[:, ind] + ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) + test_rrule( + config, eigh_trunc2, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_svd( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "SVD Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + minmn = min(size(A)...) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_svd_algorithm(A) + @testset "svd_compact" begin + USV, ΔUSVᴴ, ΔUS2Vᴴ = ad_svd_compact_setup(rng, A) + test_rrule( + copy_svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol + ) + test_rrule( + copy_svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUS2Vᴴ, atol = atol, rtol = rtol + ) + test_rrule( + config, svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUS2Vᴴ, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "svd_vals" begin + S, ΔS = ad_svd_vals_setup(rng, A) + test_rrule( + copy_svd_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔS, atol, rtol + ) + test_rrule( + config, svd_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔS, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "svd_trunc" begin + @testset for r in 1:4:minmn + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(rng, A, truncalg) + test_rrule( + copy_svd_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + U, S, Vᴴ = USVᴴ + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), ΔUSVᴴtrunc) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = truncrank(r) + ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) + test_rrule( + config, svd_trunc, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + S, ΔS = ad_svd_vals_setup(rng, A) + truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(rng, A, truncalg) + test_rrule( + copy_svd_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + U, S, Vᴴ = USVᴴ + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), ΔUSVᴴtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = trunctol(; atol = S[1, 1] / 2) + ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) + test_rrule( + config, svd_trunc, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_polar( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Polar Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_polar_algorithm(A) + @testset "left_polar" begin + if m >= n + test_rrule(copy_left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + test_rrule( + config, left_polar, A, alg ⊢ NoTangent(); + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + @testset "right_polar" begin + if m <= n + test_rrule(copy_right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + test_rrule( + config, right_polar, A, alg ⊢ NoTangent(); + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + end +end + +function test_chainrules_orthnull( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Orthnull Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + N, ΔN = ad_left_null_setup(rng, A) + Nᴴ, ΔNᴴ = ad_right_null_setup(rng, A) + test_rrule( + config, left_orth, A; + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, left_orth, A; + fkwargs = (; alg = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m >= n && + test_rrule( + config, left_orth, A; + fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, left_null, A; + fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + + test_rrule( + config, right_orth, A; + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, right_orth, A; fkwargs = (; alg = :lq), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m <= n && + test_rrule( + config, right_orth, A; fkwargs = (; alg = :polar), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, right_null, A; + fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end +end diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl new file mode 100644 index 00000000..c1084054 --- /dev/null +++ b/test/testsuite/mooncake.jl @@ -0,0 +1,412 @@ +using TestExtras +using MatrixAlgebraKit +using Mooncake, Mooncake.TestUtils +using Mooncake: rrule!! +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc + +make_mooncake_tangent(ΔAelem::T) where {T <: Real} = ΔAelem +make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) +make_mooncake_tangent(ΔA::Matrix{<:Real}) = ΔA +make_mooncake_tangent(ΔA::Vector{<:Real}) = ΔA +make_mooncake_tangent(ΔA::Matrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔA::Vector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) +make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD))) + +make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), make_mooncake_tangent.(T)...) + +make_mooncake_fdata(x) = make_mooncake_tangent(x) +make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) + +# no `alg` argument +function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) + dA_copy = make_mooncake_tangent(copy(ΔA)) + A_copy = copy(A) + dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy)) + copy_pb!!(rdata) + return dA_copy +end + +# `alg` argument +function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) + dA_copy = make_mooncake_tangent(copy(ΔA)) + A_copy = copy(A) + dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) + copy_pb!!(rdata) + return dA_copy +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) + dA_inplace = make_mooncake_tangent(copy(ΔA)) + A_inplace = copy(A) + dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + end + inplace_pb!!(rdata) + return dA_inplace +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + dA_inplace = make_mooncake_tangent(copy(ΔA)) + A_inplace = copy(A) + dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + end + inplace_pb!!(rdata) + return dA_inplace +end + +""" + test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + +Compare the result of running the *in-place, mutating* function `f!`'s reverse rule +with the result of running its *non-mutating* partner function `f`'s reverse rule. +We must compare directly because many of the mutating functions modify `A` as a +scratch workspace, making testing `f!` against finite differences infeasible. + +The arguments to this function are: + - `f!` the mutating, in-place version of the function (accepts `args` for the function result) + - `f` the non-mutating version of the function (does not accept `args` for the function result) + - `A` the input matrix to factorize + - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) + - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input + - `alg` optional algorithm keyword argument + - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) +""" +function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) + sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + rrule = Mooncake.build_rrule(rvs_interp, sig) + ΔA = randn(rng, eltype(A), size(A)) + + dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) + dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + + dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] + dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] + @test dA_inplace_ ≈ dA_copy_ + return +end + +function test_mooncake(T::Type, sz, rng; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake AD $summary_str" begin + test_mooncake_qr(T, sz, rng; kwargs...) + test_mooncake_lq(T, sz, rng; kwargs...) + if length(sz) == 1 || sz[1] == sz[2] + test_mooncake_eig(T, sz, rng; kwargs...) + test_mooncake_eigh(T, sz, rng; kwargs...) + end + test_mooncake_svd(T, sz, rng; kwargs...) + test_mooncake_polar(T, sz, rng; kwargs...) + test_mooncake_orthnull(T, sz, rng; kwargs...) + end +end + +function test_mooncake_qr( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "QR Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + @testset "qr_compact" begin + QR, ΔQR = ad_qr_compact_setup(rng, A) + Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(rng, qr_compact!, qr_compact, A, QR, ΔQR) + end + @testset "qr_null" begin + N, ΔN = ad_qr_null_setup(rng, A) + dN = make_mooncake_tangent(copy(ΔN)) + Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dN, atol = atol, rtol = rtol) + test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN) + end + @testset "qr_full" begin + QR, ΔQR = ad_qr_full_setup(rng, A) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + test_pullbacks_match(rng, qr_full!, qr_full, A, QR, ΔQR) + end + @testset "qr_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + QR, ΔQR = ad_qr_rd_compact_setup(rng, Ard) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, QR, ΔQR) + end + end +end + +function test_mooncake_lq( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "LQ Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + @testset "lq_compact" begin + LQ, ΔLQ = ad_lq_compact_setup(rng, A) + Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(rng, lq_compact!, lq_compact, A, LQ, ΔLQ) + end + @testset "lq_null" begin + Nᴴ, ΔNᴴ = ad_lq_null_setup(rng, A) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ) + end + @testset "lq_full" begin + LQ, ΔLQ = ad_lq_full_setup(rng, A) + dLQ = make_mooncake_tangent(ΔLQ) + Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, lq_full!, lq_full, A, LQ, ΔLQ) + end + @testset "lq_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + LQ, ΔLQ = ad_lq_rd_compact_setup(rng, Ard) + dLQ = make_mooncake_tangent(ΔLQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, LQ, ΔLQ) + end + end +end + +function test_mooncake_eig( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIG Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m = size(A, 1) + @testset "eig_full" begin + DV, ΔDV, ΔD2V = ad_eig_full_setup(rng, A) + dDV = make_mooncake_tangent(ΔD2V) + Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dDV, atol = atol, rtol = rtol) + test_pullbacks_match(rng, eig_full!, eig_full, A, DV, ΔD2V) + end + @testset "eig_vals" begin + D, ΔD = ad_eig_vals_setup(rng, A) + dD = make_mooncake_tangent(ΔD) + Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(rng, eig_vals!, eig_vals, A, D, ΔD) + end + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) + DV, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(rng, A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) + test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) + DV, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(rng, A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) + test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + end +end + +function test_mooncake_eigh( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIGH Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + A = A + A' + m = size(A, 1) + @testset "eigh_full" begin + DV, ΔDV, ΔD2V = ad_eigh_full_setup(rng, A) + dDV = make_mooncake_tangent(ΔD2V) + Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V) + end + @testset "eigh_vals" begin + D, ΔD = ad_eigh_vals_setup(rng, A) + dD = make_mooncake_tangent(ΔD) + Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D, ΔD) + end + @testset "eigh_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(r; by = abs)) + DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + D = eigh_vals(A / 2) + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2)) + DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + end +end + +function test_mooncake_svd( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "SVD Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + minmn = min(size(A)...) + @testset "svd_compact" begin + USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(rng, A) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ) + end + @testset "svd_full" begin + USVᴴ, ΔUSVᴴ = ad_svd_full_setup(rng, A) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ) + end + @testset "svd_vals" begin + S, ΔS = ad_svd_vals_setup(rng, A) + Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS) + end + @testset "svd_trunc" begin + S, ΔS = ad_svd_vals_setup(rng, A) + @testset for r in 1:4:minmn + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(rng, A, truncalg) + ϵ = zero(real(T)) + dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + @testset "trunctol" begin + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(rng, A, truncalg) + ϵ = zero(real(T)) + dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + end + end +end + +function test_mooncake_polar( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Polar Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + @testset "left_polar" begin + if m >= n + WP, ΔWP = ad_left_polar_setup(rng, A) + Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(rng, left_polar!, left_polar, A, WP, ΔWP) + end + end + @testset "right_polar" begin + if m <= n + PWᴴ, ΔPWᴴ = ad_right_polar_setup(rng, A) + Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, ΔPWᴴ) + end + end + end +end + +left_orth_qr(X) = left_orth(X; alg = :qr) +left_orth_polar(X) = left_orth(X; alg = :polar) +left_null_qr(X) = left_null(X; alg = :qr) +right_orth_lq(X) = right_orth(X; alg = :lq) +right_orth_polar(X) = right_orth(X; alg = :polar) +right_null_lq(X) = right_null(X; alg = :lq) + +MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) +MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) +MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) +MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) + +function test_mooncake_orthnull( + T::Type, sz, rng; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Orthnull Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + VC, ΔVC = ad_left_orth_setup(rng, A) + CVᴴ, ΔCVᴴ = ad_right_orth_setup(rng, A) + Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, left_orth!, left_orth, A, VC, ΔVC) + Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, right_orth!, right_orth, A, CVᴴ, ΔCVᴴ) + + Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, ΔVC) + if m >= n + Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, ΔVC) + end + + N, ΔN = ad_left_null_setup(rng, A) + dN = make_mooncake_tangent(ΔN) + Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) + test_pullbacks_match(rng, ((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) + + Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, ΔCVᴴ) + + if m <= n + Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, ΔCVᴴ) + end + + Nᴴ, ΔNᴴ = ad_right_null_setup(rng, A) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) + test_pullbacks_match(rng, ((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) + end +end From d9da5cd73da137f65c42223c353dcfc2e6732990 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 11:12:24 -0500 Subject: [PATCH 3/6] More cleanup --- test/chainrules.jl | 5 +- test/mooncake.jl | 8 +- test/testsuite/ad_utils.jl | 202 +++++++++++++++-------------------- test/testsuite/chainrules.jl | 127 +++++++++------------- test/testsuite/mooncake.jl | 154 +++++++++++++------------- 5 files changed, 221 insertions(+), 275 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 8ed6fdb8..c0ab618a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,6 +1,5 @@ using MatrixAlgebraKit using Test -using StableRNGs #BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI @@ -10,12 +9,10 @@ using .TestSuite is_buildkite = get(ENV, "BUILDKITE", "false") == "true" -rng = StableRNG(12345) - m = 19 for T in BLASFloats, n in (17, m, 23) TestSuite.seed_rng!(123) if !is_buildkite # doesn't work on GPU - TestSuite.test_chainrules(T, (m, n), rng; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + TestSuite.test_chainrules(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/mooncake.jl b/test/mooncake.jl index aa449c44..8a0c2931 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -1,6 +1,5 @@ using MatrixAlgebraKit using Test -using StableRNGs using CUDA, AMDGPU #BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) @@ -11,18 +10,17 @@ using .TestSuite is_buildkite = get(ENV, "BUILDKITE", "false") == "true" -rng = StableRNG(12345) m = 19 for T in BLASFloats, n in (17, m, 23) TestSuite.seed_rng!(123) if is_buildkite if CUDA.functional() - TestSuite.test_mooncake(CuMatrix{T}, (m, n), rng; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + TestSuite.test_mooncake(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end #=if AMDGPU.functional() - TestSuite.test_mooncake(ROCMatrix{T}, (m, n), rng; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + TestSuite.test_mooncake(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end=# # not yet supported else - TestSuite.test_mooncake(T, (m, n), rng; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + TestSuite.test_mooncake(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index d93cd339..c4de880c 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -28,99 +28,99 @@ function remove_eighgauge_dependence!( return ΔV end -function ad_qr_compact_setup(rng, A) +function ad_qr_compact_setup(A) m, n = size(A) minmn = min(m, n) QR = qr_compact(A) T = eltype(A) - ΔQ = randn(rng, T, m, minmn) - ΔR = randn(rng, T, minmn, n) + ΔQ = randn!(similar(A, T, m, minmn)) + ΔR = randn!(similar(A, T, minmn, n)) return QR, (ΔQ, ΔR) end -function ad_qr_null_setup(rng, A) +function ad_qr_null_setup(A) m, n = size(A) minmn = min(m, n) Q, R = qr_compact(A) T = eltype(A) - ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) + ΔN = Q * randn!(similar(A, T, minmn, max(0, m - minmn))) N = qr_null(A) return N, ΔN end -function ad_qr_full_setup(rng, A) +function ad_qr_full_setup(A) m, n = size(A) minmn = min(m, n) T = eltype(A) Q, R = qr_full(A) Q1 = view(Q, 1:m, 1:minmn) - ΔQ = randn(rng, T, m, m) + ΔQ = randn!(similar(A, T, m, m)) ΔQ2 = view(ΔQ, :, (minmn + 1):m) mul!(ΔQ2, Q1, Q1' * ΔQ2) - ΔR = randn(rng, T, m, n) + ΔR = randn!(similar(A, T, m, n)) return (Q, R), (ΔQ, ΔR) end -function ad_qr_rd_compact_setup(rng, A) +function ad_qr_rd_compact_setup(A) m, n = size(A) minmn = min(m, n) T = eltype(A) r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) Q, R = qr_compact(Ard) QR = (Q, R) - ΔQ = randn(rng, T, m, minmn) + ΔQ = randn!(similar(A, T, m, minmn)) Q1 = view(Q, 1:m, 1:r) Q2 = view(Q, 1:m, (r + 1):minmn) ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) ΔQ2 .= 0 - ΔR = randn(rng, T, minmn, n) + ΔR = randn!(similar(A, T, minmn, n)) view(ΔR, (r + 1):minmn, :) .= 0 return (Q, R), (ΔQ, ΔR) end -function ad_lq_compact_setup(rng, A) +function ad_lq_compact_setup(A) m, n = size(A) minmn = min(m, n) LQ = lq_compact(A) T = eltype(A) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) + ΔL = randn!(similar(A, T, m, minmn)) + ΔQ = randn!(similar(A, T, minmn, n)) return LQ, (ΔL, ΔQ) end -function ad_lq_null_setup(rng, A) +function ad_lq_null_setup(A) m, n = size(A) minmn = min(m, n) T = eltype(A) L, Q = lq_compact(A) - ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - Nᴴ = randn(rng, T, max(0, n - minmn), n) + ΔNᴴ = randn!(similar(A, T, max(0, n - minmn), minmn)) * Q + Nᴴ = randn!(similar(A, T, max(0, n - minmn), n)) return Nᴴ, ΔNᴴ end -function ad_lq_full_setup(rng, A) +function ad_lq_full_setup(A) m, n = size(A) minmn = min(m, n) T = eltype(A) L, Q = lq_full(A) Q1 = view(Q, 1:minmn, 1:n) - ΔQ = randn(rng, T, n, n) + ΔQ = randn!(similar(A, T, n, n)) ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) mul!(ΔQ2, ΔQ2 * Q1', Q1) - ΔL = randn(rng, T, m, n) + ΔL = randn!(similar(A, T, m, n)) return (L, Q), (ΔL, ΔQ) end -function ad_lq_rd_compact_setup(rng, A) +function ad_lq_rd_compact_setup(A) m, n = size(A) minmn = min(m, n) T = eltype(A) r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) L, Q = lq_compact(Ard) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) + ΔL = randn!(similar(A, T, m, minmn)) + ΔQ = randn!(similar(A, T, minmn, n)) Q1 = view(Q, 1:r, 1:n) Q2 = view(Q, (r + 1):minmn, 1:n) ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) @@ -129,43 +129,35 @@ function ad_lq_rd_compact_setup(rng, A) return (L, Q), (ΔL, ΔQ) end -function ad_eig_full_setup(rng, A) +function ad_eig_full_setup(A) m, n = size(A) T = eltype(A) DV = eig_full(A) D, V = DV Ddiag = diagview(D) - ΔV = randn(rng, complex(T), m, m) + ΔV = randn!(similar(A, complex(T), m, m)) ΔV = remove_eiggauge_dependence!(ΔV, D, V) - ΔD = randn(rng, complex(T), m, m) - ΔD2 = Diagonal(randn(rng, complex(T), m)) + ΔD = randn!(similar(A, complex(T), m, m)) + ΔD2 = Diagonal(randn!(similar(A, complex(T), m))) return DV, (ΔD, ΔV), (ΔD2, ΔV) end -function ad_eig_vals_setup(rng, A) +function ad_eig_vals_setup(A) m, n = size(A) T = eltype(A) D = eig_vals(A) - ΔD = randn(rng, complex(T), m) + ΔD = randn!(similar(A, complex(T), m)) return D, ΔD end -function ad_eig_trunc_setup(rng, A, truncalg) - m, n = size(A) - T = eltype(A) - DV = eig_full(A) - D, V = DV - Ddiag = diagview(D) - ΔV = randn(rng, complex(T), m, m) - ΔV = remove_eiggauge_dependence!(ΔV, D, V) - ΔD = randn(rng, complex(T), m, m) - ΔD2 = Diagonal(randn(rng, complex(T), m)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - return DV, (ΔD2, ΔV), (ΔDtrunc, ΔVtrunc) +function ad_eig_trunc_setup(A, truncalg) + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + Dtrunc = Diagonal(diagview(DV[1])[ind]) + Vtrunc = DV[2][:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2V[1])[ind]) + ΔVtrunc = ΔDV[2][:, ind] + return DV, (Dtrunc, Vtrunc), ΔD2V, (ΔDtrunc, ΔVtrunc) end function copy_eigh_full(A; kwargs...) @@ -202,66 +194,58 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_full), A) = MatrixAlgebraKit.copy MatrixAlgebraKit.copy_input(::typeof(copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) -function ad_eigh_full_setup(rng, A) +function ad_eigh_full_setup(A) m, n = size(A) T = eltype(A) DV = eigh_full(A) D, V = DV Ddiag = diagview(D) - ΔV = randn(rng, T, m, m) + ΔV = randn!(similar(A, T, m, m)) ΔV = remove_eighgauge_dependence!(ΔV, D, V) - ΔD = randn(rng, real(T), m, m) - ΔD2 = Diagonal(randn(rng, real(T), m)) + ΔD = randn!(similar(A, real(T), m, m)) + ΔD2 = Diagonal(randn!(similar(A, real(T), m))) return DV, (ΔD, ΔV), (ΔD2, ΔV) end -function ad_eigh_vals_setup(rng, A) +function ad_eigh_vals_setup(A) m, n = size(A) T = eltype(A) D = eigh_vals(A) - ΔD = randn(rng, real(T), m) + ΔD = randn!(similar(A, real(T), m)) return D, ΔD end -function ad_eigh_trunc_setup(rng, A, truncalg) - m, n = size(A) - T = eltype(A) - DV = eigh_full(A) - D, V = DV - Ddiag = diagview(D) - ΔV = randn(rng, T, m, m) - ΔV = remove_eighgauge_dependence!(ΔV, D, V) - ΔD = randn(rng, real(T), m, m) - ΔD2 = Diagonal(randn(rng, real(T), m)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - return DV, (ΔD2, ΔV), (ΔDtrunc, ΔVtrunc) +function ad_eigh_trunc_setup(A, truncalg) + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + Dtrunc = Diagonal(diagview(DV[1])[ind]) + Vtrunc = DV[2][:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2V[1])[ind]) + ΔVtrunc = ΔDV[2][:, ind] + return DV, (Dtrunc, Vtrunc), ΔD2V, (ΔDtrunc, ΔVtrunc) end -function ad_svd_compact_setup(rng, A) +function ad_svd_compact_setup(A) m, n = size(A) T = eltype(A) minmn = min(m, n) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) + ΔU = randn!(similar(A, T, m, minmn)) + ΔS = randn!(similar(A, real(T), minmn, minmn)) + ΔS2 = Diagonal(randn!(similar(A, real(T), minmn))) + ΔVᴴ = randn!(similar(A, T, minmn, n)) U, S, Vᴴ = svd_compact(A) ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), (ΔU, ΔS2, ΔVᴴ) end -function ad_svd_full_setup(rng, A) +function ad_svd_full_setup(A) m, n = size(A) T = eltype(A) minmn = min(m, n) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) + ΔU = randn!(similar(A, T, m, minmn)) + ΔS = randn!(similar(A, real(T), minmn, minmn)) + ΔS2 = Diagonal(randn!(similar(A, real(T), minmn))) + ΔVᴴ = randn!(similar(A, T, minmn, n)) U, S, Vᴴ = svd_compact(A) ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) ΔUfull = zeros(T, m, m) @@ -274,79 +258,71 @@ function ad_svd_full_setup(rng, A) return (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull) end -function ad_svd_vals_setup(rng, A) +function ad_svd_vals_setup(A) m, n = size(A) minmn = min(m, n) T = eltype(A) S = svd_vals(A) - ΔS = randn(rng, real(T), minmn) + ΔS = randn!(similar(A, real(T), minmn)) return S, ΔS end -function ad_svd_trunc_setup(rng, A, truncalg) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - return (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc) +function ad_svd_trunc_setup(A, truncalg) + USVᴴ, ΔUSVᴴ, ΔUS2Vᴴ = ad_svd_compact_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(USVᴴ[2]), truncalg.trunc) + Strunc = Diagonal(diagview(USVᴴ[2])[ind]) + Utrunc = USVᴴ[1][:, ind] + Vᴴtrunc = USVᴴ[3][ind, :] + ΔStrunc = Diagonal(diagview(ΔUS2Vᴴ[2])[ind]) + ΔUtrunc = ΔUSVᴴ[1][:, ind] + ΔVᴴtrunc = ΔUSVᴴ[3][ind, :] + return USVᴴ, ΔUS2Vᴴ, (ΔUtrunc, ΔStrunc, ΔVᴴtrunc) end -function ad_left_polar_setup(rng, A) +function ad_left_polar_setup(A) m, n = size(A) T = eltype(A) WP = left_polar(A) - ΔWP = (randn(rng, T, m, n), randn(rng, T, n, n)) + ΔWP = (randn!(similar(A, T, m, n)), randn!(similar(A, T, n, n))) return WP, ΔWP end -function ad_right_polar_setup(rng, A) +function ad_right_polar_setup(A) m, n = size(A) T = eltype(A) PWᴴ = right_polar(A) - ΔPWᴴ = (randn(rng, T, m, m), randn(rng, T, m, n)) + ΔPWᴴ = (randn!(similar(A, T, m, m)), randn!(similar(A, T, m, n))) return PWᴴ, ΔPWᴴ end -function ad_left_orth_setup(rng, A) +function ad_left_orth_setup(A) m, n = size(A) T = eltype(A) VC = left_orth(A) - ΔVC = (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...)) + ΔVC = (randn!(similar(A, T, size(VC[1])...)), randn!(similar(A, T, size(VC[2])...))) return VC, ΔVC end -function ad_left_null_setup(rng, A) +function ad_left_null_setup(A) m, n = size(A) T = eltype(A) - N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + N = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) + ΔN = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) return N, ΔN end -function ad_right_orth_setup(rng, A) +function ad_right_orth_setup(A) m, n = size(A) T = eltype(A) CVᴴ = right_orth(A) - ΔCVᴴ = (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...)) + ΔCVᴴ = (randn!(similar(A, T, size(CVᴴ[1])...)), randn!(similar(A, T, size(CVᴴ[2])...))) return CVᴴ, ΔCVᴴ end -function ad_right_null_setup(rng, A) +function ad_right_null_setup(A) m, n = size(A) T = eltype(A) - Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] + Nᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] + ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] return Nᴴ, ΔNᴴ end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index 5edd8319..bb062c0e 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -3,8 +3,6 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! -include("ad_utils.jl") - for f in ( :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, @@ -35,23 +33,23 @@ for f in end end -function test_chainrules(T::Type, sz, rng; kwargs...) +function test_chainrules(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "Chainrules AD $summary_str" begin - test_chainrules_qr(T, sz, rng; kwargs...) - test_chainrules_lq(T, sz, rng; kwargs...) + test_chainrules_qr(T, sz; kwargs...) + test_chainrules_lq(T, sz; kwargs...) if length(sz) == 1 || sz[1] == sz[2] - test_chainrules_eig(T, sz, rng; kwargs...) - test_chainrules_eigh(T, sz, rng; kwargs...) + test_chainrules_eig(T, sz; kwargs...) + test_chainrules_eigh(T, sz; kwargs...) end - test_chainrules_svd(T, sz, rng; kwargs...) - test_chainrules_polar(T, sz, rng; kwargs...) - test_chainrules_orthnull(T, sz, rng; kwargs...) + test_chainrules_svd(T, sz; kwargs...) + test_chainrules_polar(T, sz; kwargs...) + test_chainrules_orthnull(T, sz; kwargs...) end end function test_chainrules_qr( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -61,7 +59,7 @@ function test_chainrules_qr( config = Zygote.ZygoteRuleConfig() alg = MatrixAlgebraKit.default_qr_algorithm(A) @testset "qr_compact" begin - QR, ΔQR = ad_qr_compact_setup(rng, A) + QR, ΔQR = ad_qr_compact_setup(A) ΔQ, ΔR = ΔQR test_rrule( copy_qr_compact, A, alg ⊢ NoTangent(); @@ -84,7 +82,7 @@ function test_chainrules_qr( ) end @testset "qr_null" begin - N, ΔN = ad_qr_null_setup(rng, A) + N, ΔN = ad_qr_null_setup(A) test_rrule( copy_qr_null, A, alg ⊢ NoTangent(); output_tangent = ΔN, atol = atol, rtol = rtol @@ -97,7 +95,7 @@ function test_chainrules_qr( m, n = size(A) end @testset "qr_full" begin - QR, ΔQR = ad_qr_full_setup(rng, A) + QR, ΔQR = ad_qr_full_setup(A) test_rrule( copy_qr_full, A, alg ⊢ NoTangent(); output_tangent = ΔQR, atol = atol, rtol = rtol @@ -113,7 +111,7 @@ function test_chainrules_qr( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - QR, ΔQR = ad_qr_rd_compact_setup(rng, Ard) + QR, ΔQR = ad_qr_rd_compact_setup(Ard) ΔQ, ΔR = ΔQR test_rrule( copy_qr_compact, Ard, alg ⊢ NoTangent(); @@ -129,7 +127,7 @@ function test_chainrules_qr( end function test_chainrules_lq( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -140,7 +138,7 @@ function test_chainrules_lq( config = Zygote.ZygoteRuleConfig() alg = MatrixAlgebraKit.default_lq_algorithm(A) @testset "lq_compact" begin - LQ, ΔLQ = ad_lq_compact_setup(rng, A) + LQ, ΔLQ = ad_lq_compact_setup(A) ΔL, ΔQ = ΔLQ test_rrule( copy_lq_compact, A, alg ⊢ NoTangent(); @@ -163,7 +161,7 @@ function test_chainrules_lq( ) end @testset "lq_null" begin - Nᴴ, ΔNᴴ = ad_lq_null_setup(rng, A) + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) test_rrule( copy_lq_null, A, alg ⊢ NoTangent(); output_tangent = ΔNᴴ, atol = atol, rtol = rtol @@ -175,7 +173,7 @@ function test_chainrules_lq( ) end @testset "lq_full" begin - LQ, ΔLQ = ad_lq_full_setup(rng, A) + LQ, ΔLQ = ad_lq_full_setup(A) test_rrule( copy_lq_full, A, alg ⊢ NoTangent(); output_tangent = ΔLQ, atol = atol, rtol = rtol @@ -190,7 +188,7 @@ function test_chainrules_lq( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - LQ, ΔLQ = ad_lq_rd_compact_setup(rng, Ard) + LQ, ΔLQ = ad_lq_rd_compact_setup(Ard) test_rrule( copy_lq_compact, Ard, alg ⊢ NoTangent(); output_tangent = ΔLQ, atol = atol, rtol = rtol @@ -205,7 +203,7 @@ function test_chainrules_lq( end function test_chainrules_eig( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -216,7 +214,7 @@ function test_chainrules_eig( config = Zygote.ZygoteRuleConfig() alg = MatrixAlgebraKit.default_eig_algorithm(A) @testset "eig_full" begin - DV, ΔDV, ΔD2V = ad_eig_full_setup(rng, A) + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) ΔD, ΔV = ΔDV test_rrule( copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol @@ -242,7 +240,7 @@ function test_chainrules_eig( ) end @testset "eig_vals" begin - D, ΔD = ad_eig_vals_setup(rng, A) + D, ΔD = ad_eig_vals_setup(A) test_rrule( copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol ) @@ -254,42 +252,34 @@ function test_chainrules_eig( @testset "eig_trunc" begin for r in 1:4:m truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - DV, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(rng, A, truncalg) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) test_rrule( copy_eig_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔDVtrunc..., zero(real(T))), atol = atol, rtol = rtol ) - D, V = DV - Ddiag = diagview(D) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(Ddiag[ind]) - Vtrunc = V[:, ind] + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), ΔDVtrunc) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) end truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) - DV, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(rng, A, truncalg) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) test_rrule( copy_eig_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔDVtrunc..., zero(real(T))), atol = atol, rtol = rtol ) - D, V = DV - Ddiag = diagview(D) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(Ddiag[ind]) - Vtrunc = V[:, ind] + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), ΔDVtrunc) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) end end end function test_chainrules_eigh( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -302,7 +292,7 @@ function test_chainrules_eigh( alg = MatrixAlgebraKit.default_eigh_algorithm(A) # copy_eigh_xxxx includes a projector onto the Hermitian part of the matrix @testset "eigh_full" begin - DV, ΔDV, ΔD2V = ad_eigh_full_setup(rng, A) + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) ΔD, ΔV = ΔDV test_rrule( copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol @@ -329,7 +319,7 @@ function test_chainrules_eigh( ) end @testset "eigh_vals" begin - D, ΔD = ad_eigh_vals_setup(rng, A) + D, ΔD = ad_eigh_vals_setup(A) test_rrule( copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol ) @@ -342,24 +332,20 @@ function test_chainrules_eigh( eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) for r in 1:4:m truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) test_rrule( copy_eigh_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔDVtrunc..., zero(real(T))), atol = atol, rtol = rtol ) - D, V = DV - Ddiag = diagview(D) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(Ddiag[ind]) - Vtrunc = V[:, ind] + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) - dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), ΔDVtrunc) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) trunc = truncrank(r; by = real) - ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), trunc) truncalg = TruncatedAlgorithm(alg, trunc) - DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) test_rrule( config, eigh_trunc2, A; fkwargs = (; trunc = trunc), @@ -367,33 +353,22 @@ function test_chainrules_eigh( atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end - D, ΔD = ad_eigh_vals_setup(rng, A) + D, ΔD = ad_eigh_vals_setup(A / 2) truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) - DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) - ΔDtrunc = Diagonal(diagview(ΔDV[1])[ind]) - ΔVtrunc = ΔDV[2][:, ind] test_rrule( copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), + output_tangent = (ΔDVtrunc..., zero(real(T))), atol = atol, rtol = rtol ) - D, V = DV - Ddiag = diagview(D) - Dtrunc = Diagonal(Ddiag[ind]) - Vtrunc = V[:, ind] dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) - dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), ΔDVtrunc) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) trunc = trunctol(; rtol = 1 / 2) truncalg = TruncatedAlgorithm(alg, trunc) - DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) - D, V = DV - Ddiag = diagview(D) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(Ddiag[ind]) - Vtrunc = V[:, ind] - ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) test_rrule( config, eigh_trunc2, A; fkwargs = (; trunc = trunc), @@ -405,7 +380,7 @@ function test_chainrules_eigh( end function test_chainrules_svd( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -416,7 +391,7 @@ function test_chainrules_svd( config = Zygote.ZygoteRuleConfig() alg = MatrixAlgebraKit.default_svd_algorithm(A) @testset "svd_compact" begin - USV, ΔUSVᴴ, ΔUS2Vᴴ = ad_svd_compact_setup(rng, A) + USV, ΔUSVᴴ, ΔUS2Vᴴ = ad_svd_compact_setup(A) test_rrule( copy_svd_compact, A, alg ⊢ NoTangent(); output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol @@ -437,7 +412,7 @@ function test_chainrules_svd( ) end @testset "svd_vals" begin - S, ΔS = ad_svd_vals_setup(rng, A) + S, ΔS = ad_svd_vals_setup(A) test_rrule( copy_svd_vals, A, alg ⊢ NoTangent(); output_tangent = ΔS, atol, rtol @@ -450,7 +425,7 @@ function test_chainrules_svd( @testset "svd_trunc" begin @testset for r in 1:4:minmn truncalg = TruncatedAlgorithm(alg, truncrank(r)) - USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(rng, A, truncalg) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) test_rrule( copy_svd_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), @@ -474,9 +449,9 @@ function test_chainrules_svd( atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end - S, ΔS = ad_svd_vals_setup(rng, A) + S, ΔS = ad_svd_vals_setup(A) truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) - USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(rng, A, truncalg) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) test_rrule( copy_svd_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), @@ -503,7 +478,7 @@ function test_chainrules_svd( end function test_chainrules_polar( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -535,7 +510,7 @@ function test_chainrules_polar( end function test_chainrules_orthnull( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -544,8 +519,8 @@ function test_chainrules_orthnull( A = instantiate_matrix(T, sz) m, n = size(A) config = Zygote.ZygoteRuleConfig() - N, ΔN = ad_left_null_setup(rng, A) - Nᴴ, ΔNᴴ = ad_right_null_setup(rng, A) + N, ΔN = ad_left_null_setup(A) + Nᴴ, ΔNᴴ = ad_right_null_setup(A) test_rrule( config, left_orth, A; atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl index c1084054..97049784 100644 --- a/test/testsuite/mooncake.jl +++ b/test/testsuite/mooncake.jl @@ -77,7 +77,7 @@ function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) end """ - test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) Compare the result of running the *in-place, mutating* function `f!`'s reverse rule with the result of running its *non-mutating* partner function `f`'s reverse rule. @@ -93,12 +93,12 @@ The arguments to this function are: - `alg` optional algorithm keyword argument - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) """ -function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) +function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) rrule = Mooncake.build_rrule(rvs_interp, sig) - ΔA = randn(rng, eltype(A), size(A)) + ΔA = randn!(similar(A)) dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) @@ -109,23 +109,23 @@ function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata return end -function test_mooncake(T::Type, sz, rng; kwargs...) +function test_mooncake(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "Mooncake AD $summary_str" begin - test_mooncake_qr(T, sz, rng; kwargs...) - test_mooncake_lq(T, sz, rng; kwargs...) + test_mooncake_qr(T, sz; kwargs...) + test_mooncake_lq(T, sz; kwargs...) if length(sz) == 1 || sz[1] == sz[2] - test_mooncake_eig(T, sz, rng; kwargs...) - test_mooncake_eigh(T, sz, rng; kwargs...) + test_mooncake_eig(T, sz; kwargs...) + test_mooncake_eigh(T, sz; kwargs...) end - test_mooncake_svd(T, sz, rng; kwargs...) - test_mooncake_polar(T, sz, rng; kwargs...) - test_mooncake_orthnull(T, sz, rng; kwargs...) + test_mooncake_svd(T, sz; kwargs...) + test_mooncake_polar(T, sz; kwargs...) + test_mooncake_orthnull(T, sz; kwargs...) end end function test_mooncake_qr( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -133,36 +133,36 @@ function test_mooncake_qr( return @testset "QR Mooncake AD rules $summary_str" begin A = instantiate_matrix(T, sz) @testset "qr_compact" begin - QR, ΔQR = ad_qr_compact_setup(rng, A) + QR, ΔQR = ad_qr_compact_setup(A) Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_compact!, qr_compact, A, QR, ΔQR) + test_pullbacks_match(qr_compact!, qr_compact, A, QR, ΔQR) end @testset "qr_null" begin - N, ΔN = ad_qr_null_setup(rng, A) + N, ΔN = ad_qr_null_setup(A) dN = make_mooncake_tangent(copy(ΔN)) Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dN, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN) + test_pullbacks_match(qr_null!, qr_null, A, N, ΔN) end @testset "qr_full" begin - QR, ΔQR = ad_qr_full_setup(rng, A) + QR, ΔQR = ad_qr_full_setup(A) dQR = make_mooncake_tangent(ΔQR) Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_full!, qr_full, A, QR, ΔQR) + test_pullbacks_match(qr_full!, qr_full, A, QR, ΔQR) end @testset "qr_compact - rank-deficient A" begin m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - QR, ΔQR = ad_qr_rd_compact_setup(rng, Ard) + QR, ΔQR = ad_qr_rd_compact_setup(Ard) dQR = make_mooncake_tangent(ΔQR) Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, QR, ΔQR) + test_pullbacks_match(qr_compact!, qr_compact, Ard, QR, ΔQR) end end end function test_mooncake_lq( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -170,36 +170,36 @@ function test_mooncake_lq( return @testset "LQ Mooncake AD rules $summary_str" begin A = instantiate_matrix(T, sz) @testset "lq_compact" begin - LQ, ΔLQ = ad_lq_compact_setup(rng, A) + LQ, ΔLQ = ad_lq_compact_setup(A) Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_compact!, lq_compact, A, LQ, ΔLQ) + test_pullbacks_match(lq_compact!, lq_compact, A, LQ, ΔLQ) end @testset "lq_null" begin - Nᴴ, ΔNᴴ = ad_lq_null_setup(rng, A) + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) dNᴴ = make_mooncake_tangent(ΔNᴴ) Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ) + test_pullbacks_match(lq_null!, lq_null, A, Nᴴ, ΔNᴴ) end @testset "lq_full" begin - LQ, ΔLQ = ad_lq_full_setup(rng, A) + LQ, ΔLQ = ad_lq_full_setup(A) dLQ = make_mooncake_tangent(ΔLQ) Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_full!, lq_full, A, LQ, ΔLQ) + test_pullbacks_match(lq_full!, lq_full, A, LQ, ΔLQ) end @testset "lq_compact - rank-deficient A" begin m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - LQ, ΔLQ = ad_lq_rd_compact_setup(rng, Ard) + LQ, ΔLQ = ad_lq_rd_compact_setup(Ard) dLQ = make_mooncake_tangent(ΔLQ) Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, LQ, ΔLQ) + test_pullbacks_match(lq_compact!, lq_compact, Ard, LQ, ΔLQ) end end end function test_mooncake_eig( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -208,38 +208,38 @@ function test_mooncake_eig( A = instantiate_matrix(T, sz) m = size(A, 1) @testset "eig_full" begin - DV, ΔDV, ΔD2V = ad_eig_full_setup(rng, A) + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) dDV = make_mooncake_tangent(ΔD2V) Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dDV, atol = atol, rtol = rtol) - test_pullbacks_match(rng, eig_full!, eig_full, A, DV, ΔD2V) + test_pullbacks_match(eig_full!, eig_full, A, DV, ΔD2V) end @testset "eig_vals" begin - D, ΔD = ad_eig_vals_setup(rng, A) + D, ΔD = ad_eig_vals_setup(A) dD = make_mooncake_tangent(ΔD) Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) - test_pullbacks_match(rng, eig_vals!, eig_vals, A, D, ΔD) + test_pullbacks_match(eig_vals!, eig_vals, A, D, ΔD) end @testset "eig_trunc" begin for r in 1:4:m truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) - DV, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(rng, A, truncalg) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) ϵ = zero(real(T)) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) - DV, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(rng, A, truncalg) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) ϵ = zero(real(T)) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end end end function test_mooncake_eigh( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -249,39 +249,39 @@ function test_mooncake_eigh( A = A + A' m = size(A, 1) @testset "eigh_full" begin - DV, ΔDV, ΔD2V = ad_eigh_full_setup(rng, A) + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) dDV = make_mooncake_tangent(ΔD2V) Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V) + test_pullbacks_match(copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V) end @testset "eigh_vals" begin - D, ΔD = ad_eigh_vals_setup(rng, A) + D, ΔD = ad_eigh_vals_setup(A) dD = make_mooncake_tangent(ΔD) Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D, ΔD) + test_pullbacks_match(copy_eigh_vals!, copy_eigh_vals, A, D, ΔD) end @testset "eigh_trunc" begin for r in 1:4:m truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(r; by = abs)) - DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ϵ = zero(real(T)) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(copy_eigh_trunc!, copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end D = eigh_vals(A / 2) truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2)) - DV, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(rng, A, truncalg) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ϵ = zero(real(T)) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(copy_eigh_trunc!, copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end end end function test_mooncake_svd( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -290,46 +290,46 @@ function test_mooncake_svd( A = instantiate_matrix(T, sz) minmn = min(size(A)...) @testset "svd_compact" begin - USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(rng, A) + USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ) + test_pullbacks_match(svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ) end @testset "svd_full" begin - USVᴴ, ΔUSVᴴ = ad_svd_full_setup(rng, A) + USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ) + test_pullbacks_match(svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ) end @testset "svd_vals" begin - S, ΔS = ad_svd_vals_setup(rng, A) + S, ΔS = ad_svd_vals_setup(A) Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS) + test_pullbacks_match(svd_vals!, svd_vals, A, S, ΔS) end @testset "svd_trunc" begin - S, ΔS = ad_svd_vals_setup(rng, A) + S, ΔS = ad_svd_vals_setup(A) @testset for r in 1:4:minmn truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r)) - USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(rng, A, truncalg) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) ϵ = zero(real(T)) dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end @testset "trunctol" begin truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) - USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(rng, A, truncalg) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) ϵ = zero(real(T)) dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end end end end function test_mooncake_polar( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -339,16 +339,16 @@ function test_mooncake_polar( m, n = size(A) @testset "left_polar" begin if m >= n - WP, ΔWP = ad_left_polar_setup(rng, A) + WP, ΔWP = ad_left_polar_setup(A) Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) - test_pullbacks_match(rng, left_polar!, left_polar, A, WP, ΔWP) + test_pullbacks_match(left_polar!, left_polar, A, WP, ΔWP) end end @testset "right_polar" begin if m <= n - PWᴴ, ΔPWᴴ = ad_right_polar_setup(rng, A) + PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) - test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, ΔPWᴴ) + test_pullbacks_match(right_polar!, right_polar, A, PWᴴ, ΔPWᴴ) end end end @@ -369,7 +369,7 @@ MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.co MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) function test_mooncake_orthnull( - T::Type, sz, rng; + T::Type, sz; atol::Real = 0, rtol::Real = precision(T), kwargs... ) @@ -377,36 +377,36 @@ function test_mooncake_orthnull( return @testset "Orthnull Mooncake AD rules $summary_str" begin A = instantiate_matrix(T, sz) m, n = size(A) - VC, ΔVC = ad_left_orth_setup(rng, A) - CVᴴ, ΔCVᴴ = ad_right_orth_setup(rng, A) + VC, ΔVC = ad_left_orth_setup(A) + CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, left_orth!, left_orth, A, VC, ΔVC) + test_pullbacks_match(left_orth!, left_orth, A, VC, ΔVC) Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, right_orth!, right_orth, A, CVᴴ, ΔCVᴴ) + test_pullbacks_match(right_orth!, right_orth, A, CVᴴ, ΔCVᴴ) Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, ΔVC) + test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, ΔVC) if m >= n Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, ΔVC) + test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, ΔVC) end - N, ΔN = ad_left_null_setup(rng, A) + N, ΔN = ad_left_null_setup(A) dN = make_mooncake_tangent(ΔN) Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) - test_pullbacks_match(rng, ((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) + test_pullbacks_match(((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, ΔCVᴴ) + test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, ΔCVᴴ) if m <= n Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, ΔCVᴴ) + test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, ΔCVᴴ) end - Nᴴ, ΔNᴴ = ad_right_null_setup(rng, A) + Nᴴ, ΔNᴴ = ad_right_null_setup(A) dNᴴ = make_mooncake_tangent(ΔNᴴ) Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) - test_pullbacks_match(rng, ((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) + test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) end end From e8c315e6a8740d16a0e66518311866f706f83b28 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 15:26:37 -0500 Subject: [PATCH 4/6] Loosen make_mooncake_tangent --- test/testsuite/mooncake.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl index 97049784..d99e1143 100644 --- a/test/testsuite/mooncake.jl +++ b/test/testsuite/mooncake.jl @@ -6,10 +6,10 @@ using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc make_mooncake_tangent(ΔAelem::T) where {T <: Real} = ΔAelem make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) -make_mooncake_tangent(ΔA::Matrix{<:Real}) = ΔA -make_mooncake_tangent(ΔA::Vector{<:Real}) = ΔA -make_mooncake_tangent(ΔA::Matrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) -make_mooncake_tangent(ΔA::Vector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔA::AbstractMatrix{<:Real}) = ΔA +make_mooncake_tangent(ΔA::AbstractVector{<:Real}) = ΔA +make_mooncake_tangent(ΔA::AbstractMatrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔA::AbstractVector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD))) From 21db51929a01b4432aafea9e4e79201a75652287 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 12 Dec 2025 06:38:29 -0500 Subject: [PATCH 5/6] Try CUDA arrayify --- Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index d4d5a143..b3596c9a 100644 --- a/Project.toml +++ b/Project.toml @@ -57,3 +57,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Random", "Mooncake"] + +[sources] +Mooncake = {url="https://github.com/chalk-lab/Mooncake.jl", rev="ksh/cuarraify"} From 71e0575ec990481706a6c21c051bf90423c010bf Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 12 Dec 2025 08:21:01 -0500 Subject: [PATCH 6/6] Small updates for GPU --- Project.toml | 4 ++- test/testsuite/ad_utils.jl | 34 ----------------------- test/testsuite/chainrules.jl | 52 ++++++++++++++++++------------------ test/testsuite/mooncake.jl | 51 +++++++++++++++++++++++++++++------ 4 files changed, 72 insertions(+), 69 deletions(-) diff --git a/Project.toml b/Project.toml index b3596c9a..50cc4bd7 100644 --- a/Project.toml +++ b/Project.toml @@ -46,6 +46,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -56,7 +57,8 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Random", "Mooncake"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Random", "Mooncake", "GPUArrays"] [sources] Mooncake = {url="https://github.com/chalk-lab/Mooncake.jl", rev="ksh/cuarraify"} +GPUArrays = {url="https://github.com/JuliaGPU/GPUArrays.jl", rev="ksh/findlast"} diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index c4de880c..8b8a53e7 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -160,40 +160,6 @@ function ad_eig_trunc_setup(A, truncalg) return DV, (Dtrunc, Vtrunc), ΔD2V, (ΔDtrunc, ΔVtrunc) end -function copy_eigh_full(A; kwargs...) - A = (A + A') / 2 - return eigh_full(A; kwargs...) -end - -function copy_eigh_full!(A, DV; kwargs...) - A = (A + A') / 2 - return eigh_full!(A, DV; kwargs...) -end - -function copy_eigh_vals(A; kwargs...) - A = (A + A') / 2 - return eigh_vals(A; kwargs...) -end - -function copy_eigh_vals!(A, D; kwargs...) - A = (A + A') / 2 - return eigh_vals!(A, D; kwargs...) -end - -function copy_eigh_trunc(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc(A, alg; kwargs...) -end - -function copy_eigh_trunc!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc!(A, DV, alg; kwargs...) -end - -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) - function ad_eigh_full_setup(A) m, n = size(A) T = eltype(A) diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index bb062c0e..59b11a6f 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -10,7 +10,7 @@ for f in :svd_compact, :svd_trunc, :svd_vals, :left_polar, :right_polar, ) - copy_f = Symbol(:copy_, f) + copy_f = Symbol(:cr_copy_, f) f! = Symbol(f, '!') _hermitian = startswith(string(f), "eigh") @eval begin @@ -62,7 +62,7 @@ function test_chainrules_qr( QR, ΔQR = ad_qr_compact_setup(A) ΔQ, ΔR = ΔQR test_rrule( - copy_qr_compact, A, alg ⊢ NoTangent(); + cr_copy_qr_compact, A, alg ⊢ NoTangent(); output_tangent = ΔQR, atol = atol, rtol = rtol ) test_rrule( @@ -84,7 +84,7 @@ function test_chainrules_qr( @testset "qr_null" begin N, ΔN = ad_qr_null_setup(A) test_rrule( - copy_qr_null, A, alg ⊢ NoTangent(); + cr_copy_qr_null, A, alg ⊢ NoTangent(); output_tangent = ΔN, atol = atol, rtol = rtol ) test_rrule( @@ -97,7 +97,7 @@ function test_chainrules_qr( @testset "qr_full" begin QR, ΔQR = ad_qr_full_setup(A) test_rrule( - copy_qr_full, A, alg ⊢ NoTangent(); + cr_copy_qr_full, A, alg ⊢ NoTangent(); output_tangent = ΔQR, atol = atol, rtol = rtol ) test_rrule( @@ -114,7 +114,7 @@ function test_chainrules_qr( QR, ΔQR = ad_qr_rd_compact_setup(Ard) ΔQ, ΔR = ΔQR test_rrule( - copy_qr_compact, Ard, alg ⊢ NoTangent(); + cr_copy_qr_compact, Ard, alg ⊢ NoTangent(); output_tangent = ΔQR, atol = atol, rtol = rtol ) test_rrule( @@ -141,7 +141,7 @@ function test_chainrules_lq( LQ, ΔLQ = ad_lq_compact_setup(A) ΔL, ΔQ = ΔLQ test_rrule( - copy_lq_compact, A, alg ⊢ NoTangent(); + cr_copy_lq_compact, A, alg ⊢ NoTangent(); output_tangent = ΔLQ, atol = atol, rtol = rtol ) test_rrule( @@ -163,7 +163,7 @@ function test_chainrules_lq( @testset "lq_null" begin Nᴴ, ΔNᴴ = ad_lq_null_setup(A) test_rrule( - copy_lq_null, A, alg ⊢ NoTangent(); + cr_copy_lq_null, A, alg ⊢ NoTangent(); output_tangent = ΔNᴴ, atol = atol, rtol = rtol ) test_rrule( @@ -175,7 +175,7 @@ function test_chainrules_lq( @testset "lq_full" begin LQ, ΔLQ = ad_lq_full_setup(A) test_rrule( - copy_lq_full, A, alg ⊢ NoTangent(); + cr_copy_lq_full, A, alg ⊢ NoTangent(); output_tangent = ΔLQ, atol = atol, rtol = rtol ) test_rrule( @@ -190,7 +190,7 @@ function test_chainrules_lq( Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) LQ, ΔLQ = ad_lq_rd_compact_setup(Ard) test_rrule( - copy_lq_compact, Ard, alg ⊢ NoTangent(); + cr_copy_lq_compact, Ard, alg ⊢ NoTangent(); output_tangent = ΔLQ, atol = atol, rtol = rtol ) test_rrule( @@ -217,10 +217,10 @@ function test_chainrules_eig( DV, ΔDV, ΔD2V = ad_eig_full_setup(A) ΔD, ΔV = ΔDV test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol + cr_copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol ) test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol + cr_copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol ) test_rrule( config, eig_full, A, alg ⊢ NoTangent(); @@ -242,7 +242,7 @@ function test_chainrules_eig( @testset "eig_vals" begin D, ΔD = ad_eig_vals_setup(A) test_rrule( - copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol + cr_copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol ) test_rrule( config, eig_vals, A, alg ⊢ NoTangent(); @@ -254,7 +254,7 @@ function test_chainrules_eig( truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) test_rrule( - copy_eig_trunc, A, truncalg ⊢ NoTangent(); + cr_copy_eig_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔDVtrunc..., zero(real(T))), atol = atol, rtol = rtol ) @@ -266,7 +266,7 @@ function test_chainrules_eig( truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) test_rrule( - copy_eig_trunc, A, truncalg ⊢ NoTangent(); + cr_copy_eig_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔDVtrunc..., zero(real(T))), atol = atol, rtol = rtol ) @@ -295,10 +295,10 @@ function test_chainrules_eigh( DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) ΔD, ΔV = ΔDV test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol + cr_copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol ) test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol + cr_copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol ) # eigh_full does not include a projector onto the Hermitian part of the matrix test_rrule( @@ -321,7 +321,7 @@ function test_chainrules_eigh( @testset "eigh_vals" begin D, ΔD = ad_eigh_vals_setup(A) test_rrule( - copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol + cr_copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol ) test_rrule( config, eigh_vals ∘ Matrix ∘ Hermitian, A; @@ -334,7 +334,7 @@ function test_chainrules_eigh( truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) test_rrule( - copy_eigh_trunc, A, truncalg ⊢ NoTangent(); + cr_copy_eigh_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔDVtrunc..., zero(real(T))), atol = atol, rtol = rtol ) @@ -358,7 +358,7 @@ function test_chainrules_eigh( DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) test_rrule( - copy_eigh_trunc, A, truncalg ⊢ NoTangent(); + cr_copy_eigh_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔDVtrunc..., zero(real(T))), atol = atol, rtol = rtol ) @@ -393,11 +393,11 @@ function test_chainrules_svd( @testset "svd_compact" begin USV, ΔUSVᴴ, ΔUS2Vᴴ = ad_svd_compact_setup(A) test_rrule( - copy_svd_compact, A, alg ⊢ NoTangent(); + cr_copy_svd_compact, A, alg ⊢ NoTangent(); output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol ) test_rrule( - copy_svd_compact, A, alg ⊢ NoTangent(); + cr_copy_svd_compact, A, alg ⊢ NoTangent(); output_tangent = ΔUS2Vᴴ, atol = atol, rtol = rtol ) test_rrule( @@ -414,7 +414,7 @@ function test_chainrules_svd( @testset "svd_vals" begin S, ΔS = ad_svd_vals_setup(A) test_rrule( - copy_svd_vals, A, alg ⊢ NoTangent(); + cr_copy_svd_vals, A, alg ⊢ NoTangent(); output_tangent = ΔS, atol, rtol ) test_rrule( @@ -427,7 +427,7 @@ function test_chainrules_svd( truncalg = TruncatedAlgorithm(alg, truncrank(r)) USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); + cr_copy_svd_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), atol = atol, rtol = rtol ) @@ -453,7 +453,7 @@ function test_chainrules_svd( truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); + cr_copy_svd_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), atol = atol, rtol = rtol ) @@ -490,7 +490,7 @@ function test_chainrules_polar( alg = MatrixAlgebraKit.default_polar_algorithm(A) @testset "left_polar" begin if m >= n - test_rrule(copy_left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + test_rrule(cr_copy_left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) test_rrule( config, left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false @@ -499,7 +499,7 @@ function test_chainrules_polar( end @testset "right_polar" begin if m <= n - test_rrule(copy_right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + test_rrule(cr_copy_right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) test_rrule( config, right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl index d99e1143..c2bb639b 100644 --- a/test/testsuite/mooncake.jl +++ b/test/testsuite/mooncake.jl @@ -4,6 +4,41 @@ using Mooncake, Mooncake.TestUtils using Mooncake: rrule!! using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc +function mc_copy_eigh_full(A; kwargs...) + A = (A + A') / 2 + return eigh_full(A; kwargs...) +end + +function mc_copy_eigh_full!(A, DV; kwargs...) + A = (A + A') / 2 + return eigh_full!(A, DV; kwargs...) +end + +function mc_copy_eigh_vals(A; kwargs...) + A = (A + A') / 2 + return eigh_vals(A; kwargs...) +end + +function mc_copy_eigh_vals!(A, D; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D; kwargs...) +end + +function mc_copy_eigh_trunc(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc(A, alg; kwargs...) +end + +function mc_copy_eigh_trunc!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc!(A, DV, alg; kwargs...) +end + +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) + + make_mooncake_tangent(ΔAelem::T) where {T <: Real} = ΔAelem make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) make_mooncake_tangent(ΔA::AbstractMatrix{<:Real}) = ΔA @@ -251,14 +286,14 @@ function test_mooncake_eigh( @testset "eigh_full" begin DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) dDV = make_mooncake_tangent(ΔD2V) - Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(mc_copy_eigh_full!, mc_copy_eigh_full, A, DV, ΔD2V) end @testset "eigh_vals" begin D, ΔD = ad_eigh_vals_setup(A) dD = make_mooncake_tangent(ΔD) - Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(copy_eigh_vals!, copy_eigh_vals, A, D, ΔD) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(mc_copy_eigh_vals!, mc_copy_eigh_vals, A, D, ΔD) end @testset "eigh_trunc" begin for r in 1:4:m @@ -266,16 +301,16 @@ function test_mooncake_eigh( DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ϵ = zero(real(T)) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(copy_eigh_trunc!, copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end D = eigh_vals(A / 2) truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ϵ = zero(real(T)) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(copy_eigh_trunc!, copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end end end