You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This interpolation call crashes. The equivalent call with numpy.interp works fine. Unfortunately my jax fluency is poor so that I couldn't immediately solve this.
---------------------------------------------------------------------------
TracerArrayConversionError Traceback (most recent call last)
Cell In[40], line 1
----> 1 jc.scipy.interpolate.interp(numpy.array([0.5,1.5]),numpy.array([0.,1.,2.]),numpy.array([0.,1.,2.]))
[... skipping hidden 3 frame]
File ~/opt/anaconda3/envs/unity3/lib/python3.12/site-packages/jax_cosmo/scipy/interpolate.py:33, in interp(x, xp, fp)
30 # Perform linear interpolation
31 ind = np.clip(ind, 1, len(xp) - 2)
---> 33 xi = xp[ind]
34 # Figure out if we are on the right or the left of nearest
35 s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64)
File ~/opt/anaconda3/envs/unity3/lib/python3.12/site-packages/jax/_src/core.py:710, in Tracer.__array__(self, *args, **kw)
709 def __array__(self, *args, **kw):
--> 710 raise TracerArrayConversionError(self)
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[].
This BatchTracer with object id 5183025136 was created on line:
/var/folders/91/bt9dzsj545130th75px54m0h0000gq/T/ipykernel_25064/3155834959.py:1 (<module>)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The text was updated successfully, but these errors were encountered:
def interp(x, xp, fp):
"""
Simple equivalent of np.interp that compute a linear interpolation.
We are not doing any checks, so make sure your query points are lying
inside the array.
TODO: Implement proper interpolation!
x, xp, fp need to be 1d arrays
"""
# First we find the nearest neighbour
ind = jnp.argmin((x - xp) ** 2)
# Perform linear interpolation
ind = jnp.clip(ind, 1, len(xp) - 2)
xi = jnp.asarray(xp)[ind]
# Figure out if we are on the right or the left of nearest
s = jnp.sign(jnp.clip(x, jnp.asarray(xp)[1], jnp.asarray(xp)[-2]) - xi).astype(np.int64)
a = (jnp.asarray(fp)[ind + jnp.copysign(1, s).astype(np.int64)] - jnp.asarray(fp)[ind]) / (
jnp.asarray(xp)[ind + jnp.copysign(1, s).astype(np.int64)] - jnp.asarray(xp)[ind]
)
b = jnp.asarray(fp)[ind] - a * jnp.asarray(xp)[ind]
return a * x + b
This interpolation call crashes. The equivalent call with numpy.interp works fine. Unfortunately my jax fluency is poor so that I couldn't immediately solve this.
jc.scipy.interpolate.interp(numpy.array([0.5,1.5]),numpy.array([0.,1.,2.]),numpy.array([0.,1.,2.]))
The text was updated successfully, but these errors were encountered: