Skip to content

Commit

Permalink
Fix issue with negative prompts and allow model_name argument (#63)
Browse files Browse the repository at this point in the history
* Add option to change model name and fix negative prompt bugs
  • Loading branch information
Landanjs authored Aug 30, 2023
1 parent 053d32a commit 7d28e21
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions diffusion/inference/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ class StableDiffusionInference():
Default: ``None``.
"""

def __init__(self, pretrained: bool = False, prediction_type: str = 'epsilon'):
def __init__(self,
model_name: str = 'stabilityai/stable-diffusion-2-base',
pretrained: bool = False,
prediction_type: str = 'epsilon'):
self.device = torch.cuda.current_device()

model = stable_diffusion_2(
model_name=model_name,
pretrained=pretrained,
prediction_type=prediction_type,
encode_latents_in_fp16=True,
Expand Down Expand Up @@ -68,12 +72,14 @@ def predict(self, model_requests: List[Dict[str, Any]]):
# Prompts and negative prompts if available
if isinstance(inputs, str):
prompts.append(inputs)
elif isinstance(input, Dict):
if 'prompt' not in req:
elif isinstance(inputs, Dict):
if 'prompt' not in inputs:
raise RuntimeError('"prompt" must be provided to generate call if using a dict as input')
prompts.append(inputs['prompt'])
if 'negative_prompt' in req:
if 'negative_prompt' in inputs:
negative_prompts.append(inputs['negative_prompt'])
else:
raise RuntimeError(f'Input must be of type string or dict, but it is type: {type(inputs)}')

generate_kwargs = req['parameters']

Expand Down

0 comments on commit 7d28e21

Please sign in to comment.