Skip to content

Add DLRM-V2 with sparsecore. #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions recml/core/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,15 @@ def __next__(self) -> clu_data.Element:
if self._prefetched_batch is not None:
batch = self._prefetched_batch
self._prefetched_batch = None
return batch

batch = next(self._iterator)
if self._postprocessor is not None:
batch = self._postprocessor(batch)
else:
batch = next(self._iterator)
if self._postprocessor is not None:
batch = self._postprocessor(batch)

def _maybe_to_numpy(
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor,
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor:
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)):
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)):
return x
if hasattr(x, "_numpy"):
numpy = x._numpy() # pylint: disable=protected-access
Expand All @@ -83,13 +82,16 @@ def _maybe_to_numpy(
@property
def element_spec(self) -> clu_data.ElementSpec:
if self._element_spec is not None:
batch = self._element_spec
else:
batch = self.__next__()
self._prefetched_batch = batch
return self._element_spec

batch = next(self._iterator)
if self._postprocessor is not None:
batch = self._postprocessor(batch)

self._prefetched_batch = batch

def _to_element_spec(
x: np.ndarray | tf.SparseTensor | tf.RaggedTensor,
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
) -> clu_data.ArraySpec:
if isinstance(x, tf.SparseTensor):
return clu_data.ArraySpec(
Expand All @@ -101,6 +103,10 @@ def _to_element_spec(
dtype=x.dtype.as_numpy_dtype, # pylint: disable=attribute-error
shape=tuple(x.shape.as_list()), # pylint: disable=attribute-error
)
if isinstance(x, tf.Tensor):
return clu_data.ArraySpec(
dtype=x.dtype.as_numpy_dtype, shape=tuple(x.shape.as_list())
)
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))

element_spec = tf.nest.map_structure(_to_element_spec, batch)
Expand Down
114 changes: 114 additions & 0 deletions recml/core/ops/embedding_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2024 RecML authors <[email protected]>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Embedding lookup ops."""

from collections.abc import Mapping, Sequence
import dataclasses
import functools

from etils import epy
import jax
from jax.experimental import shard_map

with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from jax_tpu_embedding.sparsecore.lib.nn import embedding
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
# pylint: enable=g-import-not-at-top


@dataclasses.dataclass
class SparsecoreParams:
"""Embedding parameters."""

feature_specs: embedding.Nested[embedding_spec.FeatureSpec]
abstract_mesh: jax.sharding.AbstractMesh
data_axes: Sequence[str | None]
embedding_axes: Sequence[str | None]
sharding_strategy: str


@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
def sparsecore_lookup(
sparsecore_params: SparsecoreParams,
tables: Mapping[str, tuple[jax.Array, ...]],
csr_inputs: tuple[jax.Array, ...],
):
return shard_map.shard_map(
functools.partial(
embedding.tpu_sparse_dense_matmul,
global_device_count=sparsecore_params.abstract_mesh.size,
feature_specs=sparsecore_params.feature_specs,
sharding_strategy=sparsecore_params.sharding_strategy,
),
mesh=sparsecore_params.abstract_mesh,
in_specs=(
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
),
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
check_rep=False,
)(*csr_inputs, tables)


def _emb_lookup_fwd(
sparsecore_params: SparsecoreParams,
tables: Mapping[str, tuple[jax.Array, ...]],
csr_inputs: tuple[jax.Array, ...],
):
out = sparsecore_lookup(sparsecore_params, tables, csr_inputs)
return out, (tables, csr_inputs)


def _emb_lookup_bwd(
sparsecore_params: SparsecoreParams,
res: tuple[Mapping[str, tuple[jax.Array, ...]], tuple[jax.Array, ...]],
gradients: embedding.Nested[jax.Array],
) -> tuple[embedding.Nested[jax.Array], None]:
"""Backward pass for embedding lookup."""
(tables, csr_inputs) = res

emb_table_grads = shard_map.shard_map(
functools.partial(
embedding.tpu_sparse_dense_matmul_grad,
feature_specs=sparsecore_params.feature_specs,
sharding_strategy=sparsecore_params.sharding_strategy,
),
mesh=sparsecore_params.abstract_mesh,
in_specs=(
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
),
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
check_rep=False,
)(gradients, *csr_inputs, tables)

# `tpu_sparse_dense_matmul_grad` returns a general mapping (usually a dict).
# It may not be the same type as the embedding table (e.g. FrozenDict).
# Here we use flatten / unflatten to ensure the types are the same.
emb_table_grads = jax.tree.unflatten(
jax.tree.structure(tables), jax.tree.leaves(emb_table_grads)
)

return emb_table_grads, None


sparsecore_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)
55 changes: 49 additions & 6 deletions recml/core/training/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from clu import periodic_actions
import clu.metrics as clu_metrics
from flax import struct
import flax.linen as nn
import jax
import jax.numpy as jnp
import keras
Expand Down Expand Up @@ -67,43 +68,85 @@ class JaxState(struct.PyTreeNode, Generic[MetaT]):
step: A counter of the current step of the job. It starts at zero and it is
incremented by 1 on a call to `state.update(...)`. This should be a Jax
array and not a Python integer.
apply: A function that can be used to apply the forward pass of the model.
For Flax models this is usually set to `model.apply`.
params: A pytree of trainable variables that will be updated by `tx` and
used in `apply`.
tx: An optax gradient transformation that will be used to update the
parameters contained in `params` on a call to `state.update(...)`.
opt_state: The optimizer state for `tx`. This is usually created by calling
`tx.init(params)`.
_apply: An optional function that can be used to apply the forward pass of
the model. For Flax models this is usually set to `model.apply` while for
Haiku models this is usually set to `transform.apply`.
_model: An optional reference to a stateless Flax model for convenience.
mutable: A pytree of mutable variables that are used by `apply`.
meta: Arbitrary metadata that is recorded on the state. This can be useful
for tracking additional references in the state.
"""

step: jax.Array
apply: Callable[..., Any] = struct.field(pytree_node=False)
params: PyTree = struct.field(pytree_node=True)
tx: optax.GradientTransformation = struct.field(pytree_node=False)
opt_state: optax.OptState = struct.field(pytree_node=True)
mutable: PyTree = struct.field(pytree_node=True, default_factory=dict)
meta: MetaT = struct.field(pytree_node=False, default_factory=dict)
_apply: Callable[..., Any] | None = struct.field(
pytree_node=False, default_factory=None
)
_model: nn.Module | None = struct.field(pytree_node=False, default=None)

@property
def model(self) -> nn.Module:
"""Returns a reference to the model used to create the state."""
if self._model is None:
raise ValueError("No Flax `model` is set on the state.")
return self._model

def apply(self, *args, **kwargs) -> Any:
"""Applies the forward pass of the model."""
if self._apply is None:
raise ValueError("No `apply` function is set on the state.")
return self._apply(*args, **kwargs)

@classmethod
def create(
cls,
*,
apply: Callable[..., Any],
apply: Callable[..., Any] | None = None,
model: nn.Module | None = None,
params: PyTree,
tx: optax.GradientTransformation,
**kwargs,
) -> Self:
"""Creates a new instance from a Jax apply function and Optax optimizer."""
"""Creates a new instance from a Jax model / apply fn and Optax optimizer.

Args:
apply: A function that can be used to apply the forward pass of the model.
For Flax models this is usually set to `model.apply`. This cannot be set
along with `model`.
model: A reference to a stateless Flax model. This cannot be set along
with `apply`. When set the `apply` attribute of the state will be set to
`model.apply`.
params: A pytree of trainable variables that will be updated by `tx` and
used in `apply`.
tx: An optax gradient transformation that will be used to update the
parameters contained in `params` on a call to `state.update(...)`.
**kwargs: Other updates to set on the new state.

Returns:
An new instance of the state.
"""
if apply is not None and model is not None:
raise ValueError("Only one of `apply` or `model` can be provided.")
elif model is not None:
apply = model.apply

return cls(
step=jnp.zeros([], dtype=jnp.int32),
apply=apply,
params=params,
tx=tx,
opt_state=tx.init(params),
_apply=apply,
_model=model,
**kwargs,
)

Expand Down
32 changes: 29 additions & 3 deletions recml/core/training/optax_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def _default_weight_decay_mask(params: optax.Params) -> optax.Params:


def _regex_mask(regex: str) -> Callable[[optax.Params], optax.Params]:
"""Returns a weight decay mask that applies to parameters matching a regex."""
"""Returns a mask that applies to parameters matching a regex."""

def _matches_regex(path: tuple[str, ...], _: Any) -> bool:
key = "/".join([jax.tree_util.keystr((k,), simple=True) for k in path])
key = '/'.join([jax.tree_util.keystr((k,), simple=True) for k in path])
return re.fullmatch(regex, key) is not None

def _mask(params: optax.Params) -> optax.Params:
Expand All @@ -54,6 +54,8 @@ class OptimizerFactory(types.Factory[optax.GradientTransformation]):
magnitude of the gradients during optimization. Defaults to None.
weight_decay_mask: The weight decay mask to use when applying weight decay.
Defaults applying weight decay to all non-1D parameters.
freeze_mask: Optional mask to freeze parameters during optimization.
Defaults to None.

Example usage:

Expand All @@ -78,6 +80,7 @@ class OptimizerFactory(types.Factory[optax.GradientTransformation]):
weight_decay_mask: str | Callable[[optax.Params], optax.Params] = (
_default_weight_decay_mask
)
freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None

def make(self) -> optax.GradientTransformation:
if self.grad_clip_norm is not None:
Expand All @@ -99,13 +102,30 @@ def make(self) -> optax.GradientTransformation:
else:
weight_decay = optax.identity()

return optax.chain(*[
tx = optax.chain(*[
apply_clipping,
self.scaling,
weight_decay,
lr_scaling,
])

if self.freeze_mask is not None:
if isinstance(self.freeze_mask, str):
mask = _regex_mask(self.freeze_mask)
else:
mask = self.freeze_mask

def _param_labels(params: optax.Params) -> optax.Params:
return jax.tree.map(
lambda p: 'frozen' if mask(p) else 'trainable', params
)

tx = optax.multi_transform(
transforms={'trainable': tx, 'frozen': optax.set_to_zero()},
param_labels=_param_labels,
)
return tx


class AdamFactory(types.Factory[optax.GradientTransformation]):
"""Adam optimizer factory.
Expand All @@ -121,6 +141,8 @@ class AdamFactory(types.Factory[optax.GradientTransformation]):
magnitude of the gradients during optimization. Defaults to None.
weight_decay_mask: The weight decay mask to use when applying weight decay.
Defaults applying weight decay to all non-1D parameters.
freeze_mask: Optional mask to freeze parameters during optimization.
Defaults to None.

Example usage:
```
Expand All @@ -143,6 +165,7 @@ class AdamFactory(types.Factory[optax.GradientTransformation]):
weight_decay_mask: str | Callable[[optax.Params], optax.Params] = (
_default_weight_decay_mask
)
freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None

def make(self) -> optax.GradientTransformation:
return OptimizerFactory(
Expand All @@ -164,6 +187,8 @@ class AdagradFactory(types.Factory[optax.GradientTransformation]):
eps: The epsilon coefficient for the Adagrad optimizer. Defaults to 1e-7.
grad_clip_norm: Optional gradient clipping norm to limit the maximum
magnitude of the gradients during optimization. Defaults to None.
freeze_mask: Optional mask to freeze parameters during optimization.
Defaults to None.

Example usage:
```
Expand All @@ -175,6 +200,7 @@ class AdagradFactory(types.Factory[optax.GradientTransformation]):
initial_accumulator_value: float = 0.1
eps: float = 1e-7
grad_clip_norm: float | None = None
freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None

def make(self) -> optax.GradientTransformation:
return OptimizerFactory(
Expand Down
Loading
Loading