Mindful / MWEasWSD

Repo for the paper "MWE as WSD: Solving Multi-Word Expression Identification with Word Sense Disambiguation"
GNU Affero General Public License v3.0
4 stars 1 forks source link

Unexpected keys in state dict #2

Closed Yusuke196 closed 2 months ago

Yusuke196 commented 3 months ago

I downloaded https://huggingface.co/Jotanner/mweaswsd-ft and tried to run the model for evaluation by

python scripts/training/wsd_eval.py --data data/WSD_Evaluation_Framework/Evaluation_Datasets/ALL/ --model $MODEL --batch_size 2

and hit an error.

Limiting possible wordnet results to those in key candidate dictionary
build candidate dict: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 155287/155287 [00:30<00:00, 5127.86it/s]
build synset key dict: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 155287/155287 [00:21<00:00, 7106.39it/s]
Lightning automatically upgraded your loaded checkpoint from v1.5.10 to v1.9.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file mweaswsd-ft/model.ckpt`
/cl/work2/yusuke-i/conda/envs/mwe-as-wsd/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Traceback (most recent call last):
  File "/project/cl-work2/yusuke-i/ghq/github.com/Mindful/MWEasWSD/scripts/training/wsd_eval.py", line 106, in <module>
    main()
  File "/project/cl-work2/yusuke-i/ghq/github.com/Mindful/MWEasWSD/scripts/training/wsd_eval.py", line 38, in main
    model = ContextDictionaryBiEncoder.load_from_checkpoint(str(args.model))
  File "/cl/work2/yusuke-i/conda/envs/mwe-as-wsd/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 139, in load_from_checkpoint
    return _load_from_checkpoint(  # type: ignore[return-value]
  File "/cl/work2/yusuke-i/conda/envs/mwe-as-wsd/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 188, in _load_from_checkpoint
    return _load_state(cls, checkpoint, strict=strict, **kwargs)
  File "/cl/work2/yusuke-i/conda/envs/mwe-as-wsd/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 247, in _load_state
    keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
  File "/cl/work2/yusuke-i/conda/envs/mwe-as-wsd/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ContextDictionaryBiEncoder:
        Unexpected key(s) in state_dict: "context_encoder.embeddings.position_ids", "definition_encoder.embeddings.position_ids". 

It seems that the model on huggingface has the keys "context_encoder.embeddings.position_ids" and "definition_encoder.embeddings.position_ids", while the model (checkpoint) produced by the current code doesn't. Can I easily fix this?

If not, I guess I should do all the tuning by myself. But if possible I'd like to run the model on huggingface to make sure that I'm using the same model as the paper.

Mindful commented 2 months ago

So I initially thought this was an issue with our model head configs, but after actually inspecting the model state dict and poking around a little, I think this might be an issue with 🤗 Transformers versioning. See people getting similar issues here: https://github.com/huggingface/transformers/issues/24921

Can you try these two things: 1) Make sure your transformers version matches the requirements version. If you have 4.1.x, you could try rolling back to the earliest 4.1. 2) See if you can just delete the position_ids keys from the model state dict. I looked at the value and it's just a range from 0 to the max position, so this may not need to be there and the model might just run without it.

Yusuke196 commented 2 months ago
  1. I had transformers 4.41.2. I set transformers==4.1.0 (or transformers==4.1.*) in requirements.txt and ran pip install -r requirements.txt but it anyway failed due to an issue related to building wheel for tokenizers. I am trying to solve this, but not successful so far (I understand this is a problem on my side).

  2. I successfully ran a model (mweaswsd-ft) like:

    Writing output
    Running scorer -------
    P=      74.0%
    R=      74.0%
    F1=     74.0%
    ----------------------
    Done

    by adding the following to ContextDictionaryBiEncoder:

    def on_load_checkpoint(self, checkpoint):
        keys_to_delete = [
            'context_encoder.embeddings.position_ids',
            'definition_encoder.embeddings.position_ids',
        ]
        for key in keys_to_delete:
            if key in checkpoint['state_dict']:
                del checkpoint['state_dict'][key]

    The score seems to match that on the paper. Thank you!

Mindful commented 2 months ago

Glad you were able to resolve the issue with #2! In that case I wouldn't worry about tweaking the transformers version; it's possible that there are other things that changed on HF's end (like maybe there's a config file somewhere that gets downloaded separate from the versions, I'm not sure).

In any case, seems like the issue is solved for now.