Skip to content

Commit 2154eb2

Browse files
Hilly12recml authors
authored and
recml authors
committed
Add DLRM-V2 with sparsecore.
PiperOrigin-RevId: 745051634
1 parent c43d6ae commit 2154eb2

File tree

10 files changed

+1144
-38
lines changed

10 files changed

+1144
-38
lines changed

recml/core/data/iterator.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,15 @@ def __next__(self) -> clu_data.Element:
5757
if self._prefetched_batch is not None:
5858
batch = self._prefetched_batch
5959
self._prefetched_batch = None
60-
return batch
61-
62-
batch = next(self._iterator)
63-
if self._postprocessor is not None:
64-
batch = self._postprocessor(batch)
60+
else:
61+
batch = next(self._iterator)
62+
if self._postprocessor is not None:
63+
batch = self._postprocessor(batch)
6564

6665
def _maybe_to_numpy(
67-
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor,
66+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
6867
) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor:
69-
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)):
68+
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)):
7069
return x
7170
if hasattr(x, "_numpy"):
7271
numpy = x._numpy() # pylint: disable=protected-access
@@ -83,13 +82,16 @@ def _maybe_to_numpy(
8382
@property
8483
def element_spec(self) -> clu_data.ElementSpec:
8584
if self._element_spec is not None:
86-
batch = self._element_spec
87-
else:
88-
batch = self.__next__()
89-
self._prefetched_batch = batch
85+
return self._element_spec
86+
87+
batch = next(self._iterator)
88+
if self._postprocessor is not None:
89+
batch = self._postprocessor(batch)
90+
91+
self._prefetched_batch = batch
9092

9193
def _to_element_spec(
92-
x: np.ndarray | tf.SparseTensor | tf.RaggedTensor,
94+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
9395
) -> clu_data.ArraySpec:
9496
if isinstance(x, tf.SparseTensor):
9597
return clu_data.ArraySpec(
@@ -101,6 +103,10 @@ def _to_element_spec(
101103
dtype=x.dtype.as_numpy_dtype, # pylint: disable=attribute-error
102104
shape=tuple(x.shape.as_list()), # pylint: disable=attribute-error
103105
)
106+
if isinstance(x, tf.Tensor):
107+
return clu_data.ArraySpec(
108+
dtype=x.dtype.as_numpy_dtype, shape=tuple(x.shape.as_list())
109+
)
104110
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))
105111

106112
element_spec = tf.nest.map_structure(_to_element_spec, batch)

recml/core/ops/embedding_ops.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2024 RecML authors <[email protected]>.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Embedding lookup ops."""
15+
16+
from collections.abc import Mapping, Sequence
17+
import dataclasses
18+
import functools
19+
20+
from etils import epy
21+
import jax
22+
from jax.experimental import shard_map
23+
24+
with epy.lazy_imports():
25+
# pylint: disable=g-import-not-at-top
26+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
27+
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
28+
# pylint: enable=g-import-not-at-top
29+
30+
31+
@dataclasses.dataclass
32+
class SparsecoreParams:
33+
"""Embedding parameters."""
34+
35+
feature_specs: embedding.Nested[embedding_spec.FeatureSpec]
36+
abstract_mesh: jax.sharding.AbstractMesh
37+
data_axes: Sequence[str | None]
38+
embedding_axes: Sequence[str | None]
39+
sharding_strategy: str
40+
41+
42+
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
43+
def sparsecore_lookup(
44+
sparsecore_params: SparsecoreParams,
45+
tables: Mapping[str, tuple[jax.Array, ...]],
46+
csr_inputs: tuple[jax.Array, ...],
47+
):
48+
return shard_map.shard_map(
49+
functools.partial(
50+
embedding.tpu_sparse_dense_matmul,
51+
global_device_count=sparsecore_params.abstract_mesh.size,
52+
feature_specs=sparsecore_params.feature_specs,
53+
sharding_strategy=sparsecore_params.sharding_strategy,
54+
),
55+
mesh=sparsecore_params.abstract_mesh,
56+
in_specs=(
57+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
58+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
59+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
60+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
61+
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
62+
),
63+
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
64+
check_rep=False,
65+
)(*csr_inputs, tables)
66+
67+
68+
def _emb_lookup_fwd(
69+
sparsecore_params: SparsecoreParams,
70+
tables: Mapping[str, tuple[jax.Array, ...]],
71+
csr_inputs: tuple[jax.Array, ...],
72+
):
73+
out = sparsecore_lookup(sparsecore_params, tables, csr_inputs)
74+
return out, (tables, csr_inputs)
75+
76+
77+
def _emb_lookup_bwd(
78+
sparsecore_params: SparsecoreParams,
79+
res: tuple[Mapping[str, tuple[jax.Array, ...]], tuple[jax.Array, ...]],
80+
gradients: embedding.Nested[jax.Array],
81+
) -> tuple[embedding.Nested[jax.Array], None]:
82+
"""Backward pass for embedding lookup."""
83+
(tables, csr_inputs) = res
84+
85+
emb_table_grads = shard_map.shard_map(
86+
functools.partial(
87+
embedding.tpu_sparse_dense_matmul_grad,
88+
feature_specs=sparsecore_params.feature_specs,
89+
sharding_strategy=sparsecore_params.sharding_strategy,
90+
),
91+
mesh=sparsecore_params.abstract_mesh,
92+
in_specs=(
93+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
94+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
95+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
96+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
97+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
98+
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
99+
),
100+
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
101+
check_rep=False,
102+
)(gradients, *csr_inputs, tables)
103+
104+
# `tpu_sparse_dense_matmul_grad` returns a general mapping (usually a dict).
105+
# It may not be the same type as the embedding table (e.g. FrozenDict).
106+
# Here we use flatten / unflatten to ensure the types are the same.
107+
emb_table_grads = jax.tree.unflatten(
108+
jax.tree.structure(tables), jax.tree.leaves(emb_table_grads)
109+
)
110+
111+
return emb_table_grads, None
112+
113+
114+
sparsecore_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)

recml/core/training/jax.py

+49-6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from clu import periodic_actions
2727
import clu.metrics as clu_metrics
2828
from flax import struct
29+
import flax.linen as nn
2930
import jax
3031
import jax.numpy as jnp
3132
import keras
@@ -67,43 +68,85 @@ class JaxState(struct.PyTreeNode, Generic[MetaT]):
6768
step: A counter of the current step of the job. It starts at zero and it is
6869
incremented by 1 on a call to `state.update(...)`. This should be a Jax
6970
array and not a Python integer.
70-
apply: A function that can be used to apply the forward pass of the model.
71-
For Flax models this is usually set to `model.apply`.
7271
params: A pytree of trainable variables that will be updated by `tx` and
7372
used in `apply`.
7473
tx: An optax gradient transformation that will be used to update the
7574
parameters contained in `params` on a call to `state.update(...)`.
7675
opt_state: The optimizer state for `tx`. This is usually created by calling
7776
`tx.init(params)`.
77+
_apply: An optional function that can be used to apply the forward pass of
78+
the model. For Flax models this is usually set to `model.apply` while for
79+
Haiku models this is usually set to `transform.apply`.
80+
_model: An optional reference to a stateless Flax model for convenience.
7881
mutable: A pytree of mutable variables that are used by `apply`.
7982
meta: Arbitrary metadata that is recorded on the state. This can be useful
8083
for tracking additional references in the state.
8184
"""
8285

8386
step: jax.Array
84-
apply: Callable[..., Any] = struct.field(pytree_node=False)
8587
params: PyTree = struct.field(pytree_node=True)
8688
tx: optax.GradientTransformation = struct.field(pytree_node=False)
8789
opt_state: optax.OptState = struct.field(pytree_node=True)
8890
mutable: PyTree = struct.field(pytree_node=True, default_factory=dict)
8991
meta: MetaT = struct.field(pytree_node=False, default_factory=dict)
92+
_apply: Callable[..., Any] | None = struct.field(
93+
pytree_node=False, default_factory=None
94+
)
95+
_model: nn.Module | None = struct.field(pytree_node=False, default=None)
96+
97+
@property
98+
def model(self) -> nn.Module:
99+
"""Returns a reference to the model used to create the state."""
100+
if self._model is None:
101+
raise ValueError("No Flax `model` is set on the state.")
102+
return self._model
103+
104+
def apply(self, *args, **kwargs) -> Any:
105+
"""Applies the forward pass of the model."""
106+
if self._apply is None:
107+
raise ValueError("No `apply` function is set on the state.")
108+
return self._apply(*args, **kwargs)
90109

91110
@classmethod
92111
def create(
93112
cls,
94113
*,
95-
apply: Callable[..., Any],
114+
apply: Callable[..., Any] | None = None,
115+
model: nn.Module | None = None,
96116
params: PyTree,
97117
tx: optax.GradientTransformation,
98118
**kwargs,
99119
) -> Self:
100-
"""Creates a new instance from a Jax apply function and Optax optimizer."""
120+
"""Creates a new instance from a Jax model / apply fn and Optax optimizer.
121+
122+
Args:
123+
apply: A function that can be used to apply the forward pass of the model.
124+
For Flax models this is usually set to `model.apply`. This cannot be set
125+
along with `model`.
126+
model: A reference to a stateless Flax model. This cannot be set along
127+
with `apply`. When set the `apply` attribute of the state will be set to
128+
`model.apply`.
129+
params: A pytree of trainable variables that will be updated by `tx` and
130+
used in `apply`.
131+
tx: An optax gradient transformation that will be used to update the
132+
parameters contained in `params` on a call to `state.update(...)`.
133+
**kwargs: Other updates to set on the new state.
134+
135+
Returns:
136+
An new instance of the state.
137+
"""
138+
if apply is not None and model is not None:
139+
raise ValueError("Only one of `apply` or `model` can be provided.")
140+
elif model is not None:
141+
apply = model.apply
142+
101143
return cls(
102144
step=jnp.zeros([], dtype=jnp.int32),
103-
apply=apply,
104145
params=params,
105146
tx=tx,
106147
opt_state=tx.init(params),
148+
_apply=apply,
149+
_model=model,
107150
**kwargs,
108151
)
109152

recml/core/training/optax_factory.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ def _default_weight_decay_mask(params: optax.Params) -> optax.Params:
2929

3030

3131
def _regex_mask(regex: str) -> Callable[[optax.Params], optax.Params]:
32-
"""Returns a weight decay mask that applies to parameters matching a regex."""
32+
"""Returns a mask that applies to parameters matching a regex."""
3333

3434
def _matches_regex(path: tuple[str, ...], _: Any) -> bool:
35-
key = "/".join([jax.tree_util.keystr((k,), simple=True) for k in path])
35+
key = '/'.join([jax.tree_util.keystr((k,), simple=True) for k in path])
3636
return re.fullmatch(regex, key) is not None
3737

3838
def _mask(params: optax.Params) -> optax.Params:
@@ -54,6 +54,8 @@ class OptimizerFactory(types.Factory[optax.GradientTransformation]):
5454
magnitude of the gradients during optimization. Defaults to None.
5555
weight_decay_mask: The weight decay mask to use when applying weight decay.
5656
Defaults applying weight decay to all non-1D parameters.
57+
freeze_mask: Optional mask to freeze parameters during optimization.
58+
Defaults to None.
5759
5860
Example usage:
5961
@@ -78,6 +80,7 @@ class OptimizerFactory(types.Factory[optax.GradientTransformation]):
7880
weight_decay_mask: str | Callable[[optax.Params], optax.Params] = (
7981
_default_weight_decay_mask
8082
)
83+
freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None
8184

8285
def make(self) -> optax.GradientTransformation:
8386
if self.grad_clip_norm is not None:
@@ -99,13 +102,30 @@ def make(self) -> optax.GradientTransformation:
99102
else:
100103
weight_decay = optax.identity()
101104

102-
return optax.chain(*[
105+
tx = optax.chain(*[
103106
apply_clipping,
104107
self.scaling,
105108
weight_decay,
106109
lr_scaling,
107110
])
108111

112+
if self.freeze_mask is not None:
113+
if isinstance(self.freeze_mask, str):
114+
mask = _regex_mask(self.freeze_mask)
115+
else:
116+
mask = self.freeze_mask
117+
118+
def _param_labels(params: optax.Params) -> optax.Params:
119+
return jax.tree.map(
120+
lambda p: 'frozen' if mask(p) else 'trainable', params
121+
)
122+
123+
tx = optax.multi_transform(
124+
transforms={'trainable': tx, 'frozen': optax.set_to_zero()},
125+
param_labels=_param_labels,
126+
)
127+
return tx
128+
109129

110130
class AdamFactory(types.Factory[optax.GradientTransformation]):
111131
"""Adam optimizer factory.
@@ -121,6 +141,8 @@ class AdamFactory(types.Factory[optax.GradientTransformation]):
121141
magnitude of the gradients during optimization. Defaults to None.
122142
weight_decay_mask: The weight decay mask to use when applying weight decay.
123143
Defaults applying weight decay to all non-1D parameters.
144+
freeze_mask: Optional mask to freeze parameters during optimization.
145+
Defaults to None.
124146
125147
Example usage:
126148
```
@@ -143,6 +165,7 @@ class AdamFactory(types.Factory[optax.GradientTransformation]):
143165
weight_decay_mask: str | Callable[[optax.Params], optax.Params] = (
144166
_default_weight_decay_mask
145167
)
168+
freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None
146169

147170
def make(self) -> optax.GradientTransformation:
148171
return OptimizerFactory(
@@ -164,6 +187,8 @@ class AdagradFactory(types.Factory[optax.GradientTransformation]):
164187
eps: The epsilon coefficient for the Adagrad optimizer. Defaults to 1e-7.
165188
grad_clip_norm: Optional gradient clipping norm to limit the maximum
166189
magnitude of the gradients during optimization. Defaults to None.
190+
freeze_mask: Optional mask to freeze parameters during optimization.
191+
Defaults to None.
167192
168193
Example usage:
169194
```
@@ -175,6 +200,7 @@ class AdagradFactory(types.Factory[optax.GradientTransformation]):
175200
initial_accumulator_value: float = 0.1
176201
eps: float = 1e-7
177202
grad_clip_norm: float | None = None
203+
freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None
178204

179205
def make(self) -> optax.GradientTransformation:
180206
return OptimizerFactory(

0 commit comments

Comments
 (0)