google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
456 stars 68 forks source link

[Feature Request] Need ZeRo-1/2 to cooperate with PP+TP+DP. Which may more faster than FSDP sometimes. #64

Open MoFHeka opened 9 months ago

MoFHeka commented 9 months ago

For example in single A100 machine. Llama2 13B training speed with TP2 DP 4 + Zero1 is more faster than FSDP.