Vchitect / Latte

Latte: Latent Diffusion Transformer for Video Generation.
Apache License 2.0
1.44k stars 147 forks source link

Question: model code and design choices #77

Open mrartemevmorphic opened 2 months ago

mrartemevmorphic commented 2 months ago

Hello! Thank you for your great paper and for publishing the code and checkpoints for the t2v models. While reading about how it all works, I had a number of questions. I hope you'll find some time to answer at least some of them. Feel free to direct me to your paper if it is already explained there :)

  1. What is the reasoning behind using such a small patch size = 2? Usually, I see patch sizes of 16 or 8 used, especially when generating 512x512 images.
  2. I see that you used LoRACompatible modules for linear projection. Have you thought about how this architecture could be expanded with LoRAs?
  3. Have you thought about adding some image-specific positional encoding to appended images?
  4. What is the purpose of args.fixed_spatial? In what cases would one want to train only spatial layers?
  5. In the provided training script, the decay for EMA is set to 0. Does that mean that the provided checkpoint was trained without EMA? Link: here
  6. Given that you are already passing "scaling_factor": 0.18215 to the VAE model, why do you scale it again in the training loop? Link: here
  7. Given that you are already doing attention masking in the encode_prompt function, why are you passing attention_mask and encoder_attention_mask arguments to the model's forward method? I'm not sure if I'm right, but it seems that both of these arguments are never used.
  8. How do you switch between using fp16 and fp32 in the training script?
  9. Training the model for more than 16 frames often results in checkerboard artifacts and significantly reduced quality. Do you think this is a limitation of the Latte model's architecture? I've seen that you recommend looking into autoregressive video modeling, but still, how can we effectively scale the number of frames generated from 16 to 32 without changing the architecture or sampling method?
  10. In the implementation of the BasicTransformerBlock, there is a lot of commented-out code with the cross-attention implementation. Does this mean that the pretrained checkpoint was trained without it?

Thank you again for your work, and I look forward to your answers!

maxin-cn commented 2 months ago

Hello! Thank you for your great paper and for publishing the code and checkpoints for the t2v models. While reading about how it all works, I had a number of questions. I hope you'll find some time to answer at least some of them. Feel free to direct me to your paper if it is already explained there :)

  1. What is the reasoning behind using such a small patch size = 2? Usually, I see patch sizes of 16 or 8 used, especially when generating 512x512 images.
  2. I see that you used LoRACompatible modules for linear projection. Have you thought about how this architecture could be expanded with LoRAs?
  3. Have you thought about adding some image-specific positional encoding to appended images?
  4. What is the purpose of args.fixed_spatial? In what cases would one want to train only spatial layers?
  5. In the provided training script, the decay for EMA is set to 0. Does that mean that the provided checkpoint was trained without EMA? Link: here
  6. Given that you are already passing "scaling_factor": 0.18215 to the VAE model, why do you scale it again in the training loop? Link: here
  7. Given that you are already doing attention masking in the encode_prompt function, why are you passing attention_mask and encoder_attention_mask arguments to the model's forward method? I'm not sure if I'm right, but it seems that both of these arguments are never used.
  8. How do you switch between using fp16 and fp32 in the training script?
  9. Training the model for more than 16 frames often results in checkerboard artifacts and significantly reduced quality. Do you think this is a limitation of the Latte model's architecture? I've seen that you recommend looking into autoregressive video modeling, but still, how can we effectively scale the number of frames generated from 16 to 32 without changing the architecture or sampling method?
  10. In the implementation of the BasicTransformerBlock, there is a lot of commented-out code with the cross-attention implementation. Does this mean that the pretrained checkpoint was trained without it?

Thank you again for your work, and I look forward to your answers!

Hi, thanks for your interest.

  1. Directly inherited from DiT and Pixart-alpha.
  2. Not yet, but it should be easy, depending on what you want to do with this.
  3. What's the positional embedding, for example? The spatial part is encoded with the absolute positional embedding.
  4. Someone only wants to train the temporal module.
  5. No, the setting of 0 here is just to synchronize the parameter values with the model at the beginning of training.
  6. VAE itself does not multiply by this scaling factor.
  7. attention_mask is used for training; encoder_attention_mask is used fro both training and testing.
  8. The corresponding parameters are controlled in config.
  9. Training on longer frames, such as 32, I did not experience a serious drop in quality.
  10. The autoregressive method is just a training-free method, and there are some training-free methods that can generate longer videos than the base model.
  11. Yes, it is from diffusers and not used in Latte.