Skip to content

Commit 424271f

Browse files
[JAX SC] Implement pipelined SparseCore embedding layer in Flax
This change introduces `PipelinedSparseCoreEmbed` to enable pipelining of SparseCore operations with TensorCore. The layer stores inputs and activations/gradients across steps to allow for overlapping computation. The Shakespeare example is updated to use this new pipelined layer, requiring changes to how variables are handled in the training loop, including using `mutable=True` in `model.apply`. A test target for the Shakespeare example is also added. With pipelining enabled, the model converges as shown below: ``` I 2025-09-28T13:41:31.529508-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 9: Loss = 7.518053 I 2025-09-28T13:41:31.599092-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 19: Loss = 7.317211 I 2025-09-28T13:41:31.668971-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 29: Loss = 7.0158334 I 2025-09-28T13:41:31.738482-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 39: Loss = 6.609592 I 2025-09-28T13:41:31.807324-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 49: Loss = 6.140646 I 2025-09-28T13:41:31.876329-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 59: Loss = 5.7092185 I 2025-09-28T13:41:31.949826-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 69: Loss = 5.3921995 I 2025-09-28T13:41:32.018904-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 79: Loss = 5.183298 I 2025-09-28T13:41:32.088231-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 89: Loss = 5.0565457 I 2025-09-28T13:41:32.156878-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 99: Loss = 4.9826536 ... I 2025-09-28T13:41:38.309954-07:00 4428 jax_sc_shakespeare_jit_flax.py:415] Step 999: Loss = 0.31050137 ``` PiperOrigin-RevId: 812524634
1 parent e7ef80a commit 424271f

File tree

4 files changed

+725
-141
lines changed

4 files changed

+725
-141
lines changed

jax_tpu_embedding/sparsecore/examples/models/shakespeare/flax_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from jax_tpu_embedding.sparsecore.lib.nn import embedding
2121
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
2222

23+
2324
shard_map = jax.experimental.shard_map.shard_map
2425
Nested = embedding.Nested
2526

@@ -38,6 +39,7 @@ class Model(nn.Module):
3839
feature_name: str = 'shakespeare_feature'
3940
mesh: jax.sharding.Mesh | None = None
4041
sharding_axis: str = 'sparsecore_sharding'
42+
enable_pipelining: bool = False
4143

4244
def add_sharding_constraint(self, x: jax.Array, names: tuple[str | None]):
4345
# Add a sharding constraint to the array.
@@ -66,6 +68,7 @@ def __call__(self, embedding_lookup_inputs: embedding.PreprocessedInput):
6668
feature_specs=self.feature_specs,
6769
mesh=self.mesh,
6870
sharding_axis=self.sharding_axis,
71+
enable_pipelining=self.enable_pipelining,
6972
)(embedding_lookup_inputs)
7073

7174
# Unpack the activations.

jax_tpu_embedding/sparsecore/lib/flax/embed.py

Lines changed: 225 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
import functools
1717
from typing import Any, Callable, Mapping, TypeVar
1818

19+
import flax
1920
from flax import linen as nn
2021
from flax import typing
2122
import jax
2223
from jax.experimental import layout
24+
import jax.numpy as jnp
2325
from jax_tpu_embedding.sparsecore.lib.nn import embedding
2426
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
2527
from jax_tpu_embedding.sparsecore.utils import utils
@@ -69,7 +71,32 @@ def wrapper(*args, **kwargs):
6971

7072

7173
class SparseCoreEmbed(nn.Module):
72-
"""SparseCore embedding layer."""
74+
"""SparseCore embedding layer.
75+
76+
## Pipelining
77+
78+
This layer supports pipelining of SparseCore computations if
79+
`enable_pipelining` is set to True. Pipelining decouples the SC computation
80+
with TC computation by processing multiple batches concurrently stored in
81+
internal state (variables). This allows for greater SC-TC overlap and
82+
generally better performance at the cost of higher memory usage.
83+
There's however a comparitively slower convergence which is tolerable in most
84+
cases. See internal link:jax-sc-embedding-pipelining for more information.
85+
86+
When pipelining is enabled, it implements a two-stage pipeline: embedding
87+
lookups for batch `i` run concurrently with TensorCore computations for batch
88+
`i-1` and embedding gradient updates for batch `i-2`. This results in
89+
activations being delayed by one step and gradient updates by two steps
90+
relative to the inputs.
91+
92+
NOTE for pipelining:
93+
* The first two steps return zero activations (warm-up), therefore user needs
94+
to run two additional steps. The dense input for first(0) and last(N+1) could
95+
be dummy input.
96+
* If pipelining is enabled, user will have to pass
97+
`mutable=['sparsecore_pipeline_state']` to `.apply()` to
98+
update internal pipeline state.
99+
"""
73100

74101
# A sequence of FeatureSpecs to specify the configurations for the
75102
# input feature.
@@ -81,6 +108,7 @@ class SparseCoreEmbed(nn.Module):
81108
# Sharding strategy for embedding tables.
82109
table_sharding_strategy: str = 'MOD'
83110
enable_minibatching: bool = False
111+
enable_pipelining: bool = False
84112

85113
num_sc_per_device: int = -1 # Initialized in __post_init__.
86114

@@ -165,6 +193,7 @@ def preprocess_inputs(
165193
all_reduce_interface=all_reduce_interface,
166194
)[0]
167195

196+
@nn.compact
168197
def __call__(
169198
self, embedding_lookup_inputs: EmbeddingLookupInput
170199
) -> embedding.Nested[jax.Array]:
@@ -176,12 +205,19 @@ def __call__(
176205
Returns:
177206
The activations structure with the same structure as feature_specs.
178207
"""
208+
if self.enable_pipelining:
209+
return self._pipelined_call(embedding_lookup_inputs)
179210
return _emb_lookup(
180211
self,
181212
embedding_lookup_inputs,
182213
self.embedding_table,
183214
)
184215

216+
def _get_embedding_table_path(self) -> str:
217+
"""Returns the path to the embedding table within the module."""
218+
return '/'.join(self.path + (EMBEDDING_PARAM_NAME,))
219+
220+
@nn.compact
185221
def apply_gradient(
186222
self,
187223
gradients: embedding.Nested[jax.Array],
@@ -196,13 +232,131 @@ def apply_gradient(
196232
Returns:
197233
The updated activation embedding tables.
198234
"""
235+
if self.enable_pipelining:
236+
return self._pipelined_apply_gradient(gradients, embedding_lookup_inputs)
199237
_, embed_table = _emb_lookup_bwd(
200238
self,
201239
(embedding_lookup_inputs, self.embedding_table),
202240
gradients,
203241
)
204-
path = '/'.join(self.path + (EMBEDDING_PARAM_NAME,))
205-
return {path: embed_table}
242+
return {self._get_embedding_table_path(): embed_table}
243+
244+
##############################################################################
245+
# Pipelining implementation
246+
##############################################################################
247+
def _pipelined_call(
248+
self, embedding_lookup_inputs: embedding.PreprocessedInput
249+
) -> embedding.Nested[jax.Array]:
250+
"""Pipelined version of __call__."""
251+
step_im1_sparse_activations = self._get_step_im1_sparse_activations()
252+
step_im2_sparse_inputs = self._get_step_im2_sparse_inputs(
253+
embedding_lookup_inputs
254+
)
255+
step_im2_sparse_gradients = self._get_step_im2_sparse_gradients()
256+
257+
# Update embedding table using values from step i-2.
258+
# Perform the update using a custom_vjp to avoid differentiating the
259+
# optimizer step.
260+
updated_table_val = _perform_update(
261+
self,
262+
step_im2_sparse_inputs.value,
263+
step_im2_sparse_gradients.value,
264+
self.embedding_table,
265+
)
266+
if self.is_mutable_collection('params'):
267+
self.scope.variables()['params'][
268+
EMBEDDING_PARAM_NAME
269+
] = self.scope.variables()['params'][EMBEDDING_PARAM_NAME].replace_boxed(
270+
updated_table_val
271+
)
272+
273+
updated_table = updated_table_val
274+
275+
# The activations for the current step's forward pass are from step i-1.
276+
result_activations = step_im1_sparse_activations.value
277+
278+
# Now, perform the lookup for the current step (i) using the newly updated
279+
# embedding table and store the inputs and resulting activations for future
280+
# steps.
281+
self._get_step_im1_sparse_inputs(embedding_lookup_inputs).value = (
282+
embedding_lookup_inputs
283+
)
284+
step_im1_sparse_activations.value = _emb_lookup(
285+
self, embedding_lookup_inputs, updated_table
286+
)
287+
288+
return result_activations
289+
290+
def _pipelined_apply_gradient(
291+
self,
292+
gradients: embedding.Nested[jax.Array],
293+
embedding_lookup_inputs: embedding.PreprocessedInput,
294+
) -> Mapping[str, Mapping[str, jax.Array]]:
295+
"""Pipelined version of apply_gradient."""
296+
step_im1_sparse_inputs = self._get_step_im1_sparse_inputs(
297+
embedding_lookup_inputs
298+
)
299+
step_im2_sparse_inputs = self._get_step_im2_sparse_inputs(
300+
embedding_lookup_inputs
301+
)
302+
step_im2_sparse_gradients = self._get_step_im2_sparse_gradients()
303+
304+
# Store sparse inputs and gradients for use on next step.
305+
step_im2_sparse_inputs.value = step_im1_sparse_inputs.value
306+
step_im2_sparse_gradients.value = gradients
307+
308+
return {}
309+
310+
def _get_unfrozen_feature_specs(
311+
self,
312+
) -> embedding.Nested[embedding_spec.FeatureSpec]:
313+
return (
314+
flax.core.unfreeze(self.feature_specs)
315+
if isinstance(self.feature_specs, flax.core.FrozenDict)
316+
else self.feature_specs
317+
)
318+
319+
def _get_step_im1_sparse_inputs(
320+
self, embedding_lookup_inputs: embedding.PreprocessedInput
321+
) -> flax.core.scope.Variable[embedding.PreprocessedInput]:
322+
return self.variable(
323+
SPARSECORE_PIPELINE_STATE_COLLECTION,
324+
'step_im1_sparse_inputs',
325+
lambda: jax.tree.map(jnp.zeros_like, embedding_lookup_inputs),
326+
)
327+
328+
def _get_step_im1_sparse_activations(
329+
self,
330+
) -> flax.core.scope.Variable[embedding.Nested[jax.Array]]:
331+
return self.variable(
332+
SPARSECORE_PIPELINE_STATE_COLLECTION,
333+
'step_im1_sparse_activations',
334+
lambda: jax.tree.map(
335+
lambda f: jnp.zeros(f.output_shape, dtype=jnp.float32),
336+
self._get_unfrozen_feature_specs(),
337+
),
338+
)
339+
340+
def _get_step_im2_sparse_inputs(
341+
self, embedding_lookup_inputs: embedding.PreprocessedInput
342+
) -> flax.core.scope.Variable[embedding.PreprocessedInput]:
343+
return self.variable(
344+
SPARSECORE_PIPELINE_STATE_COLLECTION,
345+
'step_im2_sparse_inputs',
346+
lambda: jax.tree.map(jnp.zeros_like, embedding_lookup_inputs),
347+
)
348+
349+
def _get_step_im2_sparse_gradients(
350+
self,
351+
) -> flax.core.scope.Variable[embedding.Nested[jax.Array]]:
352+
return self.variable(
353+
SPARSECORE_PIPELINE_STATE_COLLECTION,
354+
'step_im2_sparse_gradients',
355+
lambda: jax.tree.map(
356+
lambda f: jnp.zeros(f.output_shape, dtype=jnp.float32),
357+
self._get_unfrozen_feature_specs(),
358+
),
359+
)
206360

207361

208362
################################################################################
@@ -255,7 +409,7 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):
255409

256410
pt = embedding_layer.embedding_table_partition
257411
pd = embedding_layer.data_partition
258-
emb_table_grads = shard_map(
412+
updated_emb_table = shard_map(
259413
functools.partial(
260414
embedding.tpu_sparse_dense_matmul_grad,
261415
feature_specs=embedding_layer.feature_specs,
@@ -272,11 +426,12 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):
272426
emb_table,
273427
)
274428

429+
# > Reinterpret emb table as grads (unused).
275430
# tpu_sparse_dense_matmul_grad returns a general Mapping (usually a dict).
276431
# It may not be the same type as the embedding table (e.g. FrozenDict).
277432
# Here we use flatten / unflatten to ensure the types are the same.
278433
emb_table_grads = jax.tree.unflatten(
279-
jax.tree.structure(emb_table), jax.tree.leaves(emb_table_grads)
434+
jax.tree.structure(emb_table), jax.tree.leaves(updated_emb_table)
280435
)
281436

282437
return (
@@ -286,3 +441,68 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):
286441

287442

288443
_emb_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)
444+
445+
446+
SPARSECORE_PIPELINE_STATE_COLLECTION = 'sparsecore_pipeline_state'
447+
448+
449+
################################################################################
450+
# Define custom VJP for embedding update for pipelining.
451+
# This is used to prevent autodiff from differentiating through the optimizer
452+
# step.
453+
################################################################################
454+
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
455+
def _perform_update(
456+
module: 'SparseCoreEmbed',
457+
im2_inputs: embedding.PreprocessedInput,
458+
im2_grads: embedding.Nested[jax.Array],
459+
emb_table: embedding.Nested[jax.Array],
460+
) -> embedding.Nested[jax.Array]:
461+
"""Performs the embedding update, but is opaque to autodiff."""
462+
_, updated_table = _emb_lookup_bwd( # pylint: disable=protected-access
463+
module,
464+
(im2_inputs, emb_table),
465+
im2_grads,
466+
)
467+
return updated_table
468+
469+
470+
def _perform_update_fwd(
471+
module: 'SparseCoreEmbed',
472+
im2_inputs: embedding.PreprocessedInput,
473+
im2_grads: embedding.Nested[jax.Array],
474+
emb_table: embedding.Nested[jax.Array],
475+
):
476+
"""Forward pass for _perform_update."""
477+
updated_table = _perform_update(module, im2_inputs, im2_grads, emb_table)
478+
# Return inputs as residuals for backward pass.
479+
return updated_table, (im2_inputs, im2_grads, emb_table)
480+
481+
482+
def _perform_update_bwd(
483+
module: 'SparseCoreEmbed',
484+
res: tuple[
485+
embedding.PreprocessedInput,
486+
embedding.Nested[jax.Array],
487+
embedding.Nested[jax.Array],
488+
],
489+
g: embedding.Nested[jax.Array],
490+
) -> tuple[
491+
embedding.PreprocessedInput,
492+
embedding.Nested[jax.Array],
493+
embedding.Nested[jax.Array],
494+
]:
495+
"""Backward pass for _perform_update."""
496+
# g is the gradient w.r.t. the output (updated_table).
497+
# We want this to flow back to the original emb_table as if this function
498+
# was an identity function.
499+
im2_inputs, im2_grads, emb_table = res
500+
del module, emb_table
501+
return (
502+
jax.tree.map(jnp.zeros_like, im2_inputs),
503+
jax.tree.map(jnp.zeros_like, im2_grads),
504+
g, # Pass gradient through to the original embedding table.
505+
)
506+
507+
508+
_perform_update.defvjp(_perform_update_fwd, _perform_update_bwd)

0 commit comments

Comments
 (0)