Closed Dali660 closed 1 year ago
Hi, for distributed training, are you supposed to gather grads from all devices before applying grad update? https://github.com/young-geng/EasyLM/blob/eb24c9f176c5cc9899b0fa466cac79208b37c390/EasyLM/models/llama/llama_train.py#L127 it seems a grads=jax.lax.psum(grads) or pmean is missing here?
This is only necessary for pmap, when you explicitly map over an extra axis. PJIT automatically partitions the computation with sharding annotations, so you don't need to do that.
Hi, for distributed training, are you supposed to gather grads from all devices before applying grad update? https://github.com/young-geng/EasyLM/blob/eb24c9f176c5cc9899b0fa466cac79208b37c390/EasyLM/models/llama/llama_train.py#L127 it seems a grads=jax.lax.psum(grads) or pmean is missing here?