Skip to content

.at[] indexing doesn't work with brainpy array #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
llandsmeer opened this issue Apr 21, 2025 · 5 comments
Open

.at[] indexing doesn't work with brainpy array #743

llandsmeer opened this issue Apr 21, 2025 · 5 comments
Labels
bug Something isn't working

Comments

@llandsmeer
Copy link

Array.at[arr1].add(arr2) crashes

import brainpy.math as bp
x = bp.arange(10)
x.at[x].add(x)

Expected output:

Array([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18], dtype=int32)

Actual output:

File ~/.local/lib/python3.10/site-packages/jax/_src/numpy/indexing.py:817, in index_to_gather(x_shape, idx, normalize_indices)
    814   if normalize_indices:
    815     advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
    816                       for e, i, j in advanced_pairs)
--> 817   advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
    819 x_axis = 0  # Current axis in x.
    820 y_axis = 0  # Current axis in y, before collapsing. See below.

ValueError: not enough values to unpack (expected 3, got 0)

Where the error seems to originate from the index array getting interpreted as a single value instead of an array

Casting to a jax array seems to work as a intermediate solution but is ofc not desired

[ins] In [12]: import brainpy.math as bp
          ...: x = bp.arange(10)
          ...: x.at[jnp.array(x)].add(x)
Out[12]: Array([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18], dtype=int32)

Versions:

pip freeze | grep -E 'jax|brainpy|taichi'
brainpy==2.6.0.post20250420
jax==0.5.3
jaxlib==0.5.3
jaxtyping==0.3.1
taichi==1.7.3

upgrading to jax==0.6 doesn't help

@Routhleck
Copy link
Member

Thanks for your feedback!
Maybe downgrade JAX's version will work?

pip install jax==0.4.38

@llandsmeer
Copy link
Author

Still same exception for jax==0.4.38 (although now its at lax_numpy.py:12096, but same code)

@Routhleck
Copy link
Member

I apologize for the confusion. I don't believe this is a JAX compatibility issue. Since BrainPy has further encapsulated JAX's Array implementation, you can directly implement the operations mentioned above using the following code. I hope this helps answer your question.

import brainpy.math as bp

x = bp.arange(10)
x[x] += x

@llandsmeer
Copy link
Author

Thank you for clarifying this. However, the behaviour is different for when there are duplicate entries in the index array:

import brainpy.math as bp

x = bp.arange(10)
idx = bp.zeros(10, dtype=int)


print(x.at[jax.numpy.array(idx)].add(x))
# prints [45  1  2  3  4  5  6  7  8  9]

x[idx] += x
print(x)
# prints [9, 1, 2, 3, 4, 5, 6, 7, 8, 9]

We very much want the first behaviour

@Routhleck
Copy link
Member

Sorry for the late reply.
It might only be possible to convert the indices by using methods like bm.as_jax() or jax.numpy.array() to achieve this.

import brainpy.math as bm

x = bm.arange(10)
idx = bm.zeros(10, dtype=int)

x = x.at[bm.as_jax(idx)].add(x)
print(x)
# prints [45  1  2  3  4  5  6  7  8  9]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants