Open windmaple opened 9 months ago
OK, seems that code is for Cloud TPU only as mentioned in this HF blog. Then this is a feature request.
@alanwaketan
🐛 Bug
Not sure if this is a feature request or bug. I took the SPMD Gemma ft code from Hugging Face and tried to run it on Kaggle; it didn't work.
trl seems to have an issue there.
To Reproduce
See my Kaggle notebook.
Expected behavior
Ideally it should run.
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
- torch_xla version:
Stock Kaggle env.
Additional context
Kaggle is using Older version of torch-xla where distributed.spmd is not implemented
OK, seems that code is for Cloud TPU only as mentioned in this HF blog. Then this is a feature request.
kaggle is using older version of torch-xla where torch.distributed.spmd was not implemented , would recommend to upgrade torch-xla
!pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
@windmaple You need to install the nightly torch-xla and torch.
Kaggle VM just silently dies after upgrading torch and torch-xla
Kaggle VM just silently dies after upgrading torch and torch-xla
!pip uninstall -y tensorflow
!pip install tensorflow-cpu #optional
It helped me get a little further with 2.2.0. But still,
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[9], line 42
34 fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": [
35 "GemmaDecoderLayer"
36 ],
37 "xla": True,
38 "xla_fsdp_v2": True,
39 "xla_fsdp_grad_ckpt": True}
41 # Finally, set up the trainer and train the model.
---> 42 trainer = SFTTrainer(
43 model=model,
44 train_dataset=data,
45 args=TrainingArguments(
46 per_device_train_batch_size=64, # This is actually the global batch size for SPMD.
47 num_train_epochs=100,
48 max_steps=-1,
49 output_dir="./output",
50 optim="adafactor",
51 logging_steps=1,
52 dataloader_drop_last = True, # Required for SPMD.
53 fsdp="full_shard",
54 fsdp_config=fsdp_config,
55 ),
56 peft_config=lora_config,
57 dataset_text_field="quote",
58 max_seq_length=max_seq_length,
59 packing=True,
60 )
62 trainer.train()
File /usr/local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:299, in SFTTrainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics, peft_config, dataset_text_field, packing, formatting_func, max_seq_length, infinite, num_of_sequences, chars_per_token, dataset_num_proc, dataset_batch_size, neftune_noise_alpha, model_init_kwargs, dataset_kwargs)
293 if tokenizer.padding_side is not None and tokenizer.padding_side != "right":
294 warnings.warn(
295 "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to "
296 "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code."
297 )
--> 299 super().__init__(
300 model=model,
301 args=args,
302 data_collator=data_collator,
303 train_dataset=train_dataset,
304 eval_dataset=eval_dataset,
305 tokenizer=tokenizer,
306 model_init=model_init,
307 compute_metrics=compute_metrics,
308 callbacks=callbacks,
309 optimizers=optimizers,
310 preprocess_logits_for_metrics=preprocess_logits_for_metrics,
311 )
313 # Add tags for models that have been loaded with the correct transformers version
314 if hasattr(self.model, "add_model_tags"):
File /usr/local/lib/python3.10/site-packages/transformers/trainer.py:653, in Trainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
649 if self.is_fsdp_xla_v2_enabled:
650 # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
651 # Tensor axis is just a placeholder where it will not be used in FSDPv2.
652 num_devices = xr.global_runtime_device_count()
--> 653 xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
AttributeError: module 'torch_xla.distributed.spmd' has no attribute 'set_global_mesh'
What's the right way to install nightly? I searched around but couldn't find it.
@windmaple Here is the instructions to install nightly: https://github.com/pytorch/xla#available-docker-images-and-wheels
I had the same problem as @windmaple:
AttributeError: module 'torch_xla.distributed.spmd' has no attribute 'set_global_mesh'
As @alanwaketan suggested I installed nightly build of xla in fresh conda env with specified packages.
conda create -n v_xla python=3.10
conda activate v_xla
pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
pip install datasets peft transformers trl
python train.py
Where train.py is this script https://huggingface.co/google/gemma-7b/blob/main/examples/example_fsdp.py
Running this script results in the following error:
Traceback (most recent call last):
File "/home/me/finetune/train.py", line 5, in <module>
import torch_xla
File "/home/me/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch_xla/__init__.py", line 7, in <module>
import _XLAC
ImportError: /home/me/miniconda3/envs/v_xla/lib/python3.10/site-packages/_XLAC.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104impl3cow23materialize_cow_storageERNS_11StorageImplE
I am looking for workarounds.
@PawKanarek I'm stuck here too.
To resolve this problem
ImportError: /home/me/miniconda3/envs/v_xla/lib/python3.10/site-packages/_XLAC.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104impl3cow23materialize_cow_storageERNS_11StorageImplE
You have to update pytorch to nightly
conda install pytorch-nightly::pytorch
But after this i got new problem
File "/home/me/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch_xla/runtime.py", line 124, in xla_device
return torch.device(torch_xla._XLAC._xla_get_default_device())
RuntimeError: Bad StatusOr access: INTERNAL: Failed to get global TPU topology.
I found similar issues: https://github.com/google/gemma_pytorch/issues/25, https://github.com/Lightning-AI/pytorch-lightning/issues/18932
@PawKanarek What's your libtpu version?
@windmaple Yea, usually you just need nightly for both pytorch and pytorch/xla. pytorch/xla heavily depends on pytorch.
@alanwaketan I think that my libtpu version is tpu-vm-pt-2.0
, this is based on the command that I used to create my TPU v4-8.
gcloud compute tpus tpu-vm create my-tpu-name --zone=us-central2-b --accelerator-type=v4-8 --version=tpu-vm-pt-2.0
Oh, I see on documentation https://cloud.google.com/tpu/docs/supported-tpu-configurations#tpu_v4 that I should use tpu-vm-v4-pt-2.0
. Thanks for the insight. ;)
@PawKanarek libtpu is a pip pkg, you can grep it from pip list.
The latest version is:
pip list | grep libtpu
libtpu-nightly 0.1.dev20240213
If yours is older than this, you can update it via:
pip install torch-xla[tpuvm]
I've installed this package
libtpu-nightly 0.1.dev20240213
and I still have the same
File "/home/me/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch_xla/runtime.py", line 124, in xla_device
return torch.device(torch_xla._XLAC._xla_get_default_device())
RuntimeError: Bad StatusOr access: INTERNAL: Failed to get global TPU topology.
@PawKanarek Could be a hardware issue then... Can you try recreate a new TPU vm?
tpu-vm-v4-pt-2.0
is a bit old image, do you mind following https://cloud.google.com/tpu/docs/run-calculation-pytorch to use vm version tpu-ubuntu2204-base
. If the framrwork and libtpu version matched and it still doesn't work, it is usually usually the hardware issue or driver issue.
I created new machine with command
gcloud compute tpus tpu-vm create my-name --zone=us-central2-b --accelerator-type=v4-8 --version=tpu-ubuntu2204-base
installed all required packages on and now when i try to run this script https://huggingface.co/google/gemma-7b/blob/main/examples/example_fsdp.py I got this error:
(v_xla) me@tpu-1:~/finetune$ python train.py
Aborted (core dumped)
I will look for more specific errors :)
@JackCaoG Now i created the v4-8 machine with this vm version: tpu-vm-v4-pt-2.0
gcloud compute tpus tpu-vm create myname --zone=us-central2-b --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0
And now Im getting different message, but at least it's now readable :)
python server/server.py
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710013726.247688 30296 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/me/miniconda3/envs/tpu_v4/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710013726.247769 30296 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710013726.247774 30296 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
/home/me/miniconda3/envs/tpu_v4/lib/python3.10/site-packages/torch_xla/runtime.py:247: UserWarning: Replicating tensors already initialized on non-virtual XLA device for SPMD to force SPMD mode. This is one-time overhead to setup, and to minimize such, please set SPMD mode before initializting tensors (i.e., call use_spmd() in the beginning of the program).
warnings.warn(
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 3.07it/s]
/home/me/miniconda3/envs/tpu_v4/lib/python3.10/site-packages/transformers/training_args.py:1815: FutureWarning: `--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--hub_token` instead.
warnings.warn(
/home/me/miniconda3/envs/tpu_v4/lib/python3.10/site-packages/transformers/training_args.py:1827: FutureWarning: `--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this argument (in this case google/gemma-2-it).
warnings.warn(
https://symbolize.stripped_domain/r/?trace=7f8ddd4d4953,7f8ea111e3bf,7f8de5b4364d,7f8dde56762d,7f8de5b46273,7f8dddde807a,7f8dddbdb4ea,7f8e90515509&map=
*** SIGSEGV (@0x1d8), see go/stacktraces#s15 received by PID 30296 (TID 31841) on cpu 195; stack trace: ***
PC: @ 0x7f8ddd4d4953 (unknown) torch_xla::runtime::PjRtComputationClient::ExecuteReplicated()::{lambda()#1}::operator()()
@ 0x7f8d6c18c6a7 928 (unknown)
@ 0x7f8ea111e3c0 1984 (unknown)
@ 0x7f8de5b4364e 32 std::_Function_handler<>::_M_invoke()
@ 0x7f8dde56762e 288 Eigen::ThreadPoolDevice::parallelFor()
@ 0x7f8de5b46274 576 tsl::thread::ThreadPool::ParallelFor()
@ 0x7f8dddde807b 1168 torch_xla::runtime::PjRtComputationClient::ExecuteReplicated()
@ 0x7f8dddbdb4eb 624 torch_xla::XLAGraphExecutor::ScheduleSyncTensorsGraph()::{lambda()#1}::operator()()
@ 0x7f8e9051550a (unknown) torch::lazy::MultiWait::Complete()
@ ... and at least 1 more frames
https://symbolize.stripped_domain/r/?trace=7f8ddd4d4953,7f8d6c18c6a6,7f8ea111e3bf,7f8de5b4364d,7f8dde56762d,7f8de5b46273,7f8dddde807a,7f8dddbdb4ea,7f8e90515509&map=
E0309 19:48:53.091365 31841 coredump_hook.cc:442] RAW: Remote crash data gathering hook invoked.
E0309 19:48:53.091373 31841 coredump_hook.cc:481] RAW: Skipping coredump since rlimit was 0 at process start.
E0309 19:48:53.091379 31841 client.cc:269] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0309 19:48:53.091381 31841 coredump_hook.cc:537] RAW: Sending fingerprint to remote end.
E0309 19:48:53.091395 31841 coredump_hook.cc:546] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0309 19:48:53.091399 31841 coredump_hook.cc:598] RAW: Dumping core locally.
E0309 19:48:53.337414 31841 process_state.cc:807] RAW: Raising signal 11 with default behavior
Segmentation fault (core dumped)
Can you follow https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#sanity-check to run a resnet with fakedata? I am not sure if it is a env setup issue or gemma issue in your case.
Thanks for advice, sanity check looks good on this tpu imports:
python
Python 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_xla
>>> print(torch.__version__)
2.3.0.dev20240309
>>> print(torch_xla.__version__)
2.3.0+git6043185
simple calculation
python3
Python 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_xla.core.xla_model as xm
>>> t1 = torch.tensor(100, device=xm.xla_device())
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710270930.199793 326792 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/raix/miniconda3/envs/tpu_v4/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710270930.199885 326792 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710270930.199890 326792 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
>>> t2 = torch.tensor(200, device=xm.xla_device())
>>> print(t1 + t2)
tensor(300, device='xla:0')
>>>
imagenet
Epoch 18 train end 20:13:54
| Test Device=xla:0/0 Step=0 Epoch=18 Time=20:13:54
| Test Device=xla:0/1 Step=0 Epoch=18 Time=20:13:54
| Test Device=xla:0/3 Step=0 Epoch=18 Time=20:13:54
| Test Device=xla:0/2 Step=0 Epoch=18 Time=20:13:54
| Test Device=xla:0/3 Step=20 Epoch=18 Time=20:13:54
| Test Device=xla:0/2 Step=20 Epoch=18 Time=20:13:54
| Test Device=xla:0/1 Step=20 Epoch=18 Time=20:13:54
| Test Device=xla:0/0 Step=20 Epoch=18 Time=20:13:54
| Test Device=xla:0/1 Step=40 Epoch=18 Time=20:13:54
| Test Device=xla:0/0 Step=40 Epoch=18 Time=20:13:54
| Test Device=xla:0/3 Step=40 Epoch=18 Time=20:13:54
| Test Device=xla:0/2 Step=40 Epoch=18 Time=20:13:54
| Test Device=xla:0/1 Step=60 Epoch=18 Time=20:13:55
| Test Device=xla:0/3 Step=60 Epoch=18 Time=20:13:55
| Test Device=xla:0/0 Step=60 Epoch=18 Time=20:13:55
| Test Device=xla:0/2 Step=60 Epoch=18 Time=20:13:55
| Test Device=xla:0/1 Step=80 Epoch=18 Time=20:13:55
| Test Device=xla:0/2 Step=80 Epoch=18 Time=20:13:55
| Test Device=xla:0/0 Step=80 Epoch=18 Time=20:13:55
| Test Device=xla:0/3 Step=80 Epoch=18 Time=20:13:55
Epoch 18 test end 20:13:55, Accuracy=100.00
Max Accuracy: 100.00%
@PawKanarek For Gemma, have you set the following env: PJRT_DEVICE=TPU XLA_USE_SPMD=1 ?
It seems that setting export PJRT_DEVICE=TPU
and export XLA_USE_SPMD=1
resolved the issue. I was certain I had exported the variables... The training now works though it occasionally crashes during training on larger datasets. But no problems on smaller datasets. Thanks!
It seems that setting
export PJRT_DEVICE=TPU
andexport XLA_USE_SPMD=1
resolved the issue. I was certain I had exported the variables... The training now works though it occasionally crashes during training on larger datasets. But no problems on smaller datasets. Thanks!
I would love to learn more about the crash as well! Do you mind open a new bug?
@windmaple @PawKanarek Are we good to close this issue?
The problem with AttributeError: module 'torch_xla.distributed.spmd' has no attribute 'set_global_mesh'
was resolved on my machine.
🐛 Bug
Not sure if this is a feature request or bug. I took the SPMD Gemma ft code from Hugging Face and tried to run it on Kaggle; it didn't work.
trl seems to have an issue there.
To Reproduce
See my Kaggle notebook.
Expected behavior
Ideally it should run.
Environment
Stock Kaggle env.
Additional context