Open dudulightricks opened 3 days ago
hmm this is a bit weird, looking at the commit history
if the breaking change is coming from 11/06 -> 11/07 the change must be merged in 11/06 but there is nothing suspicious merged that date.
I will try to repo with the https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py which is doing fsdpv2 sharding with the decoder only model and see if I can repo the memory issue.
🐛 Bug
I'm working with nightly versions of torch/xla on TPU. When moving from torch==2.6.0.dev20241106+cpu to torch==2.6.0.dev20241107, I see significantly increased use of the TPU memory for SPMD training (x 2.5), and in some settings, it also crashes due to OOM. The newest nightly still hasn't solved this problem. I suspect it might be some change in torch that affects SPMD training in torch_xla.
Environment