microsoft / IRNet

An algorithm for cross-domain NL2SQL
MIT License
264 stars 81 forks source link

RuntimeError: Error(s) in loading state_dict for IRNet #50

Open anshudaur opened 3 years ago

anshudaur commented 3 years ago

HI All, I am getting below error while loading IRNET trained model, i am using the final end_model.model to evaluate

/python3.8/site-packages/torch/nn/modules/module.py", line 846, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for IRNet: size mismatch for lf_decoder_lstm.weight_ih: copying a param with shape torch.Size([1200, 556]) from checkpoint, the shape in current model is torch.Size([1200, 492]). size mismatch for sketch_decoder_lstm.weight_ih: copying a param with shape torch.Size([1200, 556]) from checkpoint, the shape in current model is torch.Size([1200, 492]). size mismatch for type_embed.weight: copying a param with shape torch.Size([10, 128]) from checkpoint, the shape in current model is torch.Size([10, 64]). size mismatch for att_project.weight: copying a param with shape torch.Size([300, 428]) from checkpoint, the shape in current model is torch.Size([300, 364]). Namespace(data_path='./spider', input_path='predict_lf.json', output_path='./results/irnet')

I have tried following :

  1. model.load_state_dict(pretrained_modeled,strict=False)
  2. from collections import OrderedDict new_state_dict = OrderedDict()

    for k in pretrained_model.keys():

    for k,v in pretrained_model.items():

    if k not in model.state_dict().keys():

    #    del pretrained_modeled[k]
    print("actual key ::",k)
    new_state_dict[k]=v

    model.load_state_dict(new_state_dict)

Any suggestions would be helpful. Thanks