Skip to content

Commit 0dc44c7

Browse files
Refactor: Use embedding.EmbeddingVariables in Keras-RS JAX embedding tests.
This change updates the Keras-RS JAX embedding tests to use the `embedding.EmbeddingVariables` dataclass from `jax_tpu_embedding` for representing embedding tables and slot variables, instead of a custom tuple structure. This involves updating type hints, variable access, and build dependencies.
1 parent cee2286 commit 0dc44c7

File tree

2 files changed

+23
-22
lines changed

2 files changed

+23
-22
lines changed

keras_rs/src/layers/embedding/jax/embedding_lookup_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_forward_pass(self, ragged: bool, stacked: bool):
193193

194194
# Add pseudo gradients to the inputs.
195195
embedding_variables = jax.tree.map(
196-
lambda table: (table, None),
196+
lambda table: embedding.EmbeddingVariables(table=table, slot=()),
197197
sharded_tables,
198198
)
199199

@@ -288,7 +288,7 @@ def test_model_sharding(
288288

289289
# Add pseudo gradients to the inputs.
290290
embedding_variables = jax.tree.map(
291-
lambda table: (table, None),
291+
lambda table: embedding.EmbeddingVariables(table=table, slot=()),
292292
sharded_tables,
293293
)
294294

@@ -479,7 +479,8 @@ def test_autograd(
479479
)
480480
)
481481
sharded_table_and_slot_variables = typing.cast(
482-
dict[str, tuple[jax.Array, ...]], sharded_table_and_slot_variables
482+
dict[str, embedding.EmbeddingVariables],
483+
sharded_table_and_slot_variables,
483484
)
484485

485486
# Shard samples for lookup query.

keras_rs/src/layers/embedding/jax/test_utils.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import keras
88
import numpy as np
99
from jax import numpy as jnp
10+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
1011
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
1112
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
1213
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec
@@ -18,6 +19,8 @@
1819
ArrayLike: TypeAlias = Union[jax.Array, np.ndarray[Any, Any]]
1920
Shape: TypeAlias = tuple[int, ...]
2021

22+
EmbeddingTable: TypeAlias = ArrayLike | embedding.EmbeddingVariables
23+
2124

2225
def has_sparsecores() -> bool:
2326
device_kind = jax.devices()[0].device_kind
@@ -142,7 +145,7 @@ def create_tables(
142145
def create_table_and_slot_variables(
143146
table_specs: Nested[TableSpec],
144147
keys: Nested[ArrayLike] | None = None,
145-
) -> Nested[ArrayLike]:
148+
) -> Nested[embedding.EmbeddingVariables]:
146149
"""Creates and initializes embedding tables and slot variables.
147150
148151
Args:
@@ -164,7 +167,7 @@ def create_table_and_slot_variables(
164167
def _create_table_and_slot_variables(
165168
table_spec: TableSpec,
166169
key: ArrayLike,
167-
) -> tuple[jax.Array, tuple[jax.Array, ...]]:
170+
) -> embedding.EmbeddingVariables:
168171
slot_initializers = table_spec.optimizer.slot_variables_initializers()
169172
num_slot_variables = len(keras.tree.flatten(slot_initializers))
170173
slot_keys = jnp.unstack(jax.random.split(key, num_slot_variables))
@@ -178,10 +181,10 @@ def _create_table_and_slot_variables(
178181
slot_initializers,
179182
slot_keys,
180183
)
181-
return (table, slot_variables)
184+
return embedding.EmbeddingVariables(table, slot_variables)
182185

183186
# Initialize tables.
184-
output: Nested[ArrayLike] = jax.tree.map(
187+
output: Nested[embedding.EmbeddingVariables] = jax.tree.map(
185188
_create_table_and_slot_variables,
186189
table_specs,
187190
keys,
@@ -311,14 +314,14 @@ def _create_samples(
311314

312315
def stack_shard_and_put_tables(
313316
table_specs: Nested[TableSpec],
314-
tables: Nested[jax.Array],
317+
tables: EmbeddingTable,
315318
num_shards: int,
316319
sharding: jax.sharding.Sharding,
317-
) -> dict[str, Nested[jax.Array]]:
320+
) -> dict[str, EmbeddingTable]:
318321
sharded_tables = embedding_utils.stack_and_shard_tables(
319322
table_specs, tables, num_shards
320323
)
321-
output: dict[str, Nested[jax.Array]] = jax.device_put(
324+
output: dict[str, T] = jax.device_put(
322325
jax.tree.map(
323326
# Flatten shard dimension to allow auto-sharding to split the array.
324327
lambda table: table.reshape((-1, table.shape[-1])),
@@ -469,27 +472,24 @@ def compute_expected_lookup_grad(
469472
def _update_table_and_slot_variables(
470473
table_spec: TableSpec,
471474
grad: jax.Array,
472-
table_and_slot_variables: tuple[jax.Array, tuple[jax.Array, ...]],
473-
) -> tuple[
474-
jax.Array,
475-
embedding_spec.SGDSlotVariables | embedding_spec.AdagradSlotVariables,
476-
]:
475+
table_and_slot_variables: embedding.EmbeddingVariables,
476+
) -> embedding.EmbeddingVariables:
477477
"""Updates a table and its slot variables based on the gradient."""
478-
table = table_and_slot_variables[0]
478+
table = table_and_slot_variables.table
479479
optimizer = table_spec.optimizer
480480

481481
# Adagrad, update and apply gradient accumulator.
482482
if isinstance(optimizer, embedding_spec.AdagradOptimizerSpec):
483-
accumulator = table_and_slot_variables[1][0]
483+
accumulator = table_and_slot_variables.slot.accumulator
484484
accumulator = accumulator + grad * grad
485485
learning_rate = optimizer.get_learning_rate(0) / jnp.sqrt(accumulator)
486-
return (
486+
return embedding.EmbeddingVariables(
487487
table - learning_rate * grad,
488488
embedding_spec.AdagradSlotVariables(accumulator=accumulator),
489489
)
490490

491491
# SGD
492-
return (
492+
return embedding.EmbeddingVariables(
493493
table - optimizer.get_learning_rate(0) * grad,
494494
embedding_spec.SGDSlotVariables(),
495495
)
@@ -500,8 +500,8 @@ def compute_expected_updates(
500500
feature_samples: Nested[FeatureSamples],
501501
activation_gradients: Nested[jax.Array],
502502
table_specs: Nested[TableSpec],
503-
table_and_slot_variables: Nested[jax.Array],
504-
) -> Nested[jax.Array]:
503+
table_and_slot_variables: Nested[embedding.EmbeddingVariables],
504+
) -> Nested[embedding.EmbeddingVariables]:
505505
"""Computes the expected updates for a given embedding lookup.
506506
507507
Args:
@@ -522,7 +522,7 @@ def compute_expected_updates(
522522
)
523523

524524
# Apply updates per table.
525-
output: Nested[jax.Array] = jax.tree.map(
525+
output: Nested[embedding.EmbeddingVariables] = jax.tree.map(
526526
_update_table_and_slot_variables,
527527
table_specs,
528528
table_grads,

0 commit comments

Comments
 (0)