Closed charlielito closed 1 year ago
Wow! That was fast. Will take a look.
We can add a notebook for this pipe specifically I think. stable_diffusion_videos_flax.ipynb
or something (in the root of the repo).
Need access to the notebook. Perhaps just go ahead and commit the above. Or, if its too messy, just save as gist and send that on over, please 😄
Need access to the notebook. Perhaps just go ahead and commit the above. Or, if its too messy, just save as gist and send that on over, please smile
Just updated the permissions, you should be able to see it now
Just tested @charlielito's notebook in a TPU v3-8 and it takes 20-25 seconds for 30 frames :tada:
Yes I tested it last night too!! Very cool. Noticed batch size param > 1 fails though. Is this expected? I know nothing about jax.
I'm fine with merging this I think. Let me see if there's anything to clean up tomorrow documentation wise but otherwise LGTM
It shouldn't crash with batch_size>1. Let me check
@nateraw I just fixed and it is working with batch size > 1.
The only pending stuff is to support negative_prompt
. I also have some questions see below
@charlielito feel free to leave out negative prompt if its a burden. you can feel free to iterate on this as you have time. Doesn't have to be perfect - this whole project is a hacky fun thing haha.
@nateraw I'll leave the missing features for a future PR.
I also replicated your notebook but using the flax API to use the TPU. For gradio interface, I slightly modified the app.py
to accept a FlaxPipeline.
Also I linked that notebook (to be merged) in the README
SG! glanced at it and I think its good but I want to run it...traveling internationally today so not 100% sure I'll be able to get to it til tomorrow (depends on how much time I have on airport wifi).
Thanks again, very hype about this one.
One note, that we'll investigate further later, is that when batch size is >1 I think I"m seeing out of order frames in the output video. nbd, I'm sure its being ordered by TPU core or something, and we have to gather differently.
One note, that we'll investigate further later, is that when batch size is >1 I think I"m seeing out of order frames in the output video. nbd, I'm sure its being ordered by TPU core or something, and we have to gather differently.
That was actually a concern I had while implementing it and I thought it was not happening. But now that you say it I'll take a look as it might be the case
Yea I really jacked up the num_interpolation_steps
param and it made it more obvious.
I am on Colab Pro+...did TPU w/ High Ram. Batch size of 4 and num_interpolation_steps
~240-320 I think. Was clearly doing some jittering with those settings. Looks like I wont have enough time here at the airport so I'll probably get to this either tomorrow or Friday.
Will make a release soon, after we make sure everything's all good on main :) perhaps later tonight if I get around to it, worst case early next week.
Resolves #139
You can test the code here: https://colab.research.google.com/drive/11EoVXsfHFZKCKBFzsih3SdodZ8xFsaSl#scrollTo=YJ53vCC593-A&uniqifier=1
The first time you run it is slower because it needs to compile the code for TPU. Afterward it generates an image each 1.3~ seconds, almost 7 times faster than the GPU version in Colab (correct me if I am wrong)
In the sample colab, it generates a video of 30 interpolations steps in just 50 secs