Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow masking padding tokens in cross attention layers #94

Merged
merged 8 commits into from
Nov 16, 2023

Conversation

jazcollins
Copy link
Contributor

This PR adds a new parameter to the stable_diffusion_xl and stable_diffusion_2 classes called mask_pad_tokens that allows for masking out padding tokens in cross attention layers.

The generate() had to get a bit more complicated due to the setting where we pass in pre-tokenized inputs.. we'd now want to allow passing the padding mask with it (as well as for pre-tokenized negative prompts). Let me know if you think of a better way of handling this :/

One small note: this change might be slightly redundant with the zero_out_negative_prompt arg (in the generate() function) and zero_dropped_captions (in dataloader) that I added not too long ago for zero-ing out empty negative prompts and dropped captions. I think mask_pad_tokens ought to serve a similar purpose of masking out the empty text embeddings in the cross attention layers, however zero_out_negative_prompt/zero_dropped_captions additionally zero-out the pooled text embedding (used in SDXL as microconditioning). I think we still want to keep that functionality. The cleanest way would prob be to merge these all into one flag?

Copy link
Collaborator

@coryMosaicML coryMosaicML left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for this!

@jazcollins jazcollins merged commit 5decf2a into mosaicml:main Nov 16, 2023
7 checks passed
@jazcollins jazcollins deleted the zero-pad-tokens branch November 16, 2023 19:27
@jazcollins jazcollins restored the zero-pad-tokens branch November 16, 2023 19:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants