diff --git a/Project.toml b/Project.toml index b13be4aa..9e61a2e4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.5.0" +version = "0.5.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/BlockSparseArrays.jl b/src/BlockSparseArrays.jl index 78ce8237..ffbfe937 100644 --- a/src/BlockSparseArrays.jl +++ b/src/BlockSparseArrays.jl @@ -44,5 +44,6 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl") # factorizations include("factorizations/svd.jl") +include("factorizations/truncation.jl") end diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index 4f64ff57..f83c11b3 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -4,7 +4,8 @@ using MatrixAlgebraKit: MatrixAlgebraKit, svd_compact!, svd_full! BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm) A wrapper for `MatrixAlgebraKit.AbstractAlgorithm` that implements the wrapped algorithm on -a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted block-diagonal matrix. +a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or +a block permuted block-diagonal matrix. """ struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <: MatrixAlgebraKit.AbstractAlgorithm diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl new file mode 100644 index 00000000..0029c7a1 --- /dev/null +++ b/src/factorizations/truncation.jl @@ -0,0 +1,102 @@ +using MatrixAlgebraKit: TruncationStrategy, diagview, svd_trunc! + +function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T} + D = BlockSparseVector{T}(undef, axes(A, 1)) + for I in eachblockstoredindex(A) + if ==(Int.(Tuple(I))...) + D[Tuple(I)[1]] = diagview(A[I]) + end + end + return D +end + +""" + BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy) + +A wrapper for `TruncationStrategy` that implements the wrapped strategy on a block-by-block +basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted +block-diagonal matrix. +""" +struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy + strategy::T +end + +const TBlockUSVᴴ = Tuple{ + <:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix +} + +function MatrixAlgebraKit.truncate!( + ::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy +) + # TODO assert blockdiagonal + return MatrixAlgebraKit.truncate!( + svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy) + ) +end + +# cannot use regular slicing here: I want to slice without altering blockstructure +# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map +function MatrixAlgebraKit.findtruncated( + values::AbstractVector, strategy::BlockPermutedDiagonalTruncationStrategy +) + ind = MatrixAlgebraKit.findtruncated(values, strategy.strategy) + indexmask = falses(length(values)) + indexmask[ind] .= true + return indexmask +end + +function MatrixAlgebraKit.truncate!( + ::typeof(svd_trunc!), + (U, S, Vᴴ)::TBlockUSVᴴ, + strategy::BlockPermutedDiagonalTruncationStrategy, +) + indexmask = MatrixAlgebraKit.findtruncated(diagview(S), strategy) + + # first determine the block structure of the output to avoid having assumptions on the + # data structures + ax = axes(S, 1) + counter = Base.Fix1(count, Base.Fix1(getindex, indexmask)) + Slengths = filter!(>(0), map(counter, blocks(ax))) + Sax = blockedrange(Slengths) + Ũ = similar(U, axes(U, 1), Sax) + S̃ = similar(S, Sax, Sax) + Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2)) + + # then loop over the blocks and assign the data + # TODO: figure out if we can presort and loop over the blocks - + # for now this has issues with missing blocks + bI_Us = collect(eachblockstoredindex(U)) + bI_Ss = collect(eachblockstoredindex(S)) + bI_Vᴴs = collect(eachblockstoredindex(Vᴴ)) + + I′ = 0 # number of skipped blocks that got fully truncated + for I in 1:blocksize(ax, 1) + b = ax[Block(I)] + mask = indexmask[b] + + if !any(mask) + I′ += 1 + continue + end + + bU_id = @something findfirst(x -> last(Tuple(x)) == Block(I), bI_Us) error( + "No U-block found for $I" + ) + bU = Tuple(bI_Us[bU_id]) + Ũ[bU[1], bU[2] - Block(I′)] = view(U, bU...)[:, mask] + + bVᴴ_id = @something findfirst(x -> first(Tuple(x)) == Block(I), bI_Vᴴs) error( + "No Vᴴ-block found for $I" + ) + bVᴴ = Tuple(bI_Vᴴs[bVᴴ_id]) + Ṽᴴ[bVᴴ[1] - Block(I′), bVᴴ[2]] = view(Vᴴ, bVᴴ...)[mask, :] + + bS_id = findfirst(x -> last(Tuple(x)) == Block(I), bI_Ss) + if !isnothing(bS_id) + bS = Tuple(bI_Ss[bS_id]) + S̃[(bS .- Block(I′))...] = Diagonal(diagview(view(S, bS...))[mask]) + end + end + + return Ũ, S̃, Ṽᴴ +end diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 7a28fdb7..5152237b 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,6 +1,6 @@ using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex -using MatrixAlgebraKit: svd_compact, svd_full +using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank, trunctol using LinearAlgebra: LinearAlgebra using Random: Random using Test: @inferred, @testset, @test @@ -83,3 +83,74 @@ end usv = svd_full(c) @test test_svd(c, usv; full=true) end + +# svd_trunc! +# ---------- + +@testset "svd_trunc ($m, $n) BlockSparseMatri{$T}" for ((m, n), T) in test_params + a = BlockSparseArray{T}(undef, m, n) + + # test blockdiagonal + for i in LinearAlgebra.diagind(blocks(a)) + I = CartesianIndices(blocks(a))[i] + a[Block(I.I...)] = rand(T, size(blocks(a)[i])) + end + + minmn = min(size(a)...) + r = max(1, minmn - 2) + trunc = truncrank(r) + + U1, S1, V1ᴴ = svd_trunc(a; trunc) + U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc) + @test size(U1) == size(U2) + @test size(S1) == size(S2) + @test size(V1ᴴ) == size(V2ᴴ) + @test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ + + @test (U1' * U1 ≈ LinearAlgebra.I) + @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I) + + atol = minimum(LinearAlgebra.diag(S1)) + 10 * eps(real(T)) + trunc = trunctol(atol) + + U1, S1, V1ᴴ = svd_trunc(a; trunc) + U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc) + @test size(U1) == size(U2) + @test size(S1) == size(S2) + @test size(V1ᴴ) == size(V2ᴴ) + @test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ + + @test (U1' * U1 ≈ LinearAlgebra.I) + @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I) + + # test permuted blockdiagonal + perm = Random.randperm(length(m)) + b = a[Block.(perm), Block.(1:length(n))] + for trunc in (truncrank(r), trunctol(atol)) + U1, S1, V1ᴴ = svd_trunc(b; trunc) + U2, S2, V2ᴴ = svd_trunc(Matrix(b); trunc) + @test size(U1) == size(U2) + @test size(S1) == size(S2) + @test size(V1ᴴ) == size(V2ᴴ) + @test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ + + @test (U1' * U1 ≈ LinearAlgebra.I) + @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I) + end + + # test permuted blockdiagonal with missing row/col + I_removed = rand(eachblockstoredindex(b)) + c = copy(b) + delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed)))) + for trunc in (truncrank(r), trunctol(atol)) + U1, S1, V1ᴴ = svd_trunc(c; trunc) + U2, S2, V2ᴴ = svd_trunc(Matrix(c); trunc) + @test size(U1) == size(U2) + @test size(S1) == size(S2) + @test size(V1ᴴ) == size(V2ᴴ) + @test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ + + @test (U1' * U1 ≈ LinearAlgebra.I) + @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I) + end +end