Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions recml/core/data/tf_dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import re
from typing import Any, Protocol

from absl import flags
from absl import logging
import jax
from recml.core.utils import types
Expand Down Expand Up @@ -162,12 +163,17 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
Defaults to False.
seed: An optional seed to use for deterministic shuffling / preprocessing.
Defaults to None.
tf_data_service_address: An optional URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no
data service will be applied.
enable_tf_data_service: Whether to apply tf.data service for this dataset.
If True, flag `tf_data_service_address` must be set.
tf_data_service_policy: Sharding policy to use for tf.data service when it
is enabled.
tf_data_service_job_name: Job name to use for tf.data service. If None, the
default job name will be used.
tf_data_service_replicate_on_split: Whether to replicate the file dataset on
split when distributing data to tf.data service workers. Note: it could be
used in the case where multiple datasets are processed together under
`Dynamic` mode. The dataset with `tf_data_service_replicate_on_split`
enabled is equivalent to having that dataset processed as `Off` mode.
feature_spec: A mapping of feature keys to `FixedLenFeature`,
`VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be
used to parse the TF examples, or as context_features spec to parse TF
Expand Down Expand Up @@ -208,7 +214,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
tensorflow.
debug: An optional boolean indicating whether to debug input boundedness. If
`True`, the dataset will consist of a single batch that's cached and
infinitely repeated
infinitely repeated.
"""

cache_reading: bool = False
Expand All @@ -231,7 +237,8 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
readahead: str | None = None
group_uris_by_dir: bool = False
seed: int | None = None
tf_data_service_address: str | None = None
enable_tf_data_service: bool = False
tf_data_service_job_name: str | None = None
tf_data_service_policy: tf.data.experimental.service.ShardingPolicy = (
tf.data.experimental.service.ShardingPolicy.OFF
)
Expand All @@ -246,10 +253,16 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
sharding_info: DatasetShardingInfo = dataclasses.field(
default_factory=DatasetShardingInfo
)
tf_data_service_replicate_on_split: bool = False
debug: bool = False

def __post_init__(self):
if self.tf_data_service_address is not None:
if self.enable_tf_data_service:
if flags.FLAGS.tf_data_service_address is None:
raise ValueError(
"Flag `tf_data_service_address` must be set when"
" `enable_tf_data_service` is True."
)
if self.seed is not None:
raise ValueError("`seed` must be None for data service.")
if self.sharding:
Expand Down Expand Up @@ -464,6 +477,9 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
# Create a dataset of file / file group uris.
dataset = tf.data.Dataset.from_tensor_slices(uris)

if self.tf_data_service_replicate_on_split:
dataset = tf.data.apply_rewrite(dataset, rewrite="replicate_on_split")

# Repeat the dataset. We might need to repeat the dataset here in case the
# issue is encountered: internal screenshot link:6jAKKoEMT3afXRe
# even we do have enough shards for the input data.
Expand Down Expand Up @@ -533,23 +549,26 @@ def _maybe_apply_tf_data_service(
self, dataset: tf.data.Dataset
) -> tf.data.Dataset:
"""Applies the tf.data service to the dataset."""
if self.tf_data_service_address is None:
if not self.enable_tf_data_service:
return dataset

tf_data_service_address = flags.FLAGS.tf_data_service_address

per_proc_batch_size = self.sharding_info.per_process_batch_size(
self.global_batch_size
)
logging.info(
"Applying tf.data service with address %s and per replica batch"
" size %s",
self.tf_data_service_address,
tf_data_service_address,
per_proc_batch_size,
)
return dataset.apply(
tf.data.experimental.service.distribute(
processing_mode=self.tf_data_service_policy,
service=self.tf_data_service_address,
job_name=f"bs_{per_proc_batch_size}",
service=tf_data_service_address,
job_name=self.tf_data_service_job_name
or "tf_data_service_shared_job_name",
)
)

Expand Down
16 changes: 8 additions & 8 deletions recml/core/ops/hstu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def _apply_mask(
masks = []
if mask_ref is not None:
if k_in_lanes:
mask = pl.load(mask_ref, (slice(None), k_slice))
mask = mask_ref[:, k_slice]
else:
mask = pl.load(mask_ref, (k_slice, slice(None)))
mask = mask_ref[k_slice, :]

snm = jnp.where(should_not_mask, 1, 0)
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
Expand Down Expand Up @@ -156,7 +156,7 @@ def _apply_mask(
k_sequence = k_offset + jax.lax.broadcasted_iota(
jnp.int32, (k_slice.size, bq), 0
)
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
q_sequence = q_sequence_ref[:1, :] # [1, bq]
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))

assert q_sequence.shape == k_sequence.shape
Expand All @@ -170,7 +170,7 @@ def _apply_mask(

if q_segment_ids_ref is not None:
if k_in_lanes:
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
if rem:
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
Expand All @@ -181,9 +181,9 @@ def _apply_mask(
if rem:
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
kv_ids = pltpu.repeat(
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
kv_segment_ids_ref[k_slice, :], repeats, axis=1
) # [k_slice, bq]
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
masks.append(q_ids == kv_ids)

if masks:
Expand Down Expand Up @@ -228,7 +228,7 @@ def body(kv_compute_index, _):
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)

q = q_ref[...]
k = pl.load(k_ref, (slice_k, slice(None)))
k = k_ref[slice_k, :]
qk = jax.lax.dot_general(
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
)
Expand Down Expand Up @@ -256,7 +256,7 @@ def body(kv_compute_index, _):
)

sv_dims = NN_DIM_NUMBERS
v = pl.load(v_ref, (slice_k, slice(None)))
v = v_ref[slice_k, :]

to_float32 = lambda x: x.astype(jnp.float32)
v = to_float32(v)
Expand Down
Loading
Loading