yilundu / improved_contrastive_divergence

[ICML'21] Improved Contrastive Divergence Training of Energy Based Models
62 stars 14 forks source link

Error about load_state_dict #13

Open HoJ-Onle opened 2 years ago

HoJ-Onle commented 2 years ago

I want to test the performance of pretrained model. But I got the error as follows when I load the celeba model:

RuntimeError: Error(s) in loading state_dict for CelebAModel: Missing key(s) in state_dict: "res_1a.conv1.weight_orig", "res_1a.conv1.weight_u", "res_1a.conv1.weight_orig", "res_1a.conv1.weight_u", "res_1a.conv1.weight_v", "res_1a.conv2.weight_orig", "res_1a.conv2.weight_u", "res_1a.conv2.weight_orig", "res_1a.conv2.weight_u", "res_1a.conv2.weight_v", "res_1b.conv1.weight_orig", "res_1b.conv1.weight_u", "res_1b.conv1.weight_orig", "res_1b.conv1.weight_u", "res_1b.conv1.weight_v", "res_1b.conv2.weight_orig", "res_1b.conv2.weight_u", "res_1b.conv2.weight_orig", "res_1b.conv2.weight_u", "res_1b.conv2.weight_v", "res_2a.conv1.weight_orig", "res_2a.conv1.weight_u", "res_2a.conv1.weight_orig", "res_2a.conv1.weight_u", "res_2a.conv1.weight_v", "res_2a.conv2.weight_orig", "res_2a.conv2.weight_u", "res_2a.conv2.weight_orig", "res_2a.conv2.weight_u", "res_2a.conv2.weight_v", "res_2b.conv1.weight_orig", "res_2b.conv1.weight_u", "res_2b.conv1.weight_orig", "res_2b.conv1.weight_u", "res_2b.conv1.weight_v", "res_2b.conv2.weight_orig", "res_2b.conv2.weight_u", "res_2b.conv2.weight_orig", "res_2b.conv2.weight_u", "res_2b.conv2.weight_v", "res_3a.conv1.weight_orig", "res_3a.conv1.weight_u", "res_3a.conv1.weight_orig", "res_3a.conv1.weight_u", "res_3a.conv1.weight_v", "res_3a.conv2.weight_orig", "res_3a.conv2.weight_u", "res_3a.conv2.weight_orig", "res_3a.conv2.weight_u", "res_3a.conv2.weight_v", "res_3b.conv1.weight_orig", "res_3b.conv1.weight_u", "res_3b.conv1.weight_orig", "res_3b.conv1.weight_u", "res_3b.conv1.weight_v", "res_3b.conv2.weight_orig", "res_3b.conv2.weight_u", "res_3b.conv2.weight_orig", "res_3b.conv2.weight_u", "res_3b.conv2.weight_v", "res_4a.conv1.weight_orig", "res_4a.conv1.weight_u", "res_4a.conv1.weight_orig", "res_4a.conv1.weight_u", "res_4a.conv1.weight_v", "res_4a.conv2.weight_orig", "res_4a.conv2.weight_u", "res_4a.conv2.weight_orig", "res_4a.conv2.weight_u", "res_4a.conv2.weight_v", "res_4b.conv1.weight_orig", "res_4b.conv1.weight_u", "res_4b.conv1.weight_orig", "res_4b.conv1.weight_u", "res_4b.conv1.weight_v", "res_4b.conv2.weight_orig", "res_4b.conv2.weight_u", "res_4b.conv2.weight_orig", "res_4b.conv2.weight_u", "res_4b.conv2.weight_v". Unexpected key(s) in state_dict: "map_fc1.weight", "map_fc1.bias", "map_fc2.weight", "map_fc2.bias", "map_fc3.weight", "map_fc3.bias", "map_fc4.weight", "map_fc4.bias", "res_1a.conv1.weight", "res_1a.conv2.weight", "res_1b.conv1.weight", "res_1b.conv2.weight", "res_2a.conv1.weight", "res_2a.conv2.weight", "res_2b.conv1.weight", "res_2b.conv2.weight", "res_3a.conv1.weight", "res_3a.conv2.weight", "res_3b.conv1.weight", "res_3b.conv2.weight", "res_4a.conv1.weight", "res_4a.conv2.weight", "res_4b.conv1.weight", "res_4b.conv2.weight". size mismatch for res_1a.latent_map.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 2]). size mismatch for res_1a.latent_map_2.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 2]). size mismatch for res_1b.latent_map.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for res_1b.latent_map_2.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for res_2a.latent_map.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for res_2a.latent_map_2.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for res_2b.latent_map.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for res_2b.latent_map_2.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for res_3a.latent_map.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 2]). size mismatch for res_3a.latent_map_2.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 2]). size mismatch for res_3b.latent_map.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 2]). size mismatch for res_3b.latent_map_2.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 2]). size mismatch for res_4a.latent_map.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([512, 2]). size mismatch for res_4a.latent_map_2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([512, 2]). size mismatch for res_4b.latent_map.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([512, 2]). size mismatch for res_4b.latent_map_2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([512, 2]). size mismatch for mid_res_1a.latent_map.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for mid_res_1a.latent_map_2.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for mid_res_1b.latent_map.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for mid_res_1b.latent_map_2.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for mid_res_2a.latent_map.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for mid_res_2a.latent_map_2.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for mid_res_2b.latent_map.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for mid_res_2b.latent_map_2.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for mid_res_3a.latent_map.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 2]). size mismatch for mid_res_3a.latent_map_2.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 2]). size mismatch for mid_res_3b.latent_map.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 2]). size mismatch for mid_res_3b.latent_map_2.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 2]). size mismatch for small_res_1a.latent_map.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for small_res_1a.latent_map_2.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for small_res_1b.latent_map.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for small_res_1b.latent_map_2.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for small_res_2a.latent_map.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for small_res_2a.latent_map_2.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for small_res_2b.latent_map.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]). size mismatch for small_res_2b.latent_map_2.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 2]).

yilundu commented 2 years ago

Hi — are you using the model files provided in the Dropbox link? You need to overwrite the model files in the current repo as the architecture for the pretrained models are a bit different.

HoJ-Onle commented 2 years ago

Hello. Thanks for your response!! I solved the problem and I got a new one. I followed your command to train a CIFAR10 model by myself:

python train.py --exp=cifar10_model --step_lr=100.0 --num_steps=40 --cuda --ensembles=1 --kl_coeff=1.0 --kl=True --multiscale --self_attn

But I got the result as follows:

Inception score of 1.3539612293243408 with std of 0.0 Obtained fid value of 329.3163920255062 FID of score 329.3163920255062

I choosed the resume_iter 20000 for the test. And I trained the model with a GPU V100. Maybe something went wrong, how can I get the results in the paper. Looking forward to your help. Thank you!!

yilundu commented 2 years ago

Hi, did you train the model for 20000 iterations? What do the samples look like? The initial samples logged when resuming are poor but they get better as the replay buffer evolves (when training runs for longer) as they have only been sampled for a small duration. You can also use the test generation code for the full generative results

HoJ-Onle commented 2 years ago

Yup. I trained the model for 20000 iterations. The samples look very confusing. I found that the Inception Score is around 5.7 during the training. And I found that the std of energy is "inf" when I run the test_inception.py for cifar-10. Maybe there's something wrong with the command I'm using? Looking forward to your help. Thanks!!

yilundu commented 2 years ago

Hi, you can use scripts similar to https://github.com/yilundu/improved_contrastive_divergence/blob/master/celeba_gen.py#L47 to generate samples for FID/Inception score evaluation. The score reported during training is only evaluated on a small number of samples and is a lower approximation of the actual FID/inception score.

HoJ-Onle commented 2 years ago

OK. I will try it later. Thanks!

HoJ-Onle commented 2 years ago

Hello! Thanks for your help!! I had trained the model enough iteration. How can I repoduce the FID/IS for the paper? I run the command as follows:

python test_inception.py --exp cifar10_model --num_steps 10 --batch_size=512 --step_lr=100.0 --resume_iter=120000 --im_number=10000

And I got:

Inception score of 3.913252353668213 with std of 0.005254626274108887
Obtained fid value of 123.11683642496388
FID of score 123.11683642496388

I don't know how to set the correct hyperparameters. Looking forward to your reply!

yilundu commented 2 years ago

Does the following command work?

python test_inception.py --exp=<> --num_steps=20 --batch_size=512 --step_lr=10.0 --resume_iter=121200 --im_number=50000 --repeat_scale=30 --nomix=4  --ema
HoJ-Onle commented 2 years ago

Hello! Thanks for your reply! I didn't save the checkpoint at 121200. So I tried to use 121000 and 122000 to test. But it doesn't work. Even the FID is higher. I observed the generative pictures and many of them were black and blurry.

yilundu commented 2 years ago

Hmm can you plot your training curve? At the checkpoint you evaluated

HoJ-Onle commented 2 years ago

Hello! Thanks for your patient help!!! I'm sorry that I didn't notice my MailTips and didn't get back to you in time. As you said, I plotted my training curve about the IS. I found that during training, my IS was up to 6.46 at 75,000 iteration. It can reach 6.314 at 121200 iteration. And the IS of 121,000 iteration and 122,000 iteration used in my previous process can also reach more than 6. But during the test, my IS was very low and could only reach about 2. So I would like to ask for your help in testing the cifar10 dataset if I need to modify the test_inception.py. I found that the default "dataset parameter" in test_inception.py is cifar10, so I didn't revise it. Looking forward to your help! Thank you so much!!

yilundu commented 2 years ago

Hi,

Yes happy to help resolve and figure out what is going on. Can you see the output images you get when you use the command I linked below?

python test_inception.py --exp=<> --num_steps=20 --batch_size=512 --step_lr=10.0 --resume_iter=121200 --im_number=50000 --repeat_scale=30 --nomix=4 --ema

Best, Yilun

HoJ-Onle commented 2 years ago

Hello! Such a quick response! Thank you so much! Sure, I observed the output images and found that they were very dark and blurry. That's probably why I got a low IS. Have you ever encountered it?

yilundu commented 2 years ago

Hi, yeah that's a bit odd -- can you check to make sure you are loading the model that was getting inception scores of 6.3? This inception score is only calculated on 128 samples, so should correspond to test inception scores of around 8ish.