Skip to content

Commit

Permalink
Fix jax numpy deprecation
Browse files Browse the repository at this point in the history
`jax.numpy.product` was deprecated in jax 0.4.12 and removed in 0.4.16 in favor of `jax.numpy.prod`.

See changelog
https://jax.readthedocs.io/en/latest/changelog.html
  • Loading branch information
aboucaud committed Apr 29, 2024
1 parent 84db7be commit b0fca8f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax_cosmo/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _block_det(sparse, k, N, P):
v = sparse[k + 1 : N, k : k + 1, 0:P]
Sinv_v = sparse_dot_sparse(inv(S), v)
M = sparse[k, k] - sparse_dot_sparse(u, Sinv_v)
sign = np.product(np.sign(M))
sign = np.prod(np.sign(M))
logdet = np.sum(np.log(np.abs(M)))
return sign, logdet

Expand Down Expand Up @@ -354,7 +354,7 @@ def slogdet(sparse):
"""
sparse = check_sparse(sparse, square=True)
N, _, P = sparse.shape
sign = np.product(np.sign(sparse[-1, -1]))
sign = np.prod(np.sign(sparse[-1, -1]))
logdet = np.sum(np.log(np.abs(sparse[-1, -1])))
# The individual blocks can be calculated in any order so there
# should be a better way to express this using lax.map but I
Expand Down

0 comments on commit b0fca8f

Please sign in to comment.