You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Remove reliance on __jax_array__ to unwrap variables. (#21719)
JAX uses `__jax_array__` to handle non-JAX types. For instance when doing `a * v` where `a` is a `jax.Array` and `v` is a `keras.Variable`, the `jax.Array.__mul__` implementation calls `v.__jax_array__()` because `v` is not a JAX type.
However, `__jax_array__` did not work in all contexts, and the next version of JAX further restricts which contexts it works in.
The fix rarely involves explictly calling `v.value`. Instead, we rely on existing mechanisms that are already in place to unwrap variables in a lot of contexts:
- ops are always supposed to call `convert_to_tensor` on tensor inputs and `convert_to_tensor` extracts values from variables
- using `keras.ops` instead of native ops (+ - * / < > & etc.) unwraps variables. It is already a best practice to use `keras.ops` instead of native ops:
- to support the creation of functional models via `KerasTensor`s and their serialization
- to have consistent type promotion between backends
- to support sparse tensors and ragged tensors
This was tested via a seperate PR #21702 that won't be submitted because of https://github.com/keras-team/keras/pull/21702/files#diff-900deadc65fc119ce93fb813e340dcb644b8eab9e7c0207bf37cdc05b8e8796e .
0 commit comments