clovaai / stargan-v2

StarGAN v2 - Official PyTorch Implementation (CVPR 2020)
Other
3.5k stars 659 forks source link

Error when loading state_dict for Generator #103

Open kw01sg opened 3 years ago

kw01sg commented 3 years ago

When loading pre-trained network, an error is encountered when loading state_dict for Generator:

$ python main.py --mode sample --num_domains 2 --resume_iter 100000 --w_hpf 1 \
               --checkpoint_dir expr/checkpoints/celeba_hq \
               --result_dir expr/results/celeba_hq \
               --src_dir assets/representative/celeba_hq/src \
               --ref_dir assets/representative/celeba_hq/ref

Namespace(batch_size=8, beta1=0.0, beta2=0.99, checkpoint_dir='expr/checkpoints/celeba_hq', ds_iter=100000, eval_dir='expr/eval', eval_every=50000, f_lr=1e-06, hidden_dim=512, img_size=256, inp_dir='assets/representative/custom/female', lambda_cyc=1, lambda_ds=1, lambda_reg=1, lambda_sty=1, latent_dim=16, lm_path='expr/checkpoints/celeba_lm_mean.npz', lr=0.0001, mode='sample', num_domains=2, num_outs_per_domain=10, num_workers=4, out_dir='assets/representative/celeba_hq/src/female', print_every=10, randcrop_prob=0.5, ref_dir='assets/representative/celeba_hq/ref', result_dir='expr/results/celeba_hq', resume_iter=100000, sample_dir='expr/samples', sample_every=5000, save_every=10000, seed=777, src_dir='assets/representative/celeba_hq/src', style_dim=64, total_iters=100000, train_img_dir='data/celeba_hq/train', val_batch_size=32, val_img_dir='data/celeba_hq/val', w_hpf=1.0, weight_decay=0.0001, wing_path='expr/checkpoints/wing.ckpt')
Number of parameters of generator: 43467395
Number of parameters of mapping_network: 2438272
Number of parameters of style_encoder: 20916928
Number of parameters of discriminator: 20852290
Number of parameters of fan: 6333603
Initializing generator...
Initializing mapping_network...
Initializing style_encoder...
Initializing discriminator...
Preparing DataLoader for the generation phase...
Preparing DataLoader for the generation phase...
Loading checkpoint from expr/checkpoints/celeba_hq/100000_nets_ema.ckpt...
Traceback (most recent call last):
  File "main.py", line 182, in <module>
    main(args)
  File "main.py", line 73, in main
    solver.sample(loaders)
  File "/home/kianboon/miniconda/envs/stargan-v2/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad
    return func(*args, **kwargs)
  File "/home/kianboon/test/core/solver.py", line 178, in sample
    self._load_checkpoint(args.resume_iter)
  File "/home/kianboon/test/core/solver.py", line 73, in _load_checkpoint
    ckptio.load(step)
  File "/home/kianboon/test/core/checkpoint.py", line 48, in load
    module.module.load_state_dict(module_dict[name])
  File "/home/kianboon/miniconda/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict: "hpf.filter". 
kw01sg commented 3 years ago

For anyone who needs a temporary solution, reverting back to commit e28bdee will allow pre-trained model to work.

yzcv commented 3 years ago

For anyone who needs a temporary solution, reverting back to commit e28bdee will allow pre-trained model to work.

Thank you very much.

Aspen233 commented 3 years ago

I solved it by making the following changes to core/checkpoint.py: in line 48. module.module.load_state_dict(module_dict[name]) -> mod.module.load_state_dict(module_dict[name], False)

ammar-deep commented 3 years ago

@Aspen233 what is mod here? Its not defined anywhere so it will throw an error. I think you meant module.

yzcv commented 3 years ago

@Aspen233 what is mod here? It's not defined anywhere so it will throw an error. I think you meant module.

Yes, you are right. The "False" of the strict argument option means to set the load_state_dict() function to ignore non-matching keys.

We can refer to this page. https://pytorch.org/tutorials/beginner/saving_loading_models.html

Aspen233 commented 3 years ago

@Aspen233 what is mod here? Its not defined anywhere so it will throw an error. I think you meant module.

Yeah I mean module. I changed the name. I didn't like seeing a variable with an attribute of its own name. Sorry I forgot to change it back in my answer.

yzcv commented 3 years ago

@Aspen233 what is mod here? Its not defined anywhere so it will throw an error. I think you meant module.

Yeah I mean module. I changed the name. I didn't like seeing a variable with an attribute of its own name. Sorry I forgot to change it back in my

May I ask why there is a mismatch in the model and checkpoint?

yzcv commented 3 years ago

@Aspen233 what is mod here? Its not defined anywhere so it will throw an error. I think you meant module.

Yeah I mean module. I changed the name. I didn't like seeing a variable with an attribute of its own name. Sorry I forgot to change it back in my answer.

Yes, totally understand. Sometimes it is confusing

yzcv commented 3 years ago

@Aspen233 what is mod here? Its not defined anywhere so it will throw an error. I think you meant module.

Yeah I mean module. I changed the name. I didn't like seeing a variable with an attribute of its own name. Sorry I forgot to change it back in my

May I ask why there is a mismatch in the model and checkpoint?

I have figured out the reason behind the mismatch. There is a change in class HighPass by this f306904 commit. Just adding the "false" option is fine and does not affect the result.

Mahmood-Hussain commented 3 years ago

I solved it by making the following changes to core/checkpoint.py: in line 48. module.module.load_state_dict(module_dict[name]) -> mod.module.load_state_dict(module_dict[name], False)

this fixes one issue but then there is one more here:

python main.py --mode sample --num_domains 2 --resume_iter 100000 --w_hpf 1 --checkpoint_dir expr/checkpoints/celeba_hq --result_dir expr/results/celeba_hq --src_dir assets/representative/celeba_hq/src --ref_dir assets/representative/celeba_hq/ref Namespace(batch_size=8, beta1=0.0, beta2=0.99, checkpoint_dir='expr/checkpoints/celeba_hq', ds_iter=100000, eval_dir='expr/eval', eval_every=50000, f_lr=1e-06, hidden_dim=512, img_size=256, inp_dir='assets/representative/custom/female', lambda_cyc=1, lambda_ds=1, lambda_reg=1, lambda_sty=1, latent_dim=16, lm_path='expr/checkpoints/celeba_lm_mean.npz', lr=0.0001, mode='sample', num_domains=2, num_outs_per_domain=10, num_workers=4, out_dir='assets/representative/celeba_hq/src/female', print_every=10, randcrop_prob=0.5, ref_dir='assets/representative/celeba_hq/ref', result_dir='expr/results/celeba_hq', resume_iter=100000, sample_dir='expr/samples', sample_every=5000, save_every=10000, seed=777, src_dir='assets/representative/celeba_hq/src', style_dim=64, total_iters=100000, train_img_dir='data/celeba_hq/train', val_batch_size=32, val_img_dir='data/celeba_hq/val', w_hpf=1.0, weight_decay=0.0001, wing_path='expr/checkpoints/wing.ckpt') Number of parameters of generator: 43467395 Number of parameters of mapping_network: 2438272 Number of parameters of style_encoder: 20916928 Number of parameters of discriminator: 20852290 Number of parameters of fan: 6333603 Initializing generator... Initializing mapping_network... Initializing style_encoder... Initializing discriminator... Preparing DataLoader for the generation phase... Preparing DataLoader for the generation phase... Loading checkpoint from expr/checkpoints/celeba_hq/100000_nets_ema.ckpt... Working on expr/results/celeba_hq/reference.jpg... /home/mahmood/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/functional.py:2506: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details. "See the documentation of nn.Upsample for details.".format(mode)) Working on expr/results/celeba_hq/video_ref.mp4... video_ref: 0%| | 0/32 [00:00<?, ?it/s] Traceback (most recent call last): File "main.py", line 182, in main(args) File "main.py", line 73, in main solver.sample(loaders) File "/home/mahmood/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad return func(*args, kwargs) File "/home/mahmood/methods/stargan-v2/core/solver.py", line 189, in sample utils.video_ref(nets_ema, args, src.x, ref.x, ref.y, fname) File "/home/mahmood/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad return func(*args, *kwargs) File "/home/mahmood/methods/stargan-v2/core/utils.py", line 213, in video_ref interpolated = interpolate(nets, args, x_src, s_prev, s_next) File "/home/mahmood/methods/stargan-v2/core/utils.py", line 168, in interpolate x_fake = nets.generator(x_src, s_ref, masks=masks) File "/home/mahmood/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in call result = self.forward(input, kwargs) File "/home/mahmood/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/home/mahmood/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home/mahmood/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply output.reraise() File "/home/mahmood/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/_utils.py", line 394, in reraise raise self.exc_type(msg) TypeError: Caught TypeError in replica 1 on device 1. Original Traceback (most recent call last): File "/home/mahmood/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker output = module(*input, *kwargs) File "/home/mahmood/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in call result = self.forward(input, **kwargs)

How to solve it?

Mahmood-Hussain commented 3 years ago

Because I have multiple GPUs it throws an error even after changing module.module.load_state_dict(module_dict[name]) -> mod.module.load_state_dict(module_dict[name], False)

now I am using GPU and it solved my problem CUDA_VISIBLE_DEVICES=0 <command>

nguyenviettuan96 commented 1 year ago

https://github.com/clovaai/stargan-v2/issues/103#issuecomment-911051743 do you know why having an error when training on multiple GPUs?

SchrodingersJade commented 1 year ago

#103 (comment) do you know why having an error when training on multiple GPUs?

yunjey said that they can not support multi-GPU training

yunus-temurlenk commented 1 year ago

I solved it by making the following changes to core/checkpoint.py: in line 48. module.module.load_state_dict(module_dict[name]) -> mod.module.load_state_dict(module_dict[name], False)

was helpful thanks. module.module.load_state_dict(module_dict[name], False) worked for me