allenai / OLMo

Modeling, training, eval, and inference code for OLMo
https://allenai.org/olmo
Apache License 2.0
4.37k stars 431 forks source link

Cannot convert internal OLMo checkpoint to HF #654

Open viking-sudo-rm opened 1 month ago

viking-sudo-rm commented 1 month ago

🐛 Describe the bug

I'm trying to convert internal OLMo checkpoints to Huggingface format so I can analyze the attention matrices. The reason I need Huggingface format is that @2015aroras implemented a very useful flag output_attentions in the Huggingface API to retrieve attention matrices from the forward pass, but this is not supported natively in the OLMo repo. The conversion script seems to work for converting the checkpoint to Hugginface (see below), but when I try to load the Huggingface checkpoint, I get an error message.

Minimal Example

I'd like to be able to handle many different checkpoints, but for purposes of reproducing the bug I'll focus ons3://ai2-llm/checkpoints/OLMo-medium/mitchish7/step0. I've downloaded this locally and unsharded it using:

ROOT=/net/nfs.cirrascale/allennlp/willm/olmo-sparsity/checkpoints
S3="s3://ai2-llm/checkpoints"
ckpt="OLMo-medium/mitchish7/step0"
aws s3 cp --recursive $S3/$ckpt $ROOT/$ckpt
python scripts/unshard.py \
    $ROOT/$ckpt \
    $ROOT/$ckpt-unsharded \
    --type local

I then run scripts/convert_olmo_to_hf_new.py to convert to Huggingface format:

python scripts/convert_olmo_to_hf_new.py \
    --input_dir $ROOT/OLMo-medium/mitchish7/step0-unsharded \
    --output_dir $ROOT/OLMo-medium/mitchish7/step0-hf \
    --tokenizer_json_path olmo_data/tokenizers/allenai_gpt-neox-olmo-dolma-v1_5.json

This fails with an error message when trying to delete the temporary files after converting to the Huggingface format:

Traceback (most recent call last):
  File "/home/willm/OLMo/scripts/convert_olmo_to_hf_new.py", line 272, in <module>
    main()
  File "/home/willm/OLMo/scripts/convert_olmo_to_hf_new.py", line 261, in main
    write_model(
  File "/home/willm/OLMo/scripts/convert_olmo_to_hf_new.py", line 191, in write_model
    shutil.rmtree(tmp_model_path)
  File "/opt/miniconda3/lib/python3.10/shutil.py", line 731, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/opt/miniconda3/lib/python3.10/shutil.py", line 729, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/net/nfs.cirrascale/allennlp/willm/olmo-sparsity/checkpoints/OLMo-medium/mitchish7/step0-hf/tmp'

However, the Huggingface model is there, and after inspecting the script, I think everything should have run correctly besides the last line which deletes the temporary scratch directory. See: https://github.com/allenai/OLMo/blob/d423c11a58b6a4dcd5e6256618a6670339b4447a/scripts/convert_olmo_to_hf_new.py#L191

Because I think the output of the Huggingface script is correct despite the error message, I proceeded with trying to load in Huggingface as follows:

from hf_olmo import OLMoForCausalLM
ckpt = "/net/nfs.cirrascale/allennlp/willm/olmo-sparsity/checkpoints/OLMo-medium/mitchish7/step0-hf"
olmo = OLMoForCausalLM.from_pretrained(ckpt)

This throws the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/willm/.local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3122, in from_pretrained
    config, model_kwargs = cls.config_class.from_pretrained(
  File "/home/willm/.local/lib/python3.10/site-packages/transformers/configuration_utils.py", line 609, in from_pretrained
    return cls.from_dict(config_dict, **kwargs)
  File "/home/willm/.local/lib/python3.10/site-packages/transformers/configuration_utils.py", line 761, in from_dict
    config = cls(**config_dict)
  File "/home/willm/OLMo/hf_olmo/configuration_olmo.py", line 25, in __init__
    super().__init__(**all_kwargs)
  File "/home/willm/.local/lib/python3.10/site-packages/transformers/configuration_utils.py", line 375, in __init__
    raise err
  File "/home/willm/.local/lib/python3.10/site-packages/transformers/configuration_utils.py", line 372, in __init__
    setattr(self, key, value)
  File "/home/willm/.local/lib/python3.10/site-packages/transformers/configuration_utils.py", line 258, in __setattr__
    super().__setattr__(key, value)
AttributeError: can't set attribute 'hidden_size'

Possible Diagnoses

Am I converting the OLMo checkpoint to Huggingface incorrectly? Or should I be using it in a way besides passing the path to from_pretrained? Or maybe there's just a version issue here and I shouldn't expect to be able to load arbitrary checkpoints in Huggingface? If the latter, it would be nice if the error messages could better indicate the incompatibility. I would also like to know potential workarounds for retrieving attention matrices natively in the OLMo repo without having to convert to Huggingface.

Versions

Python 3.10.9

pip freeze accelerate==0.32.1 -e git+https://github.com/allenai/OLMo@d423c11a58b6a4dcd5e6256618a6670339b4447a#egg=ai2_olmo -e git+https://github.com/allenai/OLMo-core@eb56a9f0c2f63cf2e79e90da878a00d1a282cec9#egg=ai2_olmo_core aiohttp==3.9.5 aiosignal==1.3.1 alabaster==0.7.16 annotated-types==0.6.0 antlr4-python3-runtime==4.9.3 anyio==3.7.0 argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 arrow==1.2.3 asttokens==2.2.1 async-timeout==4.0.3 attrs==23.1.0 Babel==2.15.0 backcall==0.2.0 backoff==2.1.2 backports.tarfile==1.2.0 beaker-gantry==1.1.0 beaker-py==1.26.14 beautifulsoup4==4.12.2 black==23.12.1 bleach==6.0.0 blessed==1.20.0 blinker==1.8.2 boltons==24.0.0 boto3==1.34.96 boto3-extensions==0.23.0 botocore==1.34.96 brotlipy==0.7.0 build==1.2.1 cached_path==1.6.2 cachetools==5.3.3 certifi @ file:///croot/certifi_1671487769961/work/certifi cffi @ file:///croot/cffi_1670423208954/work charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work click==8.1.7 click-aliases==1.0.4 click-help-colors==0.9.1 cloudpickle==3.0.0 cmake==3.29.2 colorama==0.4.6 comm==0.1.3 conda==23.1.0 conda-content-trust @ file:///tmp/abs_5952f1c8-355c-4855-ad2e-538535021ba5h26t22e5/croots/recipe/conda-content-trust_1658126371814/work conda-package-handling @ file:///croot/conda-package-handling_1672865015732/work conda_package_streaming @ file:///croot/conda-package-streaming_1670508151586/work contourpy==1.2.1 cryptography @ file:///croot/cryptography_1673298753778/work cycler==0.12.1 datasets==2.7.1 dateparser==1.2.0 debugpy==1.6.7 decorator==5.1.1 defusedxml==0.7.1 dill==0.3.6 diskcache==5.6.3 distro==1.9.0 docker==6.1.3 docker-pycreds==0.4.0 docutils==0.20.1 evaluate==0.4.2 exceptiongroup==1.1.1 execnet==2.1.1 executing==1.2.0 face==20.1.1 fastapi==0.110.3 fastjsonschema==2.17.1 filelock==3.13.4 Flask==3.0.3 fonttools==4.51.0 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2024.3.1 ftfy==6.2.0 furo==2023.5.20 gitdb==4.0.10 GitPython==3.1.31 glom==23.5.0 google-api-core==2.19.0 google-auth==2.30.0 google-cloud-core==2.4.1 google-cloud-storage==2.17.0 google-crc32c==1.5.0 google-resumable-media==2.7.1 googleapis-common-protos==1.63.1 gpustat==1.1 h11==0.14.0 halo==0.0.31 httpcore==1.0.5 httptools==0.6.1 httpx==0.27.0 huggingface-hub==0.21.4 idna @ file:///croot/idna_1666125576474/work imagesize==1.4.1 importlib_metadata==7.2.0 importlib_resources==6.4.0 iniconfig==2.0.0 interegular==0.3.3 ipykernel==6.23.2 ipython==8.14.0 ipython-genutils==0.2.0 ipywidgets==8.0.6 isodate==0.6.1 isoduration==20.11.0 isort==5.12.0 itsdangerous==2.2.0 jaraco.classes==3.4.0 jaraco.context==5.3.0 jaraco.functools==4.0.1 jedi==0.18.2 jeepney==0.8.0 Jinja2==3.1.2 jmespath==1.0.1 joblib==1.4.0 jsonpointer==2.3 jsonschema==4.17.3 jupyter==1.0.0 jupyter-console==6.6.3 jupyter-events==0.6.3 jupyter_client==8.2.0 jupyter_core==5.3.1 jupyter_server==2.6.0 jupyter_server_terminals==0.4.4 jupyterlab-pygments==0.2.2 jupyterlab-widgets==3.0.7 keyring==25.2.1 kiwisolver==1.4.5 lark==1.1.9 lightning-utilities==0.11.3.post0 livereload==2.6.3 llvmlite==0.42.0 lm-format-enforcer==0.9.8 log-symbols==0.0.14 markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.8.4 matplotlib-inline==0.1.6 maturin==1.5.1 mdit-py-plugins==0.4.1 mdurl==0.1.2 mistune==2.0.5 more-itertools==10.3.0 mpmath==1.3.0 msgpack==1.0.8 msgspec==0.18.6 multidict==6.0.5 multiprocess==0.70.14 mypy==1.3.0 mypy-extensions==1.0.0 myst-parser==2.0.0 nbclassic==1.0.0 nbclient==0.8.0 nbconvert==7.5.0 nbformat==5.9.0 necessary==0.4.3 nest-asyncio==1.5.6 networkx==3.3 nh3==0.2.17 ninja==1.11.1.1 notebook==6.5.4 notebook_shim==0.2.3 numba==0.59.1 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-ml-py==11.525.112 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.1.105 omegaconf==2.3.0 oocmap==0.3 openai==1.29.0 outlines==0.0.34 overrides==7.3.1 packaging==23.1 pandas==2.2.2 pandocfilters==1.5.0 parso==0.8.3 pathspec==0.12.1 petname==2.6 pexpect==4.8.0 pickleshare==0.7.5 pillow==10.3.0 pkginfo==1.11.1 platformdirs==3.5.3 pluggy==1.5.0 prometheus-fastapi-instrumentator==7.0.0 prometheus_client==0.20.0 prompt-toolkit==3.0.38 proto-plus==1.23.0 protobuf==4.25.3 psutil==5.9.5 ptyprocess==0.7.0 pure-eval==0.2.2 py-cpuinfo==9.0.0 pyarrow==16.0.0 pyasn1==0.6.0 pyasn1_modules==0.4.0 pycosat @ file:///croot/pycosat_1666805502580/work pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work pydantic==2.7.1 pydantic_core==2.18.2 Pygments==2.15.1 pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work pyparsing==3.1.2 pyproject_hooks==1.1.0 pyrsistent==0.19.3 PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work pytest==8.2.2 pytest-sphinx==0.6.3 pytest-xdist==3.6.1 python-dateutil==2.8.2 python-dotenv==1.0.1 python-json-logger==2.0.7 pytz==2024.1 PyYAML==6.0 pyzmq==25.1.0 qtconsole==5.4.3 QtPy==2.3.1 ray==2.12.0 readme_renderer==43.0 referencing==0.35.0 regex==2024.4.28 requests @ file:///opt/conda/conda-bld/requests_1657734628632/work requests-toolbelt==1.0.0 requirements-parser==0.9.0 responses==0.18.0 rfc3339-validator==0.1.4 rfc3986==2.0.0 rfc3986-validator==0.1.1 rich==13.4.2 rpds-py==0.18.0 rsa==4.9 ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work ruff==0.4.10 rusty-dawg @ file:///home/willm/rusty-dawg/bindings/python/target/wheels/rusty_dawg-0.1.0-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=b780f6524d32a76e9dfff4137425f799401afc86deed08bbe1c73ac4885e3baf s3transfer==0.10.1 safetensors==0.4.3 scikit-learn==1.5.1 scipy==1.13.0 seaborn==0.13.2 SecretStorage==3.3.3 Send2Trash==1.8.2 sentencepiece==0.2.0 sentry-sdk==2.7.1 setproctitle==1.3.3 six @ file:///tmp/build/80754af9/six_1644875935023/work smart-open==7.0.4 smashed==0.21.5 smmap==5.0.0 sniffio==1.3.0 snowballstemmer==2.2.0 soupsieve==2.4.1 Sphinx==7.0.1 sphinx-autobuild==2021.3.14 sphinx-autodoc-typehints==1.23.3 sphinx-basic-ng==1.0.0b2 sphinx-copybutton==0.5.2 sphinxcontrib-applehelp==1.0.8 sphinxcontrib-devhelp==1.0.6 sphinxcontrib-htmlhelp==2.0.5 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.7 sphinxcontrib-serializinghtml==1.1.10 spinners==0.0.24 stack-data==0.6.2 starlette==0.37.2 sympy==1.12 tabulate==0.9.0 termcolor==2.4.0 terminado==0.17.1 threadpoolctl==3.5.0 tiktoken==0.6.0 tinycss2==1.2.1 tokenizers==0.19.1 tomli==2.0.1 toolz @ file:///croot/toolz_1667464077321/work torch==2.1.2 torchmetrics==1.4.0.post0 tornado==6.3.2 tqdm @ file:///opt/conda/conda-bld/tqdm_1664392687731/work traitlets==5.9.0 transformers==4.40.1 triton==2.1.0 trouting==0.3.3 twine==5.1.0 typeguard==2.13.3 types-setuptools==70.1.0.20240627 typing_extensions==4.11.0 tzdata==2024.1 tzlocal==5.2 uri-template==1.2.0 urllib3 @ file:///croot/urllib3_1673575502006/work uvicorn==0.29.0 uvloop==0.19.0 vllm==0.4.2 vllm-nccl-cu12==2.18.1.0.4.0 wandb==0.17.4 watchfiles==0.21.0 wcwidth==0.2.13 webcolors==1.13 webencodings==0.5.1 websocket-client==1.5.3 websockets==12.0 Werkzeug==3.0.3 widgetsnbextension==4.0.7 wrapt==1.16.0 xformers==0.0.26.post1 xxhash==3.4.1 yarl==1.9.4 zipp==3.19.2 zstandard @ file:///opt/conda/conda-bld/zstandard_1663827383994/work
AkshitaB commented 1 month ago

@2015aroras , can you take a look at this?

2015aroras commented 1 month ago

Regarding the failure of deleting temp files, I'm guessing you're running on beaker. Beaker is problematic about deleting the files with python (as Oyvindt discovered), so you can skip it using --no_tmp_cleanup and delete the temp folder manually afterwards. The main stuff of the script has run successfully if you are hitting the cleanup error.

Your issue is that from hf_olmo import OLMoForCausalLM is the "old-style" of OLMo HF. You're trying to load "new-style" checkpoints with the old OLMo HF checkpoints. Using from transformers import AutoModelForCausalLM and AutoModelForCausalLM.from_pretrained should make things work for you (and you can use this for old-style checkpoints too, as long as you do import hf_olmo first).

NB: if you convert a checkpoint not compatible with OLMo 1 or 1.7, then the converter may silent fail, producing a checkpoint that produces incorrect outputs. Then you'll need to use old-style HF OLMo checkpoints, but this does not have output_attentions implemented yet. See https://github.com/allenai/OLMo/blob/main/docs/Checkpoints.md for more details about types of checkpoints.