Circular part of #677, followup to linear support in #691
Copying the description from #677:
This implementation very closely follows the pax implementation here, collaborated and obtained permission from pax authors
This adds support for both circular and regular pipelining, but not loop fission. In addition this supports multiple layers per stage.
Currently there is no overlap between the main pipeline communication (activation forwarding between stages via collective permute) and the compute, this is a WIP with the compiler team. Over ICI this is not such a big hit, but over DCN we cannot obtain great performance until we develop a process to overlap. This implementation does roughly get the expected performance - e.g. the two major detrimental factors of bubble_size and exposed comms account for the realized step time.
Example runs over DCN (pipeline parallelism over DCN, fsdp within ICI):
Model AI = 3 (since swiglu) MLP(28k) num_layers_per_stage (2) = 170k (approx)
This is actually pessimistic since ignoring attention
Hardware AI = (275 TFLOPs/s) / (1.6 GB/s DCN) also approx 170k
However this is not overlapped, so we expect steps take 50% time compute, 50% communication
Extra Bubble = (stages - 1) / (num_micro num_repeat) = 3/(42) = 3/8
So we expect (1 + 3/8) extra steps * 2x longer each step = 2.75x longer
In reality we see only roughly 2x longer, MFU of ~30% trace
Whereas a similar workload run on only 1 slice with pure FSDP the MFU is around 60% trace
3x v4-128
"Cheated" the sizes to get best MFU by maxing out num layers and MLP,
Expected extra step time= (1 + 2/24) * [(170k + 630k)/630k] = 1.37x
Actually see smaller extra step time, MFU of ~40% XPROF
In addition to the correctness tests provided in this PR, I also ran a sanity check convergence test on a v4-128 with ici_pipeline_parallelism=4 and got the same results as the convergence test XPK logs
Circular part of #677, followup to linear support in #691
Copying the description from #677:
This implementation very closely follows the pax implementation here, collaborated and obtained permission from pax authors
This adds support for both circular and regular pipelining, but not loop fission. In addition this supports multiple layers per stage.
Currently there is no overlap between the main pipeline communication (activation forwarding between stages via collective permute) and the compute, this is a WIP with the compiler team. Over ICI this is not such a big hit, but over DCN we cannot obtain great performance until we develop a process to overlap. This implementation does roughly get the expected performance - e.g. the two major detrimental factors of bubble_size and exposed comms account for the realized step time.
Example runs over DCN (pipeline parallelism over DCN, fsdp within ICI):
4x v4-16
3x v4-128 "Cheated" the sizes to get best MFU by maxing out num layers and MLP,