Open
Description
Describe the bug
Following the blog post on Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e. This worked magically until I tried to generate an image in a different size. At 1024x1024 we get inference latency of ~3s per image (as compared to ~8s on the NVIDIA A10G). But change the resolution to 1280x960 and we see next to no improvement.
Reproduction
Use the same code as in the blog post: https://huggingface.co/blog/sdxl_jax
Changes:
def generate(
prompt,
negative_prompt,
seed=default_seed,
guidance_scale=default_guidance_scale,
num_inference_steps=default_num_steps,
width=1024,
height=1024,
):
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
images = pipeline(
prompt_ids,
p_params,
rng,
num_inference_steps=num_inference_steps,
neg_prompt_ids=neg_prompt_ids,
guidance_scale=guidance_scale,
width=width,
height=height,
jit=True,
).images
# convert the images to PIL
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
return pipeline.numpy_to_pil(np.array(images))
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt, width=960, height=1280)
print(f"Compiled in {time.time() - start}")
start = time.time()
print("starting")
prompt = "llama in ancient Greece, oil on canvas"
neg_prompt = "cartoon, illustration, animation"
images = generate(prompt, neg_prompt, width=960, height=1280)
print(f"Inference in {time.time() - start}")
Logs
No response
System Info
Python: 3.10.6
Diffusers: 0.26.2
Torch: 2.2.0+cu121
Jax: 0.4.23
Flax: 0.8.0