Skip to content

Commit aa8e6a6

Browse files
patnotzrecml authors
authored andcommitted
Add batch_number to AbstractInputBatch.
This CL adds the batch_number attribute to the AbstractInputBatch base class. The number should be unique and incremental, but can be reset to 0 on restart or between epochs. PiperOrigin-RevId: 783536041
1 parent eb58583 commit aa8e6a6

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

recml/layers/linen/sparsecore.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,12 @@ class SparsecorePreprocessor:
276276
sparsecore_config: The sparsecore config used to create the tables.
277277
global_batch_size: The global batch size across all devices to partition the
278278
inputs across.
279+
_batch_number: The batch number for preprocessing, incremented on each call.
279280
"""
280281

281282
sparsecore_config: SparsecoreConfig
282283
global_batch_size: int
284+
_batch_number: int = dataclasses.field(init=False, default=0)
283285

284286
def __post_init__(self):
285287
self.sparsecore_config.init_feature_specs(self.global_batch_size)
@@ -328,6 +330,7 @@ def _to_np(x: Any) -> np.ndarray:
328330
if weights[key] is not None:
329331
weights[key] = np.reshape(weights[key], (-1, 1))
330332

333+
self._batch_number += 1
331334
csr_inputs, _ = embedding.preprocess_sparse_dense_matmul_input(
332335
features=features,
333336
features_weights=weights,
@@ -337,6 +340,7 @@ def _to_np(x: Any) -> np.ndarray:
337340
num_sc_per_device=self.sparsecore_config.num_sc_per_device,
338341
sharding_strategy=self.sparsecore_config.sharding_strategy,
339342
allow_id_dropping=False,
343+
batch_number=self._batch_number,
340344
)
341345

342346
processed_inputs = {

0 commit comments

Comments
 (0)