yuval-alaluf / stylegan3-editing

Official Implementation of "Third Time's the Charm? Image and Video Editing with StyleGAN3" (AIM ECCVW 2022) https://arxiv.org/abs/2201.13433
https://yuval-alaluf.github.io/stylegan3-editing/
MIT License
654 stars 73 forks source link

Preparing our Restyle_psp_encoder #29

Closed rut00 closed 2 years ago

rut00 commented 2 years ago

We are preparing our Restyle_psp_encoder with the custom dataset.

We have trained our StyleGAN3 network of type StyleGAN3-T (translation equiv.) and then converted the generated .pkl file to a .pt file using the snippet provided here: #16 (issue comment)

And after running the train_restyle_psp.py with the below command:

!python train_restyle_psp.py \
--dataset_type=my_encode \
--encoder_type=BackboneEncoder \
--exp_dir=experiment/restyle_psp_my_encode \
--workers=8 \
--batch_size=8 \
--test_batch_size=8 \
--test_workers=8 \
--val_interval=5000 \
--save_interval=10000 \
--start_from_latent_avg 'True' \
--lpips_lambda=0.8 \
--l2_lambda=1 \
--w_norm_lambda=0 \
--id_lambda=0.1 \
--input_nc=6 \
--n_iters_per_batch=5 \
--output_size=128 \
--stylegan_weights=/content/drive/MyDrive/myfile.pt

I am getting the below-mentioned error:

Loading StyleGAN3 generator from path: /content/drive/MyDrive/myfile.pt
Traceback (most recent call last):
  File "./models/stylegan3/model.py", line 61, in _load_checkpoint
    self.decoder.load_state_dict(torch.load(checkpoint_path), strict=True)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1498, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict: "synthesis.L0_36_1024.weight", "synthesis.L0_36_1024.bias", "synthesis.L0_36_1024.magnitude_ema", "synthesis.L0_36_1024.up_filter", "synthesis.L0_36_1024.down_filter", "synthesis.L0_36_1024.affine.weight", "synthesis.L0_36_1024.affine.bias", "synthesis.L1_36_1024.weight", "synthesis.L1_36_1024.bias", "synthesis.L1_36_1024.magnitude_ema", "synthesis.L1_36_1024.up_filter", "synthesis.L1_36_1024.down_filter", "synthesis.L1_36_1024.affine.weight", "synthesis.L1_36_1024.affine.bias", "synthesis.L2_52_1024.weight", "synthesis.L2_52_1024.bias", "synthesis.L2_52_1024.magnitude_ema", "synthesis.L2_52_1024.up_filter", "synthesis.L2_52_1024.down_filter", "synthesis.L2_52_1024.affine.weight", "synthesis.L2_52_1024.affine.bias", "synthesis.L3_52_1024.weight", "synthesis.L3_52_1024.bias", "synthesis.L3_52_1024.magnitude_ema", "synthesis.L3_52_1024.up_filter", "synthesis.L3_52_1024.down_filter", "synthesis.L3_52_1024.affine.weight", "synthesis.L3_52_1024.affine.bias", "synthesis.L4_84_1024.weight", "synthesis.L4_84_1024.bias", "synthesis.L4_84_1024.magnitude_ema", "synthesis.L4_84_1024.up_filter", "synthesis.L4_84_1024.down_filter", "synthesis.L4_84_1024.affine.weight", "synthesis.L4_84_1024.affine.bias", "synthesis.L5_148_1024.weight", "synthesis.L5_148_1024.bias", "synthesis.L5_148_1024.magnitude_ema", "synthesis.L5_148_1024.up_filter", "synthesis.L5_148_1024.down_filter", "synthesis.L5_148_1024.affine.weight", "synthesis.L5_148_1024.affine.bias", "synthesis.L6_148_1024.weight", "synthesis.L6_148_1024.bias", "synthesis.L6_148_1024.magnitude_ema", "synthesis.L6_148_1024.up_filter", "synthesis.L6_148_1024.down_filter", "synthesis.L6_148_1024.affine.weight", "synthesis.L6_148_1024.affine.bias", "synthesis.L7_276_645.weight", "synthesis.L7_276_645.bias", "synthesis.L7_276_645.magnitude_ema", "synthesis.L7_276_645.up_filter", "synthesis.L7_276_645.down_filter", "synthesis.L7_276_645.affine.weight", "synthesis.L7_276_645.affine.bias", "synthesis.L8_276_406.weight", "synthesis.L8_276_406.bias", "synthesis.L8_276_406.magnitude_ema", "synthesis.L8_276_406.up_filter", "synthesis.L8_276_406.down_filter", "synthesis.L8_276_406.affine.weight", "synthesis.L8_276_406.affine.bias", "synthesis.L9_532_256.weight", "synthesis.L9_532_256.bias", "synthesis.L9_532_256.magnitude_ema", "synthesis.L9_532_256.up_filter", "synthesis.L9_532_256.down_filter", "synthesis.L9_532_256.affine.weight", "synthesis.L9_532_256.affine.bias", "synthesis.L10_1044_161.weight", "synthesis.L10_1044_161.bias", "synthesis.L10_1044_161.magnitude_ema", "synthesis.L10_1044_161.up_filter", "synthesis.L10_1044_161.down_filter", "synthesis.L10_1044_161.affine.weight", "synthesis.L10_1044_161.affine.bias", "synthesis.L11_1044_102.weight", "synthesis.L11_1044_102.bias", "synthesis.L11_1044_102.magnitude_ema", "synthesis.L11_1044_102.up_filter", "synthesis.L11_1044_102.down_filter", "synthesis.L11_1044_102.affine.weight", "synthesis.L11_1044_102.affine.bias", "synthesis.L12_1044_64.weight", "synthesis.L12_1044_64.bias", "synthesis.L12_1044_64.magnitude_ema", "synthesis.L12_1044_64.up_filter", "synthesis.L12_1044_64.down_filter", "synthesis.L12_1044_64.affine.weight", "synthesis.L12_1044_64.affine.bias", "synthesis.L13_1024_64.weight", "synthesis.L13_1024_64.bias", "synthesis.L13_1024_64.magnitude_ema", "synthesis.L13_1024_64.up_filter", "synthesis.L13_1024_64.down_filter", "synthesis.L13_1024_64.affine.weight", "synthesis.L13_1024_64.affine.bias", "synthesis.L14_1024_3.weight", "synthesis.L14_1024_3.bias", "synthesis.L14_1024_3.magnitude_ema", "synthesis.L14_1024_3.affine.weight", "synthesis.L14_1024_3.affine.bias". 
    Unexpected key(s) in state_dict: "synthesis.L0_36_512.weight", "synthesis.L0_36_512.bias", "synthesis.L0_36_512.magnitude_ema", "synthesis.L0_36_512.up_filter", "synthesis.L0_36_512.down_filter", "synthesis.L0_36_512.affine.weight", "synthesis.L0_36_512.affine.bias", "synthesis.L1_36_512.weight", "synthesis.L1_36_512.bias", "synthesis.L1_36_512.magnitude_ema", "synthesis.L1_36_512.up_filter", "synthesis.L1_36_512.down_filter", "synthesis.L1_36_512.affine.weight", "synthesis.L1_36_512.affine.bias", "synthesis.L2_36_512.weight", "synthesis.L2_36_512.bias", "synthesis.L2_36_512.magnitude_ema", "synthesis.L2_36_512.up_filter", "synthesis.L2_36_512.down_filter", "synthesis.L2_36_512.affine.weight", "synthesis.L2_36_512.affine.bias", "synthesis.L3_52_512.weight", "synthesis.L3_52_512.bias", "synthesis.L3_52_512.magnitude_ema", "synthesis.L3_52_512.up_filter", "synthesis.L3_52_512.down_filter", "synthesis.L3_52_512.affine.weight", "synthesis.L3_52_512.affine.bias", "synthesis.L4_52_512.weight", "synthesis.L4_52_512.bias", "synthesis.L4_52_512.magnitude_ema", "synthesis.L4_52_512.up_filter", "synthesis.L4_52_512.down_filter", "synthesis.L4_52_512.affine.weight", "synthesis.L4_52_512.affine.bias", "synthesis.L5_84_512.weight", "synthesis.L5_84_512.bias", "synthesis.L5_84_512.magnitude_ema", "synthesis.L5_84_512.up_filter", "synthesis.L5_84_512.down_filter", "synthesis.L5_84_512.affine.weight", "synthesis.L5_84_512.affine.bias", "synthesis.L6_84_512.weight", "synthesis.L6_84_512.bias", "synthesis.L6_84_512.magnitude_ema", "synthesis.L6_84_512.up_filter", "synthesis.L6_84_512.down_filter", "synthesis.L6_84_512.affine.weight", "synthesis.L6_84_512.affine.bias", "synthesis.L7_148_512.weight", "synthesis.L7_148_512.bias", "synthesis.L7_148_512.magnitude_ema", "synthesis.L7_148_512.up_filter", "synthesis.L7_148_512.down_filter", "synthesis.L7_148_512.affine.weight", "synthesis.L7_148_512.affine.bias", "synthesis.L8_148_512.weight", "synthesis.L8_148_512.bias", "synthesis.L8_148_512.magnitude_ema", "synthesis.L8_148_512.up_filter", "synthesis.L8_148_512.down_filter", "synthesis.L8_148_512.affine.weight", "synthesis.L8_148_512.affine.bias", "synthesis.L9_148_362.weight", "synthesis.L9_148_362.bias", "synthesis.L9_148_362.magnitude_ema", "synthesis.L9_148_362.up_filter", "synthesis.L9_148_362.down_filter", "synthesis.L9_148_362.affine.weight", "synthesis.L9_148_362.affine.bias", "synthesis.L10_276_256.weight", "synthesis.L10_276_256.bias", "synthesis.L10_276_256.magnitude_ema", "synthesis.L10_276_256.up_filter", "synthesis.L10_276_256.down_filter", "synthesis.L10_276_256.affine.weight", "synthesis.L10_276_256.affine.bias", "synthesis.L11_276_181.weight", "synthesis.L11_276_181.bias", "synthesis.L11_276_181.magnitude_ema", "synthesis.L11_276_181.up_filter", "synthesis.L11_276_181.down_filter", "synthesis.L11_276_181.affine.weight", "synthesis.L11_276_181.affine.bias", "synthesis.L12_276_128.weight", "synthesis.L12_276_128.bias", "synthesis.L12_276_128.magnitude_ema", "synthesis.L12_276_128.up_filter", "synthesis.L12_276_128.down_filter", "synthesis.L12_276_128.affine.weight", "synthesis.L12_276_128.affine.bias", "synthesis.L13_256_128.weight", "synthesis.L13_256_128.bias", "synthesis.L13_256_128.magnitude_ema", "synthesis.L13_256_128.up_filter", "synthesis.L13_256_128.down_filter", "synthesis.L13_256_128.affine.weight", "synthesis.L13_256_128.affine.bias", "synthesis.L14_256_3.weight", "synthesis.L14_256_3.bias", "synthesis.L14_256_3.magnitude_ema", "synthesis.L14_256_3.affine.weight", "synthesis.L14_256_3.affine.bias". 
    size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
    size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
    size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "inversion/scripts/train_restyle_psp.py", line 29, in <module>
    main()
  File "/usr/local/lib/python3.7/dist-packages/pyrallis/argparsing.py", line 158, in wrapper_inner
    response = fn(cfg, args, *kwargs)
  File "inversion/scripts/train_restyle_psp.py", line 24, in main
    coach = Coach(opts)
  File "./inversion/training/coach_restyle_psp.py", line 36, in _init_
    self.net = pSp(self.opts).to(self.device)
  File "./inversion/models/psp3.py", line 20, in _init_
    self.load_weights()
  File "./inversion/models/psp3.py", line 42, in load_weights
    self.decoder = SG3Generator(checkpoint_path=self.opts.stylegan_weights).decoder
  File "./models/stylegan3/model.py", line 56, in _init_
    self._load_checkpoint(checkpoint_path)
  File "./models/stylegan3/model.py", line 65, in _load_checkpoint
    self.decoder.load_state_dict(ckpt, strict=False)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1498, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
    size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
    size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
    size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).

I appreciate your help. Thank you.

yuval-alaluf commented 2 years ago

What images does your generator output? This could be that you tried defining a Generator object with an output resolution of 1024 but you're checkpoint holds the weights of a generator with an output resolution of 512. The error occurs in the following line: https://github.com/yuval-alaluf/stylegan3-editing/blob/ab01a5d90b8ba67e0da0e1388f0931482601006c/inversion/models/psp3.py#L42 If the problem is with the output resolution of the generator, then you can make the following changes: In the definition of the SG3Generator, we can specify the output resolution with res. https://github.com/yuval-alaluf/stylegan3-editing/blob/ab01a5d90b8ba67e0da0e1388f0931482601006c/models/stylegan3/model.py#L19-L21 Therefore, it can simply change line 42 in psp3.py to:

 self.decoder = SG3Generator(checkpoint_path=self.opts.stylegan_weights, res=512).decoder 

If this does not solve your problem, it could be because by default we use the r config and you trained your model on the t config. In that case, you can try specifying config=landscape when defining your SG3Generator. The landscape generator also uses the t config so this should match your generator.

Hope this helps.

rut00 commented 2 years ago

With the above methods, the issue got resolved. Thank you for the help.