Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions keras_rs/src/layers/embedding/jax/embedding_lookup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_forward_pass(self, ragged: bool, stacked: bool):

# Add pseudo gradients to the inputs.
embedding_variables = jax.tree.map(
lambda table: (table, None),
lambda table: embedding.EmbeddingVariables(table=table, slot=()),
sharded_tables,
)

Expand Down Expand Up @@ -288,7 +288,7 @@ def test_model_sharding(

# Add pseudo gradients to the inputs.
embedding_variables = jax.tree.map(
lambda table: (table, None),
lambda table: embedding.EmbeddingVariables(table=table, slot=()),
sharded_tables,
)

Expand Down Expand Up @@ -479,7 +479,8 @@ def test_autograd(
)
)
sharded_table_and_slot_variables = typing.cast(
dict[str, tuple[jax.Array, ...]], sharded_table_and_slot_variables
dict[str, embedding.EmbeddingVariables],
sharded_table_and_slot_variables,
)

# Shard samples for lookup query.
Expand Down
36 changes: 17 additions & 19 deletions keras_rs/src/layers/embedding/jax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import keras
import numpy as np
from jax import numpy as jnp
from jax_tpu_embedding.sparsecore.lib.nn import embedding
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec
Expand Down Expand Up @@ -142,7 +143,7 @@ def create_tables(
def create_table_and_slot_variables(
table_specs: Nested[TableSpec],
keys: Nested[ArrayLike] | None = None,
) -> Nested[ArrayLike]:
) -> Nested[embedding.EmbeddingVariables]:
"""Creates and initializes embedding tables and slot variables.

Args:
Expand All @@ -164,7 +165,7 @@ def create_table_and_slot_variables(
def _create_table_and_slot_variables(
table_spec: TableSpec,
key: ArrayLike,
) -> tuple[jax.Array, tuple[jax.Array, ...]]:
) -> embedding.EmbeddingVariables:
slot_initializers = table_spec.optimizer.slot_variables_initializers()
num_slot_variables = len(keras.tree.flatten(slot_initializers))
slot_keys = jnp.unstack(jax.random.split(key, num_slot_variables))
Expand All @@ -178,10 +179,10 @@ def _create_table_and_slot_variables(
slot_initializers,
slot_keys,
)
return (table, slot_variables)
return embedding.EmbeddingVariables(table, slot_variables)

# Initialize tables.
output: Nested[ArrayLike] = jax.tree.map(
output: Nested[embedding.EmbeddingVariables] = jax.tree.map(
_create_table_and_slot_variables,
table_specs,
keys,
Expand Down Expand Up @@ -311,14 +312,14 @@ def _create_samples(

def stack_shard_and_put_tables(
table_specs: Nested[TableSpec],
tables: Nested[jax.Array],
tables: Nested[embedding.EmbeddingVariables],
num_shards: int,
sharding: jax.sharding.Sharding,
) -> dict[str, Nested[jax.Array]]:
) -> dict[str, embedding.EmbeddingVariables]:
sharded_tables = embedding_utils.stack_and_shard_tables(
table_specs, tables, num_shards
)
output: dict[str, Nested[jax.Array]] = jax.device_put(
output: dict[str, embedding.EmbeddingVariables] = jax.device_put(
jax.tree.map(
# Flatten shard dimension to allow auto-sharding to split the array.
lambda table: table.reshape((-1, table.shape[-1])),
Expand Down Expand Up @@ -469,27 +470,24 @@ def compute_expected_lookup_grad(
def _update_table_and_slot_variables(
table_spec: TableSpec,
grad: jax.Array,
table_and_slot_variables: tuple[jax.Array, tuple[jax.Array, ...]],
) -> tuple[
jax.Array,
embedding_spec.SGDSlotVariables | embedding_spec.AdagradSlotVariables,
]:
table_and_slot_variables: embedding.EmbeddingVariables,
) -> embedding.EmbeddingVariables:
"""Updates a table and its slot variables based on the gradient."""
table = table_and_slot_variables[0]
table = table_and_slot_variables.table
optimizer = table_spec.optimizer

# Adagrad, update and apply gradient accumulator.
if isinstance(optimizer, embedding_spec.AdagradOptimizerSpec):
accumulator = table_and_slot_variables[1][0]
accumulator = table_and_slot_variables.slot.accumulator
accumulator = accumulator + grad * grad
learning_rate = optimizer.get_learning_rate(0) / jnp.sqrt(accumulator)
return (
return embedding.EmbeddingVariables(
table - learning_rate * grad,
embedding_spec.AdagradSlotVariables(accumulator=accumulator),
)

# SGD
return (
return embedding.EmbeddingVariables(
table - optimizer.get_learning_rate(0) * grad,
embedding_spec.SGDSlotVariables(),
)
Expand All @@ -500,8 +498,8 @@ def compute_expected_updates(
feature_samples: Nested[FeatureSamples],
activation_gradients: Nested[jax.Array],
table_specs: Nested[TableSpec],
table_and_slot_variables: Nested[jax.Array],
) -> Nested[jax.Array]:
table_and_slot_variables: Nested[embedding.EmbeddingVariables],
) -> Nested[embedding.EmbeddingVariables]:
"""Computes the expected updates for a given embedding lookup.

Args:
Expand All @@ -522,7 +520,7 @@ def compute_expected_updates(
)

# Apply updates per table.
output: Nested[jax.Array] = jax.tree.map(
output: Nested[embedding.EmbeddingVariables] = jax.tree.map(
_update_table_and_slot_variables,
table_specs,
table_grads,
Expand Down