lucidrains / grokfast-pytorch

Explorations into the proposal from the paper "Grokfast, Accelerated Grokking by Amplifying Slow Gradients"
MIT License
85 stars 4 forks source link

The question about learning rate normalization #5

Open pkorobov opened 4 months ago

pkorobov commented 4 months ago

Hello Phil, thank you for the great implementation!

I see that you normalize lr by (1 + lamb) for fairness of comparison, but I just want to clarify one question to be sure that I haven't missed anything. In grokfast we modify the gradients to be equal to new_grad = grad + lamb * grok_exp_avg. And at first glance, since we assume that ||new_grad|| ~= ||grad|| * (1 + lamb), this sounds quite reasonable. However, we then pass these modified gradients to AdamW, so it seems to me that this gradient norm amplification must be canceled during the calculation of the final parameters update (because it is the same in the numerator and denominator neglecting eps).

Do we really need to normalize lr for a fair comparison? We would probably need it for SGD, but I am a bit unsure about AdamW.