-
Notifications
You must be signed in to change notification settings - Fork 97
.at[] indexing doesn't work with brainpy array #743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks for your feedback! pip install jax==0.4.38 |
Still same exception for jax==0.4.38 (although now its at lax_numpy.py:12096, but same code) |
I apologize for the confusion. I don't believe this is a JAX compatibility issue. Since BrainPy has further encapsulated JAX's Array implementation, you can directly implement the operations mentioned above using the following code. I hope this helps answer your question. import brainpy.math as bp
x = bp.arange(10)
x[x] += x |
Thank you for clarifying this. However, the behaviour is different for when there are duplicate entries in the index array: import brainpy.math as bp
x = bp.arange(10)
idx = bp.zeros(10, dtype=int)
print(x.at[jax.numpy.array(idx)].add(x))
# prints [45 1 2 3 4 5 6 7 8 9]
x[idx] += x
print(x)
# prints [9, 1, 2, 3, 4, 5, 6, 7, 8, 9] We very much want the first behaviour |
Sorry for the late reply. import brainpy.math as bm
x = bm.arange(10)
idx = bm.zeros(10, dtype=int)
x = x.at[bm.as_jax(idx)].add(x)
print(x)
# prints [45 1 2 3 4 5 6 7 8 9] |
Array.at[arr1].add(arr2) crashes
Expected output:
Actual output:
Where the error seems to originate from the index array getting interpreted as a single value instead of an array
Casting to a jax array seems to work as a intermediate solution but is ofc not desired
Versions:
upgrading to jax==0.6 doesn't help
The text was updated successfully, but these errors were encountered: