Skip to content

Commit 67199c8

Browse files
committed
Introduct QR flag to gauge_fixed_svd as well
1 parent dc86911 commit 67199c8

File tree

1 file changed

+30
-21
lines changed

1 file changed

+30
-21
lines changed

varipeps/utils/svd.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,33 @@ def _H(x):
2424
return jnp.conj(_T(x))
2525

2626

27-
@custom_jvp
28-
def svd_wrapper(a):
27+
@partial(custom_jvp, nondiff_argnums=(1,))
28+
def svd_wrapper(a, use_qr=False):
2929
check_arraylike("jnp.linalg.svd", a)
3030
(a,) = promote_dtypes_inexact(jnp.asarray(a))
3131

32-
result = lax_svd(a, full_matrices=False, compute_uv=True)
33-
34-
result = lax.cond(
35-
jnp.isnan(jnp.sum(result[1])),
36-
lambda matrix, _: lax_svd(
37-
matrix,
32+
if use_qr:
33+
result = lax_svd(
34+
a,
3835
full_matrices=False,
3936
compute_uv=True,
4037
algorithm=lax.linalg.SvdAlgorithm.QR,
41-
),
42-
lambda _, res: res,
43-
a,
44-
result,
45-
)
38+
)
39+
else:
40+
result = lax_svd(a, full_matrices=False, compute_uv=True)
41+
42+
result = lax.cond(
43+
jnp.isnan(jnp.sum(result[1])),
44+
lambda matrix, _: lax_svd(
45+
matrix,
46+
full_matrices=False,
47+
compute_uv=True,
48+
algorithm=lax.linalg.SvdAlgorithm.QR,
49+
),
50+
lambda _, res: res,
51+
a,
52+
result,
53+
)
4654

4755
return result
4856

@@ -51,10 +59,10 @@ def _svd_jvp_rule_impl(primals, tangents, only_u_or_vt=None, use_qr=False):
5159
(A,) = primals
5260
(dA,) = tangents
5361

54-
if use_qr:
62+
if use_qr and only_u_or_vt is not None:
5563
U, s, Vt = _svd_only_u_vt_impl(A, u_or_vt=2, use_qr=True)
5664
else:
57-
U, s, Vt = svd_wrapper(A)
65+
U, s, Vt = svd_wrapper(A, use_qr=use_qr)
5866

5967
Ut, V = _H(U), _H(Vt)
6068
s_dim = s[..., None, :]
@@ -106,8 +114,8 @@ def _svd_jvp_rule_impl(primals, tangents, only_u_or_vt=None, use_qr=False):
106114

107115

108116
@svd_wrapper.defjvp
109-
def _svd_jvp_rule(primals, tangents):
110-
return _svd_jvp_rule_impl(primals, tangents)
117+
def _svd_jvp_rule(use_qr, primals, tangents):
118+
return _svd_jvp_rule_impl(primals, tangents, use_qr=use_qr)
111119

112120

113121
jax.ffi.register_ffi_target(
@@ -293,10 +301,11 @@ def _svd_only_vt_jvp_rule(use_qr, primals, tangents):
293301
return _svd_jvp_rule_impl(primals, tangents, only_u_or_vt="Vt", use_qr=use_qr)
294302

295303

296-
@partial(jit, inline=True, static_argnums=(1,))
304+
@partial(jit, inline=True, static_argnums=(1, 2))
297305
def gauge_fixed_svd(
298306
matrix: jnp.ndarray,
299307
only_u_or_vh=None,
308+
use_qr=False,
300309
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
301310
"""
302311
Calculate the gauge-fixed (also called sign-fixed) SVD. To this end, each
@@ -316,13 +325,13 @@ def gauge_fixed_svd(
316325
Tuple with sign-fixed U, S and Vh of the SVD.
317326
"""
318327
if only_u_or_vh is None:
319-
U, S, Vh = svd_wrapper(matrix)
328+
U, S, Vh = svd_wrapper(matrix, use_qr=use_qr)
320329
gauge_unitary = U
321330
elif only_u_or_vh == "U":
322-
U, S = svd_only_u(matrix)
331+
U, S = svd_only_u(matrix, use_qr=use_qr)
323332
gauge_unitary = U
324333
elif only_u_or_vh == "Vh":
325-
S, Vh = svd_only_vt(matrix)
334+
S, Vh = svd_only_vt(matrix, use_qr=use_qr)
326335
gauge_unitary = Vh.T.conj()
327336
else:
328337
raise ValueError("Invalid value for parameter 'only_u_or_vh'.")

0 commit comments

Comments
 (0)