Open Sutongtong233 opened 7 months ago
w = vae_2d_ckpt["state_dict"][key_2d] # conv2d weight
new_w = torch.zeros(shape_3d, dtype=w.dtype)
new_w[:, :, -1, :, :] = w
https://github.com/vivym/OmniGen/blob/main/scripts/inflate_conv_for_video_vae.py
@vivym thanks, but I have another question about temporal upsample at this line https://github.com/vivym/OmniGen/blob/4f0bf7d7f7dcb6b1b79b50c90153f7477151e139/src/omni_gen/models/video_vae/upsamplers.py#L87, it isn't 2x upsample, it will always be odd frames.
@Birdylx It is indeed an odd number of frames. You can refer to the paper https://arxiv.org/abs/2310.05737
@vivym thanks for your quick rely!, I will read the paper for more details.
@vivym Do you train the full model? or freeze the model, just train the temporal block?
Thanks:) I will have a try.
It works. Thanks a lot!
I see, "Despite the VAE in Diffusion training being frozen" mentioned in your latest doc. Is that means that you've found freezing 2d-VAE weight ("tail" of casual3d Conv) performs better?
@vivym Do you train the full model? or freeze the model, just train the temporal block?
I've tried train the full model, the motion blurring is alleviated, while the single frame reconstruction degrade.
w = vae_2d_ckpt["state_dict"][key_2d] # conv2d weight
new_w = torch.zeros(shape_3d, dtype=w.dtype) # shape_3d = (batch_size, 3, t, height, width)
new_w[:, :, -1, :, :] = w # --tail initialization
# center : new_w[:, :, T/2, :, :]
# average : new_w[:, :, :, :, :]
Hi, I find that you introduce in CausalVideoVAE.md that you use special initialization(tail initialization) for CausalConv3d training. I am interested in this trick, and I would be sincerely grateful if you could share the specific initialization code.