VITA-Group / EnlightenGAN

[IEEE TIP] "EnlightenGAN: Deep Light Enhancement without Paired Supervision" by Yifan Jiang, Xinyu Gong, Ding Liu, Yu Cheng, Chen Fang, Xiaohui Shen, Jianchao Yang, Pan Zhou, Zhangyang Wang
Other
890 stars 198 forks source link

KeyError: 'unexpected key "module.conv1_1.weight" in state_dict' #30

Closed nelaturuharsha closed 4 years ago

nelaturuharsha commented 4 years ago

Traceback (most recent call last): File "predict.py", line 18, in model = create_model(opt) File "/home/Documents//enlighten/EnlightenGAN/models/models.py", line 36, in create_model model.initialize(opt) File "/home/Documents/enlighten/EnlightenGAN/models/single_model.py", line 72, in initialize self.load_network(self.netG_A, 'G_A', which_epoch) File "/home/Documents/enlighten/EnlightenGAN/models/base_model.py", line 53, in load_network network.load_state_dict(torch.load(save_path)) File "/home/anaconda3/envs/enlighten/lib/python3.5/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict .format(name)) KeyError: 'unexpected key "module.conv1_1.weight" in state_dict'

Python: 3.5.6 PyTorch : 0.3.1 CPU

I'm facing this error when I run the test command in script/scripts.py, I would appreciate any help!

Thank you in advance!

nelaturuharsha commented 4 years ago

Fix :

In models/base_model.py replace the load_network with the following:

def load_network(self, network, network_label, epoch_label): save_filename = '%snet%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) state_dict = torch.load(save_path) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove module. new_state_dict[name] = v network.load_state_dict(new_state_dict)

Dear-Mr commented 4 years ago

您好,在经过上述修改后, 并将 new_state_dict = OrderedDict() 修改为 new_state_dict = collections.OrderedDict() 后, 出现KeyError: 'unexpected key "conv10.bias" in state_dict',该如何解决?

nelaturuharsha commented 4 years ago

Is your python 3.5, pytorch 0.3.1? And are you trying to run via CPU or GPU?

yifanjiang19 commented 4 years ago

@Dear-Mr @SreeHarshaNelaturu This problem is about torch.nn.DataParallel(). Normally DataParallel() will cover the original model, which means you can only use model.module to access the original model. And also, if you directly save the model covered by DataParallel, the key of saved model will be module.conv.weight instead of conv.weight. You can simply use DataParallel to cover the model before you load the pre-trained model to solve this problem.

Eg.

model = ResNet()
parallel_model = torch.nn.DataParallel(model)
print(parallel_model.module)
# Here parallel_model.module is equal to model
yifanjiang19 commented 4 years ago

I've pushed a new version. You can directly pull the new version to avoid this problem.

Dear-Mr commented 4 years ago

Is your python 3.5, pytorch 0.3.1? And are you trying to run via CPU or GPU?

Thanks for your reply. My workspace is python 3.5, PyTorch 0.3.1, torchvision 0.2.0. And i tried to run via GPU.