@@ -76,11 +76,6 @@ def round_up_dim_and_vocab_size(
7676 return table_to_padded_dim , tables_to_padded_vocab_size
7777
7878
79- # TODO(b/355289256): Consider better name for the stack.
80- def _get_stack_name (table_names : Sequence [str ]) -> str :
81- return "_" .join (table_names )
82-
83-
8479def _get_stack_table_names (
8580 tables : Mapping [str , embedding_spec .TableSpec ], num_sc : int
8681) -> Sequence [Sequence [str ]]:
@@ -96,12 +91,13 @@ def _get_stack_table_names(
9691
9792
9893def _verify_stack_tables (
94+ stack_name : str ,
9995 table_names : Sequence [str ],
96+ features : Sequence [embedding_spec .FeatureSpec ],
10097 tables : Mapping [str , embedding_spec .TableSpec ],
10198 table_to_padded_dim : Mapping [str , int ],
10299):
103100 """Verifies that the provided stacking groups are valid."""
104- stack_name = _get_stack_name (table_names )
105101 logging .vlog (
106102 2 ,
107103 "Verifying stack group: %s with tables: %s" ,
@@ -122,6 +118,15 @@ def _is_stacked_already(table: embedding_spec.TableSpec):
122118 f" { tables [tname ].setting_in_stack .stack_name } ."
123119 )
124120
121+ for feature in features :
122+ if (
123+ _is_stacked_already (feature .table_spec )
124+ and feature .table_spec .setting_in_stack .stack_name == stack_name
125+ ):
126+ raise ValueError (
127+ f"Cannot use stack name { stack_name } since it's already used."
128+ )
129+
125130 # A table should not be repeated in a group.
126131 counter = collections .Counter (table_names )
127132 for table , count in counter .items ():
@@ -214,6 +219,7 @@ def _get_limits_for_stack(
214219
215220
216221def _stack_feature_specs (
222+ stack_name : str ,
217223 features : Nested [embedding_spec .FeatureSpec ],
218224 table_names : Sequence [str ],
219225 table_to_padded_dim : Mapping [str , int ],
@@ -225,7 +231,6 @@ def _stack_feature_specs(
225231) -> None :
226232 """Updated the feature spec based on provided groups and stacking logic."""
227233
228- stack_name = _get_stack_name (table_names )
229234 table_name_to_feature_spec = {
230235 f .table_spec .name : f for f in tree .flatten (features )
231236 }
@@ -324,6 +329,7 @@ def stack_tables(
324329 num_sc_per_device : int = 4 ,
325330 stack_to_max_ids_per_partition : LimitsCallable = get_default_limits ,
326331 stack_to_max_unique_ids_per_partition : LimitsCallable = get_default_limits ,
332+ stack_name : str | None = None ,
327333) -> None :
328334 """Creates new feature specs based on specified stacking groups.
329335
@@ -341,7 +347,12 @@ def stack_tables(
341347 stack.
342348 stack_to_max_unique_ids_per_partition: Override the
343349 max_unique_ids_per_partition for each stack.
350+ stack_name: A unique name for the table stack. If None, a default name will
351+ be chosen.
344352 """
353+ if stack_name is None :
354+ # TODO(b/355289256): Consider better name for the stack.
355+ stack_name = "_" .join (table_names )
345356 flatten_features = tree .flatten (features )
346357 tables_in_group = {
347358 feature .table_spec .name : feature .table_spec
@@ -353,8 +364,15 @@ def stack_tables(
353364 tables_in_group , num_sc_per_device * global_device_count
354365 )
355366 )
356- _verify_stack_tables (table_names , tables_in_group , table_to_padded_dim )
367+ _verify_stack_tables (
368+ stack_name ,
369+ table_names ,
370+ flatten_features ,
371+ tables_in_group ,
372+ table_to_padded_dim ,
373+ )
357374 _stack_feature_specs (
375+ stack_name = stack_name ,
358376 features = features ,
359377 table_names = table_names ,
360378 table_to_padded_dim = table_to_padded_dim ,
0 commit comments