young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.33k stars 247 forks source link

Use EasyLM to pre-train llama-7B using Nvidia GPU #78

Open zhpacer opened 1 year ago

zhpacer commented 1 year ago

Do we have the training script to pre-train a llama-7B model using GPU such as A100? Current examples are based on TPU. Don't know if there are some difference. thanks.

young-geng commented 1 year ago

I believe the configuration would be very similar, although you might need to tune the mesh dimensions according to your cluster configuration and network topology to get the best performance. Specifically, you'll want to add these options when training on GPUs in a multihost environment:

python -m EasyLM.models.llama.llama_train \
    --jax_distributed.initialize_jax_distributed=True \
    --jax_distributed.coordinator_address=<your coordinator (process 0) address and port> \
    --jax_distributed.num_processes=<total number of processes (hosts)> \
    --jax_distributed.process_id=<current process id>
zhpacer commented 1 year ago

Great thanks, I will have a try