16
16
import functools
17
17
from typing import Any , Callable , Mapping , TypeVar
18
18
19
+ import flax
19
20
from flax import linen as nn
20
21
from flax import typing
21
22
import jax
22
23
from jax .experimental import layout
24
+ import jax .numpy as jnp
23
25
from jax_tpu_embedding .sparsecore .lib .nn import embedding
24
26
from jax_tpu_embedding .sparsecore .lib .nn import embedding_spec
25
27
from jax_tpu_embedding .sparsecore .utils import utils
@@ -69,7 +71,32 @@ def wrapper(*args, **kwargs):
69
71
70
72
71
73
class SparseCoreEmbed (nn .Module ):
72
- """SparseCore embedding layer."""
74
+ """SparseCore embedding layer.
75
+
76
+ ## Pipelining
77
+
78
+ This layer supports pipelining of SparseCore computations if
79
+ `enable_pipelining` is set to True. Pipelining decouples the SC computation
80
+ with TC computation by processing multiple batches concurrently stored in
81
+ internal state (variables). This allows for greater SC-TC overlap and
82
+ generally better performance at the cost of higher memory usage.
83
+ There's however a comparitively slower convergence which is tolerable in most
84
+ cases. See internal link:jax-sc-embedding-pipelining for more information.
85
+
86
+ When pipelining is enabled, it implements a two-stage pipeline: embedding
87
+ lookups for batch `i` run concurrently with TensorCore computations for batch
88
+ `i-1` and embedding gradient updates for batch `i-2`. This results in
89
+ activations being delayed by one step and gradient updates by two steps
90
+ relative to the inputs.
91
+
92
+ NOTE for pipelining:
93
+ * The first two steps return zero activations (warm-up), therefore user needs
94
+ to run two additional steps. The dense input for first(0) and last(N+1) could
95
+ be dummy input.
96
+ * If pipelining is enabled, user will have to pass
97
+ `mutable=['sparsecore_pipeline_state']` to `.apply()` to
98
+ update internal pipeline state.
99
+ """
73
100
74
101
# A sequence of FeatureSpecs to specify the configurations for the
75
102
# input feature.
@@ -81,6 +108,7 @@ class SparseCoreEmbed(nn.Module):
81
108
# Sharding strategy for embedding tables.
82
109
table_sharding_strategy : str = 'MOD'
83
110
enable_minibatching : bool = False
111
+ enable_pipelining : bool = False
84
112
85
113
num_sc_per_device : int = - 1 # Initialized in __post_init__.
86
114
@@ -165,6 +193,7 @@ def preprocess_inputs(
165
193
all_reduce_interface = all_reduce_interface ,
166
194
)[0 ]
167
195
196
+ @nn .compact
168
197
def __call__ (
169
198
self , embedding_lookup_inputs : EmbeddingLookupInput
170
199
) -> embedding .Nested [jax .Array ]:
@@ -176,12 +205,19 @@ def __call__(
176
205
Returns:
177
206
The activations structure with the same structure as feature_specs.
178
207
"""
208
+ if self .enable_pipelining :
209
+ return self ._pipelined_call (embedding_lookup_inputs )
179
210
return _emb_lookup (
180
211
self ,
181
212
embedding_lookup_inputs ,
182
213
self .embedding_table ,
183
214
)
184
215
216
+ def _get_embedding_table_path (self ) -> str :
217
+ """Returns the path to the embedding table within the module."""
218
+ return '/' .join (self .path + (EMBEDDING_PARAM_NAME ,))
219
+
220
+ @nn .compact
185
221
def apply_gradient (
186
222
self ,
187
223
gradients : embedding .Nested [jax .Array ],
@@ -196,13 +232,131 @@ def apply_gradient(
196
232
Returns:
197
233
The updated activation embedding tables.
198
234
"""
235
+ if self .enable_pipelining :
236
+ return self ._pipelined_apply_gradient (gradients , embedding_lookup_inputs )
199
237
_ , embed_table = _emb_lookup_bwd (
200
238
self ,
201
239
(embedding_lookup_inputs , self .embedding_table ),
202
240
gradients ,
203
241
)
204
- path = '/' .join (self .path + (EMBEDDING_PARAM_NAME ,))
205
- return {path : embed_table }
242
+ return {self ._get_embedding_table_path (): embed_table }
243
+
244
+ ##############################################################################
245
+ # Pipelining implementation
246
+ ##############################################################################
247
+ def _pipelined_call (
248
+ self , embedding_lookup_inputs : embedding .PreprocessedInput
249
+ ) -> embedding .Nested [jax .Array ]:
250
+ """Pipelined version of __call__."""
251
+ step_im1_sparse_activations = self ._get_step_im1_sparse_activations ()
252
+ step_im2_sparse_inputs = self ._get_step_im2_sparse_inputs (
253
+ embedding_lookup_inputs
254
+ )
255
+ step_im2_sparse_gradients = self ._get_step_im2_sparse_gradients ()
256
+
257
+ # Update embedding table using values from step i-2.
258
+ # Perform the update using a custom_vjp to avoid differentiating the
259
+ # optimizer step.
260
+ updated_table_val = _perform_update (
261
+ self ,
262
+ step_im2_sparse_inputs .value ,
263
+ step_im2_sparse_gradients .value ,
264
+ self .embedding_table ,
265
+ )
266
+ if self .is_mutable_collection ('params' ):
267
+ self .scope .variables ()['params' ][
268
+ EMBEDDING_PARAM_NAME
269
+ ] = self .scope .variables ()['params' ][EMBEDDING_PARAM_NAME ].replace_boxed (
270
+ updated_table_val
271
+ )
272
+
273
+ updated_table = updated_table_val
274
+
275
+ # The activations for the current step's forward pass are from step i-1.
276
+ result_activations = step_im1_sparse_activations .value
277
+
278
+ # Now, perform the lookup for the current step (i) using the newly updated
279
+ # embedding table and store the inputs and resulting activations for future
280
+ # steps.
281
+ self ._get_step_im1_sparse_inputs (embedding_lookup_inputs ).value = (
282
+ embedding_lookup_inputs
283
+ )
284
+ step_im1_sparse_activations .value = _emb_lookup (
285
+ self , embedding_lookup_inputs , updated_table
286
+ )
287
+
288
+ return result_activations
289
+
290
+ def _pipelined_apply_gradient (
291
+ self ,
292
+ gradients : embedding .Nested [jax .Array ],
293
+ embedding_lookup_inputs : embedding .PreprocessedInput ,
294
+ ) -> Mapping [str , Mapping [str , jax .Array ]]:
295
+ """Pipelined version of apply_gradient."""
296
+ step_im1_sparse_inputs = self ._get_step_im1_sparse_inputs (
297
+ embedding_lookup_inputs
298
+ )
299
+ step_im2_sparse_inputs = self ._get_step_im2_sparse_inputs (
300
+ embedding_lookup_inputs
301
+ )
302
+ step_im2_sparse_gradients = self ._get_step_im2_sparse_gradients ()
303
+
304
+ # Store sparse inputs and gradients for use on next step.
305
+ step_im2_sparse_inputs .value = step_im1_sparse_inputs .value
306
+ step_im2_sparse_gradients .value = gradients
307
+
308
+ return {}
309
+
310
+ def _get_unfrozen_feature_specs (
311
+ self ,
312
+ ) -> embedding .Nested [embedding_spec .FeatureSpec ]:
313
+ return (
314
+ flax .core .unfreeze (self .feature_specs )
315
+ if isinstance (self .feature_specs , flax .core .FrozenDict )
316
+ else self .feature_specs
317
+ )
318
+
319
+ def _get_step_im1_sparse_inputs (
320
+ self , embedding_lookup_inputs : embedding .PreprocessedInput
321
+ ) -> flax .core .scope .Variable [embedding .PreprocessedInput ]:
322
+ return self .variable (
323
+ SPARSECORE_PIPELINE_STATE_COLLECTION ,
324
+ 'step_im1_sparse_inputs' ,
325
+ lambda : jax .tree .map (jnp .zeros_like , embedding_lookup_inputs ),
326
+ )
327
+
328
+ def _get_step_im1_sparse_activations (
329
+ self ,
330
+ ) -> flax .core .scope .Variable [embedding .Nested [jax .Array ]]:
331
+ return self .variable (
332
+ SPARSECORE_PIPELINE_STATE_COLLECTION ,
333
+ 'step_im1_sparse_activations' ,
334
+ lambda : jax .tree .map (
335
+ lambda f : jnp .zeros (f .output_shape , dtype = jnp .float32 ),
336
+ self ._get_unfrozen_feature_specs (),
337
+ ),
338
+ )
339
+
340
+ def _get_step_im2_sparse_inputs (
341
+ self , embedding_lookup_inputs : embedding .PreprocessedInput
342
+ ) -> flax .core .scope .Variable [embedding .PreprocessedInput ]:
343
+ return self .variable (
344
+ SPARSECORE_PIPELINE_STATE_COLLECTION ,
345
+ 'step_im2_sparse_inputs' ,
346
+ lambda : jax .tree .map (jnp .zeros_like , embedding_lookup_inputs ),
347
+ )
348
+
349
+ def _get_step_im2_sparse_gradients (
350
+ self ,
351
+ ) -> flax .core .scope .Variable [embedding .Nested [jax .Array ]]:
352
+ return self .variable (
353
+ SPARSECORE_PIPELINE_STATE_COLLECTION ,
354
+ 'step_im2_sparse_gradients' ,
355
+ lambda : jax .tree .map (
356
+ lambda f : jnp .zeros (f .output_shape , dtype = jnp .float32 ),
357
+ self ._get_unfrozen_feature_specs (),
358
+ ),
359
+ )
206
360
207
361
208
362
################################################################################
@@ -255,7 +409,7 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):
255
409
256
410
pt = embedding_layer .embedding_table_partition
257
411
pd = embedding_layer .data_partition
258
- emb_table_grads = shard_map (
412
+ updated_emb_table = shard_map (
259
413
functools .partial (
260
414
embedding .tpu_sparse_dense_matmul_grad ,
261
415
feature_specs = embedding_layer .feature_specs ,
@@ -272,11 +426,12 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):
272
426
emb_table ,
273
427
)
274
428
429
+ # > Reinterpret emb table as grads (unused).
275
430
# tpu_sparse_dense_matmul_grad returns a general Mapping (usually a dict).
276
431
# It may not be the same type as the embedding table (e.g. FrozenDict).
277
432
# Here we use flatten / unflatten to ensure the types are the same.
278
433
emb_table_grads = jax .tree .unflatten (
279
- jax .tree .structure (emb_table ), jax .tree .leaves (emb_table_grads )
434
+ jax .tree .structure (emb_table ), jax .tree .leaves (updated_emb_table )
280
435
)
281
436
282
437
return (
@@ -286,3 +441,68 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):
286
441
287
442
288
443
_emb_lookup .defvjp (_emb_lookup_fwd , _emb_lookup_bwd )
444
+
445
+
446
+ SPARSECORE_PIPELINE_STATE_COLLECTION = 'sparsecore_pipeline_state'
447
+
448
+
449
+ ################################################################################
450
+ # Define custom VJP for embedding update for pipelining.
451
+ # This is used to prevent autodiff from differentiating through the optimizer
452
+ # step.
453
+ ################################################################################
454
+ @functools .partial (jax .custom_vjp , nondiff_argnums = (0 ,))
455
+ def _perform_update (
456
+ module : 'SparseCoreEmbed' ,
457
+ im2_inputs : embedding .PreprocessedInput ,
458
+ im2_grads : embedding .Nested [jax .Array ],
459
+ emb_table : embedding .Nested [jax .Array ],
460
+ ) -> embedding .Nested [jax .Array ]:
461
+ """Performs the embedding update, but is opaque to autodiff."""
462
+ _ , updated_table = _emb_lookup_bwd ( # pylint: disable=protected-access
463
+ module ,
464
+ (im2_inputs , emb_table ),
465
+ im2_grads ,
466
+ )
467
+ return updated_table
468
+
469
+
470
+ def _perform_update_fwd (
471
+ module : 'SparseCoreEmbed' ,
472
+ im2_inputs : embedding .PreprocessedInput ,
473
+ im2_grads : embedding .Nested [jax .Array ],
474
+ emb_table : embedding .Nested [jax .Array ],
475
+ ):
476
+ """Forward pass for _perform_update."""
477
+ updated_table = _perform_update (module , im2_inputs , im2_grads , emb_table )
478
+ # Return inputs as residuals for backward pass.
479
+ return updated_table , (im2_inputs , im2_grads , emb_table )
480
+
481
+
482
+ def _perform_update_bwd (
483
+ module : 'SparseCoreEmbed' ,
484
+ res : tuple [
485
+ embedding .PreprocessedInput ,
486
+ embedding .Nested [jax .Array ],
487
+ embedding .Nested [jax .Array ],
488
+ ],
489
+ g : embedding .Nested [jax .Array ],
490
+ ) -> tuple [
491
+ embedding .PreprocessedInput ,
492
+ embedding .Nested [jax .Array ],
493
+ embedding .Nested [jax .Array ],
494
+ ]:
495
+ """Backward pass for _perform_update."""
496
+ # g is the gradient w.r.t. the output (updated_table).
497
+ # We want this to flow back to the original emb_table as if this function
498
+ # was an identity function.
499
+ im2_inputs , im2_grads , emb_table = res
500
+ del module , emb_table
501
+ return (
502
+ jax .tree .map (jnp .zeros_like , im2_inputs ),
503
+ jax .tree .map (jnp .zeros_like , im2_grads ),
504
+ g , # Pass gradient through to the original embedding table.
505
+ )
506
+
507
+
508
+ _perform_update .defvjp (_perform_update_fwd , _perform_update_bwd )
0 commit comments