hkproj / pytorch-stable-diffusion

Stable Diffusion implemented from scratch in PyTorch
https://www.youtube.com/watch?v=ZBKpAp_6TGI
MIT License
608 stars 138 forks source link

RuntimeError: Error(s) in loading state_dict for VAE_Encoder: #15

Open jianingPeng0382 opened 7 months ago

jianingPeng0382 commented 7 months ago

In fact, I faced this problem when I run the demo, it seems like the keys after converted cannot be found. What should I do?
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for VAE_Encoder: Unexpected key(s) in state_dict: "1.groupnorm_1.weight", "1.groupnorm_1.bias", "1.conv_1.weight", "1.conv_1.bias", "1.groupnorm_2.weight", "1.groupnorm_2.bias", "1.conv_2.weight", "1.conv_2.bias", "2.groupnorm_1.weight", "2.groupnorm_1.bias", "2.conv_1.weight", "2.conv_1.bias", "2.groupnorm_2.weight", "2.groupnorm_2.bias", "2.conv_2.weight", "2.conv_2.bias", "4.groupnorm_1.weight", "4.groupnorm_1.bias", "4.conv_1.weight", "4.conv_1.bias", "4.groupnorm_2.weight", "4.groupnorm_2.bias", "4.conv_2.weight", "4.conv_2.bias", "5.groupnorm_1.weight", "5.groupnorm_1.bias", "5.conv_1.weight", "5.conv_1.bias", "5.groupnorm_2.weight", "5.groupnorm_2.bias", "5.conv_2.weight", "5.conv_2.bias", "7.groupnorm_1.weight", "7.groupnorm_1.bias", "7.conv_1.weight", "7.conv_1.bias", "7.groupnorm_2.weight", "7.groupnorm_2.bias", "7.conv_2.weight", "7.conv_2.bias", "8.groupnorm_1.weight", "8.groupnorm_1.bias", "8.conv_1.weight", "8.conv_1.bias", "8.groupnorm_2.weight", "8.groupnorm_2.bias", "8.conv_2.weight", "8.conv_2.bias", "10.groupnorm_1.weight", "10.groupnorm_1.bias", "10.conv_1.weight", "10.conv_1.bias", "10.groupnorm_2.weight", "10.groupnorm_2.bias", "10.conv_2.weight", "10.conv_2.bias", "11.groupnorm_1.weight", "11.groupnorm_1.bias", "11.conv_1.weight", "11.conv_1.bias", "11.groupnorm_2.weight", "11.groupnorm_2.bias", "11.conv_2.weight", "11.conv_2.bias", "12.groupnorm_1.weight", "12.groupnorm_1.bias", "12.conv_1.weight", "12.conv_1.bias", "12.groupnorm_2.weight", "12.groupnorm_2.bias", "12.conv_2.weight", "12.conv_2.bias", "14.groupnorm_1.weight", "14.groupnorm_1.bias", "14.conv_1.weight", "14.conv_1.bias", "14.groupnorm_2.weight", "14.groupnorm_2.bias", "14.conv_2.weight", "14.conv_2.bias".

parth394 commented 7 months ago

So the naming convention of the model needs to match the state dict naming convention you have to check the for example in encoder you need to check is it either groupnorm or groupNorm and the naming should exactly match.

meankitdas commented 7 months ago

I had also got the same error. Try checking your class UNET(nn.Module) and the SwitchSequential . You might have done some mistake there. It solved my issue. Let me know if it works for you as well.