rpryzant / delete_retrieve_generate

PyTorch implementation of the Delete, Retrieve Generate style transfer algorithm
MIT License
132 stars 26 forks source link

I replace the delete with sq2sq in the config yelp_config.json, and then I meet errors, so I need modify train.py when I change the model_type? #31

Closed wasedaward closed 3 years ago

wasedaward commented 3 years ago
@exp16:~/delete_retrieve_generate$ python3 train.py --config yelp_config.json --bleu
2021-06-10 09:58:27,484 - INFO - Reading data ...
2021-06-10 09:58:45,074 - INFO - ...done!
/home/Ren/anaconda3/envs/drg/lib/python3.6/site-packages/torch/nn/modules/rnn.py:54: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1
  "num_layers={}".format(dropout, num_layers))
2021-06-10 09:58:46,409 - INFO - MODEL HAS 9050117 params
Traceback (most recent call last):
  File "train.py", line 114, in <module>
    checkpoint_dir=working_dir)
  File "/home/Ren/delete_retrieve_generate/src/models.py", line 38, in attempt_load_model
    model.load_state_dict(torch.load(checkpoint_path))
  File "/home/Ren/anaconda3/envs/drg/lib/python3.6/site-packages/torch/nn/modules/module.py", line 777, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SeqModel:
        Unexpected key(s) in state_dict: "attribute_embedding.weight". 
        size mismatch for c_bridge.weight: copying a param with shape torch.Size([512, 640]) from checkpoint, the shape in current model is torch.Size([512, 512]).
        size mismatch for h_bridge.weight: copying a param with shape torch.Size([512, 640]) from checkpoint, the shape in current model is torch.Size([512, 512]).
rpryzant commented 3 years ago

Maybe you mispelled the model_type? Accepted values are "seq2seq", "delete_retrieve", and "delete". I've updated the error message to be more useful: https://github.com/rpryzant/delete_retrieve_generate/commit/1dac1d428823e18832b3f4cb832bff3d17caafa3