nateraw / stable-diffusion-videos

Create 🔥 videos with Stable Diffusion by exploring the latent space and morphing between text prompts
Apache License 2.0
4.42k stars 421 forks source link

Flax Implementation for TPU support #140

Closed charlielito closed 1 year ago

charlielito commented 1 year ago

Resolves #139

You can test the code here: https://colab.research.google.com/drive/11EoVXsfHFZKCKBFzsih3SdodZ8xFsaSl#scrollTo=YJ53vCC593-A&uniqifier=1

The first time you run it is slower because it needs to compile the code for TPU. Afterward it generates an image each 1.3~ seconds, almost 7 times faster than the GPU version in Colab (correct me if I am wrong)

In the sample colab, it generates a video of 30 interpolations steps in just 50 secs

nateraw commented 1 year ago

Wow! That was fast. Will take a look.

We can add a notebook for this pipe specifically I think. stable_diffusion_videos_flax.ipynb or something (in the root of the repo).

nateraw commented 1 year ago

Need access to the notebook. Perhaps just go ahead and commit the above. Or, if its too messy, just save as gist and send that on over, please 😄

charlielito commented 1 year ago

Need access to the notebook. Perhaps just go ahead and commit the above. Or, if its too messy, just save as gist and send that on over, please smile

Just updated the permissions, you should be able to see it now

cgarciae commented 1 year ago

Just tested @charlielito's notebook in a TPU v3-8 and it takes 20-25 seconds for 30 frames :tada:

Screenshot from 2023-01-15 07-23-15

nateraw commented 1 year ago

Yes I tested it last night too!! Very cool. Noticed batch size param > 1 fails though. Is this expected? I know nothing about jax.

I'm fine with merging this I think. Let me see if there's anything to clean up tomorrow documentation wise but otherwise LGTM

charlielito commented 1 year ago

It shouldn't crash with batch_size>1. Let me check

charlielito commented 1 year ago

@nateraw I just fixed and it is working with batch size > 1. The only pending stuff is to support negative_prompt. I also have some questions see below

nateraw commented 1 year ago

@charlielito feel free to leave out negative prompt if its a burden. you can feel free to iterate on this as you have time. Doesn't have to be perfect - this whole project is a hacky fun thing haha.

charlielito commented 1 year ago

@nateraw I'll leave the missing features for a future PR. I also replicated your notebook but using the flax API to use the TPU. For gradio interface, I slightly modified the app.py to accept a FlaxPipeline. Also I linked that notebook (to be merged) in the README

nateraw commented 1 year ago

SG! glanced at it and I think its good but I want to run it...traveling internationally today so not 100% sure I'll be able to get to it til tomorrow (depends on how much time I have on airport wifi).

Thanks again, very hype about this one.

One note, that we'll investigate further later, is that when batch size is >1 I think I"m seeing out of order frames in the output video. nbd, I'm sure its being ordered by TPU core or something, and we have to gather differently.

charlielito commented 1 year ago

One note, that we'll investigate further later, is that when batch size is >1 I think I"m seeing out of order frames in the output video. nbd, I'm sure its being ordered by TPU core or something, and we have to gather differently.

That was actually a concern I had while implementing it and I thought it was not happening. But now that you say it I'll take a look as it might be the case

nateraw commented 1 year ago

Yea I really jacked up the num_interpolation_steps param and it made it more obvious.

I am on Colab Pro+...did TPU w/ High Ram. Batch size of 4 and num_interpolation_steps ~240-320 I think. Was clearly doing some jittering with those settings. Looks like I wont have enough time here at the airport so I'll probably get to this either tomorrow or Friday.

nateraw commented 1 year ago

Will make a release soon, after we make sure everything's all good on main :) perhaps later tonight if I get around to it, worst case early next week.