Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
456
stars
68
forks
source link
[Feature Request] Need ZeRo-1/2 to cooperate with PP+TP+DP. Which may more faster than FSDP sometimes. #64
For example in single A100 machine. Llama2 13B training speed with TP2 DP 4 + Zero1 is more faster than FSDP.