Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jnp.interp #129

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

ASKabalan
Copy link

Using jax numpy interpolate instead of the custom code that was made before jnp.interp was implemented

Notebook proving that jnp.interp is more accurate and much faster

@jecampagne
Copy link
Collaborator

Although if I agree to use jax.numpy.interp as it is now implemented in the JAX lib. I was curious to see where jax-cosmo fails and to cure the problem.

Here is my 1-cent (in the context of jax-cosmo one would use the decorator and switch jnp an np due to the jax-ccsmo naming convention (ie. np would not be the numpy lib)

#@functools.partial(jax.vmap, in_axes=(0, None, None))
def interp_modif(x, xp, fp):
    """
    Simple equivalent of np.interp that compute a linear interpolation.

    We are not doing any checks, so make sure your query points are lying
    inside the array.

    x, xp, fp need to be 1d arrays
    """

    x = jnp.atleast_1d(x)

    # First we find the nearest neighbour
    ind = jnp.argmin((x - xp) ** 2)

    # Perform linear interpolation
    ind = jnp.clip(ind, 0, len(xp) - 2)

    xi = xp[ind]


    # Figure out if we are on the right or the left of nearest
    s = jnp.sign(jnp.clip(x, xp[0], xp[-2]) - xi)
    s =jax.lax.convert_element_type(s,jnp.int32)

    one = jnp.copysign(1, s)
    one = jax.lax.convert_element_type(one,jnp.int32)
    
    a = (fp[ind + one] - fp[ind]) / (
        xp[ind + one] - xp[ind]
    )
    b = fp[ind] - a * xp[ind]
    return jnp.squeeze(a * x + b)

The failure comes essentialy from the two clipping lower bounds. I have also remove the casting to int64.
You can see the result in the Google nb: https://colab.research.google.com/drive/1QhFG-G0J8Tyq9YPUuxdojvJdNEaDGoVi?usp=sharing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants