Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from jax_tpu_embedding.sparsecore.lib.nn import embedding
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec


shard_map = jax.experimental.shard_map.shard_map
Nested = embedding.Nested

Expand All @@ -38,6 +39,7 @@ class Model(nn.Module):
feature_name: str = 'shakespeare_feature'
mesh: jax.sharding.Mesh | None = None
sharding_axis: str = 'sparsecore_sharding'
enable_pipelining: bool = False

def add_sharding_constraint(self, x: jax.Array, names: tuple[str | None]):
# Add a sharding constraint to the array.
Expand Down Expand Up @@ -66,6 +68,7 @@ def __call__(self, embedding_lookup_inputs: embedding.PreprocessedInput):
feature_specs=self.feature_specs,
mesh=self.mesh,
sharding_axis=self.sharding_axis,
enable_pipelining=self.enable_pipelining,
)(embedding_lookup_inputs)

# Unpack the activations.
Expand Down
230 changes: 225 additions & 5 deletions jax_tpu_embedding/sparsecore/lib/flax/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import functools
from typing import Any, Callable, Mapping, TypeVar

import flax
from flax import linen as nn
from flax import typing
import jax
from jax.experimental import layout
import jax.numpy as jnp
from jax_tpu_embedding.sparsecore.lib.nn import embedding
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
from jax_tpu_embedding.sparsecore.utils import utils
Expand Down Expand Up @@ -69,7 +71,32 @@ def wrapper(*args, **kwargs):


class SparseCoreEmbed(nn.Module):
"""SparseCore embedding layer."""
"""SparseCore embedding layer.

## Pipelining

This layer supports pipelining of SparseCore computations if
`enable_pipelining` is set to True. Pipelining decouples the SC computation
with TC computation by processing multiple batches concurrently stored in
internal state (variables). This allows for greater SC-TC overlap and
generally better performance at the cost of higher memory usage.
There's however a comparitively slower convergence which is tolerable in most
cases. See internal link:jax-sc-embedding-pipelining for more information.

When pipelining is enabled, it implements a two-stage pipeline: embedding
lookups for batch `i` run concurrently with TensorCore computations for batch
`i-1` and embedding gradient updates for batch `i-2`. This results in
activations being delayed by one step and gradient updates by two steps
relative to the inputs.

NOTE for pipelining:
* The first two steps return zero activations (warm-up), therefore user needs
to run two additional steps. The dense input for first(0) and last(N+1) could
be dummy input.
* If pipelining is enabled, user will have to pass
`mutable=['sparsecore_pipeline_state']` to `.apply()` to
update internal pipeline state.
"""

# A sequence of FeatureSpecs to specify the configurations for the
# input feature.
Expand All @@ -81,6 +108,7 @@ class SparseCoreEmbed(nn.Module):
# Sharding strategy for embedding tables.
table_sharding_strategy: str = 'MOD'
enable_minibatching: bool = False
enable_pipelining: bool = False

num_sc_per_device: int = -1 # Initialized in __post_init__.

Expand Down Expand Up @@ -165,6 +193,7 @@ def preprocess_inputs(
all_reduce_interface=all_reduce_interface,
)[0]

@nn.compact
def __call__(
self, embedding_lookup_inputs: EmbeddingLookupInput
) -> embedding.Nested[jax.Array]:
Expand All @@ -176,12 +205,19 @@ def __call__(
Returns:
The activations structure with the same structure as feature_specs.
"""
if self.enable_pipelining:
return self._pipelined_call(embedding_lookup_inputs)
return _emb_lookup(
self,
embedding_lookup_inputs,
self.embedding_table,
)

def _get_embedding_table_path(self) -> str:
"""Returns the path to the embedding table within the module."""
return '/'.join(self.path + (EMBEDDING_PARAM_NAME,))

@nn.compact
def apply_gradient(
self,
gradients: embedding.Nested[jax.Array],
Expand All @@ -196,13 +232,131 @@ def apply_gradient(
Returns:
The updated activation embedding tables.
"""
if self.enable_pipelining:
return self._pipelined_apply_gradient(gradients, embedding_lookup_inputs)
_, embed_table = _emb_lookup_bwd(
self,
(embedding_lookup_inputs, self.embedding_table),
gradients,
)
path = '/'.join(self.path + (EMBEDDING_PARAM_NAME,))
return {path: embed_table}
return {self._get_embedding_table_path(): embed_table}

##############################################################################
# Pipelining implementation
##############################################################################
def _pipelined_call(
self, embedding_lookup_inputs: embedding.PreprocessedInput
) -> embedding.Nested[jax.Array]:
"""Pipelined version of __call__."""
step_im1_sparse_activations = self._get_step_im1_sparse_activations()
step_im2_sparse_inputs = self._get_step_im2_sparse_inputs(
embedding_lookup_inputs
)
step_im2_sparse_gradients = self._get_step_im2_sparse_gradients()

# Update embedding table using values from step i-2.
# Perform the update using a custom_vjp to avoid differentiating the
# optimizer step.
updated_table_val = _perform_update(
self,
step_im2_sparse_inputs.value,
step_im2_sparse_gradients.value,
self.embedding_table,
)
if self.is_mutable_collection('params'):
self.scope.variables()['params'][
EMBEDDING_PARAM_NAME
] = self.scope.variables()['params'][EMBEDDING_PARAM_NAME].replace_boxed(
updated_table_val
)

updated_table = updated_table_val

# The activations for the current step's forward pass are from step i-1.
result_activations = step_im1_sparse_activations.value

# Now, perform the lookup for the current step (i) using the newly updated
# embedding table and store the inputs and resulting activations for future
# steps.
self._get_step_im1_sparse_inputs(embedding_lookup_inputs).value = (
embedding_lookup_inputs
)
step_im1_sparse_activations.value = _emb_lookup(
self, embedding_lookup_inputs, updated_table
)

return result_activations

def _pipelined_apply_gradient(
self,
gradients: embedding.Nested[jax.Array],
embedding_lookup_inputs: embedding.PreprocessedInput,
) -> Mapping[str, Mapping[str, jax.Array]]:
"""Pipelined version of apply_gradient."""
step_im1_sparse_inputs = self._get_step_im1_sparse_inputs(
embedding_lookup_inputs
)
step_im2_sparse_inputs = self._get_step_im2_sparse_inputs(
embedding_lookup_inputs
)
step_im2_sparse_gradients = self._get_step_im2_sparse_gradients()

# Store sparse inputs and gradients for use on next step.
step_im2_sparse_inputs.value = step_im1_sparse_inputs.value
step_im2_sparse_gradients.value = gradients

return {}

def _get_unfrozen_feature_specs(
self,
) -> embedding.Nested[embedding_spec.FeatureSpec]:
return (
flax.core.unfreeze(self.feature_specs)
if isinstance(self.feature_specs, flax.core.FrozenDict)
else self.feature_specs
)

def _get_step_im1_sparse_inputs(
self, embedding_lookup_inputs: embedding.PreprocessedInput
) -> flax.core.scope.Variable[embedding.PreprocessedInput]:
return self.variable(
SPARSECORE_PIPELINE_STATE_COLLECTION,
'step_im1_sparse_inputs',
lambda: jax.tree.map(jnp.zeros_like, embedding_lookup_inputs),
)

def _get_step_im1_sparse_activations(
self,
) -> flax.core.scope.Variable[embedding.Nested[jax.Array]]:
return self.variable(
SPARSECORE_PIPELINE_STATE_COLLECTION,
'step_im1_sparse_activations',
lambda: jax.tree.map(
lambda f: jnp.zeros(f.output_shape, dtype=jnp.float32),
self._get_unfrozen_feature_specs(),
),
)

def _get_step_im2_sparse_inputs(
self, embedding_lookup_inputs: embedding.PreprocessedInput
) -> flax.core.scope.Variable[embedding.PreprocessedInput]:
return self.variable(
SPARSECORE_PIPELINE_STATE_COLLECTION,
'step_im2_sparse_inputs',
lambda: jax.tree.map(jnp.zeros_like, embedding_lookup_inputs),
)

def _get_step_im2_sparse_gradients(
self,
) -> flax.core.scope.Variable[embedding.Nested[jax.Array]]:
return self.variable(
SPARSECORE_PIPELINE_STATE_COLLECTION,
'step_im2_sparse_gradients',
lambda: jax.tree.map(
lambda f: jnp.zeros(f.output_shape, dtype=jnp.float32),
self._get_unfrozen_feature_specs(),
),
)


################################################################################
Expand Down Expand Up @@ -255,7 +409,7 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):

pt = embedding_layer.embedding_table_partition
pd = embedding_layer.data_partition
emb_table_grads = shard_map(
updated_emb_table = shard_map(
functools.partial(
embedding.tpu_sparse_dense_matmul_grad,
feature_specs=embedding_layer.feature_specs,
Expand All @@ -272,11 +426,12 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):
emb_table,
)

# > Reinterpret emb table as grads (unused).
# tpu_sparse_dense_matmul_grad returns a general Mapping (usually a dict).
# It may not be the same type as the embedding table (e.g. FrozenDict).
# Here we use flatten / unflatten to ensure the types are the same.
emb_table_grads = jax.tree.unflatten(
jax.tree.structure(emb_table), jax.tree.leaves(emb_table_grads)
jax.tree.structure(emb_table), jax.tree.leaves(updated_emb_table)
)

return (
Expand All @@ -286,3 +441,68 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):


_emb_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)


SPARSECORE_PIPELINE_STATE_COLLECTION = 'sparsecore_pipeline_state'


################################################################################
# Define custom VJP for embedding update for pipelining.
# This is used to prevent autodiff from differentiating through the optimizer
# step.
################################################################################
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
def _perform_update(
module: 'SparseCoreEmbed',
im2_inputs: embedding.PreprocessedInput,
im2_grads: embedding.Nested[jax.Array],
emb_table: embedding.Nested[jax.Array],
) -> embedding.Nested[jax.Array]:
"""Performs the embedding update, but is opaque to autodiff."""
_, updated_table = _emb_lookup_bwd( # pylint: disable=protected-access
module,
(im2_inputs, emb_table),
im2_grads,
)
return updated_table


def _perform_update_fwd(
module: 'SparseCoreEmbed',
im2_inputs: embedding.PreprocessedInput,
im2_grads: embedding.Nested[jax.Array],
emb_table: embedding.Nested[jax.Array],
):
"""Forward pass for _perform_update."""
updated_table = _perform_update(module, im2_inputs, im2_grads, emb_table)
# Return inputs as residuals for backward pass.
return updated_table, (im2_inputs, im2_grads, emb_table)


def _perform_update_bwd(
module: 'SparseCoreEmbed',
res: tuple[
embedding.PreprocessedInput,
embedding.Nested[jax.Array],
embedding.Nested[jax.Array],
],
g: embedding.Nested[jax.Array],
) -> tuple[
embedding.PreprocessedInput,
embedding.Nested[jax.Array],
embedding.Nested[jax.Array],
]:
"""Backward pass for _perform_update."""
# g is the gradient w.r.t. the output (updated_table).
# We want this to flow back to the original emb_table as if this function
# was an identity function.
im2_inputs, im2_grads, emb_table = res
del module, emb_table
return (
jax.tree.map(jnp.zeros_like, im2_inputs),
jax.tree.map(jnp.zeros_like, im2_grads),
g, # Pass gradient through to the original embedding table.
)


_perform_update.defvjp(_perform_update_fwd, _perform_update_bwd)
Loading