pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

Support splitting physical axis in HybridMesh #8381

Open tengyifei opened 1 week ago

tengyifei commented 1 week ago

🚀 Feature

HybridMesh is a utility to generate a mapping from accelerator device IDs to logical mesh coordinates. Today it doesn't support splitting a physical axis. So you can't use e.g. ICI mesh shape (64, 4) with TPU topology 16 x 16 (4 < 16).

Motivation

The motivation is to scale a model to multiple pods of Trillium TPUs. For example, we may want to use:

over two pods of Trillium 16x16 TPUs. That requires splitting one of the 16 physical axis to 4x4 in order to map to an ICI axis size of 4.

Pitch

To do this we can probably reference what JAX does these days, since parts of HybridMesh was copied from JAX.