@@ -24,25 +24,33 @@ def _H(x):
24
24
return jnp .conj (_T (x ))
25
25
26
26
27
- @custom_jvp
28
- def svd_wrapper (a ):
27
+ @partial ( custom_jvp , nondiff_argnums = ( 1 ,))
28
+ def svd_wrapper (a , use_qr = False ):
29
29
check_arraylike ("jnp.linalg.svd" , a )
30
30
(a ,) = promote_dtypes_inexact (jnp .asarray (a ))
31
31
32
- result = lax_svd (a , full_matrices = False , compute_uv = True )
33
-
34
- result = lax .cond (
35
- jnp .isnan (jnp .sum (result [1 ])),
36
- lambda matrix , _ : lax_svd (
37
- matrix ,
32
+ if use_qr :
33
+ result = lax_svd (
34
+ a ,
38
35
full_matrices = False ,
39
36
compute_uv = True ,
40
37
algorithm = lax .linalg .SvdAlgorithm .QR ,
41
- ),
42
- lambda _ , res : res ,
43
- a ,
44
- result ,
45
- )
38
+ )
39
+ else :
40
+ result = lax_svd (a , full_matrices = False , compute_uv = True )
41
+
42
+ result = lax .cond (
43
+ jnp .isnan (jnp .sum (result [1 ])),
44
+ lambda matrix , _ : lax_svd (
45
+ matrix ,
46
+ full_matrices = False ,
47
+ compute_uv = True ,
48
+ algorithm = lax .linalg .SvdAlgorithm .QR ,
49
+ ),
50
+ lambda _ , res : res ,
51
+ a ,
52
+ result ,
53
+ )
46
54
47
55
return result
48
56
@@ -51,10 +59,10 @@ def _svd_jvp_rule_impl(primals, tangents, only_u_or_vt=None, use_qr=False):
51
59
(A ,) = primals
52
60
(dA ,) = tangents
53
61
54
- if use_qr :
62
+ if use_qr and only_u_or_vt is not None :
55
63
U , s , Vt = _svd_only_u_vt_impl (A , u_or_vt = 2 , use_qr = True )
56
64
else :
57
- U , s , Vt = svd_wrapper (A )
65
+ U , s , Vt = svd_wrapper (A , use_qr = use_qr )
58
66
59
67
Ut , V = _H (U ), _H (Vt )
60
68
s_dim = s [..., None , :]
@@ -106,8 +114,8 @@ def _svd_jvp_rule_impl(primals, tangents, only_u_or_vt=None, use_qr=False):
106
114
107
115
108
116
@svd_wrapper .defjvp
109
- def _svd_jvp_rule (primals , tangents ):
110
- return _svd_jvp_rule_impl (primals , tangents )
117
+ def _svd_jvp_rule (use_qr , primals , tangents ):
118
+ return _svd_jvp_rule_impl (primals , tangents , use_qr = use_qr )
111
119
112
120
113
121
jax .ffi .register_ffi_target (
@@ -293,10 +301,11 @@ def _svd_only_vt_jvp_rule(use_qr, primals, tangents):
293
301
return _svd_jvp_rule_impl (primals , tangents , only_u_or_vt = "Vt" , use_qr = use_qr )
294
302
295
303
296
- @partial (jit , inline = True , static_argnums = (1 ,))
304
+ @partial (jit , inline = True , static_argnums = (1 , 2 ))
297
305
def gauge_fixed_svd (
298
306
matrix : jnp .ndarray ,
299
307
only_u_or_vh = None ,
308
+ use_qr = False ,
300
309
) -> Tuple [jnp .ndarray , jnp .ndarray , jnp .ndarray ]:
301
310
"""
302
311
Calculate the gauge-fixed (also called sign-fixed) SVD. To this end, each
@@ -316,13 +325,13 @@ def gauge_fixed_svd(
316
325
Tuple with sign-fixed U, S and Vh of the SVD.
317
326
"""
318
327
if only_u_or_vh is None :
319
- U , S , Vh = svd_wrapper (matrix )
328
+ U , S , Vh = svd_wrapper (matrix , use_qr = use_qr )
320
329
gauge_unitary = U
321
330
elif only_u_or_vh == "U" :
322
- U , S = svd_only_u (matrix )
331
+ U , S = svd_only_u (matrix , use_qr = use_qr )
323
332
gauge_unitary = U
324
333
elif only_u_or_vh == "Vh" :
325
- S , Vh = svd_only_vt (matrix )
334
+ S , Vh = svd_only_vt (matrix , use_qr = use_qr )
326
335
gauge_unitary = Vh .T .conj ()
327
336
else :
328
337
raise ValueError ("Invalid value for parameter 'only_u_or_vh'." )
0 commit comments