HybridMesh is a utility to generate a mapping from accelerator device IDs to logical mesh coordinates. Today it doesn't support splitting a physical axis. So you can't use e.g. ICI mesh shape (64, 4) with TPU topology 16 x 16 (4 < 16).
Motivation
The motivation is to scale a model to multiple pods of Trillium TPUs. For example, we may want to use:
ICI mesh shape: (64, 4)
DCN mesh shape: (2, 1)
over two pods of Trillium 16x16 TPUs. That requires splitting one of the 16 physical axis to 4x4 in order to map to an ICI axis size of 4.
Pitch
To do this we can probably reference what JAX does these days, since parts of HybridMesh was copied from JAX.
🚀 Feature
HybridMesh
is a utility to generate a mapping from accelerator device IDs to logical mesh coordinates. Today it doesn't support splitting a physical axis. So you can't use e.g. ICI mesh shape(64, 4)
with TPU topology16 x 16
(4 < 16).Motivation
The motivation is to scale a model to multiple pods of Trillium TPUs. For example, we may want to use:
(64, 4)
(2, 1)
over two pods of Trillium 16x16 TPUs. That requires splitting one of the
16
physical axis to4x4
in order to map to an ICI axis size of4
.Pitch
To do this we can probably reference what JAX does these days, since parts of
HybridMesh
was copied from JAX.