dwadden / multivers

Code and model checkpoints for the MultiVerS model for scientific claim verification.
MIT License
44 stars 11 forks source link

Issue with updating state_dict #7

Closed Gab123789 closed 1 year ago

Gab123789 commented 1 year ago

I cloned from scratch and am encountering this error:

Traceback (most recent call last):

  File "multivers/predict.py", line 109, in <module>

    main()

  File "multivers/predict.py", line 101, in main

    predictions = get_predictions(args)

  File "multivers/predict.py", line 36, in get_predictions

    model = MultiVerSModel.load_from_checkpoint(checkpoint_path=args.checkpoint_path)

  File "/opt/miniconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 156, in load_from_checkpoint

    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)

  File "/opt/miniconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 198, in _load_model_state

    model = cls(**_cls_kwargs)

  File "/Users/anna/Desktop/mv/multivers/multivers/model.py", line 86, in __init__

    self.encoder = self._get_encoder(hparams)

  File "/Users/anna/Desktop/mv/multivers/multivers/model.py", line 170, in _get_encoder

    new_state_dict[name] = orig_state_dict[name]

KeyError: 'embeddings.position_ids'

This is from line 168 in model.py: ADD_TO_CHECKPOINT = ["embeddings.position_ids"]

So I tried changing this to: ADD_TO_CHECKPOINT = ["embeddings.position_embeddings.weight"] as this seemed to be the missing item from the Huggingface state_dict.

However I then encountered this error:

Traceback (most recent call last):

  File "multivers/predict.py", line 109, in <module>

    main()

  File "multivers/predict.py", line 101, in main

    predictions = get_predictions(args)

  File "multivers/predict.py", line 36, in get_predictions

    model = MultiVerSModel.load_from_checkpoint(checkpoint_path=args.checkpoint_path)

  File "/opt/miniconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 156, in load_from_checkpoint

    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)

  File "/opt/miniconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 204, in _load_model_state

    model.load_state_dict(checkpoint['state_dict'], strict=strict)

  File "/opt/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict

    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(

RuntimeError: Error(s) in loading state_dict for MultiVerSModel:

Unexpected key(s) in state_dict: "encoder.embeddings.position_ids"

Which I've been unable to resolve so far. This is the same for all checkpoints (e.g. scifact, healthvers). Thanks!

dwadden commented 1 year ago

Can you provide me with the command you're running to get this error so I can reproduce?

Gab123789 commented 1 year ago

python multivers/predict.py

model is scifact.ckpt, but same result when using healthvers

dwadden commented 1 year ago

Got it - it may be just an issue of calling the wrong script. Can you follow the instructions in the inference section of the README and let me know what this does? In particular, try calling bash script/predict.sh scifact rather than running the Python script.

Gab123789 commented 1 year ago

My bad Dave, had I installed sentence-transformers, which isn't compatible with transfromers 4.2.2. Uninstalled and is working now - thanks for the help and sorry for taking up your time!

dwadden commented 1 year ago

No problem, I'm glad it worked!