grammarly / gector

Official implementation of the papers "GECToR – Grammatical Error Correction: Tag, Not Rewrite" (BEA-20) and "Text Simplification by Tagging" (BEA-21)
Apache License 2.0
894 stars 216 forks source link

Cant load newly trained model for prediction #117

Closed Aksh97 closed 3 years ago

Aksh97 commented 3 years ago

Hi There, Thanks for this great repository.

I've trained through all the 3 steps. I got "model.th", " roberta_1_gector.th.1", and "best.th" files.

model= GecBERTModel(vocab_path="vocabulary", model_paths=["model.th"]) Now when I try to use any of the files(model.th, best.th, roberta_1_gector.th.1), I get this error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-22-0093b4209421> in <module>()
----> 1 model= GecBERTModel(vocab_path="/content/drive/MyDrive/grammar2/gector/model/vocabulary", model_paths=["/content/drive/MyDrive/grammar2/gector/model/model.th"])

1 frames
/content/drive/My Drive/checking/gector/gector/gec_model.py in _get_model_data(model_path)
     98     def _get_model_data(model_path):
     99         model_name = model_path.split('/')[-1]
--> 100         tr_model, stf = model_name.split('_')[:2]
    101         return tr_model, int(stf)
    102 

ValueError: not enough values to unpack (expected 2, got 1)

And when I use- `python predict.py --model_path model.th --input_file train.txt --output_file outputs.txt

2021-07-20 07:17:25.390748: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0 Traceback (most recent call last): File "predict.py", line 114, in main(args) File "predict.py", line 42, in main weigths=args.weights) File "/content/drive/My Drive/grammar2/gector/gector/gec_model.py", line 90, in init model.load_state_dict(torch.load(model_path)) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 839, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Seq2Labels: ` size mismatch for tag_labels_projection_layer._module.weight: copying a param with shape torch.Size([1002, 768]) from checkpoint.... size mismatch for tag_labels_projection_layer._module.bias: copying a param with shape torch.Size([1002]) from checkpoint

Can you please help me how to debug it?

skurzhanskyi commented 3 years ago

Yes, that's because we expect a model name like roberta_1_smth.th, where 1 stands for whether the model was trained with special_tokens_fix parameter.

Aksh97 commented 3 years ago

So I just need to rename the model.th file?

skurzhanskyi commented 3 years ago

Exactly

On Tue, Jul 20, 2021, 16:29 Akshay Sachdeva @.***> wrote:

So I just need to rename the model.th file?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/grammarly/gector/issues/117#issuecomment-883393469, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEGSLNMEIIBXQRTCC2PPGXLTYV227ANCNFSM5AVHAYIA .

Aksh97 commented 3 years ago

Thank you so much for your quick response. Highly Appreciated.

Also, if we want to train it for multiple languages, what changes will be required in vocabulary ? Bcz we can use that model to train it further, if Im not wrong

skurzhanskyi commented 3 years ago

I would suggest having joined vocabulary from the very beginning for all languages you want to use

Aksh97 commented 3 years ago

Thanks a lot. 👍