7
7
import keras
8
8
import numpy as np
9
9
from jax import numpy as jnp
10
+ from jax_tpu_embedding .sparsecore .lib .nn import embedding
10
11
from jax_tpu_embedding .sparsecore .lib .nn import embedding_spec
11
12
from jax_tpu_embedding .sparsecore .lib .nn .embedding_spec import FeatureSpec
12
13
from jax_tpu_embedding .sparsecore .lib .nn .embedding_spec import TableSpec
18
19
ArrayLike : TypeAlias = Union [jax .Array , np .ndarray [Any , Any ]]
19
20
Shape : TypeAlias = tuple [int , ...]
20
21
22
+ EmbeddingTable : TypeAlias = ArrayLike | embedding .EmbeddingVariables
23
+
21
24
22
25
def has_sparsecores () -> bool :
23
26
device_kind = jax .devices ()[0 ].device_kind
@@ -142,7 +145,7 @@ def create_tables(
142
145
def create_table_and_slot_variables (
143
146
table_specs : Nested [TableSpec ],
144
147
keys : Nested [ArrayLike ] | None = None ,
145
- ) -> Nested [ArrayLike ]:
148
+ ) -> Nested [embedding . EmbeddingVariables ]:
146
149
"""Creates and initializes embedding tables and slot variables.
147
150
148
151
Args:
@@ -164,7 +167,7 @@ def create_table_and_slot_variables(
164
167
def _create_table_and_slot_variables (
165
168
table_spec : TableSpec ,
166
169
key : ArrayLike ,
167
- ) -> tuple [ jax . Array , tuple [ jax . Array , ...]] :
170
+ ) -> embedding . EmbeddingVariables :
168
171
slot_initializers = table_spec .optimizer .slot_variables_initializers ()
169
172
num_slot_variables = len (keras .tree .flatten (slot_initializers ))
170
173
slot_keys = jnp .unstack (jax .random .split (key , num_slot_variables ))
@@ -178,10 +181,10 @@ def _create_table_and_slot_variables(
178
181
slot_initializers ,
179
182
slot_keys ,
180
183
)
181
- return (table , slot_variables )
184
+ return embedding . EmbeddingVariables (table , slot_variables )
182
185
183
186
# Initialize tables.
184
- output : Nested [ArrayLike ] = jax .tree .map (
187
+ output : Nested [embedding . EmbeddingVariables ] = jax .tree .map (
185
188
_create_table_and_slot_variables ,
186
189
table_specs ,
187
190
keys ,
@@ -311,14 +314,14 @@ def _create_samples(
311
314
312
315
def stack_shard_and_put_tables (
313
316
table_specs : Nested [TableSpec ],
314
- tables : Nested [ jax . Array ] ,
317
+ tables : EmbeddingTable ,
315
318
num_shards : int ,
316
319
sharding : jax .sharding .Sharding ,
317
- ) -> dict [str , Nested [ jax . Array ] ]:
320
+ ) -> dict [str , EmbeddingTable ]:
318
321
sharded_tables = embedding_utils .stack_and_shard_tables (
319
322
table_specs , tables , num_shards
320
323
)
321
- output : dict [str , Nested [ jax . Array ] ] = jax .device_put (
324
+ output : dict [str , T ] = jax .device_put (
322
325
jax .tree .map (
323
326
# Flatten shard dimension to allow auto-sharding to split the array.
324
327
lambda table : table .reshape ((- 1 , table .shape [- 1 ])),
@@ -469,27 +472,24 @@ def compute_expected_lookup_grad(
469
472
def _update_table_and_slot_variables (
470
473
table_spec : TableSpec ,
471
474
grad : jax .Array ,
472
- table_and_slot_variables : tuple [jax .Array , tuple [jax .Array , ...]],
473
- ) -> tuple [
474
- jax .Array ,
475
- embedding_spec .SGDSlotVariables | embedding_spec .AdagradSlotVariables ,
476
- ]:
475
+ table_and_slot_variables : embedding .EmbeddingVariables ,
476
+ ) -> embedding .EmbeddingVariables :
477
477
"""Updates a table and its slot variables based on the gradient."""
478
- table = table_and_slot_variables [ 0 ]
478
+ table = table_and_slot_variables . table
479
479
optimizer = table_spec .optimizer
480
480
481
481
# Adagrad, update and apply gradient accumulator.
482
482
if isinstance (optimizer , embedding_spec .AdagradOptimizerSpec ):
483
- accumulator = table_and_slot_variables [ 1 ][ 0 ]
483
+ accumulator = table_and_slot_variables . slot . accumulator
484
484
accumulator = accumulator + grad * grad
485
485
learning_rate = optimizer .get_learning_rate (0 ) / jnp .sqrt (accumulator )
486
- return (
486
+ return embedding . EmbeddingVariables (
487
487
table - learning_rate * grad ,
488
488
embedding_spec .AdagradSlotVariables (accumulator = accumulator ),
489
489
)
490
490
491
491
# SGD
492
- return (
492
+ return embedding . EmbeddingVariables (
493
493
table - optimizer .get_learning_rate (0 ) * grad ,
494
494
embedding_spec .SGDSlotVariables (),
495
495
)
@@ -500,8 +500,8 @@ def compute_expected_updates(
500
500
feature_samples : Nested [FeatureSamples ],
501
501
activation_gradients : Nested [jax .Array ],
502
502
table_specs : Nested [TableSpec ],
503
- table_and_slot_variables : Nested [jax . Array ],
504
- ) -> Nested [jax . Array ]:
503
+ table_and_slot_variables : Nested [embedding . EmbeddingVariables ],
504
+ ) -> Nested [embedding . EmbeddingVariables ]:
505
505
"""Computes the expected updates for a given embedding lookup.
506
506
507
507
Args:
@@ -522,7 +522,7 @@ def compute_expected_updates(
522
522
)
523
523
524
524
# Apply updates per table.
525
- output : Nested [jax . Array ] = jax .tree .map (
525
+ output : Nested [embedding . EmbeddingVariables ] = jax .tree .map (
526
526
_update_table_and_slot_variables ,
527
527
table_specs ,
528
528
table_grads ,
0 commit comments