google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.44k stars 263 forks source link

Pipeline paralleism (Linear only) #699

Closed gobbleturk closed 2 months ago

gobbleturk commented 2 months ago

Full PR with circular support is https://github.com/google/maxtext/pull/677 - we have decided to split up the PR into two

Without circular at first (This PR) Adding circular See https://github.com/google/maxtext/pull/677 for details on tests run in addition to the ones added in the PR. Note that without circular pipelining the bubble is usually large (or microbatch_size is small) so hard to get good performance

jonb377 commented 2 months ago

@gobbleturk Do we need this one?

gobbleturk commented 2 months ago

Nope, closing