Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interpol crash #121

Open
AlexGKim opened this issue Feb 20, 2024 · 1 comment
Open

interpol crash #121

AlexGKim opened this issue Feb 20, 2024 · 1 comment

Comments

@AlexGKim
Copy link

This interpolation call crashes. The equivalent call with numpy.interp works fine. Unfortunately my jax fluency is poor so that I couldn't immediately solve this.

jc.scipy.interpolate.interp(numpy.array([0.5,1.5]),numpy.array([0.,1.,2.]),numpy.array([0.,1.,2.]))

---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[40], line 1
----> 1 jc.scipy.interpolate.interp(numpy.array([0.5,1.5]),numpy.array([0.,1.,2.]),numpy.array([0.,1.,2.]))

    [... skipping hidden 3 frame]

File ~/opt/anaconda3/envs/unity3/lib/python3.12/site-packages/jax_cosmo/scipy/interpolate.py:33, in interp(x, xp, fp)
     30 # Perform linear interpolation
     31 ind = np.clip(ind, 1, len(xp) - 2)
---> 33 xi = xp[ind]
     34 # Figure out if we are on the right or the left of nearest
     35 s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64)

File ~/opt/anaconda3/envs/unity3/lib/python3.12/site-packages/jax/_src/core.py:710, in Tracer.__array__(self, *args, **kw)
    709 def __array__(self, *args, **kw):
--> 710   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[].
This BatchTracer with object id 5183025136 was created on line:
  /var/folders/91/bt9dzsj545130th75px54m0h0000gq/T/ipykernel_25064/3155834959.py:1 (<module>)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@AlexGKim
Copy link
Author

This implementation seems to work.

def interp(x, xp, fp):
    """
    Simple equivalent of np.interp that compute a linear interpolation.

    We are not doing any checks, so make sure your query points are lying
    inside the array.

    TODO: Implement proper interpolation!

    x, xp, fp need to be 1d arrays
    """
    # First we find the nearest neighbour
    ind = jnp.argmin((x - xp) ** 2)

    # Perform linear interpolation
    ind = jnp.clip(ind, 1, len(xp) - 2)
    xi = jnp.asarray(xp)[ind]

    # Figure out if we are on the right or the left of nearest
    s = jnp.sign(jnp.clip(x, jnp.asarray(xp)[1], jnp.asarray(xp)[-2]) - xi).astype(np.int64)
    a = (jnp.asarray(fp)[ind + jnp.copysign(1, s).astype(np.int64)] - jnp.asarray(fp)[ind]) / (
        jnp.asarray(xp)[ind + jnp.copysign(1, s).astype(np.int64)] - jnp.asarray(xp)[ind]
    )
    b = jnp.asarray(fp)[ind] - a * jnp.asarray(xp)[ind]
    return a * x + b

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant