Skip to content

Commit 31f605e

Browse files
authored
Remove reliance on __jax_array__ to unwrap variables. (#21719)
JAX uses `__jax_array__` to handle non-JAX types. For instance when doing `a * v` where `a` is a `jax.Array` and `v` is a `keras.Variable`, the `jax.Array.__mul__` implementation calls `v.__jax_array__()` because `v` is not a JAX type. However, `__jax_array__` did not work in all contexts, and the next version of JAX further restricts which contexts it works in. The fix rarely involves explictly calling `v.value`. Instead, we rely on existing mechanisms that are already in place to unwrap variables in a lot of contexts: - ops are always supposed to call `convert_to_tensor` on tensor inputs and `convert_to_tensor` extracts values from variables - using `keras.ops` instead of native ops (+ - * / < > & etc.) unwraps variables. It is already a best practice to use `keras.ops` instead of native ops: - to support the creation of functional models via `KerasTensor`s and their serialization - to have consistent type promotion between backends - to support sparse tensors and ragged tensors This was tested via a seperate PR #21702 that won't be submitted because of https://github.com/keras-team/keras/pull/21702/files#diff-900deadc65fc119ce93fb813e340dcb644b8eab9e7c0207bf37cdc05b8e8796e .
1 parent ce61e6b commit 31f605e

File tree

15 files changed

+84
-46
lines changed

15 files changed

+84
-46
lines changed

keras/src/backend/jax/nn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ def depthwise_conv(
396396
feature_group_count = (
397397
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
398398
)
399+
kernel = convert_to_tensor(kernel)
400+
inputs = convert_to_tensor(inputs)
399401
kernel = jnp.reshape(
400402
kernel,
401403
kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),

keras/src/backend/jax/numpy.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -543,15 +543,18 @@ def clip(x, x_min, x_max):
543543

544544
def concatenate(xs, axis=0):
545545
bcoo_count = builtins.sum(isinstance(x, jax_sparse.BCOO) for x in xs)
546-
if bcoo_count:
547-
if bcoo_count == len(xs):
548-
axis = canonicalize_axis(axis, len(xs[0].shape))
549-
return jax_sparse.bcoo_concatenate(xs, dimension=axis)
550-
else:
551-
xs = [
552-
x.todense() if isinstance(x, jax_sparse.JAXSparse) else x
553-
for x in xs
554-
]
546+
if bcoo_count == len(xs):
547+
axis = canonicalize_axis(axis, len(xs[0].shape))
548+
return jax_sparse.bcoo_concatenate(xs, dimension=axis)
549+
elif bcoo_count:
550+
xs = [
551+
x.todense()
552+
if isinstance(x, jax_sparse.JAXSparse)
553+
else convert_to_tensor(x)
554+
for x in xs
555+
]
556+
else:
557+
xs = [convert_to_tensor(x) for x in xs]
555558
return jnp.concatenate(xs, axis=axis)
556559

557560

@@ -1087,6 +1090,7 @@ def reshape(x, newshape):
10871090
if None not in output_shape:
10881091
newshape = output_shape
10891092
return jax_sparse.bcoo_reshape(x, new_sizes=newshape)
1093+
x = convert_to_tensor(x)
10901094
return jnp.reshape(x, newshape)
10911095

10921096

@@ -1149,10 +1153,12 @@ def sort(x, axis=-1):
11491153

11501154

11511155
def split(x, indices_or_sections, axis=0):
1156+
x = convert_to_tensor(x)
11521157
return jnp.split(x, indices_or_sections, axis=axis)
11531158

11541159

11551160
def stack(x, axis=0):
1161+
x = [convert_to_tensor(t) for t in x]
11561162
return jnp.stack(x, axis=axis)
11571163

11581164

@@ -1338,6 +1344,7 @@ def squeeze(x, axis=None):
13381344
axis = tuple(i for i, d in enumerate(x.shape) if d == 1)
13391345
axis = to_tuple_or_list(axis)
13401346
return jax_sparse.bcoo_squeeze(x, dimensions=axis)
1347+
x = convert_to_tensor(x)
13411348
return jnp.squeeze(x, axis=axis)
13421349

13431350

keras/src/backend/jax/optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,14 @@ def _backend_apply_gradients(self, grads, trainable_variables):
3636
new_g_accs = jax.lax.cond(
3737
is_update_step,
3838
lambda: [jnp.zeros(g.shape, dtype=g.dtype) for g in acc_grads],
39-
lambda: [g + acc_g for g, acc_g in zip(grads, acc_grads)],
39+
lambda: [g + acc_g.value for g, acc_g in zip(grads, acc_grads)],
4040
)
4141

4242
grads = jax.lax.cond(
4343
is_update_step,
4444
lambda: [
45-
(g + acc_g) / steps for g, acc_g in zip(grads, acc_grads)
45+
(g + acc_g.value) / steps
46+
for g, acc_g in zip(grads, acc_grads)
4647
],
4748
lambda: list(grads),
4849
)

keras/src/layers/attention/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _calculate_scores(self, query, key):
121121
if self.score_mode == "dot":
122122
scores = ops.matmul(query, ops.transpose(key, axes=[0, 2, 1]))
123123
if self.scale is not None:
124-
scores *= self.scale
124+
scores = ops.multiply(scores, self.scale)
125125
elif self.score_mode == "concat":
126126
# Reshape tensors to enable broadcasting.
127127
# Reshape into [batch_size, Tq, 1, dim].

keras/src/layers/layer_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def __init__(self):
678678
def call(self, x):
679679
# Should not autocast.
680680
assertDType(self.v, "float32")
681-
return ops.cast(x, "float32") + self.v
681+
return ops.add(ops.cast(x, "float32"), self.v)
682682

683683
# A layer that is explicitly full precision.
684684
class InnerLayerTwo(layers.Layer):
@@ -694,7 +694,7 @@ def __init__(self):
694694
def call(self, x):
695695
# Should not autocast.
696696
assertDType(self.v, "float32")
697-
return x + self.v
697+
return ops.add(x, self.v)
698698

699699
# A layer that is explicitly mixed precision but with autocast=False
700700
# weight.
@@ -732,7 +732,7 @@ def call(self, x):
732732
# Should autocast.
733733
assertDType(self.v, "float16")
734734
return self.inner_three(
735-
self.inner_two(self.inner_one(x + self.v))
735+
self.inner_two(self.inner_one(ops.add(x, self.v)))
736736
)
737737

738738
layer = MixedPrecisionLayer()
@@ -935,7 +935,7 @@ def call(self, x):
935935
x = x + backend.random.normal(
936936
shape=(), seed=self._seed_generator
937937
)
938-
return x + self.tw + self.ntw
938+
return ops.add(x, ops.add(self.tw, self.ntw))
939939

940940
data = np.random.random((3, 4))
941941
layer = TestLayer()

keras/src/layers/normalization/layer_normalization_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def test_correctness(self):
9999
).astype("float32")
100100

101101
out = layer(inputs)
102-
out -= layer.beta
103-
out /= layer.gamma
102+
out = ops.subtract(out, layer.beta)
103+
out = ops.divide(out, layer.gamma)
104104

105105
self.assertAllClose(ops.mean(out), 0.0, atol=1e-1)
106106
self.assertAllClose(ops.std(out), 1.0, atol=1e-1)

keras/src/layers/normalization/rms_normalization_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@ def test_correctness(self):
3838
inputs = ops.convert_to_tensor(inputs)
3939

4040
out = layer(inputs)
41-
expected = (
42-
inputs
43-
* ops.rsqrt(ops.mean(ops.square(inputs), axis=-1, keepdims=True))
44-
* layer.scale
41+
expected = ops.multiply(
42+
ops.multiply(
43+
inputs,
44+
ops.rsqrt(ops.mean(ops.square(inputs), axis=-1, keepdims=True)),
45+
),
46+
layer.scale,
4547
)
4648

4749
self.assertAllClose(out, expected, atol=1e-1)

keras/src/layers/rnn/gru.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def call(self, inputs, states, training=False):
261261
matrix_x = ops.matmul(inputs, self.kernel)
262262
if self.use_bias:
263263
# biases: bias_z_i, bias_r_i, bias_h_i
264-
matrix_x += input_bias
264+
matrix_x = ops.add(matrix_x, input_bias)
265265

266266
x_z, x_r, x_h = ops.split(matrix_x, 3, axis=-1)
267267

keras/src/layers/rnn/lstm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,9 @@ def call(self, inputs, states, training=False):
276276

277277
z = ops.matmul(inputs, self.kernel)
278278

279-
z += ops.matmul(h_tm1, self.recurrent_kernel)
279+
z = ops.add(z, ops.matmul(h_tm1, self.recurrent_kernel))
280280
if self.use_bias:
281-
z += self.bias
281+
z = ops.add(z, self.bias)
282282

283283
z = ops.split(z, 4, axis=1)
284284
c, o = self._compute_carry_and_output_fused(z, c_tm1)

keras/src/layers/rnn/simple_rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def call(self, sequence, states, training=False):
160160
sequence = sequence * dp_mask
161161
h = ops.matmul(sequence, self.kernel)
162162
if self.bias is not None:
163-
h += self.bias
163+
h = ops.add(h, self.bias)
164164

165165
if training and rec_dp_mask is not None:
166166
prev_output = prev_output * rec_dp_mask

0 commit comments

Comments
 (0)