ironjr / grokfast

Official repository for the paper "Grokfast: Accelerated Grokking by Amplifying Slow Gradients"
https://arxiv.org/abs/2405.20233
MIT License
341 stars 26 forks source link

Exploding Gradients #9

Closed DustinEwan closed 1 week ago

DustinEwan commented 3 weeks ago

I seem to keep running into an issue where the gradients (or rather, the grad norm) keeps getting larger and larger until eventually it becomes 'inf'

{'loss': 11.6752, 'grad_norm': 3789384056832.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2064384}
{'loss': 11.4293, 'grad_norm': 5675928780800.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2113536}
{'loss': 11.3192, 'grad_norm': 8501688532992.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2162688}
{'loss': 11.2486, 'grad_norm': 12734252974080.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2211840}
{'loss': 11.111, 'grad_norm': 19073997996032.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2260992}
{'loss': 11.2352, 'grad_norm': 28569986138112.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2310144}
{'loss': 11.3782, 'grad_norm': 42793548382208.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2359296}
{'loss': 11.2332, 'grad_norm': 64098310029312.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2408448}
{'loss': 11.3905, 'grad_norm': 96009648603136.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2457600}
{'loss': 11.2784, 'grad_norm': 143808037650432.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2506752}
{'loss': 11.3188, 'grad_norm': 215402843996160.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2555904}
{'loss': 11.4465, 'grad_norm': 322641097392128.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2605056}
{'loss': 11.3301, 'grad_norm': 483267975315456.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2654208}
{'loss': 11.354, 'grad_norm': 723862782214144.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2703360}
{'loss': 11.2944, 'grad_norm': 1084237784547328.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2752512}
{'loss': 11.1959, 'grad_norm': 1624025381994496.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2801664}
{'loss': 11.3966, 'grad_norm': 2432546264580096.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2850816}
{'loss': 11.2565, 'grad_norm': 3643589066227712.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2899968}
{'loss': 11.3079, 'grad_norm': 5457548907905024.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2949120}
{'loss': 11.2257, 'grad_norm': 8174589242769408.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2998272}
{'loss': 11.084, 'grad_norm': 1.22443075158016e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3047424}
{'loss': 11.2356, 'grad_norm': 1.834013527166157e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3096576}
{'loss': 11.0967, 'grad_norm': 2.747076973900595e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3145728}
{'loss': 11.3159, 'grad_norm': 4.114708807077069e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3194880}
{'loss': 11.2537, 'grad_norm': 6.163215792734208e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3244032}
{'loss': 11.1508, 'grad_norm': 9.231571782257869e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3293184}
{'loss': 11.2386, 'grad_norm': 1.382750805253161e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3342336}
{'loss': 11.2199, 'grad_norm': 2.0711531456181043e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3391488}
{'loss': 11.2959, 'grad_norm': 3.102276524535972e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3440640}
{'loss': 11.3998, 'grad_norm': 4.6467440153985024e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3489792}
{'loss': 11.2736, 'grad_norm': 6.960125757368566e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3538944}

In gradfilter_ema, if I change:

p.grad.data = p.grad.data + grads[n] * lamb to p.grad.data = (p.grad.data + grads[n] * lamb) / (1 + lamb)

Then this solves the exploding gradients... however, I'm not sure of the implications of this change and wanted to reach out to hear your thoughts.

고맙습니다! 이 생각은 너무 멋있어요! 수고했어요!

ironjr commented 3 weeks ago

Thanks for trying this out, and for the warm comment, too!

It seems like depending on the underlying loss surface of the task, which also depends on the model and the dataset being used, adding the low-frequency gradients with high amplitude may cause unstably large gradient updates. As you may know, the change of code you made is equivalent to reducing the learning rate accordingly to compensate this gradient explosion.

The sequence of the low-frequency parts grads[n] can be interpreted as gradients calculated from larger minibatch (please note that this interpretation fails if the transient of the model parameters is too rapid, though, since the older gradients summed into the filter output are evaluated based on the different set of parameters). Therefore, updating the gradients with p.grad.data = p.grad.data + grads[n] * lamb can be thought of as replacing the gradients by the sum of (1) the original, fast-varying, smaller minibatch gradients and (2) the low-frequency, slow-varying, larger minibatch gradients scaled by lamb. In other words, Grokfast can be regarded as a mixture of slow and fast gradients. This justifies your modification of the code, which takes the weighted average instead of scaled mixture of those gradients.

Please also note that this interpretation of mixture of gradients has a pitfall in that the distributions of the two components p.grad.data and grads[n] can actually be very different. This can be easily checked by plotting the mean and standard deviation of each gradient component.

Hope this helps!