Babelscape / rebel

REBEL is a seq2seq model that simplifies Relation Extraction (EMNLP 2021).
502 stars 73 forks source link

Pytorch shape mismatch #46

Closed Iqra840 closed 2 years ago

Iqra840 commented 2 years ago

When I try to run model_saving.py to save the model in a hf transformers format, I get the following error and am not sure how to resolve this. Is there an issue with my training, or is one of my packages incompatible? Thank you for your help!


 python rebel/src/model_saving.py

File "rebel/src/model_saving.py", line 25, in <module>
    model = pl_module.load_from_checkpoint(checkpoint_path = 'rebel/outputs/2022-10-17/08-41-38/experiments/docred/epoch=13-step=1315.ckpt', config = config, tokenizer = tokenizer, model = model)
  File "/anaconda/envs/worker_venv/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 159, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "/anaconda/envs/worker_venv/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 205, in _load_model_state
    model.load_state_dict(checkpoint['state_dict'], strict=strict)
  File "/anaconda/envs/worker_venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BasePLModule:
        size mismatch for model.final_logits_bias: copying a param with shape torch.Size([1, 50278]) from checkpoint, the shape in current model is torch.Size([1, 50268]).
        size mismatch for model.model.shared.weight: copying a param with shape torch.Size([50278, 1024]) from checkpoint, the shape in current model is torch.Size([50268, 1024]).
        size mismatch for model.model.encoder.embed_tokens.weight: copying a param with shape torch.Size([50278, 1024]) from checkpoint, the shape in current model is torch.Size([50268, 1024]).
        size mismatch for model.model.decoder.embed_tokens.weight: copying a param with shape torch.Size([50278, 1024]) from checkpoint, the shape in current model is torch.Size([50268, 1024]).
        size mismatch for model.lm_head.weight: copying a param with shape torch.Size([50278, 1024]) from checkpoint, the shape in current model is torch.Size([50268, 1024])```

packages installed:
absl-py                 1.3.0
aiohttp                 3.8.3
aiosignal               1.2.0
altair                  4.2.0
antlr4-python3-runtime  4.8
arrow                   1.2.3
astor                   0.8.1
async-timeout           4.0.2
attrs                   22.1.0
backports.zoneinfo      0.2.1
base58                  2.1.1
blinker                 1.5
bravado                 11.0.3
bravado-core            5.17.1
cachetools              5.2.0
certifi                 2022.9.24
charset-normalizer      2.1.1
click                   7.1.2
configparser            5.3.0
datasets                1.3.0
decorator               5.1.1
dill                    0.3.5.1
docker-pycreds          0.4.0
entrypoints             0.4
filelock                3.8.0
fqdn                    1.5.1
frozenlist              1.3.1
fsspec                  2022.8.2
future                  0.18.2
gitdb                   4.0.9
GitPython               3.1.29
google-auth             2.12.0
google-auth-oauthlib    0.4.6
grpcio                  1.49.1
huggingface-hub         0.10.1
hydra-core              1.0.6
idna                    3.4
importlib-metadata      5.0.0
importlib-resources     5.10.0
isoduration             20.11.0
Jinja2                  3.1.2
joblib                  1.2.0
jsonpointer             2.3
jsonref                 0.3.0
jsonschema              4.16.0
Markdown                3.4.1
MarkupSafe              2.1.1
monotonic               1.6
msgpack                 1.0.4
multidict               6.0.2
multiprocess            0.70.13
neptune-client          0.5.1
nltk                    3.7
numpy                   1.23.4
oauthlib                3.2.1
omegaconf               2.0.6
packaging               21.3
pandas                  1.5.0
pathtools               0.1.2
Pillow                  9.2.0
pip                     22.2.2
pkgutil_resolve_name    1.3.10
portalocker             2.5.1
promise                 2.3
protobuf                3.19.6
psutil                  5.8.0
pyarrow                 9.0.0
pyasn1                  0.4.8
pyasn1-modules          0.2.8
pydeck                  0.8.0b4
pyDeprecate             0.3.2
PyJWT                   2.5.0
pyparsing               3.0.9
pyrsistent              0.18.1
python-dateutil         2.8.2
pytorch-lightning       1.1.7
pytz                    2022.4
pytz-deprecation-shim   0.1.0.post0
PyYAML                  6.0
regex                   2022.9.13
requests                2.28.1
requests-oauthlib       1.3.1
rfc3339-validator       0.1.4
rfc3987                 1.3.8
rouge-score             0.0.4
rsa                     4.9
sacrebleu               1.5.0
sentry-sdk              1.9.10
setuptools              63.4.1
shortuuid               1.0.9
simplejson              3.17.6
six                     1.16.0
smmap                   5.0.0
streamlit               0.82.0
subprocess32            3.5.4
swagger-spec-validator  2.7.6
tensorboard             2.10.1
tensorboard-data-server 0.6.1
tensorboard-plugin-wit  1.8.1
tokenizers              0.12.1
toml                    0.10.2
toolz                   0.12.0
torch                   1.12.1
torchmetrics            0.10.0
tornado                 6.2
tqdm                    4.64.1
transformers            4.23.1
typing_extensions       4.4.0
tzdata                  2022.5
tzlocal                 4.2
uri-template            1.2.0
urllib3                 1.26.12
validators              0.20.0
wandb                   0.10.26
watchdog                2.1.9
webcolors               1.12
websocket-client        1.4.1
Werkzeug                2.2.2
wheel                   0.37.1
xxhash                  3.0.0
yarl                    1.8.1
zipp                    3.9.0
Iqra840 commented 2 years ago

Edit: It works for now as I set the parameter in the reshape method to be the length of the checkpoint and not the tokeniser, but may I still know why this might be happening?

LittlePea13 commented 2 years ago

Hi, this is due to the number of tokens in the loaded checkpoint being different to that of the model as defined in the pl_module. This is probably due to having smaller vocabulary (50268) when the pl_module is instantiated vs the larger vocab when the checkpoint was saved (50278). Sorry for the inconvenience, either having the model initialised with the same number of tokens as it had when the checkpoint was saved or what you suggest should work.

shellbreaker commented 1 year ago

How did you solve it? (same issue)