guotong1988 / NL2SQL-RULE

Content Enhanced BERT-based Text-to-SQL Generation https://arxiv.org/abs/1910.07179
188 stars 48 forks source link

ERROR when running train.py (RuntimeError: Error(s) in loading state_dict for Seq2SQL_v1) #11

Closed mellahysf closed 4 years ago

mellahysf commented 4 years ago

Hi, I try to run python3 train.py --trained --bert_type_abb uS but it gives me this error : RuntimeError: Error(s) in loading state_dict for Seq2SQL_v1:

Details of execution is below :

XXXX@YYYY:/mnt/c/users/administrateur/desktop/sqlova$ python3 train.py --trained --bert_type_abb uS

BERT-type: uncased_L-12_H-768_A-12 Batch_size = 32 BERT parameters: learning rate: 1e-05 Fine-tune BERT: False vocab size: 30522 hidden_size: 768 num_hidden_layer: 12 num_attention_heads: 12 hidden_act: gelu intermediate_size: 3072 hidden_dropout_prob: 0.1 attention_probs_dropout_prob: 0.1 max_position_embeddings: 512 type_vocab_size: 2 initializer_range: 0.02 Load pre-trained parameters. Seq-to-SQL: the number of final BERT layers to be used: 2 Seq-to-SQL: the size of hidden dimension = 100 Seq-to-SQL: LSTM encoding layer size = 2 Seq-to-SQL: dropout rate = 0.3 Seq-to-SQL: learning rate = 0.001 Traceback (most recent call last): File "train.py", line 741, in path_model_bert=path_model_bert, path_model=path_model) File "train.py", line 196, in get_models model.load_state_dict(res['model']) File "/home/ysfmell/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Seq2SQL_v1: size mismatch for scp.W_att.weight: copying a param with shape torch.Size([103, 105]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for scp.W_att.bias: copying a param with shape torch.Size([103]) from checkpoint, the shape in current model is torch.Size([100]). size mismatch for scp.W_c.weight: copying a param with shape torch.Size([100, 105]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for scp.W_hs.weight: copying a param with shape torch.Size([100, 103]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for sap.W_att.weight: copying a param with shape torch.Size([103, 105]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for sap.W_att.bias: copying a param with shape torch.Size([103]) from checkpoint, the shape in current model is torch.Size([100]). size mismatch for sap.sa_out.0.weight: copying a param with shape torch.Size([100, 105]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for wnp.W_att_h.weight: copying a param with shape torch.Size([1, 103]) from checkpoint, the shape in current model is torch.Size([1, 100]). size mismatch for wnp.W_hidden.weight: copying a param with shape torch.Size([200, 103]) from checkpoint, the shape in current model is torch.Size([200, 100]). size mismatch for wnp.W_cell.weight: copying a param with shape torch.Size([200, 103]) from checkpoint, the shape in current model is torch.Size([200, 100]). size mismatch for wnp.W_att_n.weight: copying a param with shape torch.Size([1, 105]) from checkpoint, the shape in current model is torch.Size([1, 100]). size mismatch for wnp.wn_out.0.weight: copying a param with shape torch.Size([100, 105]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for wcp.W_att.weight: copying a param with shape torch.Size([103, 105]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for wcp.W_att.bias: copying a param with shape torch.Size([103]) from checkpoint, the shape in current model is torch.Size([100]). size mismatch for wcp.W_c.weight: copying a param with shape torch.Size([100, 105]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for wcp.W_hs.weight: copying a param with shape torch.Size([100, 103]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for wvp.W_att.weight: copying a param with shape torch.Size([103, 105]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for wvp.W_att.bias: copying a param with shape torch.Size([103]) from checkpoint, the shape in current model is torch.Size([100]). size mismatch for wvp.W_c.weight: copying a param with shape torch.Size([100, 105]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for wvp.W_hs.weight: copying a param with shape torch.Size([100, 103]) from checkpoint, the shape in current model is torch.Size([100, 100]). size mismatch for wvp.wv_out.0.weight: copying a param with shape torch.Size([100, 405]) from checkpoint, the shape in current model is torch.Size([100, 400]).

guotong1988 commented 4 years ago

The model you load is not the right one.