PKU-YuanGroup / Open-Sora-Plan

This project aim to reproduce Sora (Open AI T2V model), we wish the open source community contribute to this project.
MIT License
11.47k stars 1.02k forks source link

Code for initialize CausalConv3d from pretrained Conv2D. #168

Open Sutongtong233 opened 7 months ago

Sutongtong233 commented 7 months ago

Hi, I find that you introduce in CausalVideoVAE.md that you use special initialization(tail initialization) for CausalConv3d training. I am interested in this trick, and I would be sincerely grateful if you could share the specific initialization code.

vivym commented 7 months ago
w = vae_2d_ckpt["state_dict"][key_2d]            # conv2d weight
new_w = torch.zeros(shape_3d, dtype=w.dtype)
new_w[:, :, -1, :, :] = w

https://github.com/vivym/OmniGen/blob/main/scripts/inflate_conv_for_video_vae.py

Birdylx commented 7 months ago

@vivym thanks, but I have another question about temporal upsample at this line https://github.com/vivym/OmniGen/blob/4f0bf7d7f7dcb6b1b79b50c90153f7477151e139/src/omni_gen/models/video_vae/upsamplers.py#L87, it isn't 2x upsample, it will always be odd frames.

vivym commented 7 months ago

@Birdylx It is indeed an odd number of frames. You can refer to the paper https://arxiv.org/abs/2310.05737

Birdylx commented 7 months ago

@vivym thanks for your quick rely!, I will read the paper for more details.

Birdylx commented 7 months ago

@vivym Do you train the full model? or freeze the model, just train the temporal block?

Sutongtong233 commented 7 months ago

Thanks:) I will have a try.

Sutongtong233 commented 6 months ago

It works. Thanks a lot!

Sutongtong233 commented 6 months ago

I see, "Despite the VAE in Diffusion training being frozen" mentioned in your latest doc. Is that means that you've found freezing 2d-VAE weight ("tail" of casual3d Conv) performs better?

Sutongtong233 commented 6 months ago

@vivym Do you train the full model? or freeze the model, just train the temporal block?

I've tried train the full model, the motion blurring is alleviated, while the single frame reconstruction degrade.

Catpp01 commented 2 weeks ago

w = vae_2d_ckpt["state_dict"][key_2d]  # conv2d weight
new_w = torch.zeros(shape_3d, dtype=w.dtype) # shape_3d = (batch_size, 3, t, height, width)
new_w[:, :, -1, :, :] = w #    --tail initialization
# center   : new_w[:, :, T/2, :, :]
# average  : new_w[:, :, :, :, :]