Closed charlielito closed 1 year ago
Ah this isn't actually TPU support necessarily, its an entirely different pipeline written in jax as opposed to torch. I would LOVE to add this, as it would significantly speed things up for colab-only users, however I have spent maybe a max of 2 hours playing with jax in my life haha.
Is jax something you're comfortable with? If not, perhaps we find someone who is that can assist here and add a new pipeline.
I am not a JAX expert but I've been playing with it for some time now. I would like to help with this
Anything I can do to support you/help you if you try to do it?
What I'm thinking:
make_video_pyav
, etc.)Feel free to limit the amount of features at first so we can just get a minimal V0 working. We can iterate from there
Stable diffusion now is supported in Flax via Jax to be run in TPUs. They can generate 8 images in just 5 seconds. This seems very useful for this project where you need to generate a lot of images.
Guide: https://huggingface.co/blog/stable_diffusion_jax Colab: https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fast_jax.ipynb#scrollTo=5Dz5aeq_yUNq