From 4a1d5926b1031e9a1d808fd9fdbd62d384950b2c Mon Sep 17 00:00:00 2001 From: maltezfaria Date: Sun, 4 May 2025 10:10:54 +0200 Subject: [PATCH] support hermitian types for `KernelMatrix` --- src/kernelmatrix.jl | 9 ++++++- src/multiplication.jl | 22 +++++++-------- test/cholesky_test.jl | 63 +++++++++++++++++++++++++++++-------------- 3 files changed, 62 insertions(+), 32 deletions(-) diff --git a/src/kernelmatrix.jl b/src/kernelmatrix.jl index 50ede61..d8fcd7b 100644 --- a/src/kernelmatrix.jl +++ b/src/kernelmatrix.jl @@ -48,6 +48,12 @@ function KernelMatrix(f, X, Y) return KernelMatrix{typeof(f),typeof(X),typeof(Y),T}(f, X, Y) end +# light wrapper for Hermitian +const HermitianKernelMatrix = Hermitian{<:Any,<:KernelMatrix} + +rowelements(K::HermitianKernelMatrix) = K |> parent |> rowelements +colelements(K::HermitianKernelMatrix) = K |> parent |> colelements + """ assemble_hmatrix(K::AbstractKernelMatrix[; atol, rank, rtol, kwargs...]) @@ -57,7 +63,8 @@ arguments are passed to the [`PartialACA`](@ref) constructor, and the remaining keyword arguments are forwarded to the main `assemble_hmatrix` function. """ function assemble_hmatrix( - K::AbstractKernelMatrix; + K::Union{KernelMatrix,HermitianKernelMatrix}; + # K::KernelMatrix; atol = 0, rank = typemax(Int), rtol = atol > 0 || rank < typemax(Int) ? 0 : sqrt(eps(Float64)), diff --git a/src/multiplication.jl b/src/multiplication.jl index ce861f1..be14fec 100644 --- a/src/multiplication.jl +++ b/src/multiplication.jl @@ -221,26 +221,26 @@ Multiplication when the target is a dense matrix. The numbering system in the fo function _mul_dense!(C::Base.Matrix, A, B, a) Adata = isleaf(A) ? data(A) : A Bdata = isleaf(B) ? data(B) : B - if Adata isa HMatrix - if Bdata isa Matrix + if parent(Adata) isa HMatrix + if parent(Bdata) isa Matrix _mul131!(C, Adata, Bdata, a) - elseif Bdata isa RkMatrix + elseif parent(Bdata) isa RkMatrix _mul132!(C, Adata, Bdata, a) end - elseif Adata isa AdjOrMat - if Bdata isa Matrix + elseif parent(Adata) isa AdjOrMat + if parent(Bdata) isa Matrix _mul111!(C, Adata, Bdata, a) - elseif Bdata isa RkMatrix + elseif parent(Bdata) isa RkMatrix _mul112!(C, Adata, Bdata, a) - elseif Bdata isa HMatrix + elseif parent(Bdata) isa HMatrix _mul113!(C, Adata, Bdata, a) end - elseif Adata isa RkMatrix - if Bdata isa Matrix + elseif parent(Adata) isa RkMatrix + if parent(Bdata) isa Matrix _mul121!(C, Adata, Bdata, a) - elseif Bdata isa RkMatrix + elseif parent(Bdata) isa RkMatrix _mul122!(C, Adata, Bdata, a) - elseif Bdata isa HMatrix + elseif parent(Bdata) isa HMatrix _mul123!(C, Adata, Bdata, a) end else diff --git a/test/cholesky_test.jl b/test/cholesky_test.jl index 1462369..5d0dc87 100644 --- a/test/cholesky_test.jl +++ b/test/cholesky_test.jl @@ -10,26 +10,49 @@ include(joinpath(HMatrices.PROJECT_ROOT, "test", "testutils.jl")) Random.seed!(1) -m = 5000 -T = Float64 -X = points_on_sphere(m) -Y = X +@testset "Assemble and solve" begin + m = 5000 + T = Float64 + X = points_on_sphere(m) + Y = X -K = laplace_matrix(X, X) + K = laplace_matrix(X, X) -splitter = CardinalitySplitter(; nmax = 50) -Xclt = ClusterTree(X, splitter) -Yclt = ClusterTree(Y, splitter) -adm = StrongAdmissibilityStd(3) -comp = PartialACA(; atol = 1e-10) -for threads in (false, true) - H = assemble_hmatrix(Hermitian(K), Xclt, Yclt; adm, comp, threads, distributed = false) - hchol = cholesky(H; atol = 1e-10) - y = rand(m) - M = Matrix(K) - exact = M \ y - approx = hchol \ y - @test norm(exact - approx, Inf) < 1e-10 - # test multiplication by checking if the solution is correct - @test hchol.L * (hchol.U * approx) ≈ y + splitter = CardinalitySplitter(; nmax = 50) + Xclt = ClusterTree(X, splitter) + Yclt = ClusterTree(Y, splitter) + adm = StrongAdmissibilityStd(3) + comp = PartialACA(; atol = 1e-10) + for threads in (false, true) + H = assemble_hmatrix( + Hermitian(K), + Xclt, + Yclt; + adm, + comp, + threads, + distributed = false, + ) + hchol = cholesky(H; atol = 1e-10) + y = rand(m) + M = Matrix(K) + exact = M \ y + approx = hchol \ y + @test norm(exact - approx, Inf) < 1e-10 + # test multiplication by checking if the solution is correct + @test hchol.L * (hchol.U * approx) ≈ y + end +end + +@testset "Issue 80" begin + m = 1000 + pts = rand(SVector{2,Float64}, m) + km = KernelMatrix((x, y) -> exp(-norm(x - y)), pts, pts) |> Hermitian + hk = assemble_hmatrix(km; atol = 1e-10) + hkf = cholesky(hk; atol = 1e-10) + full = Matrix(hk) + y = rand(m) + x_exact = full \ y + x_approx = hkf \ y + norm(x_exact - x_approx, Inf) / norm(x_exact, Inf) < 1e-8 end