google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
456 stars 68 forks source link

[NVIDIA] Add LLaMA configs and scripts #68

Closed ashors1 closed 8 months ago

ashors1 commented 8 months ago

Adds LLaMA 7B, 13B and 70B configs and corresponding scripts to paxml/contrib/gpu/scripts_gpu

Depends on https://github.com/google/praxis/pull/50

zhangqiaorjc commented 8 months ago

Waiting for follow up on https://github.com/google/praxis/pull/50

zhangqiaorjc commented 8 months ago

@ashors1 a few files are missing license header, do you mind adding them?

ashors1 commented 8 months ago

@ashors1 a few files are missing license header, do you mind adding them?

Done, thanks!

zhangqiaorjc commented 8 months ago

merging