-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full SDXL Model #67
Full SDXL Model #67
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great work!! It seems really difficult to manage the SDXL specific features and the SD2 features in one model, but you did the best organization possible. The few possible refactors may not be that much better. I proposed a few suggestions, but up to your discretion if you think they are good / worth the effort.
The only possible bug I noticed was in log_diffusion_images.py
.
One general proposal: there are a few if-statements due to slight differences in the SDXL and HF tokenizers. If you're up for it, I think it would clean up a bit of code if could make them more similar. I think this requires two things:
- Having a
max_length
argument in the SDXL Tokenizers and in the code set max length bymax_length = None if self.sdxl else self.tokenizer.model_max_length
. - Have the SDXL tokenizer return of dictionary with the key
input_ids
Co-authored-by: Landan Seguin <[email protected]>
Co-authored-by: Landan Seguin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing!! Thanks for all the fixes!
torch_dtype = torch.float16 if encode_latents_in_fp16 else None | ||
try: | ||
vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch_dtype) | ||
except: # for handling SDXL vae fp16 fixed checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The blanket except isn't great here. We should probably qualify the exception type here.
attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) | ||
else: | ||
attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) | ||
log.info('Using %s with clip_val %.1f' % (attn_processor.__class__, clip_qkv)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the '%' sign, logger can automatically generate the strings by itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great PR, just some minor nits
hidden_states = attn.to_out[1](hidden_states) | ||
|
||
if input_ndim == 4: | ||
assert channel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print the value of the assert, so we know why it's False (None, 0, etc.)
assert channel | |
assert channel, f"{channel}" |
why we need to clip qkv |
This PR contains the full implementation of Stable Diffusion XL (SDXL). SDXL uses two text encoders/tokenizers and also takes crop & size parameters from the dataloader as conditioning - a majority of the changes here are for supporting that.
A high-level description of the changes for each file:
diffusion/datasets/image_caption.py
rand_crop
flag to choose betweenLargestCenterSquare
&RandomCropSquare
- previously only center cropping was supported. This is relevant to SD2 if one might want to train with random cropping, but doesn't apply to SDXLtokenizer_name_or_path
RandomCropSquareReturnTransform
for SDXL, which returns the cropping parameters used as well as original image size (for training SDXL with micro-conditioning) and return micro-conditioning as part of the training batchmicrocond_drop_prob
flag). This is not discussed in the SDXL paper but is reflected in Stability AI's implementationSDXLTokenizer
diffusion/datasets/laion/transforms.py
RandomCropSquare
(does random crop only) andRandomCropSquareReturnTransform
(does random crop and returns crop params)diffusion/models/layers.py
zero_module
function used in SDXL initdiffusion/models/models.py
clip_qkv
argument)SDXLTokenizer
andSDXLTextEncoder
which contain the two tokenizers/text encoders but mostly can be used as if they are one tokenizer/text encoderdiffusion/models/stable_diffusion.py
sdxl
flag to StableDiffusion to indicate if we are training an SDXL modellatent_scale
for SD2 vs. SDXLpooled_conditioning
from SDXL text encoder, which is used in micro-conditioningcrop_params
andsize_params
for SDXL micro-conditioning, otherwise set them to reasonable default valueszero_out_negative_prompt
that zero's out the negative prompt if it is empty (rather than tokenizing and encoding the empty string). This was added to match the behavior of the diffusers StableDiffusionXLPipeline and in general this just seems like a good thing to do. Note: I set the default value to beTrue
, so this means previously made generations (e.g. with SD2) will look different despite using the same prompt/seed. Obviously can set it toFalse
to match previous results.diffusion/callbacks/log_diffusion_images.py
setup.py
There are a few remaining things I'd like to add, but this is already a big enough PR. I will add the following once this one is merged in:
CenterCropSquareReturnTransform
transformation that can be used for COCO eval. Currently with SDXL training we do random crop for the eval dataset as well