jy0205 / Pyramid-Flow

Code of Pyramidal Flow Matching for Efficient Video Generative Modeling
https://pyramid-flow.github.io/
MIT License
1.86k stars 157 forks source link

strange video output for t2v and i2v #101

Open YAY-3M-TA3 opened 4 days ago

YAY-3M-TA3 commented 4 days ago

I am getting a strange output for t2v and i2v.

I am on a mac with m2 24GB, python 3.10.13, torch 2.6.0(nightly), torchvision 0.20.0(nightly).
I modified the code to support MPS

  1. switching device"cuda" to "mps" and
  2. making sure that tensors moved to gpu are bfloat16.
  3. scheduling_flow_matching.py, line 195, modified to add cast to bfloat16: self.timesteps = torch.from_numpy(timesteps).to(torch.bfloat16).to(device=device)
  4. line 204, modified to add cast to bfloat16: sigmas = torch.from_numpy(ratios).to(torch.bfloat16).to(device=device)
  5. modeling_pyramid_mmdit.py, line 30, modified dtype to torch.bfloat16: scale = torch.arange(0, dim, 2, dtype=torch.bfloat16, device=pos.device) / dim

itv test condition: model: diffusion_transformer_384p prompt: A campfire burning with flames and embers, gradually increasing in size and intensity before dying down towards the end, hyper quality, Ultra HD, 8K width: 640 height: 384 first_frame_steps: 20,20,20 video steps 10,10,10 temp: 16 guidance scale: 7.00 video guidance scale: 5.00 seed: 44664248661394

VAE tile sample min size: 256 window size: 1

Here is the t2v out put https://github.com/user-attachments/assets/fb3371e3-146c-44b4-bba3-80a4ee96992e

i2v same set-up, but with different prompt and added image: prompt: FPV flying over the Great Wall, hyper quality, Ultra HD, 8K image: the_great_wall.jpg seed:44664248661398

Here is that output: https://github.com/user-attachments/assets/80c3bd2d-c1b1-4ab4-bc8d-af792165ab39

What could be causing this? is it because of the precision modifications I made to bfloat16?

jy0205 commented 4 days ago

Hi, our model does not support dynamic resolution now! For 384p version, it requires width=640, height=384; For 768p, it requires width=1280, height=768; You can change the height setting and try again?

YAY-3M-TA3 commented 4 days ago

Hi, our model does not support dynamic resolution now! For 384p version, it requires width=640, height=384; For 768p, it requires width=1280, height=768; You can change the height setting and try again?

Sorry, I wrote that wrong... I am using 384 as the height....

feifeiobama commented 4 days ago

Thanks for trying on Apple Silicon, we were not expecting this!

We suspect there are several issues that may be in the Python version, the PyTorch version, and MPS. Although we don't have time to fix them in the short term, it would be very interesting to come back and look at solutions for Apple users after the model is upgraded.

YAY-3M-TA3 commented 4 days ago

Ok! Mild success now. I was able to get the t2v model to output a correct video. I needed to add an autocast to the text encoder. with torch.autocast("mps", dtype=torch.bfloat16):

And modified the autocasts in the app.py to have the same for mps.

Here is a result:(Its only one frame for now - I'm running out of memory for more than 1 frame on a Mac m2 with 24GB - MPS backend out of memory (MPS allocated: 2.12 GB, other allocations: 25.08 GB, max allowed: 27.20 GB). Tried to allocate 256 bytes on shared pool.)

_Note: blocks in the image are because I set the tile_sample_minsize = 64 in an effort to get this to process in under 24GB

https://github.com/user-attachments/assets/a338589e-0c04-49ca-896e-7f06c17eac40

So, the next problem is trying to get this to process in under 24GB. (Currently this needs 27GB )

YAY-3M-TA3 commented 3 days ago

Ok,there is still the weird video output. problem.

While a 1 frame video can be made... https://github.com/user-attachments/assets/d2c9d276-f4b0-45e0-8523-cf4a0feb64f7

a 16 frame video, will still have the problem..

https://github.com/user-attachments/assets/bc4ce18c-86b1-4617-a1f1-b6966f52b6a4

feifeiobama commented 3 days ago

Could this be due to our causal VAE similar to MAGVIT-v2? It results in two different types of latents in the single-frame and 16-frame generation, which were treated separately and could be in different value ranges.

YAY-3M-TA3 commented 3 days ago

This sounds interesting. I had thought that maybe the 1 frame video was only grabbing the initial sd3 image generated, then subsequent frames where being generated by the pyramid model. I'm unfamiliar with MAGVIT-v2 - although, it sounds like part of the "secret sauce" for the latest ai video generation.

I do see where the logic checks for idx = 0 in the chunk_decode, then follows the other logic for frames beyond 0. I didn't notice the difference in the logic branch until you mentioned it (is_init_image).

I did make a modification in modeling_causal_conv.py in the forward function. On MPS, I was getting an error about RuntimeError: Input type (c10::Half) and bias type (c10::BFloat16) should be the same, so I made this cast modification:

dtypeX = x.dtype
x = self.conv(x.to(torch.bfloat16))
x.to(dtypeX)
return x 

Is it possible that this cast to bfloat is the problem?