jt827859032 / DRRN-pytorch

Pytorch implementation of Deep Recursive Residual Network for Super Resolution (DRRN), CVPR 2017
MIT License
193 stars 53 forks source link

training problem #11

Open meroluo opened 5 years ago

meroluo commented 5 years ago

Thank you for your wonderful work! I want to train a model with my own dataset, but there are something wrong in the process. The error is described as below:

===> Loading datasets ===> Building model /home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/nn/_reduction.py:49: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead. warnings.warn(warning.format(ret)) ===> Setting GPU ===> load model model/model_epoch_28.pth /home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) Traceback (most recent call last): File "main.py", line 128, in main() File "main.py", line 69, in main model.load_state_dict(weights['model'].state_dict()) File "/home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for DRRN: Missing key(s) in state_dict: "input.weight", "conv1.weight", "conv2.weight", "output.weight". Unexpected key(s) in state_dict: "module.input.weight", "module.conv1.weight", "module.conv2.weight", "module.output.weight".

It seems that the parameters in model are miss? I can't understand the error, hoping you can give me some suggestions. Sincerely appreciate for your reply.

xsacha commented 5 years ago

You need to remove module from the state_dict because it was trained with data parallel. You can do model = model.module to achieve the same thing, I believe.