Skip to content

Commit c3ad006

Browse files
committed
Ensure that the inverse for the non-square case of the SVD is non-zero as well
1 parent 5d51aa0 commit c3ad006

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

varipeps/utils/svd.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ def _svd_jvp_rule_impl(primals, tangents, only_u_or_vt=None, use_qr=False):
9292
m, n = A.shape[-2:]
9393
if m > n and (only_u_or_vt is None or only_u_or_vt == "U"):
9494
dAV = dA @ V
95-
dU = dU + (dAV - U @ (Ut @ dAV)) / s_dim.astype(A.dtype)
95+
dU = dU + (dAV - U @ (Ut @ dAV)) * s_inv.astype(A.dtype)
9696
if n > m and (only_u_or_vt is None or only_u_or_vt == "Vt"):
9797
dAHU = _H(dA) @ U
98-
dV = dV + (dAHU - V @ (Vt @ dAHU)) / s_dim.astype(A.dtype)
98+
dV = dV + (dAHU - V @ (Vt @ dAHU)) * s_inv.astype(A.dtype)
9999

100100
if only_u_or_vt is None:
101101
return (U, s, Vt), (dU, ds, _H(dV))
@@ -293,10 +293,9 @@ def _svd_only_vt_jvp_rule(use_qr, primals, tangents):
293293
return _svd_jvp_rule_impl(primals, tangents, only_u_or_vt="Vt", use_qr=use_qr)
294294

295295

296-
@partial(jit, inline=True)
296+
@partial(jit, inline=True, static_argnums=(1,))
297297
def gauge_fixed_svd(
298298
matrix: jnp.ndarray,
299-
*,
300299
only_u_or_vh=None,
301300
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
302301
"""

0 commit comments

Comments
 (0)