ironjr / grokfast

Official repository for the paper "Grokfast: Accelerated Grokking by Amplifying Slow Gradients"
https://arxiv.org/abs/2405.20233
MIT License
476 stars 39 forks source link

mps compatibility #5

Closed d0rc closed 2 months ago

d0rc commented 2 months ago

For some reason, when using device='mps' it produced 'nan' instead of zeros, this fixes it.

ironjr commented 2 months ago

Thanks for your contribution!