From 424271f55d7556943938bebeca35521bbb37fdd8 Mon Sep 17 00:00:00 2001 From: Aditya Gupta Date: Sun, 28 Sep 2025 13:44:15 -0700 Subject: [PATCH] [JAX SC] Implement pipelined SparseCore embedding layer in Flax This change introduces `PipelinedSparseCoreEmbed` to enable pipelining of SparseCore operations with TensorCore. The layer stores inputs and activations/gradients across steps to allow for overlapping computation. The Shakespeare example is updated to use this new pipelined layer, requiring changes to how variables are handled in the training loop, including using `mutable=True` in `model.apply`. A test target for the Shakespeare example is also added. With pipelining enabled, the model converges as shown below: ``` I 2025-09-28T13:41:31.529508-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 9: Loss = 7.518053 I 2025-09-28T13:41:31.599092-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 19: Loss = 7.317211 I 2025-09-28T13:41:31.668971-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 29: Loss = 7.0158334 I 2025-09-28T13:41:31.738482-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 39: Loss = 6.609592 I 2025-09-28T13:41:31.807324-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 49: Loss = 6.140646 I 2025-09-28T13:41:31.876329-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 59: Loss = 5.7092185 I 2025-09-28T13:41:31.949826-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 69: Loss = 5.3921995 I 2025-09-28T13:41:32.018904-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 79: Loss = 5.183298 I 2025-09-28T13:41:32.088231-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 89: Loss = 5.0565457 I 2025-09-28T13:41:32.156878-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 99: Loss = 4.9826536 ... I 2025-09-28T13:41:38.309954-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 999: Loss = 0.31050137 ``` PiperOrigin-RevId: 812524634 --- .../examples/models/shakespeare/flax_model.py | 3 + .../sparsecore/lib/flax/embed.py | 230 +++++++++- .../sparsecore/lib/flax/embed_pipelining.py | 237 +++++++++++ .../sparsecore/lib/flax/tests/embed_test.py | 396 ++++++++++++------ 4 files changed, 725 insertions(+), 141 deletions(-) create mode 100644 jax_tpu_embedding/sparsecore/lib/flax/embed_pipelining.py diff --git a/jax_tpu_embedding/sparsecore/examples/models/shakespeare/flax_model.py b/jax_tpu_embedding/sparsecore/examples/models/shakespeare/flax_model.py index f4c26f20..df84dd1f 100644 --- a/jax_tpu_embedding/sparsecore/examples/models/shakespeare/flax_model.py +++ b/jax_tpu_embedding/sparsecore/examples/models/shakespeare/flax_model.py @@ -20,6 +20,7 @@ from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec + shard_map = jax.experimental.shard_map.shard_map Nested = embedding.Nested @@ -38,6 +39,7 @@ class Model(nn.Module): feature_name: str = 'shakespeare_feature' mesh: jax.sharding.Mesh | None = None sharding_axis: str = 'sparsecore_sharding' + enable_pipelining: bool = False def add_sharding_constraint(self, x: jax.Array, names: tuple[str | None]): # Add a sharding constraint to the array. @@ -66,6 +68,7 @@ def __call__(self, embedding_lookup_inputs: embedding.PreprocessedInput): feature_specs=self.feature_specs, mesh=self.mesh, sharding_axis=self.sharding_axis, + enable_pipelining=self.enable_pipelining, )(embedding_lookup_inputs) # Unpack the activations. diff --git a/jax_tpu_embedding/sparsecore/lib/flax/embed.py b/jax_tpu_embedding/sparsecore/lib/flax/embed.py index 70f77bef..2100e03f 100644 --- a/jax_tpu_embedding/sparsecore/lib/flax/embed.py +++ b/jax_tpu_embedding/sparsecore/lib/flax/embed.py @@ -16,10 +16,12 @@ import functools from typing import Any, Callable, Mapping, TypeVar +import flax from flax import linen as nn from flax import typing import jax from jax.experimental import layout +import jax.numpy as jnp from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec from jax_tpu_embedding.sparsecore.utils import utils @@ -69,7 +71,32 @@ def wrapper(*args, **kwargs): class SparseCoreEmbed(nn.Module): - """SparseCore embedding layer.""" + """SparseCore embedding layer. + + ## Pipelining + + This layer supports pipelining of SparseCore computations if + `enable_pipelining` is set to True. Pipelining decouples the SC computation + with TC computation by processing multiple batches concurrently stored in + internal state (variables). This allows for greater SC-TC overlap and + generally better performance at the cost of higher memory usage. + There's however a comparitively slower convergence which is tolerable in most + cases. See internal link:jax-sc-embedding-pipelining for more information. + + When pipelining is enabled, it implements a two-stage pipeline: embedding + lookups for batch `i` run concurrently with TensorCore computations for batch + `i-1` and embedding gradient updates for batch `i-2`. This results in + activations being delayed by one step and gradient updates by two steps + relative to the inputs. + + NOTE for pipelining: + * The first two steps return zero activations (warm-up), therefore user needs + to run two additional steps. The dense input for first(0) and last(N+1) could + be dummy input. + * If pipelining is enabled, user will have to pass + `mutable=['sparsecore_pipeline_state']` to `.apply()` to + update internal pipeline state. + """ # A sequence of FeatureSpecs to specify the configurations for the # input feature. @@ -81,6 +108,7 @@ class SparseCoreEmbed(nn.Module): # Sharding strategy for embedding tables. table_sharding_strategy: str = 'MOD' enable_minibatching: bool = False + enable_pipelining: bool = False num_sc_per_device: int = -1 # Initialized in __post_init__. @@ -165,6 +193,7 @@ def preprocess_inputs( all_reduce_interface=all_reduce_interface, )[0] + @nn.compact def __call__( self, embedding_lookup_inputs: EmbeddingLookupInput ) -> embedding.Nested[jax.Array]: @@ -176,12 +205,19 @@ def __call__( Returns: The activations structure with the same structure as feature_specs. """ + if self.enable_pipelining: + return self._pipelined_call(embedding_lookup_inputs) return _emb_lookup( self, embedding_lookup_inputs, self.embedding_table, ) + def _get_embedding_table_path(self) -> str: + """Returns the path to the embedding table within the module.""" + return '/'.join(self.path + (EMBEDDING_PARAM_NAME,)) + + @nn.compact def apply_gradient( self, gradients: embedding.Nested[jax.Array], @@ -196,13 +232,131 @@ def apply_gradient( Returns: The updated activation embedding tables. """ + if self.enable_pipelining: + return self._pipelined_apply_gradient(gradients, embedding_lookup_inputs) _, embed_table = _emb_lookup_bwd( self, (embedding_lookup_inputs, self.embedding_table), gradients, ) - path = '/'.join(self.path + (EMBEDDING_PARAM_NAME,)) - return {path: embed_table} + return {self._get_embedding_table_path(): embed_table} + + ############################################################################## + # Pipelining implementation + ############################################################################## + def _pipelined_call( + self, embedding_lookup_inputs: embedding.PreprocessedInput + ) -> embedding.Nested[jax.Array]: + """Pipelined version of __call__.""" + step_im1_sparse_activations = self._get_step_im1_sparse_activations() + step_im2_sparse_inputs = self._get_step_im2_sparse_inputs( + embedding_lookup_inputs + ) + step_im2_sparse_gradients = self._get_step_im2_sparse_gradients() + + # Update embedding table using values from step i-2. + # Perform the update using a custom_vjp to avoid differentiating the + # optimizer step. + updated_table_val = _perform_update( + self, + step_im2_sparse_inputs.value, + step_im2_sparse_gradients.value, + self.embedding_table, + ) + if self.is_mutable_collection('params'): + self.scope.variables()['params'][ + EMBEDDING_PARAM_NAME + ] = self.scope.variables()['params'][EMBEDDING_PARAM_NAME].replace_boxed( + updated_table_val + ) + + updated_table = updated_table_val + + # The activations for the current step's forward pass are from step i-1. + result_activations = step_im1_sparse_activations.value + + # Now, perform the lookup for the current step (i) using the newly updated + # embedding table and store the inputs and resulting activations for future + # steps. + self._get_step_im1_sparse_inputs(embedding_lookup_inputs).value = ( + embedding_lookup_inputs + ) + step_im1_sparse_activations.value = _emb_lookup( + self, embedding_lookup_inputs, updated_table + ) + + return result_activations + + def _pipelined_apply_gradient( + self, + gradients: embedding.Nested[jax.Array], + embedding_lookup_inputs: embedding.PreprocessedInput, + ) -> Mapping[str, Mapping[str, jax.Array]]: + """Pipelined version of apply_gradient.""" + step_im1_sparse_inputs = self._get_step_im1_sparse_inputs( + embedding_lookup_inputs + ) + step_im2_sparse_inputs = self._get_step_im2_sparse_inputs( + embedding_lookup_inputs + ) + step_im2_sparse_gradients = self._get_step_im2_sparse_gradients() + + # Store sparse inputs and gradients for use on next step. + step_im2_sparse_inputs.value = step_im1_sparse_inputs.value + step_im2_sparse_gradients.value = gradients + + return {} + + def _get_unfrozen_feature_specs( + self, + ) -> embedding.Nested[embedding_spec.FeatureSpec]: + return ( + flax.core.unfreeze(self.feature_specs) + if isinstance(self.feature_specs, flax.core.FrozenDict) + else self.feature_specs + ) + + def _get_step_im1_sparse_inputs( + self, embedding_lookup_inputs: embedding.PreprocessedInput + ) -> flax.core.scope.Variable[embedding.PreprocessedInput]: + return self.variable( + SPARSECORE_PIPELINE_STATE_COLLECTION, + 'step_im1_sparse_inputs', + lambda: jax.tree.map(jnp.zeros_like, embedding_lookup_inputs), + ) + + def _get_step_im1_sparse_activations( + self, + ) -> flax.core.scope.Variable[embedding.Nested[jax.Array]]: + return self.variable( + SPARSECORE_PIPELINE_STATE_COLLECTION, + 'step_im1_sparse_activations', + lambda: jax.tree.map( + lambda f: jnp.zeros(f.output_shape, dtype=jnp.float32), + self._get_unfrozen_feature_specs(), + ), + ) + + def _get_step_im2_sparse_inputs( + self, embedding_lookup_inputs: embedding.PreprocessedInput + ) -> flax.core.scope.Variable[embedding.PreprocessedInput]: + return self.variable( + SPARSECORE_PIPELINE_STATE_COLLECTION, + 'step_im2_sparse_inputs', + lambda: jax.tree.map(jnp.zeros_like, embedding_lookup_inputs), + ) + + def _get_step_im2_sparse_gradients( + self, + ) -> flax.core.scope.Variable[embedding.Nested[jax.Array]]: + return self.variable( + SPARSECORE_PIPELINE_STATE_COLLECTION, + 'step_im2_sparse_gradients', + lambda: jax.tree.map( + lambda f: jnp.zeros(f.output_shape, dtype=jnp.float32), + self._get_unfrozen_feature_specs(), + ), + ) ################################################################################ @@ -255,7 +409,7 @@ def _emb_lookup_bwd(embedding_layer, res, gradients): pt = embedding_layer.embedding_table_partition pd = embedding_layer.data_partition - emb_table_grads = shard_map( + updated_emb_table = shard_map( functools.partial( embedding.tpu_sparse_dense_matmul_grad, feature_specs=embedding_layer.feature_specs, @@ -272,11 +426,12 @@ def _emb_lookup_bwd(embedding_layer, res, gradients): emb_table, ) + # > Reinterpret emb table as grads (unused). # tpu_sparse_dense_matmul_grad returns a general Mapping (usually a dict). # It may not be the same type as the embedding table (e.g. FrozenDict). # Here we use flatten / unflatten to ensure the types are the same. emb_table_grads = jax.tree.unflatten( - jax.tree.structure(emb_table), jax.tree.leaves(emb_table_grads) + jax.tree.structure(emb_table), jax.tree.leaves(updated_emb_table) ) return ( @@ -286,3 +441,68 @@ def _emb_lookup_bwd(embedding_layer, res, gradients): _emb_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd) + + +SPARSECORE_PIPELINE_STATE_COLLECTION = 'sparsecore_pipeline_state' + + +################################################################################ +# Define custom VJP for embedding update for pipelining. +# This is used to prevent autodiff from differentiating through the optimizer +# step. +################################################################################ +@functools.partial(jax.custom_vjp, nondiff_argnums=(0,)) +def _perform_update( + module: 'SparseCoreEmbed', + im2_inputs: embedding.PreprocessedInput, + im2_grads: embedding.Nested[jax.Array], + emb_table: embedding.Nested[jax.Array], +) -> embedding.Nested[jax.Array]: + """Performs the embedding update, but is opaque to autodiff.""" + _, updated_table = _emb_lookup_bwd( # pylint: disable=protected-access + module, + (im2_inputs, emb_table), + im2_grads, + ) + return updated_table + + +def _perform_update_fwd( + module: 'SparseCoreEmbed', + im2_inputs: embedding.PreprocessedInput, + im2_grads: embedding.Nested[jax.Array], + emb_table: embedding.Nested[jax.Array], +): + """Forward pass for _perform_update.""" + updated_table = _perform_update(module, im2_inputs, im2_grads, emb_table) + # Return inputs as residuals for backward pass. + return updated_table, (im2_inputs, im2_grads, emb_table) + + +def _perform_update_bwd( + module: 'SparseCoreEmbed', + res: tuple[ + embedding.PreprocessedInput, + embedding.Nested[jax.Array], + embedding.Nested[jax.Array], + ], + g: embedding.Nested[jax.Array], +) -> tuple[ + embedding.PreprocessedInput, + embedding.Nested[jax.Array], + embedding.Nested[jax.Array], +]: + """Backward pass for _perform_update.""" + # g is the gradient w.r.t. the output (updated_table). + # We want this to flow back to the original emb_table as if this function + # was an identity function. + im2_inputs, im2_grads, emb_table = res + del module, emb_table + return ( + jax.tree.map(jnp.zeros_like, im2_inputs), + jax.tree.map(jnp.zeros_like, im2_grads), + g, # Pass gradient through to the original embedding table. + ) + + +_perform_update.defvjp(_perform_update_fwd, _perform_update_bwd) diff --git a/jax_tpu_embedding/sparsecore/lib/flax/embed_pipelining.py b/jax_tpu_embedding/sparsecore/lib/flax/embed_pipelining.py new file mode 100644 index 00000000..f724bd0d --- /dev/null +++ b/jax_tpu_embedding/sparsecore/lib/flax/embed_pipelining.py @@ -0,0 +1,237 @@ +# Copyright 2024 The JAX SC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SparseCore layer to pipeline computations with TensorCore.""" + +import functools +from typing import Mapping + +from absl import logging +import flax +from flax import linen as nn +import jax +import jax.numpy as jnp +from jax_tpu_embedding.sparsecore.lib.flax import embed +from jax_tpu_embedding.sparsecore.lib.nn import embedding +from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec + + +SPARSECORE_PIPELINE_STATE_COLLECTION = 'sparsecore_pipeline_state' + + +class PipelinedSparseCoreEmbed(embed.SparseCoreEmbed): + """A SparseCore embedding layer with pipelining support. + + This decouples the SC computation with TC computation by processing multiple + batches concurrently stored in internal state (variables). This allows for + greater SC-TC + overlap and generally better performance at the cost of higher memory usage. + There's however a comparitively slower convergence which is tolerable in most + cases. See internal link:jax-sc-embedding-pipelining for more information. + + It implements a two-stage pipeline: embedding lookups for batch `i` run + concurrently with TensorCore computations for batch `i-1` and embedding + gradient updates for batch `i-2`. This results in activations being + delayed by one step and gradient updates by two steps relative to the inputs. + + TODO {{bugsnag;a:manuadg;p:2;s:2;t:fr;c:1846895}} - Enable SC pipelining for + Flax layer. + + There are couple of missing optimizations that would be later + added (requires cl/811416653): + + * Store stacked activations and load stacked gradients to avoid reshaping + during SC operation (Currently we unstack during lookup and stack during + gradient update.) + + NOTE: + * The first two steps return zero activations (warm-up), therefore user needs + to run two additional steps. The dense input for first(0) and last(N+1) could + be dummy input. + * User will have to pass `mutable=True` to `.apply()/.apply_gradient()` to + update internal pipeline state. + """ + + @nn.compact + def __call__( + self, embedding_lookup_inputs: embedding.PreprocessedInput + ) -> embedding.Nested[jax.Array]: + + step_im1_sparse_activations = self._get_step_im1_sparse_activations() + step_im2_sparse_inputs = self._get_step_im2_sparse_inputs( + embedding_lookup_inputs + ) + step_im2_sparse_gradients = self._get_step_im2_sparse_gradients() + + # Update embedding table using values from step i-2. + # Perform the update using a custom_vjp to avoid differentiating the + # optimizer step. + updated_table = _perform_update( + self, + step_im2_sparse_inputs.value, + step_im2_sparse_gradients.value, + self.embedding_table, + ) + + # The activations for the current step's forward pass are from step i-1. + result_activations = step_im1_sparse_activations.value + + # Now, perform the lookup for the current step (i) using the newly updated + # embedding table and store the inputs and resulting activations for future + # steps. + self._get_step_im1_sparse_inputs(embedding_lookup_inputs).value = ( + embedding_lookup_inputs + ) + step_im1_sparse_activations.value = embed._emb_lookup( + self, embedding_lookup_inputs, updated_table + ) + + return result_activations + + @nn.compact + def apply_gradient( + self, + gradients: embedding.Nested[jax.Array], + embedding_lookup_inputs: embedding.PreprocessedInput, + ) -> Mapping[str, Mapping[str, jax.Array]]: + + step_im1_sparse_inputs = self._get_step_im1_sparse_inputs( + embedding_lookup_inputs + ) + step_im2_sparse_inputs = self._get_step_im2_sparse_inputs( + embedding_lookup_inputs + ) + step_im2_sparse_gradients = self._get_step_im2_sparse_gradients() + + # Store sparse inputs and gradients for use on next step. + step_im2_sparse_inputs.value = step_im1_sparse_inputs.value + step_im2_sparse_gradients.value = gradients + + return {} + + ############################################################################## + # Variables + ############################################################################## + + def _get_unfrozen_feature_specs( + self, + ) -> embedding.Nested[embedding_spec.FeatureSpec]: + return ( + flax.core.unfreeze(self.feature_specs) + if isinstance(self.feature_specs, flax.core.FrozenDict) + else self.feature_specs + ) + + def _get_step_im1_sparse_inputs( + self, embedding_lookup_inputs: embedding.PreprocessedInput + ) -> flax.core.scope.Variable[embedding.PreprocessedInput]: + return self.variable( + SPARSECORE_PIPELINE_STATE_COLLECTION, + 'step_im1_sparse_inputs', + lambda: jax.tree.map(jnp.zeros_like, embedding_lookup_inputs), + ) + + def _get_step_im1_sparse_activations( + self, + ) -> flax.core.scope.Variable[embedding.Nested[jax.Array]]: + return self.variable( + SPARSECORE_PIPELINE_STATE_COLLECTION, + 'step_im1_sparse_activations', + lambda: jax.tree.map( + lambda f: jnp.zeros(f.output_shape, dtype=jnp.float32), + self._get_unfrozen_feature_specs(), + ), + ) + + def _get_step_im2_sparse_inputs( + self, embedding_lookup_inputs: embedding.PreprocessedInput + ) -> flax.core.scope.Variable[embedding.PreprocessedInput]: + return self.variable( + SPARSECORE_PIPELINE_STATE_COLLECTION, + 'step_im2_sparse_inputs', + lambda: jax.tree.map(jnp.zeros_like, embedding_lookup_inputs), + ) + + def _get_step_im2_sparse_gradients( + self, + ) -> flax.core.scope.Variable[embedding.Nested[jax.Array]]: + return self.variable( + SPARSECORE_PIPELINE_STATE_COLLECTION, + 'step_im2_sparse_gradients', + lambda: jax.tree.map( + lambda f: jnp.zeros(f.output_shape, dtype=jnp.float32), + self._get_unfrozen_feature_specs(), + ), + ) + + +################################################################################ +# Define custom VJP for embedding update. +# This is used to prevent autodiff from differentiating through the optimizer +# step. +################################################################################ +@functools.partial(jax.custom_vjp, nondiff_argnums=(0,)) +def _perform_update( + module: 'PipelinedSparseCoreEmbed', + im2_inputs: embedding.PreprocessedInput, + im2_grads: embedding.Nested[jax.Array], + emb_table: embedding.Nested[jax.Array], +) -> embedding.Nested[jax.Array]: + """Performs the embedding update, but is opaque to autodiff.""" + _, updated_table = embed._emb_lookup_bwd( # pylint: disable=protected-access + module, + (im2_inputs, emb_table), + im2_grads, + ) + return updated_table + + +def _perform_update_fwd( + module: 'PipelinedSparseCoreEmbed', + im2_inputs: embedding.PreprocessedInput, + im2_grads: embedding.Nested[jax.Array], + emb_table: embedding.Nested[jax.Array], +): + """Forward pass for _perform_update.""" + updated_table = _perform_update(module, im2_inputs, im2_grads, emb_table) + # Return inputs as residuals for backward pass. + return updated_table, (im2_inputs, im2_grads, emb_table) + + +def _perform_update_bwd( + module: 'PipelinedSparseCoreEmbed', + res: tuple[ + embedding.PreprocessedInput, + embedding.Nested[jax.Array], + embedding.Nested[jax.Array], + ], + g: embedding.Nested[jax.Array], +) -> tuple[ + embedding.PreprocessedInput, + embedding.Nested[jax.Array], + embedding.Nested[jax.Array], +]: + """Backward pass for _perform_update.""" + # g is the gradient w.r.t. the output (updated_table). + # We want this to flow back to the original emb_table as if this function + # was an identity function. + im2_inputs, im2_grads, emb_table = res + del module, emb_table + return ( + jax.tree.map(jnp.zeros_like, im2_inputs), + jax.tree.map(jnp.zeros_like, im2_grads), + g, # Pass gradient through to the original embedding table. + ) + + +_perform_update.defvjp(_perform_update_fwd, _perform_update_bwd) diff --git a/jax_tpu_embedding/sparsecore/lib/flax/tests/embed_test.py b/jax_tpu_embedding/sparsecore/lib/flax/tests/embed_test.py index 56ae2f2e..f92072e7 100644 --- a/jax_tpu_embedding/sparsecore/lib/flax/tests/embed_test.py +++ b/jax_tpu_embedding/sparsecore/lib/flax/tests/embed_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from absl.testing import parameterized import einops +from flax import core from flax import linen as nn import jax import jax.numpy as jnp @@ -209,6 +210,92 @@ def _row_initialize_with_padding( paddings = tuple((0, y - x) for x, y in zip(shape, padded_shape)) return np.pad(array, paddings, mode='constant', constant_values=pad_value) + def _create_embedding_variables( + self, + module: embed.SparseCoreEmbed, + feature_specs: tuple[embedding_spec.FeatureSpec, ...], + offsets: dict[str, int], + ) -> core.FrozenDict[str, embedding.EmbeddingVariables]: + """Creates sharded embedding variables for given feature specs.""" + embedding_variables = {} + devices = module.mesh.devices.flatten() + device_count = len(devices) + num_sc_per_device = module.num_sc_per_device + sharding = NamedSharding(module.mesh, P(module.sharding_axis, None)) + + unique_tables = {} + for f in feature_specs: + if f.table_spec.name not in unique_tables: + unique_tables[f.table_spec.name] = f.table_spec + + for table_name, table_spec in unique_tables.items(): + padded_vocab = table_spec.setting_in_stack.padded_vocab_size + padded_dim = table_spec.setting_in_stack.padded_embedding_dim + + emb_table = self._row_initialize_with_padding( + shape=(table_spec.vocabulary_size, table_spec.embedding_dim), + padded_shape=(padded_vocab, padded_dim), + offset=offsets.get(table_name, 0), + ) + emb_table_sharded = einops.rearrange( + emb_table, + '(v c s) f -> c (s v) f', + c=device_count, + s=num_sc_per_device, + ) + device_arrays = [ + jax.device_put(emb_table_sharded[i], device=d) + for i, d in enumerate(devices) + ] + embedding_variables[table_name] = embedding.EmbeddingVariables( + table=jax.make_array_from_single_device_arrays( + shape=(padded_vocab, padded_dim), + sharding=sharding, + arrays=device_arrays, + ), + slot=embedding_spec.SGDSlotVariables(), + ) + return core.freeze(embedding_variables) + + def _initialize_model_variables( + self, + module: embed.SparseCoreEmbed, + embedding_lookup_input: embedding.PreprocessedInput, + embedding_variables: core.FrozenDict[str, embedding.EmbeddingVariables], + ): + """Initializes model variables and replaces embedding table.""" + var_spec = jax.eval_shape( + module.init, + jax.random.PRNGKey(0), + embedding_lookup_input, + ) + out_sharding = nn.get_sharding(var_spec, module.mesh) + variables = jax.jit( + module.init, + in_shardings=( + NamedSharding(module.mesh, P()), + NamedSharding(module.mesh, P(module.sharding_axis)), + ), + out_shardings=out_sharding, + )( + jax.random.PRNGKey(0), + embedding_lookup_input, + ) + + # Replace the embedding variables in params with the ones we created. + def check_shape(a, b): + assert a.shape == b.shape + + jax.tree.map( + check_shape, + variables['params'][_EMBED_PARAM].value, + embedding_variables, + ) + variables['params'][_EMBED_PARAM] = variables['params'][ + _EMBED_PARAM + ].replace_boxed(embedding_variables) + return variables + @parameterized.named_parameters( dict(testcase_name='_with_minibatching', enable_minibatching=True), dict(testcase_name='_without_minibatching', enable_minibatching=False), @@ -248,95 +335,15 @@ def test_forward_and_backward_with_one_table(self, enable_minibatching: bool): self.feature_spec_b.table_spec.setting_in_stack.padded_embedding_dim ) - device_count = len(devices) - emb_table_a = self._row_initialize_with_padding( - shape=(_VOC_A, _DIM_A), padded_shape=(padded_vocab_a, padded_dim_a) - ) - emb_table_a_sharded = einops.rearrange( - emb_table_a, - '(v c s) f -> c (s v) f', - c=device_count, - s=num_sc_per_device, - ) - emb_table_b = self._row_initialize_with_padding( - shape=(_VOC_B, _DIM_B), padded_shape=(padded_vocab_b, padded_dim_b) - ) - emb_table_b_sharded = einops.rearrange( - emb_table_b, - '(v c s) f -> c (s v) f', - c=device_count, - s=num_sc_per_device, - ) - - embedding_variables = {} - - embedding_variables['table_a'] = [ - jax.device_put( - emb_table_a_sharded[i], - device=local_device, - ) - for i, local_device in enumerate(devices) - ] - embedding_variables['table_b'] = [ - jax.device_put( - emb_table_b_sharded[i], - device=local_device, - ) - for i, local_device in enumerate(devices) - ] - sharding = NamedSharding(sc_module.mesh, P(sc_module.sharding_axis, None)) - embedding_variables['table_a'] = embedding.EmbeddingVariables( - table=jax.make_array_from_single_device_arrays( - shape=(padded_vocab_a, padded_dim_a), - sharding=sharding, - arrays=embedding_variables['table_a'], - ), - slot=embedding_spec.SGDSlotVariables(), - ) - embedding_variables['table_b'] = embedding.EmbeddingVariables( - table=jax.make_array_from_single_device_arrays( - shape=(padded_vocab_b, padded_dim_b), - sharding=sharding, - arrays=embedding_variables['table_b'], - ), - slot=embedding_spec.SGDSlotVariables(), - ) - - var_spec = jax.eval_shape( - sc_module.init, - jax.random.PRNGKey(0), - embedding_lookup_input, + embedding_variables = self._create_embedding_variables( + sc_module, feature_specs, offsets={'table_a': 0, 'table_b': 0} ) - - out_sharding = nn.get_sharding(var_spec, sc_module.mesh) - - params = jax.jit( - sc_module.init, - in_shardings=( - NamedSharding(sc_module.mesh, P()), - NamedSharding(sc_module.mesh, P(sc_module.sharding_axis)), - ), - out_shardings=out_sharding, - )( - jax.random.PRNGKey(0), - embedding_lookup_input, + variables = self._initialize_model_variables( + sc_module, embedding_lookup_input, embedding_variables ) - # Replace the embedding variables in params with the ones we created. - def check_shape(a, b): - assert a.shape == b.shape - - jax.tree.map( - check_shape, - params['params'][_EMBED_PARAM].value, - embedding_variables, - ) - params['params'][_EMBED_PARAM] = params['params'][ - _EMBED_PARAM - ].replace_boxed(embedding_variables) - activations = jax.jit(sc_module.apply)( - params, + variables, embedding_lookup_input, ) @@ -408,7 +415,7 @@ def check_shape(a, b): params_updates = jax.jit( functools.partial(sc_module.apply, method=sc_module.apply_gradient), )( - params, + variables, activations_grad, embedding_lookup_input, ) @@ -416,9 +423,9 @@ def check_shape(a, b): # Updates params with the new embedding variables. assert len(params_updates) == 1 tree.assert_same_structure( - params_updates[_EMBED_PARAM], params['params'][_EMBED_PARAM].value + params_updates[_EMBED_PARAM], variables['params'][_EMBED_PARAM].value ) - params['params'] = params['params'] | params_updates + variables['params'] = variables['params'] | params_updates expected_grad_table_a = np.full( (padded_vocab_a, padded_dim_a), _PAD_VALUE, dtype=np.float32 @@ -427,14 +434,14 @@ def check_shape(a, b): (padded_vocab_b, padded_dim_b), _PAD_VALUE, dtype=np.float32 ) - for i, array in enumerate(embedding_variables['table_a'][0]): + for i, array in enumerate(embedding_variables['table_a'].table): col_id = array[0] new_col_id = col_id - (count_num(self.input_tensor, col_id) * 0.01) expected_grad_table_a[i, :_DIM_A] = np.full( (1, _DIM_A), new_col_id, dtype=np.float32 ) - for i, array in enumerate(embedding_variables['table_b'][0]): + for i, array in enumerate(embedding_variables['table_b'].table): col_id = array[0] new_col_id = col_id - ( count_num(self.input_tensor_table_b, col_id) * 0.01 @@ -443,10 +450,152 @@ def check_shape(a, b): (1, _DIM_B), new_col_id, dtype=np.float32 ) np.testing.assert_equal( - params['params'][_EMBED_PARAM]['table_a'][0], expected_grad_table_a + variables['params'][_EMBED_PARAM]['table_a'].table, + expected_grad_table_a, ) np.testing.assert_equal( - params['params'][_EMBED_PARAM]['table_b'][0], expected_grad_table_b + variables['params'][_EMBED_PARAM]['table_b'].table, + expected_grad_table_b, + ) + + @parameterized.named_parameters( + dict(testcase_name='_with_minibatching', enable_minibatching=True), + dict(testcase_name='_without_minibatching', enable_minibatching=False), + ) + def test_pipelined_forward_and_backward(self, enable_minibatching: bool): + devices = jax.devices() + num_sc_per_device = utils.num_sparsecores_per_device(devices[0]) + + feature_specs = (self.feature_spec_a,) + embedding.prepare_feature_specs_for_training( + feature_specs, + global_device_count=jax.device_count(), + num_sc_per_device=num_sc_per_device, + ) + + sc_module = embed.SparseCoreEmbed( + feature_specs=feature_specs, + enable_minibatching=enable_minibatching, + enable_pipelining=True, + ) + + inp0 = sc_module.preprocess_inputs( + 0, + (self.input_tensor,), + (self.input_weights,), + ) + inp1 = sc_module.preprocess_inputs( + 1, + (self.input_tensor,), + (self.input_weights,), + ) + inp2 = sc_module.preprocess_inputs( + 2, + (self.input_tensor,), + (self.input_weights,), + ) + + padded_vocab_a = ( + self.feature_spec_a.table_spec.setting_in_stack.padded_vocab_size + ) + padded_dim_a = ( + self.feature_spec_a.table_spec.setting_in_stack.padded_embedding_dim + ) + + embedding_variables = self._create_embedding_variables( + sc_module, feature_specs, offsets={'table_a': 0} + ) + variables = self._initialize_model_variables( + sc_module, inp0, embedding_variables + ) + + # step 0 + apply_fn = jax.jit(sc_module.apply, static_argnames=['mutable']) + apply_grad_fn = jax.jit( + functools.partial(sc_module.apply, method=sc_module.apply_gradient), + static_argnames=['mutable'], + ) + + activations0, variables = apply_fn(variables, inp0, mutable=True) + # Check activations are 0s + np.testing.assert_allclose( + activations0[0], jnp.zeros((_BATCH_SIZE, _DIM_A)) + ) + + grad0 = (jnp.ones((_BATCH_SIZE, _DIM_A), dtype=jnp.float32),) + _, variables = apply_grad_fn(variables, grad0, inp0, mutable=True) + + # step 1 + activations1, variables = apply_fn(variables, inp1, mutable=True) + # in step 1, activations should be from lookup of step 0. + expected_emb_activations_0 = np.broadcast_to( + np.array( + [ + [11.0], + [3.0], + [9.0], + [26.0], + [29.0], + [31.0], + [67.0], + [57.0], + [15.0], + [13.0], + [11.0], + [8.0], + [17.0], + [42.0], + [30.0], + [26.0], + ], + dtype=np.float32, + ), + (_BATCH_SIZE, _DIM_A), + ) + np.testing.assert_allclose(activations1[0], expected_emb_activations_0) + + grad1 = (jnp.ones((_BATCH_SIZE, _DIM_A), dtype=jnp.float32),) + _, variables = apply_grad_fn(variables, grad1, inp1, mutable=True) + + # step 2 + activations2, variables = apply_fn(variables, inp2, mutable=True) + # in step 2, activations should be from lookup of step 1, which used + # embedding table updated with step 0 gradients. + expected_activations_2_list = [] + for i in range(len(self.input_tensor)): + val = 0 + for embedding_id in self.input_tensor[i]: + val += embedding_id - count_num(self.input_tensor, embedding_id) * 0.01 + expected_activations_2_list.append([val]) + expected_activations_2 = np.broadcast_to( + np.array(expected_activations_2_list, dtype=np.float32), + (_BATCH_SIZE, _DIM_A), + ) + np.testing.assert_allclose( + activations2[0], expected_activations_2, rtol=1e-6 + ) + + grad2 = (jnp.ones((_BATCH_SIZE, _DIM_A), dtype=jnp.float32),) + _, variables = apply_grad_fn(variables, grad2, inp2, mutable=True) + + # In step 2 __call__, table update using grad0 and inp0 has been performed. + # The updated table is in variables['params'][_EMBED_PARAM] + + expected_grad_table_a = np.full( + (padded_vocab_a, padded_dim_a), _PAD_VALUE, dtype=np.float32 + ) + + for i, array in enumerate(embedding_variables['table_a'].table): + col_id = array[0] + if col_id != _PAD_VALUE: + new_col_id = col_id - 2 * (count_num(self.input_tensor, col_id) * 0.01) + expected_grad_table_a[i, :_DIM_A] = np.full( + (1, _DIM_A), new_col_id, dtype=np.float32 + ) + np.testing.assert_allclose( + variables['params'][_EMBED_PARAM].value['table_a'].table, + expected_grad_table_a, + rtol=1e-6, ) @parameterized.named_parameters( @@ -518,50 +667,23 @@ def test_forward_and_backward_with_table_stacking( for i, local_device in enumerate(devices) ] sharding = NamedSharding(mesh, P('x', None)) - embedding_variables['table_a_table_c'] = embedding.EmbeddingVariables( - table=jax.make_array_from_single_device_arrays( - shape=(stacked_vocab_size, padded_dim_a), - sharding=sharding, - arrays=embedding_variables['table_a_table_c'], - ), - slot=embedding_spec.SGDSlotVariables(), - ) - - var_spec = jax.eval_shape( - sc_module.init, - jax.random.PRNGKey(0), - embedding_lookup_input, - ) - - out_sharding = nn.get_sharding(var_spec, mesh) - - params = jax.jit( - sc_module.init, - in_shardings=( - NamedSharding(mesh, P()), - NamedSharding(mesh, P(sharding_axis)), - ), - out_shardings=out_sharding, - )( - jax.random.PRNGKey(0), - embedding_lookup_input, + embedding_variables['table_a_table_c'] = ( + embedding.EmbeddingVariables( + table=jax.make_array_from_single_device_arrays( + shape=(stacked_vocab_size, padded_dim_a), + sharding=sharding, + arrays=embedding_variables['table_a_table_c'], + ), + slot=embedding_spec.SGDSlotVariables(), + ) ) - # Replace the embedding variables in params with the ones we created. - def check_shape(a, b): - assert a.shape == b.shape - - jax.tree.map( - check_shape, - params['params'][_EMBED_PARAM].value, - embedding_variables, + variables = self._initialize_model_variables( + sc_module, embedding_lookup_input, core.freeze(embedding_variables) ) - params['params'][_EMBED_PARAM] = params['params'][ - _EMBED_PARAM - ].replace_boxed(embedding_variables) activations = jax.jit(sc_module.apply)( - params, + variables, embedding_lookup_input, ) @@ -606,7 +728,7 @@ def check_shape(a, b): params_updates = jax.jit( functools.partial(sc_module.apply, method=sc_module.apply_gradient), )( - params, + variables, activations_grad, embedding_lookup_input, ) @@ -614,9 +736,9 @@ def check_shape(a, b): # Updates params with the new embedding variables. assert len(params_updates) == 1 tree.assert_same_structure( - params_updates[_EMBED_PARAM], params['params'][_EMBED_PARAM].value + params_updates[_EMBED_PARAM], variables['params'][_EMBED_PARAM].value ) - params['params'] = params['params'] | params_updates + variables['params'] = variables['params'] | params_updates expected_grad_table_ac = np.full( (stacked_vocab_size, stacked_embedding_dim), @@ -624,7 +746,9 @@ def check_shape(a, b): dtype=np.float32, ) - for i, array in enumerate(embedding_variables['table_a_table_c'][0]): + for i, array in enumerate( + embedding_variables['table_a_table_c'].table + ): col_id = array[0] embedding_dim = _DIM_A if col_id < 200: @@ -639,7 +763,7 @@ def check_shape(a, b): ) np.testing.assert_equal( - params['params'][_EMBED_PARAM]['table_a_table_c'][0], + variables['params'][_EMBED_PARAM]['table_a_table_c'].table, expected_grad_table_ac, )