google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.44k stars 263 forks source link

Circular Pipelining #701

Closed gobbleturk closed 2 months ago

gobbleturk commented 2 months ago

Circular part of #677, followup to linear support in #691

Copying the description from #677:

This implementation very closely follows the pax implementation here, collaborated and obtained permission from pax authors

This adds support for both circular and regular pipelining, but not loop fission. In addition this supports multiple layers per stage.

Currently there is no overlap between the main pipeline communication (activation forwarding between stages via collective permute) and the compute, this is a WIP with the compiler team. Over ICI this is not such a big hit, but over DCN we cannot obtain great performance until we develop a process to overlap. This implementation does roughly get the expected performance - e.g. the two major detrimental factors of bubble_size and exposed comms account for the realized step time.

Example runs over DCN (pipeline parallelism over DCN, fsdp within ICI):

4x v4-16

3x v4-128 "Cheated" the sizes to get best MFU by maxing out num layers and MLP,