Skip to content

Commit 00af9f7

Browse files
RecML authorsrecml authors
authored andcommitted
Add replicate_on_split to TFDatasetFactory.
Reverts changelist 793734230 PiperOrigin-RevId: 815176403
1 parent 847628b commit 00af9f7

26 files changed

+4764
-458
lines changed

recml/core/data/tf_dataset_factory.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import re
2525
from typing import Any, Protocol
2626

27+
from absl import flags
2728
from absl import logging
2829
import jax
2930
from recml.core.utils import types
@@ -162,12 +163,17 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
162163
Defaults to False.
163164
seed: An optional seed to use for deterministic shuffling / preprocessing.
164165
Defaults to None.
165-
tf_data_service_address: An optional URI of a tf.data service to offload
166-
preprocessing onto during training. The URI should be in the format
167-
"protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no
168-
data service will be applied.
166+
enable_tf_data_service: Whether to apply tf.data service for this dataset.
167+
If True, flag `tf_data_service_address` must be set.
169168
tf_data_service_policy: Sharding policy to use for tf.data service when it
170169
is enabled.
170+
tf_data_service_job_name: Job name to use for tf.data service. If None, the
171+
default job name will be used.
172+
tf_data_service_replicate_on_split: Whether to replicate the file dataset on
173+
split when distributing data to tf.data service workers. Note: it could be
174+
used in the case where multiple datasets are processed together under
175+
`Dynamic` mode. The dataset with `tf_data_service_replicate_on_split`
176+
enabled is equivalent to having that dataset processed as `Off` mode.
171177
feature_spec: A mapping of feature keys to `FixedLenFeature`,
172178
`VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be
173179
used to parse the TF examples, or as context_features spec to parse TF
@@ -208,7 +214,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
208214
tensorflow.
209215
debug: An optional boolean indicating whether to debug input boundedness. If
210216
`True`, the dataset will consist of a single batch that's cached and
211-
infinitely repeated
217+
infinitely repeated.
212218
"""
213219

214220
cache_reading: bool = False
@@ -231,7 +237,8 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
231237
readahead: str | None = None
232238
group_uris_by_dir: bool = False
233239
seed: int | None = None
234-
tf_data_service_address: str | None = None
240+
enable_tf_data_service: bool = False
241+
tf_data_service_job_name: str | None = None
235242
tf_data_service_policy: tf.data.experimental.service.ShardingPolicy = (
236243
tf.data.experimental.service.ShardingPolicy.OFF
237244
)
@@ -246,10 +253,16 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
246253
sharding_info: DatasetShardingInfo = dataclasses.field(
247254
default_factory=DatasetShardingInfo
248255
)
256+
tf_data_service_replicate_on_split: bool = False
249257
debug: bool = False
250258

251259
def __post_init__(self):
252-
if self.tf_data_service_address is not None:
260+
if self.enable_tf_data_service:
261+
if flags.FLAGS.tf_data_service_address is None:
262+
raise ValueError(
263+
"Flag `tf_data_service_address` must be set when"
264+
" `enable_tf_data_service` is True."
265+
)
253266
if self.seed is not None:
254267
raise ValueError("`seed` must be None for data service.")
255268
if self.sharding:
@@ -464,6 +477,9 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
464477
# Create a dataset of file / file group uris.
465478
dataset = tf.data.Dataset.from_tensor_slices(uris)
466479

480+
if self.tf_data_service_replicate_on_split:
481+
dataset = tf.data.apply_rewrite(dataset, rewrite="replicate_on_split")
482+
467483
# Repeat the dataset. We might need to repeat the dataset here in case the
468484
# issue is encountered: internal screenshot link:6jAKKoEMT3afXRe
469485
# even we do have enough shards for the input data.
@@ -533,23 +549,26 @@ def _maybe_apply_tf_data_service(
533549
self, dataset: tf.data.Dataset
534550
) -> tf.data.Dataset:
535551
"""Applies the tf.data service to the dataset."""
536-
if self.tf_data_service_address is None:
552+
if not self.enable_tf_data_service:
537553
return dataset
538554

555+
tf_data_service_address = flags.FLAGS.tf_data_service_address
556+
539557
per_proc_batch_size = self.sharding_info.per_process_batch_size(
540558
self.global_batch_size
541559
)
542560
logging.info(
543561
"Applying tf.data service with address %s and per replica batch"
544562
" size %s",
545-
self.tf_data_service_address,
563+
tf_data_service_address,
546564
per_proc_batch_size,
547565
)
548566
return dataset.apply(
549567
tf.data.experimental.service.distribute(
550568
processing_mode=self.tf_data_service_policy,
551-
service=self.tf_data_service_address,
552-
job_name=f"bs_{per_proc_batch_size}",
569+
service=tf_data_service_address,
570+
job_name=self.tf_data_service_job_name
571+
or "tf_data_service_shared_job_name",
553572
)
554573
)
555574

recml/core/ops/hstu_ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def _apply_mask(
125125
masks = []
126126
if mask_ref is not None:
127127
if k_in_lanes:
128-
mask = pl.load(mask_ref, (slice(None), k_slice))
128+
mask = mask_ref[:, k_slice]
129129
else:
130-
mask = pl.load(mask_ref, (k_slice, slice(None)))
130+
mask = mask_ref[k_slice, :]
131131

132132
snm = jnp.where(should_not_mask, 1, 0)
133133
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
@@ -156,7 +156,7 @@ def _apply_mask(
156156
k_sequence = k_offset + jax.lax.broadcasted_iota(
157157
jnp.int32, (k_slice.size, bq), 0
158158
)
159-
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
159+
q_sequence = q_sequence_ref[:1, :] # [1, bq]
160160
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))
161161

162162
assert q_sequence.shape == k_sequence.shape
@@ -170,7 +170,7 @@ def _apply_mask(
170170

171171
if q_segment_ids_ref is not None:
172172
if k_in_lanes:
173-
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
173+
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
174174
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
175175
if rem:
176176
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
@@ -181,9 +181,9 @@ def _apply_mask(
181181
if rem:
182182
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
183183
kv_ids = pltpu.repeat(
184-
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
184+
kv_segment_ids_ref[k_slice, :], repeats, axis=1
185185
) # [k_slice, bq]
186-
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
186+
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
187187
masks.append(q_ids == kv_ids)
188188

189189
if masks:
@@ -228,7 +228,7 @@ def body(kv_compute_index, _):
228228
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)
229229

230230
q = q_ref[...]
231-
k = pl.load(k_ref, (slice_k, slice(None)))
231+
k = k_ref[slice_k, :]
232232
qk = jax.lax.dot_general(
233233
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
234234
)
@@ -256,7 +256,7 @@ def body(kv_compute_index, _):
256256
)
257257

258258
sv_dims = NN_DIM_NUMBERS
259-
v = pl.load(v_ref, (slice_k, slice(None)))
259+
v = v_ref[slice_k, :]
260260

261261
to_float32 = lambda x: x.astype(jnp.float32)
262262
v = to_float32(v)

0 commit comments

Comments
 (0)