Open cgarciae opened 6 months ago
Thanks for noticing this! It's often challenging to test these portions due to the unavailability of a personal multi-GPU setup for development. However, I will be accessing 2 GPUs around 10th March. Will immediately examine this but you are more than welcome to make corrections from your end if convenient, I would in fact very much appreciate that.
Hey, great job with nanodl!
I was just looking through the code and noticed that when in Lambda's Trainer the gradients are not being averaged across devices here:
https://github.com/HMUNACHI/nanodl/blob/18c7f8e3da3c0bbfe2df3638a5e87857ec84868d/nanodl/__src/models/lamda.py#L564-L565
Not sure if this is happening elsewhere but usually to keep the weights in sync you apply a
jax.lax.pmean
over the gradients before passing them toapply_gradients
, e.g.