The function I want to focus on asking about is load_weights:
def load_weights(self):
if self.opts.checkpoint_path is not None:
print(f'Loading ReStyle e4e from checkpoint: {self.opts.checkpoint_path}')
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
self.encoder.load_state_dict(self.__get_keys(ckpt, 'encoder'), strict=False)#False
self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True)#True
self.__load_latent_avg(ckpt)
print(f'Loading checkpoint from {self.opts.checkpoint_path} successfully!')
else:
encoder_ckpt = self.__get_encoder_checkpoint()
self.encoder.load_state_dict(encoder_ckpt, strict=False)
print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}')
ckpt = torch.load(self.opts.stylegan_weights)
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
self.__load_latent_avg(ckpt, repeat=self.n_styles)
Using parameter checkpoint_path to train
Firstly, I want to use pretrained model as a start point to
I would like to use the weight parameters of your already trained model restyle_e4e_ffhq.pt, which have been provided the download link in readme.md, as my initialization parameters when I retrain the model, I believe will speed up my training results。
And after running the train_restyle_e4e.py with the below command:
Traceback (most recent call last):
File "scripts/train_restyle_e4e.py", line 86, in <module>
main()
File "scripts/train_restyle_e4e.py", line 28, in main
coach = Coach(opts, previous_train_ckpt)
File "./training/coach_restyle_e4e.py", line 34, in __init__
self.net = e4e(self.opts).to(self.device)
File "./models/e4e.py", line 25, in __init__
self.load_weights()
File "./models/e4e.py", line 49, in load_weights
self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True)#True
File "/root/miniconda3/envs/sg3_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict: "style.1.weight", "style.1.bias", "style.2.weight", "style.2.bias", "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias"(there are hunderd of missing keys following, so I won't paste so much here)
I don't know why, is it because the parameter checkpoint_path ,maybe I use in wrong way, for example, this parameter is not used to represent the path of an already trained restyle_e4e model. But I feel after reading this part of the code that he does indicate the path of the already trained restyle_e4e model. So I'm extraordinarily confused, beg your help and explanation! 🙏
QUESTION ABOUT restyle_e4e_ffhq.pt AND function load_weights
I also have a big question on the principle, I read in your paper that you work with the restyle-e4e model as an encoder and the styelgan3 model as a decoder, but why is it that in the if self.opts.checkpoint_path is not None block in the load_weights function, it uses a separate checkpoint_path in the load_weights function to load both the encoder and decoder?
I'm sorry I'm a bit clumsy, I've been thinking about this question for a long time and have asked gpt for advice, but haven't gotten an answer that I could understand. 😢
3.No use of the a parameter checkpoint_path to train
After using the checkpoint_path parameter above to initiate training encountered unsolvable errors, I tried to train the model without using the already trained restyle_e4e_ffhq.pt model to initialize the weights, parameters.
Running the train_restyle_e4e.py with the below command:
An other error appeared:
’‘’
Loading encoders weights from resnet34!
Loading decoder weights from pretrained path: pretrained_models/sg3-r-ffhq-1024.pt
Traceback (most recent call last):
File "scripts/train_restyle_e4e.py", line 86, in
main()
File "scripts/train_restyle_e4e.py", line 28, in main
coach = Coach(opts, previous_train_ckpt)
File "./training/coach_restyle_e4e.py", line 34, in init
self.net = e4e(self.opts).to(self.device)
File "./models/e4e.py", line 25, in init
self.load_weights()
File "./models/e4e.py", line 64, in load_weights
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
KeyError: 'g_ema'
‘’‘
Then I change the code to :
else:
encoder_ckpt = self.__get_encoder_checkpoint()
self.encoder.load_state_dict(encoder_ckpt, strict=False)
print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}')
ckpt = torch.load(self.opts.stylegan_weights)
# see all keys
print(ckpt.keys())
# try to use whole ckpt to load instead of ckpt['g_ema']
self.decoder.load_state_dict(ckpt, strict=False)
#self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
self.__load_latent_avg(ckpt, repeat=self.n_styles)
I find the ckpt.keys() don't contain the g_ema indeed, and after using self.decoder.load_state_dict(ckpt, strict=False) , the error don't disappear. But I'm not satisfied with this result, because i use all the model you show in readme.md(https://github.com/yuval-alaluf/restyle-encoder#preparing-your-data), like ResNet-34 model, restyle_e4e_ffhq.pt, which I guest you also use it when you train your model. So I'm confused about why i just use the basic training command, ordinary parameter will rising this weird error, This looks like a model mismatch that won't work 😭
I apologize for asking 3 questions in one issue, mainly because I think they are highly related. I need your help very urgently and look forward to your early reply! @yuval-alaluf
The function I want to focus on asking about is load_weights:
Using parameter
checkpoint_path
to trainFirstly, I want to use pretrained model as a start point to I would like to use the weight parameters of your already trained model restyle_e4e_ffhq.pt, which have been provided the download link in readme.md, as my initialization parameters when I retrain the model, I believe will speed up my training results。 And after running the train_restyle_e4e.py with the below command:
It would raise a really crazy error:
I don't know why, is it because the parameter
checkpoint_path
,maybe I use in wrong way, for example, this parameter is not used to represent the path of an already trained restyle_e4e model. But I feel after reading this part of the code that he does indicate the path of the already trained restyle_e4e model. So I'm extraordinarily confused, beg your help and explanation! 🙏QUESTION ABOUT restyle_e4e_ffhq.pt AND function load_weights
I also have a big question on the principle, I read in your paper that you work with the restyle-e4e model as an encoder and the styelgan3 model as a decoder, but why is it that in the if self.opts.checkpoint_path is not None block in the load_weights function, it uses a separate checkpoint_path in the load_weights function to load both the encoder and decoder?
I'm sorry I'm a bit clumsy, I've been thinking about this question for a long time and have asked gpt for advice, but haven't gotten an answer that I could understand. 😢
3.No use of the a parameter
checkpoint_path
to trainAfter using the checkpoint_path parameter above to initiate training encountered unsolvable errors, I tried to train the model without using the already trained restyle_e4e_ffhq.pt model to initialize the weights, parameters. Running the train_restyle_e4e.py with the below command:
An other error appeared: ’‘’ Loading encoders weights from resnet34! Loading decoder weights from pretrained path: pretrained_models/sg3-r-ffhq-1024.pt Traceback (most recent call last): File "scripts/train_restyle_e4e.py", line 86, in
main()
File "scripts/train_restyle_e4e.py", line 28, in main
coach = Coach(opts, previous_train_ckpt)
File "./training/coach_restyle_e4e.py", line 34, in init
self.net = e4e(self.opts).to(self.device)
File "./models/e4e.py", line 25, in init
self.load_weights()
File "./models/e4e.py", line 64, in load_weights
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
KeyError: 'g_ema'
‘’‘
Then I change the code to :
I find the ckpt.keys() don't contain the g_ema indeed, and after using
self.decoder.load_state_dict(ckpt, strict=False)
, the error don't disappear. But I'm not satisfied with this result, because i use all the model you show in readme.md(https://github.com/yuval-alaluf/restyle-encoder#preparing-your-data), like ResNet-34 model, restyle_e4e_ffhq.pt, which I guest you also use it when you train your model. So I'm confused about why i just use the basic training command, ordinary parameter will rising this weird error, This looks like a model mismatch that won't work 😭I apologize for asking 3 questions in one issue, mainly because I think they are highly related. I need your help very urgently and look forward to your early reply! @yuval-alaluf