diff --git a/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h b/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h index e43ee736..ab5079b2 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h +++ b/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h @@ -187,7 +187,95 @@ struct LocalSparseCoreTensorGroupingContext { MatrixXi& kept_unique_ids_per_partition_per_bucket; }; -inline void GroupAndDeduplicateCooTensorsForLocalSparseCore( +inline void GroupAndDeduplicateCooTensorsForLocalSparseCoreNoBuckets( + LocalSparseCoreTensorGroupingContext context) { + // Unpack context for readability. + const PreprocessSparseDenseMatmulInputOptions& options = context.options; + const StackedTableMetadata& stacked_table_metadata = + context.stacked_table_metadata; + const std::vector& coo_tensors = context.coo_tensors; + PartitionedCooTensors& grouped_coo_tensors = context.grouped_coo_tensors; + StatsPerDevice& stats = context.stats; + MatrixXi& observed_ids = context.ids_per_sc_partition_per_bucket; + MatrixXi& observed_unique_ids = context.unique_ids_per_partition_per_bucket; + MatrixXi& kept_ids = context.kept_ids_per_sc_partition_per_bucket; + MatrixXi& kept_unique_ids = context.kept_unique_ids_per_partition_per_bucket; + + const bool allow_id_dropping = options.allow_id_dropping; + const uint32_t global_sc_count = options.GetNumScs(); + const int max_ids_per_partition = + stacked_table_metadata.max_ids_per_partition; + const int max_unique_ids_per_partition = + stacked_table_metadata.max_unique_ids_per_partition; + uint32_t prev_col_id = std::numeric_limits::max(); + uint32_t prev_row_id = std::numeric_limits::max(); + bool dropping_current_unique_col_id = false; + for (const uint64_t key : context.keys) { + // Step 1: Unpack key to get tensor coordinates. + const uint32_t index = key & CooFormat::kIndexMask; + const CooFormat& coo_tensor = coo_tensors[index]; + const uint32_t col_id = coo_tensor.col_id; + const uint32_t global_sc_id = coo_tensor.col_id & (global_sc_count - 1); + const uint32_t row_id = coo_tensor.row_id; + + // Step 2: Handle duplicates. + // An ID that is a duplicate of a previously non-dropped ID is merged. + // It does not count as a new ID for stats and does not go through dropping + // logic. + if (grouped_coo_tensors.MaybeMerge(/*bucket_id=*/0, coo_tensor)) { + continue; + } + // If the ID is a duplicate of the last seen ID, it must have been dropped + // (otherwise it would have been merged above), so drop this one too. + if (col_id == prev_col_id && row_id == prev_row_id) { + ++stats.dropped_id_count; + continue; + } + + // Step 3: Update observed statistics for the new ID. + const bool is_new_col = col_id != prev_col_id; + // Update observed stats. These are never decremented and are used for + // reporting. + observed_ids(global_sc_id, 0) += 1; + if (is_new_col) { + observed_unique_ids(global_sc_id, 0) += 1; + dropping_current_unique_col_id = + (kept_unique_ids(global_sc_id, 0) + 1) > + max_unique_ids_per_partition; + } + + // Step 4: Determine if the ID should be dropped based on capacity limits. + // We do NOT drop IDs when minibatching is enabled and we are in the + // first pass (`create_buckets=false`), as we need to detect limit + // overflows to decide if minibatching is required. + const bool can_drop_id = + !options.enable_minibatching; + const bool exceeds_ids_limit = + (kept_ids(global_sc_id, 0) + 1) > max_ids_per_partition; + + // Step 5: Add ID to result or drop it. + if (can_drop_id && allow_id_dropping && + (exceeds_ids_limit || dropping_current_unique_col_id)) { + // Dropped id. + ++stats.dropped_id_count; + } else { + grouped_coo_tensors.Add(context.local_sc_id, /*bucket_id=*/0, coo_tensor); + // Update kept counts. + kept_ids(global_sc_id, 0) += 1; + if (is_new_col) { + kept_unique_ids(global_sc_id, 0) += 1; + } + } + + // Step 6: Update state for next iteration. + // This must be done regardless of whether the ID was dropped to ensure + // correct stats collection for subsequent IDs. + prev_col_id = col_id; + prev_row_id = row_id; + } +} + +inline void GroupAndDeduplicateCooTensorsForLocalSparseCoreWithBuckets( LocalSparseCoreTensorGroupingContext context) { // Unpack context for readability. const PreprocessSparseDenseMatmulInputOptions& options = context.options; @@ -298,8 +386,8 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore( // NOTE: We use output buffers `max_ids_per_sc`, `max_unique_ids_per_sc`, and // `required_buffer_size_per_sc` because we fill values in a loop to a bigger // array. -template -PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice( +template +PartitionedCooTensors SortAndGroupCooTensorsPerLocalDeviceImpl( const ExtractedCooTensors& extracted_coo_tensors, const StackedTableMetadata& stacked_table_metadata, const PreprocessSparseDenseMatmulInputOptions& options, @@ -320,20 +408,18 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice( // This function can be called in two passes for minibatching. The logic for // stats collection and ID dropping depends on the pass. // - // Pass 1: Check if minibatching is required (`create_buckets` is false). + // Pass 1: Check if minibatching is required (`kCreateBuckets` is false). // - No IDs are dropped. // - Stats are collected on all observed IDs to compute splits. // - // Pass 2: Create buckets (`create_buckets` is true). + // Pass 2: Create buckets (`kCreateBuckets` is true). // - A dummy stats object is used (stats are not re-computed). // - IDs may be dropped if they exceed capacity. - const bool create_buckets = options.enable_minibatching && - (std::is_same_v); // Partition COO tensors among SparseCores for the local device (based on row // id). const int bucket_count = - create_buckets ? CooFormat::kMaxMinibatchingBuckets : 1; + kCreateBuckets ? CooFormat::kMaxMinibatchingBuckets : 1; PartitionedCooTensors grouped_coo_tensors( coo_tensors.size(), num_sc_per_device, global_sc_count, bucket_count); @@ -367,7 +453,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice( // local_embedding_id(32-num_scs bits), index(26 bits)]. // Note that this assumes `num_scs` is a power of 2. keys.push_back(coo_tensors[coo_tensor_index].GetGroupingKey( - num_sc_bits, coo_tensor_index, create_buckets, + num_sc_bits, coo_tensor_index, kCreateBuckets, options.minibatching_bucketing_hash_fn)); } @@ -375,23 +461,43 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice( DCHECK(expected_keys_size == 0 || keys.size() == expected_keys_size); hwy::VQSort(keys.data(), keys.size(), hwy::SortAscending()); - internal::GroupAndDeduplicateCooTensorsForLocalSparseCore({ - .keys = keys, - .coo_tensors = coo_tensors, - .stacked_table_metadata = stacked_table_metadata, - .options = options, - .create_buckets = create_buckets, - .local_sc_id = local_sc_id, - .grouped_coo_tensors = grouped_coo_tensors, - .ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket, - .unique_ids_per_partition_per_bucket = - unique_ids_per_partition_per_bucket, - .stats = stats, - .kept_ids_per_sc_partition_per_bucket = - kept_ids_per_sc_partition_per_bucket, - .kept_unique_ids_per_partition_per_bucket = - kept_unique_ids_per_partition_per_bucket, - }); + if constexpr (kCreateBuckets) { + internal::GroupAndDeduplicateCooTensorsForLocalSparseCoreWithBuckets({ + .keys = keys, + .coo_tensors = coo_tensors, + .stacked_table_metadata = stacked_table_metadata, + .options = options, + .create_buckets = kCreateBuckets, + .local_sc_id = local_sc_id, + .grouped_coo_tensors = grouped_coo_tensors, + .ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket, + .unique_ids_per_partition_per_bucket = + unique_ids_per_partition_per_bucket, + .stats = stats, + .kept_ids_per_sc_partition_per_bucket = + kept_ids_per_sc_partition_per_bucket, + .kept_unique_ids_per_partition_per_bucket = + kept_unique_ids_per_partition_per_bucket, + }); + } else { + internal::GroupAndDeduplicateCooTensorsForLocalSparseCoreNoBuckets({ + .keys = keys, + .coo_tensors = coo_tensors, + .stacked_table_metadata = stacked_table_metadata, + .options = options, + .create_buckets = kCreateBuckets, + .local_sc_id = local_sc_id, + .grouped_coo_tensors = grouped_coo_tensors, + .ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket, + .unique_ids_per_partition_per_bucket = + unique_ids_per_partition_per_bucket, + .stats = stats, + .kept_ids_per_sc_partition_per_bucket = + kept_ids_per_sc_partition_per_bucket, + .kept_unique_ids_per_partition_per_bucket = + kept_unique_ids_per_partition_per_bucket, + }); + } grouped_coo_tensors.FillRemainingScBuckets(); @@ -427,7 +533,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice( // Only validate if creating minibatching buckets or when minibatching is // disabled, not when checking if minibatching is required. - if (!options.enable_minibatching || create_buckets) + if (!options.enable_minibatching || kCreateBuckets) internal::ValidateMaxIdsOrDie( observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket, max_ids_per_partition, max_unique_ids_per_partition, @@ -437,6 +543,26 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice( return grouped_coo_tensors; } +template +PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice( + const ExtractedCooTensors& extracted_coo_tensors, + const StackedTableMetadata& stacked_table_metadata, + const PreprocessSparseDenseMatmulInputOptions& options, + internal::StatsPerDevice& stats, SplitType& minibatching_split) { + const bool create_buckets = + options.enable_minibatching && + std::is_same_v; + if (create_buckets) { + return SortAndGroupCooTensorsPerLocalDeviceImpl( + extracted_coo_tensors, stacked_table_metadata, options, stats, + minibatching_split); + } else { + return SortAndGroupCooTensorsPerLocalDeviceImpl( + extracted_coo_tensors, stacked_table_metadata, options, stats, + minibatching_split); + } +} + } // namespace jax_sc_embedding #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SORT_AND_GROUP_COO_TENSORS_IMPL_H_