Open ikergarcia1996 opened 2 years ago
Thanks for filing this issue @ikergarcia1996
@thisisalbertliang can you please take a look at this issue?
Hey @ikergarcia1996. This is a known problem and we are working on a long term fix (hopefully to land by q3). One workaround I can think of is to explicitly insert xm.mark_step
between steps.
My guess is that you are currently using parallel_loader
(if you are not, you should since it will upload the data async). parallel_loader
will call mark_step for you to cut the graph every x batch_count and let xla compiler to compile and execute the compilation.
The issue here is that number of parameter exceed the TPU smem limit, one option to avoid this is to cut the graph more frequent which will result in smaller graph with less parameter. if batches_per_execution
in your setting is already 1, you can consier add xm.mark_step
between layers, or do it after forward pass so forward and backward will become 2 graph instead of 1.
Another thing is is weird is you run into this issue at step 9. This seems to indicate that number of parameter increases as you training the model(since step 1-8 run without error). This is unexpected since the graph should be the same for every step otherwise it will keep recompiling. One quick way to check is use PT_XLA_DEBUG=1
env var and it will print some auto-debug message. Please checkout https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#perform-a-auto-metrics-analysis
I have the same issue
We also encountered this issue when we are trying to scale to large transformer models with FSDP in #3431
Computation requires more parameters (4748) than supported (limit 3304).
Looking forward to the long-term fix on this!
The fix(runtime migration) is WIP, will update here when that is ready.
FYI @will-cromar PJRT has an option to tuplify the input during compilation and execution, we should (maybe optionally) use https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_client.h#L222 when we switch to PJRT.
🐛 Bug
I am trying to train a model from this repository: https://github.com/ikergarcia1996/Self-Driving-Car-in-Video-Games using a TPU v3-8 VM Even when I train a tiny 9M parameter model with a small batch size, I get the following error:
The error always seems to happen at step 9. The model works as expected when running in a GPU/CPU.
I found that other people also found this issue with large models (#1963), but in my case it happens with models of any size.
To Reproduce
I am trying to train this model: https://github.com/ikergarcia1996/Self-Driving-Car-in-Video-Games/blob/master/model.py#L778 from this repository: https://github.com/ikergarcia1996/Self-Driving-Car-in-Video-Games
Here is a small colab notebook to reproduce the issue. There are no TPU available right now in colab so I cannot test if I get the same error in colab. I am using a TPU v3-8 VM: https://colab.research.google.com/drive/1nVbJooUMvMMc8V9F6ioqrYkhyuiveN1i?usp=sharing
Expected behavior
The model works fine when training with a GPU/CPU
Environment
tensorflow 2.9.0 (tf-nightly, if I use the stable release i a get a weird error: "DefaultDeviceShapeRepresentation not available in this library" error) torch 1.11.0 torch-xla 1.11 pytorch-lightning==1.6.0rc1 (installed from source, wandb crashes with the stable release) cloud-tpu-client==0.10
TPU
Full environment
Additional context
Full Traceback