clovaai / stargan-v2

StarGAN v2 - Official PyTorch Implementation (CVPR 2020)
Other
3.48k stars 653 forks source link

Missing keys(s) in state_dict: "hpf_filter" #137

Open fanbyprinciple opened 2 years ago

fanbyprinciple commented 2 years ago
(myenv) misthios@cerebrus:~/Documents/github_projects/stargan-v2$ python main.py --mode sample --num_domains 2 --resume_iter 100000 --w_hpf 1 \
>                --checkpoint_dir expr/checkpoints/celeba_hq \
>                --result_dir expr/results/celeba_hq \
>                --src_dir assets/representative/celeba_hq/src \
>                --ref_dir assets/representative/celeba_hq/ref
Namespace(img_size=256, num_domains=2, latent_dim=16, hidden_dim=512, style_dim=64, lambda_reg=1, lambda_cyc=1, lambda_sty=1, lambda_ds=1, ds_iter=100000, w_hpf=1.0, randcrop_prob=0.5, total_iters=100000, resume_iter=100000, batch_size=8, val_batch_size=32, lr=0.0001, f_lr=1e-06, beta1=0.0, beta2=0.99, weight_decay=0.0001, num_outs_per_domain=10, mode='sample', num_workers=4, seed=777, train_img_dir='data/celeba_hq/train', val_img_dir='data/celeba_hq/val', sample_dir='expr/samples', checkpoint_dir='expr/checkpoints/celeba_hq', eval_dir='expr/eval', result_dir='expr/results/celeba_hq', src_dir='assets/representative/celeba_hq/src', ref_dir='assets/representative/celeba_hq/ref', inp_dir='assets/representative/custom/female', out_dir='assets/representative/celeba_hq/src/female', wing_path='expr/checkpoints/wing.ckpt', lm_path='expr/checkpoints/celeba_lm_mean.npz', print_every=10, sample_every=5000, save_every=10000, eval_every=50000)
Number of parameters of generator: 43467395
Number of parameters of mapping_network: 2438272
Number of parameters of style_encoder: 20916928
Number of parameters of discriminator: 20852290
Number of parameters of fan: 6333603
Initializing generator...
Initializing mapping_network...
Initializing style_encoder...
Initializing discriminator...
Preparing DataLoader for the generation phase...
Preparing DataLoader for the generation phase...
Loading checkpoint from expr/checkpoints/celeba_hq/100000_nets_ema.ckpt...
Traceback (most recent call last):
  File "/home/misthios/Documents/github_projects/stargan-v2/main.py", line 182, in <module>
    main(args)
  File "/home/misthios/Documents/github_projects/stargan-v2/main.py", line 73, in main
    solver.sample(loaders)
  File "/home/misthios/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/misthios/Documents/github_projects/stargan-v2/core/solver.py", line 178, in sample
    self._load_checkpoint(args.resume_iter)
  File "/home/misthios/Documents/github_projects/stargan-v2/core/solver.py", line 73, in _load_checkpoint
    ckptio.load(step)
  File "/home/misthios/Documents/github_projects/stargan-v2/core/checkpoint.py", line 48, in load
    module.module.load_state_dict(module_dict[name])
  File "/home/misthios/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict: "hpf.filter". 
424341023 commented 2 years ago

I had the same problem.Could you solve this problem.

424341023 commented 2 years ago

I solved it by making the following changes to core/checkpoint.py: in line 48. module.module.load_state_dict(module_dict[name]) -> module.module.load_state_dict(module_dict[name], False)

lawleenaja commented 1 year ago

if your using google colab, add these lines to your notebook before running any main.py commands !cp /content/stargan-v2/core/checkpoint.py /content/stargan-v2/core/checkpoint_orig.py !sed -i 's|module.module.load_state_dict(module_dict[name])|module.module.load_state_dict(module_dict[name], False)|g' /content/stargan-v2/core/checkpoint.py

AdityaVishwakarmaIN commented 1 year ago

I solved it by making the following changes to core/checkpoint.py: in line 48. module.module.load_state_dict(module_dict[name]) -> module.module.load_state_dict(module_dict[name], False)

Perfect solution