pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 482 forks source link

TPU memory use increased significantly in torch/xla - 2.6.0.dev20241107 #8423

Open dudulightricks opened 3 days ago

dudulightricks commented 3 days ago

🐛 Bug

I'm working with nightly versions of torch/xla on TPU. When moving from torch==2.6.0.dev20241106+cpu to torch==2.6.0.dev20241107, I see significantly increased use of the TPU memory for SPMD training (x 2.5), and in some settings, it also crashes due to OOM. The newest nightly still hasn't solved this problem. I suspect it might be some change in torch that affects SPMD training in torch_xla.

Environment

JackCaoG commented 3 days ago

hmm this is a bit weird, looking at the commit history

image

if the breaking change is coming from 11/06 -> 11/07 the change must be merged in 11/06 but there is nothing suspicious merged that date.

JackCaoG commented 3 days ago

I will try to repo with the https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py which is doing fsdpv2 sharding with the decoder only model and see if I can repo the memory issue.