allenai / OLMo

Modeling, training, eval, and inference code for OLMo
https://allenai.org/olmo
Apache License 2.0
4.24k stars 400 forks source link

Problem with HF loading from model checkpoint #586

Closed ryanyxw closed 2 months ago

ryanyxw commented 2 months ago

🐛 Describe the bug

I'm trying to load a OLMO-1B checkpoint into huggingface in order to utilize the HF inference and trainer scripts. However, I'm having trouble loading the model in the first place. I get the following error:

Some weights of OlmoForCausalLM were not initialized from the model checkpoint at /home/ryan/decouple/models/olmo/olmo1B_step737000 and are newly initialized: 
['lm_head.weight', 'model.embed_tokens.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight' [etc...]
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

This reinitialization occurs for all components from layers 0 up to 31 (meaning that the entire model was probably completely reinitialized).

You can replicate this as follows. Please let me know if there is anything wrong with this pipeline

Execute the following in a bash file (which simply downloads a checkpoint directly from the checkpoint urls)

checkpoint_url="https://olmo-checkpoints.org/ai2-llm/olmo-small/g4g72enr/step737000-unsharded/"

output_dir="olmo1B_step737000"

mkdir -p "$output_dir"

files=("config.yaml" "model.pt" "optim.pt" "train.pt")

# Loop through the list of files and download each one
for file in "${files[@]}"; do
    wget -N "${checkpoint_url}${file}" -O "${output_dir}/${file}"
done

Execute the following bash file to convert olmo model to hf compatible model

checkpoint_dir=olmo1B_step737000
python hf_olmo/convert_olmo_to_hf.py --checkpoint-dir ${checkpoint_dir}

Execute the following python file

from transformers import AutoModelForCausalLM, AutoTokenizer

def main():
    olmo = AutoModelForCausalLM.from_pretrained("./olmo1B_step737000")

if __name__ == "__main__":
    main()

Versions

Python 3.8.19 accelerate==0.29.1

Editable install with no version control (ai2-olmo==0.3.0)

-e /home/ryan/decouple/OLMo aiobotocore==2.12.2 aiohttp==3.9.3 aioitertools==0.11.0 aiosignal==1.3.1 annotated-types==0.6.0 antlr4-python3-runtime==4.9.3 anyascii==0.3.2 appdirs==1.4.4 async-timeout==4.0.3 attrs==23.2.0 backports.tarfile==1.0.0 beaker-gantry==0.22.2 beaker-py==1.26.3 beautifulsoup4==4.12.3 black==23.12.1 blingfire==0.1.8 boltons==24.0.0 boto3==1.34.51 botocore==1.34.51 bs4==0.0.2 build==1.2.1 cached_path==1.6.2 cachetools==5.3.3 certifi==2024.2.2 cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 click-help-colors==0.9.4 cmake==3.26.3 contourpy==1.1.1 cryptography==42.0.5 cycler==0.12.1 datasets==2.18.0 dill==0.3.8 docker==6.1.3 docker-pycreds==0.4.0 docutils==0.20.1 dolma==1.0.2 exceptiongroup==1.2.0 face==20.1.1 fasttext-wheel==0.9.2 filelock==3.13.3 fonttools==4.51.0 frozenlist==1.4.1 fsspec==2024.3.1 ftfy==6.2.0 gitdb==4.0.11 GitPython==3.1.43 glom==23.5.0 google-api-core==2.18.0 google-api-python-client==2.125.0 google-auth==2.29.0 google-auth-httplib2==0.2.0 google-cloud-core==2.4.1 google-cloud-storage==2.16.0 google-crc32c==1.5.0 google-resumable-media==2.7.0 googleapis-common-protos==1.63.0 httplib2==0.22.0 huggingface-hub==0.21.4 idna==3.6 importlib_metadata==7.1.0 importlib_resources==6.4.0 iniconfig==2.0.0 isort==5.12.0 jaraco.classes==3.4.0 jaraco.context==5.3.0 jaraco.functools==4.0.0 jeepney==0.8.0 Jinja2==3.1.3 jmespath==1.0.1 joblib==1.3.2 keyring==25.1.0 kiwisolver==1.4.5 langdetect==1.0.9 lightning-utilities==0.11.2 lit==16.0.2 LTpycld2==0.42 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.7.5 mdurl==0.1.2 more-itertools==10.2.0 mpmath==1.3.0 msgspec==0.18.6 multidict==6.0.5 multiprocess==0.70.16 mypy==1.3.0 mypy-extensions==1.0.0 necessary==0.4.3 networkx==3.1 nh3==0.2.17 nltk==3.8.1 numpy==1.24.4 nvidia-cublas-cu11==11.11.3.6 nvidia-cuda-cupti-cu11==11.8.87 nvidia-cuda-nvrtc-cu11==11.8.89 nvidia-cuda-runtime-cu11==11.8.89 nvidia-cudnn-cu11==8.7.0.84 nvidia-cufft-cu11==10.9.0.58 nvidia-curand-cu11==10.3.0.86 nvidia-cusolver-cu11==11.4.1.48 nvidia-cusparse-cu11==11.7.5.86 nvidia-nccl-cu11==2.19.3 nvidia-nvtx-cu11==11.8.86 omegaconf==2.3.0 packaging==24.0 pandas==2.0.3 pathspec==0.12.1 peft==0.10.0 petname==2.6 pillow==10.3.0 pkginfo==1.10.0 platformdirs==4.2.0 pluggy==1.4.0 proto-plus==1.23.0 protobuf==4.25.3 psutil==5.9.8 pyarrow==15.0.2 pyarrow-hotfix==0.6 pyasn1==0.6.0 pyasn1_modules==0.4.0 pybind11==2.12.0 pycparser==2.22 pydantic==2.6.4 pydantic_core==2.16.3 Pygments==2.17.2 pyparsing==3.1.2 pyproject_hooks==1.0.0 pytest==8.1.1 pytest-sphinx==0.6.2 python-dateutil==2.9.0.post0 pytz==2024.1 PyYAML==6.0.1 readme_renderer==43.0 regex==2023.12.25 requests==2.31.0 requests-toolbelt==1.0.0 requirements-parser==0.9.0 rfc3986==2.0.0 rich==13.7.1 rsa==4.9 ruff==0.3.5 s3fs==2024.3.1 s3transfer==0.10.1 safetensors==0.4.2 scikit-learn==1.3.2 scipy==1.10.1 seaborn==0.13.2 SecretStorage==3.3.3 sentry-sdk==1.44.1 setproctitle==1.3.3 six==1.16.0 smart-open==7.0.4 smashed==0.21.5 smmap==5.0.1 soupsieve==2.5 sympy==1.12 threadpoolctl==3.4.0 tokenizers==0.19.1 tomli==2.0.1 torch==2.1.0+cu118 torchaudio==2.1.0+cu118 torchmetrics==1.3.2 torchvision==0.16.0+cu118 tqdm==4.66.2 transformers @ git+https://github.com/huggingface/transformers@73014b561d5f88d728e46a57d346f516fefe3f2d triton==2.1.0 trouting==0.3.3 twine==5.0.0 types-setuptools==69.2.0.20240317 typing_extensions==4.11.0 tzdata==2024.1 uniseg==0.8.0 uritemplate==4.1.1 urllib3==1.26.18 wandb==0.16.6 wcwidth==0.2.13 websocket-client==1.7.0 wrapt==1.16.0 xxhash==3.4.1 yarl==1.9.4 zipp==3.18.1 zstandard==0.22.0

2015aroras commented 2 months ago

Hi Ryan,

I just put up some docs yesterday about the types of checkpoints OLMo has, including how to convert from OLMo to HF: Checkpoints.md. Try using convert_olmo_to_hf_new.py instead (slightly different arguments) and let us know if you still have issues.

2015aroras commented 2 months ago

The main README was outdated too, I have updated it now: https://github.com/allenai/OLMo/pull/589

ryanyxw commented 2 months ago

Hey Shane,

Thank you for your speedy reply and fix! This is really helpful.

I do have another quick question related to your change. I noticed that the convert_olmo_to_hf_new now takes in a tokenizer JSON path. The readme specified this to be tokenizers/allenai_gpt-neox-olmo-dolma-v1_5.json, which seems to match up with the tokenizers released on hf hub. However, the configs in configs/official/OLMo-1B.yaml or configs/official/OLMo-7B.yaml or all the checkpoint configs declare tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json to be the tokenizer identifier. These two json files seem to describe different tokenizers with different special tokens.

Is this discrepancy intentional? Or does any one of the tokenizers work?

Thanks!

2015aroras commented 2 months ago

My rough understanding regarding our tokenization is:

  1. The tokenizer in the config doesn't make a difference to the pretraining runs, since the data we train on is already tokenized and converted to numbers. We had tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json in our config when we ran pretraining, hence it is in our release json.
  2. Our end of string token in the pretokenized data had id 50279 instead of 0 for some reason, so we changed our tokenizer in order to have 50279 as the end of string token. Hence the use of the newer tokenizers/allenai_gpt-neox-olmo-dolma-v1_5.json in HF Hub. This is probably the tokenizer to use for inference and tokenizing new data.
ryanyxw commented 2 months ago

Thank you so much! Everything seems to be working :)