google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
446 stars 68 forks source link

[NVIDIA] Add Llama2 configs #56

Closed ashors1 closed 9 months ago

ashors1 commented 9 months ago

Adds Llama2 7B, 13B and 70B configs for benchmarking purposes. Requires Praxis PR #37.