kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

TpuEmbeddingEngine_WriteParameters not available in this library. #202

Closed nikhilanayak closed 2 years ago

nikhilanayak commented 2 years ago

I followed all of the instructions in the training guide but when I run the device_train script, I get this error:

2022-02-23 07:56:56.271731: F external/org_tensorflow/tensorflow/core/tpu/tpu_library_init_fns.inc:104] TpuEmbeddingEngine_WriteParameters not available in this library.

This is my exact command for the training process:

python3 device_train.py --config=configs/6B.json --tune-model-path=gs://nnrap/step_383500
whoislimshady commented 2 years ago

check jax version ig you are using 0.2.16 but the correct version in order to run the training is 0.2.12

nikhilanayak commented 2 years ago

If I run pip3 list | grep jax, it returns:

jax                          0.2.12
jaxlib                       0.3.0
nikhilanayak commented 2 years ago

Also @whoislimshady when I run it, it crashes and prints out

2022-02-23 17:35:25.749201: F external/org_tensorflow/tensorflow/core/tpu/tpu_library_init_fns.inc:104] TpuEmbeddingEngine_WriteParameters not available in this library.
Aborted (core dumped)
mrseeker commented 2 years ago

jax 0.2.12 won't run and crashes with this error. Using

pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

works but introduces new issues.

whoislimshady commented 2 years ago

@nikhilanayak i also faced the same issue but somehow it got resolved by just changing version from 0.2.16 to 12 i am sharing the packages i have on my env and i am able to train model hope this helps absl-py==0.12.0 aiohttp==3.8.1 aiohttp-cors==0.7.0 aioredis==2.0.1 aiosignal==1.2.0 anyio==3.5.0 appdirs==1.4.4 asgiref==3.5.0 astunparse==1.6.3 async-timeout==4.0.2 attrs==19.3.0 Automat==0.8.0 bcrypt==3.2.0 best-download==0.0.9 black==22.1.0 blessings==1.7 BLEURT @ https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip blinker==1.4 cachetools==4.2.2 certifi==2020.12.5 cffi==1.15.0 chardet==4.0.0 charset-normalizer==2.0.11 chex==0.1.0 clang==5.0 click==7.1.2 cloud-init==21.1 cloud-tpu-client==0.10 cloudpickle==1.3.0 colorama==0.4.3 colorful==0.5.4 command-not-found==0.3 configobj==5.0.6 constantly==15.1.0 cryptography==2.8 Cython==0.29.23 DataProperty==0.54.2 datasets==1.15.1 dbus-python==1.2.16 Deprecated==1.2.13 dill==0.3.4 distlib==0.3.1 distro==1.4.0 distro-info===0.23ubuntu1 dm-haiku==0.0.5 dm-tree==0.1.6 docker-pycreds==0.4.0 dyNET38==2.1 einops==0.3.2 entrypoints==0.3 fabric==2.6.0 fastapi==0.73.0 filelock==3.0.12 Flask==1.1.4 flatbuffers==1.12 frozenlist==1.3.0 fsspec==2022.1.0 ftfy==6.1.1 func-timeout==4.3.5 future==0.18.2 gast==0.4.0 gitdb==4.0.9 GitPython==3.1.26 google-api-core==1.28.0 google-api-python-client==1.8.0 google-auth==1.30.1 google-auth-httplib2==0.1.0 google-auth-oauthlib==0.4.4 google-cloud-core==1.7.2 google-cloud-storage==1.36.2 google-crc32c==1.3.0 google-pasta==0.2.0 google-resumable-media==1.3.3 googleapis-common-protos==1.53.0 gpustat==0.6.0 grpcio==1.38.0 h11==0.13.0 h5py==3.1.0 httplib2==0.19.1 huggingface-hub==0.4.0 hyperlink==19.0.0 idna==2.10 importlib-metadata==1.5.0 incremental==16.10.1 iniconfig==1.1.1 invoke==1.6.0 itsdangerous==1.1.0 jax==0.2.12 jaxlib==0.1.67 jieba==0.42.1 Jinja2==2.10.1 jmp==0.0.2 joblib==1.1.0 jsonlines==2.0.0 jsonpatch==1.22 jsonpointer==2.0 jsonschema==3.2.0 keras==2.6.0 Keras-Applications==1.0.8 keras-nightly==2.6.0.dev2021052400 Keras-Preprocessing==1.1.2 keyring==18.0.1 language-selector==0.1 launchpadlib==1.10.13 lazr.restfulclient==0.14.2 lazr.uri==1.0.3 libclang==13.0.0 lm-dataformat==0.0.20 lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness/@c74ca57c51c5fb0889b955878f00fe6d60ba393c Markdown==3.3.4 MarkupSafe==1.1.0 mbstrdecoder==1.1.0 mesh-transformer @ file:///home/harsh/gptj mock==4.0.3 more-itertools==4.2.0 msgfy==0.2.0 msgpack==1.0.3 multidict==6.0.2 multiprocess==0.70.12.2 mypy-extensions==0.4.3 nagisa==0.2.7 netifaces==0.10.4 nltk==3.7 numexpr==2.7.2 numpy==1.22.2 nvidia-ml-py3==7.352.0 oauth2client==4.1.3 oauthlib==3.1.0 openai==0.6.4 opencensus==0.8.0 opencensus-context==0.1.2 opt-einsum==3.3.0 optax==0.0.9 packaging==20.9 pandas==1.4.0 paramiko==2.9.2 pathlib2==2.3.7.post1 pathspec==0.9.0 pathtools==0.1.2 pathvalidate==2.5.0 pathy==0.6.1 pbr==5.8.1 pexpect==4.6.0 Pillow==8.2.0 platformdirs==2.5.0 pluggy==0.13.1 portalocker==2.3.2 prometheus-client==0.13.1 promise==2.3 protobuf==3.17.1 psutil==5.9.0 py==1.11.0 py-spy==0.3.11 pyarrow==7.0.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 pybind11==2.6.2 pycountry==20.7.3 pycparser==2.21 pydantic==1.9.0 PyGObject==3.36.0 PyHamcrest==1.9.0 PyJWT==1.7.1 pymacaroons==0.13.0 PyNaCl==1.3.0 pyOpenSSL==19.0.0 pyparsing==2.4.7 pyrsistent==0.15.5 pyserial==3.4 pytablewriter==0.58.0 pytest==6.2.3 python-apt==2.0.0+ubuntu0.20.4.4 python-dateutil==2.8.2 python-debian===0.1.36ubuntu1 pytz==2021.1 PyYAML==5.4.1 ray==1.4.1 redis==4.1.3 regex==2022.1.18 rehash==1.0.0 requests==2.25.1 requests-oauthlib==1.3.0 requests-unixsocket==0.2.0 rouge-score==0.0.4 rsa==4.7.2 sacrebleu==1.5.0 sacremoses==0.0.47 scikit-learn==1.0.2 scipy==1.6.3 SecretStorage==2.3.1 sentencepiece==0.1.96 sentry-sdk==1.5.5 service-identity==18.1.0 shortuuid==1.0.8 simplejson==3.16.0 six==1.15.0 smart-open==5.2.1 smmap==5.0.0 sniffio==1.2.0 sos==4.1 sqlitedict==1.6.0 ssh-import-id==5.10 starlette==0.17.1 systemd-python==234 tabledata==1.3.0 tabulate==0.8.9 tb-nightly==2.6.0a20210524 tcolorpy==0.1.1 tensorboard==2.6.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorflow==2.8.0 tensorflow-cpu==2.6.3 tensorflow-estimator==2.6.0 tensorflow-io-gcs-filesystem==0.24.0 termcolor==1.1.0 testresources==2.0.1 tf-estimator-nightly==2.8.0.dev2021122109 tf-slim==1.1.0 threadpoolctl==3.1.0 tokenizers==0.11.4 toml==0.10.2 tomli==2.0.1 toolz==0.11.2 torch==1.8.1 torch-xla==1.8.1 torchvision==0.9.1 tqdm==4.62.3 tqdm-multiprocess==0.0.11 transformers==4.16.2 Twisted==18.9.0 typepy==1.3.0 typer==0.4.0 typing-extensions==3.10.0.2 ubuntu-advantage-tools==20.3 ufw==0.36 ujson==5.1.0 unattended-upgrades==0.1 uritemplate==3.0.1 urllib3==1.26.4 uvicorn==0.17.4 virtualenv==20.4.7 wadllib==1.3.3 wandb==0.12.10 wcwidth==0.2.5 Werkzeug==1.0.1 wrapt==1.12.1 xxhash==2.0.2 yarl==1.7.2 yaspin==2.1.0 zipp==1.0.0 zope.interface==4.7.1 zstandard==0.15.0

mrseeker commented 2 years ago

@whoislimshady It looks like you are using torch 1.8.1? I got 1.11.1...

mrseeker commented 2 years ago

Just to be sure that everyone is on the correct page: What version of TPU-VM are you actually running? It might be that the version that people are running is actually incorrect, and that a lower version actually performs better than the newer ones. I am running with --version tpu-vm-tf-2.8.0, but I think this version might need be actually a lower one (2.6.3)

safeeazeem commented 2 years ago

For fine tuning I have always used tpu version v2-alpha. Havent come across any errors so far.

nikhilanayak commented 2 years ago

Here's a full set of commands to reproduce the error: (The GPT-J-6B/step_383500 weights are already uploaded to gcloud)

gcloud alpha compute tpus tpu-vm create my_tpu_vm --zone=us-central1-a --accelerator-type=v3-8 --version=v2-alpha gcloud alpha compute tpus tpu-vm ssh my_tpu_vm --zone us-central1-a --project [MY PROJECT ID] $ git clone https://github.com/kingoflolz/mesh-transformer-jax $ cd mesh-transformer-jax $ echo gs://[my_tfrecord] > data/main.train.index $ cat configs/6B_roto_256.json (Change it so this is true)

{
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
  "gradient_accumulation_steps": 16,

  "warmup_steps": 3000,
  "anneal_steps": 300000,
  "lr": 5e-5,
  "end_lr": 1e-5,
  "weight_decay": 0.1,
  "total_steps": 350000,

  "tpu_size": 8,

  "bucket": "[my bucket (without the gs://)]",
  "model_dir": "mesh_jax_pile_6B_rotary",

  "train_set": "main.train.index",
  "val_set": {},

  "eval_harness_tasks": [],

  "val_batches": 100,
  "val_every": 350001,
  "ckpt_every": 500,
  "keep_every": 10000,

  "name": "GPT3_6B_pile_rotary",
  "wandb_project": "mesh-transformer-jax",
  "comment": ""
}

$ pip install -r requirements.txt --use-deprecated=legacy-resolver $ pip install "jax[tpu]>=0.2.12" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html $ pip install jaxlib==0.1.68 $ export LD_LIBRARY_PATH=/usr/local/lib $ python3 device_train.py --config=configs/6B_roto_256.json --tune-model-path=gs://MY-BUCKET/step_383500/

mrseeker commented 2 years ago

Okay, I fixed the issue this way:

This seemed to fix a LOT of my issues I was having, and its now working :)

mosmos6 commented 2 years ago

@mrseeker Your solution fixed my problem too. Thank you for sharing it!