ChengpengChen / RepGhost

RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization
MIT License
168 stars 17 forks source link

RuntimeError('Error(s) in loading state_dict #8

Closed Choneke closed 1 year ago

Choneke commented 1 year ago

Hi, thank you for the great work. I used the pre-trained models to train on my dataset. However, I encountered the following error when running the code.

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for RepGhostNet:
        Unexpected key(s) in state_dict: "epoch", "arch", "state_dict", "optimizer", "version", "args", "amp_scaler", "state_dict_ema", "metric".

I use the following code, which also used on other models and it work fine.

    if pretrained:
       model_dict = model.state_dict()
       model_dict_file = './weights/'+experiment_name+'/training/repghostnet_0_5x_43M_66.95.pth.tar'
       if os.path.exists(model_dict_file):
           pretrained_dict = torch.load(model_dict_file)
       #else:
       #    pretrained_dict = model_zoo.load_url(model_urls['repghostnet_0_5x_43M_66.95'])
       pretrained_dict = {k: v for k, v in pretrained_dict.items() if k.split('.')[0] != 'classifier'}
       '''for k,v in pretrained_dict.items():
           print k.split('.')[0]'''
       model_dict.update(pretrained_dict)
       model.load_state_dict(model_dict)

    return model

Can you help me to solve the issue?

ChengpengChen commented 1 year ago

You should use the model weights for initialization. For example, pretrained_dict = torch.load(model_dict_file)['state_dict'] or use the ema model: pretrained_dict = torch.load(model_dict_file)['state_dict_ema']

Choneke commented 1 year ago

Thank you very much. The problem has been solved by modifying it with the code mentioned above.