pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 482 forks source link

Speeding up computation while using SPMD on large TPU pod #7987

Open dudulightricks opened 2 months ago

dudulightricks commented 2 months ago

❓ Questions and Help

When running on vp-128 TPU pod (even when sharding only by batch dimension) we are experiencing very low performance comparing to the same pod without SPMD.

Do you have any tips how to increase the performance? some SPMD arguments? things we need to think about when using it? anything that might help because right now the performance is lower than regular in a factor. @JackCaoG

JackCaoG commented 2 months ago

do you have a profile(xplane file) you can share? it is hard to guess what's happening without looking at the profile.

giuliano-97 commented 2 months ago

@JackCaoG I've been trying to fine-tune Gemma-2 9B on v4 / v5 pods with FSDP + SPMD using HF transfomers and torch XLA and I also have the feeling that training is slow, do you have some benchmarks on training LLMs with the same setup?

JackCaoG commented 2 months ago

replied in the other thread.