Skip to content

Commit

Permalink
fix negative prompt classifier free guidance
Browse files Browse the repository at this point in the history
  • Loading branch information
jazcollins committed Oct 3, 2023
1 parent 3ae948e commit daf745c
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def generate(
width: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 3.0,
num_images_per_prompt: Optional[int] = 1,
num_images_per_prompt: int = 1,
seed: Optional[int] = None,
progress_bar: Optional[bool] = True,
zero_out_negative_prompt: bool = True,
Expand Down Expand Up @@ -413,14 +413,13 @@ def generate(
# negative prompt is given in place of the unconditional input in classifier free guidance
pooled_embeddings = None
if do_classifier_free_guidance:
if negative_prompt_embeds is None and zero_out_negative_prompt:
if not negative_prompt and not tokenized_negative_prompts and zero_out_negative_prompt:
# Negative prompt is empty and we want to zero it out
unconditional_embeddings = torch.zeros_like(text_embeddings)
if pooled_text_embeddings is not None:
pooled_unconditional_embeddings = torch.zeros_like(pooled_text_embeddings)
else:
pooled_unconditional_embeddings = None
pooled_unconditional_embeddings = torch.zeros_like(pooled_text_embeddings) if self.sdxl else None
else:
negative_prompt = negative_prompt or ([''] * (batch_size // num_images_per_prompt)) # type: ignore
if not negative_prompt:
negative_prompt = [''] * (batch_size // num_images_per_prompt) # type: ignore
unconditional_embeddings, pooled_unconditional_embeddings = self._prepare_text_embeddings(
negative_prompt, tokenized_negative_prompts, negative_prompt_embeds, num_images_per_prompt)

Expand Down

0 comments on commit daf745c

Please sign in to comment.