Hello Phil, thank you for the great implementation!
I see that you normalize lr by (1 + lamb) for fairness of comparison, but I just want to clarify one question to be sure that I haven't missed anything.
In grokfast we modify the gradients to be equal to new_grad = grad + lamb * grok_exp_avg. And at first glance, since we assume that ||new_grad|| ~= ||grad|| * (1 + lamb), this sounds quite reasonable. However, we then pass these modified gradients to AdamW, so it seems to me that this gradient norm amplification must be canceled during the calculation of the final parameters update (because it is the same in the numerator and denominator neglecting eps).
Do we really need to normalize lr for a fair comparison?
We would probably need it for SGD, but I am a bit unsure about AdamW.
Hello Phil, thank you for the great implementation!
I see that you normalize lr by (1 + lamb) for fairness of comparison, but I just want to clarify one question to be sure that I haven't missed anything. In grokfast we modify the gradients to be equal to
new_grad = grad + lamb * grok_exp_avg
. And at first glance, since we assume that||new_grad|| ~= ||grad|| * (1 + lamb)
, this sounds quite reasonable. However, we then pass these modified gradients to AdamW, so it seems to me that this gradient norm amplification must be canceled during the calculation of the final parameters update (because it is the same in the numerator and denominator neglecting eps).Do we really need to normalize lr for a fair comparison? We would probably need it for SGD, but I am a bit unsure about AdamW.