From c5eb89a2b5f45c424d2f0e0ac95c0f3fea9016ae Mon Sep 17 00:00:00 2001 From: Chandra Devarakonda Date: Thu, 27 Feb 2025 15:36:08 -0800 Subject: [PATCH] Fix incorrect number of return values in docstring of embedding.preprocess_sparsecore_inputs. PiperOrigin-RevId: 731891800 --- jax_tpu_embedding/sparsecore/lib/nn/embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py index 417401cc..65e7565b 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py @@ -334,8 +334,8 @@ def preprocess_sparse_dense_matmul_input( the max_ids_per_partition or max_unique_ids_per_partition limits. Returns: - A tuple of four dictionaries mapping the stacked table names to the - preprocessed inputs for the corresponding table. The four dictionaries are + A tuple of five dictionaries mapping the stacked table names to the + preprocessed inputs for the corresponding table. The five dictionaries are lhs_row_pointers, lhs_embedding_ids, lhs_sample_ids, lhs_gains and stats. """ tree.assert_same_structure(features, feature_specs)