huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.41k stars 5.26k forks source link

Slow SDXL inference with JAX on Cloud TPU v5e for sizes other than 1024x1024 #6882

Open CaptainStiggz opened 8 months ago

CaptainStiggz commented 8 months ago

Describe the bug

Following the blog post on Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e. This worked magically until I tried to generate an image in a different size. At 1024x1024 we get inference latency of ~3s per image (as compared to ~8s on the NVIDIA A10G). But change the resolution to 1280x960 and we see next to no improvement.

Reproduction

Use the same code as in the blog post: https://huggingface.co/blog/sdxl_jax

Changes:

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
    width=1024,
    height=1024,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    images = pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        neg_prompt_ids=neg_prompt_ids,
        guidance_scale=guidance_scale,
        width=width,
        height=height,
        jit=True,
    ).images

    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return pipeline.numpy_to_pil(np.array(images))
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt, width=960, height=1280)
print(f"Compiled in {time.time() - start}")
start = time.time()
print("starting")
prompt = "llama in ancient Greece, oil on canvas"
neg_prompt = "cartoon, illustration, animation"
images = generate(prompt, neg_prompt, width=960, height=1280)
print(f"Inference in {time.time() - start}")

Logs

No response

System Info

Python: 3.10.6 Diffusers: 0.26.2 Torch: 2.2.0+cu121 Jax: 0.4.23 Flax: 0.8.0

Who can help?

@patrickvonplaten @yiyixuxu @DN6

DN6 commented 8 months ago

cc: @pcuenca for visibility here

github-actions[bot] commented 6 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

pcuenca commented 6 months ago

This could be for a number of reasons, but unfortunately I don't currently have access to TPU v5e instances to test. I'll see if we can get one to verify.

github-actions[bot] commented 6 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] commented 5 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

huseyintemiz commented 2 months ago

I have a similar issue when I tried SD1.5. SD15 inferences with custom resolution take unfair inference times. 512x512 50 step ⇒ 1.06 sec 512w x 640h 50 step ⇒ 3.36 sec 640x640 50 step ⇒ 5.02 sec 768x768 50 step ⇒ 4.34 sec

TPU-v5e-1chip