Closed zhangzx-uiuc closed 1 year ago
Yes, you will need to use the shareded_create_trainstate_from_params
, otherwise your train_state will not be sharded correctly. Note that for multi-device training, it is always necessary to specify the sharding for any jit complied function.
Thanks much for implementing the Jax/Flax version of these foundation language models! This is really helpful for TPU-backended researchers.
I am still a beginner of Jax/Flax, and I have a detailed question on the LLaMA training script. When the
train_state
is created at: https://github.com/young-geng/EasyLM/blob/e3e26579feae83352736207bac8446e8dee11840/EasyLM/models/llama/llama_train.py#L232 I wonder why you are using the sharded version ofcreate_trainstate_from_params
? It seems that in: https://github.com/young-geng/EasyLM/blob/e3e26579feae83352736207bac8446e8dee11840/EasyLM/models/llama/llama_train.py#L224 thesharded_fn
is already passed into the checkpointer and the outputrestored_params
is already sharded across all TPU devices. Will there be any problems if I usecreate_trainstate_from_params
instead ofshareded_create_trainstate_from_params
in Line 232 (assuming that I am not using distributed training)?Thanks!