jy0205 / Pyramid-Flow

Code of Pyramidal Flow Matching for Efficient Video Generative Modeling
https://pyramid-flow.github.io/
MIT License
1.83k stars 154 forks source link

Image tensor is NaN on i2v with more than 1 frame and NaN with t2v with 1 or many frames. #98

Open YAY-3M-TA3 opened 4 days ago

YAY-3M-TA3 commented 4 days ago

I traced the NaN back to def _get_t5_prompt_embeds where prompt_embeds is NaN,

_(I am trying to modify the project to work on Mac MPS. Python=3.10.13, torch-2.6.0 nightly, torchvision-0.20.0 - I was able to use these versions for comfyui flux on MPS successfully. I also modified the code to move "cuda" tensors to "mps" with all dtypes set to bfloat16.

..and in def rope(), I did have to change scale = torch.arange(0, dim, 2, dtype=torch.**float64**, device=pos.device) / dim to scale = torch.arange(0, dim, 2, dtype=torch.**bfloat16**, device=pos.device) / dim

i2v with 1 frame does work with this setup)_

Here are the details from def _get_t5_prompt_embeds: self.tokenizer_3 is called with prompt = ['A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors, hyper quality, Ultra HD, 8K'] and max_sequence_length = 128

the result is:

**text_inputs** = {'input_ids': tensor([[   71,  1974,  6943,  4767,     8, 12560,    13,     8,   604,   215,
           625,   628,   388,  5119,     3,     9,  1131, 13996, 17989,  1054,
         11718, 18691,     6,  1692,  5796,     6,  3136,  9980,     6, 10276,
          1225,   869,     6,  2538,    30,  3097,   635,   814,     6, 18744,
          2602,     6,  6676,   463,     6,  8618,  3726,     6,   505,   439,
             1,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])}

Then after self.text_encoder_3(), ``` prompt_embeds = tensor([[[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], ..., [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan]]], device='mps:0', dtype=torch.bfloat16)



Any idea why **self.text_encoder_3()** would return NaN?
feifeiobama commented 4 days ago

Thanks for trying to generalize our method to multiple frame conditioning. Our model should be able to handle this situation since it generates in an autoregressive manner. However, there is a problem due to the VAE latent space it operates on, since we use MAGVIT-v2 like causal VAE, it looks like:

1st frame, 2nd-9th frame, 10th-17th frame, ...

This means that the model can only be modified for single-frame / 9-frame / 17-frame conditioned video generation.

feifeiobama commented 4 days ago

As for the NaN error, I think it's because we assume the number of frames follows the rules above, so there can be a computational error if you pass an unexpected number of conditioned frames. We may include assertions in future versions of the code to avoid this situation.

YAY-3M-TA3 commented 3 days ago

Ok- I solved the NaN problem. For MPS, needed to add an autocast to the text encoder call in pyramid_dit_for_video_gen_pipeline.py

with torch.autocast("mps", dtype=torch.bfloat16):

Now no NaN through to VAE image.

https://github.com/user-attachments/assets/a338589e-0c04-49ca-896e-7f06c17eac40