moussaKam / BARThez

A french sequence to sequence pretrained model
Apache License 2.0
57 stars 11 forks source link

Unable to load weights from pytorch checkpoint file #2

Closed FlorianMuller closed 3 years ago

FlorianMuller commented 3 years ago

Hello,

I wanted to use BARThez with HuggingFace but it seems like I can't load the BARThez checkpoint.

I tried to execute your HuggingFace exemple:

text_sentence = "Paris est la capitale de la <mask>"

from transformers import ( AutoModelForSeq2SeqLM )
import torch
import sentencepiece as spm
from transformers import ( BarthezTokenizer )

barthez_tokenizer = BarthezTokenizer.from_pretrained("moussaKam/barthez")
barthez_model = AutoModelForSeq2SeqLM.from_pretrained("moussaKam/barthez")

input_ids = torch.tensor(
    [barthez_tokenizer.encode(text_sentence, add_special_tokens=True)]
)
mask_idx = torch.where(input_ids == barthez_tokenizer.mask_token_id)[1].tolist()[0]

barthez_model.eval()
predict = barthez_model.forward(input_ids)[0]

barthez_tokenizer.decode(predict[:, mask_idx, :].topk(5).indices[0])

(I don't know why, but I had to change a bit the import order to make it work: import AutoModelForSeq2SeqLM before torch and import sentencepiece as spm before BarthezTokenizer. Without this specific and weird order, I had a Segmentation fault (or the jupyter lab kernel would restart))

but encountered this error:

OSError: Unable to load weights from pytorch checkpoint file for 'moussaKam/barthez' at '/root/.cache/huggingface/transformers/83969d596ba07eda19456fd012872ce770b004cc42313bcef1bb8ea82db9bd27.fc8778edd5440e97055d6f539021d2ea934da72fe9044a3aa7fe65a9c66250c2'If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.
Full stacktrace
``` Traceback (most recent call last): File "/opt/conda/lib/python3.6/tarfile.py", line 189, in nti n = int(s.strip() or "0", 8) ValueError: invalid literal for int() with base 8: 'v2\nq\x03((X' During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/opt/conda/lib/python3.6/tarfile.py", line 2297, in next tarinfo = self.tarinfo.fromtarfile(self) File "/opt/conda/lib/python3.6/tarfile.py", line 1093, in fromtarfile obj = cls.frombuf(buf, tarfile.encoding, tarfile.errors) File "/opt/conda/lib/python3.6/tarfile.py", line 1035, in frombuf chksum = nti(buf[148:156]) File "/opt/conda/lib/python3.6/tarfile.py", line 191, in nti raise InvalidHeaderError("invalid header") tarfile.InvalidHeaderError: invalid header During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/opt/conda/lib/python3.6/site-packages/torch/serialization.py", line 591, in _load return legacy_load(f) File "/opt/conda/lib/python3.6/site-packages/torch/serialization.py", line 502, in legacy_load with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ File "/opt/conda/lib/python3.6/tarfile.py", line 1589, in open return func(name, filemode, fileobj, **kwargs) File "/opt/conda/lib/python3.6/tarfile.py", line 1619, in taropen return cls(name, mode, fileobj, **kwargs) File "/opt/conda/lib/python3.6/tarfile.py", line 1482, in __init__ self.firstmember = self.next() File "/opt/conda/lib/python3.6/tarfile.py", line 2309, in next raise ReadError(str(e)) tarfile.ReadError: invalid header During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/opt/conda/lib/python3.6/site-packages/transformers/modeling_utils.py", line 1038, in from_pretrained state_dict = torch.load(resolved_archive_file, map_location="cpu") File "/opt/conda/lib/python3.6/site-packages/torch/serialization.py", line 422, in load return _load(f, map_location, pickle_module, **pickle_load_args) File "/opt/conda/lib/python3.6/site-packages/torch/serialization.py", line 595, in _load raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name)) RuntimeError: /root/.cache/huggingface/transformers/83969d596ba07eda19456fd012872ce770b004cc42313bcef1bb8ea82db9bd27.fc8778edd5440e97055d6f539021d2ea934da72fe9044a3aa7fe65a9c66250c2 is a zip archive (did you mean to use torch.jit.load()?) During handling of the above exception, another exception occurred: Traceback (most recent call last): File "test_barthez.py", line 9, in barthez_model = AutoModelForSeq2SeqLM.from_pretrained("moussaKam/barthez") File "/opt/conda/lib/python3.6/site-packages/transformers/models/auto/modeling_auto.py", line 1219, in from_pretrained pretrained_model_name_or_path, *model_args, config=config, **kwargs File "/opt/conda/lib/python3.6/site-packages/transformers/modeling_utils.py", line 1041, in from_pretrained f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' " OSError: Unable to load weights from pytorch checkpoint file for 'moussaKam/barthez' at '/root/.cache/huggingface/transformers/83969d596ba07eda19456fd012872ce770b004cc42313bcef1bb8ea82db9bd27.fc8778edd5440e97055d6f539021d2ea934da72fe9044a3aa7fe65a9c66250c2'If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. ```

I triied replacing AutoModelForSeq2SeqLM with MBartForConditionalGeneration, but I had the same error:

Full stacktrace with MBartForConditionalGeneration ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) /opt/conda/lib/python3.6/tarfile.py in nti(s) 188 s = nts(s, "ascii", "strict") --> 189 n = int(s.strip() or "0", 8) 190 except ValueError: ValueError: invalid literal for int() with base 8: 'v2\nq\x03((X' During handling of the above exception, another exception occurred: InvalidHeaderError Traceback (most recent call last) /opt/conda/lib/python3.6/tarfile.py in next(self) 2296 try: -> 2297 tarinfo = self.tarinfo.fromtarfile(self) 2298 except EOFHeaderError as e: /opt/conda/lib/python3.6/tarfile.py in fromtarfile(cls, tarfile) 1092 buf = tarfile.fileobj.read(BLOCKSIZE) -> 1093 obj = cls.frombuf(buf, tarfile.encoding, tarfile.errors) 1094 obj.offset = tarfile.fileobj.tell() - BLOCKSIZE /opt/conda/lib/python3.6/tarfile.py in frombuf(cls, buf, encoding, errors) 1034 -> 1035 chksum = nti(buf[148:156]) 1036 if chksum not in calc_chksums(buf): /opt/conda/lib/python3.6/tarfile.py in nti(s) 190 except ValueError: --> 191 raise InvalidHeaderError("invalid header") 192 return n InvalidHeaderError: invalid header During handling of the above exception, another exception occurred: ReadError Traceback (most recent call last) /opt/conda/lib/python3.6/site-packages/torch/serialization.py in _load(f, map_location, pickle_module, **pickle_load_args) 590 try: --> 591 return legacy_load(f) 592 except tarfile.TarError: /opt/conda/lib/python3.6/site-packages/torch/serialization.py in legacy_load(f) 501 --> 502 with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ 503 mkdtemp() as tmpdir: /opt/conda/lib/python3.6/tarfile.py in open(cls, name, mode, fileobj, bufsize, **kwargs) 1588 raise CompressionError("unknown compression type %r" % comptype) -> 1589 return func(name, filemode, fileobj, **kwargs) 1590 /opt/conda/lib/python3.6/tarfile.py in taropen(cls, name, mode, fileobj, **kwargs) 1618 raise ValueError("mode must be 'r', 'a', 'w' or 'x'") -> 1619 return cls(name, mode, fileobj, **kwargs) 1620 /opt/conda/lib/python3.6/tarfile.py in __init__(self, name, mode, fileobj, format, tarinfo, dereference, ignore_zeros, encoding, errors, pax_headers, debug, errorlevel, copybufsize) 1481 self.firstmember = None -> 1482 self.firstmember = self.next() 1483 /opt/conda/lib/python3.6/tarfile.py in next(self) 2308 elif self.offset == 0: -> 2309 raise ReadError(str(e)) 2310 except EmptyHeaderError: ReadError: invalid header During handling of the above exception, another exception occurred: RuntimeError Traceback (most recent call last) /opt/conda/lib/python3.6/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) 1037 try: -> 1038 state_dict = torch.load(resolved_archive_file, map_location="cpu") 1039 except Exception: /opt/conda/lib/python3.6/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args) 421 try: --> 422 return _load(f, map_location, pickle_module, **pickle_load_args) 423 finally: /opt/conda/lib/python3.6/site-packages/torch/serialization.py in _load(f, map_location, pickle_module, **pickle_load_args) 594 # .zip is used for torch.jit.save and will throw an un-pickling error here --> 595 raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name)) 596 # if not a tarfile, reset file offset and proceed RuntimeError: /root/.cache/huggingface/transformers/83969d596ba07eda19456fd012872ce770b004cc42313bcef1bb8ea82db9bd27.fc8778edd5440e97055d6f539021d2ea934da72fe9044a3aa7fe65a9c66250c2 is a zip archive (did you mean to use torch.jit.load()?) During handling of the above exception, another exception occurred: OSError Traceback (most recent call last) in 9 barthez_tokenizer = BarthezTokenizer.from_pretrained("moussaKam/barthez") 10 # barthez_model = AutoModelForSeq2SeqLM.from_pretrained("moussaKam/barthez") ---> 11 barthez_model = MBartForConditionalGeneration.from_pretrained("moussaKam/barthez") 12 13 /opt/conda/lib/python3.6/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) 1039 except Exception: 1040 raise OSError( -> 1041 f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' " 1042 f"at '{resolved_archive_file}'" 1043 "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " OSError: Unable to load weights from pytorch checkpoint file for 'moussaKam/barthez' at '/root/.cache/huggingface/transformers/83969d596ba07eda19456fd012872ce770b004cc42313bcef1bb8ea82db9bd27.fc8778edd5440e97055d6f539021d2ea934da72fe9044a3aa7fe65a9c66250c2'If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. ```

And I tried using mbarthez but I had a Segmentation fault.


I use :

Full `pip freeze` ``` absl-py==0.8.0 alabaster==0.7.12 apex==0.1 appdirs==1.4.3 ascii-graph==1.5.1 asn1crypto==0.24.0 atomicwrites==1.3.0 attrs==19.2.0 audioread==2.1.8 Babel==2.7.0 backcall==0.1.0 beautifulsoup4==4.8.0 bleach==3.1.0 boto3==1.9.240 botocore==1.12.240 certifi==2020.12.5 cffi==1.12.3 chardet==3.0.4 Click==7.0 codecov==2.0.15 conda==4.9.2 conda-build==3.18.9 conda-package-handling==1.6.0 coverage==4.5.4 cryptography==2.7 cxxfilt==0.2.0 cycler==0.10.0 cymem==2.0.2 Cython==0.28.4 cytoolz==0.9.0.1 dataclasses==0.8 DataProperty==0.43.1 datasets==1.1.3 decorator==4.4.0 defusedxml==0.6.0 dill==0.2.9 docutils==0.15.2 entrypoints==0.3 filelock==3.0.12 flake8==3.7.8 Flask==1.1.1 future==0.17.1 glob2==0.7 grpcio==1.24.0 h5py==2.10.0 html2text==2019.9.26 hypothesis==4.38.1 idna==2.8 imageio==2.5.0 imagesize==1.1.0 importlib-metadata==0.23 inflect==2.1.0 ipdb==0.12.2 ipykernel==5.1.2 ipympl==0.5.8 ipython==7.8.0 ipython-genutils==0.2.0 ipywidgets==7.5.1 itsdangerous==1.1.0 jedi==0.15.1 Jinja2==2.10.1 jmespath==0.9.4 joblib==0.14.0 json5==0.8.5 jsonschema==3.0.2 jupyter-client==5.3.3 jupyter-core==4.5.0 jupyter-tensorboard==0.1.10 jupyterlab==2.2.9 jupyterlab-server==1.2.0 jupytext==1.2.4 kiwisolver==1.1.0 libarchive-c==2.8 librosa==0.6.3 lief==0.9.0 llvmlite==0.28.0 lmdb==0.97 Mako==1.1.0 Markdown==3.1.1 MarkupSafe==1.1.1 maskrcnn-benchmark==0.1 matplotlib==3.3.3 mbstrdecoder==0.8.1 mccabe==0.6.1 mistune==0.8.4 mlperf-compliance==0.0.10 mock==3.0.5 more-itertools==7.2.0 msgfy==0.0.7 msgpack==0.6.1 msgpack-numpy==0.4.3.2 multiprocess==0.70.11.1 murmurhash==1.0.2 mysql-connector-python==8.0.22 nbconvert==5.6.0 nbformat==4.4.0 networkx==2.0 nltk==3.4.5 notebook==6.0.1 numba==0.43.1 numpy==1.17.2 nvidia-dali==0.14.0 onnx==1.5.0 opencv-python==3.4.1.15 packaging==19.2 pandas==0.24.2 pandocfilters==1.4.2 parso==0.5.1 pathvalidate==0.29.0 pexpect==4.7.0 pickleshare==0.7.5 Pillow==8.0.1 Pillow-SIMD==5.3.0.post1 pkginfo==1.5.0.1 plac==0.9.6 pluggy==0.13.0 preshed==2.0.1 progressbar==2.5 prometheus-client==0.7.1 prompt-toolkit==2.0.9 protobuf==3.9.2 psutil==5.6.3 ptyprocess==0.6.0 py==1.8.0 pyarrow==2.0.0 pybind11==2.4.2 pycocotools==2.0+nv0.3.1 pycodestyle==2.5.0 pycosat==0.6.3 pycparser==2.19 pycuda==2019.1.2 pydot==1.4.1 pyflakes==2.1.1 Pygments==2.4.2 pymongo==3.11.2 pyOpenSSL==19.0.0 pyparsing==2.4.2 pyrsistent==0.15.4 PySocks==1.7.1 pytablewriter==0.46.1 pytest==5.2.0 pytest-cov==2.7.1 pytest-pythonpath==0.7.3 python-dateutil==2.8.0 python-nvd3==0.15.0 python-slugify==3.0.4 pytools==2019.1.1 pytorch-crf==0.7.2 pytz==2019.2 PyWavelets==1.0.3 PyYAML==5.1.2 pyzmq==18.1.0 regex==2018.1.10 requests==2.22.0 resampy==0.2.2 revtok==0.0.3 ruamel-yaml==0.15.46 s3transfer==0.2.1 sacrebleu==1.2.10 sacremoses==0.0.19 scikit-image==0.15.0 scikit-learn==0.21.3 scipy==1.3.1 seaborn==0.11.1 Send2Trash==1.5.0 sentencepiece==0.1.94 seqeval==1.2.2 six==1.12.0 snowballstemmer==1.9.1 SoundFile==0.10.2 soupsieve==1.9.3 sox==1.3.7 spacy==2.0.16 Sphinx==2.2.0 sphinx-rtd-theme==0.4.3 sphinxcontrib-applehelp==1.0.1 sphinxcontrib-devhelp==1.0.1 sphinxcontrib-htmlhelp==1.0.2 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.2 sphinxcontrib-serializinghtml==1.1.3 SSD==0.1 subword-nmt==0.3.3 tabledata==0.9.1 tabulate==0.8.5 tensorboard==2.0.0 tensorrt==6.0.1.5 terminado==0.8.2 testpath==0.4.2 text-unidecode==1.3 thinc==6.12.1 tokenizers==0.9.4 toml==0.10.0 toolz==0.10.0 torch==1.3.0a0+24ae9b5 torchtext==0.4.0 torchvision==0.5.0a0 tornado==6.0.3 tqdm==4.31.1 traitlets==4.3.2 transformers==4.2.1 typepy==0.6.0 typing==3.7.4.1 typing-extensions==3.7.4 ujson==1.35 Unidecode==1.1.1 urllib3==1.24.2 wcwidth==0.1.7 webencodings==0.5.1 Werkzeug==0.16.0 widgetsnbextension==3.5.1 wrapt==1.10.11 xxhash==2.0.0 yacs==0.1.6 zipp==0.6.0 ```

I know my PyTorch version is a bit low but this is the maximum version I can use with the cuda driver I have on the machine I use.

I've done all my test in a docker container from nvidia (nvcr.io/nvidia/pytorch:19.10-py3)

FlorianMuller commented 3 years ago

FInally, I think it was due to my low PyTorch version. I've try to load the model with PyTorch 1.7 and it worked. Then I saved it with :

 torch.save(model.state_dict(), "pytorch_model.bin", _use_new_zipfile_serialization=False)

And then I could load the pytorch_model.bin with my 1.3 PyTorch version (thanks to the _use_new_zipfile_serialization option)