HMUNACHI / nanodl

A Jax-based library for designing and training transformer models from scratch.
MIT License
271 stars 11 forks source link

Gradient synchronization in data-parallel trainers #12

Open cgarciae opened 6 months ago

cgarciae commented 6 months ago

Hey, great job with nanodl!

I was just looking through the code and noticed that when in Lambda's Trainer the gradients are not being averaged across devices here:

https://github.com/HMUNACHI/nanodl/blob/18c7f8e3da3c0bbfe2df3638a5e87857ec84868d/nanodl/__src/models/lamda.py#L564-L565

Not sure if this is happening elsewhere but usually to keep the weights in sync you apply a jax.lax.pmean over the gradients before passing them to apply_gradients, e.g.

grads = jax.lax.pmean(grads, axis_name='devices')
HMUNACHI commented 6 months ago

Thanks for noticing this! It's often challenging to test these portions due to the unavailability of a personal multi-GPU setup for development. However, I will be accessing 2 GPUs around 10th March. Will immediately examine this but you are more than welcome to make corrections from your end if convenient, I would in fact very much appreciate that.