google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.39k stars 247 forks source link

How to implement 1F1B pipeline parallelism in Jax? #752

Open MoFHeka opened 2 weeks ago

MoFHeka commented 2 weeks ago

Not GPipe. Run pipeline forward meanwhile backward.