junyanz / pytorch-CycleGAN-and-pix2pix

Image-to-Image Translation in PyTorch
Other
22.8k stars 6.29k forks source link

Error(s) in loading state_dict for ResnetGenerator after replace the ConvTranspose2d in ResnetGenerator. #1566

Open KurobaHN opened 1 year ago

KurobaHN commented 1 year ago

I got Error in loading state_dict for ResnetGenerator with continue_train param after replace the ConvTranspose2d in ResnetGenerator either load_state_dict(state_dict) as strict=false. If any way can I try to fix this issue?

My option setting batch_size: 1 beta1: 0.5 checkpoints_dir: ./checkpoints continue_train: True [default: False] crop_size: 256 dataroot: ./datasets/test [default: None] dataset_mode: unaligned direction: AtoB display_env: main display_freq: 400 display_id: 0 display_ncols: 4 display_port: 8097 display_server: http://localhost display_winsize: 256 epoch: latest epoch_count: 1 gan_mode: lsgan gpu_ids: -1 init_gain: 0.02 init_type: normal input_nc: 3 isTrain: True [default: None] lambda_A: 10.0 lambda_B: 10.0 lambda_identity: 0.5 load_iter: 0 [default: 0] load_size: 286 lr: 0.0002 lr_decay_iters: 50 lr_policy: linear max_dataset_size: inf model: cycle_gan n_epochs: 100 n_epochs_decay: 100 n_layers_D: 3 name: test [default: experiment_name] ndf: 64 netD: basic netG: resnet_9blocks ngf: 64 no_dropout: True no_flip: False no_html: False norm: instance num_threads: 4 output_nc: 3 phase: train pool_size: 50 preprocess: resize_and_crop print_freq: 100 save_by_iter: False save_epoch_freq: 5 save_latest_freq: 5000 serial_batches: False suffix: update_html_freq: 1000 verbose: False

The error code: RuntimeError: Error(s) in loading state_dict for ResnetGenerator: Missing key(s) in state_dict: "model.21.weight", "model.21.bias", "model.30.weight", "model.30.bias". Unexpected key(s) in state_dict: "model.19.weight", "model.19.bias", "model.22.weight", "model.22.bias". size mismatch for model.26.weight: copying a param with shape torch.Size([3, 64, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 128, 3, 3]). size mismatch for model.26.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([64]).

taesungp commented 1 year ago

The strict=False option can deal with missing parameters, but it cannot deal with size mismatch. For example, in your case, model.26.weight has shape [3, 64, 7, 7] from the checkpoint, but the model architecture requires it to be [64, 128, 3, 3].

There is no good way to resolve this, because it just means they are two different layers, which are not meant to be loaded from each other. If you'd still like to load the rest of the weights from the checkpoints, you can do so by removing the problematic weights from the state dict. For example, after you load state_dict from the checkpoint, you can remove model.26.weight and model.26.bias from the dictionary.