j6mes / acl2021-factual-error-correction

ACL 2021
https://jamesthorne.com
Apache License 2.0
26 stars 7 forks source link

Unable to load a saved model and do predictions #4

Closed vnik18 closed 3 years ago

vnik18 commented 3 years ago

Hi,

I am using the following command to load an already trained model and use it to make predictions on an input test file:

python -m error_correction.corrector.run \
    --model_name_or_path t5-base \
    --output_dir <path_to_trained_model_directory> \
    --do_predict \
    --test_file <path_to_test_file> \
    --train_file <path_to_train_file> \
    --val_file <path_to_val_file> \
    --reader mask \
    --mutation_source true \
    --mutation_target false \
    --labels all

I have to give the parameters train_file and val_file as input, since they are 'required' parameters according to the code. However, on running this command I get the following error:

Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/t5-error-correction/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/ubuntu/anaconda3/envs/t5-error-correction/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/t5-error-correction/src/error_correction/corrector/run.py", line 83, in <module>
    main(args)
  File "/home/ubuntu/t5-error-correction/src/error_correction/corrector/run.py", line 74, in main
    trainer.test(ckpt_path=checkpoints[-1])
  File "/home/ubuntu/anaconda3/envs/t5-error-correction/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1281, in test
    results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
  File "/home/ubuntu/anaconda3/envs/t5-error-correction/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1309, in __test_using_best_weights
    model.load_state_dict(ckpt['state_dict'])
  File "/home/ubuntu/anaconda3/envs/t5-error-correction/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ErrorCorrectionModule:
    Missing key(s) in state_dict: "model.decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight". 

Can you please let me know what the issue might be here? Thank you!

vnik18 commented 3 years ago

Looks like it was due to a pytorch version error.