Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jnp.interp #129

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions jax_cosmo/scipy/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from jax.numpy import zeros
from jax.tree_util import register_pytree_node_class

__all__ = ["interp"]
__all__ = ["interp", "interp_legacy"]


@functools.partial(vmap, in_axes=(0, None, None))
def interp(x, xp, fp):
def interp_legacy(x, xp, fp):
"""
Simple equivalent of np.interp that compute a linear interpolation.

Expand All @@ -40,6 +40,15 @@ def interp(x, xp, fp):
return a * x + b


def interp(x, xp, fp, left=None, right=None, period=None):
"""
Calling the jax implementation of interp

x, xp, fp need to be 1d arrays
"""
return np.interp(x, xp, fp, left=left, right=right, period=period)


@register_pytree_node_class
class InterpolatedUnivariateSpline(object):
def __init__(self, x, y, k=3, endpoints="not-a-knot", coefficients=None):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_spline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This module tests the InterpolatedUnivariateSpline implementation against
# SciPy
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
import jax.numpy as np
Expand Down