Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batched pairwise similarity method for Semantic Dedup #581

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

praateekmahajan
Copy link
Collaborator

@praateekmahajan praateekmahajan commented Mar 7, 2025

Description

Resolves #520

Currently if a single cluster is large enough, we'll likely OOM since M @ M.T requires N**2 storage. A batched version breaks it up into smaller batches B and performs M @ B.T. The only thing we need to be careful is how to zero out the diagonals and get the upper triangular matrix.

other nits

  1. always l2 normalize the embedding vector so that M @ M.T results in an absolute max value of 1
  2. renamed _semdedup to pairwise_similarity
  3. added tests for existing function + the batched approach.
  4. default now is batched implementation

Usage

# Add snippet demonstrating usage

Checklist

  • I am familiar with the Contributing Guide.
  • New or Existing tests cover these changes.
  • The documentation is up to date with these changes.

Signed-off-by: Praateek <[email protected]>
Signed-off-by: Praateek <[email protected]>
Signed-off-by: Praateek <[email protected]>
Signed-off-by: Praateek <[email protected]>
Copy link
Collaborator

@sarahyurick sarahyurick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice PR, added a couple minor comments.

Comment on lines 194 to 195
# Compute pairwise cosine similarity
pairwise_sim_matrix = cluster_reps @ (cluster_reps.T)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha I don't think I've used @ before. Would there be any advantage to using torch.mm, torch.matmul, etc.?

@@ -681,6 +681,7 @@
" id_column_type=\"str\",\n",
" embedding_col=\"image_embedding\",\n",
" which_to_keep=\"hard\",\n",
" batched_cosine_similarity=1024,\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this been tested?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, do we always manually run these notebooks for such PRs? That'll be a time sink but I'm okay to do it that's the practice

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is expected to produce the same results as before, it is okay with me. Sometimes I leave notebooks unchanged (or add changes to have it keep the previous default) if the output is expected to change, so that the user won't be confused when their cell outputs are different than the ones on GitHub.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it sounds like that isn't the case here?

@praateekmahajan praateekmahajan changed the title Add batched pairwise similarity method Add batched pairwise similarity method for Semantic Dedup Mar 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.OutOfMemoryError: CUDA out of memory. while performing peft curation with sdg on default configs
2 participants