From 71e916b9444b8815e2142da21f46c67c30118a85 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 28 Oct 2024 18:26:36 +0100 Subject: [PATCH 1/2] Use jnp.interp for the interpolation instead of custom code --- jax_cosmo/scipy/interpolate.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/jax_cosmo/scipy/interpolate.py b/jax_cosmo/scipy/interpolate.py index 40ad80d..ab75e2a 100644 --- a/jax_cosmo/scipy/interpolate.py +++ b/jax_cosmo/scipy/interpolate.py @@ -9,11 +9,11 @@ from jax.numpy import zeros from jax.tree_util import register_pytree_node_class -__all__ = ["interp"] +__all__ = ["interp", "interp_legacy"] @functools.partial(vmap, in_axes=(0, None, None)) -def interp(x, xp, fp): +def interp_legacy(x, xp, fp): """ Simple equivalent of np.interp that compute a linear interpolation. @@ -40,6 +40,15 @@ def interp(x, xp, fp): return a * x + b +def interp(x, xp, fp, left=None, right=None, period=None): + """ + Calling the jax implementation of interp + + x, xp, fp need to be 1d arrays + """ + return np.interp(x, xp, fp, left=left, right=right, period=period) + + @register_pytree_node_class class InterpolatedUnivariateSpline(object): def __init__(self, x, y, k=3, endpoints="not-a-knot", coefficients=None): From da97d2a60ce07eb6a530566c14b3c7af488b900c Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 28 Oct 2024 18:27:15 +0100 Subject: [PATCH 2/2] Update test --- tests/test_spline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_spline.py b/tests/test_spline.py index 725039d..39734cb 100644 --- a/tests/test_spline.py +++ b/tests/test_spline.py @@ -1,6 +1,6 @@ # This module tests the InterpolatedUnivariateSpline implementation against # SciPy -from jax.config import config +from jax import config config.update("jax_enable_x64", True) import jax.numpy as np