Closed splinter21 closed 4 years ago
Oh, I have fixed it. 1、delete "state_dict()" in loss->adversarial but the optimizer paramters are lost. 2、loss->init
add these to init(): . self.epoch=len(ckp.log)
add these to save(): . opt_d=[l.optimizer.state_dict()for l in self.loss_module if(hasattr(l,"optimizer"))] . with open(os.path.join(apath, 'opt_d.pt'), 'wb') as _f: pickle.dump(opt_d, _f)
add these to load(): . i=0 . for l in self.loss_module: . if(hasattr(l, "optimizer")): . l.optimizer.load_state_dict(optd[i]) . if self.epoch > 1: . for in range(self.epoch): l.optimizer.scheduler.step() . i+=1
So if "GAN" is found in loss function, the model can't be resume.
(I'm using PyTorch 0.4.1, so the code version is legacy/0.4.1)
Preparing loss function: 1.000 VGG54 0.005 RGAN 0.010 * L1 Traceback (most recent call last): File "main.py", line 22, in
loss = loss.Loss(args, checkpoint) if not args.test_only else None
File "/all-data/sv6-disk1/timchen_home/SR/src_gan/loss/init.py", line 67, in init
if args.load != '': self.load(ckp.dir, cpu=args.cpu)
File "/all-data/sv6-disk1/timchen_home/SR/src_gan/loss/init.py", line 137, in load
**kwargs
File "/all-data/sv6-disk1/timchen_home/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Loss:
Missing key(s) in state_dict: "loss_module.1.dis.features.0.0.weight", "loss_module.1.dis.features.0.1.weight", "loss_module.1.dis.features.0.1.bias", "loss_module.1.dis.features.0.1.running_mean", "loss_module.1.dis.features.0.1.running_var", "loss_module.1.dis.features.1.0.weight", "loss_module.1.dis.features.1.1.weight", "loss_module.1.dis.features.1.1.bias", "loss_module.1.dis.features.1.1.running_mean", "loss_module.1.dis.features.1.1.running_var", "loss_module.1.dis.features.2.0.weight", "loss_module.1.dis.features.2.1.weight", "loss_module.1.dis.features.2.1.bias", "loss_module.1.dis.features.2.1.running_mean", "loss_module.1.dis.features.2.1.running_var", "loss_module.1.dis.features.3.0.weight", "loss_module.1.dis.features.3.1.weight", "loss_module.1.dis.features.3.1.bias", "loss_module.1.dis.features.3.1.running_mean", "loss_module.1.dis.features.3.1.running_var", "loss_module.1.dis.features.4.0.weight", "loss_module.1.dis.features.4.1.weight", "loss_module.1.dis.features.4.1.bias", "loss_module.1.dis.features.4.1.running_mean", "loss_module.1.dis.features.4.1.running_var", "loss_module.1.dis.features.5.0.weight", "loss_module.1.dis.features.5.1.weight", "loss_module.1.dis.features.5.1.bias", "loss_module.1.dis.features.5.1.running_mean", "loss_module.1.dis.features.5.1.running_var", "loss_module.1.dis.features.6.0.weight", "loss_module.1.dis.features.6.1.weight", "loss_module.1.dis.features.6.1.bias", "loss_module.1.dis.features.6.1.running_mean", "loss_module.1.dis.features.6.1.running_var", "loss_module.1.dis.features.7.0.weight", "loss_module.1.dis.features.7.1.weight", "loss_module.1.dis.features.7.1.bias", "loss_module.1.dis.features.7.1.running_mean", "loss_module.1.dis.features.7.1.running_var", "loss_module.1.dis.classifier.0.weight", "loss_module.1.dis.classifier.0.bias", "loss_module.1.dis.classifier.2.weight", "loss_module.1.dis.classifier.2.bias".
Unexpected key(s) in state_dict: "loss_module.1.features.0.0.weight", "loss_module.1.features.0.1.weight", "loss_module.1.features.0.1.bias", "loss_module.1.features.0.1.running_mean", "loss_module.1.features.0.1.running_var", "loss_module.1.features.0.1.num_batches_tracked", "loss_module.1.features.1.0.weight", "loss_module.1.features.1.1.weight", "loss_module.1.features.1.1.bias", "loss_module.1.features.1.1.running_mean", "loss_module.1.features.1.1.running_var", "loss_module.1.features.1.1.num_batches_tracked", "loss_module.1.features.2.0.weight", "loss_module.1.features.2.1.weight", "loss_module.1.features.2.1.bias", "loss_module.1.features.2.1.running_mean", "loss_module.1.features.2.1.running_var", "loss_module.1.features.2.1.num_batches_tracked", "loss_module.1.features.3.0.weight", "loss_module.1.features.3.1.weight", "loss_module.1.features.3.1.bias", "loss_module.1.features.3.1.running_mean", "loss_module.1.features.3.1.running_var", "loss_module.1.features.3.1.num_batches_tracked", "loss_module.1.features.4.0.weight", "loss_module.1.features.4.1.weight", "loss_module.1.features.4.1.bias", "loss_module.1.features.4.1.running_mean", "loss_module.1.features.4.1.running_var", "loss_module.1.features.4.1.num_batches_tracked", "loss_module.1.features.5.0.weight", "loss_module.1.features.5.1.weight", "loss_module.1.features.5.1.bias", "loss_module.1.features.5.1.running_mean", "loss_module.1.features.5.1.running_var", "loss_module.1.features.5.1.num_batches_tracked", "loss_module.1.features.6.0.weight", "loss_module.1.features.6.1.weight", "loss_module.1.features.6.1.bias", "loss_module.1.features.6.1.running_mean", "loss_module.1.features.6.1.running_var", "loss_module.1.features.6.1.num_batches_tracked", "loss_module.1.features.7.0.weight", "loss_module.1.features.7.1.weight", "loss_module.1.features.7.1.bias", "loss_module.1.features.7.1.running_mean", "loss_module.1.features.7.1.running_var", "loss_module.1.features.7.1.num_batches_tracked", "loss_module.1.classifier.0.weight", "loss_module.1.classifier.0.bias", "loss_module.1.classifier.2.weight", "loss_module.1.classifier.2.bias".