ironjr / grokfast

Official repository for the paper "Grokfast: Accelerated Grokking by Amplifying Slow Gradients"
https://arxiv.org/abs/2405.20233
MIT License
470 stars 38 forks source link

Anyone already working on including this in transformers? #2

Open l4b4r4b4b4 opened 2 months ago

l4b4r4b4b4 commented 2 months ago

Ill try my best, but thought to check if there is anyone else wanting to try this in context of transformers trainer.

ironjr commented 2 months ago

Although it is a fairly small one, the main algorithmic data experiments are done with a 2-layer Transformer with 400k parameters. I'll leave this issue open so people can share their thoughts.

l4b4r4b4b4 commented 2 months ago

worked the approach into a fork of HF transformer trainer. Runs without errors using trl ORPO trainer and unsloth. Will do some testing and report over the coming week. Since I did not follow a single best practice from the transformer library no PR yet but for anyone who wants to try it out: https://github.com/l4b4r4b4b4/transformers/blob/main/src/transformers/trainer.py

phalexo commented 2 months ago

Is there the expectation that all weights within a network are trainable or can it be used for fine-tuning when only some layers are trainable, @ironjr ?

I tried to insert the code into the Trainer inner training step as well but I get an error about NoneType.

lucasjinreal commented 2 months ago

@l4b4r4b4b4 hows the result going

HydrogenBombaklot commented 1 week ago

@l4b4r4b4b4 any update?

l4b4r4b4b4 commented 1 week ago

Had everything implemented in a transformers fork. I think @ehartford https://github.com/cognitivecomputations/grokadamw took it a bit further.

ehartford commented 1 week ago

I got creative, my implementation is inspired by the paper rather than a direct implementation of it

phalexo commented 1 week ago

Has it been tried on larger models to assess training time reduction? I am planning to deploy llama 3.1 70b instruct for machine translation and wondering if fine-tuning of it could benefit.

On Tue, Aug 20, 2024, 7:58 AM Eric Hartford @.***> wrote:

I got creative, my implementation is inspired by the paper rather than a direct implementation of it

— Reply to this email directly, view it on GitHub https://github.com/ironjr/grokfast/issues/2#issuecomment-2298679988, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDD3ZJWN5BBEKN4LQMRPATZSMVNVAVCNFSM6AAAAABJXALV5KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOJYGY3TSOJYHA . You are receiving this because you commented.Message ID: @.***>

ehartford commented 1 week ago

Mine so far I'm seeing only marginal divergence vs adamw-fused

An improvement but not an obvious slam dunk.

Maybe I need to improve the default grokking functions

image-8.png

phalexo commented 1 week ago

I am looking at the graphic and the two loss lines seem to track each other, there is no divergence, The absolute difference is likely just the artefact of how the two are initialized.

On Tue, Aug 20, 2024 at 8:41 AM Eric Hartford @.***> wrote:

Mine so far I'm seeing only marginal divergence vs adamw-fused

An improvement but not an obvious slam dunk.

Maybe I need to improve the default grokking functions

image-8.png (view on web) https://github.com/user-attachments/assets/397c77a3-09a2-45a1-a610-2b9c2f88cfb0

— Reply to this email directly, view it on GitHub https://github.com/ironjr/grokfast/issues/2#issuecomment-2298764645, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDD3ZIVOVQKESVBIZ3EMR3ZSM2QLAVCNFSM6AAAAABJXALV5KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOJYG43DINRUGU . You are receiving this because you commented.Message ID: @.***>

ehartford commented 1 week ago

Well, I suppose you know best.

Here is how they started.

image