Open CaptainStiggz opened 8 months ago
cc: @pcuenca for visibility here
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
This could be for a number of reasons, but unfortunately I don't currently have access to TPU v5e instances to test. I'll see if we can get one to verify.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
I have a similar issue when I tried SD1.5. SD15 inferences with custom resolution take unfair inference times. 512x512 50 step ⇒ 1.06 sec 512w x 640h 50 step ⇒ 3.36 sec 640x640 50 step ⇒ 5.02 sec 768x768 50 step ⇒ 4.34 sec
TPU-v5e-1chip
Describe the bug
Following the blog post on Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e. This worked magically until I tried to generate an image in a different size. At 1024x1024 we get inference latency of ~3s per image (as compared to ~8s on the NVIDIA A10G). But change the resolution to 1280x960 and we see next to no improvement.
Reproduction
Use the same code as in the blog post: https://huggingface.co/blog/sdxl_jax
Changes:
Logs
No response
System Info
Python: 3.10.6 Diffusers: 0.26.2 Torch: 2.2.0+cu121 Jax: 0.4.23 Flax: 0.8.0
Who can help?
@patrickvonplaten @yiyixuxu @DN6