salesforce / jaxformer

Minimal library to train LLMs on TPU in JAX with pjit().
BSD 3-Clause "New" or "Revised" License
270 stars 35 forks source link

Pipeline Parallelism #25

Closed sh0416 closed 1 year ago

sh0416 commented 1 year ago

I am wondering whether this codebase has pipeline parallelism technique or not.

https://huggingface.co/docs/transformers/perf_train_gpu_many#dppp

sh0416 commented 1 year ago

I think jax doesn't need manual pipeline parallelism.