|
1 | 1 | using LinearAlgebra: Diagonal, I, diag, isposdef, norm |
2 | 2 | using MatrixAlgebraKit: qr_compact, svd_trunc, truncrank |
3 | 3 | using StableRNGs: StableRNG |
4 | | -using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncdegen, truncerr |
| 4 | +using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncdegen |
5 | 5 | using Test: @test, @testset |
6 | 6 |
|
7 | 7 | elts = (Float32, Float64, ComplexF32, ComplexF64) |
@@ -152,158 +152,6 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) |
152 | 152 | @test V' * V ≈ I |
153 | 153 | @test MatrixAlgebra.svdvals(A) ≈ diag(S) |
154 | 154 | end |
155 | | - @testset "Truncation" begin |
156 | | - s = Diagonal(real(elt)[1.2, 0.9, 0.3, 0.2, 0.01]) |
157 | | - n = length(diag(s)) |
158 | | - rng = StableRNG(123) |
159 | | - u, _ = qr_compact(randn(rng, elt, n, n); positive=true) |
160 | | - v, _ = qr_compact(randn(rng, elt, n, n); positive=true) |
161 | | - a = u * s * v |
162 | | - |
163 | | - # p = 2, relative = true |
164 | | - ũ, s̃, ṽ = svd_trunc( |
165 | | - a; trunc=truncerr(; rtol=norm([0.3, 0.2, 0.01]) / norm(diag(s)) + 10eps(real(elt))) |
166 | | - ) |
167 | | - @test size(ũ) == (n, 2) |
168 | | - @test size(s̃) == (2, 2) |
169 | | - @test size(ṽ) == (2, n) |
170 | | - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) |
171 | | - ũ, s̃, ṽ = svd_trunc( |
172 | | - a; trunc=truncerr(; rtol=norm([0.3, 0.2, 0.01]) / norm(diag(s)) - 10eps(real(elt))) |
173 | | - ) |
174 | | - @test size(ũ) == (n, 3) |
175 | | - @test size(s̃) == (3, 3) |
176 | | - @test size(ṽ) == (3, n) |
177 | | - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) |
178 | | - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=0)) |
179 | | - @test size(ũ) == (n, n) |
180 | | - @test size(s̃) == (n, n) |
181 | | - @test size(ṽ) == (n, n) |
182 | | - @test ũ * s̃ * ṽ ≈ a |
183 | | - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=1)) |
184 | | - @test size(ũ) == (n, 0) |
185 | | - @test size(s̃) == (0, 0) |
186 | | - @test size(ṽ) == (0, n) |
187 | | - @test norm(ũ * s̃ * ṽ) ≈ 0 |
188 | | - |
189 | | - # p = 2, relative = false |
190 | | - ũ, s̃, ṽ = svd_trunc( |
191 | | - a; trunc=truncerr(; atol=norm([0.3, 0.2, 0.01]) + 10eps(real(elt))) |
192 | | - ) |
193 | | - @test size(ũ) == (n, 2) |
194 | | - @test size(s̃) == (2, 2) |
195 | | - @test size(ṽ) == (2, n) |
196 | | - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) |
197 | | - ũ, s̃, ṽ = svd_trunc( |
198 | | - a; trunc=truncerr(; atol=norm([0.3, 0.2, 0.01]) - 10eps(real(elt))) |
199 | | - ) |
200 | | - @test size(ũ) == (n, 3) |
201 | | - @test size(s̃) == (3, 3) |
202 | | - @test size(ṽ) == (3, n) |
203 | | - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) |
204 | | - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0)) |
205 | | - @test size(ũ) == (n, n) |
206 | | - @test size(s̃) == (n, n) |
207 | | - @test size(ṽ) == (n, n) |
208 | | - @test ũ * s̃ * ṽ ≈ a |
209 | | - ũ, s̃, ṽ = svd_trunc( |
210 | | - a; trunc=truncerr(; atol=(norm(diag(s)) * (one(real(elt)) + 10eps(real(elt))))) |
211 | | - ) |
212 | | - @test size(ũ) == (n, 0) |
213 | | - @test size(s̃) == (0, 0) |
214 | | - @test size(ṽ) == (0, n) |
215 | | - @test norm(ũ * s̃ * ṽ) ≈ 0 |
216 | | - |
217 | | - # p = 1, relative = true |
218 | | - ũ, s̃, ṽ = svd_trunc( |
219 | | - a; |
220 | | - trunc=truncerr(; |
221 | | - rtol=(norm([0.3, 0.2, 0.01], 1) / norm(diag(s), 1) + 10eps(real(elt))), p=1 |
222 | | - ), |
223 | | - ) |
224 | | - @test size(ũ) == (n, 2) |
225 | | - @test size(s̃) == (2, 2) |
226 | | - @test size(ṽ) == (2, n) |
227 | | - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) |
228 | | - ũ, s̃, ṽ = svd_trunc( |
229 | | - a; |
230 | | - trunc=truncerr(; |
231 | | - rtol=(norm([0.3, 0.2, 0.01], 1) / norm(diag(s), 1) - 10eps(real(elt))), p=1 |
232 | | - ), |
233 | | - ) |
234 | | - @test size(ũ) == (n, 3) |
235 | | - @test size(s̃) == (3, 3) |
236 | | - @test size(ṽ) == (3, n) |
237 | | - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) |
238 | | - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=0, p=1)) |
239 | | - @test size(ũ) == (n, n) |
240 | | - @test size(s̃) == (n, n) |
241 | | - @test size(ṽ) == (n, n) |
242 | | - @test ũ * s̃ * ṽ ≈ a |
243 | | - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=1, p=1)) |
244 | | - @test size(ũ) == (n, 0) |
245 | | - @test size(s̃) == (0, 0) |
246 | | - @test size(ṽ) == (0, n) |
247 | | - @test norm(ũ * s̃ * ṽ) ≈ 0 |
248 | | - |
249 | | - # p = 1, relative = false |
250 | | - ũ, s̃, ṽ = svd_trunc( |
251 | | - a; trunc=truncerr(; atol=(norm([0.3, 0.2, 0.01], 1) + 10eps(real(elt))), p=1) |
252 | | - ) |
253 | | - @test size(ũ) == (n, 2) |
254 | | - @test size(s̃) == (2, 2) |
255 | | - @test size(ṽ) == (2, n) |
256 | | - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) |
257 | | - ũ, s̃, ṽ = svd_trunc( |
258 | | - a; trunc=truncerr(; atol=(norm([0.3, 0.2, 0.01], 1) - 10eps(real(elt))), p=1) |
259 | | - ) |
260 | | - @test size(ũ) == (n, 3) |
261 | | - @test size(s̃) == (3, 3) |
262 | | - @test size(ṽ) == (3, n) |
263 | | - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) |
264 | | - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0, p=1)) |
265 | | - @test size(ũ) == (n, n) |
266 | | - @test size(s̃) == (n, n) |
267 | | - @test size(ṽ) == (n, n) |
268 | | - @test ũ * s̃ * ṽ ≈ a |
269 | | - ũ, s̃, ṽ = svd_trunc( |
270 | | - a; |
271 | | - trunc=truncerr(; atol=(norm(diag(s), 1) * (one(real(elt)) + 10eps(real(elt)))), p=1), |
272 | | - ) |
273 | | - @test size(ũ) == (n, 0) |
274 | | - @test size(s̃) == (0, 0) |
275 | | - @test size(ṽ) == (0, n) |
276 | | - @test norm(ũ * s̃ * ṽ) ≈ 0 |
277 | | - |
278 | | - # Specifying both `atol` and `rtol`. |
279 | | - s = Diagonal(real(elt)[0.1, 0.01, 0.001]) |
280 | | - n = length(diag(s)) |
281 | | - rng = StableRNG(123) |
282 | | - u, _ = qr_compact(randn(rng, elt, n, n); positive=true) |
283 | | - v, _ = qr_compact(randn(rng, elt, n, n); positive=true) |
284 | | - a = u * s * v |
285 | | - |
286 | | - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=0.002)) |
287 | | - @test size(ũ) == (n, n) |
288 | | - @test size(s̃) == (n, n) |
289 | | - @test size(ṽ) == (n, n) |
290 | | - @test ũ * s̃ * ṽ ≈ a |
291 | | - @test ũ * s̃ * ṽ ≈ a rtol = 0.002 |
292 | | - |
293 | | - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0.002)) |
294 | | - @test size(ũ) == (n, 2) |
295 | | - @test size(s̃) == (2, 2) |
296 | | - @test size(ṽ) == (2, n) |
297 | | - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.001]) |
298 | | - @test ũ * s̃ * ṽ ≈ a atol = 0.002 |
299 | | - |
300 | | - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0.002, rtol=0.002)) |
301 | | - @test size(ũ) == (n, 2) |
302 | | - @test size(s̃) == (2, 2) |
303 | | - @test size(ṽ) == (2, n) |
304 | | - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.001]) |
305 | | - @test ũ * s̃ * ṽ ≈ a atol = 0.002 rtol = 0.002 |
306 | | - end |
307 | 155 | @testset "Truncate degenerate" begin |
308 | 156 | s = Diagonal(real(elt)[2.0, 0.32, 0.3, 0.29, 0.01, 0.01]) |
309 | 157 | n = length(diag(s)) |
|
0 commit comments