sanghyun-son / EDSR-PyTorch

PyTorch version of the paper 'Enhanced Deep Residual Networks for Single Image Super-Resolution' (CVPRW 2017)
MIT License
2.43k stars 669 forks source link

When training GAN-based SR models, the state_dict of discriminator is not saved right. #245

Closed splinter21 closed 4 years ago

splinter21 commented 4 years ago

So if "GAN" is found in loss function, the model can't be resume.

(I'm using PyTorch 0.4.1, so the code version is legacy/0.4.1)

Preparing loss function: 1.000 VGG54 0.005 RGAN 0.010 * L1 Traceback (most recent call last): File "main.py", line 22, in loss = loss.Loss(args, checkpoint) if not args.test_only else None File "/all-data/sv6-disk1/timchen_home/SR/src_gan/loss/init.py", line 67, in init if args.load != '': self.load(ckp.dir, cpu=args.cpu) File "/all-data/sv6-disk1/timchen_home/SR/src_gan/loss/init.py", line 137, in load **kwargs File "/all-data/sv6-disk1/timchen_home/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Loss: Missing key(s) in state_dict: "loss_module.1.dis.features.0.0.weight", "loss_module.1.dis.features.0.1.weight", "loss_module.1.dis.features.0.1.bias", "loss_module.1.dis.features.0.1.running_mean", "loss_module.1.dis.features.0.1.running_var", "loss_module.1.dis.features.1.0.weight", "loss_module.1.dis.features.1.1.weight", "loss_module.1.dis.features.1.1.bias", "loss_module.1.dis.features.1.1.running_mean", "loss_module.1.dis.features.1.1.running_var", "loss_module.1.dis.features.2.0.weight", "loss_module.1.dis.features.2.1.weight", "loss_module.1.dis.features.2.1.bias", "loss_module.1.dis.features.2.1.running_mean", "loss_module.1.dis.features.2.1.running_var", "loss_module.1.dis.features.3.0.weight", "loss_module.1.dis.features.3.1.weight", "loss_module.1.dis.features.3.1.bias", "loss_module.1.dis.features.3.1.running_mean", "loss_module.1.dis.features.3.1.running_var", "loss_module.1.dis.features.4.0.weight", "loss_module.1.dis.features.4.1.weight", "loss_module.1.dis.features.4.1.bias", "loss_module.1.dis.features.4.1.running_mean", "loss_module.1.dis.features.4.1.running_var", "loss_module.1.dis.features.5.0.weight", "loss_module.1.dis.features.5.1.weight", "loss_module.1.dis.features.5.1.bias", "loss_module.1.dis.features.5.1.running_mean", "loss_module.1.dis.features.5.1.running_var", "loss_module.1.dis.features.6.0.weight", "loss_module.1.dis.features.6.1.weight", "loss_module.1.dis.features.6.1.bias", "loss_module.1.dis.features.6.1.running_mean", "loss_module.1.dis.features.6.1.running_var", "loss_module.1.dis.features.7.0.weight", "loss_module.1.dis.features.7.1.weight", "loss_module.1.dis.features.7.1.bias", "loss_module.1.dis.features.7.1.running_mean", "loss_module.1.dis.features.7.1.running_var", "loss_module.1.dis.classifier.0.weight", "loss_module.1.dis.classifier.0.bias", "loss_module.1.dis.classifier.2.weight", "loss_module.1.dis.classifier.2.bias". Unexpected key(s) in state_dict: "loss_module.1.features.0.0.weight", "loss_module.1.features.0.1.weight", "loss_module.1.features.0.1.bias", "loss_module.1.features.0.1.running_mean", "loss_module.1.features.0.1.running_var", "loss_module.1.features.0.1.num_batches_tracked", "loss_module.1.features.1.0.weight", "loss_module.1.features.1.1.weight", "loss_module.1.features.1.1.bias", "loss_module.1.features.1.1.running_mean", "loss_module.1.features.1.1.running_var", "loss_module.1.features.1.1.num_batches_tracked", "loss_module.1.features.2.0.weight", "loss_module.1.features.2.1.weight", "loss_module.1.features.2.1.bias", "loss_module.1.features.2.1.running_mean", "loss_module.1.features.2.1.running_var", "loss_module.1.features.2.1.num_batches_tracked", "loss_module.1.features.3.0.weight", "loss_module.1.features.3.1.weight", "loss_module.1.features.3.1.bias", "loss_module.1.features.3.1.running_mean", "loss_module.1.features.3.1.running_var", "loss_module.1.features.3.1.num_batches_tracked", "loss_module.1.features.4.0.weight", "loss_module.1.features.4.1.weight", "loss_module.1.features.4.1.bias", "loss_module.1.features.4.1.running_mean", "loss_module.1.features.4.1.running_var", "loss_module.1.features.4.1.num_batches_tracked", "loss_module.1.features.5.0.weight", "loss_module.1.features.5.1.weight", "loss_module.1.features.5.1.bias", "loss_module.1.features.5.1.running_mean", "loss_module.1.features.5.1.running_var", "loss_module.1.features.5.1.num_batches_tracked", "loss_module.1.features.6.0.weight", "loss_module.1.features.6.1.weight", "loss_module.1.features.6.1.bias", "loss_module.1.features.6.1.running_mean", "loss_module.1.features.6.1.running_var", "loss_module.1.features.6.1.num_batches_tracked", "loss_module.1.features.7.0.weight", "loss_module.1.features.7.1.weight", "loss_module.1.features.7.1.bias", "loss_module.1.features.7.1.running_mean", "loss_module.1.features.7.1.running_var", "loss_module.1.features.7.1.num_batches_tracked", "loss_module.1.classifier.0.weight", "loss_module.1.classifier.0.bias", "loss_module.1.classifier.2.weight", "loss_module.1.classifier.2.bias".

splinter21 commented 4 years ago

Oh, I have fixed it. 1、delete "state_dict()" in loss->adversarial but the optimizer paramters are lost. 2、loss->init

add these to init(): . self.epoch=len(ckp.log)

add these to save(): . opt_d=[l.optimizer.state_dict()for l in self.loss_module if(hasattr(l,"optimizer"))] . with open(os.path.join(apath, 'opt_d.pt'), 'wb') as _f: pickle.dump(opt_d, _f)

add these to load(): . i=0 . for l in self.loss_module: . if(hasattr(l, "optimizer")): . l.optimizer.load_state_dict(optd[i]) . if self.epoch > 1: . for in range(self.epoch): l.optimizer.scheduler.step() . i+=1