kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

TPU-V4 #255

Open wimjan123 opened 1 year ago

wimjan123 commented 1 year ago

How can one use this project to fine-tune using a TPU-v4 instance? I tried everything, but always get errors. Most commonly:

UserWarning: cloud_tpu_init failed: KeyError('v4-8') This a JAX bug; please report an issue at https://github.com/google/jax/issues _warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report " 2023-03-05 21:55:43.305762: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory 2023-03-05 21:55:43.941977: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2023-03-05 21:55:43.942070: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2023-03-05 21:55:43.942076: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) Traceback (most recent call last): File "device_train.py", line 191, in raise ValueError(msg) ValueError: each shard needs a separate device, but device count (1) < shard count (4)

Wingie commented 1 year ago

your jax version installed is wrong for your tpu version. (this repo is old) basically you have to keep trying installations and images (i use image v2-alpha on TPUv3-8) once this command works, then you have jax installed on your tpu working fine.

python3 -c "import jax; print(jax.devices())"  # should print TpuDevice

also, your libcudart errors means you need to uninstall your tensorflow and install tensorflow-cpu as you do not have a GPU on a TPU device.

i would recommend you go through https://github.com/ayaka14732/tpu-starter it can help with some errors you face.

wimjan123 commented 1 year ago

I use V2-alpha-tpu4 on TPUv4-8. The command to check if jax is installed returns this: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

leejason commented 1 year ago

According to the following, the number of TPU cores has changed from 8 to 4 for TPU v4.

Display the number of TPU cores available: jax.device_count() The number of TPU cores is displayed. If you are using a v4 TPU, this should be 4. If you are using a v2 or v3 TPU, this should be 8.

(source: https://cloud.google.com/tpu/docs/run-calculation-jax)

wimjan123 commented 1 year ago

Aha, I see. Is there any way to fine tune gpt-j using 4 tpu cores?

leejason commented 1 year ago

Aha, I see. Is there any way to fine tune gpt-j using 4 tpu cores?

I change the following from 8 to 4 in the configuration file.

"cores_per_replica": 4

wimjan123 commented 1 year ago

If I do that, I get a "AssertionError: Incompatible checkpoints" error

leejason commented 1 year ago

If I do that, I get a "AssertionError: Incompatible checkpoints" error

I forgot to mention that it's for pre-training from scratch. The above compatibility seems a valid issue since it's not clear whether the checkpoints on 8 cores can work on 4 cores.

wimjan123 commented 1 year ago

Is there any way to convert the checkpoints to, let's say, 4 shards?

leejason commented 1 year ago

Is there any way to convert the checkpoints to, let's say, 4 shards?

No idea but I guess not and didn't try. I plan to move forward to TPU v4.

mosmos6 commented 7 months ago

I'm curious how this attempt turned out. Has anyone succeeded in running GPT-J on TPU v4?

sokarblue13 commented 7 months ago

PARTE 1 YARSY.txt