diff --git a/jax_cosmo/scipy/interpolate.py b/jax_cosmo/scipy/interpolate.py index 40ad80d..ab75e2a 100644 --- a/jax_cosmo/scipy/interpolate.py +++ b/jax_cosmo/scipy/interpolate.py @@ -9,11 +9,11 @@ from jax.numpy import zeros from jax.tree_util import register_pytree_node_class -__all__ = ["interp"] +__all__ = ["interp", "interp_legacy"] @functools.partial(vmap, in_axes=(0, None, None)) -def interp(x, xp, fp): +def interp_legacy(x, xp, fp): """ Simple equivalent of np.interp that compute a linear interpolation. @@ -40,6 +40,15 @@ def interp(x, xp, fp): return a * x + b +def interp(x, xp, fp, left=None, right=None, period=None): + """ + Calling the jax implementation of interp + + x, xp, fp need to be 1d arrays + """ + return np.interp(x, xp, fp, left=left, right=right, period=period) + + @register_pytree_node_class class InterpolatedUnivariateSpline(object): def __init__(self, x, y, k=3, endpoints="not-a-knot", coefficients=None): diff --git a/tests/test_spline.py b/tests/test_spline.py index 725039d..39734cb 100644 --- a/tests/test_spline.py +++ b/tests/test_spline.py @@ -1,6 +1,6 @@ # This module tests the InterpolatedUnivariateSpline implementation against # SciPy -from jax.config import config +from jax import config config.update("jax_enable_x64", True) import jax.numpy as np