omertov / encoder4editing

Official implementation of "Designing an Encoder for StyleGAN Image Manipulation" (SIGGRAPH 2021) https://arxiv.org/abs/2102.02766
MIT License
945 stars 154 forks source link

Error when trying to use encoder trained on own dataset #27

Closed Alen95 closed 3 years ago

Alen95 commented 3 years ago

After training the encoder on my own dataset and trying to use it for inference, I get the following error :

Loading e4e over the pSp framework from checkpoint: e4e_ffhq_encode.pt Traceback (most recent call last): File "scripts/train.py", line 88, in main() File "scripts/train.py", line 28, in main coach = Coach(opts, previous_train_ckpt) File "./training/coach.py", line 39, in init self.net = pSp(self.opts).to(self.device) File "./models/psp.py", line 28, in init self.load_weights() File "./models/psp.py", line 43, in load_weights self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) File "/opt/conda/envs/e4e_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1045, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Encoder4Editing: Unexpected key(s) in state_dict: "styles.16.convs.0.weight", "styles.16.convs.0.bias", "styles.16.convs.2.weight", "styles.16.convs.2.bias", "styles.16.convs.4.weight", "styles.16.convs.4.bias", "styles.16.convs.6.weight", "styles.16.convs.6.bias", "styles.16.convs.8.weight", "styles.16.convs.8.bias", "styles.16.convs.10.weight", "styles.16.convs.10.bias", "styles.16.linear.weight", "styles.16.linear.bias", "styles.17.convs.0.weight", "styles.17.convs.0.bias", "styles.17.convs.2.weight", "styles.17.convs.2.bias", "styles.17.convs.4.weight", "styles.17.convs.4.bias", "styles.17.convs.6.weight", "styles.17.convs.6.bias", "styles.17.convs.8.weight", "styles.17.convs.8.bias", "styles.17.convs.10.weight", "styles.17.convs.10.bias", "styles.17.linear.weight", "styles.17.linear.bias".

Did anyone face the same problem or does anyone have any hints which may help to solve the problem ?

omertov commented 3 years ago

Hi @Alen95, From a first glance, it looks like you are trying to load a different checkpoint from the one you produced (The checkpoint's name is the same as the official one we published, perhaps it is not the one you intend to load)?

The error probably occures since the checkpoint contains an encoder for 1024 resolution StyleGAN while in the ops object, opts.stylegan_size=512 (therefore the extra layers 16 and 17), I will look into it and see if it is a bug in loading a saved checkpoint.

Alen95 commented 3 years ago

I have followed the guide (Training), with the proposed commands, which is available inside the repository to train an encoder on my own dataset. I have done approx 50000 steps and took the last checkpoint to perform the inference and at this point I got the aforementioned error.

omertov commented 3 years ago

Can you elaborate on the inference procedure? When you train the model for 50k iterations, it should output a checkpoint named best_model.pt and iteration_50000.pt under the <exp_dir>/checkpoints folder.

In case you are using the inference script of this repo, you should specify one of the above checkpoints, currently it looks like the inference script instead tries to load the official checkpoint of e4e_ffhq_encode.pt (from the log you specified) Can you share the inference script/command being used?

Alen95 commented 3 years ago

Yes, the checkpoints can be found in the mentioned folder.

The training command : python scripts/train.py --dataset_type my_data_encode --exp_dir new/experiment/directory --start_from_latent_avg --use_w_pool --w_discriminator_lambda 0.1 --progressive_start 20000 --id_lambda 0.5 --val_interval 10000 --max_steps 50000 --stylegan_size 512 --stylegan_weights pretrained_models/stylegan2-ffhq-config-f.pt --workers 8 --batch_size 8 --test_batch_size 4 --test_workers 4

The inference command : python scripts/inference.py --images_dir=notebooks/images --save_dir=output best_model.pt

I always move the checkpoint to the root directory, hence why I use "best_model.pt"

The last error message:

Loading e4e over the pSp framework from checkpoint: best_model.pt Traceback (most recent call last): File "scripts/inference.py", line 134, in main(args) File "scripts/inference.py", line 22, in main net, opts = setup_model(args.ckpt, device) File "./utils/model_utils.py", line 24, in setup_model net = pSp(opts) File "./models/psp.py", line 28, in init self.load_weights() File "./models/psp.py", line 43, in load_weights self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Encoder4Editing: Unexpected key(s) in state_dict: "styles.14.convs.0.weight", "styles.14.convs.0.bias", "styles.14.convs.2.weight", "styles.14.convs.2.bias", "styles.14.convs.4.weight", "styles.14.convs.4.bias", "styles.14.convs.6.weight", "styles.14.convs.6.bias", "styles.14.convs.8.weight", "styles.14.convs.8.bias", "styles.14.convs.10.weight", "styles.14.convs.10.bias", "styles.14.linear.weight", "styles.14.linear.bias", "styles.15.convs.0.weight", "styles.15.convs.0.bias", "styles.15.convs.2.weight", "styles.15.convs.2.bias", "styles.15.convs.4.weight", "styles.15.convs.4.bias", "styles.15.convs.6.weight", "styles.15.convs.6.bias", "styles.15.convs.8.weight", "styles.15.convs.8.bias", "styles.15.convs.10.weight", "styles.15.convs.10.bias", "styles.15.linear.weight", "styles.15.linear.bias".

omertov commented 3 years ago

Hi @Alen95, It looks like a progress was made, as now the log starts with "Loading e4e ... from checkpoint: best_model.pt" One thing I notice is that you train the encoder using --stylegan_size 512 --stylegan_weights pretrained_models/stylegan2-ffhq-config-f.pt, which i am not sure is correct, as the pretrained ffhq stylegan is of size 1024.

As for the error, it seems like the e4e encoder is initialized into the 256 resolution stylegan as opposed to the checkpoint which contains an e4e encoder trained for a 512 resolution stylegan. In order to solve this, you need to make sure that opts.stylegan_size=512 in the inference script (you can add a debug print to test it out).

Can you send the contents of the opts.json file in the experiment directory? Also as a side note, In case there is already an latents.pt file in the save_dir, I recomend deleting it for a clean inference.

Best, Omer

AleksiKnuutila commented 3 years ago

I had the same problem. It seems like model_utils.py infers stylegan_size from the dataset_type string which lead to a wrong value in my case.

omertov commented 3 years ago

Thank you @AleksiKnuutila, the override of the stylegan_size can lead to such confusing errors, I will remove it.