young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.38k stars 254 forks source link

A detailed question on LLaMA training script. #43

Closed zhangzx-uiuc closed 1 year ago

zhangzx-uiuc commented 1 year ago

Thanks much for implementing the Jax/Flax version of these foundation language models! This is really helpful for TPU-backended researchers.

I am still a beginner of Jax/Flax, and I have a detailed question on the LLaMA training script. When the train_state is created at: https://github.com/young-geng/EasyLM/blob/e3e26579feae83352736207bac8446e8dee11840/EasyLM/models/llama/llama_train.py#L232 I wonder why you are using the sharded version of create_trainstate_from_params? It seems that in: https://github.com/young-geng/EasyLM/blob/e3e26579feae83352736207bac8446e8dee11840/EasyLM/models/llama/llama_train.py#L224 the sharded_fn is already passed into the checkpointer and the output restored_params is already sharded across all TPU devices. Will there be any problems if I use create_trainstate_from_params instead of shareded_create_trainstate_from_params in Line 232 (assuming that I am not using distributed training)?

Thanks!

young-geng commented 1 year ago

Yes, you will need to use the shareded_create_trainstate_from_params, otherwise your train_state will not be sharded correctly. Note that for multi-device training, it is always necessary to specify the sharding for any jit complied function.