Skip to content

Commit 877d5b8

Browse files
committed
Add initial truncation implementation
1 parent 6702176 commit 877d5b8

File tree

3 files changed

+118
-1
lines changed

3 files changed

+118
-1
lines changed

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,6 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")
4444

4545
# factorizations
4646
include("factorizations/svd.jl")
47+
include("factorizations/truncation.jl")
4748

4849
end

src/factorizations/truncation.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using MatrixAlgebraKit: TruncationStrategy, diagview
2+
3+
const TBlockUSVᴴ = Tuple{
4+
<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix
5+
}
6+
7+
function MatrixAlgebraKit.truncate!(
8+
::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy
9+
)
10+
ind = MatrixAlgebraKit.findtruncated(diagview(S), strategy)
11+
# cannot use regular slicing here: I want to slice without altering blockstructure
12+
# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map
13+
indexmask = falses(size(S, 1))
14+
indexmask[ind] .= true
15+
16+
# first determine the block structure of the output to avoid having assumptions on the
17+
# data structures
18+
ax = axes(S, 1)
19+
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
20+
Slengths = filter!(>(0), map(counter, blocks(ax)))
21+
Sax = blockedrange(Slengths)
22+
= similar(U, axes(U, 1), Sax)
23+
= similar(S, Sax, Sax)
24+
Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2))
25+
26+
# then loop over the blocks and assign the data
27+
# TODO: figure out if we can presort and loop over the blocks -
28+
# for now this has issues with missing blocks
29+
bI_Us = collect(eachblockstoredindex(U))
30+
bI_Ss = collect(eachblockstoredindex(S))
31+
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ))
32+
33+
I′ = 0 # number of skipped blocks that got fully truncated
34+
for (I, b) in enumerate(blocks(ax))
35+
mask = indexmask[b]
36+
37+
if !any(mask)
38+
I′ += 1
39+
continue
40+
end
41+
42+
bU_id = @something findfirst(x -> last(Tuple(x)) == Block(I), bI_Us) error(
43+
"No U-block found for $I"
44+
)
45+
bU = Tuple(bI_Us[bU_id])
46+
Ũ[bU[1], bU[2] - Block(I′)] = view(U, bU...)[:, mask]
47+
48+
bVᴴ_id = @something findfirst(x -> first(Tuple(x)) == Block(I), bI_Vᴴs) error(
49+
"No Vᴴ-block found for $I"
50+
)
51+
bVᴴ = Tuple(bI_Vᴴs[bVᴴ_id])
52+
Ṽᴴ[bVᴴ[1] - Block(I′), bVᴴ[2]] = view(Vᴴ, bVᴴ...)[mask, :]
53+
54+
bS_id = @something findfirst(x -> last(Tuple(x)) == Block(I), bI_Ss) error(
55+
"No S-block found for $I"
56+
)
57+
bS = Tuple(bI_Ss[bS_id])
58+
S̃[(bS .- Block(I′))...] = Diagonal(diagview(view(S, bS...))[mask])
59+
end
60+
61+
return Ũ, S̃, Ṽᴴ
62+
end

test/test_factorizations.jl

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
22
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
3-
using MatrixAlgebraKit: svd_compact, svd_full
3+
using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank
44
using LinearAlgebra: LinearAlgebra
55
using Random: Random
66
using Test: @inferred, @testset, @test
@@ -83,3 +83,57 @@ end
8383
usv = svd_full(c)
8484
@test test_svd(c, usv; full=true)
8585
end
86+
87+
# svd_trunc!
88+
# ----------
89+
90+
@testset "svd_trunc ($m, $n) BlockSparseMatri{$T}" for ((m, n), T) in test_params
91+
(m, n), T = first(test_params)
92+
a = BlockSparseArray{T}(undef, m, n)
93+
94+
# test blockdiagonal
95+
for i in LinearAlgebra.diagind(blocks(a))
96+
I = CartesianIndices(blocks(a))[i]
97+
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
98+
end
99+
100+
minmn = min(size(a)...)
101+
r = max(1, minmn - 2)
102+
103+
U1, S1, V1ᴴ = svd_trunc(a; trunc=truncrank(r))
104+
U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc=truncrank(r))
105+
@test size(U1) == size(U2)
106+
@test size(S1) == size(S2)
107+
@test size(V1ᴴ) == size(V2ᴴ)
108+
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ
109+
110+
@test (U1' * U1 LinearAlgebra.I)
111+
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
112+
113+
# test permuted blockdiagonal
114+
perm = Random.randperm(length(m))
115+
b = a[Block.(perm), Block.(1:length(n))]
116+
U1, S1, V1ᴴ = svd_trunc(b; trunc=truncrank(r))
117+
U2, S2, V2ᴴ = svd_trunc(Matrix(b); trunc=truncrank(r))
118+
@test size(U1) == size(U2)
119+
@test size(S1) == size(S2)
120+
@test size(V1ᴴ) == size(V2ᴴ)
121+
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ
122+
123+
@test (U1' * U1 LinearAlgebra.I)
124+
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
125+
126+
# test permuted blockdiagonal with missing row/col
127+
I_removed = rand(eachblockstoredindex(b))
128+
c = copy(b)
129+
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
130+
U1, S1, V1ᴴ = svd_trunc(c; trunc=truncrank(r))
131+
U2, S2, V2ᴴ = svd_trunc(Matrix(c); trunc=truncrank(r))
132+
@test size(U1) == size(U2)
133+
@test size(S1) == size(S2)
134+
@test size(V1ᴴ) == size(V2ᴴ)
135+
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ
136+
137+
@test (U1' * U1 LinearAlgebra.I)
138+
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
139+
end

0 commit comments

Comments
 (0)