stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

Pipeline Parallelism in `Stacked` #1

Open dlwh opened 1 year ago

dlwh commented 1 year ago

I think it's not too bad to implement pipeline parallelism directly in Stacked. The basic idea is that we map the Layers axis of a Stacked to a (new) physical axis (called stage here and in the link), then we reshape our batch into microbatches and push through the pipeline.

Example implementation https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/pipeline.py (which looks a lot like accumulate_gradients_sharded)

The biggest thing that's not clear to me is partitioning of the (macro) batch itself. Easiest thing to do is replicate it across the stage axis, but i think that's not ideal. should take a look at an impl of pipeline parallelism

dlwh commented 1 year ago

https://github.com/google/praxis/blob/main/praxis/layers/pipeline.py

(Googlers are telling me PP is a waste of time on TPU until you cross node boundary)