diff --git a/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py b/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py index 0bdfa1a..7816928 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py +++ b/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py @@ -193,7 +193,7 @@ def test_forward_pass(self, ragged: bool, stacked: bool): # Add pseudo gradients to the inputs. embedding_variables = jax.tree.map( - lambda table: (table, None), + lambda table: embedding.EmbeddingVariables(table=table, slot=()), sharded_tables, ) @@ -288,7 +288,7 @@ def test_model_sharding( # Add pseudo gradients to the inputs. embedding_variables = jax.tree.map( - lambda table: (table, None), + lambda table: embedding.EmbeddingVariables(table=table, slot=()), sharded_tables, ) @@ -479,7 +479,8 @@ def test_autograd( ) ) sharded_table_and_slot_variables = typing.cast( - dict[str, tuple[jax.Array, ...]], sharded_table_and_slot_variables + dict[str, embedding.EmbeddingVariables], + sharded_table_and_slot_variables, ) # Shard samples for lookup query. diff --git a/keras_rs/src/layers/embedding/jax/test_utils.py b/keras_rs/src/layers/embedding/jax/test_utils.py index 55f2101..b051781 100644 --- a/keras_rs/src/layers/embedding/jax/test_utils.py +++ b/keras_rs/src/layers/embedding/jax/test_utils.py @@ -7,6 +7,7 @@ import keras import numpy as np from jax import numpy as jnp +from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec @@ -142,7 +143,7 @@ def create_tables( def create_table_and_slot_variables( table_specs: Nested[TableSpec], keys: Nested[ArrayLike] | None = None, -) -> Nested[ArrayLike]: +) -> Nested[embedding.EmbeddingVariables]: """Creates and initializes embedding tables and slot variables. Args: @@ -164,7 +165,7 @@ def create_table_and_slot_variables( def _create_table_and_slot_variables( table_spec: TableSpec, key: ArrayLike, - ) -> tuple[jax.Array, tuple[jax.Array, ...]]: + ) -> embedding.EmbeddingVariables: slot_initializers = table_spec.optimizer.slot_variables_initializers() num_slot_variables = len(keras.tree.flatten(slot_initializers)) slot_keys = jnp.unstack(jax.random.split(key, num_slot_variables)) @@ -178,10 +179,10 @@ def _create_table_and_slot_variables( slot_initializers, slot_keys, ) - return (table, slot_variables) + return embedding.EmbeddingVariables(table, slot_variables) # Initialize tables. - output: Nested[ArrayLike] = jax.tree.map( + output: Nested[embedding.EmbeddingVariables] = jax.tree.map( _create_table_and_slot_variables, table_specs, keys, @@ -311,14 +312,14 @@ def _create_samples( def stack_shard_and_put_tables( table_specs: Nested[TableSpec], - tables: Nested[jax.Array], + tables: Nested[embedding.EmbeddingVariables], num_shards: int, sharding: jax.sharding.Sharding, -) -> dict[str, Nested[jax.Array]]: +) -> dict[str, embedding.EmbeddingVariables]: sharded_tables = embedding_utils.stack_and_shard_tables( table_specs, tables, num_shards ) - output: dict[str, Nested[jax.Array]] = jax.device_put( + output: dict[str, embedding.EmbeddingVariables] = jax.device_put( jax.tree.map( # Flatten shard dimension to allow auto-sharding to split the array. lambda table: table.reshape((-1, table.shape[-1])), @@ -469,27 +470,24 @@ def compute_expected_lookup_grad( def _update_table_and_slot_variables( table_spec: TableSpec, grad: jax.Array, - table_and_slot_variables: tuple[jax.Array, tuple[jax.Array, ...]], -) -> tuple[ - jax.Array, - embedding_spec.SGDSlotVariables | embedding_spec.AdagradSlotVariables, -]: + table_and_slot_variables: embedding.EmbeddingVariables, +) -> embedding.EmbeddingVariables: """Updates a table and its slot variables based on the gradient.""" - table = table_and_slot_variables[0] + table = table_and_slot_variables.table optimizer = table_spec.optimizer # Adagrad, update and apply gradient accumulator. if isinstance(optimizer, embedding_spec.AdagradOptimizerSpec): - accumulator = table_and_slot_variables[1][0] + accumulator = table_and_slot_variables.slot.accumulator accumulator = accumulator + grad * grad learning_rate = optimizer.get_learning_rate(0) / jnp.sqrt(accumulator) - return ( + return embedding.EmbeddingVariables( table - learning_rate * grad, embedding_spec.AdagradSlotVariables(accumulator=accumulator), ) # SGD - return ( + return embedding.EmbeddingVariables( table - optimizer.get_learning_rate(0) * grad, embedding_spec.SGDSlotVariables(), ) @@ -500,8 +498,8 @@ def compute_expected_updates( feature_samples: Nested[FeatureSamples], activation_gradients: Nested[jax.Array], table_specs: Nested[TableSpec], - table_and_slot_variables: Nested[jax.Array], -) -> Nested[jax.Array]: + table_and_slot_variables: Nested[embedding.EmbeddingVariables], +) -> Nested[embedding.EmbeddingVariables]: """Computes the expected updates for a given embedding lookup. Args: @@ -522,7 +520,7 @@ def compute_expected_updates( ) # Apply updates per table. - output: Nested[jax.Array] = jax.tree.map( + output: Nested[embedding.EmbeddingVariables] = jax.tree.map( _update_table_and_slot_variables, table_specs, table_grads,