Skip to content

Commit 917f42c

Browse files
Allow customizing the stack name when using table_stacking.stack_tables.
PiperOrigin-RevId: 715449105
1 parent 6b5bae7 commit 917f42c

File tree

2 files changed

+102
-8
lines changed

2 files changed

+102
-8
lines changed

jax_tpu_embedding/sparsecore/lib/nn/table_stacking.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ def round_up_dim_and_vocab_size(
7676
return table_to_padded_dim, tables_to_padded_vocab_size
7777

7878

79-
# TODO(b/355289256): Consider better name for the stack.
80-
def _get_stack_name(table_names: Sequence[str]) -> str:
81-
return "_".join(table_names)
82-
83-
8479
def _get_stack_table_names(
8580
tables: Mapping[str, embedding_spec.TableSpec], num_sc: int
8681
) -> Sequence[Sequence[str]]:
@@ -96,12 +91,13 @@ def _get_stack_table_names(
9691

9792

9893
def _verify_stack_tables(
94+
stack_name: str,
9995
table_names: Sequence[str],
96+
features: Sequence[embedding_spec.FeatureSpec],
10097
tables: Mapping[str, embedding_spec.TableSpec],
10198
table_to_padded_dim: Mapping[str, int],
10299
):
103100
"""Verifies that the provided stacking groups are valid."""
104-
stack_name = _get_stack_name(table_names)
105101
logging.vlog(
106102
2,
107103
"Verifying stack group: %s with tables: %s",
@@ -122,6 +118,15 @@ def _is_stacked_already(table: embedding_spec.TableSpec):
122118
f" {tables[tname].setting_in_stack.stack_name}."
123119
)
124120

121+
for feature in features:
122+
if (
123+
_is_stacked_already(feature.table_spec)
124+
and feature.table_spec.setting_in_stack.stack_name == stack_name
125+
):
126+
raise ValueError(
127+
f"Cannot use stack name {stack_name} since it's already used."
128+
)
129+
125130
# A table should not be repeated in a group.
126131
counter = collections.Counter(table_names)
127132
for table, count in counter.items():
@@ -214,6 +219,7 @@ def _get_limits_for_stack(
214219

215220

216221
def _stack_feature_specs(
222+
stack_name: str,
217223
features: Nested[embedding_spec.FeatureSpec],
218224
table_names: Sequence[str],
219225
table_to_padded_dim: Mapping[str, int],
@@ -225,7 +231,6 @@ def _stack_feature_specs(
225231
) -> None:
226232
"""Updated the feature spec based on provided groups and stacking logic."""
227233

228-
stack_name = _get_stack_name(table_names)
229234
table_name_to_feature_spec = {
230235
f.table_spec.name: f for f in tree.flatten(features)
231236
}
@@ -324,6 +329,7 @@ def stack_tables(
324329
num_sc_per_device: int = 4,
325330
stack_to_max_ids_per_partition: LimitsCallable = get_default_limits,
326331
stack_to_max_unique_ids_per_partition: LimitsCallable = get_default_limits,
332+
stack_name: str | None = None,
327333
) -> None:
328334
"""Creates new feature specs based on specified stacking groups.
329335
@@ -341,7 +347,12 @@ def stack_tables(
341347
stack.
342348
stack_to_max_unique_ids_per_partition: Override the
343349
max_unique_ids_per_partition for each stack.
350+
stack_name: A unique name for the table stack. If None, a default name will
351+
be chosen.
344352
"""
353+
if stack_name is None:
354+
# TODO(b/355289256): Consider better name for the stack.
355+
stack_name = "_".join(table_names)
345356
flatten_features = tree.flatten(features)
346357
tables_in_group = {
347358
feature.table_spec.name: feature.table_spec
@@ -353,8 +364,15 @@ def stack_tables(
353364
tables_in_group, num_sc_per_device * global_device_count
354365
)
355366
)
356-
_verify_stack_tables(table_names, tables_in_group, table_to_padded_dim)
367+
_verify_stack_tables(
368+
stack_name,
369+
table_names,
370+
flatten_features,
371+
tables_in_group,
372+
table_to_padded_dim,
373+
)
357374
_stack_feature_specs(
375+
stack_name=stack_name,
358376
features=features,
359377
table_names=table_names,
360378
table_to_padded_dim=table_to_padded_dim,

jax_tpu_embedding/sparsecore/lib/nn/tests/table_stacking_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,82 @@ def limits(name: str, batch_size: int) -> int:
371371
0.5
372372
)
373373

374+
def test_manual_stacking_reuse_table_name(self):
375+
table_spec_a = embedding_spec.TableSpec(
376+
vocabulary_size=64,
377+
embedding_dim=12,
378+
initializer=lambda: jnp.zeros((64, 16), dtype=jnp.float32),
379+
optimizer=embedding_spec.SGDOptimizerSpec(learning_rate=0.5),
380+
combiner='sum',
381+
name='A',
382+
max_ids_per_partition=16,
383+
max_unique_ids_per_partition=16,
384+
)
385+
table_spec_b = embedding_spec.TableSpec(
386+
vocabulary_size=120,
387+
embedding_dim=10,
388+
initializer=lambda: jnp.zeros((128, 16), dtype=jnp.float32),
389+
optimizer=embedding_spec.SGDOptimizerSpec(learning_rate=0.5),
390+
combiner='sum',
391+
name='B',
392+
max_ids_per_partition=16,
393+
max_unique_ids_per_partition=16,
394+
)
395+
table_spec_c = embedding_spec.TableSpec(
396+
vocabulary_size=120,
397+
embedding_dim=10,
398+
initializer=lambda: jnp.zeros((128, 16), dtype=jnp.float32),
399+
optimizer=embedding_spec.SGDOptimizerSpec(learning_rate=0.5),
400+
combiner='sum',
401+
name='C',
402+
max_ids_per_partition=16,
403+
max_unique_ids_per_partition=16,
404+
)
405+
feature_specs = [
406+
embedding_spec.FeatureSpec(
407+
table_spec=table_spec_a,
408+
input_shape=(16, 1),
409+
output_shape=(
410+
16,
411+
table_spec_a.embedding_dim,
412+
),
413+
name='feature_a',
414+
),
415+
embedding_spec.FeatureSpec(
416+
table_spec=table_spec_b,
417+
input_shape=(16, 1),
418+
output_shape=(
419+
16,
420+
table_spec_b.embedding_dim,
421+
),
422+
name='feature_b',
423+
),
424+
embedding_spec.FeatureSpec(
425+
table_spec=table_spec_c,
426+
input_shape=(16, 1),
427+
output_shape=(
428+
16,
429+
table_spec_c.embedding_dim,
430+
),
431+
name='feature_c',
432+
),
433+
]
434+
435+
table_stacking.stack_tables(
436+
feature_specs,
437+
('A', 'B'),
438+
global_device_count=jax.device_count(),
439+
stack_name='custom_stack',
440+
)
441+
442+
with self.assertRaisesRegex(ValueError, 'custom_stack.*already used.*'):
443+
table_stacking.stack_tables(
444+
feature_specs,
445+
('C',),
446+
global_device_count=jax.device_count(),
447+
stack_name='custom_stack',
448+
)
449+
374450
def test_manual_stacking_not_same_optimizer(self):
375451
table_spec_a = embedding_spec.TableSpec(
376452
vocabulary_size=64,

0 commit comments

Comments
 (0)