Skip to content

Commit

Permalink
SDXL conditioning is now a flag
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Oct 22, 2024
1 parent 2b4af55 commit fc95831
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions diffusion/evaluation/clean_fid_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class CleanFIDEvaluator:
default_prompt (Optional[str]): An optional default prompt to add before each eval prompt. Default: ``None``.
default_negative_prompt (Optional[str]): An optional default negative prompt to add before each
negative prompt. Default: ``None``.
sdxl_conditioning (bool): Whether or not to include SDXL conditioning in the evaluation. Default: ``False``.
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
"""
Expand All @@ -74,6 +75,7 @@ def __init__(self,
prompts: Optional[List[str]] = None,
default_prompt: Optional[str] = None,
default_negative_prompt: Optional[str] = None,
sdxl_conditioning: bool = False,
additional_generate_kwargs: Optional[Dict] = None):
self.model = model
self.dataset = dataset
Expand All @@ -92,8 +94,8 @@ def __init__(self,
self.prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater']
self.default_prompt = default_prompt
self.default_negative_prompt = default_negative_prompt
self.sdxl_conditioning = sdxl_conditioning
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
self.sdxl = model.sdxl

# Load the model
trainer = Trainer(model=self.model,
Expand Down Expand Up @@ -165,7 +167,7 @@ def _generate_images(self, guidance_scale: float):
if self.default_negative_prompt:
augmented_negative_prompt = [f'{self.default_negative_prompt}' for _ in text_captions]

if self.sdxl:
if self.sdxl_conditioning:
crop_params = torch.tensor([0, 0]).unsqueeze(0)
input_size_params = torch.tensor([self.size, self.size]).unsqueeze(0)
else:
Expand Down

0 comments on commit fc95831

Please sign in to comment.