AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.49k stars 279 forks source link

How to implement 1F1B pipeline parallelism in Jax? #752

Open MoFHeka opened 3 months ago

MoFHeka commented 3 months ago

Not GPipe. Run pipeline forward meanwhile backward.

gobbleturk commented 1 month ago

We are still looking into this in the open source side! Likely at least 6 months away