I think it's not too bad to implement pipeline parallelism directly in Stacked. The basic idea is that we map the Layers axis of a Stacked to a (new) physical axis (called stage here and in the link), then we reshape our batch into microbatches and push through the pipeline.
The biggest thing that's not clear to me is partitioning of the (macro) batch itself. Easiest thing to do is replicate it across the stage axis, but i think that's not ideal. should take a look at an impl of pipeline parallelism
I think it's not too bad to implement pipeline parallelism directly in Stacked. The basic idea is that we map the Layers axis of a Stacked to a (new) physical axis (called
stage
here and in the link), then we reshape our batch into microbatches and push through the pipeline.Example implementation https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/pipeline.py (which looks a lot like accumulate_gradients_sharded)
The biggest thing that's not clear to me is partitioning of the (macro) batch itself. Easiest thing to do is replicate it across the stage axis, but i think that's not ideal. should take a look at an impl of pipeline parallelism