Open markusheimerl opened 4 months ago
Hi @markusheimerl, thanks for the issue! It looks like you are setting CKPT_PATH=/home/markusheimerl/gemma_ckpt/
. This should be unrelated to the underlying failure you are experiencing, but the CKPT_PATH
should be the path to the actual weights, not the directory.
It looks to me like this is a Torch XLA issue. It is possible that this can be fixed by using a newer version of the base container here. If not, maybe we need to put in an issue with Torch XLA.
What I would recommend first is trying us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20240226
and seeing if that fixes the issue. Otherwise, consider reaching out to the Torch XLA team. I can also see if I can get access to a VM with that topology to try to replicate your error.
Hi @markusheimerl , it seems you are using v4-16 TPU which 2 host VMs. This multi-host setup is currently not supported.
To test it on TPU, I suggest you try to run it on v4-8 / v5e-8 which is a single-host TPU architecture and has 1 VM. You should be able to run the command on v4-8 / v5e-8 out-of-the-box.
Hi @michaelmoynihan, I also have the Failed to get global TPU topology
on tpu v4-8, so I followed your advice:
What I would recommend first is trying
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20240226
and seeing if that fixes the issue. Otherwise, consider reaching out to the Torch XLA team. I can also see if I can get access to a VM with that topology to try to replicate your error.
So I made the Docker file with contents
FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20240226
RUN pip install datasets peft transformers trl
Then I ran this:
sudo docker build -t my-tpu-pytorch-image .
sudo docker run -v /home/me/finetune:/workspace my-tpu-pytorch-image python /workspace/train.py
where train.py is this script for training gemma7b https://huggingface.co/google/gemma-7b/blob/main/examples/example_fsdp.py In the result I got
(v_xla) me@t1v-n-w-0:~/finetune$ sudo docker run -v /home/me/finetune:/workspace my-tpu-pytorch-image python /workspace/train.py
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1709816278.547533 1 pjrt_api.cc:100] GetPjrtApi was found for tpu at /usr/local/lib/python3.8/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1709816278.547628 1 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1709816278.547636 1 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
Traceback (most recent call last):
File "/workspace/train.py", line 15, in <module>
device = xm.xla_device()
File "/usr/local/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 211, in xla_device
return runtime.xla_device(n, devkind)
File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 88, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 117, in xla_device
return torch.device(torch_xla._XLAC._xla_get_default_device())
RuntimeError: Bad StatusOr access: INTERNAL: Failed to get global TPU topology.
There is again the same error INTERNAL: Failed to get global TPU topology.
, but i also see that there is something wrong with JPRT. I will try that on reproduce that another env.
I tried to run this script on tpu v3-8 and with slight modifications of the script (I lowered the model to Gemma-2b - because of resource_exhausted bug) could start my script with command (without docker)
python train.py
The script is working, looks like i was using wrong vm version when creating TPU, and I forgot about setting environment variables Correct way to create tpu v4-8
gcloud compute tpus tpu-vm create myname --zone=my-zone --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0
and add this env var
PJRT_DEVICE=TPU XLA_USE_SPMD=1