Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 153 additions & 27 deletions jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,95 @@ struct LocalSparseCoreTensorGroupingContext {
MatrixXi& kept_unique_ids_per_partition_per_bucket;
};

inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
inline void GroupAndDeduplicateCooTensorsForLocalSparseCoreNoBuckets(
LocalSparseCoreTensorGroupingContext context) {
// Unpack context for readability.
const PreprocessSparseDenseMatmulInputOptions& options = context.options;
const StackedTableMetadata& stacked_table_metadata =
context.stacked_table_metadata;
const std::vector<CooFormat>& coo_tensors = context.coo_tensors;
PartitionedCooTensors& grouped_coo_tensors = context.grouped_coo_tensors;
StatsPerDevice& stats = context.stats;
MatrixXi& observed_ids = context.ids_per_sc_partition_per_bucket;
MatrixXi& observed_unique_ids = context.unique_ids_per_partition_per_bucket;
MatrixXi& kept_ids = context.kept_ids_per_sc_partition_per_bucket;
MatrixXi& kept_unique_ids = context.kept_unique_ids_per_partition_per_bucket;

const bool allow_id_dropping = options.allow_id_dropping;
const uint32_t global_sc_count = options.GetNumScs();
const int max_ids_per_partition =
stacked_table_metadata.max_ids_per_partition;
const int max_unique_ids_per_partition =
stacked_table_metadata.max_unique_ids_per_partition;
uint32_t prev_col_id = std::numeric_limits<uint32_t>::max();
uint32_t prev_row_id = std::numeric_limits<uint32_t>::max();
bool dropping_current_unique_col_id = false;
for (const uint64_t key : context.keys) {
// Step 1: Unpack key to get tensor coordinates.
const uint32_t index = key & CooFormat::kIndexMask;
const CooFormat& coo_tensor = coo_tensors[index];
const uint32_t col_id = coo_tensor.col_id;
const uint32_t global_sc_id = coo_tensor.col_id & (global_sc_count - 1);
const uint32_t row_id = coo_tensor.row_id;

// Step 2: Handle duplicates.
// An ID that is a duplicate of a previously non-dropped ID is merged.
// It does not count as a new ID for stats and does not go through dropping
// logic.
if (grouped_coo_tensors.MaybeMerge(/*bucket_id=*/0, coo_tensor)) {
continue;
}
// If the ID is a duplicate of the last seen ID, it must have been dropped
// (otherwise it would have been merged above), so drop this one too.
if (col_id == prev_col_id && row_id == prev_row_id) {
++stats.dropped_id_count;
continue;
}

// Step 3: Update observed statistics for the new ID.
const bool is_new_col = col_id != prev_col_id;
// Update observed stats. These are never decremented and are used for
// reporting.
observed_ids(global_sc_id, 0) += 1;
if (is_new_col) {
observed_unique_ids(global_sc_id, 0) += 1;
dropping_current_unique_col_id =
(kept_unique_ids(global_sc_id, 0) + 1) >
max_unique_ids_per_partition;
}

// Step 4: Determine if the ID should be dropped based on capacity limits.
// We do NOT drop IDs when minibatching is enabled and we are in the
// first pass (`create_buckets=false`), as we need to detect limit
// overflows to decide if minibatching is required.
const bool can_drop_id =
!options.enable_minibatching;
const bool exceeds_ids_limit =
(kept_ids(global_sc_id, 0) + 1) > max_ids_per_partition;

// Step 5: Add ID to result or drop it.
if (can_drop_id && allow_id_dropping &&
(exceeds_ids_limit || dropping_current_unique_col_id)) {
// Dropped id.
++stats.dropped_id_count;
} else {
grouped_coo_tensors.Add(context.local_sc_id, /*bucket_id=*/0, coo_tensor);
// Update kept counts.
kept_ids(global_sc_id, 0) += 1;
if (is_new_col) {
kept_unique_ids(global_sc_id, 0) += 1;
}
}

// Step 6: Update state for next iteration.
// This must be done regardless of whether the ID was dropped to ensure
// correct stats collection for subsequent IDs.
prev_col_id = col_id;
prev_row_id = row_id;
}
}

inline void GroupAndDeduplicateCooTensorsForLocalSparseCoreWithBuckets(
LocalSparseCoreTensorGroupingContext context) {
// Unpack context for readability.
const PreprocessSparseDenseMatmulInputOptions& options = context.options;
Expand Down Expand Up @@ -298,8 +386,8 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
// NOTE: We use output buffers `max_ids_per_sc`, `max_unique_ids_per_sc`, and
// `required_buffer_size_per_sc` because we fill values in a loop to a bigger
// array.
template <typename SplitType>
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
template <bool kCreateBuckets, typename SplitType>
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDeviceImpl(
const ExtractedCooTensors& extracted_coo_tensors,
const StackedTableMetadata& stacked_table_metadata,
const PreprocessSparseDenseMatmulInputOptions& options,
Expand All @@ -320,20 +408,18 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
// This function can be called in two passes for minibatching. The logic for
// stats collection and ID dropping depends on the pass.
//
// Pass 1: Check if minibatching is required (`create_buckets` is false).
// Pass 1: Check if minibatching is required (`kCreateBuckets` is false).
// - No IDs are dropped.
// - Stats are collected on all observed IDs to compute splits.
//
// Pass 2: Create buckets (`create_buckets` is true).
// Pass 2: Create buckets (`kCreateBuckets` is true).
// - A dummy stats object is used (stats are not re-computed).
// - IDs may be dropped if they exceed capacity.
const bool create_buckets = options.enable_minibatching &&
(std::is_same_v<SplitType, MinibatchingSplit>);

// Partition COO tensors among SparseCores for the local device (based on row
// id).
const int bucket_count =
create_buckets ? CooFormat::kMaxMinibatchingBuckets : 1;
kCreateBuckets ? CooFormat::kMaxMinibatchingBuckets : 1;
PartitionedCooTensors grouped_coo_tensors(
coo_tensors.size(), num_sc_per_device, global_sc_count, bucket_count);

Expand Down Expand Up @@ -367,31 +453,51 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
// local_embedding_id(32-num_scs bits), index(26 bits)].
// Note that this assumes `num_scs` is a power of 2.
keys.push_back(coo_tensors[coo_tensor_index].GetGroupingKey(
num_sc_bits, coo_tensor_index, create_buckets,
num_sc_bits, coo_tensor_index, kCreateBuckets,
options.minibatching_bucketing_hash_fn));
}

// The expected allocation size may be uninitialized.
DCHECK(expected_keys_size == 0 || keys.size() == expected_keys_size);
hwy::VQSort(keys.data(), keys.size(), hwy::SortAscending());

internal::GroupAndDeduplicateCooTensorsForLocalSparseCore({
.keys = keys,
.coo_tensors = coo_tensors,
.stacked_table_metadata = stacked_table_metadata,
.options = options,
.create_buckets = create_buckets,
.local_sc_id = local_sc_id,
.grouped_coo_tensors = grouped_coo_tensors,
.ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket,
.unique_ids_per_partition_per_bucket =
unique_ids_per_partition_per_bucket,
.stats = stats,
.kept_ids_per_sc_partition_per_bucket =
kept_ids_per_sc_partition_per_bucket,
.kept_unique_ids_per_partition_per_bucket =
kept_unique_ids_per_partition_per_bucket,
});
if constexpr (kCreateBuckets) {
internal::GroupAndDeduplicateCooTensorsForLocalSparseCoreWithBuckets({
.keys = keys,
.coo_tensors = coo_tensors,
.stacked_table_metadata = stacked_table_metadata,
.options = options,
.create_buckets = kCreateBuckets,
.local_sc_id = local_sc_id,
.grouped_coo_tensors = grouped_coo_tensors,
.ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket,
.unique_ids_per_partition_per_bucket =
unique_ids_per_partition_per_bucket,
.stats = stats,
.kept_ids_per_sc_partition_per_bucket =
kept_ids_per_sc_partition_per_bucket,
.kept_unique_ids_per_partition_per_bucket =
kept_unique_ids_per_partition_per_bucket,
});
} else {
internal::GroupAndDeduplicateCooTensorsForLocalSparseCoreNoBuckets({
.keys = keys,
.coo_tensors = coo_tensors,
.stacked_table_metadata = stacked_table_metadata,
.options = options,
.create_buckets = kCreateBuckets,
.local_sc_id = local_sc_id,
.grouped_coo_tensors = grouped_coo_tensors,
.ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket,
.unique_ids_per_partition_per_bucket =
unique_ids_per_partition_per_bucket,
.stats = stats,
.kept_ids_per_sc_partition_per_bucket =
kept_ids_per_sc_partition_per_bucket,
.kept_unique_ids_per_partition_per_bucket =
kept_unique_ids_per_partition_per_bucket,
});
}

grouped_coo_tensors.FillRemainingScBuckets();

Expand Down Expand Up @@ -427,7 +533,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(

// Only validate if creating minibatching buckets or when minibatching is
// disabled, not when checking if minibatching is required.
if (!options.enable_minibatching || create_buckets)
if (!options.enable_minibatching || kCreateBuckets)
internal::ValidateMaxIdsOrDie(
observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket,
max_ids_per_partition, max_unique_ids_per_partition,
Expand All @@ -437,6 +543,26 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
return grouped_coo_tensors;
}

template <typename SplitType>
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
const ExtractedCooTensors& extracted_coo_tensors,
const StackedTableMetadata& stacked_table_metadata,
const PreprocessSparseDenseMatmulInputOptions& options,
internal::StatsPerDevice& stats, SplitType& minibatching_split) {
const bool create_buckets =
options.enable_minibatching &&
std::is_same_v<SplitType, MinibatchingSplit>;
if (create_buckets) {
return SortAndGroupCooTensorsPerLocalDeviceImpl<true>(
extracted_coo_tensors, stacked_table_metadata, options, stats,
minibatching_split);
} else {
return SortAndGroupCooTensorsPerLocalDeviceImpl<false>(
extracted_coo_tensors, stacked_table_metadata, options, stats,
minibatching_split);
}
}

} // namespace jax_sc_embedding

#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SORT_AND_GROUP_COO_TENSORS_IMPL_H_