pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 471 forks source link

Random OOM and crashes #8216

Open alexanderswerdlow opened 1 week ago

alexanderswerdlow commented 1 week ago

❓ Questions and Help

I've found that I'm unable to train more than ~20-80K steps without a crash and it's difficult to figure out how to debug this. In a typical PyTorch training run, I would get a clear OOM message at a particular line, or any other error and this would be printed to log/console.

However, about half the time, my training run simply exits with no message on any rank, and the other half the time it's clearly due to memory with a "Resource Exhausted" message. The issue is it's not clear where this new allocation happens (I have a fairly standard decoder based transformer, not even any eval batches, and I'm not using any eager modes). I tried to switch to nightly to get a recent dataloader memory fix, but that doesn't seem to fix it.

I know there are many flags that can be used for debugging, but it's unclear exactly which ones can be used during training without a large performance hit. I've done all the suggested steps including profiling, and making sure there isn't re-compiliation happening, etc. Perhaps it would be good to clarify the impact of the flags somewhere to make it clear which are safe—and any other advice on how to debug this would be great!

Also, I should note this occurs with SPMD multi-node training, I have not spent time testing other modes, but this has happened with between 2 and 8 TPUv4 VMs, both in DDP-like configurations and several other mesh configurations

JackCaoG commented 3 days ago

Hmm, good question. It seems like there are 2 problems here

  1. training code OOM
  2. error discovery is difficult(where the OOM happened for example)

For 2 it is because from XLA perspective it is executing a compiled program and in the middle of that it got OOM(assuming there is no recompilation, the OOM is runtime OOM). It is hard for PyTorch/XLA to map this OOM event back to the specified python line. However the silent exiting case seems weird, when I force the runtime error in SPMD it always throw a error so it is hard for me to think of what happened there.

Regarding the code OOM, try

watch -n0 tpu-info

you should already have the tpu-info installed when you install the libtpu if you are using nightly. Try to see if the memory usage slowly going up across runs, I am wondering if there are some small tensors slowly accumulated in the HBM.