Skip to content

Slow SDXL inference with JAX on Cloud TPU v5e for sizes other than 1024x1024 #6882

Open
@zstiggz

Description

@zstiggz

Describe the bug

Following the blog post on Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e. This worked magically until I tried to generate an image in a different size. At 1024x1024 we get inference latency of ~3s per image (as compared to ~8s on the NVIDIA A10G). But change the resolution to 1280x960 and we see next to no improvement.

Reproduction

Use the same code as in the blog post: https://huggingface.co/blog/sdxl_jax

Changes:

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
    width=1024,
    height=1024,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    images = pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        neg_prompt_ids=neg_prompt_ids,
        guidance_scale=guidance_scale,
        width=width,
        height=height,
        jit=True,
    ).images

    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return pipeline.numpy_to_pil(np.array(images))
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt, width=960, height=1280)
print(f"Compiled in {time.time() - start}")
start = time.time()
print("starting")
prompt = "llama in ancient Greece, oil on canvas"
neg_prompt = "cartoon, illustration, animation"
images = generate(prompt, neg_prompt, width=960, height=1280)
print(f"Inference in {time.time() - start}")

Logs

No response

System Info

Python: 3.10.6
Diffusers: 0.26.2
Torch: 2.2.0+cu121
Jax: 0.4.23
Flax: 0.8.0

Who can help?

@patrickvonplaten @yiyixuxu @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingjax/flaxstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions