yuval-alaluf / stylegan3-editing

Official Implementation of "Third Time's the Charm? Image and Video Editing with StyleGAN3" (AIM ECCVW 2022) https://arxiv.org/abs/2201.13433
https://yuval-alaluf.github.io/stylegan3-editing/
MIT License
658 stars 72 forks source link

Error due to load_weight function in models/e4e.py scripts #53

Open YKJamesMoriarty opened 4 days ago

YKJamesMoriarty commented 4 days ago

The function I want to focus on asking about is load_weights:

def load_weights(self):
        if self.opts.checkpoint_path is not None:
            print(f'Loading ReStyle e4e from checkpoint: {self.opts.checkpoint_path}')
            ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
            self.encoder.load_state_dict(self.__get_keys(ckpt, 'encoder'), strict=False)#False
            self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True)#True

            self.__load_latent_avg(ckpt)
            print(f'Loading checkpoint from {self.opts.checkpoint_path} successfully!')
        else:
            encoder_ckpt = self.__get_encoder_checkpoint()
            self.encoder.load_state_dict(encoder_ckpt, strict=False)
            print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}')
            ckpt = torch.load(self.opts.stylegan_weights)
            self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
            self.__load_latent_avg(ckpt, repeat=self.n_styles)

Using parameter checkpoint_path to train

Firstly, I want to use pretrained model as a start point to I would like to use the weight parameters of your already trained model restyle_e4e_ffhq.pt, which have been provided the download link in readme.md, as my initialization parameters when I retrain the model, I believe will speed up my training results。 And after running the train_restyle_e4e.py with the below command:

python scripts/train_restyle_e4e.py \
--dataset_type my_data_encode \
--encoder_type ResNetProgressiveBackboneEncoder \
--exp_dir=experiment/restyle_e4e_NTF_encode \
--batch_size 2 \
--test_batch_size 2 \
--workers 8 \
--test_workers 8 \
--val_interval 5000 \
--save_interval 10000 \
--start_from_latent_avg \
--lpips_lambda 0.8 \
--l2_lambda 1 \
--id_lambda 0.1 \
--input_nc 6 \
--n_iters_per_batch 3 \
--output_size 1024 \
--save_training_data \
--stylegan_weights pretrained_models/sg3-r-ffhq-1024.pt 
--checkpoint_path pretrained_models/restyle_e4e_ffhq.pt

It would raise a really crazy error:

Traceback (most recent call last):
  File "scripts/train_restyle_e4e.py", line 86, in <module>
    main()
  File "scripts/train_restyle_e4e.py", line 28, in main
    coach = Coach(opts, previous_train_ckpt)
  File "./training/coach_restyle_e4e.py", line 34, in __init__
    self.net = e4e(self.opts).to(self.device)
  File "./models/e4e.py", line 25, in __init__
    self.load_weights()
  File "./models/e4e.py", line 49, in load_weights
    self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True)#True
  File "/root/miniconda3/envs/sg3_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
        Missing key(s) in state_dict: "style.1.weight", "style.1.bias", "style.2.weight", "style.2.bias", "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias"(there are hunderd of missing keys following, so I won't paste so much here)

I don't know why, is it because the parameter checkpoint_path ,maybe I use in wrong way, for example, this parameter is not used to represent the path of an already trained restyle_e4e model. But I feel after reading this part of the code that he does indicate the path of the already trained restyle_e4e model. So I'm extraordinarily confused, beg your help and explanation! 🙏

QUESTION ABOUT restyle_e4e_ffhq.pt AND function load_weights

I also have a big question on the principle, I read in your paper that you work with the restyle-e4e model as an encoder and the styelgan3 model as a decoder, but why is it that in the if self.opts.checkpoint_path is not None block in the load_weights function, it uses a separate checkpoint_path in the load_weights function to load both the encoder and decoder?

I'm sorry I'm a bit clumsy, I've been thinking about this question for a long time and have asked gpt for advice, but haven't gotten an answer that I could understand. 😢

3.No use of the a parameter checkpoint_path to train

After using the checkpoint_path parameter above to initiate training encountered unsolvable errors, I tried to train the model without using the already trained restyle_e4e_ffhq.pt model to initialize the weights, parameters. Running the train_restyle_e4e.py with the below command:

python scripts/train_restyle_e4e.py \
--dataset_type my_data_encode \
--encoder_type ResNetProgressiveBackboneEncoder \
--exp_dir=experiment/restyle_e4e_NTF_encode \
--batch_size 2 \
--test_batch_size 2 \
--workers 8 \
--test_workers 8 \
--val_interval 5000 \
--save_interval 10000 \
--start_from_latent_avg \
--lpips_lambda 0.8 \
--l2_lambda 1 \
--id_lambda 0.1 \
--input_nc 6 \
--n_iters_per_batch 3 \
--output_size 1024 \
--save_training_data \
--stylegan_weights pretrained_models/sg3-r-ffhq-1024.pt

An other error appeared: ’‘’ Loading encoders weights from resnet34! Loading decoder weights from pretrained path: pretrained_models/sg3-r-ffhq-1024.pt Traceback (most recent call last): File "scripts/train_restyle_e4e.py", line 86, in main() File "scripts/train_restyle_e4e.py", line 28, in main coach = Coach(opts, previous_train_ckpt) File "./training/coach_restyle_e4e.py", line 34, in init self.net = e4e(self.opts).to(self.device) File "./models/e4e.py", line 25, in init self.load_weights() File "./models/e4e.py", line 64, in load_weights self.decoder.load_state_dict(ckpt['g_ema'], strict=False) KeyError: 'g_ema' ‘’‘ Then I change the code to :

else:
            encoder_ckpt = self.__get_encoder_checkpoint()
            self.encoder.load_state_dict(encoder_ckpt, strict=False)
            print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}')
            ckpt = torch.load(self.opts.stylegan_weights)
            # see all keys 
            print(ckpt.keys())
            # try to use whole ckpt to load instead of ckpt['g_ema']
            self.decoder.load_state_dict(ckpt, strict=False)
            #self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
            self.__load_latent_avg(ckpt, repeat=self.n_styles)

I find the ckpt.keys() don't contain the g_ema indeed, and after using self.decoder.load_state_dict(ckpt, strict=False) , the error don't disappear. But I'm not satisfied with this result, because i use all the model you show in readme.md(https://github.com/yuval-alaluf/restyle-encoder#preparing-your-data), like ResNet-34 model, restyle_e4e_ffhq.pt, which I guest you also use it when you train your model. So I'm confused about why i just use the basic training command, ordinary parameter will rising this weird error, This looks like a model mismatch that won't work 😭

I apologize for asking 3 questions in one issue, mainly because I think they are highly related. I need your help very urgently and look forward to your early reply! @yuval-alaluf

YKJamesMoriarty commented 2 days ago

@orpatashnik @yuval-alaluf 😭