jy0205 / Pyramid-Flow

Code of Pyramidal Flow Matching for Efficient Video Generative Modeling
https://pyramid-flow.github.io/
MIT License
2.4k stars 233 forks source link

Training issues #159

Closed rob-hen closed 2 weeks ago

rob-hen commented 2 weeks ago

Hi all,

thank you for open-source this great work. I am facing some difficulties training the model based on flux, using train_pyramid_flow_without_ar.sh.

  1. The script train_pyramid_flow_without_ar.sh sets the flag load_vae. However in the code you assume to use pre-computed vae files.
  2. On 8 A100s, even with SHARD_STRATEGY=zero3, I can train only with NUM_FRAMES=8, for RESOLUTION="384p", BATCH_SIZE=4, TASK=t2v, on the flux model. With 16 or 32 frames, I am getting CUDA OOM. How did you train the model with larger number of frames? Did you use multi-node training? I modified the code to run on 8*5 A100 GPUs and are still getting out of memory for 16 frames.
  3. Why it says in the comments about NUM_FRAMES: # e.g., 8 for 2s, 16 for 5s, 32 for 10s, if the model is generating videos with 24 fps? And how do I train on 5 seconds with 24 fps, i.e. 120 frames?
rob-hen commented 2 weeks ago

I realised that NUM_FRAMES indicates the number of frames in VAE's latent space. So NUM_FRAMES=8 should correspond to 64 frames, and NUM_FRAMES=16 to 128 frames.

jy0205 commented 2 weeks ago

Hi! Here are the responses:

  1. If you use the pre-extracted vae latents for videos, you can set the load_vae to False. I set it to true since the image training does not use the pre-extracted vae latents.
  2. For the train_pyramid_flow_without_ar, we do not implement the specific memory optimizations now. You can reduce the memory cost by using sequence parallel or enabling the gradient checkpointing.
  3. The frame number and the vae latent temporal dim have the following relation: num_frames = 8 * (temp - 1) + 1