sanghyun-son / EDSR-PyTorch

PyTorch version of the paper 'Enhanced Deep Residual Networks for Single Image Super-Resolution' (CVPRW 2017)
MIT License
2.42k stars 668 forks source link

Training a x4 model using a pretrained x2 model #8

Closed cys4 closed 6 years ago

cys4 commented 6 years ago

First of all, thank you for sharing this great work.

I have a problem when training a x4 model using a pre-trained x2 model (../experiment/EDSR_baseline_x2/model/model_best.pt) as follows.

$ python main.py --model EDSR --scale 4 --save EDSR_baseline_x4 --reset  --dir_data /data --pre_train ../experiment/EDSR_baseline_x2/model/model_best.pt
...
Loading model from ../experiment/EDSR_baseline_x2/model/model_best.pt...
Traceback (most recent call last):
  File "main.py", line 13, in <module>
    t = Trainer(my_loader, checkpoint, args)
  File "/home/yschoi/work/SR/EDSR-PyTorch_custom/code/trainer.py", line 21, in __init__
    self.model, self.loss, self.optimizer, self.scheduler = ckp.load()
  File "/home/yschoi/work/SR/EDSR-PyTorch_custom/code/utils.py", line 80, in load
    my_model = model(self.args).get_model()
  File "/home/yschoi/work/SR/EDSR-PyTorch_custom/code/model/__init__.py", line 18, in get_model
    my_model.load_state_dict(torch.load(self.args.pre_train))
  File "/home/yschoi/work/SR/EDSR-PyTorch_custom/code/model/EDSR.py", line 78, in load_state_dict
    raise KeyError('missing keys in state_dict: "{}"'.format(missing))
KeyError: 'missing keys in state_dict: "{\'tail.0.2.bias\', \'tail.0.2.weight\'}"'

Although all above errors could be removed by adding 'strict=False' to the call statement of 'load_statedict' (line 15) in './code/model/__init_\.py' as follows, I'm not sure this is a right way to handle this situation.

- my_model.load_state_dict(torch.load(self.args.pre_train))
+ my_model.load_state_dict(torch.load(self.args.pre_train), strict=False)

Please let me know if I'm missing something important.

sanghyun-son commented 6 years ago

Hello.

x4 upsampler consists of two consecutive x2 upsampler.

In our original approach, we randomly initialize x4 upsampler even when we load the weights from scale 2 model.

If you use strict=False, however, the first part of x4 upsampler is initialized with pre-trained x2 upsampler, and the second part is initialized randomly.

Although this situation is not intended, I think your solution will not make problems.

I will check the code and make it executable without any error message.

Thank you for reporting this error.

Sanghyun.

cys4 commented 6 years ago

Okay. I see. Thank you for clarifying!