young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.33k stars 247 forks source link

Bug in llamda training? #66

Closed Dali660 closed 1 year ago

Dali660 commented 1 year ago

Hi, for distributed training, are you supposed to gather grads from all devices before applying grad update? https://github.com/young-geng/EasyLM/blob/eb24c9f176c5cc9899b0fa466cac79208b37c390/EasyLM/models/llama/llama_train.py#L127 it seems a grads=jax.lax.psum(grads) or pmean is missing here?

young-geng commented 1 year ago

This is only necessary for pmap, when you explicitly map over an extra axis. PJIT automatically partitions the computation with sharding annotations, so you don't need to do that.