jy0205 / Pyramid-Flow

Code of Pyramidal Flow Matching for Efficient Video Generative Modeling
https://pyramid-flow.github.io/
MIT License
2.4k stars 233 forks source link

Where can I get LPIPS vgg checkpoints? #163

Open vanche opened 2 weeks ago

vanche commented 2 weeks ago

Hi, thank you for this great work! It seems that I need the checkpoint for LPIPS vgg to train the VAE. I referred to your utils.py code and downloaded vgg.pth, but I'm getting a "Missing key(s) in state_dict:" error, possibly because it's a different checkpoint from the model. Could you please help me find where I can obtain the checkpoint for LPIPS vgg?

jy0205 commented 2 weeks ago

Hi, the LPIPS vgg checkpoints is provided in the official vqgan repo. See here. You can find the downloaded URL.

jy0205 commented 2 weeks ago

You can download from here

vanche commented 2 weeks ago

Thank you for your prompt response. However, after checking the model checkpoint from the link you provided, I'm still encountering the same issue. So, I reviewed the relevant code again. The LPIPS class is similar to the one in the Open-Sora project (link), so I compared the two codes.

In your code, you use strict=True when calling load_state_dict. In this case, it seems that weights for self.scaling_layer and self.net within LPIPS are required. When I set strict=False in self.load_state_dict, the following RuntimeError I was experiencing does not occur: RuntimeError: Error(s) in loading state_dict for LPIPS: Missing key(s) in state_dict: "scaling_layer.shift", "scaling_layer.scale", "net.slice1.0.weight", "net.slice1.0.bias", "net.slice1.2.weight", "net.slice1.2.bias", "net.slice2.5.weight", "net.slice2.5.bias", "net.slice2.7.weight", "net.slice2.7.bias", "net.slice3.10.weight", "net.slice3.10.bias", "net.slice3.12.weight", "net.slice3.12.bias", "net.slice3.14.weight", "net.slice3.14.bias", "net.slice4.17.weight", "net.slice4.17.bias", "net.slice4.19.weight", "net.slice4.19.bias", "net.slice4.21.weight", "net.slice4.21.bias", "net.slice5.24.weight", "net.slice5.24.bias", "net.slice5.26.weight", "net.slice5.26.bias", "net.slice5.28.weight", "net.slice5.28.bias".

Would setting strict=False align with your intended functionality?

Also, when using strict=False, it seems appropriate to set pretrained=True for self.net, which is the VGG network. I wanted to confirm whether this matches your intentions.

I look forward to your response. Thank you!

Stefano-retinize commented 2 weeks ago

The link you have provided is size 6.7 KB. And the Lpips model you are using has 14,716,161 parameters is impossible they match. Also, it seems to be a different file since it was saved in your project as: '/home/jinyang06/models/vae/video_vae_baseline/vgg_lpips.pth' which seems to be different than the vgg.pth file you are referring to.

vanche commented 4 days ago

Do you have any updated answers or additional information regarding my question? @jy0205

catsled commented 7 hours ago

我也遇到这个问题了,有谁有好的解决方法吗

catsled commented 6 hours ago

@catsled

image image

like this