mseitzer / srgan

Pytorch implementation of "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
MIT License
42 stars 14 forks source link

Fine-tuning SRGAN #9

Open hosseinahm1367 opened 4 years ago

hosseinahm1367 commented 4 years ago

I am trying to fine-tune the SRGAN network initializing it with the pre-trained SRGAN checkpoint provided in the source code. I tried to do same way you suggested previously: https://github.com/mseitzer/srgan/issues/8

Basically, i simply replaced srresnet.pth with srgan.pth in the srgan config file. I get a model loading error in "load_model_state_dict" function. How should I approach this? Thanks,

mseitzer commented 4 years ago

Yeaah, this is a bit more complicated (I have never done it this way around, but it should work). Note that "pretrained_weights" in the config has two arguments, the path to the checkpoint, and a key (in this case, model). The key describes under which key in the checkpoint dictionary the loader should look for the model weights.

For normal training, this is simply model. The checkpoint is created here: https://github.com/mseitzer/srgan/blob/70f88fbac26b5a3659847965befea2fdd49eb625/training/runner.py#L105-L109 For adversarial training, the checkpoint has a different structure, and the model key is generator. It is created here: https://github.com/mseitzer/srgan/blob/70f88fbac26b5a3659847965befea2fdd49eb625/training/adversarial_runner.py#L183-L189

Basically, you want to replace model with generator, and it should work:

"model": {
  "#include": "resnet.json",
  "pretrained_weights": ["../resources/pretrained/srgan.pth", "generator"]
}